Skip to content

Commit 2d1acde

Browse files
probicheauxclaudepre-commit-ci[bot]
authored
nas: 4 new CLI commands + --group on model list (G4) (#474)
* nas: 4 new CLI commands + --group on model list (G4) Wraps the new public API routes from the agentic-surface-area onsite plan: roboflow train cancel <project>/<version> [--continue-if-no-refund] roboflow train stop <project>/<version> roboflow train results <project>/<version> roboflow model star <model-id> [--unstar] Plus extends `roboflow model list -p <project>` with `-g/--group <modelGroup>`, the canonical "list NAS models per run" path. When --group is set, the list command hits the public /models endpoint (full enriched projection: hardware, latency, map5095, paretoOptimalFor, recommended ★) instead of walking versions via the SDK. Adapter additions in roboflow/adapters/rfapi.py: cancel_version_training, stop_version_training, get_training_results, list_project_models (with optional group), get_model_by_url, favorite_nas_model Backend companions: - roboflow#11603 (G1, validator) - roboflow#11605 (G6, projection + ?group=) - roboflow#11610 (G2, public train cancel/stop + favorite) - roboflow#11612 (G3, training results) Tests: +13 cases across test_train_handler.py and test_model_handler.py covering register, success paths, 409 + MODEL_NOT_NAS hint surfacing, unstar flow, and the --group endpoint switch. All 298 CLI tests pass locally; ruff check + ruff format clean. CLI-COMMANDS.md updated with two new sections (train lifecycle + NAS list/star/deploy). E2E: driven against staging (api.roboflow.one) on peter-robicheaux/beer-can-hackathon: - `train results .../410` returned full NAS bundle (52 models, recommendedByHardware, modelGroup) - `model list -p ... -g <modelGroup>` rendered 53-row leaderboard table with HARDWARE / LATENCY / MAP50 / MAP5095 / REC columns - `model star 14CwSGmGetWh6rB0EnjL` → success, favorites reflected - `model star --unstar` flips state - `train cancel .../318` (finished version) → 409 surfaces hint "Cancel only applies to in-flight runs." Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * nas: 'model star' takes URL slug instead of Firestore doc id The public favorite endpoint now accepts the model URL slug (roboflow#11646), so the CLI can drop the Firestore-doc-id wart. Changes: - star_model argument is now `model_url`, accepting either the bare slug (when -w is set / a default workspace exists) or the workspace-prefixed form `<ws>/<slug>` — same shape as `model get`. - rfapi.favorite_nas_model parameter renamed `model_id` → `model_url` with urllib.parse.quote() for safety, since the slug is now what appears in the path. - Hints updated to point at models[].modelUrl instead of modelId, and the workspace fallback hint mentions the prefix form. Tests: +2 cases for the new parsing (workspace-prefixed URL vs bare slug + -w fallback). 22/22 model handler tests pass; 36/36 across model + train. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * test: train_results fixture uses URL slugs, not doc ids Mirrors the backend cleanup in roboflow#11646. The training-results fixture now uses the public shape (trainingId is workspace/project/ version, models[].modelUrl, recommendedByHardware values are URL slugs). No behavior change in the CLI handler — it passes the response through. * nas: 'model star' arg + favorite_nas_model param: model_url → model_id Mirrors the wire rename in roboflow#11646. The public API field for the opaque model identifier is now `modelId` (the value is still the URL slug; that's an implementation detail callers shouldn't have to reason about). Changes: - `roboflow model star` argument: `model_url` → `model_id`. Help text and error hints updated to point at `models[].modelId`. - `rfapi.favorite_nas_model(model_url=...)` → `favorite_nas_model( model_id=...)`. Internal local var becomes `public_model_id` to keep the call-site readable. - Test fixtures: `model_url` arg → `model_id`, `models[].modelUrl` → `models[].modelId`. 36/36 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(pre_commit): 🎨 auto format pre-commit hooks --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e37327c commit 2d1acde

6 files changed

Lines changed: 733 additions & 2 deletions

File tree

CLI-COMMANDS.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,50 @@ roboflow download my-workspace/my-project/3 -f coco # alias
4747
roboflow infer photo.jpg -m my-project/3
4848
```
4949

50+
### Train, monitor, cancel, stop
51+
52+
```bash
53+
# Start training (any architecture). For NAS sweeps, use a NAS parent modelType:
54+
roboflow train start -p my-project -v 3 --type rfdetr-base
55+
roboflow train start -p my-project -v 3 --type rfdetr-nas-parent # NAS sweep
56+
roboflow train start -p my-project -v 3 --type rfdetr-nas-base-parent # NAS Base sweep
57+
roboflow train start -p my-project -v 3 --type rfdetr-nas-seg-parent # NAS instance-segmentation
58+
59+
# Cancel an in-flight training (any architecture; NAS-aware):
60+
roboflow train cancel my-project/3
61+
# Pass --continue-if-no-refund to cancel even past the refund window:
62+
roboflow train cancel my-project/3 --continue-if-no-refund
63+
64+
# Graceful early-stop:
65+
roboflow train stop my-project/3
66+
67+
# Run-level training results bundle (NAS leaderboard for NAS runs,
68+
# minimal bundle for non-NAS):
69+
roboflow train results my-project/3
70+
```
71+
72+
NAS sweeps require the version's validation split to have at least 15 images;
73+
the server returns `code: "insufficient_validation_images_for_nas"` otherwise.
74+
75+
### NAS models — list, star, deploy
76+
77+
```bash
78+
# Get a NAS run's modelGroup from training results:
79+
roboflow --json train results my-project/3 | jq -r .modelGroup
80+
# → rfdetrNasGroup-3
81+
82+
# List every model from one NAS run, with hardware/latency/mAP columns:
83+
roboflow model list -p my-project --group rfdetrNasGroup-3
84+
85+
# Star a NAS-trained model (triggers TRT compile for its recommended hardware):
86+
# --json train results … gives you the modelId per row.
87+
roboflow model star <modelId>
88+
roboflow model star <modelId> --unstar
89+
```
90+
91+
`model star` is NAS-only by server-side design; non-NAS modelTypes return
92+
`code: "MODEL_NOT_NAS"`.
93+
5094
### Search and export
5195

5296
```bash

roboflow/adapters/rfapi.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,99 @@ def start_version_training(
9696
return True
9797

9898

99+
def cancel_version_training(
100+
api_key: str,
101+
workspace_url: str,
102+
project_url: str,
103+
version: str,
104+
*,
105+
continue_if_no_refund: bool = False,
106+
):
107+
"""Cancel an in-flight training run.
108+
109+
Backend handler is canonical for both vanilla and NAS trainings — it
110+
accepts ``mining`` status, so this works for NAS sweeps too.
111+
"""
112+
url = f"{API_URL}/{workspace_url}/{project_url}/{version}/train/cancel?api_key={api_key}"
113+
body: Dict[str, Union[str, int, bool]] = {}
114+
if continue_if_no_refund:
115+
body["continueIfNoRefund"] = True
116+
response = requests.post(url, json=body)
117+
if not response.ok:
118+
raise RoboflowError(response.text)
119+
return response.json() if response.content else {"success": True}
120+
121+
122+
def stop_version_training(api_key: str, workspace_url: str, project_url: str, version: str):
123+
"""Request an early stop on an in-flight training run.
124+
125+
The backend flips ``train.requestedStop``; the run finishes the current
126+
phase gracefully (mining or training).
127+
"""
128+
url = f"{API_URL}/{workspace_url}/{project_url}/{version}/train/stop?api_key={api_key}"
129+
response = requests.post(url, json={})
130+
if not response.ok:
131+
raise RoboflowError(response.text)
132+
return response.json() if response.content else {"success": True}
133+
134+
135+
def get_training_results(api_key: str, workspace_url: str, project_url: str, version: str):
136+
"""Run-level training results bundle.
137+
138+
For NAS runs returns ``{ trainingId, status, modelGroup, modelCount,
139+
recommendedByHardware, mining?, models: [...] }``. For non-NAS runs
140+
returns a minimal bundle with the produced model(s).
141+
"""
142+
url = f"{API_URL}/{workspace_url}/{project_url}/{version}/training/results?api_key={api_key}"
143+
response = requests.get(url)
144+
if not response.ok:
145+
raise RoboflowError(response.text)
146+
return response.json()
147+
148+
149+
def list_project_models(
150+
api_key: str,
151+
workspace_url: str,
152+
project_url: str,
153+
*,
154+
group: Optional[str] = None,
155+
):
156+
"""List models for a project; pass ``group`` to scope to one NAS run."""
157+
url = f"{API_URL}/{workspace_url}/{project_url}/models?api_key={api_key}"
158+
if group:
159+
url += f"&group={urllib.parse.quote(group, safe='')}"
160+
response = requests.get(url)
161+
if not response.ok:
162+
raise RoboflowError(response.text)
163+
return response.json()
164+
165+
166+
def get_model_by_url(api_key: str, workspace_url: str, model_url: str):
167+
"""Fetch a single model by its URL slug."""
168+
encoded = urllib.parse.quote(model_url, safe="/")
169+
url = f"{API_URL}/models/{workspace_url}/{encoded}?api_key={api_key}"
170+
response = requests.get(url)
171+
if not response.ok:
172+
raise RoboflowError(response.text)
173+
return response.json()
174+
175+
176+
def favorite_nas_model(api_key: str, workspace_url: str, model_id: str, *, starred: bool = True):
177+
"""Star or unstar a NAS-trained model.
178+
179+
``model_id`` is the opaque public model id (e.g. ``my-project-3-nas-gpu-b``),
180+
the same value the public API returns as ``models[].modelId`` on
181+
``GET /:workspace/:project/:version/training/results``. NAS-only on the
182+
server side.
183+
"""
184+
encoded = urllib.parse.quote(model_id, safe="")
185+
url = f"{API_URL}/{workspace_url}/models/{encoded}/favorite?api_key={api_key}"
186+
response = requests.post(url, json={"starred": bool(starred)})
187+
if not response.ok:
188+
raise RoboflowError(response.text)
189+
return response.json()
190+
191+
99192
def get_version(api_key: str, workspace_url: str, project_url: str, version: str, nocache: bool = False):
100193
"""
101194
Fetch detailed information about a specific dataset version.

roboflow/cli/handlers/model.py

Lines changed: 150 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,23 @@
1515
def list_models(
1616
ctx: typer.Context,
1717
project: Annotated[str, typer.Option("-p", "--project", help="Project ID or shorthand (e.g. my-ws/my-project)")],
18+
group: Annotated[
19+
Optional[str],
20+
typer.Option(
21+
"-g",
22+
"--group",
23+
help=(
24+
"NAS modelGroup to scope the list to a single NAS run. "
25+
"Get the value from 'roboflow train results <project>/<version>'."
26+
),
27+
),
28+
] = None,
1829
) -> None:
19-
"""List trained models for a project."""
20-
args = ctx_to_args(ctx, project=project)
30+
"""List trained models for a project.
31+
32+
Pass --group <modelGroup> to filter to a single NAS run.
33+
"""
34+
args = ctx_to_args(ctx, project=project, group=group)
2135
_list_models(args)
2236

2337

@@ -31,6 +45,30 @@ def get_model(
3145
_get_model(args)
3246

3347

48+
@model_app.command("star")
49+
def star_model(
50+
ctx: typer.Context,
51+
model_id: Annotated[
52+
str,
53+
typer.Argument(
54+
help=(
55+
"Model id (e.g. workspace/model-id, or just the bare id if -w is set). "
56+
"Get it from 'roboflow train results <project>/<version>' (models[].modelId)."
57+
),
58+
),
59+
],
60+
unstar: Annotated[bool, typer.Option("--unstar", help="Unstar instead of starring")] = False,
61+
) -> None:
62+
"""Star or unstar a NAS-trained model.
63+
64+
NAS-only by design — the server rejects non-NAS modelTypes with a
65+
MODEL_NOT_NAS error. Starring triggers TRT compilation for the model's
66+
recommended hardware so the model becomes deployable as an edge target.
67+
"""
68+
args = ctx_to_args(ctx, model_id=model_id, starred=not unstar)
69+
_star_model(args)
70+
71+
3472
@model_app.command("infer")
3573
def model_infer(
3674
ctx: typer.Context,
@@ -86,16 +124,64 @@ def upload_model(
86124

87125
def _list_models(args): # noqa: ANN001
88126
import roboflow
127+
from roboflow.adapters import rfapi
89128
from roboflow.cli._output import output, output_error, suppress_sdk_output
90129
from roboflow.cli._resolver import resolve_resource
91130
from roboflow.cli._table import format_table
131+
from roboflow.config import load_roboflow_api_key
92132

93133
try:
94134
workspace_url, project_slug, _version = resolve_resource(args.project, workspace_override=args.workspace)
95135
except ValueError as exc:
96136
output_error(args, str(exc))
97137
return
98138

139+
group = getattr(args, "group", None)
140+
141+
if group:
142+
# NAS path — hit the public /models endpoint with ?group= filter.
143+
# Surfaces full per-row NAS metadata (nasFamily, group,
144+
# train.results.{hardware,latency,map5095,paretoOptimalFor},
145+
# favorites, recommended).
146+
api_key = args.api_key or load_roboflow_api_key(workspace_url)
147+
if not api_key:
148+
output_error(
149+
args,
150+
"No API key found.",
151+
hint="Set ROBOFLOW_API_KEY or run 'roboflow auth login'.",
152+
exit_code=2,
153+
)
154+
return
155+
try:
156+
rows = rfapi.list_project_models(api_key, workspace_url, project_slug, group=group)
157+
except rfapi.RoboflowError as exc:
158+
output_error(args, str(exc), exit_code=3)
159+
return
160+
if not isinstance(rows, list):
161+
rows = []
162+
# Project a leaderboard view for the text table; full row stays in JSON.
163+
table_rows = []
164+
for r in rows:
165+
metrics = r.get("metrics") or {}
166+
table_rows.append(
167+
{
168+
"url": r.get("url", ""),
169+
"type": r.get("modelType", ""),
170+
"hardware": metrics.get("hardware", ""),
171+
"latency": metrics.get("latency", ""),
172+
"map50": metrics.get("map50", ""),
173+
"map5095": metrics.get("map5095", ""),
174+
"recommended": "★" if r.get("recommended") else "",
175+
}
176+
)
177+
table = format_table(
178+
table_rows,
179+
columns=["url", "type", "hardware", "latency", "map50", "map5095", "recommended"],
180+
headers=["URL", "TYPE", "HARDWARE", "LATENCY", "MAP50", "MAP5095", "REC"],
181+
)
182+
output(args, rows, text=table)
183+
return
184+
99185
api_key = args.api_key or None
100186

101187
try:
@@ -130,6 +216,68 @@ def _list_models(args): # noqa: ANN001
130216
output(args, models, text=table)
131217

132218

219+
def _star_model(args): # noqa: ANN001
220+
from roboflow.adapters import rfapi
221+
from roboflow.cli._output import output, output_error
222+
from roboflow.config import load_roboflow_api_key
223+
224+
# Accept either "workspace/model-id" or just "model-id" (when -w is
225+
# set). Mirrors the parsing pattern used by `roboflow model get`.
226+
raw = args.model_id.strip("/")
227+
if "/" in raw:
228+
ws_from_arg, _sep, public_model_id = raw.partition("/")
229+
else:
230+
ws_from_arg, public_model_id = None, raw
231+
232+
workspace_url = args.workspace or ws_from_arg
233+
if not workspace_url:
234+
from roboflow.cli._resolver import resolve_default_workspace
235+
236+
workspace_url = resolve_default_workspace(args.api_key)
237+
if not workspace_url:
238+
output_error(
239+
args,
240+
"Could not determine workspace.",
241+
hint=(
242+
"Pass -w/--workspace, prefix the model id (workspace/id), or run 'roboflow auth set-workspace <ws>'."
243+
),
244+
exit_code=2,
245+
)
246+
return
247+
248+
api_key = args.api_key or load_roboflow_api_key(workspace_url)
249+
if not api_key:
250+
output_error(
251+
args,
252+
"No API key found.",
253+
hint="Set ROBOFLOW_API_KEY or run 'roboflow auth login'.",
254+
exit_code=2,
255+
)
256+
return
257+
258+
try:
259+
result = rfapi.favorite_nas_model(api_key, workspace_url, public_model_id, starred=args.starred)
260+
except rfapi.RoboflowError as exc:
261+
msg = str(exc)
262+
hint = None
263+
if "MODEL_NOT_NAS" in msg or "non-NAS" in msg:
264+
hint = "Star is NAS-only. Use 'roboflow train results' to find NAS model ids (models[].modelId)."
265+
elif "MODEL_NOT_IN_WORKSPACE" in msg:
266+
hint = (
267+
"Verify the model id and workspace. The id is the same value "
268+
"'roboflow train results' returns as models[].modelId."
269+
)
270+
output_error(args, msg, hint=hint, exit_code=3)
271+
return
272+
273+
verb = "starred" if args.starred else "unstarred"
274+
output(
275+
args,
276+
result,
277+
text=f"Model {workspace_url}/{public_model_id} {verb}.",
278+
)
279+
280+
133281
def _get_model(args): # noqa: ANN001
134282
import json
135283

0 commit comments

Comments
 (0)