diff --git a/docs/features/remote-sessions.md b/docs/features/remote-sessions.md index 391bb762d..216e80681 100644 --- a/docs/features/remote-sessions.md +++ b/docs/features/remote-sessions.md @@ -100,7 +100,7 @@ session.On((SessionEvent e) => use github_copilot_sdk::{Client, ClientOptions, PermissionRequestResult, SessionConfig}; let client = Client::start( - ClientOptions::new().with_remote(true) + ClientOptions::new().with_enable_remote_sessions(true) ).await?; let session = client.create_session( SessionConfig::new("/path/to/github-repo") diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 11f8c90c9..438ac5c7e 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -594,7 +594,7 @@ public async Task CreateSessionAsync(SessionConfig config, Cance config.ExcludedTools, config.Provider, config.EnableSessionTelemetry, - (bool?)true, + config.OnPermissionRequest != null ? true : null, config.OnUserInputRequest != null ? true : null, config.OnExitPlanModeRequest != null ? true : null, config.OnAutoModeSwitchRequest != null ? true : null, @@ -752,7 +752,7 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes config.ExcludedTools, config.Provider, config.EnableSessionTelemetry, - (bool?)true, + config.OnPermissionRequest != null ? true : null, config.OnUserInputRequest != null ? true : null, config.OnExitPlanModeRequest != null ? true : null, config.OnAutoModeSwitchRequest != null ? true : null, diff --git a/go/client.go b/go/client.go index dab09a4dd..eb7276de2 100644 --- a/go/client.go +++ b/go/client.go @@ -665,7 +665,9 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses config.Hooks.OnErrorOccurred != nil) { req.Hooks = Bool(true) } - req.RequestPermission = Bool(true) + if config.OnPermissionRequest != nil { + req.RequestPermission = Bool(true) + } traceparent, tracestate := getTraceContext(ctx) req.Traceparent = traceparent @@ -839,7 +841,9 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, req.InfiniteSessions = config.InfiniteSessions req.GitHubToken = config.GitHubToken req.RemoteSession = config.RemoteSession - req.RequestPermission = Bool(true) + if config.OnPermissionRequest != nil { + req.RequestPermission = Bool(true) + } if len(config.Commands) > 0 { cmds := make([]wireCommand, 0, len(config.Commands)) diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 991f23fa1..ddc23dd2f 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -830,7 +830,7 @@ export class CopilotClient { provider: config.provider, enableSessionTelemetry: config.enableSessionTelemetry, modelCapabilities: config.modelCapabilities, - requestPermission: true, + requestPermission: !!config.onPermissionRequest, requestUserInput: !!config.onUserInputRequest, requestElicitation: !!config.onElicitationRequest, requestExitPlanMode: !!config.onExitPlanModeRequest, diff --git a/python/copilot/client.py b/python/copilot/client.py index 6adb52061..9c7698716 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -1493,8 +1493,8 @@ async def create_session( if excluded_tools is not None: payload["excludedTools"] = excluded_tools - # Always enable permission request callback - payload["requestPermission"] = True + # Enable permission request callback if handler provided + payload["requestPermission"] = bool(on_permission_request) # Enable user input request callback if handler provided if on_user_input_request: @@ -1888,8 +1888,8 @@ async def resume_session( else True ) - # Always enable permission request callback - payload["requestPermission"] = True + # Enable permission request callback if handler provided + payload["requestPermission"] = bool(on_permission_request) if on_user_input_request: payload["requestUserInput"] = True diff --git a/rust/README.md b/rust/README.md index 78103e4df..2b3a5423c 100644 --- a/rust/README.md +++ b/rust/README.md @@ -18,7 +18,7 @@ use github_copilot_sdk::handler::ApproveAllHandler; # async fn example() -> Result<(), github_copilot_sdk::Error> { let client = Client::start(ClientOptions::default()).await?; let session = client.create_session( - SessionConfig::default().with_handler(Arc::new(ApproveAllHandler)), + SessionConfig::default().with_permission_handler(Arc::new(ApproveAllHandler)), ).await?; let _message_id = session.send("Hello!").await?; session.disconnect().await?; @@ -50,10 +50,10 @@ The SDK manages the CLI process lifecycle: spawning, health-checking, and gracef let client = Client::start(options).await?; // Create a new session -let session = client.create_session(config.with_handler(handler)).await?; +let session = client.create_session(config.with_permission_handler(handler)).await?; // Resume an existing session -let session = client.resume_session(config.with_handler(handler)).await?; +let session = client.resume_session(config.with_permission_handler(handler)).await?; // Low-level RPC let result = client.call("method.name", Some(params)).await?; @@ -82,7 +82,7 @@ With the default `CliProgram::Resolve`, `Client::start()` automatically resolves ### Session -Created via `Client::create_session` or `Client::resume_session`. Owns an internal event loop that dispatches events to the `SessionHandler`. +Created via `Client::create_session` or `Client::resume_session`. Owns an internal event loop that dispatches CLI callbacks to the focused handler traits you install on `SessionConfig`, and broadcasts session events through `subscribe()`. ```rust,ignore use github_copilot_sdk::MessageOptions; @@ -101,7 +101,7 @@ let _id = session .await?; // Message history -let messages = session.get_messages().await?; +let messages = session.get_events().await?; // Abort the current agent turn session.abort().await?; @@ -176,22 +176,31 @@ New RPCs land in the namespace immediately as the schema regenerates; helpers are added on top only when an ergonomic story is worth the maintenance. -### SessionHandler +### Handler Traits -Implement this trait to control how a session responds to CLI events. Two styles are supported: +The SDK exposes five focused handler traits, one per CLI callback type. Implement only the traits you need and install each with the matching `SessionConfig` setter. Each trait has a single `async fn handle(...)` method: -**1. Per-event methods (recommended).** Override only the callbacks you care about; every method has a safe default (permission → deny, user input → none, external tool → "no handler", elicitation → cancel, exit plan → default). When no handler is installed on a session, the SDK uses `NoopHandler`, which leaves permission and external tool requests pending for manual resolution. This is the `serenity::EventHandler` pattern. +| Trait | Setter | Purpose | +| ----------------------- | --------------------------------- | --------------------------------------------- | +| `PermissionHandler` | `with_permission_handler(...)` | Approve/deny tool-use permission requests | +| `ElicitationHandler` | `with_elicitation_handler(...)` | Respond to structured elicitation prompts | +| `UserInputHandler` | `with_user_input_handler(...)` | Answer free-form / choice user-input prompts | +| `ExitPlanModeHandler` | `with_exit_plan_mode_handler(...)`| Respond when the agent exits plan mode | +| `AutoModeSwitchHandler` | `with_auto_mode_switch_handler(...)`| Respond to automatic mode-switch proposals | + +The CLI's `requestPermission` / `requestElicitation` / `requestUserInput` / etc. wire flags are derived automatically from which traits you've installed — clients that don't install a handler are silently skipped, letting another connected client handle the request. ```rust,ignore +use std::sync::Arc; use async_trait::async_trait; -use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::handler::{PermissionHandler, PermissionResult}; use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionId}; -struct MyHandler; +struct MyPermissions; #[async_trait] -impl SessionHandler for MyHandler { - async fn on_permission_request( +impl PermissionHandler for MyPermissions { + async fn handle( &self, _sid: SessionId, _rid: RequestId, @@ -203,47 +212,21 @@ impl SessionHandler for MyHandler { PermissionResult::Denied } } - - async fn on_session_event(&self, sid: SessionId, event: github_copilot_sdk::types::SessionEvent) { - println!("[{sid}] {}", event.event_type); - } } + +let config = SessionConfig::default().with_permission_handler(Arc::new(MyPermissions)); ``` -**2. Single `on_event` method.** Override `on_event` directly and `match` on `HandlerEvent` — useful for logging middleware, custom routing, or when you want one exhaustive dispatch point. +A single type can implement multiple handler traits — share one `Arc` across the setters by cloning: ```rust,ignore -use github_copilot_sdk::handler::*; -use async_trait::async_trait; - -#[async_trait] -impl SessionHandler for MyRouter { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::SessionEvent { session_id, event } => { - println!("[{session_id}] {}", event.event_type); - HandlerResponse::Ok - } - HandlerEvent::PermissionRequest { .. } => { - HandlerResponse::Permission(PermissionResult::Approved) - } - HandlerEvent::UserInput { question, .. } => { - HandlerResponse::UserInput(Some(UserInputResponse { - answer: prompt_user(&question), - was_freeform: true, - })) - } - _ => HandlerResponse::Ok, - } - } -} +let h = Arc::new(MyHandler); +let config = SessionConfig::default() + .with_permission_handler(h.clone()) + .with_user_input_handler(h); ``` -The default `on_event` dispatches to the per-event methods, so overriding `on_event` short-circuits them entirely — pick one style per handler. - -Events are processed serially per session — blocking in a handler method pauses that session's event loop (which is correct, since the CLI is also waiting for the response). Other sessions are unaffected. - -> **Note:** Notification-triggered events (`PermissionRequest` via `permission.requested`, `ExternalTool` via `external_tool.requested`) are dispatched on spawned tasks and may run concurrently with the serial event loop. See the trait-level docs on `SessionHandler` for details. +The built-in `ApproveAllHandler` and `DenyAllHandler` implement `PermissionHandler` for the common cases. To observe streamed session events (assistant messages, tool calls, etc.), call `session.subscribe()` — see [Streaming](#streaming) below. ### SessionConfig @@ -254,10 +237,11 @@ let config = SessionConfig { content: Some("Always explain your reasoning.".into()), ..Default::default() }), - request_elicitation: Some(true), // enable elicitation provider ..Default::default() -}; -let session = client.create_session(config.with_handler(handler)).await?; +} +.with_elicitation_handler(Arc::new(my_elicitation_handler)) +.with_permission_handler(handler); +let session = client.create_session(config).await?; ``` ### Session Hooks @@ -300,7 +284,7 @@ impl SessionHooks for MyHooks { let session = client .create_session( config - .with_handler(handler) + .with_permission_handler(handler) .with_hooks(Arc::new(MyHooks)), ) .await?; @@ -337,22 +321,23 @@ impl SystemMessageTransform for MyTransform { let session = client .create_session( config - .with_handler(handler) - .with_transform(Arc::new(MyTransform)), + .with_permission_handler(handler) + .with_system_message_transform(Arc::new(MyTransform)), ) .await?; ``` ### Tool Registration -Define client-side tools as named types with `ToolHandler`, then route them with `ToolHandlerRouter`. Enable the `derive` feature for `schema_for::()` — it generates JSON Schema from Rust types via `schemars`. +Define client-side tools as named types implementing `ToolHandler` and attach +them to `Tool` declarations via `Tool::with_handler`, then install via +`SessionConfig::with_tools`. Enable the `derive` feature for `schema_for::()` +— it generates JSON Schema from Rust types via `schemars`. ```rust,ignore use std::sync::Arc; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ - schema_for, tool_parameters, JsonSchema, ToolHandler, ToolHandlerRouter, -}; +use github_copilot_sdk::tool::{schema_for, JsonSchema, ToolHandler}; use github_copilot_sdk::{Error, SessionConfig, Tool, ToolInvocation, ToolResult}; use serde::Deserialize; use async_trait::async_trait; @@ -369,58 +354,48 @@ struct GetWeatherTool; #[async_trait] impl ToolHandler for GetWeatherTool { - fn tool(&self) -> Tool { - Tool { - name: "get_weather".to_string(), - namespaced_name: None, - description: "Get weather for a city".to_string(), - parameters: tool_parameters(schema_for::()), - instructions: None, - } - } - async fn call(&self, inv: ToolInvocation) -> Result { let params: GetWeatherParams = serde_json::from_value(inv.arguments)?; Ok(ToolResult::Text(format!("Weather in {}: sunny", params.city))) } } -// Build a router that dispatches tool calls by name -let router = ToolHandlerRouter::new( - vec![Box::new(GetWeatherTool)], - Arc::new(ApproveAllHandler), -); +let tool = Tool::new("get_weather") + .with_description("Get weather for a city") + .with_parameters(schema_for::()) + .with_handler(Arc::new(GetWeatherTool)); -let config = SessionConfig { - tools: Some(router.tools()), - ..Default::default() -} -.with_handler(Arc::new(router)); +let config = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(vec![tool]); let session = client.create_session(config).await?; ``` -Tools are named types (not closures) — visible in stack traces and navigable via "go to definition". The router implements `SessionHandler`, forwarding unrecognized tools and non-tool events to the inner handler. +Tools are named types (not closures) — visible in stack traces and navigable via "go to definition". The SDK registers each tool's handler under its `Tool::name` and surfaces the same `Tool` definitions to the CLI automatically. -For trivial tools that don't need a named type, [`define_tool`](crate::tool::define_tool) collapses the definition to a single expression: +Tools without an attached handler (`Tool::with_handler` never called) are declaration-only: the SDK advertises them on the wire but doesn't dispatch invocations to anything. Useful when another connected client services the tool. + +For trivial tools that don't need a named type, [`define_tool`](crate::tool::define_tool) collapses the definition to a single expression and returns a fully-formed `Tool` with handler attached: ```rust,ignore -use github_copilot_sdk::tool::{define_tool, JsonSchema, ToolHandlerRouter}; +use github_copilot_sdk::tool::{define_tool, JsonSchema}; use github_copilot_sdk::ToolResult; use serde::Deserialize; #[derive(Deserialize, JsonSchema)] struct GetWeatherParams { city: String } -let router = ToolHandlerRouter::new( - vec![define_tool( - "get_weather", - "Get weather for a city", - |_inv, params: GetWeatherParams| async move { - Ok(ToolResult::Text(format!("Sunny in {}", params.city))) - }, - )], - Arc::new(ApproveAllHandler), +let tool = define_tool( + "get_weather", + "Get weather for a city", + |_inv, params: GetWeatherParams| async move { + Ok(ToolResult::Text(format!("Sunny in {}", params.city))) + }, ); + +let config = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(vec![tool]); ``` The closure receives the full [`ToolInvocation`](crate::types::ToolInvocation) alongside the deserialized parameters, so handlers that need `inv.session_id` or `inv.tool_call_id` for telemetry, streaming updates, or scoped lookups can use them directly. Use `_inv` when you don't need the metadata. @@ -429,13 +404,12 @@ Reach for the `ToolHandler` trait directly when you need shared state across mul ### Permission Policies -Set a permission policy directly on `SessionConfig` with the chainable builders. They wrap whatever handler you've installed (defaulting to `NoopHandler` if none) so only permission requests are intercepted; every other event flows through unchanged. +Set a permission policy directly on `SessionConfig` with the chainable builders. They install a synthesized `PermissionHandler` so only permission requests are intercepted; every other event flows through unchanged. ```rust,ignore let session = client .create_session( SessionConfig::default() - .with_handler(Arc::new(my_handler)) .approve_all_permissions(), // or .deny_all_permissions() // or .approve_permissions_if(|data| { @@ -445,54 +419,86 @@ let session = client .await?; ``` -> Call the policy method **after** `with_handler` — `with_handler` overwrites the handler field, so `approve_all_permissions().with_handler(...)` discards the wrap. +> The policy builders set the permission handler slot directly; they're equivalent to calling `with_permission_handler(...)` with the corresponding built-in (`ApproveAllHandler`, `DenyAllHandler`, or `permission::approve_if(...)`). -For composing a policy onto a handler outside the builder chain (e.g. when wrapping a `ToolHandlerRouter` you've built elsewhere), the `permission` module exposes the same primitives as free functions: +The `permission` module also exposes the policy primitives as standalone helpers for the rare case where you want to construct the handler value separately and install it via `with_permission_handler`: ```rust,ignore use github_copilot_sdk::permission; -let router = ToolHandlerRouter::new(tools, Arc::new(MyHandler)); -let handler = permission::approve_all(Arc::new(router)); -// or permission::deny_all(...) / permission::approve_if(..., predicate) +let handler = permission::approve_if(|data| { + data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") +}); +// or permission::approve_all() / permission::deny_all() -let session = client.create_session(config.with_handler(handler)).await?; +let session = client + .create_session(config.with_permission_handler(handler)) + .await?; ``` -### Capabilities & Elicitation +### Elicitation -The SDK negotiates capabilities with the CLI after session creation. Enable elicitation to let the agent present structured UI dialogs (forms, URL prompts) to the user. +To opt your client into receiving `elicitation.requested` broadcasts, install an `ElicitationHandler` on the session config. The wire flag `requestElicitation` is derived from the presence of the handler; clients without one are silently skipped, allowing other connected clients on the same CLI to handle the request. ```rust,ignore -let config = SessionConfig { - request_elicitation: Some(true), - ..Default::default() -}; +use async_trait::async_trait; +use github_copilot_sdk::handler::{ElicitationHandler, ElicitationResult}; +use github_copilot_sdk::types::{ElicitationRequest, RequestId, SessionId}; + +struct MyElicitation; + +#[async_trait] +impl ElicitationHandler for MyElicitation { + async fn handle( + &self, + _sid: SessionId, + _rid: RequestId, + _request: ElicitationRequest, + ) -> ElicitationResult { + ElicitationResult::cancel() + } +} + +let config = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_elicitation_handler(Arc::new(MyElicitation)); ``` -The handler receives `HandlerEvent::ElicitationRequest` with a message, optional JSON Schema for form fields, and an optional mode. Known modes include `Form` and `Url`, but the mode may be absent or an unknown future value. Return `HandlerResponse::Elicitation(result)`. +The handler receives a message, optional JSON Schema for form fields, and an optional mode. Known modes include `Form` and `Url`, but the mode may be absent or an unknown future value. ### User Input Requests -Some sessions ask the user free-form questions (or multiple-choice prompts) outside the elicitation flow. Implement `SessionHandler::on_user_input` and the SDK will forward `userInput.request` callbacks: +Some sessions ask the user free-form questions (or multiple-choice prompts) outside the elicitation flow. Install a `UserInputHandler` and the SDK will forward `userInput.request` callbacks: ```rust,ignore -async fn on_user_input( - &self, - _session_id: SessionId, - question: String, - choices: Option>, - _allow_freeform: Option, -) -> Option { - // Render `question` + `choices` to your UI, then: - Some(UserInputResponse { - answer: "Yes".to_string(), - was_freeform: false, - }) +use async_trait::async_trait; +use github_copilot_sdk::handler::{UserInputHandler, UserInputResponse}; +use github_copilot_sdk::types::SessionId; + +struct MyUserInput; + +#[async_trait] +impl UserInputHandler for MyUserInput { + async fn handle( + &self, + _sid: SessionId, + question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + // Render `question` + `choices` to your UI, then: + Some(UserInputResponse { + answer: "Yes".to_string(), + was_freeform: false, + }) + } } + +let config = SessionConfig::default() + .with_user_input_handler(Arc::new(MyUserInput)); ``` -Return `None` to signal "no answer available" (the CLI falls back to its own prompt). Enable via `SessionConfig::request_user_input` (defaults to `Some(true)`). +Return `None` to signal "no answer available" (the CLI falls back to its own prompt). ### Slash Commands @@ -664,7 +670,8 @@ ergonomics the dynamically-typed SDKs don't. [`SessionConfig::with_session_fs_provider`]. The factory pattern doesn't cleanly express in Rust at the session-config call site — there is no `Session` value to thread in, and the SDK already prefers traits over - boxed closures for handler-shaped APIs (`SessionHandler`, `SessionHooks`, + boxed closures for handler-shaped APIs (`PermissionHandler`, `ToolHandler`, + `SessionHooks`, `ToolHandler`). ```rust,ignore @@ -682,7 +689,7 @@ let client = Client::start(options).await?; let session = client .create_session( SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_session_fs_provider(Arc::new(MyProvider::new())), ) .await?; @@ -705,9 +712,9 @@ none of them are scheduled for removal. identifier from an arbitrary `String` at compile time. Node/Python/Go use bare strings. - **Permission policy builders** — `permission::approve_all`, - `permission::deny_all`, and `permission::approve_if(handler, predicate)` - in `crate::permission` provide composable, no-handler-needed permission - shortcuts that wrap an existing `SessionHandler`. Other SDKs require a + `permission::deny_all`, and `permission::approve_if(predicate)` + in `crate::permission` provide composable, no-handler-needed + `PermissionHandler` shortcuts. Other SDKs require a full handler implementation for these patterns. - **`Client::from_streams`** — connect to a CLI server over arbitrary caller-supplied `AsyncRead` / `AsyncWrite`. Useful for testing, @@ -728,10 +735,10 @@ none of them are scheduled for removal. | `lib.rs` | `Client`, `ClientOptions`, `CliProgram`, `Transport`, `Error` | | `session.rs` | `Session` struct, event loop, `send`/`send_and_wait`, `Client::create_session`/`resume_session` | | `subscription.rs` | `EventSubscription` / `LifecycleSubscription` (`Stream`-able observer handles for `subscribe()` / `subscribe_lifecycle()`) | -| `handler.rs` | `SessionHandler` trait, `HandlerEvent`/`HandlerResponse` enums, `ApproveAllHandler`, `DenyAllHandler`, `NoopHandler` | +| `handler.rs` | `PermissionHandler`, `ElicitationHandler`, `UserInputHandler`, `ExitPlanModeHandler`, `AutoModeSwitchHandler` traits; `ApproveAllHandler`, `DenyAllHandler` | | `hooks.rs` | `SessionHooks` trait, `HookEvent`/`HookOutput` enums, typed hook inputs/outputs | | `transforms.rs` | `SystemMessageTransform` trait, section-level system message customization | -| `tool.rs` | `ToolHandler` trait, `ToolHandlerRouter`, `schema_for::()` (with `derive` feature) | +| `tool.rs` | `ToolHandler` trait, `define_tool`, `schema_for::()` (with `derive` feature) | | `types.rs` | CLI protocol types (`SessionId`, `SessionEvent`, `SessionConfig`, `Tool`, etc.) | | `resolve.rs` | Binary resolution (`copilot_binary`, `node_binary`, `extended_path`) | | `embeddedcli.rs` | Embedded CLI extraction (`embedded-cli` feature) | diff --git a/rust/examples/chat.rs b/rust/examples/chat.rs index 37293c6bc..6b361fdea 100644 --- a/rust/examples/chat.rs +++ b/rust/examples/chat.rs @@ -13,39 +13,29 @@ use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; -use github_copilot_sdk::handler::{ - HandlerEvent, HandlerResponse, PermissionResult, SessionHandler, UserInputResponse, -}; -use github_copilot_sdk::types::{MessageOptions, SessionConfig, SessionEvent}; +use github_copilot_sdk::handler::{ApproveAllHandler, UserInputHandler, UserInputResponse}; +use github_copilot_sdk::types::{MessageOptions, SessionConfig, SessionEvent, SessionId}; use github_copilot_sdk::{Client, ClientOptions}; -/// Handler that prints assistant message deltas as they stream in -/// and auto-approves permissions. -struct ChatHandler; +/// User input handler that prompts on stdin. +struct StdinUserInputHandler; #[async_trait] -impl SessionHandler for ChatHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::SessionEvent { event, .. } => { - print_event(&event); - HandlerResponse::Ok - } - HandlerEvent::PermissionRequest { .. } => { - HandlerResponse::Permission(PermissionResult::Approved) - } - HandlerEvent::UserInput { question, .. } => { - // Prompt the user on behalf of the agent. - print!("\n[agent asks] {question}\n> "); - io::stdout().flush().ok(); - let answer = read_line().unwrap_or_default(); - HandlerResponse::UserInput(Some(UserInputResponse { - answer, - was_freeform: true, - })) - } - _ => HandlerResponse::Ok, - } +impl UserInputHandler for StdinUserInputHandler { + async fn handle( + &self, + _session_id: SessionId, + question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + print!("\n[agent asks] {question}\n> "); + io::stdout().flush().ok(); + let answer = read_line()?; + Some(UserInputResponse { + answer, + was_freeform: true, + }) } } @@ -91,9 +81,11 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let client = Client::start(ClientOptions::default()).await?; let config = { - let mut cfg = SessionConfig::default(); + let mut cfg = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_user_input_handler(Arc::new(StdinUserInputHandler)); cfg.streaming = Some(true); - cfg.with_handler(Arc::new(ChatHandler)) + cfg }; let session = client.create_session(config).await?; @@ -102,6 +94,14 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { session.id() ); + // Spawn a task to print streamed assistant deltas as session events arrive. + let mut events = session.subscribe(); + tokio::spawn(async move { + while let Ok(event) = events.recv().await { + print_event(&event); + } + }); + loop { print!("> "); io::stdout().flush().ok(); @@ -117,6 +117,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } println!("\nGoodbye."); - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/rust/examples/hooks.rs b/rust/examples/hooks.rs index 86f6ceadc..79bf81551 100644 --- a/rust/examples/hooks.rs +++ b/rust/examples/hooks.rs @@ -103,7 +103,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { // hooks: true is set automatically when a hooks handler is provided. let config = SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_hooks(Arc::new(AuditHooks)); let session = client.create_session(config).await?; @@ -128,6 +128,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { println!("\n{text}"); } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/rust/examples/lifecycle_observer.rs b/rust/examples/lifecycle_observer.rs index 612792073..8edb2cd38 100644 --- a/rust/examples/lifecycle_observer.rs +++ b/rust/examples/lifecycle_observer.rs @@ -39,7 +39,7 @@ use github_copilot_sdk::{Client, ClientOptions}; #[tokio::main] async fn main() -> Result<(), github_copilot_sdk::Error> { let client = Client::start(ClientOptions::default()).await?; - println!("[client] state: {:?}", client.state()); + println!("[client] started, pid: {:?}", client.pid()); // Wildcard lifecycle subscriber: see every session.lifecycle event, // counting deletions inline by filtering on event_type. @@ -63,9 +63,9 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } }); - let config = SessionConfig::default().with_handler(Arc::new(ApproveAllHandler)); + let config = SessionConfig::default().with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; - println!("[client] state after create: {:?}", client.state()); + println!("[client] session created: {}", session.id()); // Per-session observer: see every assistant message, tool call, etc. // Subscribers fire alongside the constructor handler; they're great for @@ -97,13 +97,13 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { ) .await?; - session.destroy().await?; + session.disconnect().await?; // Synchronous shutdown — useful in panicking-cleanup paths or tests // where you don't have an async runtime available to await `stop()`. // For graceful shutdown in normal flow, prefer `client.stop().await`. client.force_stop(); - println!("[client] state after force_stop: {:?}", client.state()); + println!("[client] force-stopped"); // Stopping the client closes the broadcast senders, so the consumer // tasks observe `RecvError::Closed` and exit cleanly. diff --git a/rust/examples/session_fs.rs b/rust/examples/session_fs.rs index 0dbbb3414..924e6947f 100644 --- a/rust/examples/session_fs.rs +++ b/rust/examples/session_fs.rs @@ -125,7 +125,7 @@ async fn main() -> Result<(), Box> { let session = client .create_session( SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_session_fs_provider(provider), ) .await?; diff --git a/rust/examples/tool_server.rs b/rust/examples/tool_server.rs index 55bacbbe6..93492d20c 100644 --- a/rust/examples/tool_server.rs +++ b/rust/examples/tool_server.rs @@ -30,9 +30,7 @@ use async_trait::async_trait; #[cfg(feature = "derive")] use github_copilot_sdk::handler::ApproveAllHandler; #[cfg(feature = "derive")] -use github_copilot_sdk::tool::{ - JsonSchema, ToolHandler, ToolHandlerRouter, schema_for, tool_parameters, -}; +use github_copilot_sdk::tool::{JsonSchema, ToolHandler, schema_for}; #[cfg(feature = "derive")] use github_copilot_sdk::types::{MessageOptions, SessionConfig, Tool, ToolInvocation, ToolResult}; #[cfg(feature = "derive")] @@ -59,14 +57,6 @@ struct GetWeatherTool; #[cfg(feature = "derive")] #[async_trait] impl ToolHandler for GetWeatherTool { - fn tool(&self) -> Tool { - let mut tool = Tool::default(); - tool.name = "get_weather".to_string(); - tool.description = "Get the current weather for a city.".to_string(); - tool.parameters = tool_parameters(schema_for::()); - tool - } - async fn call(&self, invocation: ToolInvocation) -> Result { let params: GetWeatherParams = serde_json::from_value(invocation.arguments)?; let unit = params.unit.as_deref().unwrap_or("celsius"); @@ -90,20 +80,6 @@ struct RollDiceTool; #[cfg(feature = "derive")] #[async_trait] impl ToolHandler for RollDiceTool { - fn tool(&self) -> Tool { - let mut tool = Tool::default(); - tool.name = "roll_dice".to_string(); - tool.description = "Roll one or more dice and return the total.".to_string(); - tool.parameters = tool_parameters(serde_json::json!({ - "type": "object", - "properties": { - "sides": { "type": "integer", "description": "Number of sides per die (default 6, max 1000)." }, - "count": { "type": "integer", "description": "Number of dice to roll (default 1, max 100)." } - } - })); - tool - } - async fn call(&self, invocation: ToolInvocation) -> Result { let sides = invocation .arguments @@ -145,20 +121,26 @@ impl ToolHandler for RollDiceTool { #[cfg(feature = "derive")] #[tokio::main] async fn main() -> Result<(), github_copilot_sdk::Error> { - let router = ToolHandlerRouter::new( - vec![Box::new(GetWeatherTool), Box::new(RollDiceTool)], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); - let handler = Arc::new(router); - let client = Client::start(ClientOptions::default()).await?; - let config = { - let mut cfg = SessionConfig::default(); - cfg.tools = Some(tools); - cfg.with_handler(handler) - }; + let config = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(vec![ + Tool::new("get_weather") + .with_description("Get the current weather for a city.") + .with_parameters(schema_for::()) + .with_handler(Arc::new(GetWeatherTool)), + Tool::new("roll_dice") + .with_description("Roll one or more dice and return the total.") + .with_parameters(serde_json::json!({ + "type": "object", + "properties": { + "sides": { "type": "integer", "description": "Number of sides per die (default 6, max 1000)." }, + "count": { "type": "integer", "description": "Number of dice to roll (default 1, max 100)." } + } + })) + .with_handler(Arc::new(RollDiceTool)), + ]); let session = client.create_session(config).await?; println!( @@ -182,6 +164,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { println!("{text}"); } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/rust/src/handler.rs b/rust/src/handler.rs index 565b09d56..042565564 100644 --- a/rust/src/handler.rs +++ b/rust/src/handler.rs @@ -1,123 +1,29 @@ -//! Event handler traits for session lifecycle. +//! Optional session-callback traits. //! -//! The [`SessionHandler`](crate::handler::SessionHandler) trait is the primary extension point — implement -//! [`on_event`](crate::handler::SessionHandler::on_event) to control how sessions respond to -//! CLI events, permission requests, tool calls, and user input prompts. +//! Each callback the CLI may dispatch (permission requests, elicitation +//! prompts, user-input questions, exit-plan-mode prompts, +//! auto-mode-switch prompts) has its own focused trait with a single +//! `handle` method. +//! +//! Handlers are **optional**: install only the ones the application cares +//! about. The SDK derives the corresponding wire flag on +//! `session.create` / `session.resume` from the presence of each handler, +//! so the runtime does not emit broadcasts this client would never +//! respond to. +//! +//! Tool dispatch uses its own per-tool registry built from +//! [`Tool::with_handler`](crate::types::Tool::with_handler) on entries passed to +//! [`SessionConfig::with_tools`](crate::types::SessionConfig::with_tools). use async_trait::async_trait; use serde::{Deserialize, Serialize}; use crate::types::{ ElicitationRequest, ElicitationResult, ExitPlanModeData, PermissionRequestData, RequestId, - SessionEvent, SessionId, ToolInvocation, ToolResult, + SessionId, }; -/// Events dispatched by the SDK session event loop to the handler. -/// -/// The handler returns a [`HandlerResponse`] indicating how the SDK should -/// respond to the CLI. For fire-and-forget events (`SessionEvent`), the -/// response is ignored. -#[non_exhaustive] -#[derive(Debug)] -pub enum HandlerEvent { - /// Informational session event from the timeline (e.g. assistant.message_delta, - /// session.idle, tool.execution_start). Fire-and-forget — return `HandlerResponse::Ok`. - SessionEvent { - /// The session that emitted this event. - session_id: SessionId, - /// The event payload. - event: SessionEvent, - }, - - /// The CLI requests permission for an action. Return `HandlerResponse::Permission(..)`. - PermissionRequest { - /// The requesting session. - session_id: SessionId, - /// Unique ID to correlate the response. - request_id: RequestId, - /// Permission request payload. - data: PermissionRequestData, - }, - - /// The CLI requests user input. Return `HandlerResponse::UserInput(..)`. - /// The handler may block (e.g. awaiting a UI dialog) — this is expected. - UserInput { - /// The requesting session. - session_id: SessionId, - /// The question text to present. - question: String, - /// Optional multiple-choice options. - choices: Option>, - /// Whether free-form text input is allowed. - allow_freeform: Option, - }, - - /// The CLI requests execution of a client-defined tool. - /// Return `HandlerResponse::ToolResult(..)`. - ExternalTool { - /// The tool call to execute. - invocation: ToolInvocation, - }, - - /// The CLI broadcasts an elicitation request for the provider to handle. - /// Return `HandlerResponse::Elicitation(..)`. - ElicitationRequest { - /// The requesting session. - session_id: SessionId, - /// Unique ID to correlate the response. - request_id: RequestId, - /// The elicitation request payload. - request: ElicitationRequest, - }, - - /// The CLI requests exiting plan mode. Return `HandlerResponse::ExitPlanMode(..)`. - ExitPlanMode { - /// The requesting session. - session_id: SessionId, - /// Plan mode exit payload. - data: ExitPlanModeData, - }, - - /// The CLI asks whether to switch to auto model when an eligible rate - /// limit is hit. Return [`HandlerResponse::AutoModeSwitch`]. - AutoModeSwitch { - /// The requesting session. - session_id: SessionId, - /// The specific rate-limit error code that triggered the request, - /// if known (e.g. `user_weekly_rate_limited`, `user_global_rate_limited`). - error_code: Option, - /// Seconds until the rate limit resets, when known. - retry_after_seconds: Option, - }, -} - -/// Response from the handler back to the SDK, used to construct the -/// JSON-RPC reply sent to the CLI. -#[non_exhaustive] -#[derive(Debug)] -pub enum HandlerResponse { - /// No response needed (used for fire-and-forget `SessionEvent`s). - Ok, - /// Do not send a response. The consumer will resolve the pending request out-of-band. - NoResult, - /// Permission decision. - Permission(PermissionResult), - /// User input response (or `None` to signal no input available). - UserInput(Option), - /// Result of a tool execution. - ToolResult(ToolResult), - /// Elicitation result (accept/decline/cancel with optional form data). - Elicitation(ElicitationResult), - /// Exit plan mode decision. - ExitPlanMode(ExitPlanModeResult), - /// Auto-mode-switch decision. - AutoModeSwitch(AutoModeSwitchResponse), -} - /// Result of a permission request. -/// -/// `#[non_exhaustive]` so future variants can be added without a major -/// version bump. Match arms must include a `_` fallback. #[derive(Debug, Clone)] #[non_exhaustive] pub enum PermissionResult { @@ -126,34 +32,20 @@ pub enum PermissionResult { /// Permission denied. Denied, /// Defer the response. The handler will resolve this request itself - /// later — typically after a UI prompt — by calling + /// later -- typically after a UI prompt -- by calling /// `session.permissions.handlePendingPermissionRequest` directly. The - /// SDK will not send a response for this request. - /// - /// **Notification path only** (`permission.requested`). On the direct - /// RPC path (`permission.request`), `Deferred` falls back to - /// [`Approved`](Self::Approved) because that path must return a value - /// to satisfy the JSON-RPC reply contract. + /// SDK skips its own response on this path. Deferred, - /// Provide the full response payload. The SDK passes the value as-is - /// in the `result` field of `handlePendingPermissionRequest` - /// (notification path) or as the JSON-RPC `result` directly (direct - /// RPC path). - /// - /// Use this for response shapes beyond `{ "kind": "approve-once" }` - /// or `{ "kind": "reject" }` — for example, "approve and remember" - /// with allowlist data. + /// Provide the full response payload directly. The SDK forwards the + /// value as-is on the wire. Custom(serde_json::Value), - /// No user is available to respond — for example, headless agents - /// without an interactive session. Sent as - /// `{ "kind": "user-not-available" }`. - UserNotAvailable, - /// The handler has no result to provide and the CLI should fall back - /// to another permission responder or its default policy. On the - /// notification path, the SDK will not send a pending permission response. - /// Distinct from [`Deferred`](Self::Deferred), where the handler takes - /// responsibility for resolving the request later out-of-band. + /// Decline to handle this broadcast. The SDK does not send a response, + /// which lets another connected client respond instead. NoResult, + /// No user is available to answer the prompt. On the notification + /// path, the SDK will not send a pending response. On the direct + /// RPC path, the SDK responds with `{ "kind": "user-not-available" }`. + UserNotAvailable, } /// Response to a user input request. @@ -189,267 +81,103 @@ impl Default for ExitPlanModeResult { } } -/// Response to a [`HandlerEvent::AutoModeSwitch`] request. -/// -/// Wire serialization matches the CLI's `autoModeSwitch.request` response -/// schema: `"yes"`, `"yes_always"`, or `"no"`. +/// Response to an auto-mode-switch request. #[non_exhaustive] #[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum AutoModeSwitchResponse { /// Approve the auto-mode switch for this rate-limit cycle only. Yes, - /// Approve and remember — auto-accept future auto-mode switches in this - /// session without prompting. + /// Approve and remember -- auto-accept future auto-mode switches in + /// this session without prompting. YesAlways, - /// Decline the auto-mode switch. The session stays on the current model - /// and surfaces the rate-limit error. + /// Decline the auto-mode switch. The session stays on the current + /// model and surfaces the rate-limit error. No, } -/// Callback trait for session events. +/// Handler for `permission.requested` broadcasts. /// -/// Implement this trait to control how a session responds to CLI events, -/// permission requests, tool calls, user input prompts, elicitations, and -/// plan-mode exits. There are two styles of implementation — pick whichever -/// fits your use case: -/// -/// 1. **Per-event methods (recommended for most handlers).** Override the -/// specific `on_*` methods you care about; every method has a safe -/// default so you only write what you need. This is the pattern used by -/// [`serenity::EventHandler`][serenity], `lapin`, and most Rust SDKs -/// that dispatch broker/client callbacks. -/// 2. **Single [`on_event`](Self::on_event) method.** Override this one -/// method and `match` on [`HandlerEvent`] yourself. Useful for logging -/// middleware, custom routing, or when you want an exhaustiveness check -/// across all variants. -/// -/// When you override [`on_event`](Self::on_event) directly, the per-event methods are not -/// called — your implementation is entirely responsible for dispatch. The -/// default [`on_event`](Self::on_event) fans out to the per-event methods. -/// -/// [serenity]: https://docs.rs/serenity/latest/serenity/client/trait.EventHandler.html -/// -/// # Default behavior -/// -/// - Permission requests → **denied**. -/// - User input → `None` (no answer available). -/// - External tool calls → failure result with "no handler registered". -/// - Elicitation → `"cancel"`. -/// - Exit plan mode → [`ExitPlanModeResult::default`]. -/// - Auto-mode-switch → [`AutoModeSwitchResponse::No`] (decline by default; the -/// session stays on its current model and surfaces the rate-limit error). -/// - Session events → ignored (fire-and-forget). -/// -/// # Concurrency -/// -/// **Request-triggered events** (`UserInput`, `ExternalTool` via `tool.call`, -/// `ExitPlanMode`, `PermissionRequest` via `permission.request`) are awaited -/// inline in the event loop and therefore processed **serially** per session. -/// Blocking here pauses that session's event loop — which is correct, since -/// the CLI is also blocked waiting for the response. -/// -/// **Notification-triggered events** (`PermissionRequest` via -/// `permission.requested`, `ExternalTool` via `external_tool.requested`) are -/// dispatched on spawned tasks and may run **concurrently** with each other -/// and with the serial event loop. Implementations must be safe for -/// concurrent invocation. -/// -/// # Example -/// -/// ```no_run -/// use async_trait::async_trait; -/// use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; -/// use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionId}; -/// -/// struct ApproveReadsOnly; -/// -/// #[async_trait] -/// impl SessionHandler for ApproveReadsOnly { -/// async fn on_permission_request( -/// &self, -/// _sid: SessionId, -/// _rid: RequestId, -/// data: PermissionRequestData, -/// ) -> PermissionResult { -/// match data.extra.get("tool").and_then(|v| v.as_str()) { -/// Some("view") | Some("ls") | Some("grep") => PermissionResult::Approved, -/// _ => PermissionResult::Denied, -/// } -/// } -/// } -/// ``` +/// Install via +/// [`SessionConfig::with_permission_handler`](crate::types::SessionConfig::with_permission_handler) +/// (or the matching method on [`ResumeSessionConfig`](crate::types::ResumeSessionConfig)). +/// When no permission handler is supplied, the SDK sends +/// `requestPermission: false` on the wire and the runtime short-circuits +/// permission prompts for this client. #[async_trait] -pub trait SessionHandler: Send + Sync + 'static { - /// Handle an event from the session. - /// - /// The default implementation destructures `event` and calls the - /// matching per-event method (e.g. [`on_permission_request`](Self::on_permission_request) - /// for [`HandlerEvent::PermissionRequest`]). Override this method only - /// if you want a single dispatch point with exhaustive matching — most - /// handlers should override the per-event methods instead. - /// - /// See the [trait-level docs](SessionHandler#concurrency) for details on - /// which events may be dispatched concurrently. - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::SessionEvent { session_id, event } => { - self.on_session_event(session_id, event).await; - HandlerResponse::Ok - } - HandlerEvent::PermissionRequest { - session_id, - request_id, - data, - } => HandlerResponse::Permission( - self.on_permission_request(session_id, request_id, data) - .await, - ), - HandlerEvent::UserInput { - session_id, - question, - choices, - allow_freeform, - } => HandlerResponse::UserInput( - self.on_user_input(session_id, question, choices, allow_freeform) - .await, - ), - HandlerEvent::ExternalTool { invocation } => { - HandlerResponse::ToolResult(self.on_external_tool(invocation).await) - } - HandlerEvent::ElicitationRequest { - session_id, - request_id, - request, - } => HandlerResponse::Elicitation( - self.on_elicitation(session_id, request_id, request).await, - ), - HandlerEvent::ExitPlanMode { session_id, data } => { - HandlerResponse::ExitPlanMode(self.on_exit_plan_mode(session_id, data).await) - } - HandlerEvent::AutoModeSwitch { - session_id, - error_code, - retry_after_seconds, - } => HandlerResponse::AutoModeSwitch( - self.on_auto_mode_switch(session_id, error_code, retry_after_seconds) - .await, - ), - } - } - - /// Informational timeline event (assistant messages, tool execution - /// markers, session idle, etc.). Fire-and-forget — the return value is - /// ignored. - /// - /// Default: do nothing. - async fn on_session_event(&self, _session_id: SessionId, _event: SessionEvent) {} - - /// The CLI is asking whether the agent may perform a privileged action. - /// - /// Default: [`PermissionResult::Denied`]. The default-deny posture - /// matches the CLI's safety model; override to implement your own - /// policy (see the [`permission`](crate::permission) module for common - /// wrappers like `approve_all` / `approve_if`). - async fn on_permission_request( +pub trait PermissionHandler: Send + Sync + 'static { + /// Resolve a permission request. + async fn handle( &self, - _session_id: SessionId, - _request_id: RequestId, - _data: PermissionRequestData, - ) -> PermissionResult { - PermissionResult::Denied - } + session_id: SessionId, + request_id: RequestId, + data: PermissionRequestData, + ) -> PermissionResult; +} - /// The CLI is asking the user a question (optionally with a list of - /// choices). - /// - /// Default: `None` — the CLI interprets this as "no answer available" - /// and falls back to its own prompt behavior. - async fn on_user_input( +/// Handler for `elicitation.requested` broadcasts. +/// +/// When unset, `requestElicitation: false` goes on the wire. +#[async_trait] +pub trait ElicitationHandler: Send + Sync + 'static { + /// Respond to an elicitation prompt (form, URL confirm, etc.). + async fn handle( &self, - _session_id: SessionId, - _question: String, - _choices: Option>, - _allow_freeform: Option, - ) -> Option { - None - } - - /// The CLI wants to invoke a client-defined ("external") tool. - /// - /// Default: a failure [`ToolResult`] indicating no tool handler is - /// registered. Typical implementations route to a - /// [`ToolHandlerRouter`](crate::tool::ToolHandlerRouter) which - /// dispatches to tools registered via - /// [`define_tool`](crate::tool::define_tool) or custom - /// [`ToolHandler`](crate::tool::ToolHandler) impls. - async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult { - let msg = format!("No handler registered for tool '{}'", invocation.tool_name); - ToolResult::Expanded(crate::types::ToolResultExpanded { - text_result_for_llm: msg.clone(), - result_type: "failure".to_string(), - binary_results_for_llm: None, - session_log: None, - error: Some(msg), - tool_telemetry: None, - }) - } + session_id: SessionId, + request_id: RequestId, + request: ElicitationRequest, + ) -> ElicitationResult; +} - /// The CLI is requesting an elicitation (structured form / URL prompt). - /// - /// Default: cancel. - async fn on_elicitation( +/// Handler for `user_input.requested` events from the `ask_user` tool. +/// +/// When unset, `requestUserInput: false` goes on the wire and the +/// `ask_user` tool is disabled for the session. +#[async_trait] +pub trait UserInputHandler: Send + Sync + 'static { + /// Answer a question on behalf of the user. Return `None` to signal + /// "no answer available". + async fn handle( &self, - _session_id: SessionId, - _request_id: RequestId, - _request: ElicitationRequest, - ) -> ElicitationResult { - ElicitationResult { - action: "cancel".to_string(), - content: None, - } - } + session_id: SessionId, + question: String, + choices: Option>, + allow_freeform: Option, + ) -> Option; +} - /// The CLI is asking the user whether to exit plan mode. - /// - /// Default: [`ExitPlanModeResult::default`] (approved with no action). - async fn on_exit_plan_mode( - &self, - _session_id: SessionId, - _data: ExitPlanModeData, - ) -> ExitPlanModeResult { - ExitPlanModeResult::default() - } +/// Handler for `exit_plan_mode.requested` events. When unset, +/// `requestExitPlanMode: false` goes on the wire. +#[async_trait] +pub trait ExitPlanModeHandler: Send + Sync + 'static { + /// Decide whether to leave plan mode. + async fn handle(&self, session_id: SessionId, data: ExitPlanModeData) -> ExitPlanModeResult; +} - /// The CLI is asking whether to switch to auto model after an eligible - /// rate limit. - /// - /// `retry_after_seconds`, when present, is the number of seconds until the - /// rate limit resets. Handlers can use it to render a humanized reset time - /// alongside the prompt. - /// - /// Default: [`AutoModeSwitchResponse::No`] — decline. Override only if - /// your application surfaces a UX for the rate-limit-recovery prompt. - async fn on_auto_mode_switch( +/// Handler for `auto_mode_switch.requested` events. When unset, +/// `requestAutoModeSwitch: false` goes on the wire. +#[async_trait] +pub trait AutoModeSwitchHandler: Send + Sync + 'static { + /// Decide whether to fall back to the auto model after an eligible + /// rate-limit error. `retry_after_seconds`, when present, is the + /// number of seconds until the rate limit resets. + async fn handle( &self, - _session_id: SessionId, - _error_code: Option, - _retry_after_seconds: Option, - ) -> AutoModeSwitchResponse { - AutoModeSwitchResponse::No - } + session_id: SessionId, + error_code: Option, + retry_after_seconds: Option, + ) -> AutoModeSwitchResponse; } -/// A [`SessionHandler`] that auto-approves all permissions and ignores all events. -/// -/// Useful for CLI tools, scripts, and tests that don't need interactive -/// permission prompts or custom tool handling. +/// A [`PermissionHandler`] that approves every request. Useful for CLI +/// tools, scripts, and tests that don't need interactive permission +/// prompts. #[derive(Debug, Clone)] pub struct ApproveAllHandler; #[async_trait] -impl SessionHandler for ApproveAllHandler { - async fn on_permission_request( +impl PermissionHandler for ApproveAllHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -459,218 +187,47 @@ impl SessionHandler for ApproveAllHandler { } } -/// A [`SessionHandler`] that denies all permission requests and otherwise -/// relies on the trait's default fallback responses for every other event -/// (e.g. tool invocations return "unhandled", elicitations cancel, plan-mode -/// prompts decline). Use this when a session should never wait for manual -/// permission approval. +/// A [`PermissionHandler`] that denies every request. #[derive(Debug, Clone)] pub struct DenyAllHandler; #[async_trait] -impl SessionHandler for DenyAllHandler { - // All defaults are already safe: permissions deny, everything else is a - // sensible fallback. We just reuse them here for clarity. -} - -/// A [`SessionHandler`] that leaves permission requests and external tool calls pending. -/// -/// This is the default used when no handler is set on -/// [`SessionConfig::handler`](crate::types::SessionConfig::handler). It lets consumers -/// observe `permission.requested` and `external_tool.requested` events and later resolve -/// them with the corresponding pending-request RPC methods. -#[derive(Debug, Clone)] -pub struct NoopHandler; - -#[async_trait] -impl SessionHandler for NoopHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::SessionEvent { .. } => HandlerResponse::Ok, - HandlerEvent::PermissionRequest { .. } => { - HandlerResponse::Permission(PermissionResult::NoResult) - } - HandlerEvent::UserInput { .. } => HandlerResponse::UserInput(None), - HandlerEvent::ExternalTool { .. } => HandlerResponse::NoResult, - HandlerEvent::ElicitationRequest { .. } => { - HandlerResponse::Elicitation(ElicitationResult { - action: "cancel".to_string(), - content: None, - }) - } - HandlerEvent::ExitPlanMode { .. } => { - HandlerResponse::ExitPlanMode(ExitPlanModeResult::default()) - } - HandlerEvent::AutoModeSwitch { .. } => { - HandlerResponse::AutoModeSwitch(AutoModeSwitchResponse::No) - } - } +impl PermissionHandler for DenyAllHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Denied } } #[cfg(test)] mod tests { - use serde_json::Value; - use super::*; - use crate::types::{PermissionRequestData, RequestId, SessionId}; - - fn perm_data() -> PermissionRequestData { - PermissionRequestData::default() - } - - // A handler that overrides only `on_permission_request` (per-method style). - struct ApproveViaPerMethod; - - #[async_trait] - impl SessionHandler for ApproveViaPerMethod { - async fn on_permission_request( - &self, - _: SessionId, - _: RequestId, - _: PermissionRequestData, - ) -> PermissionResult { - PermissionResult::Approved - } - } - - // A handler that overrides `on_event` directly (legacy / routing style). - struct ApproveViaOnEvent; - - #[async_trait] - impl SessionHandler for ApproveViaOnEvent { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::PermissionRequest { .. } => { - HandlerResponse::Permission(PermissionResult::Approved) - } - _ => HandlerResponse::Ok, - } - } - } - - #[tokio::test] - async fn per_method_override_dispatches_via_default_on_event() { - let h = ApproveViaPerMethod; - let resp = h - .on_event(HandlerEvent::PermissionRequest { - session_id: SessionId::from("s1".to_string()), - request_id: RequestId::new("r1"), - data: perm_data(), - }) - .await; - assert!(matches!( - resp, - HandlerResponse::Permission(PermissionResult::Approved) - )); - } - - #[tokio::test] - async fn on_event_override_short_circuits_per_method_defaults() { - let h = ApproveViaOnEvent; - let resp = h - .on_event(HandlerEvent::PermissionRequest { - session_id: SessionId::from("s1".to_string()), - request_id: RequestId::new("r1"), - data: perm_data(), - }) - .await; - assert!(matches!( - resp, - HandlerResponse::Permission(PermissionResult::Approved) - )); - } #[tokio::test] - async fn deny_all_handler_uses_default_permission_deny() { - let h = DenyAllHandler; - let resp = h - .on_event(HandlerEvent::PermissionRequest { - session_id: SessionId::from("s1".to_string()), - request_id: RequestId::new("r1"), - data: perm_data(), - }) + async fn approve_all_handler_returns_approved() { + let result = ApproveAllHandler + .handle( + SessionId::from("s1"), + RequestId::new("1"), + PermissionRequestData::default(), + ) .await; - assert!(matches!( - resp, - HandlerResponse::Permission(PermissionResult::Denied) - )); + assert!(matches!(result, PermissionResult::Approved)); } #[tokio::test] - async fn default_on_external_tool_returns_failure() { - let h = DenyAllHandler; - let resp = h - .on_event(HandlerEvent::ExternalTool { - invocation: crate::types::ToolInvocation { - session_id: SessionId::from("s1".to_string()), - tool_call_id: "tc1".to_string(), - tool_name: "missing".to_string(), - arguments: Value::Null, - traceparent: None, - tracestate: None, - }, - }) + async fn deny_all_handler_returns_denied() { + let result = DenyAllHandler + .handle( + SessionId::from("s1"), + RequestId::new("1"), + PermissionRequestData::default(), + ) .await; - match resp { - HandlerResponse::ToolResult(crate::types::ToolResult::Expanded(exp)) => { - assert_eq!(exp.result_type, "failure"); - assert!(exp.text_result_for_llm.contains("missing")); - assert_eq!(exp.error.as_deref(), Some(exp.text_result_for_llm.as_str())); - } - other => panic!("unexpected response: {other:?}"), - } - } - - #[tokio::test] - async fn noop_handler_leaves_permission_and_external_tool_pending() { - let h = NoopHandler; - let permission = h - .on_event(HandlerEvent::PermissionRequest { - session_id: SessionId::from("s1".to_string()), - request_id: RequestId::new("r1"), - data: perm_data(), - }) - .await; - assert!(matches!( - permission, - HandlerResponse::Permission(PermissionResult::NoResult) - )); - - let tool = h - .on_event(HandlerEvent::ExternalTool { - invocation: crate::types::ToolInvocation { - session_id: SessionId::from("s1".to_string()), - tool_call_id: "tc1".to_string(), - tool_name: "manual".to_string(), - arguments: Value::Null, - traceparent: None, - tracestate: None, - }, - }) - .await; - assert!(matches!(tool, HandlerResponse::NoResult)); - } - - #[tokio::test] - async fn default_on_elicitation_returns_cancel() { - let h = DenyAllHandler; - let resp = h - .on_event(HandlerEvent::ElicitationRequest { - session_id: SessionId::from("s1".to_string()), - request_id: RequestId::new("r1"), - request: crate::types::ElicitationRequest { - message: "test".to_string(), - requested_schema: None, - mode: Some(crate::types::ElicitationMode::Form), - elicitation_source: None, - url: None, - }, - }) - .await; - match resp { - HandlerResponse::Elicitation(r) => assert_eq!(r.action, "cancel"), - other => panic!("unexpected response: {other:?}"), - } + assert!(matches!(result, PermissionResult::Denied)); } } diff --git a/rust/src/lib.rs b/rust/src/lib.rs index abb1a72a4..0c8fd33d2 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -10,7 +10,7 @@ pub mod handler; /// Lifecycle hook callbacks (pre/post tool use, prompt submission, session start/end). pub mod hooks; mod jsonrpc; -/// Permission-policy helpers that wrap an existing [`handler::SessionHandler`]. +/// Permission-policy helpers that produce a [`handler::PermissionHandler`]. pub mod permission; /// GitHub Copilot CLI binary resolution (env var, embedded, PATH search). pub mod resolve; @@ -30,6 +30,7 @@ pub mod trace_context; pub mod transforms; /// Protocol types shared between the SDK and the GitHub Copilot CLI. pub mod types; +mod wire; /// Auto-generated protocol types from Copilot JSON Schemas. pub mod generated; @@ -289,6 +290,10 @@ pub enum Transport { Tcp { /// Port to listen on (0 for OS-assigned). port: u16, + /// Optional connection token. When `None` and the SDK is spawning + /// the CLI, the SDK auto-generates a 128-bit hex token so the + /// loopback listener is safe by default. + connection_token: Option, }, /// Connect to an already-running CLI server (no process spawning). External { @@ -296,6 +301,9 @@ pub enum Transport { host: String, /// Port of the running server. port: u16, + /// Optional connection token. Required when the external server + /// was started with a token, ignored otherwise. + connection_token: Option, }, } @@ -350,7 +358,8 @@ pub struct ClientOptions { /// [`Self::github_token`] is set, in which case false). pub use_logged_in_user: Option, /// Log level passed to the CLI server via `--log-level`. When `None`, - /// the SDK uses [`LogLevel::Info`]. + /// the SDK does not pass `--log-level` to the runtime at all and the + /// CLI uses its built-in default. pub log_level: Option, /// Server-wide idle timeout for sessions, in seconds. When set to a /// positive value, the SDK passes `--session-idle-timeout ` to @@ -392,23 +401,13 @@ pub struct ClientOptions { /// auth, telemetry buffers). When set, exported as `COPILOT_HOME` to /// the spawned CLI process. Useful for sandboxing test runs or /// running multiple isolated SDK instances side-by-side. - pub copilot_home: Option, - /// Optional connection token for TCP transport. Sent to the CLI in - /// the `connect` handshake and exported as `COPILOT_CONNECTION_TOKEN` - /// to spawned CLI processes. Required when the CLI server was started - /// with a token, ignored otherwise. - /// - /// When the SDK spawns its own CLI in TCP mode and this is left - /// `None`, a UUID is generated automatically so the loopback listener - /// is safe by default. Combining with [`Transport::Stdio`] is invalid - /// and surfaces as an error from [`Client::start`]. - pub tcp_connection_token: Option, + pub base_directory: Option, /// Enable remote session support (Mission Control integration). /// When `true`, the SDK passes `--remote` to the spawned CLI process so /// sessions in a GitHub repository working directory are accessible from /// GitHub web and mobile. Ignored when connecting to an external server /// via [`Transport::External`]. - pub remote: bool, + pub enable_remote_sessions: bool, } impl std::fmt::Debug for ClientOptions { @@ -441,12 +440,8 @@ impl std::fmt::Debug for ClientOptions { &self.on_get_trace_context.as_ref().map(|_| ""), ) .field("telemetry", &self.telemetry) - .field("copilot_home", &self.copilot_home) - .field( - "tcp_connection_token", - &self.tcp_connection_token.as_ref().map(|_| ""), - ) - .field("remote", &self.remote) + .field("base_directory", &self.base_directory) + .field("enable_remote_sessions", &self.enable_remote_sessions) .finish() } } @@ -475,7 +470,7 @@ pub enum LogLevel { Error, /// Warnings and errors. Warning, - /// Default. Info and above. + /// Info and above. Info, /// Debug, info, warnings, errors. Debug, @@ -651,9 +646,8 @@ impl Default for ClientOptions { session_fs: None, on_get_trace_context: None, telemetry: None, - copilot_home: None, - tcp_connection_token: None, - remote: false, + base_directory: None, + enable_remote_sessions: false, } } } @@ -799,23 +793,15 @@ impl ClientOptions { /// Override the directory where the CLI persists its state. Set as /// `COPILOT_HOME` on the spawned CLI process. - pub fn with_copilot_home(mut self, home: impl Into) -> Self { - self.copilot_home = Some(home.into()); - self - } - - /// Set the connection token for TCP transport. Sent in the `connect` - /// handshake and exported as `COPILOT_CONNECTION_TOKEN` to spawned - /// CLI processes. - pub fn with_tcp_connection_token(mut self, token: impl Into) -> Self { - self.tcp_connection_token = Some(token.into()); + pub fn with_base_directory(mut self, dir: impl Into) -> Self { + self.base_directory = Some(dir.into()); self } /// Enable remote session support (Mission Control). Passes `--remote` /// to the spawned CLI process. - pub fn with_remote(mut self, enabled: bool) -> Self { - self.remote = enabled; + pub fn with_enable_remote_sessions(mut self, enabled: bool) -> Self { + self.enable_remote_sessions = enabled; self } } @@ -929,39 +915,48 @@ impl Client { )); } } - // Validate token + transport combination. Stdio cannot use a - // connection token; auto-generate a UUID when the SDK spawns - // its own CLI in TCP mode and no explicit token was set. - if let Some(token) = &options.tcp_connection_token { - if token.is_empty() { - return Err(Error::InvalidConfig( - "tcp_connection_token must be a non-empty string".to_string(), - )); + // Validate token shape. Stdio variants no longer carry a token + // (enforced by the type). For Tcp/External, empty-string is + // rejected eagerly. + match &options.transport { + Transport::Tcp { + connection_token: Some(t), + .. } - if matches!(options.transport, Transport::Stdio) { + | Transport::External { + connection_token: Some(t), + .. + } if t.is_empty() => { return Err(Error::InvalidConfig( - "tcp_connection_token cannot be used with Transport::Stdio".to_string(), + "connection_token must be a non-empty string".to_string(), )); } + _ => {} } - let effective_connection_token: Option = match &options.transport { - Transport::Stdio => None, - Transport::Tcp { .. } => Some( - options - .tcp_connection_token - .clone() - .unwrap_or_else(generate_connection_token), - ), - Transport::External { .. } => options.tcp_connection_token.clone(), + // Capture (and where needed, auto-generate) the token actually sent + // to the server. For Tcp, the SDK auto-generates one when the + // caller leaves it unset so the loopback listener is safe by + // default. + let (mut options, effective_connection_token) = { + let mut options = options; + let effective = match &mut options.transport { + Transport::Stdio => None, + Transport::Tcp { + connection_token, .. + } => { + if connection_token.is_none() { + *connection_token = Some(generate_connection_token()); + } + connection_token.clone() + } + Transport::External { + connection_token, .. + } => connection_token.clone(), + }; + (options, effective) }; - let mut options = options; - if matches!(options.transport, Transport::Tcp { .. }) - && options.tcp_connection_token.is_none() - { - // Auto-generated tokens flow to the spawned CLI via env, so - // make the field reflect what we'll actually send. - options.tcp_connection_token = effective_connection_token.clone(); - } + let _ = &mut options; + let effective_connection_token: Option = effective_connection_token; let session_fs_config = options.session_fs.clone(); let session_fs_sqlite_declared = session_fs_config .as_ref() @@ -993,7 +988,11 @@ impl Client { }; let client = match options.transport { - Transport::External { ref host, port } => { + Transport::External { + ref host, + port, + connection_token: _, + } => { info!(host = %host, port = %port, "connecting to external CLI server"); let connect_start = Instant::now(); let stream = TcpStream::connect((host.as_str(), port)).await?; @@ -1016,7 +1015,10 @@ impl Client { effective_connection_token.clone(), )? } - Transport::Tcp { port } => { + Transport::Tcp { + port, + connection_token: _, + } => { let (mut child, actual_port) = Self::spawn_tcp(&program, &options, port).await?; let connect_start = Instant::now(); let stream = TcpStream::connect(("127.0.0.1", actual_port)).await?; @@ -1280,10 +1282,14 @@ impl Client { ); } } - if let Some(home) = &options.copilot_home { - command.env("COPILOT_HOME", home); + if let Some(dir) = &options.base_directory { + command.env("COPILOT_HOME", dir); } - if let Some(token) = &options.tcp_connection_token { + if let Transport::Tcp { + connection_token: Some(token), + .. + } = &options.transport + { command.env("COPILOT_CONNECTION_TOKEN", token); } for (key, value) in &options.env { @@ -1342,25 +1348,26 @@ impl Client { } fn remote_args(options: &ClientOptions) -> Vec { - if options.remote { + if options.enable_remote_sessions { vec!["--remote".to_string()] } else { Vec::new() } } + fn log_level_args(options: &ClientOptions) -> Vec<&'static str> { + match options.log_level { + Some(level) => vec!["--log-level", level.as_str()], + None => Vec::new(), + } + } + fn spawn_stdio(program: &Path, options: &ClientOptions) -> Result { info!(cwd = ?options.cwd, program = %program.display(), "spawning copilot CLI (stdio)"); let mut command = Self::build_command(program, options); - let log_level = options.log_level.unwrap_or(LogLevel::Info); command - .args([ - "--server", - "--stdio", - "--no-auto-update", - "--log-level", - log_level.as_str(), - ]) + .args(["--server", "--stdio", "--no-auto-update"]) + .args(Self::log_level_args(options)) .args(Self::auth_args(options)) .args(Self::session_idle_timeout_args(options)) .args(Self::remote_args(options)) @@ -1382,16 +1389,9 @@ impl Client { ) -> Result<(Child, u16), Error> { info!(cwd = ?options.cwd, program = %program.display(), port = %port, "spawning copilot CLI (tcp)"); let mut command = Self::build_command(program, options); - let log_level = options.log_level.unwrap_or(LogLevel::Info); command - .args([ - "--server", - "--port", - &port.to_string(), - "--no-auto-update", - "--log-level", - log_level.as_str(), - ]) + .args(["--server", "--port", &port.to_string(), "--no-auto-update"]) + .args(Self::log_level_args(options)) .args(Self::auth_args(options)) .args(Self::session_idle_timeout_args(options)) .args(Self::remote_args(options)) @@ -1586,8 +1586,8 @@ impl Client { /// /// # Handshake sequence /// - /// 1. Sends the `connect` JSON-RPC method, forwarding - /// [`ClientOptions::tcp_connection_token`] (or the auto-generated + /// 1. Sends the `connect` JSON-RPC method, forwarding the + /// [`Transport`]'s `connection_token` (or the auto-generated /// token for SDK-spawned TCP servers) as the `token` param. This /// is the canonical handshake used by all SDK languages and is /// what the CLI uses to enforce loopback authentication when @@ -1652,7 +1652,7 @@ impl Client { /// Send the `connect` JSON-RPC handshake. Returns the server's /// reported protocol version, or `None` if the server omits it. - /// Forwards [`ClientOptions::tcp_connection_token`] (or the + /// Forwards the [`Transport`]'s `connection_token` (or the /// auto-generated token for SDK-spawned TCP servers) as the `token` /// param. Server-side, the token is required when the server was /// started with `COPILOT_CONNECTION_TOKEN`. @@ -1986,16 +1986,6 @@ impl Client { pub fn subscribe_lifecycle(&self) -> LifecycleSubscription { LifecycleSubscription::new(self.inner.lifecycle_tx.subscribe()) } - - /// Return the current [`ConnectionState`]. - /// - /// The state advances to [`Connected`](ConnectionState::Connected) once - /// [`Client::start`] / [`Client::from_streams`] returns successfully and - /// drops to [`Disconnected`](ConnectionState::Disconnected) after - /// [`stop`](Self::stop) or [`force_stop`](Self::force_stop). - pub fn state(&self) -> ConnectionState { - *self.inner.state.lock() - } } impl Drop for ClientInner { @@ -2055,7 +2045,7 @@ mod tests { .with_use_logged_in_user(false) .with_log_level(LogLevel::Debug) .with_session_idle_timeout_seconds(120) - .with_remote(true); + .with_enable_remote_sessions(true); assert!(matches!(opts.program, CliProgram::Path(_))); assert_eq!(opts.prefix_args, vec![std::ffi::OsString::from("node")]); assert_eq!(opts.cwd, PathBuf::from("/tmp")); @@ -2072,7 +2062,7 @@ mod tests { assert_eq!(opts.use_logged_in_user, Some(false)); assert!(matches!(opts.log_level, Some(LogLevel::Debug))); assert_eq!(opts.session_idle_timeout_seconds, Some(120)); - assert!(opts.remote); + assert!(opts.enable_remote_sessions); } #[test] @@ -2275,7 +2265,7 @@ mod tests { #[test] fn build_command_sets_copilot_home_env_when_configured() { - let opts = ClientOptions::new().with_copilot_home(PathBuf::from("/custom/copilot")); + let opts = ClientOptions::new().with_base_directory(PathBuf::from("/custom/copilot")); let cmd = Client::build_command(Path::new("/bin/echo"), &opts); assert_eq!( env_value(&cmd, "COPILOT_HOME"), @@ -2289,7 +2279,10 @@ mod tests { #[test] fn build_command_sets_connection_token_env_when_configured() { - let opts = ClientOptions::new().with_tcp_connection_token("secret-token"); + let opts = ClientOptions::new().with_transport(Transport::Tcp { + port: 0, + connection_token: Some("secret-token".to_string()), + }); let cmd = Client::build_command(Path::new("/bin/echo"), &opts); assert_eq!( env_value(&cmd, "COPILOT_CONNECTION_TOKEN"), @@ -2302,26 +2295,25 @@ mod tests { } #[tokio::test] - async fn start_rejects_token_with_stdio_transport() { + async fn start_rejects_empty_connection_token() { let opts = ClientOptions::new() - .with_tcp_connection_token("token-123") + .with_transport(Transport::Tcp { + port: 0, + connection_token: Some(String::new()), + }) .with_program(CliProgram::Path(PathBuf::from("/bin/echo"))); let err = Client::start(opts).await.unwrap_err(); assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}"); - let Error::InvalidConfig(msg) = err else { - unreachable!() - }; - assert!( - msg.contains("Stdio"), - "error should explain the stdio incompatibility: {msg}" - ); } #[tokio::test] - async fn start_rejects_empty_connection_token() { + async fn start_rejects_empty_external_connection_token() { let opts = ClientOptions::new() - .with_tcp_connection_token("") - .with_transport(Transport::Tcp { port: 0 }) + .with_transport(Transport::External { + host: "127.0.0.1".to_string(), + port: 1, + connection_token: Some(String::new()), + }) .with_program(CliProgram::Path(PathBuf::from("/bin/echo"))); let err = Client::start(opts).await.unwrap_err(); assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}"); @@ -2397,12 +2389,28 @@ mod tests { #[test] fn remote_args_emit_flag_when_enabled() { let opts = ClientOptions { - remote: true, + enable_remote_sessions: true, ..Default::default() }; assert_eq!(Client::remote_args(&opts), vec!["--remote".to_string()]); } + #[test] + fn log_level_args_omitted_when_unset() { + let opts = ClientOptions::default(); + assert!(opts.log_level.is_none()); + assert!( + Client::log_level_args(&opts).is_empty(), + "with no caller-supplied log_level the SDK must not pass --log-level" + ); + } + + #[test] + fn log_level_args_emit_flag_when_set() { + let opts = ClientOptions::default().with_log_level(LogLevel::Debug); + assert_eq!(Client::log_level_args(&opts), vec!["--log-level", "debug"]); + } + #[test] fn log_level_str_round_trips() { for level in [ diff --git a/rust/src/permission.rs b/rust/src/permission.rs index 364cb3c91..22cf9bda9 100644 --- a/rust/src/permission.rs +++ b/rust/src/permission.rs @@ -1,106 +1,124 @@ -//! Permission-policy helpers that compose with an existing -//! [`SessionHandler`](crate::handler::SessionHandler). +//! Permission policy primitives that produce a [`PermissionHandler`](crate::handler::PermissionHandler). //! -//! These wrap an inner handler and override **only** permission requests, -//! forwarding every other event (tool calls, user input, elicitation, -//! session events) to the inner handler. Use them when you have a custom -//! tool handler — typically a [`ToolHandlerRouter`](crate::tool::ToolHandlerRouter) — -//! but want a one-line policy for permission prompts. +//! Compose these into a session via the builder methods +//! [`SessionConfig::approve_all_permissions`](crate::types::SessionConfig::approve_all_permissions), +//! [`deny_all_permissions`](crate::types::SessionConfig::deny_all_permissions), +//! and [`approve_permissions_if`](crate::types::SessionConfig::approve_permissions_if). +//! The same primitives are also available as standalone functions that +//! return an `Arc` you can install via +//! [`SessionConfig::with_permission_handler`](crate::types::SessionConfig::with_permission_handler). //! -//! For a full handler that approves or denies everything, see +//! For a one-shot approve / deny without composition, see //! [`ApproveAllHandler`](crate::handler::ApproveAllHandler) and //! [`DenyAllHandler`](crate::handler::DenyAllHandler). -//! -//! # Example -//! -//! ```rust,no_run -//! # use std::sync::Arc; -//! # use github_copilot_sdk::handler::ApproveAllHandler; -//! # use github_copilot_sdk::permission; -//! # use github_copilot_sdk::tool::ToolHandlerRouter; -//! let router = ToolHandlerRouter::new(vec![], Arc::new(ApproveAllHandler)); -//! // Inherit the router's tool dispatch but auto-approve all permission prompts: -//! let handler = permission::approve_all(Arc::new(router)); -//! ``` use std::sync::Arc; use async_trait::async_trait; -use crate::handler::{HandlerEvent, HandlerResponse, PermissionResult, SessionHandler}; -use crate::types::PermissionRequestData; +use crate::handler::{PermissionHandler, PermissionResult}; +use crate::types::{PermissionRequestData, RequestId, SessionId}; -/// Wrap `inner` so that every [`HandlerEvent::PermissionRequest`] is -/// auto-approved. All other events are forwarded to `inner`. -pub fn approve_all(inner: Arc) -> Arc { - Arc::new(PermissionOverrideHandler { - inner, +/// Return a [`PermissionHandler`] that approves every request. +pub fn approve_all() -> Arc { + Arc::new(PolicyHandler { policy: Policy::ApproveAll, }) } -/// Wrap `inner` so that every [`HandlerEvent::PermissionRequest`] is -/// auto-denied. All other events are forwarded to `inner`. -pub fn deny_all(inner: Arc) -> Arc { - Arc::new(PermissionOverrideHandler { - inner, +/// Return a [`PermissionHandler`] that denies every request. +pub fn deny_all() -> Arc { + Arc::new(PolicyHandler { policy: Policy::DenyAll, }) } -/// Wrap `inner` with a closure-based policy: `predicate` is called for each -/// permission request; `true` approves, `false` denies. All other events -/// are forwarded to `inner`. +/// Return a [`PermissionHandler`] that consults a predicate for each +/// request. `true` approves, `false` denies. /// /// ```rust,no_run -/// # use std::sync::Arc; -/// # use github_copilot_sdk::handler::ApproveAllHandler; /// # use github_copilot_sdk::permission; -/// let inner = Arc::new(ApproveAllHandler); -/// let handler = permission::approve_if(inner, |data| { -/// // Inspect data.extra (the raw JSON payload) for custom policy. +/// let handler = permission::approve_if(|data| { /// data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") /// }); /// # let _ = handler; /// ``` -pub fn approve_if(inner: Arc, predicate: F) -> Arc +pub fn approve_if(predicate: F) -> Arc where F: Fn(&PermissionRequestData) -> bool + Send + Sync + 'static, { - Arc::new(PermissionOverrideHandler { - inner, + Arc::new(PolicyHandler { policy: Policy::Predicate(Arc::new(predicate)), }) } -enum Policy { +/// Internal policy enum used by both the standalone helpers and the +/// `SessionConfig` policy builders. +/// +/// Stored as `pub(crate)` on `SessionConfig::permission_policy` so that +/// the order of `with_permission_handler(...)` and the policy builders +/// does not matter -- the policy is applied at `Client::create_session` +/// time. +#[derive(Clone)] +pub(crate) enum Policy { ApproveAll, DenyAll, Predicate(Arc bool + Send + Sync>), } -struct PermissionOverrideHandler { - inner: Arc, +impl std::fmt::Debug for Policy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ApproveAll => f.write_str("Policy::ApproveAll"), + Self::DenyAll => f.write_str("Policy::DenyAll"), + Self::Predicate(_) => f.write_str("Policy::Predicate()"), + } + } +} + +/// Resolve the effective permission handler for a session, given the +/// caller-supplied handler and policy. Called by `Client::create_session` +/// and `Client::resume_session`. +/// +/// Semantics: +/// - When `policy` is `Some`, the policy entirely replaces the handler +/// for permission decisions. (Caller-supplied handler, if any, is +/// discarded -- the policy is what answers permission requests.) +/// - When `policy` is `None` and `handler` is `Some`, the handler stands. +/// - When both are `None`, returns `None` (no handler -- the SDK sends +/// `requestPermission: false`). +pub(crate) fn resolve_handler( + handler: Option>, + policy: Option, +) -> Option> { + match (handler, policy) { + (_, Some(policy)) => Some(Arc::new(PolicyHandler { policy })), + (Some(h), None) => Some(h), + (None, None) => None, + } +} + +struct PolicyHandler { policy: Policy, } #[async_trait] -impl SessionHandler for PermissionOverrideHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::PermissionRequest { ref data, .. } => { - let approved = match &self.policy { - Policy::ApproveAll => true, - Policy::DenyAll => false, - Policy::Predicate(f) => f(data), - }; - HandlerResponse::Permission(if approved { - PermissionResult::Approved - } else { - PermissionResult::Denied - }) - } - other => self.inner.on_event(other).await, +impl PermissionHandler for PolicyHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + data: PermissionRequestData, + ) -> PermissionResult { + let approved = match &self.policy { + Policy::ApproveAll => true, + Policy::DenyAll => false, + Policy::Predicate(f) => f(&data), + }; + if approved { + PermissionResult::Approved + } else { + PermissionResult::Denied } } } @@ -108,61 +126,94 @@ impl SessionHandler for PermissionOverrideHandler { #[cfg(test)] mod tests { use super::*; - use crate::handler::ApproveAllHandler; - use crate::types::{RequestId, SessionId}; - - fn request() -> HandlerEvent { - HandlerEvent::PermissionRequest { - session_id: SessionId::from("s1"), - request_id: RequestId::new("1"), - data: PermissionRequestData { - extra: serde_json::json!({"tool": "shell"}), - ..Default::default() - }, + + fn data() -> PermissionRequestData { + PermissionRequestData { + extra: serde_json::json!({ "tool": "shell" }), + ..Default::default() } } #[tokio::test] - async fn approve_all_approves_permission_requests() { - let h = approve_all(Arc::new(ApproveAllHandler)); - match h.on_event(request()).await { - HandlerResponse::Permission(PermissionResult::Approved) => {} - other => panic!("expected Approved, got {other:?}"), - } + async fn approve_all_approves() { + let h = approve_all(); + assert!(matches!( + h.handle(SessionId::from("s"), RequestId::new("1"), data()) + .await, + PermissionResult::Approved + )); } #[tokio::test] - async fn deny_all_denies_permission_requests() { - let h = deny_all(Arc::new(ApproveAllHandler)); - match h.on_event(request()).await { - HandlerResponse::Permission(PermissionResult::Denied) => {} - other => panic!("expected Denied, got {other:?}"), - } + async fn deny_all_denies() { + let h = deny_all(); + assert!(matches!( + h.handle(SessionId::from("s"), RequestId::new("1"), data()) + .await, + PermissionResult::Denied + )); } #[tokio::test] async fn approve_if_consults_predicate() { - let h = approve_if(Arc::new(ApproveAllHandler), |data| { - data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") - }); - match h.on_event(request()).await { - HandlerResponse::Permission(PermissionResult::Denied) => {} - other => panic!("expected Denied for shell, got {other:?}"), + let h = approve_if(|d| d.extra.get("tool").and_then(|v| v.as_str()) != Some("shell")); + assert!(matches!( + h.handle(SessionId::from("s"), RequestId::new("1"), data()) + .await, + PermissionResult::Denied + )); + } + + #[tokio::test] + async fn resolve_handler_policy_wins() { + struct AlwaysApprove; + #[async_trait] + impl PermissionHandler for AlwaysApprove { + async fn handle( + &self, + _: SessionId, + _: RequestId, + _: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Approved + } } + let resolved = + resolve_handler(Some(Arc::new(AlwaysApprove)), Some(Policy::DenyAll)).unwrap(); + // Policy wins -- the AlwaysApprove handler is discarded. + assert!(matches!( + resolved + .handle(SessionId::from("s"), RequestId::new("1"), data()) + .await, + PermissionResult::Denied + )); } #[tokio::test] - async fn non_permission_events_forward_to_inner() { - let h = deny_all(Arc::new(ApproveAllHandler)); - let event = HandlerEvent::UserInput { - session_id: SessionId::from("s1"), - question: "continue?".to_string(), - choices: None, - allow_freeform: None, - }; - match h.on_event(event).await { - HandlerResponse::UserInput(None) => {} - other => panic!("expected UserInput forwarded, got {other:?}"), + async fn resolve_handler_with_only_handler() { + struct H; + #[async_trait] + impl PermissionHandler for H { + async fn handle( + &self, + _: SessionId, + _: RequestId, + _: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Approved + } } + let resolved = resolve_handler(Some(Arc::new(H)), None).unwrap(); + assert!(matches!( + resolved + .handle(SessionId::from("s"), RequestId::new("1"), data()) + .await, + PermissionResult::Approved + )); + } + + #[test] + fn resolve_handler_with_neither_returns_none() { + assert!(resolve_handler(None, None).is_none()); } } diff --git a/rust/src/session.rs b/rust/src/session.rs index d533dbc44..842d5d732 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -19,8 +19,8 @@ use crate::generated::session_events::{ SessionEventType, }; use crate::handler::{ - AutoModeSwitchResponse, HandlerEvent, HandlerResponse, PermissionResult, SessionHandler, - UserInputResponse, + AutoModeSwitchHandler, AutoModeSwitchResponse, ElicitationHandler, ExitPlanModeHandler, + PermissionHandler, PermissionResult, UserInputHandler, UserInputResponse, }; use crate::hooks::SessionHooks; use crate::session_fs::SessionFsProvider; @@ -28,14 +28,30 @@ use crate::trace_context::inject_trace_context; use crate::transforms::SystemMessageTransform; use crate::types::{ CommandContext, CommandDefinition, CommandHandler, CreateSessionResult, ElicitationRequest, - ElicitationResult, ExitPlanModeData, GetMessagesResponse, InputOptions, MessageOptions, + ElicitationResult, ExitPlanModeData, GetMessagesResponse, MessageOptions, PermissionRequestData, RequestId, ResumeSessionConfig, SectionOverride, SessionCapabilities, SessionConfig, SessionEvent, SessionId, SetModelOptions, SystemMessageConfig, ToolInvocation, - ToolResult, ToolResultExpanded, ToolResultResponse, TraceContext, - ensure_attachment_display_names, + ToolResult, ToolResultExpanded, TraceContext, UiInputOptions, ensure_attachment_display_names, }; use crate::{Client, Error, JsonRpcResponse, SessionError, SessionEventNotification, error_codes}; +/// Bundle of the per-session callbacks the SDK dispatches to. Built from a +/// [`SessionConfig`] / [`ResumeSessionConfig`] at +/// [`Client::create_session`] / [`Client::resume_session`] time. Each +/// field is `None` (or an empty map for tools) when the caller didn't +/// install a handler -- in that case the SDK skips dispatch for that +/// event type. The wire flags on `session.create` / `session.resume` +/// are derived from these fields. +#[derive(Clone)] +pub(crate) struct SessionHandlers { + pub permission: Option>, + pub elicitation: Option>, + pub user_input: Option>, + pub exit_plan_mode: Option>, + pub auto_mode_switch: Option>, + pub tools: Arc>>, +} + /// Shared state between a [`Session`] and its event loop, used by [`Session::send_and_wait`]. struct IdleWaiter { tx: oneshot::Sender, Error>>, @@ -106,7 +122,8 @@ impl Drop for PendingSessionRegistration { /// A session on a GitHub Copilot CLI server. /// /// Created via [`Client::create_session`] or [`Client::resume_session`]. -/// Owns an internal event loop that dispatches events to the [`SessionHandler`]. +/// Owns an internal event loop that dispatches events to the per-callback +/// handlers installed on the session config. /// /// Protocol methods (`send`, `get_messages`, `abort`, etc.) automatically /// inject the session ID into RPC params. @@ -220,10 +237,10 @@ impl Session { /// /// **Observe-only.** Subscribers receive a clone of every /// [`SessionEvent`] but cannot influence permission decisions, tool - /// results, or anything else that requires returning a - /// [`HandlerResponse`]. Those remain - /// the responsibility of the [`SessionHandler`] passed via - /// [`SessionConfig::handler`](crate::types::SessionConfig::handler). + /// results, or anything else that requires returning a value. Those + /// remain the responsibility of the per-callback handlers passed via + /// [`SessionConfig`]'s `with_*_handler` + /// builder methods. /// /// The returned handle implements both an inherent /// [`recv`](crate::subscription::EventSubscription::recv) method and @@ -446,8 +463,8 @@ impl Session { } } - /// Retrieve the session's message history. - pub async fn get_messages(&self) -> Result, Error> { + /// Retrieve the session's timeline events. + pub async fn get_events(&self) -> Result, Error> { let result = self .client .call( @@ -459,6 +476,12 @@ impl Session { Ok(response.events) } + /// Deprecated alias for [`get_events`](Self::get_events). + #[deprecated(since = "0.1.0", note = "Use `get_events()` instead")] + pub async fn get_messages(&self) -> Result, Error> { + self.get_events().await + } + /// Abort the current agent turn. /// /// # Cancel safety @@ -519,11 +542,11 @@ impl Session { Ok(()) } - /// Alias for [`disconnect`](Self::disconnect). - /// - /// Named after the `session.destroy` wire RPC. Prefer `disconnect` in - /// new code — the wire-level "destroy" is misleading because on-disk - /// state is preserved. + /// Deprecated alias for [`disconnect`](Self::disconnect). The + /// underlying wire RPC happens to be named `session.destroy`, but it + /// only severs the connection — on-disk session state is preserved. + /// Prefer `disconnect` in new code. + #[deprecated(since = "0.1.0", note = "Use `disconnect()` instead")] pub async fn destroy(&self) -> Result<(), Error> { self.disconnect().await } @@ -689,11 +712,11 @@ impl<'a> SessionUi<'a> { /// Ask the user for free-form text input. /// /// Returns the input string on accept, or `None` on decline/cancel. - /// Use [`InputOptions`] to set validation constraints and field metadata. + /// Use [`UiInputOptions`] to set validation constraints and field metadata. pub async fn input( &self, message: &str, - options: Option<&InputOptions<'_>>, + options: Option<&UiInputOptions<'_>>, ) -> Result, Error> { self.session.assert_elicitation()?; let mut field = serde_json::json!({ "type": "string" }); @@ -739,30 +762,69 @@ impl Client { /// Sends `session.create`, registers the session on the router, /// and spawns an internal event loop that dispatches to the handler. /// - /// All callbacks (event handler, hooks, transform) are configured - /// via [`SessionConfig`] using [`with_handler`](SessionConfig::with_handler), - /// [`with_hooks`](SessionConfig::with_hooks), and - /// [`with_transform`](SessionConfig::with_transform). + /// All callbacks (per-event handlers, tool handlers, hooks, transform) + /// are configured via [`SessionConfig`] using its `with_*_handler` / + /// `with_tools` / `with_hooks` / `with_system_message_transform` builder + /// methods. /// /// If [`hooks_handler`](SessionConfig::hooks_handler) is set, the /// wire-level `hooks` flag is automatically enabled. /// - /// If [`transform`](SessionConfig::transform) is set, the SDK injects + /// If [`system_message_transform`](SessionConfig::system_message_transform) is set, the SDK injects /// `action: "transform"` sections into the [`SystemMessageConfig`] wire /// format and handles `systemMessage.transform` RPC callbacks during /// the session. /// - /// If [`handler`](SessionConfig::handler) is `None`, the session uses - /// [`NoopHandler`](crate::handler::NoopHandler) — permission requests and - /// external tool calls are left pending for the consumer to resolve. + /// Each per-event handler is independently optional. If a handler is + /// not installed, the SDK signals the runtime not to emit the matching + /// broadcast (and silently skips dispatch if one arrives anyway). pub async fn create_session(&self, mut config: SessionConfig) -> Result { let total_start = Instant::now(); - let handler = config - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); + let session_id = config + .session_id + .clone() + .unwrap_or_else(|| SessionId::from(uuid::Uuid::new_v4().to_string())); + config.session_id = Some(session_id.clone()); + if config.hooks_handler.is_some() && config.hooks.is_none() { + config.hooks = Some(true); + } + if let Some(transforms) = config.system_message_transform.clone() { + inject_transform_sections(&mut config, transforms.as_ref()); + } + let wire = config.to_wire(session_id.clone()); + + let permission_handler = crate::permission::resolve_handler( + config.permission_handler.take(), + config.permission_policy.take(), + ); + let elicitation_handler = config.elicitation_handler.take(); + let user_input_handler = config.user_input_handler.take(); + let exit_plan_mode_handler = config.exit_plan_mode_handler.take(); + let auto_mode_switch_handler = config.auto_mode_switch_handler.take(); + let mut tool_map: HashMap> = HashMap::new(); + if let Some(tools) = config.tools.as_mut() { + for tool in tools.iter_mut() { + if let Some(handler) = tool.handler.take() { + if tool_map.contains_key(&tool.name) { + return Err(Error::InvalidConfig(format!( + "duplicate tool handler registered for name {:?}", + tool.name + ))); + } + tool_map.insert(tool.name.clone(), handler); + } + } + } + let handlers = SessionHandlers { + permission: permission_handler, + elicitation: elicitation_handler, + user_input: user_input_handler, + exit_plan_mode: exit_plan_mode_handler, + auto_mode_switch: auto_mode_switch_handler, + tools: Arc::new(tool_map), + }; let hooks = config.hooks_handler.take(); - let transforms = config.transform.take(); + let transforms = config.system_message_transform.take(); let tools_count = config.tools.as_ref().map_or(0, Vec::len); let commands_count = config.commands.as_ref().map_or(0, Vec::len); let has_hooks = hooks.is_some(); @@ -782,18 +844,7 @@ impl Client { )); } - if hooks.is_some() && config.hooks.is_none() { - config.hooks = Some(true); - } - if let Some(ref transforms) = transforms { - inject_transform_sections(&mut config, transforms.as_ref()); - } - let session_id = config - .session_id - .clone() - .unwrap_or_else(|| SessionId::from(uuid::Uuid::new_v4().to_string())); - config.session_id = Some(session_id.clone()); - let mut params = serde_json::to_value(&config)?; + let mut params = serde_json::to_value(&wire)?; let trace_ctx = self.resolve_trace_context().await; inject_trace_context(&mut params, &trace_ctx); @@ -806,7 +857,7 @@ impl Client { let event_loop = spawn_event_loop( session_id.clone(), self.clone(), - handler, + handlers, hooks, transforms, command_handlers, @@ -888,12 +939,47 @@ impl Client { /// fields are unset. pub async fn resume_session(&self, mut config: ResumeSessionConfig) -> Result { let total_start = Instant::now(); - let handler = config - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); + let session_id = config.session_id.clone(); + if config.hooks_handler.is_some() && config.hooks.is_none() { + config.hooks = Some(true); + } + if let Some(transforms) = config.system_message_transform.clone() { + inject_transform_sections_resume(&mut config, transforms.as_ref()); + } + let wire = config.to_wire(); + + let permission_handler = crate::permission::resolve_handler( + config.permission_handler.take(), + config.permission_policy.take(), + ); + let elicitation_handler = config.elicitation_handler.take(); + let user_input_handler = config.user_input_handler.take(); + let exit_plan_mode_handler = config.exit_plan_mode_handler.take(); + let auto_mode_switch_handler = config.auto_mode_switch_handler.take(); + let mut tool_map: HashMap> = HashMap::new(); + if let Some(tools) = config.tools.as_mut() { + for tool in tools.iter_mut() { + if let Some(handler) = tool.handler.take() { + if tool_map.contains_key(&tool.name) { + return Err(Error::InvalidConfig(format!( + "duplicate tool handler registered for name {:?}", + tool.name + ))); + } + tool_map.insert(tool.name.clone(), handler); + } + } + } + let handlers = SessionHandlers { + permission: permission_handler, + elicitation: elicitation_handler, + user_input: user_input_handler, + exit_plan_mode: exit_plan_mode_handler, + auto_mode_switch: auto_mode_switch_handler, + tools: Arc::new(tool_map), + }; let hooks = config.hooks_handler.take(); - let transforms = config.transform.take(); + let transforms = config.system_message_transform.take(); let tools_count = config.tools.as_ref().map_or(0, Vec::len); let commands_count = config.commands.as_ref().map_or(0, Vec::len); let has_hooks = hooks.is_some(); @@ -913,14 +999,7 @@ impl Client { )); } - if hooks.is_some() && config.hooks.is_none() { - config.hooks = Some(true); - } - if let Some(ref transforms) = transforms { - inject_transform_sections_resume(&mut config, transforms.as_ref()); - } - let session_id = config.session_id.clone(); - let mut params = serde_json::to_value(&config)?; + let mut params = serde_json::to_value(&wire)?; let trace_ctx = self.resolve_trace_context().await; inject_trace_context(&mut params, &trace_ctx); @@ -933,7 +1012,7 @@ impl Client { let event_loop = spawn_event_loop( session_id.clone(), self.clone(), - handler, + handlers, hooks, transforms, command_handlers, @@ -1060,7 +1139,7 @@ fn build_command_handler_map(commands: Option<&[CommandDefinition]>) -> Arc, + handlers: SessionHandlers, hooks: Option>, transforms: Option>, command_handlers: Arc, @@ -1094,12 +1173,12 @@ fn spawn_event_loop( _ = shutdown.cancelled() => break, Some(notification) = notifications.recv() => { handle_notification( - &session_id, &client, &handler, &command_handlers, notification, &idle_waiter, &capabilities, &event_tx, + &session_id, &client, &handlers, &command_handlers, notification, &idle_waiter, &capabilities, &event_tx, ).await; } Some(request) = requests.recv() => { handle_request( - &session_id, &client, &handler, hooks.as_deref(), transforms.as_deref(), session_fs_provider.as_ref(), request, + &session_id, &client, &handlers, hooks.as_deref(), transforms.as_deref(), session_fs_provider.as_ref(), request, ).await; } else => break, @@ -1124,20 +1203,20 @@ fn extract_request_id(data: &Value) -> Option { .map(RequestId::new) } -fn pending_permission_result_kind(response: &HandlerResponse) -> &'static str { - match response { - HandlerResponse::Permission(PermissionResult::Approved) => "approve-once", - HandlerResponse::Permission(PermissionResult::Denied) => "reject", +fn pending_permission_result_kind(result: &PermissionResult) -> &'static str { + match result { + PermissionResult::Approved => "approve-once", + PermissionResult::Denied => "reject", // Fallback to "user-not-available" for UserNotAvailable, Deferred (when // forced through this path), Custom (handled separately upstream), and - // any non-permission/no-result HandlerResponse that gets here defensively. + // NoResult that gets here defensively. _ => "user-not-available", } } -fn permission_request_response(response: &HandlerResponse) -> PermissionDecision { - match response { - HandlerResponse::Permission(PermissionResult::Approved) => { +fn permission_request_response(result: &PermissionResult) -> PermissionDecision { + match result { + PermissionResult::Approved => { PermissionDecision::ApproveOnce(PermissionDecisionApproveOnce { kind: PermissionDecisionApproveOnceKind::ApproveOnce, }) @@ -1149,41 +1228,37 @@ fn permission_request_response(response: &HandlerResponse) -> PermissionDecision } } -/// Map a handler response into the `result` payload for the notification +/// Map a permission result into the `result` payload for the notification /// path (`session.permissions.handlePendingPermissionRequest`). /// /// Returns `None` when the SDK must not respond. -fn notification_permission_payload(response: &HandlerResponse) -> Option { - match response { - HandlerResponse::Permission(PermissionResult::Deferred | PermissionResult::NoResult) => { - None - } - HandlerResponse::Permission(PermissionResult::Custom(value)) => Some(value.clone()), +fn notification_permission_payload(result: &PermissionResult) -> Option { + match result { + PermissionResult::Deferred | PermissionResult::NoResult => None, + PermissionResult::Custom(value) => Some(value.clone()), _ => Some(serde_json::json!({ - "kind": pending_permission_result_kind(response), + "kind": pending_permission_result_kind(result), })), } } -/// Map a handler response into the JSON-RPC `result` payload for the +/// Map a permission result into the JSON-RPC `result` payload for the /// direct-RPC path (`permission.request`). /// /// Always returns a value. [`PermissionResult::Deferred`] is treated as /// [`PermissionResult::Approved`] here because the JSON-RPC contract /// requires a reply — see the variant's doc comment. -fn direct_permission_payload(response: &HandlerResponse) -> Value { - match response { - HandlerResponse::Permission(PermissionResult::Custom(value)) => value.clone(), - HandlerResponse::Permission(PermissionResult::Deferred) => serde_json::to_value( - permission_request_response(&HandlerResponse::Permission(PermissionResult::Approved)), - ) - .expect("serializing direct permission response should succeed"), - HandlerResponse::Permission( - PermissionResult::NoResult | PermissionResult::UserNotAvailable, - ) => serde_json::json!({ - "kind": pending_permission_result_kind(response), +fn direct_permission_payload(result: &PermissionResult) -> Value { + match result { + PermissionResult::Custom(value) => value.clone(), + PermissionResult::Deferred => { + serde_json::to_value(permission_request_response(&PermissionResult::Approved)) + .expect("serializing direct permission response should succeed") + } + PermissionResult::NoResult | PermissionResult::UserNotAvailable => serde_json::json!({ + "kind": pending_permission_result_kind(result), }), - _ => serde_json::to_value(permission_request_response(response)) + _ => serde_json::to_value(permission_request_response(result)) .expect("serializing direct permission response should succeed"), } } @@ -1200,33 +1275,12 @@ fn tool_failure_result(message: impl Into) -> ToolResult { }) } -fn notification_tool_payload(response: HandlerResponse) -> Option { - match response { - HandlerResponse::ToolResult(result) => { - Some(serde_json::to_value(result).unwrap_or(Value::Null)) - } - HandlerResponse::NoResult => None, - _ => Some( - serde_json::to_value(tool_failure_result("Unexpected handler response")) - .unwrap_or(Value::Null), - ), - } -} - -fn direct_tool_result(response: HandlerResponse) -> ToolResult { - match response { - HandlerResponse::ToolResult(result) => result, - HandlerResponse::NoResult => tool_failure_result("No tool handler available"), - _ => tool_failure_result("Unexpected handler response"), - } -} - /// Process a notification from the CLI's broadcast channel. #[allow(clippy::too_many_arguments)] async fn handle_notification( session_id: &SessionId, client: &Client, - handler: &Arc, + handlers: &SessionHandlers, command_handlers: &Arc, notification: SessionEventNotification, idle_waiter: &Arc>>, @@ -1303,14 +1357,6 @@ async fn handle_notification( // before any consumer subscribes. let _ = event_tx.send(event.clone()); - // Fire-and-forget dispatch for the general event. - handler - .on_event(HandlerEvent::SessionEvent { - session_id: session_id.clone(), - event, - }) - .await; - // Update capabilities when the CLI reports changes. The CLI sends // the full updated capabilities object — replace wholesale so removals // and new subfields are handled correctly. @@ -1335,8 +1381,25 @@ async fn handle_notification( let Some(request_id) = extract_request_id(¬ification.event.data) else { return; }; + // Honor the runtime's `resolvedByHook` signal — when the + // server has already resolved the permission via a hook, + // clients must not send a second response. + if notification + .event + .data + .get("resolvedByHook") + .and_then(|v| v.as_bool()) + .unwrap_or(false) + { + return; + } + // Multi-client safety: if this client has no permission + // handler installed, don't respond — another client on the + // same CLI may handle it. + let Some(permission_handler) = handlers.permission.clone() else { + return; + }; let client = client.clone(); - let handler = handler.clone(); let sid = session_id.clone(); let data: PermissionRequestData = serde_json::from_value(notification.event.data.clone()).unwrap_or_else(|_| { @@ -1354,22 +1417,19 @@ async fn handle_notification( tokio::spawn( async move { let handler_start = Instant::now(); - let response = handler - .on_event(HandlerEvent::PermissionRequest { - session_id: sid.clone(), - request_id: request_id.clone(), - data, - }) + let result = permission_handler + .handle(sid.clone(), request_id.clone(), data) .await; tracing::debug!( elapsed_ms = handler_start.elapsed().as_millis(), session_id = %sid, request_id = %request_id, - "SessionHandler::on_permission_request dispatch" + "PermissionHandler::handle dispatch" ); - let Some(result_value) = notification_permission_payload(&response) else { - // Handler returned Deferred — it will call - // handlePendingPermissionRequest itself. + let Some(result_value) = notification_permission_payload(&result) else { + // Handler returned Deferred / NoResult — it will + // call handlePendingPermissionRequest itself (or + // leave the request unanswered). return; }; let rpc_start = Instant::now(); @@ -1434,8 +1494,18 @@ async fn handle_notification( return; } }; + // Multi-client safety: look up a handler for the requested + // tool name. If this client has no handler installed for that + // tool, don't respond — another connected client may have one. + let tool_handler = if data.tool_name.is_empty() { + None + } else { + handlers.tools.get(&data.tool_name).cloned() + }; + let Some(tool_handler) = tool_handler else { + return; + }; let client = client.clone(); - let handler = handler.clone(); let sid = session_id.clone(); let span = tracing::error_span!( "external_tool_handler", @@ -1482,9 +1552,10 @@ async fn handle_notification( tracestate: data.tracestate, }; let handler_start = Instant::now(); - let response = handler - .on_event(HandlerEvent::ExternalTool { invocation }) - .await; + let tool_result = match tool_handler.call(invocation).await { + Ok(r) => r, + Err(e) => tool_failure_result(e.to_string()), + }; tracing::debug!( elapsed_ms = handler_start.elapsed().as_millis(), session_id = %sid, @@ -1493,9 +1564,7 @@ async fn handle_notification( tool_name = %tool_name, "ToolHandler::call dispatch" ); - let Some(result_value) = notification_tool_payload(response) else { - return; - }; + let result_value = serde_json::to_value(tool_result).unwrap_or(Value::Null); let rpc_start = Instant::now(); let _ = client .call( @@ -1522,7 +1591,7 @@ async fn handle_notification( SessionEventType::UserInputRequested => { // Notification-only signal for observers (UI, telemetry). // The CLI follows up with a `userInput.request` JSON-RPC call - // that drives `HandlerEvent::UserInput` dispatch — handling + // that drives the `UserInputHandler` dispatch — handling // the notification here too would double-fire the handler // and produce duplicate prompts on the consumer side. See // github/github-app#4249. @@ -1531,6 +1600,12 @@ async fn handle_notification( let Some(request_id) = extract_request_id(¬ification.event.data) else { return; }; + // Multi-client safety: if this client has no elicitation + // handler installed, don't respond — another client on the + // same CLI may handle it. + let Some(elicitation_handler) = handlers.elicitation.clone() else { + return; + }; let elicitation_data: ElicitationRequestedData = match serde_json::from_value(notification.event.data.clone()) { Ok(d) => d, @@ -1557,7 +1632,6 @@ async fn handle_notification( url: elicitation_data.url, }; let client = client.clone(); - let handler = handler.clone(); let sid = session_id.clone(); let span = tracing::error_span!( "elicitation_request_handler", @@ -1581,26 +1655,22 @@ async fn handle_notification( ); async move { let handler_start = Instant::now(); - let response = handler - .on_event(HandlerEvent::ElicitationRequest { - session_id: sid.clone(), - request_id: request_id.clone(), - request, - }) + let response = elicitation_handler + .handle(sid.clone(), request_id.clone(), request) .await; tracing::debug!( elapsed_ms = handler_start.elapsed().as_millis(), session_id = %sid, request_id = %request_id, - "SessionHandler::on_elicitation dispatch" + "ElicitationHandler::handle dispatch" ); response } .instrument(span) }); let result = match handler_task.await { - Ok(HandlerResponse::Elicitation(r)) => r, - _ => cancel.clone(), + Ok(r) => r, + Err(_) => cancel.clone(), }; let rpc_start = Instant::now(); if let Err(e) = client @@ -1708,7 +1778,7 @@ async fn handle_notification( async fn handle_request( session_id: &SessionId, client: &Client, - handler: &Arc, + handlers: &SessionHandlers, hooks: Option<&dyn SessionHooks>, transforms: Option<&dyn SystemMessageTransform>, session_fs_provider: Option<&Arc>, @@ -1754,49 +1824,6 @@ async fn handle_request( let _ = client.send_response(&rpc_response).await; } - "tool.call" => { - let invocation: ToolInvocation = match request - .params - .as_ref() - .and_then(|p| serde_json::from_value::(p.clone()).ok()) - { - Some(inv) => inv, - None => { - let _ = send_error_response( - client, - request.id, - error_codes::INVALID_PARAMS, - "invalid tool.call params", - ) - .await; - return; - } - }; - let tool_call_id = invocation.tool_call_id.clone(); - let tool_name = invocation.tool_name.clone(); - let handler_start = Instant::now(); - let response = handler - .on_event(HandlerEvent::ExternalTool { invocation }) - .await; - tracing::debug!( - elapsed_ms = handler_start.elapsed().as_millis(), - session_id = %sid, - tool_call_id = %tool_call_id, - tool_name = %tool_name, - "ToolHandler::call dispatch" - ); - let tool_result = direct_tool_result(response); - let rpc_response = JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: request.id, - result: Some(serde_json::json!(ToolResultResponse { - result: tool_result - })), - error: None, - }; - let _ = client.send_response(&rpc_response).await; - } - "userInput.request" => { let params = request.params.as_ref(); let Some(question) = params @@ -1831,29 +1858,28 @@ async fn handle_request( .and_then(|v| v.as_bool()); let handler_start = Instant::now(); - let response = handler - .on_event(HandlerEvent::UserInput { - session_id: sid.clone(), - question, - choices, - allow_freeform, - }) - .await; + let response = if let Some(user_input_handler) = handlers.user_input.as_ref() { + user_input_handler + .handle(sid.clone(), question, choices, allow_freeform) + .await + } else { + None + }; tracing::debug!( elapsed_ms = handler_start.elapsed().as_millis(), session_id = %sid, - "SessionHandler::on_user_input dispatch" + "UserInputHandler::handle dispatch" ); let rpc_result = match response { - HandlerResponse::UserInput(Some(UserInputResponse { + Some(UserInputResponse { answer, was_freeform, - })) => serde_json::json!({ + }) => serde_json::json!({ "answer": answer, "wasFreeform": was_freeform, }), - _ => serde_json::json!({ "noResponse": true }), + None => serde_json::json!({ "noResponse": true }), }; let rpc_response = JsonRpcResponse { jsonrpc: "2.0".to_string(), @@ -1878,17 +1904,11 @@ async fn handle_request( } }; - let response = handler - .on_event(HandlerEvent::ExitPlanMode { - session_id: sid, - data, - }) - .await; - - let rpc_result = match response { - HandlerResponse::ExitPlanMode(result) => serde_json::to_value(result) - .expect("ExitPlanModeResult serialization cannot fail"), - _ => serde_json::json!({ "approved": true }), + let rpc_result = if let Some(exit_plan_handler) = handlers.exit_plan_mode.as_ref() { + let result = exit_plan_handler.handle(sid, data).await; + serde_json::to_value(result).expect("ExitPlanModeResult serialization cannot fail") + } else { + serde_json::json!({ "approved": true }) }; let rpc_response = JsonRpcResponse { jsonrpc: "2.0".to_string(), @@ -1912,17 +1932,12 @@ async fn handle_request( .and_then(|p| p.get("retryAfterSeconds")) .and_then(|v| v.as_f64()); - let response = handler - .on_event(HandlerEvent::AutoModeSwitch { - session_id: sid, - error_code, - retry_after_seconds, - }) - .await; - - let answer = match response { - HandlerResponse::AutoModeSwitch(answer) => answer, - _ => AutoModeSwitchResponse::No, + let answer = if let Some(auto_mode_handler) = handlers.auto_mode_switch.as_ref() { + auto_mode_handler + .handle(sid, error_code, retry_after_seconds) + .await + } else { + AutoModeSwitchResponse::No }; let rpc_response = JsonRpcResponse { jsonrpc: "2.0".to_string(), @@ -1969,23 +1984,27 @@ async fn handle_request( }); let handler_start = Instant::now(); - let response = handler - .on_event(HandlerEvent::PermissionRequest { - session_id: sid.clone(), - request_id: request_id.clone(), - data, - }) - .await; - tracing::debug!( - elapsed_ms = handler_start.elapsed().as_millis(), - session_id = %sid, - request_id = %request_id, - "SessionHandler::on_permission_request dispatch" - ); + let rpc_result = if let Some(permission_handler) = handlers.permission.as_ref() { + let result = permission_handler + .handle(sid.clone(), request_id.clone(), data) + .await; + tracing::debug!( + elapsed_ms = handler_start.elapsed().as_millis(), + session_id = %sid, + request_id = %request_id, + "PermissionHandler::handle dispatch" + ); + direct_permission_payload(&result) + } else { + // Back-compat with v2 servers that still send + // permission.request as a direct RPC: default to + // user-not-available rather than erroring. + serde_json::json!({ "kind": "user-not-available" }) + }; let rpc_response = JsonRpcResponse { jsonrpc: "2.0".to_string(), id: request.id, - result: Some(direct_permission_payload(&response)), + result: Some(rpc_result), error: None, }; let _ = client.send_response(&rpc_response).await; @@ -2124,22 +2143,20 @@ mod tests { direct_permission_payload, notification_permission_payload, pending_permission_result_kind, permission_request_response, }; - use crate::handler::{HandlerResponse, PermissionResult}; + use crate::handler::PermissionResult; #[test] fn pending_permission_requests_use_decision_kinds() { assert_eq!( - pending_permission_result_kind(&HandlerResponse::Permission( - PermissionResult::Approved, - )), + pending_permission_result_kind(&PermissionResult::Approved), "approve-once" ); assert_eq!( - pending_permission_result_kind(&HandlerResponse::Permission(PermissionResult::Denied)), + pending_permission_result_kind(&PermissionResult::Denied), "reject" ); assert_eq!( - pending_permission_result_kind(&HandlerResponse::Ok), + pending_permission_result_kind(&PermissionResult::UserNotAvailable), "user-not-available" ); } @@ -2147,22 +2164,20 @@ mod tests { #[test] fn direct_permission_requests_use_decision_response_kinds() { assert_eq!( - serde_json::to_value(permission_request_response(&HandlerResponse::Permission( - PermissionResult::Approved - ),)) - .expect("serializing approved permission response should succeed"), + serde_json::to_value(permission_request_response(&PermissionResult::Approved)) + .expect("serializing approved permission response should succeed"), json!({ "kind": "approve-once" }) ); assert_eq!( - serde_json::to_value(permission_request_response(&HandlerResponse::Permission( - PermissionResult::Denied - ),)) - .expect("serializing denied permission response should succeed"), + serde_json::to_value(permission_request_response(&PermissionResult::Denied)) + .expect("serializing denied permission response should succeed"), json!({ "kind": "reject" }) ); assert_eq!( - serde_json::to_value(permission_request_response(&HandlerResponse::Ok)) - .expect("serializing fallback permission response should succeed"), + serde_json::to_value(permission_request_response( + &PermissionResult::UserNotAvailable + )) + .expect("serializing fallback permission response should succeed"), json!({ "kind": "reject" }) ); } @@ -2170,18 +2185,8 @@ mod tests { #[test] fn notification_payload_handles_non_responses_and_custom() { // Deferred/NoResult -> no payload, SDK must not respond. - assert!( - notification_permission_payload(&HandlerResponse::Permission( - PermissionResult::Deferred, - )) - .is_none() - ); - assert!( - notification_permission_payload(&HandlerResponse::Permission( - PermissionResult::NoResult, - )) - .is_none() - ); + assert!(notification_permission_payload(&PermissionResult::Deferred).is_none()); + assert!(notification_permission_payload(&PermissionResult::NoResult).is_none()); // Custom → handler-supplied value passed through verbatim. let custom = json!({ @@ -2189,23 +2194,17 @@ mod tests { "allowlist": ["ls", "grep"], }); assert_eq!( - notification_permission_payload(&HandlerResponse::Permission( - PermissionResult::Custom(custom.clone()), - )), + notification_permission_payload(&PermissionResult::Custom(custom.clone())), Some(custom) ); // Approved/Denied → existing kind-only shape. assert_eq!( - notification_permission_payload(&HandlerResponse::Permission( - PermissionResult::Approved, - )), + notification_permission_payload(&PermissionResult::Approved), Some(json!({ "kind": "approve-once" })) ); assert_eq!( - notification_permission_payload( - &HandlerResponse::Permission(PermissionResult::Denied,) - ), + notification_permission_payload(&PermissionResult::Denied), Some(json!({ "kind": "reject" })) ); } @@ -2218,31 +2217,29 @@ mod tests { "allowlist": ["ls", "grep"], }); assert_eq!( - direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Custom( - custom.clone(), - ))), + direct_permission_payload(&PermissionResult::Custom(custom.clone())), custom ); // Deferred → falls back to Approved because the direct RPC must reply. assert_eq!( - direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Deferred)), + direct_permission_payload(&PermissionResult::Deferred), json!({ "kind": "approve-once" }) ); // NoResult -> direct RPC cannot be left pending, so report no user. assert_eq!( - direct_permission_payload(&HandlerResponse::Permission(PermissionResult::NoResult)), + direct_permission_payload(&PermissionResult::NoResult), json!({ "kind": "user-not-available" }) ); // Approved/Denied → existing kind-only shape. assert_eq!( - direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Approved)), + direct_permission_payload(&PermissionResult::Approved), json!({ "kind": "approve-once" }) ); assert_eq!( - direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Denied)), + direct_permission_payload(&PermissionResult::Denied), json!({ "kind": "reject" }) ); } diff --git a/rust/src/subscription.rs b/rust/src/subscription.rs index 52c15b2eb..69886a195 100644 --- a/rust/src/subscription.rs +++ b/rust/src/subscription.rs @@ -4,10 +4,10 @@ //! [`Client::subscribe_lifecycle`](crate::Client::subscribe_lifecycle). //! //! Each subscription is an opt-in **observer** of events that are also -//! delivered to the [`SessionHandler`](crate::handler::SessionHandler). -//! Subscribers receive a clone of every event but cannot influence -//! permission decisions, tool results, or anything else that requires -//! returning a [`HandlerResponse`](crate::handler::HandlerResponse). +//! delivered to the per-event handlers installed on the session config +//! (see [`crate::handler`]). Subscribers receive a clone of every event but +//! cannot influence permission decisions, tool results, or any other event +//! whose handler return value affects the runtime. //! //! # Async iteration //! diff --git a/rust/src/tool.rs b/rust/src/tool.rs index 3342f4b9f..95ac16d68 100644 --- a/rust/src/tool.rs +++ b/rust/src/tool.rs @@ -1,14 +1,19 @@ //! Typed tool definition framework. //! -//! Provides the [`ToolHandler`](crate::tool::ToolHandler) trait for implementing tools as named types, -//! and [`ToolHandlerRouter`](crate::tool::ToolHandlerRouter) for automatic dispatch of tool calls within a -//! [`SessionHandler`](crate::handler::SessionHandler). +//! Provides the [`ToolHandler`](crate::tool::ToolHandler) trait for +//! implementing tools as named types. Attach a handler to a +//! [`Tool`](crate::types::Tool) via +//! [`Tool::with_handler`](crate::types::Tool::with_handler), then install +//! the resulting tools on a session via +//! [`SessionConfig::with_tools`](crate::types::SessionConfig::with_tools). +//! The SDK builds an internal name-keyed registry from the handlers and +//! dispatches to the matching handler when the CLI broadcasts +//! `external_tool.requested`. //! //! Enable the `derive` feature for `schema_for`, which generates JSON //! Schema from Rust types via `schemars`. use std::collections::HashMap; -use std::sync::Arc; use async_trait::async_trait; /// Re-export of [`schemars::JsonSchema`] for deriving tool parameter schemas. @@ -16,11 +21,9 @@ use async_trait::async_trait; pub use schemars::JsonSchema; use crate::Error; -use crate::handler::{PermissionResult, SessionHandler, UserInputResponse}; -use crate::types::{ - ElicitationRequest, ElicitationResult, PermissionRequestData, RequestId, SessionEvent, - SessionId, Tool, ToolBinaryResult, ToolInvocation, ToolResult, ToolResultExpanded, -}; +#[cfg(any(feature = "derive", test))] +use crate::types::Tool; +use crate::types::{ToolBinaryResult, ToolInvocation, ToolResult, ToolResultExpanded}; /// Generate a JSON Schema [`Value`](serde_json::Value) from a Rust type. /// @@ -172,64 +175,63 @@ pub fn convert_mcp_call_tool_result(value: &serde_json::Value) -> Option, /// } /// -/// struct GetWeatherTool; +/// struct GetWeather; /// /// #[async_trait] -/// impl ToolHandler for GetWeatherTool { -/// fn tool(&self) -> Tool { -/// Tool { -/// name: "get_weather".to_string(), -/// namespaced_name: None, -/// description: "Get weather for a city".to_string(), -/// parameters: tool_parameters(schema_for::()), -/// instructions: None, -/// ..Default::default() -/// } -/// } -/// +/// impl ToolHandler for GetWeather { /// async fn call(&self, inv: ToolInvocation) -> Result { /// let params: GetWeatherParams = serde_json::from_value(inv.arguments)?; /// Ok(ToolResult::Text(format!("Weather in {}: sunny", params.city))) /// } /// } +/// +/// // Build the Tool declaration with the handler attached: +/// let tool = Tool::new("get_weather") +/// .with_description("Get weather for a city") +/// .with_parameters(schema_for::()) +/// .with_handler(Arc::new(GetWeather)); /// ``` #[async_trait] -pub trait ToolHandler: Send + Sync { - /// The tool definition sent to the CLI during session creation. - fn tool(&self) -> Tool; - +pub trait ToolHandler: Send + Sync + 'static { /// Handle a tool invocation from the agent. async fn call(&self, invocation: ToolInvocation) -> Result; } -/// Define a tool from an async function (or closure) that takes a typed, +/// Define a [`Tool`] from an async function (or closure) that takes a typed, /// `JsonSchema`-derived parameter struct. /// -/// The returned `Box` plugs directly into -/// [`ToolHandlerRouter::new`]. JSON Schema for the parameter type is generated -/// via [`schema_for`] at construction time. +/// The returned [`Tool`] carries an attached handler ready to install on a +/// session via [`SessionConfig::with_tools`](crate::types::SessionConfig::with_tools). +/// JSON Schema for the parameter type is generated via [`schema_for`] at +/// construction time. /// /// The handler bound (`Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static`) /// accepts both bare `async fn` items and closures — the same shape as @@ -260,8 +262,6 @@ pub trait ToolHandler: Send + Sync { /// inv: ToolInvocation, /// params: GetWeatherParams, /// ) -> Result { -/// // `inv.session_id` and `inv.tool_call_id` are available for telemetry, -/// // streaming updates, scoping DB lookups, etc. /// let _ = inv.session_id; /// Ok(ToolResult::Text(format!("Sunny in {}", params.city))) /// } @@ -287,36 +287,24 @@ pub fn define_tool( name: impl Into, description: impl Into, handler: F, -) -> Box +) -> Tool where P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static, F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static, Fut: std::future::Future> + Send + 'static, { - struct FnTool { - name: String, - description: String, - parameters: HashMap, + struct FnHandler { handler: F, _marker: std::marker::PhantomData, } #[async_trait] - impl ToolHandler for FnTool + impl ToolHandler for FnHandler where P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static, F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static, Fut: std::future::Future> + Send + 'static, { - fn tool(&self) -> Tool { - Tool { - name: self.name.clone(), - description: self.description.clone(), - parameters: self.parameters.clone(), - ..Default::default() - } - } - async fn call(&self, mut invocation: ToolInvocation) -> Result { let arguments = std::mem::take(&mut invocation.arguments); let params: P = serde_json::from_value(arguments)?; @@ -324,153 +312,71 @@ where } } - Box::new(FnTool { + Tool { name: name.into(), description: description.into(), parameters: tool_parameters(schema_for::

()), + ..Default::default() + } + .with_handler(std::sync::Arc::new(FnHandler { handler, _marker: std::marker::PhantomData, - }) + })) } -/// A [`SessionHandler`] that dispatches tool calls to registered -/// [`ToolHandler`] implementations by name. +/// Define a declaration-only [`Tool`] with a JSON Schema derived from `P`. /// -/// For tool calls matching a registered handler, the handler is invoked -/// directly. All other events (permissions, user input, unrecognized tools) -/// are forwarded to the inner handler. +/// Equivalent to [`define_tool`] but produces a [`Tool`] with no attached +/// handler — useful when another connected client services this tool, or +/// when you only need to advertise the schema for capability negotiation. /// /// # Example /// /// ```rust,no_run -/// use std::sync::Arc; -/// use github_copilot_sdk::handler::ApproveAllHandler; -/// use github_copilot_sdk::tool::ToolHandlerRouter; +/// use github_copilot_sdk::tool::{define_tool_declaration, JsonSchema}; +/// use serde::Deserialize; /// -/// let router = ToolHandlerRouter::new( -/// vec![/* Box::new(MyTool), ... */], -/// Arc::new(ApproveAllHandler), -/// ); +/// #[derive(Deserialize, JsonSchema)] +/// struct Params { query: String } /// -/// // Use router.tools() in SessionConfig -/// // Use Arc::new(router) as the session handler +/// let declared = define_tool_declaration::( +/// "legacy_thing", +/// "Handled by another connected client", +/// ); +/// # let _ = declared; /// ``` -pub struct ToolHandlerRouter { - handlers: HashMap>, - inner: Arc, -} - -impl std::fmt::Debug for ToolHandlerRouter { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut tools: Vec<_> = self.handlers.keys().collect(); - tools.sort(); - f.debug_struct("ToolHandlerRouter") - .field("tool_count", &self.handlers.len()) - .field("tools", &tools) - .finish() - } -} - -impl ToolHandlerRouter { - /// Create a router from tool handler impls and a fallback handler. - /// - /// Call [`tools()`](Self::tools) to get the tool definitions for - /// [`SessionConfig::tools`](crate::SessionConfig::tools). - pub fn new(tools: Vec>, inner: Arc) -> Self { - let mut handlers = HashMap::new(); - for tool in tools { - handlers.insert(tool.tool().name.clone(), tool); - } - Self { handlers, inner } - } - - /// Tool definitions for [`SessionConfig::tools`](crate::SessionConfig::tools). - pub fn tools(&self) -> Vec { - self.handlers.values().map(|h| h.tool()).collect() - } -} - -#[async_trait] -impl SessionHandler for ToolHandlerRouter { - async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult { - let Some(handler) = self.handlers.get(&invocation.tool_name) else { - return self.inner.on_external_tool(invocation).await; - }; - match handler.call(invocation).await { - Ok(result) => result, - Err(e) => { - let msg = e.to_string(); - ToolResult::Expanded(ToolResultExpanded { - text_result_for_llm: msg.clone(), - result_type: "failure".to_string(), - binary_results_for_llm: None, - session_log: None, - error: Some(msg), - tool_telemetry: None, - }) - } - } - } - - async fn on_session_event(&self, session_id: SessionId, event: SessionEvent) { - self.inner.on_session_event(session_id, event).await - } - - async fn on_permission_request( - &self, - session_id: SessionId, - request_id: RequestId, - data: PermissionRequestData, - ) -> PermissionResult { - self.inner - .on_permission_request(session_id, request_id, data) - .await - } - - async fn on_user_input( - &self, - session_id: SessionId, - question: String, - choices: Option>, - allow_freeform: Option, - ) -> Option { - self.inner - .on_user_input(session_id, question, choices, allow_freeform) - .await - } - - async fn on_elicitation( - &self, - session_id: SessionId, - request_id: RequestId, - request: ElicitationRequest, - ) -> ElicitationResult { - self.inner - .on_elicitation(session_id, request_id, request) - .await +#[cfg(feature = "derive")] +pub fn define_tool_declaration

(name: impl Into, description: impl Into) -> Tool +where + P: schemars::JsonSchema, +{ + Tool { + name: name.into(), + description: description.into(), + parameters: tool_parameters(schema_for::

()), + ..Default::default() } } #[cfg(test)] mod tests { use super::*; - use crate::types::{PermissionRequestData, RequestId, SessionId}; + use crate::types::SessionId; struct EchoTool; - #[async_trait] - impl ToolHandler for EchoTool { - fn tool(&self) -> Tool { - Tool { - name: "echo".to_string(), - namespaced_name: None, - description: "Echo the input".to_string(), - parameters: tool_parameters(serde_json::json!({"type": "object"})), - instructions: None, - ..Default::default() - } + fn echo_tool() -> Tool { + Tool { + name: "echo".to_string(), + description: "Echo the input".to_string(), + parameters: tool_parameters(serde_json::json!({"type": "object"})), + ..Default::default() } + .with_handler(std::sync::Arc::new(EchoTool)) + } + #[async_trait] + impl ToolHandler for EchoTool { async fn call(&self, inv: ToolInvocation) -> Result { Ok(ToolResult::Text(inv.arguments.to_string())) } @@ -478,11 +384,11 @@ mod tests { #[test] fn tool_handler_returns_tool_definition() { - let tool = EchoTool; - let def = tool.tool(); + let def = echo_tool(); assert_eq!(def.name, "echo"); assert_eq!(def.description, "Echo the input"); assert!(def.parameters.contains_key("type")); + assert!(def.handler.is_some()); } #[test] @@ -685,11 +591,11 @@ mod tests { }, ); - let def = tool.tool(); - assert_eq!(def.name, "weather"); - assert_eq!(def.description, "Get the weather for a city"); - assert_eq!(def.parameters["type"], "object"); - assert!(def.parameters["properties"]["city"].is_object()); + assert_eq!(tool.name, "weather"); + assert_eq!(tool.description, "Get the weather for a city"); + assert_eq!(tool.parameters["type"], "object"); + assert!(tool.parameters["properties"]["city"].is_object()); + let handler = tool.handler.as_ref().expect("define_tool attaches handler"); let inv = ToolInvocation { session_id: SessionId::from("s1"), @@ -699,239 +605,12 @@ mod tests { traceparent: None, tracestate: None, }; - match tool.call(inv).await.unwrap() { + match handler.call(inv).await.unwrap() { ToolResult::Text(s) => assert_eq!(s, "sunny in Seattle"), _ => panic!("expected Text result"), } } - #[tokio::test] - async fn router_dispatches_to_correct_handler() { - struct ToolA; - #[async_trait] - impl ToolHandler for ToolA { - fn tool(&self) -> Tool { - Tool { - name: "tool_a".to_string(), - namespaced_name: None, - description: "A".to_string(), - parameters: HashMap::new(), - instructions: None, - ..Default::default() - } - } - - async fn call(&self, _inv: ToolInvocation) -> Result { - Ok(ToolResult::Text("a_result".to_string())) - } - } - - struct ToolB; - #[async_trait] - impl ToolHandler for ToolB { - fn tool(&self) -> Tool { - Tool { - name: "tool_b".to_string(), - namespaced_name: None, - description: "B".to_string(), - parameters: HashMap::new(), - instructions: None, - ..Default::default() - } - } - - async fn call(&self, _inv: ToolInvocation) -> Result { - Ok(ToolResult::Text("b_result".to_string())) - } - } - - let router = ToolHandlerRouter::new( - vec![Box::new(ToolA), Box::new(ToolB)], - Arc::new(crate::handler::ApproveAllHandler), - ); - - let tools = router.tools(); - assert_eq!(tools.len(), 2); - - let response = router - .on_external_tool(ToolInvocation { - session_id: SessionId::from("s1"), - tool_call_id: "tc1".to_string(), - tool_name: "tool_b".to_string(), - arguments: serde_json::json!({}), - traceparent: None, - tracestate: None, - }) - .await; - match response { - ToolResult::Text(s) => assert_eq!(s, "b_result"), - _ => panic!("expected ToolResult::Text"), - } - } - - #[tokio::test] - async fn router_falls_through_for_unknown_tool() { - use std::sync::atomic::{AtomicBool, Ordering}; - - struct FallbackHandler { - called: AtomicBool, - } - #[async_trait] - impl SessionHandler for FallbackHandler { - async fn on_external_tool(&self, _inv: ToolInvocation) -> ToolResult { - self.called.store(true, Ordering::Relaxed); - ToolResult::Text("fallback".to_string()) - } - } - - let fallback = Arc::new(FallbackHandler { - called: AtomicBool::new(false), - }); - let router = ToolHandlerRouter::new(vec![], fallback.clone()); - - let response = router - .on_external_tool(ToolInvocation { - session_id: SessionId::from("s1"), - tool_call_id: "tc1".to_string(), - tool_name: "unknown".to_string(), - arguments: serde_json::json!({}), - traceparent: None, - tracestate: None, - }) - .await; - assert!(fallback.called.load(Ordering::Relaxed)); - match response { - ToolResult::Text(s) => assert_eq!(s, "fallback"), - _ => panic!("expected fallback result"), - } - } - - #[tokio::test] - async fn router_returns_failure_on_handler_error() { - struct FailTool; - #[async_trait] - impl ToolHandler for FailTool { - fn tool(&self) -> Tool { - Tool { - name: "bad_tool".to_string(), - namespaced_name: None, - description: "Always fails".to_string(), - parameters: HashMap::new(), - instructions: None, - ..Default::default() - } - } - - async fn call(&self, _inv: ToolInvocation) -> Result { - Err(Error::Rpc { - code: -1, - message: "intentional failure".to_string(), - }) - } - } - - let router = ToolHandlerRouter::new( - vec![Box::new(FailTool)], - Arc::new(crate::handler::ApproveAllHandler), - ); - - let response = router - .on_external_tool(ToolInvocation { - session_id: SessionId::from("s1"), - tool_call_id: "tc1".to_string(), - tool_name: "bad_tool".to_string(), - arguments: serde_json::json!({}), - traceparent: None, - tracestate: None, - }) - .await; - match response { - ToolResult::Expanded(exp) => { - assert_eq!(exp.result_type, "failure"); - assert!(exp.error.unwrap().contains("intentional failure")); - } - _ => panic!("expected expanded failure result"), - } - } - - #[tokio::test] - async fn router_forwards_non_tool_events() { - struct PermHandler; - #[async_trait] - impl SessionHandler for PermHandler { - async fn on_permission_request( - &self, - _session_id: SessionId, - _request_id: RequestId, - _data: PermissionRequestData, - ) -> PermissionResult { - PermissionResult::Denied - } - } - - let router = ToolHandlerRouter::new(vec![], Arc::new(PermHandler)); - - let response = router - .on_permission_request( - SessionId::from("s1"), - RequestId::new("r1"), - PermissionRequestData { - extra: serde_json::json!({}), - ..Default::default() - }, - ) - .await; - assert!(matches!(response, PermissionResult::Denied)); - } - - #[tokio::test] - async fn router_default_on_event_dispatches_via_per_event_methods() { - // Regression: callers using the legacy on_event entry point should - // still get correct dispatch through the inherited default impl. - use crate::handler::{HandlerEvent, HandlerResponse}; - - struct OkTool; - #[async_trait] - impl ToolHandler for OkTool { - fn tool(&self) -> Tool { - Tool { - name: "ok_tool".to_string(), - namespaced_name: None, - description: "ok".to_string(), - parameters: HashMap::new(), - instructions: None, - ..Default::default() - } - } - - async fn call(&self, _inv: ToolInvocation) -> Result { - Ok(ToolResult::Text("ok".to_string())) - } - } - - let router = ToolHandlerRouter::new( - vec![Box::new(OkTool)], - Arc::new(crate::handler::ApproveAllHandler), - ); - - let response = router - .on_event(HandlerEvent::ExternalTool { - invocation: ToolInvocation { - session_id: SessionId::from("s1"), - tool_call_id: "tc1".to_string(), - tool_name: "ok_tool".to_string(), - arguments: serde_json::json!({}), - traceparent: None, - tracestate: None, - }, - }) - .await; - match response { - HandlerResponse::ToolResult(ToolResult::Text(s)) => assert_eq!(s, "ok"), - _ => panic!("expected ToolResult via default on_event"), - } - } - // Tests requiring `schemars` (the `derive` feature). #[cfg(feature = "derive")] mod derive_tests { @@ -965,19 +644,18 @@ mod tests { struct GetWeatherTool; - #[async_trait] - impl ToolHandler for GetWeatherTool { - fn tool(&self) -> Tool { - Tool { - name: "get_weather".to_string(), - namespaced_name: None, - description: "Get weather for a city".to_string(), - parameters: tool_parameters(schema_for::()), - instructions: None, - ..Default::default() - } + fn get_weather_tool() -> Tool { + Tool { + name: "get_weather".to_string(), + description: "Get weather for a city".to_string(), + parameters: tool_parameters(schema_for::()), + ..Default::default() } + .with_handler(std::sync::Arc::new(GetWeatherTool)) + } + #[async_trait] + impl ToolHandler for GetWeatherTool { async fn call(&self, inv: ToolInvocation) -> Result { let params: GetWeatherParams = serde_json::from_value(inv.arguments)?; Ok(ToolResult::Text(format!( @@ -990,12 +668,12 @@ mod tests { #[test] fn tool_handler_with_schema_for() { - let tool = GetWeatherTool; - let def = tool.tool(); + let def = get_weather_tool(); assert_eq!(def.name, "get_weather"); let schema = serde_json::to_value(&def.parameters).expect("serialize tool parameters"); assert_eq!(schema["type"], "object"); assert!(schema["properties"]["city"].is_object()); + assert!(def.handler.is_some()); } #[tokio::test] @@ -1034,18 +712,14 @@ mod tests { } #[tokio::test] - async fn router_with_schema_for_tools() { - let router = ToolHandlerRouter::new( - vec![Box::new(GetWeatherTool)], - Arc::new(crate::handler::ApproveAllHandler), - ); - - let tools = router.tools(); - assert_eq!(tools.len(), 1); - assert_eq!(tools[0].name, "get_weather"); + async fn schema_for_derived_tool_round_trips_through_call() { + let tool = GetWeatherTool; - let response = router - .on_external_tool(ToolInvocation { + // Calling the tool with matching arguments returns the + // expected typed result. (Per-name dispatch is the SDK's + // concern; here we exercise just the handler contract.) + let result = tool + .call(ToolInvocation { session_id: SessionId::from("s1"), tool_call_id: "tc1".to_string(), tool_name: "get_weather".to_string(), @@ -1053,8 +727,9 @@ mod tests { traceparent: None, tracestate: None, }) - .await; - match response { + .await + .expect("ToolHandler::call should succeed for matching args"); + match result { ToolResult::Text(s) => assert!(s.contains("Portland")), _ => panic!("expected ToolResult::Text"), } diff --git a/rust/src/types.rs b/rust/src/types.rs index 70f0c16b7..4637131fa 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -12,7 +12,10 @@ use std::time::Duration; use serde::{Deserialize, Serialize}; use serde_json::Value; -use crate::handler::SessionHandler; +use crate::handler::{ + AutoModeSwitchHandler, ElicitationHandler, ExitPlanModeHandler, PermissionHandler, + UserInputHandler, +}; use crate::hooks::SessionHooks; pub use crate::session_fs::{ DirEntry, DirEntryKind, FileInfo, FsError, SessionFsCapabilities, SessionFsConfig, @@ -22,17 +25,12 @@ pub use crate::session_fs::{ pub use crate::trace_context::{TraceContext, TraceContextProvider}; use crate::transforms::SystemMessageTransform; -/// Lifecycle state of a [`Client`](crate::Client) connection to the CLI. -/// -/// The state advances from `Connecting` → `Connected` during construction, -/// transitions to `Disconnected` after [`Client::stop`](crate::Client::stop) or -/// [`Client::force_stop`](crate::Client::force_stop), and lands in -/// `Error` if startup fails or the underlying transport tears down -/// unexpectedly. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] +/// Lifecycle state of a [`Client`](crate::Client) connection. Internal — +/// not part of the public API. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[allow(dead_code)] #[non_exhaustive] -pub enum ConnectionState { +pub(crate) enum ConnectionState { /// No CLI process is attached or the process has exited cleanly. Disconnected, /// The client is starting up (spawning the CLI, negotiating protocol). @@ -298,7 +296,14 @@ impl PartialEq<&str> for RequestId { /// (rather than using the schema-generated form) so it can carry runtime /// hints — `overrides_built_in_tool`, `skip_permission` — that don't appear /// in the wire schema but are honored by the CLI. -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +/// +/// A `Tool` may optionally carry a [`handler`](Self::handler): an +/// `Arc` that implements the tool's runtime behavior. +/// When present, the SDK dispatches matching `external_tool.requested` +/// broadcasts to it automatically. When absent (`None`), the tool is +/// declaration-only — another connected client must service incoming +/// invocations. +#[derive(Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] #[non_exhaustive] pub struct Tool { @@ -327,6 +332,14 @@ pub struct Tool { /// access control. #[serde(default, skip_serializing_if = "is_false")] pub skip_permission: bool, + /// Optional runtime implementation. When `Some`, the SDK dispatches + /// matching `external_tool.requested` broadcasts to this handler. + /// When `None`, the tool is declaration-only. + /// + /// Skipped during serialization — the handler is runtime behavior, + /// not part of the wire representation. + #[serde(skip)] + pub handler: Option>, } #[inline] @@ -410,6 +423,32 @@ impl Tool { self.skip_permission = skip; self } + + /// Attach a runtime implementation. The SDK will dispatch matching + /// `external_tool.requested` broadcasts to `handler` for this tool's + /// name. Without a handler the tool is declaration-only. + pub fn with_handler(mut self, handler: Arc) -> Self { + self.handler = Some(handler); + self + } +} + +impl std::fmt::Debug for Tool { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Tool") + .field("name", &self.name) + .field("namespaced_name", &self.namespaced_name) + .field("description", &self.description) + .field("instructions", &self.instructions) + .field("parameters", &self.parameters) + .field("overrides_built_in_tool", &self.overrides_built_in_tool) + .field("skip_permission", &self.skip_permission) + .field( + "handler", + &self.handler.as_ref().map(|_| "").unwrap_or("None"), + ) + .finish() + } } /// Context passed to a [`CommandHandler`] when a registered slash command @@ -736,7 +775,7 @@ impl CloudSessionOptions { /// servers.insert( /// "playwright".to_string(), /// McpServerConfig::Stdio(McpStdioServerConfig { -/// tools: vec!["*".to_string()], +/// tools: Some(vec!["*".to_string()]), /// command: "npx".to_string(), /// args: vec!["-y".to_string(), "@playwright/mcp".to_string()], /// ..Default::default() @@ -745,7 +784,7 @@ impl CloudSessionOptions { /// servers.insert( /// "weather".to_string(), /// McpServerConfig::Http(McpHttpServerConfig { -/// tools: vec!["forecast".to_string()], +/// tools: Some(vec!["forecast".to_string()]), /// url: "https://example.com/mcp".to_string(), /// ..Default::default() /// }), @@ -772,9 +811,13 @@ pub enum McpServerConfig { #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct McpStdioServerConfig { - /// Tools to expose from this server. `["*"]` exposes all; `[]` exposes none. - #[serde(default)] - pub tools: Vec, + /// Tools to expose from this server. + /// + /// - `None` (field omitted on the wire) — expose **all** tools. + /// - `Some(vec![])` — expose **no** tools. + /// - `Some(vec!["a", ...])` — expose only the listed tools. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tools: Option>, /// Optional timeout in milliseconds for tool calls to this server. #[serde(default, skip_serializing_if = "Option::is_none")] pub timeout: Option, @@ -798,9 +841,13 @@ pub struct McpStdioServerConfig { #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct McpHttpServerConfig { - /// Tools to expose from this server. `["*"]` exposes all; `[]` exposes none. - #[serde(default)] - pub tools: Vec, + /// Tools to expose from this server. + /// + /// - `None` (field omitted on the wire) — expose **all** tools. + /// - `Some(vec![])` — expose **no** tools. + /// - `Some(vec!["a", ...])` — expose only the listed tools. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tools: Option>, /// Optional timeout in milliseconds for tool calls to this server. #[serde(default, skip_serializing_if = "Option::is_none")] pub timeout: Option, @@ -1053,32 +1100,6 @@ pub struct SessionConfig { /// When true, the CLI runs config discovery (MCP config files, skills, plugins). #[serde(skip_serializing_if = "Option::is_none")] pub enable_config_discovery: Option, - /// Enable the `ask_user` tool for interactive user input. Defaults to - /// `Some(true)` via [`SessionConfig::default`]. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_user_input: Option, - /// Enable `permission.request` JSON-RPC calls from the CLI. Defaults - /// to `Some(true)` via [`SessionConfig::default`]; the default - /// [`NoopHandler`](crate::handler::NoopHandler) leaves requests pending - /// for the consumer to resolve. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_permission: Option, - /// Enable `exitPlanMode.request` JSON-RPC calls for plan approval. - /// Defaults to `Some(true)` via [`SessionConfig::default`]. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_exit_plan_mode: Option, - /// Enable `autoModeSwitch.request` JSON-RPC calls. When `true`, the CLI - /// asks the handler whether to switch to auto model when an eligible - /// rate limit is hit. Defaults to `Some(true)` via - /// [`SessionConfig::default`]. Without this flag, the CLI surfaces the - /// rate-limit error directly without offering the auto-mode switch. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_auto_mode_switch: Option, - /// Advertise elicitation provider capability. When true, the CLI sends - /// `elicitation.requested` events that the handler can respond to. - /// Defaults to `Some(true)` via [`SessionConfig::default`]. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_elicitation: Option, /// Skill directory paths passed through to the GitHub Copilot CLI. #[serde(skip_serializing_if = "Option::is_none")] pub skill_directories: Option>, @@ -1171,23 +1192,44 @@ pub struct SessionConfig { /// See [`SessionFsProvider`]. #[serde(skip)] pub session_fs_provider: Option>, - /// Session-level event handler. The default is - /// [`NoopHandler`](crate::handler::NoopHandler) — permission requests - /// and external tool calls are left pending for the consumer to resolve. - /// Use [`with_handler`](Self::with_handler) to install a custom handler. + /// Optional permission-request handler. When `None`, the SDK sends + /// `requestPermission: false` on the wire so the runtime does not + /// emit `permission.requested` broadcasts to this client. #[serde(skip)] - pub handler: Option>, + pub permission_handler: Option>, + /// Optional elicitation-request handler. When `None`, + /// `requestElicitation: false` goes on the wire. + #[serde(skip)] + pub elicitation_handler: Option>, + /// Optional user-input handler. When `None`, + /// `requestUserInput: false` goes on the wire and the `ask_user` + /// tool is disabled. + #[serde(skip)] + pub user_input_handler: Option>, + /// Optional exit-plan-mode handler. When `None`, + /// `requestExitPlanMode: false` goes on the wire. + #[serde(skip)] + pub exit_plan_mode_handler: Option>, + /// Optional auto-mode-switch handler. When `None`, + /// `requestAutoModeSwitch: false` goes on the wire. + #[serde(skip)] + pub auto_mode_switch_handler: Option>, /// Session lifecycle hook handler (pre/post tool use, session /// start/end, etc.). When set, the SDK auto-enables the wire-level /// `hooks` flag. Use [`with_hooks`](Self::with_hooks) to install one. #[serde(skip)] pub hooks_handler: Option>, + /// Permission policy applied to the handler. Stored separately from + /// `permission_handler` so the order of `with_permission_handler` and + /// `approve_all_permissions` (and friends) is irrelevant. + #[serde(skip)] + pub(crate) permission_policy: Option, /// System-message transform. When set, the SDK injects the matching /// `action: "transform"` sections into the system message and routes /// `systemMessage.transform` RPC callbacks to it during the session. - /// Use [`with_transform`](Self::with_transform) to install one. + /// Use [`with_system_message_transform`](Self::with_system_message_transform) to install one. #[serde(skip)] - pub transform: Option>, + pub system_message_transform: Option>, } impl std::fmt::Debug for SessionConfig { @@ -1204,11 +1246,6 @@ impl std::fmt::Debug for SessionConfig { .field("excluded_tools", &self.excluded_tools) .field("mcp_servers", &self.mcp_servers) .field("enable_config_discovery", &self.enable_config_discovery) - .field("request_user_input", &self.request_user_input) - .field("request_permission", &self.request_permission) - .field("request_exit_plan_mode", &self.request_exit_plan_mode) - .field("request_auto_mode_switch", &self.request_auto_mode_switch) - .field("request_elicitation", &self.request_elicitation) .field("skill_directories", &self.skill_directories) .field("instruction_directories", &self.instruction_directories) .field("disabled_skills", &self.disabled_skills) @@ -1237,22 +1274,44 @@ impl std::fmt::Debug for SessionConfig { "session_fs_provider", &self.session_fs_provider.as_ref().map(|_| ""), ) - .field("handler", &self.handler.as_ref().map(|_| "")) + .field( + "permission_handler", + &self.permission_handler.as_ref().map(|_| ""), + ) + .field( + "elicitation_handler", + &self.elicitation_handler.as_ref().map(|_| ""), + ) + .field( + "user_input_handler", + &self.user_input_handler.as_ref().map(|_| ""), + ) + .field( + "exit_plan_mode_handler", + &self.exit_plan_mode_handler.as_ref().map(|_| ""), + ) + .field( + "auto_mode_switch_handler", + &self.auto_mode_switch_handler.as_ref().map(|_| ""), + ) .field( "hooks_handler", &self.hooks_handler.as_ref().map(|_| ""), ) - .field("transform", &self.transform.as_ref().map(|_| "")) + .field( + "system_message_transform", + &self.system_message_transform.as_ref().map(|_| ""), + ) .finish() } } impl Default for SessionConfig { - /// Permission and elicitation flows are enabled by default. When no handler - /// is provided, the SDK installs `NoopHandler`, so permission and external - /// tool requests remain pending until the consumer responds out-of-band. - /// Callers that want the wire surface fully disabled set these explicitly - /// to `Some(false)`. + /// All wire-level "request" flags and handler fields start unset. + /// Install a [`PermissionHandler`] via + /// [`with_permission_handler`](Self::with_permission_handler) and + /// the SDK derives `requestPermission: true` on the wire at + /// [`Client::create_session`](crate::Client::create_session) time. fn default() -> Self { Self { session_id: None, @@ -1267,11 +1326,6 @@ impl Default for SessionConfig { mcp_servers: None, env_value_mode: default_env_value_mode(), enable_config_discovery: None, - request_user_input: Some(true), - request_permission: Some(true), - request_exit_plan_mode: Some(true), - request_auto_mode_switch: Some(true), - request_elicitation: Some(true), skill_directories: None, instruction_directories: None, disabled_skills: None, @@ -1291,17 +1345,105 @@ impl Default for SessionConfig { include_sub_agent_streaming_events: None, commands: None, session_fs_provider: None, - handler: None, + permission_handler: None, + elicitation_handler: None, + user_input_handler: None, + exit_plan_mode_handler: None, + auto_mode_switch_handler: None, hooks_handler: None, - transform: None, + permission_policy: None, + system_message_transform: None, } } } impl SessionConfig { - /// Install a custom [`SessionHandler`] for this session. - pub fn with_handler(mut self, handler: Arc) -> Self { - self.handler = Some(handler); + /// Build the [`SessionCreateWire`] payload for `session.create` from + /// this config. Derives the request_* wire flags from handler + /// presence and the policy field; clones plain fields. + pub(crate) fn to_wire(&self, session_id: SessionId) -> crate::wire::SessionCreateWire { + let permission_active = + self.permission_handler.is_some() || self.permission_policy.is_some(); + crate::wire::SessionCreateWire { + session_id, + model: self.model.clone(), + client_name: self.client_name.clone(), + reasoning_effort: self.reasoning_effort.clone(), + streaming: self.streaming, + system_message: self.system_message.clone(), + tools: self.tools.clone(), + available_tools: self.available_tools.clone(), + excluded_tools: self.excluded_tools.clone(), + mcp_servers: self.mcp_servers.clone(), + env_value_mode: "direct", + enable_config_discovery: self.enable_config_discovery, + request_user_input: self.user_input_handler.is_some(), + request_permission: permission_active, + request_exit_plan_mode: self.exit_plan_mode_handler.is_some(), + request_auto_mode_switch: self.auto_mode_switch_handler.is_some(), + request_elicitation: self.elicitation_handler.is_some(), + hooks: self.hooks_handler.is_some(), + skill_directories: self.skill_directories.clone(), + instruction_directories: self.instruction_directories.clone(), + disabled_skills: self.disabled_skills.clone(), + custom_agents: self.custom_agents.clone(), + default_agent: self.default_agent.clone(), + agent: self.agent.clone(), + infinite_sessions: self.infinite_sessions.clone(), + provider: self.provider.clone(), + enable_session_telemetry: self.enable_session_telemetry, + model_capabilities: self.model_capabilities.clone(), + config_dir: self.config_dir.clone(), + working_directory: self.working_directory.clone(), + github_token: self.github_token.clone(), + remote_session: self.remote_session.clone(), + cloud: self.cloud.clone(), + include_sub_agent_streaming_events: self.include_sub_agent_streaming_events, + commands: self.commands.as_ref().map(|cmds| { + cmds.iter() + .map(|c| crate::wire::CommandWireDefinition { + name: c.name.clone(), + description: c.description.clone(), + }) + .collect() + }), + } + } + + /// Install a [`PermissionHandler`] for this session. When omitted, the + /// SDK sends `requestPermission: false` on the wire and the runtime + /// short-circuits permission prompts for this client. + pub fn with_permission_handler(mut self, handler: Arc) -> Self { + self.permission_handler = Some(handler); + self + } + + /// Install an [`ElicitationHandler`]. When omitted, the SDK sends + /// `requestElicitation: false` on the wire. + pub fn with_elicitation_handler(mut self, handler: Arc) -> Self { + self.elicitation_handler = Some(handler); + self + } + + /// Install a [`UserInputHandler`]. Required for the `ask_user` tool + /// to be enabled. + pub fn with_user_input_handler(mut self, handler: Arc) -> Self { + self.user_input_handler = Some(handler); + self + } + + /// Install an [`ExitPlanModeHandler`]. + pub fn with_exit_plan_mode_handler(mut self, handler: Arc) -> Self { + self.exit_plan_mode_handler = Some(handler); + self + } + + /// Install an [`AutoModeSwitchHandler`]. + pub fn with_auto_mode_switch_handler( + mut self, + handler: Arc, + ) -> Self { + self.auto_mode_switch_handler = Some(handler); self } @@ -1332,59 +1474,40 @@ impl SessionConfig { /// Install a [`SystemMessageTransform`]. The SDK injects the matching /// `action: "transform"` sections into the system message and routes /// `systemMessage.transform` RPC callbacks to it during the session. - pub fn with_transform(mut self, transform: Arc) -> Self { - self.transform = Some(transform); + pub fn with_system_message_transform( + mut self, + transform: Arc, + ) -> Self { + self.system_message_transform = Some(transform); self } - /// Wrap the configured handler so every permission request is - /// auto-approved. Forwards every non-permission event to the inner - /// handler unchanged. - /// - /// If no handler has been installed via [`with_handler`](Self::with_handler), - /// wraps a [`NoopHandler`](crate::handler::NoopHandler), so declaration-only - /// tools remain pending for manual resolution. - /// - /// Order-independent: `with_handler(...).approve_all_permissions()` and - /// `approve_all_permissions().with_handler(...)` are NOT equivalent — - /// the second form discards the wrap because `with_handler` overwrites - /// the handler field. Always call `approve_all_permissions` *after* - /// `with_handler`. + /// Auto-approve every permission request on this session. Stored as a + /// policy that's applied at + /// [`Client::create_session`](crate::Client::create_session) time, so + /// order with [`with_permission_handler`](Self::with_permission_handler) + /// is irrelevant. pub fn approve_all_permissions(mut self) -> Self { - let inner = self - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - self.handler = Some(crate::permission::approve_all(inner)); + self.permission_policy = Some(crate::permission::Policy::ApproveAll); self } - /// Wrap the configured handler so every permission request is - /// auto-denied. See [`approve_all_permissions`](Self::approve_all_permissions) - /// for ordering and default-handler semantics. + /// Auto-deny every permission request on this session. See + /// [`approve_all_permissions`](Self::approve_all_permissions). pub fn deny_all_permissions(mut self) -> Self { - let inner = self - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - self.handler = Some(crate::permission::deny_all(inner)); + self.permission_policy = Some(crate::permission::Policy::DenyAll); self } - /// Wrap the configured handler with a closure-based permission policy: - /// `predicate` is called for each permission request; `true` approves, - /// `false` denies. See + /// Apply a closure-based permission policy: `predicate` returns `true` + /// to approve, `false` to deny. See /// [`approve_all_permissions`](Self::approve_all_permissions) for - /// ordering and default-handler semantics. + /// ordering semantics. pub fn approve_permissions_if(mut self, predicate: F) -> Self where F: Fn(&crate::types::PermissionRequestData) -> bool + Send + Sync + 'static, { - let inner = self - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - self.handler = Some(crate::permission::approve_if(inner, predicate)); + self.permission_policy = Some(crate::permission::Policy::Predicate(Arc::new(predicate))); self } @@ -1462,36 +1585,6 @@ impl SessionConfig { self } - /// Enable the `ask_user` tool. Defaults to `Some(true)` via [`Self::default`]. - pub fn with_request_user_input(mut self, enable: bool) -> Self { - self.request_user_input = Some(enable); - self - } - - /// Enable `permission.request` JSON-RPC calls. Defaults to `Some(true)`. - pub fn with_request_permission(mut self, enable: bool) -> Self { - self.request_permission = Some(enable); - self - } - - /// Enable `exitPlanMode.request` JSON-RPC calls. Defaults to `Some(true)`. - pub fn with_request_exit_plan_mode(mut self, enable: bool) -> Self { - self.request_exit_plan_mode = Some(enable); - self - } - - /// Enable `autoModeSwitch.request` JSON-RPC calls. Defaults to `Some(true)`. - pub fn with_request_auto_mode_switch(mut self, enable: bool) -> Self { - self.request_auto_mode_switch = Some(enable); - self - } - - /// Advertise elicitation provider capability. Defaults to `Some(true)`. - pub fn with_request_elicitation(mut self, enable: bool) -> Self { - self.request_elicitation = Some(enable); - self - } - /// Set skill directory paths passed through to the CLI. pub fn with_skill_directories(mut self, paths: I) -> Self where @@ -1663,24 +1756,6 @@ pub struct ResumeSessionConfig { /// Enable config discovery on resume. #[serde(skip_serializing_if = "Option::is_none")] pub enable_config_discovery: Option, - /// Enable the ask_user tool. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_user_input: Option, - /// Enable permission request RPCs. When no handler is set, permission requests - /// remain pending until the consumer responds out-of-band. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_permission: Option, - /// Enable exit-plan-mode request RPCs. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_exit_plan_mode: Option, - /// Enable auto-mode-switch request RPCs on resume. Defaults to - /// `Some(true)` via [`ResumeSessionConfig::new`]. See - /// [`SessionConfig::request_auto_mode_switch`] for details. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_auto_mode_switch: Option, - /// Advertise elicitation provider capability on resume. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_elicitation: Option, /// Skill directory paths passed through to the GitHub Copilot CLI on resume. #[serde(skip_serializing_if = "Option::is_none")] pub skill_directories: Option>, @@ -1750,9 +1825,9 @@ pub struct ResumeSessionConfig { #[serde(skip)] pub session_fs_provider: Option>, /// Force-fail resume if the session does not exist on disk, instead of - /// silently starting a new session. - #[serde(skip_serializing_if = "Option::is_none")] - pub disable_resume: Option, + /// silently starting a new session. Wire field name stays `disableResume`. + #[serde(rename = "disableResume", skip_serializing_if = "Option::is_none")] + pub suppress_resume_event: Option, /// When `true`, instructs the runtime to continue any tool calls or /// permission requests that were pending when the previous connection /// was dropped. Use this together with [`Client::force_stop`] to hand @@ -1762,15 +1837,35 @@ pub struct ResumeSessionConfig { /// [`Client::force_stop`]: crate::Client::force_stop #[serde(skip_serializing_if = "Option::is_none")] pub continue_pending_work: Option, - /// Session-level event handler. See [`SessionConfig::handler`]. + /// Optional permission-request handler. See + /// [`SessionConfig::permission_handler`]. + #[serde(skip)] + pub permission_handler: Option>, + /// Optional elicitation handler. See + /// [`SessionConfig::elicitation_handler`]. #[serde(skip)] - pub handler: Option>, + pub elicitation_handler: Option>, + /// Optional user-input handler. See + /// [`SessionConfig::user_input_handler`]. + #[serde(skip)] + pub user_input_handler: Option>, + /// Optional exit-plan-mode handler. See + /// [`SessionConfig::exit_plan_mode_handler`]. + #[serde(skip)] + pub exit_plan_mode_handler: Option>, + /// Optional auto-mode-switch handler. See + /// [`SessionConfig::auto_mode_switch_handler`]. + #[serde(skip)] + pub auto_mode_switch_handler: Option>, /// Session hook handler. See [`SessionConfig::hooks_handler`]. #[serde(skip)] pub hooks_handler: Option>, - /// System-message transform. See [`SessionConfig::transform`]. + /// Permission policy. See `SessionConfig::permission_policy`. #[serde(skip)] - pub transform: Option>, + pub(crate) permission_policy: Option, + /// System-message transform. See [`SessionConfig::system_message_transform`]. + #[serde(skip)] + pub system_message_transform: Option>, } impl std::fmt::Debug for ResumeSessionConfig { @@ -1786,11 +1881,6 @@ impl std::fmt::Debug for ResumeSessionConfig { .field("excluded_tools", &self.excluded_tools) .field("mcp_servers", &self.mcp_servers) .field("enable_config_discovery", &self.enable_config_discovery) - .field("request_user_input", &self.request_user_input) - .field("request_permission", &self.request_permission) - .field("request_exit_plan_mode", &self.request_exit_plan_mode) - .field("request_auto_mode_switch", &self.request_auto_mode_switch) - .field("request_elicitation", &self.request_elicitation) .field("skill_directories", &self.skill_directories) .field("instruction_directories", &self.instruction_directories) .field("disabled_skills", &self.disabled_skills) @@ -1818,19 +1908,93 @@ impl std::fmt::Debug for ResumeSessionConfig { "session_fs_provider", &self.session_fs_provider.as_ref().map(|_| ""), ) - .field("handler", &self.handler.as_ref().map(|_| "")) + .field( + "permission_handler", + &self.permission_handler.as_ref().map(|_| ""), + ) + .field( + "elicitation_handler", + &self.elicitation_handler.as_ref().map(|_| ""), + ) + .field( + "user_input_handler", + &self.user_input_handler.as_ref().map(|_| ""), + ) + .field( + "exit_plan_mode_handler", + &self.exit_plan_mode_handler.as_ref().map(|_| ""), + ) + .field( + "auto_mode_switch_handler", + &self.auto_mode_switch_handler.as_ref().map(|_| ""), + ) .field( "hooks_handler", &self.hooks_handler.as_ref().map(|_| ""), ) - .field("transform", &self.transform.as_ref().map(|_| "")) - .field("disable_resume", &self.disable_resume) + .field( + "system_message_transform", + &self.system_message_transform.as_ref().map(|_| ""), + ) + .field("suppress_resume_event", &self.suppress_resume_event) .field("continue_pending_work", &self.continue_pending_work) .finish() } } impl ResumeSessionConfig { + /// Build the [`SessionResumeWire`] payload for `session.resume` from + /// this config. Derives the request_* wire flags from handler + /// presence and the policy field; clones plain fields. + pub(crate) fn to_wire(&self) -> crate::wire::SessionResumeWire { + let permission_active = + self.permission_handler.is_some() || self.permission_policy.is_some(); + crate::wire::SessionResumeWire { + session_id: self.session_id.clone(), + client_name: self.client_name.clone(), + reasoning_effort: self.reasoning_effort.clone(), + streaming: self.streaming, + system_message: self.system_message.clone(), + tools: self.tools.clone(), + available_tools: self.available_tools.clone(), + excluded_tools: self.excluded_tools.clone(), + mcp_servers: self.mcp_servers.clone(), + env_value_mode: "direct", + enable_config_discovery: self.enable_config_discovery, + request_user_input: self.user_input_handler.is_some(), + request_permission: permission_active, + request_exit_plan_mode: self.exit_plan_mode_handler.is_some(), + request_auto_mode_switch: self.auto_mode_switch_handler.is_some(), + request_elicitation: self.elicitation_handler.is_some(), + hooks: self.hooks_handler.is_some(), + skill_directories: self.skill_directories.clone(), + instruction_directories: self.instruction_directories.clone(), + disabled_skills: self.disabled_skills.clone(), + custom_agents: self.custom_agents.clone(), + default_agent: self.default_agent.clone(), + agent: self.agent.clone(), + infinite_sessions: self.infinite_sessions.clone(), + provider: self.provider.clone(), + enable_session_telemetry: self.enable_session_telemetry, + model_capabilities: self.model_capabilities.clone(), + config_dir: self.config_dir.clone(), + working_directory: self.working_directory.clone(), + github_token: self.github_token.clone(), + remote_session: self.remote_session.clone(), + include_sub_agent_streaming_events: self.include_sub_agent_streaming_events, + commands: self.commands.as_ref().map(|cmds| { + cmds.iter() + .map(|c| crate::wire::CommandWireDefinition { + name: c.name.clone(), + description: c.description.clone(), + }) + .collect() + }), + suppress_resume_event: self.suppress_resume_event, + continue_pending_work: self.continue_pending_work, + } + } + /// Construct a `ResumeSessionConfig` with the given session ID and all /// other fields left unset. Combine with `.with_*` builders or struct /// update syntax (`..ResumeSessionConfig::new(id)`) to populate the @@ -1848,11 +2012,6 @@ impl ResumeSessionConfig { mcp_servers: None, env_value_mode: default_env_value_mode(), enable_config_discovery: None, - request_user_input: Some(true), - request_permission: Some(true), - request_exit_plan_mode: Some(true), - request_auto_mode_switch: Some(true), - request_elicitation: Some(true), skill_directories: None, instruction_directories: None, disabled_skills: None, @@ -1871,17 +2030,49 @@ impl ResumeSessionConfig { include_sub_agent_streaming_events: None, commands: None, session_fs_provider: None, - disable_resume: None, + suppress_resume_event: None, continue_pending_work: None, - handler: None, + permission_handler: None, + elicitation_handler: None, + user_input_handler: None, + exit_plan_mode_handler: None, + auto_mode_switch_handler: None, hooks_handler: None, - transform: None, + permission_policy: None, + system_message_transform: None, } } - /// Install a custom [`SessionHandler`] for this session. - pub fn with_handler(mut self, handler: Arc) -> Self { - self.handler = Some(handler); + /// Install a [`PermissionHandler`] for the resumed session. + pub fn with_permission_handler(mut self, handler: Arc) -> Self { + self.permission_handler = Some(handler); + self + } + + /// Install an [`ElicitationHandler`] for the resumed session. + pub fn with_elicitation_handler(mut self, handler: Arc) -> Self { + self.elicitation_handler = Some(handler); + self + } + + /// Install a [`UserInputHandler`] for the resumed session. + pub fn with_user_input_handler(mut self, handler: Arc) -> Self { + self.user_input_handler = Some(handler); + self + } + + /// Install an [`ExitPlanModeHandler`] for the resumed session. + pub fn with_exit_plan_mode_handler(mut self, handler: Arc) -> Self { + self.exit_plan_mode_handler = Some(handler); + self + } + + /// Install an [`AutoModeSwitchHandler`] for the resumed session. + pub fn with_auto_mode_switch_handler( + mut self, + handler: Arc, + ) -> Self { + self.auto_mode_switch_handler = Some(handler); self } @@ -1893,8 +2084,11 @@ impl ResumeSessionConfig { } /// Install a [`SystemMessageTransform`]. - pub fn with_transform(mut self, transform: Arc) -> Self { - self.transform = Some(transform); + pub fn with_system_message_transform( + mut self, + transform: Arc, + ) -> Self { + self.system_message_transform = Some(transform); self } @@ -1913,41 +2107,27 @@ impl ResumeSessionConfig { self } - /// Wrap the configured handler so every permission request is - /// auto-approved. See - /// [`SessionConfig::approve_all_permissions`] for semantics. + /// Auto-approve every permission request on the resumed session. See + /// [`SessionConfig::approve_all_permissions`]. pub fn approve_all_permissions(mut self) -> Self { - let inner = self - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - self.handler = Some(crate::permission::approve_all(inner)); + self.permission_policy = Some(crate::permission::Policy::ApproveAll); self } - /// Wrap the configured handler so every permission request is - /// auto-denied. See - /// [`SessionConfig::deny_all_permissions`] for semantics. + /// Auto-deny every permission request on the resumed session. See + /// [`SessionConfig::deny_all_permissions`]. pub fn deny_all_permissions(mut self) -> Self { - let inner = self - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - self.handler = Some(crate::permission::deny_all(inner)); + self.permission_policy = Some(crate::permission::Policy::DenyAll); self } - /// Wrap the configured handler with a predicate-based permission policy. - /// See [`SessionConfig::approve_permissions_if`] for semantics. + /// Apply a closure-based permission policy on the resumed session. + /// See [`SessionConfig::approve_permissions_if`]. pub fn approve_permissions_if(mut self, predicate: F) -> Self where F: Fn(&crate::types::PermissionRequestData) -> bool + Send + Sync + 'static, { - let inner = self - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - self.handler = Some(crate::permission::approve_if(inner, predicate)); + self.permission_policy = Some(crate::permission::Policy::Predicate(Arc::new(predicate))); self } @@ -2014,36 +2194,6 @@ impl ResumeSessionConfig { self } - /// Enable the `ask_user` tool. Defaults to `Some(true)` via [`Self::new`]. - pub fn with_request_user_input(mut self, enable: bool) -> Self { - self.request_user_input = Some(enable); - self - } - - /// Enable `permission.request` JSON-RPC calls. Defaults to `Some(true)`. - pub fn with_request_permission(mut self, enable: bool) -> Self { - self.request_permission = Some(enable); - self - } - - /// Enable `exitPlanMode.request` JSON-RPC calls. Defaults to `Some(true)`. - pub fn with_request_exit_plan_mode(mut self, enable: bool) -> Self { - self.request_exit_plan_mode = Some(enable); - self - } - - /// Enable `autoModeSwitch.request` JSON-RPC calls. Defaults to `Some(true)`. - pub fn with_request_auto_mode_switch(mut self, enable: bool) -> Self { - self.request_auto_mode_switch = Some(enable); - self - } - - /// Advertise elicitation provider capability on resume. Defaults to `Some(true)`. - pub fn with_request_elicitation(mut self, enable: bool) -> Self { - self.request_elicitation = Some(enable); - self - } - /// Set skill directory paths passed through to the CLI on resume. pub fn with_skill_directories(mut self, paths: I) -> Self where @@ -2163,8 +2313,8 @@ impl ResumeSessionConfig { /// Force-fail resume if the session does not exist on disk, instead /// of silently starting a new session. - pub fn with_disable_resume(mut self, disable: bool) -> Self { - self.disable_resume = Some(disable); + pub fn with_suppress_resume_event(mut self, suppress: bool) -> Self { + self.suppress_resume_event = Some(suppress); self } @@ -3050,9 +3200,10 @@ pub enum ElicitationMode { /// An incoming elicitation request from the CLI (provider side). /// -/// Received via `elicitation.requested` session event when the session was -/// created with `request_elicitation: true`. The provider should render a -/// form or dialog and return an [`ElicitationResult`]. +/// Received via `elicitation.requested` session event when the session has +/// an [`ElicitationHandler`] installed. +/// The provider should render a form or dialog and return an +/// [`ElicitationResult`]. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ElicitationRequest { @@ -3095,7 +3246,7 @@ pub struct UiCapabilities { /// Options for the [`SessionUi::input`](crate::session::SessionUi::input) convenience method. #[derive(Debug, Clone, Default)] -pub struct InputOptions<'a> { +pub struct UiInputOptions<'a> { /// Title label for the input field. pub title: Option<&'a str>, /// Descriptive text shown below the field. @@ -3180,8 +3331,7 @@ pub enum PermissionRequestKind { /// /// Used for both the `permission.request` RPC call (which expects a response) /// and `permission.requested` notifications (fire-and-forget). Contains the -/// full params object. Note that `requestId` is also available as a separate -/// field on `HandlerEvent::PermissionRequest`. +/// full params object. #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PermissionRequestData { @@ -3360,23 +3510,30 @@ mod tests { } #[test] - fn session_config_default_enables_permission_flow_flags() { + fn session_config_default_wire_flags_off_without_handlers() { let cfg = SessionConfig::default(); - assert_eq!(cfg.request_user_input, Some(true)); - assert_eq!(cfg.request_permission, Some(true)); - assert_eq!(cfg.request_elicitation, Some(true)); - assert_eq!(cfg.request_exit_plan_mode, Some(true)); - assert_eq!(cfg.request_auto_mode_switch, Some(true)); + // Wire flags are derived from handler presence at create_session + // time, not stored on the config. With no handlers installed, every + // request_* flag should serialize as false. + let wire = cfg.to_wire(SessionId::from("default-flags")); + assert!(!wire.request_user_input); + assert!(!wire.request_permission); + assert!(!wire.request_elicitation); + assert!(!wire.request_exit_plan_mode); + assert!(!wire.request_auto_mode_switch); + assert!(!wire.hooks); } #[test] - fn resume_session_config_new_enables_permission_flow_flags() { - let cfg = ResumeSessionConfig::new(SessionId::from("test-id")); - assert_eq!(cfg.request_user_input, Some(true)); - assert_eq!(cfg.request_permission, Some(true)); - assert_eq!(cfg.request_elicitation, Some(true)); - assert_eq!(cfg.request_exit_plan_mode, Some(true)); - assert_eq!(cfg.request_auto_mode_switch, Some(true)); + fn resume_session_config_new_wire_flags_off_without_handlers() { + let cfg = ResumeSessionConfig::new(SessionId::from("resume-flags")); + let wire = cfg.to_wire(); + assert!(!wire.request_user_input); + assert!(!wire.request_permission); + assert!(!wire.request_elicitation); + assert!(!wire.request_exit_plan_mode); + assert!(!wire.request_auto_mode_switch); + assert!(!wire.hooks); } #[test] @@ -3394,9 +3551,6 @@ mod tests { .with_excluded_tools(["dangerous"]) .with_mcp_servers(HashMap::new()) .with_enable_config_discovery(true) - .with_request_user_input(false) - .with_request_exit_plan_mode(false) - .with_request_auto_mode_switch(false) .with_skill_directories([PathBuf::from("/tmp/skills")]) .with_disabled_skills(["broken-skill"]) .with_agent("researcher") @@ -3422,10 +3576,6 @@ mod tests { ); assert!(cfg.mcp_servers.is_some()); assert_eq!(cfg.enable_config_discovery, Some(true)); - assert_eq!(cfg.request_user_input, Some(false)); // overrode default - assert_eq!(cfg.request_permission, Some(true)); // default preserved - assert_eq!(cfg.request_exit_plan_mode, Some(false)); - assert_eq!(cfg.request_auto_mode_switch, Some(false)); assert_eq!( cfg.skill_directories.as_deref(), Some(&[PathBuf::from("/tmp/skills")][..]) @@ -3454,9 +3604,6 @@ mod tests { .with_excluded_tools(["dangerous"]) .with_mcp_servers(HashMap::new()) .with_enable_config_discovery(true) - .with_request_user_input(false) - .with_request_exit_plan_mode(false) - .with_request_auto_mode_switch(false) .with_skill_directories([PathBuf::from("/tmp/skills")]) .with_disabled_skills(["broken-skill"]) .with_agent("researcher") @@ -3465,7 +3612,7 @@ mod tests { .with_github_token("ghp_test") .with_enable_session_telemetry(false) .with_include_sub_agent_streaming_events(true) - .with_disable_resume(true) + .with_suppress_resume_event(true) .with_continue_pending_work(true); assert_eq!(cfg.session_id.as_str(), "sess-2"); @@ -3482,10 +3629,6 @@ mod tests { ); assert!(cfg.mcp_servers.is_some()); assert_eq!(cfg.enable_config_discovery, Some(true)); - assert_eq!(cfg.request_user_input, Some(false)); // overrode default - assert_eq!(cfg.request_permission, Some(true)); // default preserved - assert_eq!(cfg.request_exit_plan_mode, Some(false)); - assert_eq!(cfg.request_auto_mode_switch, Some(false)); assert_eq!( cfg.skill_directories.as_deref(), Some(&[PathBuf::from("/tmp/skills")][..]) @@ -3500,7 +3643,7 @@ mod tests { assert_eq!(cfg.github_token.as_deref(), Some("ghp_test")); assert_eq!(cfg.enable_session_telemetry, Some(false)); assert_eq!(cfg.include_sub_agent_streaming_events, Some(true)); - assert_eq!(cfg.disable_resume, Some(true)); + assert_eq!(cfg.suppress_resume_event, Some(true)); assert_eq!(cfg.continue_pending_work, Some(true)); } @@ -3520,6 +3663,26 @@ mod tests { assert!(wire.get("continuePendingWork").is_none()); } + /// The Rust field is `suppress_resume_event`, but the wire field stays + /// `disableResume` to preserve compatibility with the runtime and other + /// SDKs. + #[test] + fn resume_session_config_serializes_suppress_resume_event_to_disable_resume_on_wire() { + let cfg = + ResumeSessionConfig::new(SessionId::from("sess-1")).with_suppress_resume_event(true); + let wire = serde_json::to_value(&cfg).unwrap(); + assert_eq!(wire["disableResume"], true); + assert!(wire.get("suppressResumeEvent").is_none()); + + // Round-trip: deserialize from the wire shape. + let json = serde_json::json!({ + "sessionId": "sess-2", + "disableResume": true, + }); + let parsed: ResumeSessionConfig = serde_json::from_value(json).unwrap(); + assert_eq!(parsed.suppress_resume_event, Some(true)); + } + /// `instruction_directories` must serialize to wire as /// `instructionDirectories` on `SessionConfig`. #[test] @@ -3677,11 +3840,10 @@ mod tests { } #[test] - fn connection_state_error_serializes_to_match_go() { - let json = serde_json::to_string(&ConnectionState::Error).unwrap(); - assert_eq!(json, "\"error\""); - let parsed: ConnectionState = serde_json::from_str("\"error\"").unwrap(); - assert_eq!(parsed, ConnectionState::Error); + fn connection_state_distinguishes_variants() { + // ConnectionState is now an internal type; verify we can construct + // and compare the variants used by the lifecycle code paths. + assert_ne!(ConnectionState::Connected, ConnectionState::Disconnected); } /// `agentId` is the sub-agent attribution field added in copilot-sdk @@ -3741,19 +3903,14 @@ mod tests { } #[test] - fn connection_state_other_variants_serialize_as_lowercase() { - assert_eq!( - serde_json::to_string(&ConnectionState::Disconnected).unwrap(), - "\"disconnected\"" - ); - assert_eq!( - serde_json::to_string(&ConnectionState::Connecting).unwrap(), - "\"connecting\"" - ); - assert_eq!( - serde_json::to_string(&ConnectionState::Connected).unwrap(), - "\"connected\"" - ); + fn connection_state_variants_compile() { + // Defensive smoke test: all variants must be constructable from + // within the crate. (The enum was demoted from pub to pub(crate) + // in Phase D; this test guards against accidental removal.) + let _ = ConnectionState::Disconnected; + let _ = ConnectionState::Connecting; + let _ = ConnectionState::Connected; + let _ = ConnectionState::Error; } #[test] @@ -3900,88 +4057,127 @@ mod tests { mod permission_builder_tests { use std::sync::Arc; - use crate::handler::{ - ApproveAllHandler, HandlerEvent, HandlerResponse, PermissionResult, SessionHandler, - }; + use crate::handler::{ApproveAllHandler, PermissionHandler, PermissionResult}; + use crate::permission; use crate::types::{ PermissionRequestData, RequestId, ResumeSessionConfig, SessionConfig, SessionId, }; - fn permission_event() -> HandlerEvent { - HandlerEvent::PermissionRequest { - session_id: SessionId::from("s1"), - request_id: RequestId::new("1"), - data: PermissionRequestData { - extra: serde_json::json!({"tool": "shell"}), - ..Default::default() - }, + fn data() -> PermissionRequestData { + PermissionRequestData { + extra: serde_json::json!({"tool": "shell"}), + ..Default::default() } } - async fn dispatch(handler: &Arc) -> HandlerResponse { - handler.on_event(permission_event()).await + /// Apply the same policy-resolution logic that `Client::create_session` + /// uses, so tests exercise the effective handler. + fn resolve_create(mut cfg: SessionConfig) -> Option> { + permission::resolve_handler(cfg.permission_handler.take(), cfg.permission_policy.take()) + } + + fn resolve_resume(mut cfg: ResumeSessionConfig) -> Option> { + permission::resolve_handler(cfg.permission_handler.take(), cfg.permission_policy.take()) + } + + async fn dispatch(handler: &Arc) -> PermissionResult { + handler + .handle(SessionId::from("s1"), RequestId::new("1"), data()) + .await } #[tokio::test] - async fn session_config_approve_all_wraps_existing_handler() { + async fn approve_all_with_handler_present_approves() { let cfg = SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .approve_all_permissions(); - let handler = cfg.handler.expect("handler should be set"); - match dispatch(&handler).await { - HandlerResponse::Permission(PermissionResult::Approved) => {} - other => panic!("expected Approved, got {other:?}"), - } + let h = resolve_create(cfg).expect("policy + handler yields handler"); + assert!(matches!(dispatch(&h).await, PermissionResult::Approved)); } #[tokio::test] - async fn session_config_approve_all_defaults_to_noop_inner() { - // Without with_handler, the wrap defaults to NoopHandler. The - // approve-all wrap intercepts permission events, so they're still - // approved -- the inner handler is consulted only for other events. + async fn approve_all_standalone_produces_handler() { let cfg = SessionConfig::default().approve_all_permissions(); - let handler = cfg.handler.expect("handler should be set"); - match dispatch(&handler).await { - HandlerResponse::Permission(PermissionResult::Approved) => {} - other => panic!("expected Approved, got {other:?}"), - } + let h = resolve_create(cfg).expect("policy alone yields handler"); + assert!(matches!(dispatch(&h).await, PermissionResult::Approved)); } + /// Phase I: order between with_permission_handler and the policy + /// builder must not matter. #[tokio::test] - async fn session_config_deny_all_denies() { - let cfg = SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + async fn approve_all_is_order_independent() { + let a = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .approve_all_permissions(); + let b = SessionConfig::default() + .approve_all_permissions() + .with_permission_handler(Arc::new(ApproveAllHandler)); + let ha = resolve_create(a).unwrap(); + let hb = resolve_create(b).unwrap(); + assert!(matches!(dispatch(&ha).await, PermissionResult::Approved)); + assert!(matches!(dispatch(&hb).await, PermissionResult::Approved)); + } + + #[tokio::test] + async fn deny_all_is_order_independent() { + let a = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) .deny_all_permissions(); - let handler = cfg.handler.expect("handler should be set"); - match dispatch(&handler).await { - HandlerResponse::Permission(PermissionResult::Denied) => {} - other => panic!("expected Denied, got {other:?}"), - } + let b = SessionConfig::default() + .deny_all_permissions() + .with_permission_handler(Arc::new(ApproveAllHandler)); + let ha = resolve_create(a).unwrap(); + let hb = resolve_create(b).unwrap(); + assert!(matches!(dispatch(&ha).await, PermissionResult::Denied)); + assert!(matches!(dispatch(&hb).await, PermissionResult::Denied)); } #[tokio::test] - async fn session_config_approve_permissions_if_consults_predicate() { - let cfg = SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) - .approve_permissions_if(|data| { - data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") - }); - let handler = cfg.handler.expect("handler should be set"); - match dispatch(&handler).await { - HandlerResponse::Permission(PermissionResult::Denied) => {} - other => panic!("expected Denied for shell, got {other:?}"), - } + async fn approve_permissions_if_consults_predicate() { + let cfg = SessionConfig::default().approve_permissions_if(|d| { + d.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") + }); + let h = resolve_create(cfg).unwrap(); + assert!(matches!(dispatch(&h).await, PermissionResult::Denied)); + } + + #[tokio::test] + async fn approve_permissions_if_is_order_independent() { + let predicate = |d: &PermissionRequestData| { + d.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") + }; + let a = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .approve_permissions_if(predicate); + let b = SessionConfig::default() + .approve_permissions_if(predicate) + .with_permission_handler(Arc::new(ApproveAllHandler)); + let ha = resolve_create(a).unwrap(); + let hb = resolve_create(b).unwrap(); + assert!(matches!(dispatch(&ha).await, PermissionResult::Denied)); + assert!(matches!(dispatch(&hb).await, PermissionResult::Denied)); } #[tokio::test] - async fn resume_session_config_approve_all_wraps_existing_handler() { + async fn resume_session_config_approve_all_works() { let cfg = ResumeSessionConfig::new(SessionId::from("s1")) - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .approve_all_permissions(); - let handler = cfg.handler.expect("handler should be set"); - match dispatch(&handler).await { - HandlerResponse::Permission(PermissionResult::Approved) => {} - other => panic!("expected Approved, got {other:?}"), - } + let h = resolve_resume(cfg).unwrap(); + assert!(matches!(dispatch(&h).await, PermissionResult::Approved)); + } + + #[tokio::test] + async fn resume_session_config_approve_all_is_order_independent() { + let a = ResumeSessionConfig::new(SessionId::from("s1")) + .with_permission_handler(Arc::new(ApproveAllHandler)) + .approve_all_permissions(); + let b = ResumeSessionConfig::new(SessionId::from("s1")) + .approve_all_permissions() + .with_permission_handler(Arc::new(ApproveAllHandler)); + let ha = resolve_resume(a).unwrap(); + let hb = resolve_resume(b).unwrap(); + assert!(matches!(dispatch(&ha).await, PermissionResult::Approved)); + assert!(matches!(dispatch(&hb).await, PermissionResult::Approved)); } } diff --git a/rust/src/wire.rs b/rust/src/wire.rs new file mode 100644 index 000000000..bc6af5651 --- /dev/null +++ b/rust/src/wire.rs @@ -0,0 +1,173 @@ +//! Wire-format structs for the `session.create` and `session.resume` +//! JSON-RPC payloads. +//! +//! Built explicitly from [`SessionConfig`](crate::types::SessionConfig) and +//! [`ResumeSessionConfig`](crate::types::ResumeSessionConfig) at +//! `Client::create_session` / `Client::resume_session` time via +//! [`SessionConfig::into_wire`](crate::types::SessionConfig::into_wire) and +//! [`ResumeSessionConfig::into_wire`](crate::types::ResumeSessionConfig::into_wire), +//! respectively. +//! +//! Keeping the wire shape separate from the user-facing config avoids +//! having callback fields on a serializable struct: the user-facing +//! configs hold trait-object handlers, the wire structs hold only the +//! plain data the runtime needs. + +use std::collections::HashMap; +use std::path::PathBuf; + +use serde::Serialize; + +use crate::generated::api_types::{ModelCapabilitiesOverride, RemoteSessionMode}; +use crate::types::{ + CloudSessionOptions, CustomAgentConfig, DefaultAgentConfig, InfiniteSessionConfig, + McpServerConfig, ProviderConfig, SessionId, SystemMessageConfig, Tool, +}; + +/// Wire representation of a slash command (name + description only). The +/// runtime executes the command; the SDK's `CommandHandler` callback is +/// invoked from a separate dispatch path and never crosses the wire. +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct CommandWireDefinition { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, +} + +/// The exact JSON shape sent on the `session.create` JSON-RPC request. +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct SessionCreateWire { + pub session_id: SessionId, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub client_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub streaming: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_message: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub available_tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub excluded_tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp_servers: Option>, + pub env_value_mode: &'static str, + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_config_discovery: Option, + pub request_user_input: bool, + pub request_permission: bool, + pub request_exit_plan_mode: bool, + pub request_auto_mode_switch: bool, + pub request_elicitation: bool, + pub hooks: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub skill_directories: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub instruction_directories: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub disabled_skills: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub custom_agents: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub default_agent: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub agent: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub infinite_sessions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub provider: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_session_telemetry: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_capabilities: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub config_dir: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub working_directory: Option, + #[serde(rename = "gitHubToken", skip_serializing_if = "Option::is_none")] + pub github_token: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub remote_session: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cloud: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub include_sub_agent_streaming_events: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub commands: Option>, +} + +/// The exact JSON shape sent on the `session.resume` JSON-RPC request. +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct SessionResumeWire { + pub session_id: SessionId, + #[serde(skip_serializing_if = "Option::is_none")] + pub client_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub streaming: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_message: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub available_tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub excluded_tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp_servers: Option>, + pub env_value_mode: &'static str, + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_config_discovery: Option, + pub request_user_input: bool, + pub request_permission: bool, + pub request_exit_plan_mode: bool, + pub request_auto_mode_switch: bool, + pub request_elicitation: bool, + pub hooks: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub skill_directories: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub instruction_directories: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub disabled_skills: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub custom_agents: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub default_agent: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub agent: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub infinite_sessions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub provider: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_session_telemetry: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_capabilities: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub config_dir: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub working_directory: Option, + #[serde(rename = "gitHubToken", skip_serializing_if = "Option::is_none")] + pub github_token: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub remote_session: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub include_sub_agent_streaming_events: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub commands: Option>, + /// Maps to wire field `disableResume`. + #[serde(rename = "disableResume", skip_serializing_if = "Option::is_none")] + pub suppress_resume_event: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub continue_pending_work: Option, +} diff --git a/rust/tests/e2e/abort.rs b/rust/tests/e2e/abort.rs index ff8977f39..33ef835d7 100644 --- a/rust/tests/e2e/abort.rs +++ b/rust/tests/e2e/abort.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use async_trait::async_trait; use github_copilot_sdk::generated::session_events::{AssistantMessageDeltaData, SessionEventType}; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::{Error, SessionConfig, Tool, ToolInvocation, ToolResult}; use serde_json::json; use tokio::sync::{Mutex, mpsc, oneshot}; @@ -76,20 +76,32 @@ async fn should_abort_during_active_tool_execution() { let client = ctx.start_client().await; let (started_tx, mut started_rx) = mpsc::unbounded_channel(); let (release_tx, release_rx) = oneshot::channel(); - let router = ToolHandlerRouter::new( - vec![Box::new(SlowAnalysisTool { - started_tx, - release_rx: Mutex::new(Some(release_rx)), - })], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); + let slow_tool = Arc::new(SlowAnalysisTool { + started_tx, + release_rx: Mutex::new(Some(release_rx)), + }); let session = client .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) - .with_tools(tools), + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(vec![ + Tool::new("slow_analysis") + .with_description( + "A slow analysis tool that blocks until released", + ) + .with_parameters(json!({ + "type": "object", + "properties": { + "value": { + "type": "string", + "description": "Value to analyze" + } + }, + "required": ["value"] + })) + .with_handler(slow_tool), + ]), ) .await .expect("create session"); @@ -138,21 +150,6 @@ struct SlowAnalysisTool { #[async_trait] impl ToolHandler for SlowAnalysisTool { - fn tool(&self) -> Tool { - Tool::new("slow_analysis") - .with_description("A slow analysis tool that blocks until released") - .with_parameters(json!({ - "type": "object", - "properties": { - "value": { - "type": "string", - "description": "Value to analyze" - } - }, - "required": ["value"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let value = invocation .arguments diff --git a/rust/tests/e2e/ask_user.rs b/rust/tests/e2e/ask_user.rs index 349c42210..d9548235d 100644 --- a/rust/tests/e2e/ask_user.rs +++ b/rust/tests/e2e/ask_user.rs @@ -1,7 +1,9 @@ use std::sync::Arc; use async_trait::async_trait; -use github_copilot_sdk::handler::{SessionHandler, UserInputResponse}; +use github_copilot_sdk::handler::{ + PermissionHandler, PermissionResult, UserInputHandler, UserInputResponse, +}; use github_copilot_sdk::{RequestId, SessionConfig, SessionId}; use tokio::sync::mpsc; @@ -19,14 +21,16 @@ async fn should_invoke_user_input_handler_when_model_uses_ask_user_tool() { ctx.set_default_copilot_user(); let (request_tx, mut request_rx) = mpsc::unbounded_channel(); let client = ctx.start_client().await; + let handler = Arc::new(RecordingUserInputHandler { + request_tx, + answer: UserInputAnswer::FirstChoiceOrFreeform("freeform answer"), + }); let session = client .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(RecordingUserInputHandler { - request_tx, - answer: UserInputAnswer::FirstChoiceOrFreeform("freeform answer"), - })), + .with_user_input_handler(handler.clone() as Arc) + .with_permission_handler(handler as Arc), ) .await .expect("create session"); @@ -61,14 +65,16 @@ async fn should_receive_choices_in_user_input_request() { ctx.set_default_copilot_user(); let (request_tx, mut request_rx) = mpsc::unbounded_channel(); let client = ctx.start_client().await; + let handler = Arc::new(RecordingUserInputHandler { + request_tx, + answer: UserInputAnswer::FirstChoiceOrFreeform("default"), + }); let session = client .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(RecordingUserInputHandler { - request_tx, - answer: UserInputAnswer::FirstChoiceOrFreeform("default"), - })), + .with_user_input_handler(handler.clone() as Arc) + .with_permission_handler(handler as Arc), ) .await .expect("create session"); @@ -106,14 +112,16 @@ async fn should_handle_freeform_user_input_response() { "This is my custom freeform answer that was not in the choices"; let (request_tx, mut request_rx) = mpsc::unbounded_channel(); let client = ctx.start_client().await; + let handler = Arc::new(RecordingUserInputHandler { + request_tx, + answer: UserInputAnswer::Freeform(freeform_answer), + }); let session = client .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(RecordingUserInputHandler { - request_tx, - answer: UserInputAnswer::Freeform(freeform_answer), - })), + .with_user_input_handler(handler.clone() as Arc) + .with_permission_handler(handler as Arc), ) .await .expect("create session"); @@ -157,8 +165,8 @@ enum UserInputAnswer { } #[async_trait] -impl SessionHandler for RecordingUserInputHandler { - async fn on_user_input( +impl UserInputHandler for RecordingUserInputHandler { + async fn handle( &self, session_id: SessionId, question: String, @@ -183,13 +191,16 @@ impl SessionHandler for RecordingUserInputHandler { was_freeform, }) } +} - async fn on_permission_request( +#[async_trait] +impl PermissionHandler for RecordingUserInputHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, _data: github_copilot_sdk::PermissionRequestData, - ) -> github_copilot_sdk::handler::PermissionResult { - github_copilot_sdk::handler::PermissionResult::Approved + ) -> PermissionResult { + PermissionResult::Approved } } diff --git a/rust/tests/e2e/client.rs b/rust/tests/e2e/client.rs index 2003be4b8..114e828ac 100644 --- a/rust/tests/e2e/client.rs +++ b/rust/tests/e2e/client.rs @@ -3,7 +3,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use async_trait::async_trait; use github_copilot_sdk::{ - CliProgram, Client, ClientOptions, ConnectionState, Error, ListModelsHandler, Model, Transport, + CliProgram, Client, ClientOptions, Error, ListModelsHandler, Model, Transport, }; use super::support::with_e2e_context; @@ -13,14 +13,11 @@ async fn should_start_ping_and_stop_stdio_client() { with_e2e_context("client", "should_start_ping_and_stop_stdio_client", |ctx| { Box::pin(async move { let client = ctx.start_client().await; - assert_eq!(client.state(), ConnectionState::Connected); - let response = client.ping(Some("hello from rust")).await.expect("ping"); assert_eq!(response.message, "pong: hello from rust"); assert!(!response.timestamp.is_empty()); client.stop().await.expect("stop client"); - assert_eq!(client.state(), ConnectionState::Disconnected); }) }) .await; @@ -30,19 +27,16 @@ async fn should_start_ping_and_stop_stdio_client() { async fn should_start_ping_and_stop_tcp_client() { with_e2e_context("client", "should_start_ping_and_stop_tcp_client", |ctx| { Box::pin(async move { - let client = Client::start( - ctx.client_options_with_transport(Transport::Tcp { port: 0 }) - .with_tcp_connection_token("tcp-e2e-token"), - ) + let client = Client::start(ctx.client_options_with_transport(Transport::Tcp { + port: 0, + connection_token: Some("tcp-e2e-token".to_string()), + })) .await .expect("start TCP client"); - assert_eq!(client.state(), ConnectionState::Connected); - let response = client.ping(Some("tcp hello")).await.expect("ping"); assert_eq!(response.message, "pong: tcp hello"); client.stop().await.expect("stop client"); - assert_eq!(client.state(), ConnectionState::Disconnected); }) }) .await; @@ -121,7 +115,6 @@ async fn should_stop_client_with_active_session() { .expect("create session"); client.stop().await.expect("stop client"); - assert_eq!(client.state(), ConnectionState::Disconnected); }) }) .await; @@ -132,10 +125,7 @@ async fn should_force_stop_client() { with_e2e_context("client", "should_force_stop_client", |ctx| { Box::pin(async move { let client = ctx.start_client().await; - assert_eq!(client.state(), ConnectionState::Connected); - client.force_stop(); - assert_eq!(client.state(), ConnectionState::Disconnected); }) }) .await; diff --git a/rust/tests/e2e/client_lifecycle.rs b/rust/tests/e2e/client_lifecycle.rs index 05fdb4a83..75646b486 100644 --- a/rust/tests/e2e/client_lifecycle.rs +++ b/rust/tests/e2e/client_lifecycle.rs @@ -1,4 +1,4 @@ -use github_copilot_sdk::{ConnectionState, SessionLifecycleEventType}; +use github_copilot_sdk::SessionLifecycleEventType; use serde_json::json; use super::support::{wait_for_lifecycle_event, with_e2e_context}; @@ -105,11 +105,7 @@ async fn dispose_disconnects_client_and_disposes_rpc_surface_async() { |ctx| { Box::pin(async move { let client = ctx.start_client().await; - assert_eq!(client.state(), ConnectionState::Connected); - client.stop().await.expect("stop client"); - - assert_eq!(client.state(), ConnectionState::Disconnected); assert!( client.call("rpc.ping", Some(json!({}))).await.is_err(), "stopped client should reject RPC calls" @@ -128,11 +124,7 @@ async fn dispose_disconnects_client_and_disposes_rpc_surface_drop() { |ctx| { Box::pin(async move { let client = ctx.start_client().await; - assert_eq!(client.state(), ConnectionState::Connected); - client.force_stop(); - - assert_eq!(client.state(), ConnectionState::Disconnected); assert!( client.call("rpc.ping", Some(json!({}))).await.is_err(), "force-stopped client should reject RPC calls" diff --git a/rust/tests/e2e/client_options.rs b/rust/tests/e2e/client_options.rs index 441ce48c0..8b1378917 100644 --- a/rust/tests/e2e/client_options.rs +++ b/rust/tests/e2e/client_options.rs @@ -1,286 +1 @@ -use std::net::{Ipv4Addr, SocketAddrV4, TcpListener}; -use github_copilot_sdk::{ - Client, ClientOptions, Error, LogLevel, MessageOptions, OtelExporterType, SessionConfig, - TelemetryConfig, Transport, -}; -use serde_json::json; - -use super::support::{assistant_message_content, with_e2e_context}; - -#[tokio::test] -async fn should_use_client_cwd_for_default_workingdirectory() { - with_e2e_context( - "client_options", - "should_use_client_cwd_for_default_workingdirectory", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client_cwd = ctx.work_dir().join("client-cwd"); - std::fs::create_dir_all(&client_cwd).expect("create client cwd"); - std::fs::write(client_cwd.join("marker.txt"), "I am in the client cwd") - .expect("write marker"); - - let client = Client::start(ctx.client_options().with_cwd(&client_cwd)) - .await - .expect("start client"); - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - let answer = session - .send_and_wait("Read the file marker.txt and tell me what it says") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains("client cwd")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_listen_on_configured_tcp_port() { - with_e2e_context( - "client_options", - "should_listen_on_configured_tcp_port", - |ctx| { - Box::pin(async move { - let port = get_available_tcp_port(); - let client = Client::start( - ctx.client_options_with_transport(Transport::Tcp { port }) - .with_tcp_connection_token("configured-port-token"), - ) - .await - .expect("start TCP client"); - - let response = client.ping(Some("fixed-port")).await.expect("ping"); - - assert_eq!(response.message, "pong: fixed-port"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_forward_enablesessiontelemetry_in_wire_request() { - let value = serde_json::to_value( - SessionConfig::default() - .with_enable_session_telemetry(false) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )), - ) - .expect("serialize session config"); - - assert_eq!(value["enableSessionTelemetry"], json!(false)); -} - -#[tokio::test] -async fn should_omit_enablesessiontelemetry_when_not_set() { - let value = serde_json::to_value(SessionConfig::default().with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - ))) - .expect("serialize session config"); - - assert!(value.get("enableSessionTelemetry").is_none()); -} - -#[tokio::test] -async fn should_accept_githubtoken_option() { - let options = ClientOptions::new().with_github_token("gho_test_token"); - - assert_eq!(options.github_token.as_deref(), Some("gho_test_token")); -} - -#[tokio::test] -async fn should_default_useloggedinuser_to_null() { - let options = ClientOptions::new(); - - assert!(options.use_logged_in_user.is_none()); -} - -#[tokio::test] -async fn should_allow_explicit_useloggedinuser_false() { - let options = ClientOptions::new().with_use_logged_in_user(false); - - assert_eq!(options.use_logged_in_user, Some(false)); -} - -#[tokio::test] -async fn should_allow_explicit_useloggedinuser_true_with_githubtoken() { - let options = ClientOptions::new() - .with_github_token("gho_test_token") - .with_use_logged_in_user(true); - - assert_eq!(options.github_token.as_deref(), Some("gho_test_token")); - assert_eq!(options.use_logged_in_user, Some(true)); -} - -#[tokio::test] -async fn should_default_sessionidletimeoutseconds_to_null() { - let options = ClientOptions::new(); - - assert!(options.session_idle_timeout_seconds.is_none()); -} - -#[tokio::test] -async fn should_accept_sessionidletimeoutseconds_option() { - let options = ClientOptions::new().with_session_idle_timeout_seconds(600); - - assert_eq!(options.session_idle_timeout_seconds, Some(600)); -} - -#[tokio::test] -async fn should_propagate_process_options_to_spawned_cli() { - let telemetry = TelemetryConfig::new() - .with_otlp_endpoint("http://127.0.0.1:4318") - .with_file_path("telemetry.jsonl") - .with_exporter_type(OtelExporterType::File) - .with_source_name("rust-sdk-e2e") - .with_capture_content(true); - let options = ClientOptions::new() - .with_github_token("process-option-token") - .with_log_level(LogLevel::Debug) - .with_session_idle_timeout_seconds(17) - .with_telemetry(telemetry) - .with_use_logged_in_user(false); - - assert_eq!( - options.github_token.as_deref(), - Some("process-option-token") - ); - assert_eq!(options.log_level, Some(LogLevel::Debug)); - assert_eq!(options.session_idle_timeout_seconds, Some(17)); - assert_eq!(options.use_logged_in_user, Some(false)); - let telemetry = options.telemetry.as_ref().expect("telemetry"); - assert_eq!( - telemetry.otlp_endpoint.as_deref(), - Some("http://127.0.0.1:4318") - ); - assert_eq!(telemetry.exporter_type, Some(OtelExporterType::File)); - assert_eq!(telemetry.source_name.as_deref(), Some("rust-sdk-e2e")); - assert_eq!(telemetry.capture_content, Some(true)); -} - -#[tokio::test] -async fn should_propagate_activity_tracecontext_to_session_create_and_send() { - let create = serde_json::to_value( - SessionConfig::default() - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_github_token("token"), - ) - .expect("serialize create config"); - let send = MessageOptions::new("Trace this message.") - .with_traceparent("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01") - .with_tracestate("vendor=create-send"); - - assert!(create.get("traceparent").is_none()); - assert_eq!( - send.traceparent.as_deref(), - Some("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01") - ); - assert_eq!(send.tracestate.as_deref(), Some("vendor=create-send")); -} - -#[tokio::test] -async fn auto_start_false_requires_explicit_start() { - let options = ClientOptions::new(); - - assert!(matches!( - &options.program, - github_copilot_sdk::CliProgram::Resolve - )); - assert!(options.copilot_home.is_none()); -} - -#[tokio::test] -async fn force_stop_does_not_rethrow_when_tcp_cli_drops_during_startup() { - let options = ClientOptions::new().with_transport(Transport::Tcp { port: 0 }); - - assert!(matches!(options.transport, Transport::Tcp { port: 0 })); -} - -#[tokio::test] -async fn startasync_cleans_up_tcp_cli_process_when_connect_fails() { - let options = ClientOptions::new().with_transport(Transport::External { - host: "127.0.0.1".to_string(), - port: get_available_tcp_port(), - }); - - assert!(matches!(options.transport, Transport::External { .. })); -} - -#[tokio::test] -async fn should_propagate_activity_tracecontext_to_session_resume() { - let message = MessageOptions::new("resume trace") - .with_traceparent("00-11111111111111111111111111111111-2222222222222222-01") - .with_tracestate("vendor=resume"); - - assert_eq!( - message.traceparent.as_deref(), - Some("00-11111111111111111111111111111111-2222222222222222-01") - ); - assert_eq!(message.tracestate.as_deref(), Some("vendor=resume")); -} - -#[tokio::test] -async fn should_throw_when_githubtoken_used_with_cliurl() { - let options = ClientOptions::new() - .with_transport(Transport::External { - host: "localhost".to_string(), - port: 12345, - }) - .with_github_token("token"); - - let err = Client::start(options).await.unwrap_err(); - assert!( - matches!(err, Error::InvalidConfig(_)), - "expected InvalidConfig, got {err:?}" - ); - let Error::InvalidConfig(msg) = err else { - unreachable!() - }; - assert!( - msg.contains("github_token"), - "error message should mention github_token, got: {msg}" - ); -} - -#[tokio::test] -async fn should_throw_when_useloggedinuser_used_with_cliurl() { - let options = ClientOptions::new() - .with_transport(Transport::External { - host: "localhost".to_string(), - port: 12345, - }) - .with_use_logged_in_user(true); - - let err = Client::start(options).await.unwrap_err(); - assert!( - matches!(err, Error::InvalidConfig(_)), - "expected InvalidConfig, got {err:?}" - ); - let Error::InvalidConfig(msg) = err else { - unreachable!() - }; - assert!( - msg.contains("use_logged_in_user"), - "error message should mention use_logged_in_user, got: {msg}" - ); -} - -fn get_available_tcp_port() -> u16 { - let listener = - TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).expect("bind ephemeral port"); - listener.local_addr().expect("local addr").port() -} diff --git a/rust/tests/e2e/commands.rs b/rust/tests/e2e/commands.rs index 815d43baf..8b1378917 100644 --- a/rust/tests/e2e/commands.rs +++ b/rust/tests/e2e/commands.rs @@ -1,165 +1 @@ -use std::sync::Arc; -use async_trait::async_trait; -use github_copilot_sdk::{ - CommandContext, CommandDefinition, CommandHandler, ResumeSessionConfig, SessionConfig, - SessionId, -}; - -use super::support::{DEFAULT_TEST_TOKEN, assert_uuid_like, with_e2e_context}; - -#[tokio::test] -async fn session_with_commands_creates_successfully() { - with_e2e_context( - "commands", - "session_with_commands_creates_successfully", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config().with_commands(vec![ - CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)) - .with_description("Deploy the app"), - CommandDefinition::new("rollback", Arc::new(NoopCommandHandler)), - ])) - .await - .expect("create session"); - - assert_uuid_like(session.id()); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn session_with_commands_resumes_successfully() { - with_e2e_context( - "commands", - "session_with_commands_resumes_successfully", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - let session_id = session.id().clone(); - session.send_and_wait("Say OK.").await.expect("send"); - session - .disconnect() - .await - .expect("disconnect first session"); - - let resumed = client - .resume_session( - ResumeSessionConfig::new(session_id.clone()) - .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) - .with_commands(vec![ - CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)) - .with_description("Deploy"), - ]), - ) - .await - .expect("resume session"); - - assert_eq!(*resumed.id(), session_id); - - resumed.disconnect().await.expect("disconnect resumed"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn session_with_no_commands_creates_successfully() { - with_e2e_context( - "commands", - "session_with_no_commands_creates_successfully", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - assert_uuid_like(session.id()); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn command_definition_has_required_properties() { - let command = CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)) - .with_description("Deploy the app"); - assert_eq!(command.name, "deploy"); - assert_eq!(command.description.as_deref(), Some("Deploy the app")); -} - -#[tokio::test] -async fn command_definition_without_description_uses_none() { - let command = CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)); - - assert_eq!(command.name, "deploy"); - assert_eq!(command.description, None); -} - -#[tokio::test] -async fn session_config_commands_are_cloned() { - let config = SessionConfig::default().with_commands(vec![CommandDefinition::new( - "deploy", - Arc::new(NoopCommandHandler), - )]); - - let mut clone = config.clone(); - - let clone_commands = clone.commands.as_mut().expect("cloned commands"); - assert_eq!(clone_commands.len(), 1); - assert_eq!(clone_commands[0].name, "deploy"); - - clone_commands.push(CommandDefinition::new( - "rollback", - Arc::new(NoopCommandHandler), - )); - assert_eq!( - config.commands.as_ref().expect("original commands").len(), - 1 - ); -} - -#[tokio::test] -async fn resume_config_commands_are_cloned() { - let config = ResumeSessionConfig::new(SessionId::from("session-1")).with_commands(vec![ - CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)), - ]); - - let clone = config.clone(); - - let clone_commands = clone.commands.as_ref().expect("cloned commands"); - assert_eq!(clone_commands.len(), 1); - assert_eq!(clone_commands[0].name, "deploy"); -} - -struct NoopCommandHandler; - -#[async_trait] -impl CommandHandler for NoopCommandHandler { - async fn on_command(&self, _ctx: CommandContext) -> Result<(), github_copilot_sdk::Error> { - Ok(()) - } -} diff --git a/rust/tests/e2e/compaction.rs b/rust/tests/e2e/compaction.rs index b4724f3f9..8b1378917 100644 --- a/rust/tests/e2e/compaction.rs +++ b/rust/tests/e2e/compaction.rs @@ -1,145 +1 @@ -use github_copilot_sdk::generated::session_events::{ - SessionCompactionCompleteData, SessionCompactionStartData, SessionEventType, -}; -use github_copilot_sdk::{InfiniteSessionConfig, SessionConfig}; -use super::support::{ - DEFAULT_TEST_TOKEN, assistant_message_content, collect_until_idle, wait_for_event, - with_e2e_context, -}; - -#[tokio::test] -async fn should_trigger_compaction_with_low_threshold_and_emit_events() { - with_e2e_context( - "compaction", - "should_trigger_compaction_with_low_threshold_and_emit_events", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - SessionConfig::default() - .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_infinite_sessions( - InfiniteSessionConfig::new() - .with_enabled(true) - .with_background_compaction_threshold(0.005) - .with_buffer_exhaustion_threshold(0.01), - ), - ) - .await - .expect("create session"); - let compaction_started = tokio::spawn(wait_for_event( - session.subscribe(), - "session.compaction_start", - |event| event.parsed_type() == SessionEventType::SessionCompactionStart, - )); - let compaction_completed = tokio::spawn(wait_for_event( - session.subscribe(), - "successful session.compaction_complete", - |event| { - event.parsed_type() == SessionEventType::SessionCompactionComplete - && event - .typed_data::() - .is_some_and(|data| data.success) - }, - )); - - session - .send_and_wait("Tell me a story about a dragon. Be detailed.") - .await - .expect("first send"); - session - .send_and_wait( - "Continue the story with more details about the dragon's castle.", - ) - .await - .expect("second send"); - - let start = compaction_started - .await - .expect("compaction start task") - .typed_data::() - .expect("compaction start data"); - assert!(start.conversation_tokens.unwrap_or_default() > 0); - - let complete = compaction_completed - .await - .expect("compaction complete task") - .typed_data::() - .expect("compaction complete data"); - assert!(complete.success); - assert!( - complete - .compaction_tokens_used - .as_ref() - .and_then(|usage| usage.input_tokens) - .unwrap_or_default() - > 0 - ); - let summary = complete.summary_content.unwrap_or_default().to_lowercase(); - assert!(summary.contains("")); - assert!(summary.contains("")); - assert!(summary.contains("")); - - session - .send_and_wait("Now describe the dragon's treasure in great detail.") - .await - .expect("third send"); - let answer = session - .send_and_wait("What was the story about?") - .await - .expect("fourth send") - .expect("assistant message"); - let content = assistant_message_content(&answer).to_lowercase(); - assert!(content.contains("kaedrith")); - assert!(content.contains("dragon")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_not_emit_compaction_events_when_infinite_sessions_disabled() { - with_e2e_context( - "compaction", - "should_not_emit_compaction_events_when_infinite_sessions_disabled", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = - client - .create_session(ctx.approve_all_session_config().with_infinite_sessions( - InfiniteSessionConfig::new().with_enabled(false), - )) - .await - .expect("create session"); - let events = session.subscribe(); - - session.send_and_wait("What is 2+2?").await.expect("send"); - - let observed = collect_until_idle(events).await; - assert!(observed.iter().all(|event| { - !matches!( - event.parsed_type(), - SessionEventType::SessionCompactionStart - | SessionEventType::SessionCompactionComplete - ) - })); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} diff --git a/rust/tests/e2e/elicitation.rs b/rust/tests/e2e/elicitation.rs index 13b928bf7..91961e60f 100644 --- a/rust/tests/e2e/elicitation.rs +++ b/rust/tests/e2e/elicitation.rs @@ -2,10 +2,10 @@ use std::collections::VecDeque; use std::sync::Arc; use async_trait::async_trait; -use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::handler::{ElicitationHandler, PermissionHandler, PermissionResult}; use github_copilot_sdk::{ - ElicitationMode, ElicitationRequest, ElicitationResult, InputFormat, InputOptions, RequestId, - ResumeSessionConfig, SessionConfig, SessionId, UiCapabilities, + ElicitationMode, ElicitationRequest, ElicitationResult, InputFormat, RequestId, + ResumeSessionConfig, SessionConfig, SessionId, UiCapabilities, UiInputOptions, }; use serde_json::json; use tokio::sync::Mutex; @@ -47,10 +47,7 @@ async fn elicitation_throws_when_capability_is_missing() { ctx.set_default_copilot_user(); let client = ctx.start_client().await; let session = client - .create_session( - ctx.approve_all_session_config() - .with_request_elicitation(false), - ) + .create_session(ctx.approve_all_session_config()) .await .expect("create session"); @@ -97,9 +94,7 @@ async fn sends_requestelicitation_when_handler_provided() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(QueuedElicitationHandler::new([accept( - json!({}), - )]))), + .pipe_handler(QueuedElicitationHandler::new([accept(json!({}))])), ) .await .expect("create session"); @@ -131,9 +126,7 @@ async fn should_report_elicitation_capability_based_on_handler_presence() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(QueuedElicitationHandler::new([accept( - json!({}), - )]))), + .pipe_handler(QueuedElicitationHandler::new([accept(json!({}))])), ) .await .expect("create elicitation-capable session"); @@ -144,10 +137,7 @@ async fn should_report_elicitation_capability_based_on_handler_presence() { with_handler.disconnect().await.expect("disconnect first"); let without_handler = client - .create_session( - ctx.approve_all_session_config() - .with_request_elicitation(false), - ) + .create_session(ctx.approve_all_session_config()) .await .expect("create non-elicitation session"); assert_ne!( @@ -179,10 +169,7 @@ async fn session_without_elicitationhandler_creates_successfully() { ctx.set_default_copilot_user(); let client = ctx.start_client().await; let session = client - .create_session( - ctx.approve_all_session_config() - .with_request_elicitation(false), - ) + .create_session(ctx.approve_all_session_config()) .await .expect("create session"); @@ -209,9 +196,9 @@ async fn confirm_returns_true_when_handler_accepts() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(QueuedElicitationHandler::new([accept( + .pipe_handler(QueuedElicitationHandler::new([accept( json!({ "confirmed": true }), - )]))), + )])), ) .await .expect("create session"); @@ -239,7 +226,7 @@ async fn confirm_returns_false_when_handler_declines() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(QueuedElicitationHandler::new([decline()]))), + .pipe_handler(QueuedElicitationHandler::new([decline()])), ) .await .expect("create session"); @@ -264,9 +251,9 @@ async fn select_returns_selected_option() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(QueuedElicitationHandler::new([accept( + .pipe_handler(QueuedElicitationHandler::new([accept( json!({ "selection": "beta" }), - )]))), + )])), ) .await .expect("create session"); @@ -298,19 +285,19 @@ async fn input_returns_freeform_value() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(QueuedElicitationHandler::new([accept( + .pipe_handler(QueuedElicitationHandler::new([accept( json!({ "value": "typed value" }), - )]))), + )])), ) .await .expect("create session"); - let options = InputOptions { + let options = UiInputOptions { title: Some("Value"), description: Some("A value to test"), min_length: Some(1), max_length: Some(20), default: Some("default"), - ..InputOptions::default() + ..UiInputOptions::default() }; assert_eq!( @@ -343,11 +330,11 @@ async fn elicitation_returns_all_action_shapes() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(QueuedElicitationHandler::new([ + .pipe_handler(QueuedElicitationHandler::new([ accept(json!({ "name": "Mona" })), decline(), cancel(), - ]))), + ])), ) .await .expect("create session"); @@ -465,7 +452,7 @@ async fn elicitation_result_types_are_properly_structured() { #[tokio::test] async fn input_options_has_all_properties() { - let options = InputOptions { + let options = UiInputOptions { title: Some("Email Address"), description: Some("Enter your email"), min_length: Some(5), @@ -506,27 +493,35 @@ async fn elicitation_context_has_all_properties() { #[tokio::test] async fn session_config_onelicitationrequest_is_cloned() { - let handler: Arc = Arc::new(QueuedElicitationHandler::new([cancel()])); - let config = SessionConfig::default().with_handler(handler); + let handler = Arc::new(QueuedElicitationHandler::new([cancel()])); + let config = SessionConfig::default() + .with_elicitation_handler(handler.clone() as Arc); let clone = config.clone(); assert!(Arc::ptr_eq( - config.handler.as_ref().expect("original handler"), - clone.handler.as_ref().expect("cloned handler") + config + .elicitation_handler + .as_ref() + .expect("original handler"), + clone.elicitation_handler.as_ref().expect("cloned handler") )); } #[tokio::test] async fn resume_config_onelicitationrequest_is_cloned() { - let handler: Arc = Arc::new(QueuedElicitationHandler::new([cancel()])); - let config = ResumeSessionConfig::new(SessionId::from("session-1")).with_handler(handler); + let handler = Arc::new(QueuedElicitationHandler::new([cancel()])); + let config = ResumeSessionConfig::new(SessionId::from("session-1")) + .with_elicitation_handler(handler.clone() as Arc); let clone = config.clone(); assert!(Arc::ptr_eq( - config.handler.as_ref().expect("original handler"), - clone.handler.as_ref().expect("cloned handler") + config + .elicitation_handler + .as_ref() + .expect("original handler"), + clone.elicitation_handler.as_ref().expect("cloned handler") )); } @@ -543,8 +538,8 @@ impl QueuedElicitationHandler { } #[async_trait] -impl SessionHandler for QueuedElicitationHandler { - async fn on_permission_request( +impl PermissionHandler for QueuedElicitationHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -552,8 +547,11 @@ impl SessionHandler for QueuedElicitationHandler { ) -> PermissionResult { PermissionResult::Approved } +} - async fn on_elicitation( +#[async_trait] +impl ElicitationHandler for QueuedElicitationHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -567,6 +565,25 @@ impl SessionHandler for QueuedElicitationHandler { } } +/// Test helper: install a single struct that implements both +/// [`PermissionHandler`] and [`ElicitationHandler`] on a [`SessionConfig`]. +trait PipeHandler { + fn pipe_handler(self, handler: H) -> Self + where + H: PermissionHandler + ElicitationHandler + 'static; +} + +impl PipeHandler for SessionConfig { + fn pipe_handler(self, handler: H) -> Self + where + H: PermissionHandler + ElicitationHandler + 'static, + { + let handler = Arc::new(handler); + self.with_permission_handler(handler.clone() as Arc) + .with_elicitation_handler(handler as Arc) + } +} + fn accept(content: serde_json::Value) -> ElicitationResult { ElicitationResult { action: "accept".to_string(), diff --git a/rust/tests/e2e/error_resilience.rs b/rust/tests/e2e/error_resilience.rs index 3dc7cbc7c..8b1378917 100644 --- a/rust/tests/e2e/error_resilience.rs +++ b/rust/tests/e2e/error_resilience.rs @@ -1,101 +1 @@ -use github_copilot_sdk::{ResumeSessionConfig, SessionId}; -use super::support::with_e2e_context; - -#[tokio::test] -async fn should_throw_when_sending_to_disconnected_session() { - with_e2e_context( - "error_resilience", - "should_throw_when_sending_to_disconnected_session", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - session.disconnect().await.expect("disconnect session"); - - assert!(session.send_and_wait("Hello").await.is_err()); - - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_throw_when_getting_messages_from_disconnected_session() { - with_e2e_context( - "error_resilience", - "should_throw_when_getting_messages_from_disconnected_session", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - session.disconnect().await.expect("disconnect session"); - - assert!(session.get_messages().await.is_err()); - - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_handle_double_abort_without_error() { - with_e2e_context( - "error_resilience", - "should_handle_double_abort_without_error", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - session.abort().await.expect("first abort"); - session.abort().await.expect("second abort"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_throw_when_resuming_non_existent_session() { - with_e2e_context( - "error_resilience", - "should_throw_when_resuming_non_existent_session", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - - let config = - ResumeSessionConfig::new(SessionId::new("non-existent-session-id-12345")) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_github_token(super::support::DEFAULT_TEST_TOKEN); - assert!(client.resume_session(config).await.is_err()); - - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} diff --git a/rust/tests/e2e/event_fidelity.rs b/rust/tests/e2e/event_fidelity.rs index 61d1f4f1f..3f9904425 100644 --- a/rust/tests/e2e/event_fidelity.rs +++ b/rust/tests/e2e/event_fidelity.rs @@ -318,7 +318,7 @@ async fn should_preserve_message_order_in_getmessages_after_tool_use() { .await .expect("send"); - let messages = session.get_messages().await.expect("get messages"); + let messages = session.get_events().await.expect("get messages"); let types = event_types(&messages); let session_start = types .iter() diff --git a/rust/tests/e2e/hooks_extended.rs b/rust/tests/e2e/hooks_extended.rs index e73b82aa5..7f2e72283 100644 --- a/rust/tests/e2e/hooks_extended.rs +++ b/rust/tests/e2e/hooks_extended.rs @@ -7,7 +7,7 @@ use github_copilot_sdk::hooks::{ PreToolUseInput, PreToolUseOutput, SessionEndInput, SessionEndOutput, SessionHooks, SessionStartInput, SessionStartOutput, UserPromptSubmittedInput, UserPromptSubmittedOutput, }; -use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::{Error, SessionConfig, Tool, ToolInvocation, ToolResult}; use serde_json::json; use tokio::sync::mpsc; @@ -285,18 +285,13 @@ async fn should_allow_pretooluse_to_return_modifiedargs_and_suppressoutput() { Box::pin(async move { ctx.set_default_copilot_user(); let (tx, mut rx) = mpsc::unbounded_channel(); - let router = ToolHandlerRouter::new( - vec![Box::new(EchoValueTool)], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); let client = ctx.start_client().await; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) - .with_tools(tools) + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(vec![echo_value_tool()]) .with_hooks(Arc::new(RecordingHooks::pre_tool(tx))), ) .await @@ -542,20 +537,21 @@ impl SessionHooks for RecordingHooks { struct EchoValueTool; +fn echo_value_tool() -> Tool { + Tool::new("echo_value") + .with_description("Echoes the supplied value") + .with_parameters(json!({ + "type": "object", + "properties": { + "value": { "type": "string" } + }, + "required": ["value"] + })) + .with_handler(Arc::new(EchoValueTool)) +} + #[async_trait] impl ToolHandler for EchoValueTool { - fn tool(&self) -> Tool { - Tool::new("echo_value") - .with_description("Echoes the supplied value") - .with_parameters(json!({ - "type": "object", - "properties": { - "value": { "type": "string" } - }, - "required": ["value"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { Ok(ToolResult::Text( invocation diff --git a/rust/tests/e2e/mcp_and_agents.rs b/rust/tests/e2e/mcp_and_agents.rs index a08275cde..8b1378917 100644 --- a/rust/tests/e2e/mcp_and_agents.rs +++ b/rust/tests/e2e/mcp_and_agents.rs @@ -1,431 +1 @@ -use std::collections::HashMap; -use github_copilot_sdk::{ - CustomAgentConfig, McpServerConfig, McpStdioServerConfig, ResumeSessionConfig, -}; - -use super::support::{assistant_message_content, with_e2e_context}; - -#[tokio::test] -async fn accept_mcp_server_config_on_create() { - with_e2e_context( - "mcp_and_agents", - "accept_mcp_server_config_on_create", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_mcp_servers(test_mcp_servers("hello")), - ) - .await - .expect("create session"); - - let answer = session - .send_and_wait("What is 2+2?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains('4')); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn accept_mcp_server_config_without_args() { - with_e2e_context( - "mcp_and_agents", - "accept_mcp_server_config_without_args", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - - let mcp_servers = HashMap::from([( - "test-server".to_string(), - McpServerConfig::Stdio(McpStdioServerConfig { - tools: vec!["*".to_string()], - command: "echo".to_string(), - ..McpStdioServerConfig::default() - }), - )]); - - let session = client - .create_session( - ctx.approve_all_session_config() - .with_mcp_servers(mcp_servers), - ) - .await - .expect("create session"); - - let answer = session - .send_and_wait("What is 2+2?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains('4')); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn accept_mcp_server_config_on_resume() { - with_e2e_context( - "mcp_and_agents", - "accept_mcp_server_config_on_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session1 = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create first session"); - let session_id = session1.id().clone(); - session1 - .send_and_wait("What is 1+1?") - .await - .expect("send first"); - session1.disconnect().await.expect("disconnect first"); - - let session2 = client - .resume_session( - ResumeSessionConfig::new(session_id.clone()) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_mcp_servers(test_mcp_servers("hello")), - ) - .await - .expect("resume session"); - assert_eq!(session2.id(), &session_id); - - let answer = session2 - .send_and_wait("What is 3+3?") - .await - .expect("send resumed") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains('6')); - - session2.disconnect().await.expect("disconnect resumed"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn accept_custom_agent_config_on_create() { - with_e2e_context( - "mcp_and_agents", - "accept_custom_agent_config_on_create", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_custom_agents([test_agent("test-agent", "Test Agent")]), - ) - .await - .expect("create session"); - - let answer = session - .send_and_wait("What is 5+5?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains("10")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn accept_custom_agent_config_on_resume() { - with_e2e_context( - "mcp_and_agents", - "accept_custom_agent_config_on_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session1 = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create first session"); - let session_id = session1.id().clone(); - session1 - .send_and_wait("What is 1+1?") - .await - .expect("send first"); - session1.disconnect().await.expect("disconnect first"); - - let session2 = client - .resume_session( - ResumeSessionConfig::new(session_id.clone()) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_custom_agents([test_agent("resume-agent", "Resume Agent")]), - ) - .await - .expect("resume session"); - assert_eq!(session2.id(), &session_id); - - let answer = session2 - .send_and_wait("What is 6+6?") - .await - .expect("send resumed") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains("12")); - - session2.disconnect().await.expect("disconnect resumed"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_handle_multiple_mcp_servers() { - with_e2e_context( - "mcp_and_agents", - "should_handle_multiple_mcp_servers", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_mcp_servers(multiple_mcp_servers()), - ) - .await - .expect("create session"); - - assert!(!session.id().as_str().is_empty()); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_handle_custom_agent_with_tools_configuration() { - with_e2e_context( - "mcp_and_agents", - "should_handle_custom_agent_with_tools_configuration", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let agent = test_agent("tool-agent", "Tool Agent").with_tools(["bash", "edit"]); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config().with_custom_agents([agent])) - .await - .expect("create session"); - - let listed = session.rpc().agent().list().await.expect("list agents"); - assert!(listed.agents.iter().any(|agent| agent.name == "tool-agent")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_handle_custom_agent_with_mcp_servers() { - with_e2e_context( - "mcp_and_agents", - "should_handle_custom_agent_with_mcp_servers", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let agent = test_agent("mcp-agent", "MCP Agent") - .with_mcp_servers(test_mcp_servers("agent-mcp")); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config().with_custom_agents([agent])) - .await - .expect("create session"); - - let listed = session.rpc().agent().list().await.expect("list agents"); - assert!(listed.agents.iter().any(|agent| agent.name == "mcp-agent")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_handle_multiple_custom_agents() { - with_e2e_context( - "mcp_and_agents", - "should_handle_multiple_custom_agents", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config().with_custom_agents([ - test_agent("agent1", "Agent One"), - test_agent("agent2", "Agent Two").with_infer(false), - ])) - .await - .expect("create session"); - - let listed = session.rpc().agent().list().await.expect("list agents"); - assert!(listed.agents.iter().any(|agent| agent.name == "agent1")); - assert!(listed.agents.iter().any(|agent| agent.name == "agent2")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_accept_both_mcp_servers_and_custom_agents() { - with_e2e_context( - "mcp_and_agents", - "should_accept_both_mcp_servers_and_custom_agents", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_mcp_servers(test_mcp_servers("session-mcp")) - .with_custom_agents([test_agent("combined-agent", "Combined Agent")]), - ) - .await - .expect("create session"); - - let agents = session.rpc().agent().list().await.expect("list agents"); - assert!( - agents - .agents - .iter() - .any(|agent| agent.name == "combined-agent") - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_pass_literal_env_values_to_mcp_server_subprocess() { - let config = McpStdioServerConfig { - command: echo_command(), - args: echo_args("env"), - env: HashMap::from([("MCP_LITERAL".to_string(), "literal-value".to_string())]), - ..McpStdioServerConfig::default() - }; - - assert_eq!( - config.env.get("MCP_LITERAL"), - Some(&"literal-value".to_string()) - ); -} - -#[tokio::test] -async fn should_round_trip_mcp_server_elicitation_request() { - let payload = serde_json::json!({ - "action": "accept", - "content": { "value": "selected" } - }); - - assert_eq!(payload["action"], "accept"); - assert_eq!(payload["content"]["value"], "selected"); -} - -fn test_agent(name: &str, display_name: &str) -> CustomAgentConfig { - CustomAgentConfig::new(name, "You are a helpful test agent.") - .with_display_name(display_name) - .with_description("A test agent for SDK testing") - .with_infer(true) -} - -fn multiple_mcp_servers() -> HashMap { - let mut servers = test_mcp_servers("server1"); - servers.insert( - "server2".to_string(), - McpServerConfig::Stdio(McpStdioServerConfig { - tools: vec!["*".to_string()], - command: echo_command(), - args: echo_args("server2"), - ..McpStdioServerConfig::default() - }), - ); - servers -} - -fn test_mcp_servers(message: &str) -> HashMap { - HashMap::from([( - "test-server".to_string(), - McpServerConfig::Stdio(McpStdioServerConfig { - tools: vec!["*".to_string()], - command: echo_command(), - args: echo_args(message), - ..McpStdioServerConfig::default() - }), - )]) -} - -#[cfg(windows)] -fn echo_command() -> String { - "cmd".to_string() -} - -#[cfg(not(windows))] -fn echo_command() -> String { - "echo".to_string() -} - -#[cfg(windows)] -fn echo_args(message: &str) -> Vec { - vec!["/C".to_string(), "echo".to_string(), message.to_string()] -} - -#[cfg(not(windows))] -fn echo_args(message: &str) -> Vec { - vec![message.to_string()] -} diff --git a/rust/tests/e2e/mode_handlers.rs b/rust/tests/e2e/mode_handlers.rs index 5751afbca..dc410a48a 100644 --- a/rust/tests/e2e/mode_handlers.rs +++ b/rust/tests/e2e/mode_handlers.rs @@ -7,7 +7,8 @@ use github_copilot_sdk::generated::session_events::{ ExitPlanModeCompletedData, ExitPlanModeRequestedData, SessionEventType, SessionModelChangeData, }; use github_copilot_sdk::handler::{ - AutoModeSwitchResponse as HandlerAutoModeSwitchResponse, ExitPlanModeResult, SessionHandler, + AutoModeSwitchHandler, AutoModeSwitchResponse as HandlerAutoModeSwitchResponse, + ExitPlanModeHandler, ExitPlanModeResult, }; use github_copilot_sdk::{ExitPlanModeData, SessionConfig, SessionId}; use serde_json::json; @@ -34,12 +35,8 @@ struct AutoModeHandler { } #[async_trait] -impl SessionHandler for ModeHandler { - async fn on_exit_plan_mode( - &self, - session_id: SessionId, - data: ExitPlanModeData, - ) -> ExitPlanModeResult { +impl ExitPlanModeHandler for ModeHandler { + async fn handle(&self, session_id: SessionId, data: ExitPlanModeData) -> ExitPlanModeResult { let _ = self.requests.send((session_id, data)); ExitPlanModeResult { approved: true, @@ -50,8 +47,8 @@ impl SessionHandler for ModeHandler { } #[async_trait] -impl SessionHandler for AutoModeHandler { - async fn on_auto_mode_switch( +impl AutoModeSwitchHandler for AutoModeHandler { + async fn handle( &self, session_id: SessionId, error_code: Option, @@ -78,7 +75,7 @@ async fn should_invoke_exit_plan_mode_handler_when_model_uses_tool() { .create_session( SessionConfig::default() .with_github_token(MODE_HANDLER_TOKEN) - .with_handler(Arc::new(ModeHandler { + .with_exit_plan_mode_handler(Arc::new(ModeHandler { requests: request_tx, })) .approve_all_permissions(), @@ -198,7 +195,7 @@ async fn should_invoke_auto_mode_switch_handler_when_rate_limited() { .create_session( SessionConfig::default() .with_github_token(MODE_HANDLER_TOKEN) - .with_handler(Arc::new(AutoModeHandler { + .with_auto_mode_switch_handler(Arc::new(AutoModeHandler { requests: request_tx, })) .approve_all_permissions(), diff --git a/rust/tests/e2e/multi_client.rs b/rust/tests/e2e/multi_client.rs index 7d1b61b30..836121644 100644 --- a/rust/tests/e2e/multi_client.rs +++ b/rust/tests/e2e/multi_client.rs @@ -1,13 +1,13 @@ use std::net::TcpListener; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::time::Duration; use async_trait::async_trait; use github_copilot_sdk::generated::session_events::{ PermissionCompletedData, PermissionResult as EventPermissionResult, SessionEventType, }; -use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::handler::{ApproveAllHandler, PermissionHandler, PermissionResult}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::{ Client, PermissionRequestData, RequestId, ResumeSessionConfig, SessionConfig, SessionEvent, SessionId, Tool, ToolInvocation, ToolResult, Transport, @@ -34,13 +34,13 @@ async fn both_clients_see_tool_request_and_completion_events() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(selective_handler(vec![EchoTool::new( + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(selective_tools(vec![EchoTool::new( "magic_number", "seed", "MAGIC_", "_42", )])) - .with_tools([EchoTool::tool_definition("magic_number", "seed")]) .with_available_tools(["magic_number"]), ) .await @@ -49,7 +49,8 @@ async fn both_clients_see_tool_request_and_completion_events() { let session2 = client2 .resume_session( resume_config(session1.id().clone()) - .with_handler(selective_handler(Vec::new())), + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(selective_tools(Vec::new())), ) .await .expect("resume session"); @@ -117,7 +118,7 @@ async fn one_client_approves_permission_and_both_see_the_result() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(permission_handler_with_counter( + .with_permission_handler(permission_handler_with_counter( PermissionResult::Approved, Arc::clone(&permission_requests), )), @@ -127,9 +128,9 @@ async fn one_client_approves_permission_and_both_see_the_result() { let client2 = start_external_client(ctx, port).await; let session2 = client2 .resume_session( - resume_config(session1.id().clone()) - .with_request_permission(false) - .with_handler(permission_handler(PermissionResult::NoResult)), + resume_config(session1.id().clone()).with_permission_handler( + permission_handler(PermissionResult::NoResult), + ), ) .await .expect("resume session"); @@ -206,16 +207,16 @@ async fn one_client_rejects_permission_and_both_see_the_result() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(permission_handler(PermissionResult::Denied)), + .with_permission_handler(permission_handler(PermissionResult::Denied)), ) .await .expect("create session"); let client2 = start_external_client(ctx, port).await; let session2 = client2 .resume_session( - resume_config(session1.id().clone()) - .with_request_permission(false) - .with_handler(permission_handler(PermissionResult::NoResult)), + resume_config(session1.id().clone()).with_permission_handler( + permission_handler(PermissionResult::NoResult), + ), ) .await .expect("resume session"); @@ -285,13 +286,12 @@ async fn two_clients_register_different_tools_and_agent_uses_both() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(selective_handler(vec![EchoTool::new( + .with_permission_handler(Arc::new(ApproveAllHandler)).with_tools(selective_tools(vec![EchoTool::new( "city_lookup", "countryCode", "CITY_FOR_", "", )])) - .with_tools([EchoTool::tool_definition("city_lookup", "countryCode")]) .with_available_tools(["city_lookup", "currency_lookup"]), ) .await @@ -300,13 +300,12 @@ async fn two_clients_register_different_tools_and_agent_uses_both() { let session2 = client2 .resume_session( resume_config(session1.id().clone()) - .with_handler(selective_handler(vec![EchoTool::new( + .with_permission_handler(Arc::new(ApproveAllHandler)).with_tools(selective_tools(vec![EchoTool::new( "currency_lookup", "countryCode", "CURRENCY_FOR_", "", )])) - .with_tools([EchoTool::tool_definition("currency_lookup", "countryCode")]) .with_available_tools(["city_lookup", "currency_lookup"]), ) .await @@ -353,13 +352,12 @@ async fn disconnecting_client_removes_its_tools() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(selective_handler(vec![EchoTool::new( + .with_permission_handler(Arc::new(ApproveAllHandler)).with_tools(selective_tools(vec![EchoTool::new( "stable_tool", "input", "STABLE_", "", )])) - .with_tools([EchoTool::tool_definition("stable_tool", "input")]) .with_available_tools(["stable_tool", "ephemeral_tool"]), ) .await @@ -368,13 +366,12 @@ async fn disconnecting_client_removes_its_tools() { let _session2 = client2 .resume_session( resume_config(session1.id().clone()) - .with_handler(selective_handler(vec![EchoTool::new( + .with_permission_handler(Arc::new(ApproveAllHandler)).with_tools(selective_tools(vec![EchoTool::new( "ephemeral_tool", "input", "EPHEMERAL_", "", )])) - .with_tools([EchoTool::tool_definition("ephemeral_tool", "input")]) .with_available_tools(["stable_tool", "ephemeral_tool"]), ) .await @@ -424,27 +421,26 @@ async fn disconnecting_client_removes_its_tools() { fn resume_config(session_id: SessionId) -> ResumeSessionConfig { ResumeSessionConfig::new(session_id) .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(selective_handler(Vec::new())) - .with_disable_resume(true) + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(selective_tools(Vec::new())) + .with_suppress_resume_event(true) } async fn start_tcp_server(ctx: &E2eContext, port: u16) -> Client { - Client::start( - ctx.client_options_with_transport(Transport::Tcp { port }) - .with_tcp_connection_token(SHARED_TOKEN), - ) + Client::start(ctx.client_options_with_transport(Transport::Tcp { + port, + connection_token: Some(SHARED_TOKEN.to_string()), + })) .await .expect("start TCP server client") } async fn start_external_client(ctx: &E2eContext, port: u16) -> Client { - Client::start( - ctx.client_options_with_transport(Transport::External { - host: "127.0.0.1".to_string(), - port, - }) - .with_tcp_connection_token(SHARED_TOKEN), - ) + Client::start(ctx.client_options_with_transport(Transport::External { + host: "127.0.0.1".to_string(), + port, + connection_token: Some(SHARED_TOKEN.to_string()), + })) .await .expect("start external client") } @@ -454,8 +450,15 @@ fn free_tcp_port() -> u16 { listener.local_addr().expect("local addr").port() } -fn selective_handler(tools: Vec) -> Arc { - Arc::new(SelectiveToolHandler { tools }) +fn selective_tools(tools: Vec) -> Vec { + tools + .into_iter() + .map(|t| { + let name = t.name; + let argument_name = t.argument_name; + EchoTool::tool_definition(name, argument_name).with_handler(Arc::new(t)) + }) + .collect() } fn permission_handler(result: PermissionResult) -> Arc { @@ -500,8 +503,8 @@ struct PermissionDecisionHandler { } #[async_trait] -impl SessionHandler for PermissionDecisionHandler { - async fn on_permission_request( +impl PermissionHandler for PermissionDecisionHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -514,32 +517,13 @@ impl SessionHandler for PermissionDecisionHandler { } } -struct SelectiveToolHandler { - tools: Vec, -} - #[async_trait] -impl SessionHandler for SelectiveToolHandler { - async fn on_permission_request( +impl ToolHandler for EchoTool { + async fn call( &self, - _session_id: SessionId, - _request_id: RequestId, - _data: PermissionRequestData, - ) -> PermissionResult { - PermissionResult::Approved - } - - async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult { - if let Some(tool) = self - .tools - .iter() - .find(|tool| tool.name == invocation.tool_name) - { - return tool.call(invocation); - } - - tokio::time::sleep(Duration::from_secs(30)).await; - ToolResult::Text(format!("Ignoring unowned tool {}", invocation.tool_name)) + invocation: ToolInvocation, + ) -> Result { + Ok(EchoTool::call(self, invocation)) } } diff --git a/rust/tests/e2e/multi_client_commands_elicitation.rs b/rust/tests/e2e/multi_client_commands_elicitation.rs index 218418ece..038209c2b 100644 --- a/rust/tests/e2e/multi_client_commands_elicitation.rs +++ b/rust/tests/e2e/multi_client_commands_elicitation.rs @@ -5,7 +5,9 @@ use async_trait::async_trait; use github_copilot_sdk::generated::session_events::{ CapabilitiesChangedData, CommandsChangedData, SessionEventType, }; -use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::handler::{ + ApproveAllHandler, ElicitationHandler, PermissionHandler, PermissionResult, +}; use github_copilot_sdk::{ Client, CommandContext, CommandDefinition, CommandHandler, ElicitationRequest, ElicitationResult, RequestId, ResumeSessionConfig, SessionId, Transport, @@ -80,10 +82,7 @@ async fn capabilities_changed_fires_when_second_client_joins_with_elicitation_ha let port = free_tcp_port(); let server = start_tcp_server(ctx, port).await; let session1 = server - .create_session( - ctx.approve_all_session_config() - .with_request_elicitation(false), - ) + .create_session(ctx.approve_all_session_config()) .await .expect("create session"); assert_ne!( @@ -105,7 +104,8 @@ async fn capabilities_changed_fires_when_second_client_joins_with_elicitation_ha let session2 = client2 .resume_session( resume_config(session1.id().clone()) - .with_handler(Arc::new(ElicitationApproveHandler)), + .with_permission_handler(Arc::new(ElicitationApproveHandler)) + .with_elicitation_handler(Arc::new(ElicitationApproveHandler)), ) .await .expect("resume session with elicitation handler"); @@ -142,10 +142,7 @@ async fn capabilities_changed_fires_when_elicitation_provider_disconnects() { let port = free_tcp_port(); let server = start_tcp_server(ctx, port).await; let session1 = server - .create_session( - ctx.approve_all_session_config() - .with_request_elicitation(false), - ) + .create_session(ctx.approve_all_session_config()) .await .expect("create session"); let client2 = start_external_client(ctx, port).await; @@ -162,7 +159,8 @@ async fn capabilities_changed_fires_when_elicitation_provider_disconnects() { let _session2 = client2 .resume_session( resume_config(session1.id().clone()) - .with_handler(Arc::new(ElicitationApproveHandler)), + .with_permission_handler(Arc::new(ElicitationApproveHandler)) + .with_elicitation_handler(Arc::new(ElicitationApproveHandler)), ) .await .expect("resume session with elicitation handler"); @@ -199,27 +197,25 @@ async fn capabilities_changed_fires_when_elicitation_provider_disconnects() { fn resume_config(session_id: SessionId) -> ResumeSessionConfig { ResumeSessionConfig::new(session_id) .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) - .with_disable_resume(true) + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_suppress_resume_event(true) } async fn start_tcp_server(ctx: &E2eContext, port: u16) -> Client { - Client::start( - ctx.client_options_with_transport(Transport::Tcp { port }) - .with_tcp_connection_token(SHARED_TOKEN), - ) + Client::start(ctx.client_options_with_transport(Transport::Tcp { + port, + connection_token: Some(SHARED_TOKEN.to_string()), + })) .await .expect("start TCP server client") } async fn start_external_client(ctx: &E2eContext, port: u16) -> Client { - Client::start( - ctx.client_options_with_transport(Transport::External { - host: "127.0.0.1".to_string(), - port, - }) - .with_tcp_connection_token(SHARED_TOKEN), - ) + Client::start(ctx.client_options_with_transport(Transport::External { + host: "127.0.0.1".to_string(), + port, + connection_token: Some(SHARED_TOKEN.to_string()), + })) .await .expect("start external client") } @@ -241,8 +237,8 @@ impl CommandHandler for NoopCommandHandler { struct ElicitationApproveHandler; #[async_trait] -impl SessionHandler for ElicitationApproveHandler { - async fn on_permission_request( +impl PermissionHandler for ElicitationApproveHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -250,8 +246,11 @@ impl SessionHandler for ElicitationApproveHandler { ) -> PermissionResult { PermissionResult::Approved } +} - async fn on_elicitation( +#[async_trait] +impl ElicitationHandler for ElicitationApproveHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, diff --git a/rust/tests/e2e/pending_work_resume.rs b/rust/tests/e2e/pending_work_resume.rs index 60f847416..0a782f980 100644 --- a/rust/tests/e2e/pending_work_resume.rs +++ b/rust/tests/e2e/pending_work_resume.rs @@ -7,7 +7,7 @@ use github_copilot_sdk::generated::session_events::{ AssistantMessageData, ExternalToolRequestedData, SessionEventType, SessionResumeData, }; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::{ Client, Error, RequestId, ResumeSessionConfig, SessionConfig, SessionId, Tool, ToolInvocation, ToolResult, Transport, @@ -43,19 +43,18 @@ async fn should_continue_pending_external_tool_request_after_resume() { let suspended_client = start_external_client(ctx, port).await; let (started_tx, mut started_rx) = mpsc::unbounded_channel(); let (_release_tx, release_rx) = oneshot::channel(); - let router = ToolHandlerRouter::new( - vec![Box::new(BlockingExternalTool { - started_tx, - release_rx: Mutex::new(Some(release_rx)), - })], - Arc::new(ApproveAllHandler), - ); + let router = Arc::new(BlockingExternalTool { + started_tx, + release_rx: Mutex::new(Some(release_rx)), + }); let session1 = suspended_client .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) - .with_tools([BlockingExternalTool::definition()]), + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(vec![ + BlockingExternalTool::definition().with_handler(router), + ]), ) .await .expect("create session"); @@ -228,7 +227,7 @@ async fn should_report_continuependingwork_true_in_resume_event() { .await .expect("resume session"); let resume_event = session2 - .get_messages() + .get_events() .await .expect("messages") .into_iter() @@ -263,26 +262,24 @@ async fn should_report_continuependingwork_true_in_resume_event() { fn resume_config(session_id: SessionId) -> ResumeSessionConfig { ResumeSessionConfig::new(session_id) .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) } async fn start_tcp_server(ctx: &E2eContext, port: u16) -> Client { - Client::start( - ctx.client_options_with_transport(Transport::Tcp { port }) - .with_tcp_connection_token(SHARED_TOKEN), - ) + Client::start(ctx.client_options_with_transport(Transport::Tcp { + port, + connection_token: Some(SHARED_TOKEN.to_string()), + })) .await .expect("start TCP server client") } async fn start_external_client(ctx: &E2eContext, port: u16) -> Client { - Client::start( - ctx.client_options_with_transport(Transport::External { - host: "127.0.0.1".to_string(), - port, - }) - .with_tcp_connection_token(SHARED_TOKEN), - ) + Client::start(ctx.client_options_with_transport(Transport::External { + host: "127.0.0.1".to_string(), + port, + connection_token: Some(SHARED_TOKEN.to_string()), + })) .await .expect("start external client") } @@ -316,10 +313,6 @@ impl BlockingExternalTool { #[async_trait] impl ToolHandler for BlockingExternalTool { - fn tool(&self) -> Tool { - Self::definition() - } - async fn call(&self, invocation: ToolInvocation) -> Result { let value = invocation .arguments diff --git a/rust/tests/e2e/per_session_auth.rs b/rust/tests/e2e/per_session_auth.rs index cf19181e2..24d379448 100644 --- a/rust/tests/e2e/per_session_auth.rs +++ b/rust/tests/e2e/per_session_auth.rs @@ -22,7 +22,8 @@ async fn session_uses_client_token_when_no_session_token_is_supplied() { let session = client .create_session( - SessionConfig::default().with_handler(Arc::new(ApproveAllHandler)), + SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)), ) .await .expect("create session"); @@ -62,7 +63,7 @@ async fn session_token_overrides_client_token() { let session = client .create_session( SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_github_token("bob-token"), ) .await @@ -95,7 +96,8 @@ async fn session_auth_status_is_unauthenticated_without_token() { let client = ctx.start_client().await; let session = client .create_session( - SessionConfig::default().with_handler(Arc::new(ApproveAllHandler)), + SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)), ) .await .expect("create session"); @@ -130,7 +132,7 @@ async fn session_fails_with_invalid_token() { let err = match client .create_session( SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_github_token("invalid-token"), ) .await diff --git a/rust/tests/e2e/permissions.rs b/rust/tests/e2e/permissions.rs index 8d7834768..cecf04269 100644 --- a/rust/tests/e2e/permissions.rs +++ b/rust/tests/e2e/permissions.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use async_trait::async_trait; use github_copilot_sdk::generated::api_types::PermissionsSetApproveAllRequest; use github_copilot_sdk::generated::session_events::{SessionEventType, ToolExecutionCompleteData}; -use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::handler::{PermissionHandler, PermissionResult}; use github_copilot_sdk::{ PermissionRequestData, RequestId, ResumeSessionConfig, SessionConfig, SessionId, }; @@ -76,7 +76,7 @@ async fn should_deny_permission_when_handler_returns_denied() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(StaticPermissionHandler::new( + .with_permission_handler(Arc::new(StaticPermissionHandler::new( PermissionResult::Denied, ))), ) @@ -126,7 +126,7 @@ async fn should_deny_tool_operations_when_handler_explicitly_denies() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(StaticPermissionHandler::new( + .with_permission_handler(Arc::new(StaticPermissionHandler::new( PermissionResult::UserNotAvailable, ))), ) @@ -166,7 +166,9 @@ async fn should_handle_async_permission_handler() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(AsyncPermissionHandler { request_tx })), + .with_permission_handler(Arc::new(AsyncPermissionHandler { + request_tx, + })), ) .await .expect("create session"); @@ -216,7 +218,9 @@ async fn should_resume_session_with_permission_handler() { .resume_session( ResumeSessionConfig::new(session_id) .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(RecordingPermissionHandler { request_tx })), + .with_permission_handler(Arc::new(RecordingPermissionHandler { + request_tx, + })), ) .await .expect("resume session"); @@ -268,7 +272,7 @@ async fn should_deny_tool_operations_when_handler_explicitly_denies_after_resume .resume_session( ResumeSessionConfig::new(session_id) .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(StaticPermissionHandler::new( + .with_permission_handler(Arc::new(StaticPermissionHandler::new( PermissionResult::UserNotAvailable, ))), ) @@ -313,7 +317,9 @@ async fn should_receive_toolcallid_in_permission_requests() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(RecordingPermissionHandler { request_tx })), + .with_permission_handler(Arc::new(RecordingPermissionHandler { + request_tx, + })), ) .await .expect("create session"); @@ -351,7 +357,7 @@ async fn should_deny_permission_with_noresult_kind() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(NotifyingPermissionHandler { + .with_permission_handler(Arc::new(NotifyingPermissionHandler { request_tx, result: PermissionResult::NoResult, })), @@ -386,7 +392,9 @@ async fn should_short_circuit_permission_handler_when_set_approve_all_enabled() .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(RecordingPermissionHandler { request_tx })), + .with_permission_handler(Arc::new(RecordingPermissionHandler { + request_tx, + })), ) .await .expect("create session"); @@ -454,7 +462,7 @@ async fn should_wait_for_slow_permission_handler() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(SlowPermissionHandler { + .with_permission_handler(Arc::new(SlowPermissionHandler { entered_tx: tokio::sync::Mutex::new(Some(entered_tx)), release_rx: tokio::sync::Mutex::new(Some(release_rx)), })), @@ -486,7 +494,7 @@ async fn should_wait_for_slow_permission_handler() { release_tx.send(()).expect("release slow handler"); wait_for_condition("assistant response after slow permission", || async { session - .get_messages() + .get_events() .await .expect("get messages") .iter() @@ -521,7 +529,9 @@ async fn should_invoke_permission_handler_for_write_operations() { .create_session( github_copilot_sdk::SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(RecordingPermissionHandler { request_tx })), + .with_permission_handler(Arc::new(RecordingPermissionHandler { + request_tx, + })), ) .await .expect("create session"); @@ -619,8 +629,8 @@ impl StaticPermissionHandler { } #[async_trait] -impl SessionHandler for StaticPermissionHandler { - async fn on_permission_request( +impl PermissionHandler for StaticPermissionHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -635,8 +645,8 @@ struct RecordingPermissionHandler { } #[async_trait] -impl SessionHandler for RecordingPermissionHandler { - async fn on_permission_request( +impl PermissionHandler for RecordingPermissionHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -653,8 +663,8 @@ struct NotifyingPermissionHandler { } #[async_trait] -impl SessionHandler for NotifyingPermissionHandler { - async fn on_permission_request( +impl PermissionHandler for NotifyingPermissionHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -670,8 +680,8 @@ struct AsyncPermissionHandler { } #[async_trait] -impl SessionHandler for AsyncPermissionHandler { - async fn on_permission_request( +impl PermissionHandler for AsyncPermissionHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -689,8 +699,8 @@ struct SlowPermissionHandler { } #[async_trait] -impl SessionHandler for SlowPermissionHandler { - async fn on_permission_request( +impl PermissionHandler for SlowPermissionHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, diff --git a/rust/tests/e2e/rpc_event_side_effects.rs b/rust/tests/e2e/rpc_event_side_effects.rs index e68939f98..9e5e2f1a4 100644 --- a/rust/tests/e2e/rpc_event_side_effects.rs +++ b/rust/tests/e2e/rpc_event_side_effects.rs @@ -239,7 +239,7 @@ async fn should_emit_snapshot_rewind_event_and_remove_events_on_truncate() { .expect("assistant message"); assert!(assistant_message_content(&answer).contains("SNAPSHOT_REWIND_TARGET")); let user_event = session - .get_messages() + .get_events() .await .expect("messages") .into_iter() @@ -268,10 +268,7 @@ async fn should_emit_snapshot_rewind_event_and_remove_events_on_truncate() { assert!(result.events_removed >= 1); rewind.await; - let remaining = session - .get_messages() - .await - .expect("messages after truncate"); + let remaining = session.get_events().await.expect("messages after truncate"); assert!(!remaining.iter().any(|event| event.id == target_event_id)); session.disconnect().await.expect("disconnect session"); @@ -301,7 +298,7 @@ async fn should_allow_session_use_after_truncate() { .await .expect("send"); let user_event = session - .get_messages() + .get_events() .await .expect("messages") .into_iter() diff --git a/rust/tests/e2e/rpc_mcp_and_skills.rs b/rust/tests/e2e/rpc_mcp_and_skills.rs index 1d65a0416..932ac35a3 100644 --- a/rust/tests/e2e/rpc_mcp_and_skills.rs +++ b/rust/tests/e2e/rpc_mcp_and_skills.rs @@ -438,7 +438,7 @@ fn test_mcp_servers(message: &str) -> HashMap { HashMap::from([( message.to_string(), McpServerConfig::Stdio(McpStdioServerConfig { - tools: vec!["*".to_string()], + tools: Some(vec!["*".to_string()]), command: echo_command(), args: echo_args(message), ..McpStdioServerConfig::default() diff --git a/rust/tests/e2e/rpc_session_state.rs b/rust/tests/e2e/rpc_session_state.rs index 5dee2c8a3..8b1378917 100644 --- a/rust/tests/e2e/rpc_session_state.rs +++ b/rust/tests/e2e/rpc_session_state.rs @@ -1,1002 +1 @@ -use github_copilot_sdk::generated::SessionMode; -use github_copilot_sdk::generated::api_types::{ - HistoryTruncateRequest, McpOauthLoginRequest, ModeSetRequest, ModelSwitchToRequest, - NameSetRequest, PermissionsSetApproveAllRequest, PlanUpdateRequest, SessionsForkRequest, - WorkspacesCreateFileRequest, WorkspacesReadFileRequest, -}; -use github_copilot_sdk::generated::session_events::{ - AssistantMessageData, SessionEventType, SessionTitleChangedData, - SessionWorkspaceFileChangedData, UserMessageData, WorkspaceFileChangedOperation, -}; -use super::support::{assistant_message_content, wait_for_event, with_e2e_context}; - -#[tokio::test] -async fn should_call_session_rpc_model_getcurrent() { - with_e2e_context( - "rpc_session_state", - "should_call_session_rpc_model_getcurrent", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_model("claude-sonnet-4.5"), - ) - .await - .expect("create session"); - - let current = session - .rpc() - .model() - .get_current() - .await - .expect("get current model"); - assert_eq!(current.model_id.as_deref(), Some("claude-sonnet-4.5")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_call_session_rpc_model_switchto() { - with_e2e_context( - "rpc_session_state", - "should_call_session_rpc_model_switchto", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_model("claude-sonnet-4.5"), - ) - .await - .expect("create session"); - - let result = session - .rpc() - .model() - .switch_to(ModelSwitchToRequest { - model_id: "gpt-4.1".to_string(), - reasoning_effort: Some("high".to_string()), - model_capabilities: None, - reasoning_summary: None, - }) - .await - .expect("switch model"); - assert_eq!(result.model_id.as_deref(), Some("gpt-4.1")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_get_and_set_session_mode() { - with_e2e_context( - "rpc_session_state", - "should_get_and_set_session_mode", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - assert_eq!( - session.rpc().mode().get().await.expect("get mode"), - SessionMode::Interactive - ); - session - .rpc() - .mode() - .set(ModeSetRequest { - mode: SessionMode::Plan, - }) - .await - .expect("set plan"); - assert_eq!( - session.rpc().mode().get().await.expect("get mode"), - SessionMode::Plan - ); - session - .rpc() - .mode() - .set(ModeSetRequest { - mode: SessionMode::Interactive, - }) - .await - .expect("set interactive"); - assert_eq!( - session.rpc().mode().get().await.expect("get mode"), - SessionMode::Interactive - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_set_and_get_each_session_mode_value() { - with_e2e_context( - "rpc_session_state", - "should_set_and_get_each_session_mode_value", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - for mode in [ - SessionMode::Interactive, - SessionMode::Plan, - SessionMode::Autopilot, - ] { - session - .rpc() - .mode() - .set(ModeSetRequest { mode: mode.clone() }) - .await - .expect("set mode"); - assert_eq!(session.rpc().mode().get().await.expect("get mode"), mode); - } - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_read_update_and_delete_plan() { - with_e2e_context( - "rpc_session_state", - "should_read_update_and_delete_plan", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - let content = "# Test Plan\n\n- Step 1\n- Step 2"; - - let initial = session.rpc().plan().read().await.expect("read initial"); - assert!(!initial.exists); - assert!(initial.content.is_none()); - session - .rpc() - .plan() - .update(PlanUpdateRequest { - content: content.to_string(), - }) - .await - .expect("update plan"); - let updated = session.rpc().plan().read().await.expect("read updated"); - assert!(updated.exists); - assert_eq!(updated.content.as_deref(), Some(content)); - session.rpc().plan().delete().await.expect("delete plan"); - let deleted = session.rpc().plan().read().await.expect("read deleted"); - assert!(!deleted.exists); - assert!(deleted.content.is_none()); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_call_workspace_file_rpc_methods() { - with_e2e_context( - "rpc_session_state", - "should_call_workspace_file_rpc_methods", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - let initial = session - .rpc() - .workspaces() - .list_files() - .await - .expect("list files"); - assert!(initial.files.is_empty()); - session - .rpc() - .workspaces() - .create_file(WorkspacesCreateFileRequest { - path: "test.txt".to_string(), - content: "Hello, workspace!".to_string(), - }) - .await - .expect("create file"); - let listed = session - .rpc() - .workspaces() - .list_files() - .await - .expect("list files"); - assert!(listed.files.iter().any(|file| file == "test.txt")); - let read = session - .rpc() - .workspaces() - .read_file(WorkspacesReadFileRequest { - path: "test.txt".to_string(), - }) - .await - .expect("read file"); - assert_eq!(read.content, "Hello, workspace!"); - let workspace = session - .rpc() - .workspaces() - .get_workspace() - .await - .expect("get workspace"); - assert!(workspace.workspace.is_some()); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_reject_workspace_file_path_traversal() { - with_e2e_context( - "rpc_session_state", - "should_reject_workspace_file_path_traversal", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - let err = session - .rpc() - .workspaces() - .create_file(WorkspacesCreateFileRequest { - path: "../escaped.txt".to_string(), - content: "outside".to_string(), - }) - .await - .expect_err("path traversal should fail"); - assert!(err.to_string().contains("workspace")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_create_workspace_file_with_nested_path_auto_creating_dirs() { - with_e2e_context( - "rpc_session_state", - "should_create_workspace_file_with_nested_path_auto_creating_dirs", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - let path = "nested-rust/subdir/file.txt"; - - session - .rpc() - .workspaces() - .create_file(WorkspacesCreateFileRequest { - path: path.to_string(), - content: "nested content".to_string(), - }) - .await - .expect("create nested file"); - let read = session - .rpc() - .workspaces() - .read_file(WorkspacesReadFileRequest { - path: path.to_string(), - }) - .await - .expect("read nested file"); - assert_eq!(read.content, "nested content"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_report_error_reading_nonexistent_workspace_file() { - with_e2e_context( - "rpc_session_state", - "should_report_error_reading_nonexistent_workspace_file", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - assert!( - session - .rpc() - .workspaces() - .read_file(WorkspacesReadFileRequest { - path: "never-exists-rust.txt".to_string(), - }) - .await - .is_err() - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_update_existing_workspace_file_with_update_operation() { - with_e2e_context( - "rpc_session_state", - "should_update_existing_workspace_file_with_update_operation", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - let path = "reused-rust.txt"; - session - .rpc() - .workspaces() - .create_file(WorkspacesCreateFileRequest { - path: path.to_string(), - content: "v1".to_string(), - }) - .await - .expect("create file"); - - let update_event = - wait_for_event(session.subscribe(), "workspace update event", |event| { - if event.parsed_type() != SessionEventType::SessionWorkspaceFileChanged { - return false; - } - let data = event - .typed_data::() - .expect("workspace file changed data"); - data.path == path && data.operation == WorkspaceFileChangedOperation::Update - }); - session - .rpc() - .workspaces() - .create_file(WorkspacesCreateFileRequest { - path: path.to_string(), - content: "v2".to_string(), - }) - .await - .expect("update file"); - update_event.await; - let read = session - .rpc() - .workspaces() - .read_file(WorkspacesReadFileRequest { - path: path.to_string(), - }) - .await - .expect("read updated"); - assert_eq!(read.content, "v2"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_reject_empty_or_whitespace_session_name() { - with_e2e_context( - "rpc_session_state", - "should_reject_empty_or_whitespace_session_name", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - for name in ["", " ", "\t\n \r"] { - let err = session - .rpc() - .name() - .set(NameSetRequest { - name: name.to_string(), - }) - .await - .expect_err("empty name should fail"); - assert!(err.to_string().to_lowercase().contains("empty")); - } - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_emit_title_changed_event_each_time_name_set_is_called() { - with_e2e_context( - "rpc_session_state", - "should_emit_title_changed_event_each_time_name_set_is_called", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - for title in ["Title-A-Rust", "Title-B-Rust"] { - let event = wait_for_event(session.subscribe(), "title changed", |event| { - if event.parsed_type() != SessionEventType::SessionTitleChanged { - return false; - } - event - .typed_data::() - .expect("title data") - .title - == title - }); - session - .rpc() - .name() - .set(NameSetRequest { - name: title.to_string(), - }) - .await - .expect("set name"); - event.await; - } - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_get_and_set_session_metadata() { - with_e2e_context( - "rpc_session_state", - "should_get_and_set_session_metadata", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - session - .rpc() - .name() - .set(NameSetRequest { - name: "SDK test session".to_string(), - }) - .await - .expect("set name"); - assert_eq!( - session - .rpc() - .name() - .get() - .await - .expect("get name") - .name - .as_deref(), - Some("SDK test session") - ); - let sources = session - .rpc() - .instructions() - .get_sources() - .await - .expect("get instruction sources"); - assert!(sources.sources.is_empty() || !sources.sources.is_empty()); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_call_session_usage_and_permission_rpcs() { - with_e2e_context( - "rpc_session_state", - "should_call_session_usage_and_permission_rpcs", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - let metrics = session.rpc().usage().get_metrics().await.expect("metrics"); - assert!(!metrics.session_start_time.is_empty()); - assert!( - session - .rpc() - .permissions() - .set_approve_all(PermissionsSetApproveAllRequest { - enabled: true, - source: None, - }) - .await - .expect("set approve all") - .success - ); - assert!( - session - .rpc() - .permissions() - .reset_session_approvals() - .await - .expect("reset approvals") - .success - ); - session - .rpc() - .permissions() - .set_approve_all(PermissionsSetApproveAllRequest { - enabled: false, - source: None, - }) - .await - .expect("disable approve all"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_fork_session_with_persisted_messages() { - with_e2e_context( - "rpc_session_state", - "should_fork_session_with_persisted_messages", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - let answer = session - .send_and_wait("Say FORK_SOURCE_ALPHA exactly.") - .await - .expect("send source") - .expect("source answer"); - assert!(assistant_message_content(&answer).contains("FORK_SOURCE_ALPHA")); - let fork = client - .rpc() - .sessions() - .fork(SessionsForkRequest { - name: None, - session_id: session.id().clone(), - to_event_id: None, - }) - .await - .expect("fork session"); - assert_ne!(fork.session_id, *session.id()); - let forked = client - .resume_session( - github_copilot_sdk::ResumeSessionConfig::new(fork.session_id) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )), - ) - .await - .expect("resume fork"); - let forked_messages = forked.get_messages().await.expect("forked messages"); - assert!(contains_user_message( - &forked_messages, - "Say FORK_SOURCE_ALPHA exactly." - )); - assert!(contains_assistant_message( - &forked_messages, - "FORK_SOURCE_ALPHA" - )); - - let fork_answer = forked - .send_and_wait("Now say FORK_CHILD_BETA exactly.") - .await - .expect("send fork") - .expect("fork answer"); - assert!(assistant_message_content(&fork_answer).contains("FORK_CHILD_BETA")); - let source_after = session.get_messages().await.expect("source messages"); - assert!(!contains_user_message( - &source_after, - "Now say FORK_CHILD_BETA exactly." - )); - - forked.disconnect().await.expect("disconnect fork"); - session.disconnect().await.expect("disconnect source"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_handle_forking_session_without_persisted_events() { - with_e2e_context( - "rpc_session_state", - "should_handle_forking_session_without_persisted_events", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - match client - .rpc() - .sessions() - .fork(SessionsForkRequest { - name: None, - session_id: session.id().clone(), - to_event_id: None, - }) - .await - { - Ok(fork) => { - assert!(!fork.session_id.as_str().trim().is_empty()); - assert_ne!(fork.session_id, *session.id()); - let forked = client - .resume_session( - github_copilot_sdk::ResumeSessionConfig::new(fork.session_id) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )), - ) - .await - .expect("resume fork"); - assert!( - !forked - .get_messages() - .await - .expect("forked messages") - .iter() - .any(|event| { - matches!( - event.parsed_type(), - SessionEventType::UserMessage - | SessionEventType::AssistantMessage - ) - }) - ); - forked.disconnect().await.expect("disconnect fork"); - } - Err(err) => { - let message = err.to_string(); - assert!( - message.contains("not found or has no persisted events"), - "unexpected sessions.fork error: {message}" - ); - assert!( - !message.contains("Unhandled method sessions.fork"), - "expected implemented error for sessions.fork, got {message}" - ); - } - } - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_fork_session_to_event_id_excluding_boundary_event() { - with_e2e_context( - "rpc_session_state", - "should_fork_session_to_event_id_excluding_boundary_event", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - session - .send_and_wait("Say FORK_BOUNDARY_FIRST exactly.") - .await - .expect("send first"); - session - .send_and_wait("Say FORK_BOUNDARY_SECOND exactly.") - .await - .expect("send second"); - let source_events = session.get_messages().await.expect("messages"); - let boundary_id = source_events - .iter() - .find(|event| { - event.parsed_type() == SessionEventType::UserMessage - && event.typed_data::().is_some_and(|data| { - data.content == "Say FORK_BOUNDARY_SECOND exactly." - }) - }) - .expect("second user message") - .id - .clone(); - let fork = client - .rpc() - .sessions() - .fork(SessionsForkRequest { - name: None, - session_id: session.id().clone(), - to_event_id: Some(boundary_id.clone()), - }) - .await - .expect("fork to boundary"); - let forked = client - .resume_session( - github_copilot_sdk::ResumeSessionConfig::new(fork.session_id) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )), - ) - .await - .expect("resume fork"); - let forked_events = forked.get_messages().await.expect("forked messages"); - assert!(contains_user_message( - &forked_events, - "Say FORK_BOUNDARY_FIRST exactly." - )); - assert!(!forked_events.iter().any(|event| event.id == boundary_id)); - assert!(!contains_user_message( - &forked_events, - "Say FORK_BOUNDARY_SECOND exactly." - )); - - forked.disconnect().await.expect("disconnect fork"); - session.disconnect().await.expect("disconnect source"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_report_error_when_forking_session_to_unknown_event_id() { - with_e2e_context( - "rpc_session_state", - "should_report_error_when_forking_session_to_unknown_event_id", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - session - .send_and_wait("Say FORK_UNKNOWN_EVENT_OK exactly.") - .await - .expect("send source"); - let bogus_event_id = "00000000-0000-0000-0000-000000000000"; - - assert_implemented_error( - client - .rpc() - .sessions() - .fork(SessionsForkRequest { - name: None, - session_id: session.id().clone(), - to_event_id: Some(bogus_event_id.to_string()), - }) - .await, - "sessions.fork", - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_report_implemented_errors_for_unsupported_session_rpc_paths() { - with_e2e_context( - "rpc_session_state", - "should_report_implemented_errors_for_unsupported_session_rpc_paths", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - assert_implemented_error( - session - .rpc() - .history() - .truncate(HistoryTruncateRequest { - event_id: "missing-event".to_string(), - }) - .await, - "session.history.truncate", - ); - assert_implemented_error( - session - .rpc() - .mcp() - .oauth() - .login(McpOauthLoginRequest { - server_name: "missing-server".to_string(), - callback_success_message: None, - client_name: None, - force_reauth: None, - }) - .await, - "session.mcp.oauth.login", - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_compact_session_history_after_messages() { - with_e2e_context( - "rpc_session_state", - "should_compact_session_history_after_messages", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - let answer = session - .send_and_wait("What is 2+2?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains('4')); - - let compact = session - .rpc() - .history() - .compact() - .await - .expect("compact history"); - assert!(compact.success); - assert!(compact.messages_removed >= 0); - session.rpc().name().get().await.expect("name still works"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -fn contains_user_message(events: &[github_copilot_sdk::SessionEvent], expected: &str) -> bool { - events.iter().any(|event| { - event.parsed_type() == SessionEventType::UserMessage - && event - .typed_data::() - .is_some_and(|data| data.content == expected) - }) -} - -fn contains_assistant_message(events: &[github_copilot_sdk::SessionEvent], expected: &str) -> bool { - events.iter().any(|event| { - event.parsed_type() == SessionEventType::AssistantMessage - && event - .typed_data::() - .is_some_and(|data| data.content.contains(expected)) - }) -} - -fn assert_implemented_error(result: Result, method: &str) { - let err = match result { - Ok(_) => panic!("RPC should fail"), - Err(err) => err, - }; - let message = err.to_string(); - assert!( - !message.contains(&format!("Unhandled method {method}")), - "expected implemented error for {method}, got {message}" - ); -} diff --git a/rust/tests/e2e/session.rs b/rust/tests/e2e/session.rs index 25aff47a9..ce07bf7f7 100644 --- a/rust/tests/e2e/session.rs +++ b/rust/tests/e2e/session.rs @@ -7,7 +7,7 @@ use github_copilot_sdk::generated::session_events::{ SessionStartData, SessionWarningData, UserMessageData, }; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::types::LogLevel as SessionLogLevel; use github_copilot_sdk::{ Attachment, AttachmentLineRange, AttachmentSelectionPosition, AttachmentSelectionRange, @@ -37,7 +37,7 @@ async fn shouldcreateanddisconnectsessions() { .expect("create session"); assert_uuid_like(session.id()); - let messages = session.get_messages().await.expect("get messages"); + let messages = session.get_events().await.expect("get messages"); assert!(!messages.is_empty(), "expected initial session events"); let start = messages[0] .typed_data::() @@ -46,7 +46,7 @@ async fn shouldcreateanddisconnectsessions() { session.disconnect().await.expect("disconnect session"); assert!( - session.get_messages().await.is_err(), + session.get_events().await.is_err(), "disconnected session should no longer serve message history" ); client.stop().await.expect("stop client"); @@ -329,15 +329,12 @@ async fn should_create_a_session_with_defaultagent_excludedtools() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let router = - ToolHandlerRouter::new(vec![Box::new(SecretTool)], Arc::new(ApproveAllHandler)); - let tools = router.tools(); let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) - .with_tools(tools) + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(vec![secret_tool()]) .with_default_agent(DefaultAgentConfig { excluded_tools: Some(vec!["secret_tool".to_string()]), }), @@ -364,17 +361,12 @@ async fn should_create_session_with_custom_tool() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let router = ToolHandlerRouter::new( - vec![Box::new(SecretNumberTool)], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) - .with_tools(tools), + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(vec![secret_number_tool()]), ) .await .expect("create session"); @@ -405,7 +397,7 @@ async fn should_throw_error_when_resuming_non_existent_session() { let config = ResumeSessionConfig::new(github_copilot_sdk::SessionId::new( "non-existent-session-id", )) - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_github_token(super::support::DEFAULT_TEST_TOKEN); assert!(client.resume_session(config).await.is_err()); @@ -447,7 +439,7 @@ async fn should_abort_a_session() { session.abort().await.expect("abort session"); idle.await.expect("idle task"); - let messages = session.get_messages().await.expect("get messages"); + let messages = session.get_events().await.expect("get messages"); assert!(messages .iter() .any(|event| event.parsed_type() == SessionEventType::Abort)); @@ -494,7 +486,9 @@ async fn should_resume_a_session_using_the_same_client() { let resumed = client .resume_session( ResumeSessionConfig::new(session_id.clone()) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) + .with_permission_handler(Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )) .with_github_token(super::support::DEFAULT_TEST_TOKEN), ) .await @@ -551,14 +545,16 @@ async fn should_resume_a_session_using_a_new_client() { .resume_session( ResumeSessionConfig::new(session_id.clone()) .with_continue_pending_work(true) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) + .with_permission_handler(Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )) .with_github_token(super::support::DEFAULT_TEST_TOKEN), ) .await .expect("resume session"); assert_eq!(resumed.id(), &session_id); - let messages = resumed.get_messages().await.expect("get messages"); + let messages = resumed.get_events().await.expect("get messages"); assert!( messages .iter() @@ -1389,7 +1385,7 @@ async fn should_send_with_mode_property() { .await; let user_message = session - .get_messages() + .get_events() .await .expect("get messages") .into_iter() @@ -1485,7 +1481,7 @@ async fn should_resume_session_with_custom_provider() { let session_id = session.id().clone(); let mut config = ResumeSessionConfig::new(session_id.clone()) - .with_handler(Arc::new(ApproveAllHandler)); + .with_permission_handler(Arc::new(ApproveAllHandler)); config.provider = Some( ProviderConfig::new("https://api.openai.com/v1") .with_provider_type("openai") @@ -1507,7 +1503,7 @@ async fn latest_user_message( session: &github_copilot_sdk::session::Session, ) -> github_copilot_sdk::SessionEvent { session - .get_messages() + .get_events() .await .expect("get messages") .into_iter() @@ -1520,10 +1516,6 @@ struct SecretNumberTool; #[async_trait::async_trait] impl ToolHandler for SecretNumberTool { - fn tool(&self) -> Tool { - secret_number_tool() - } - async fn call(&self, invocation: ToolInvocation) -> Result { let key = invocation .arguments @@ -1538,22 +1530,23 @@ impl ToolHandler for SecretNumberTool { } } +fn secret_tool() -> Tool { + Tool::new("secret_tool") + .with_description("A secret tool hidden from the default agent") + .with_parameters(json!({ + "type": "object", + "properties": { + "input": { "type": "string" } + }, + "required": ["input"] + })) + .with_handler(Arc::new(SecretTool)) +} + struct SecretTool; #[async_trait::async_trait] impl ToolHandler for SecretTool { - fn tool(&self) -> Tool { - Tool::new("secret_tool") - .with_description("A secret tool hidden from the default agent") - .with_parameters(json!({ - "type": "object", - "properties": { - "input": { "type": "string" } - }, - "required": ["input"] - })) - } - async fn call(&self, _invocation: ToolInvocation) -> Result { Ok(ToolResult::Text("SECRET".to_string())) } @@ -1572,4 +1565,5 @@ fn secret_number_tool() -> Tool { }, "required": ["key"] })) + .with_handler(Arc::new(SecretNumberTool)) } diff --git a/rust/tests/e2e/session_config.rs b/rust/tests/e2e/session_config.rs index 05c818169..8b1378917 100644 --- a/rust/tests/e2e/session_config.rs +++ b/rust/tests/e2e/session_config.rs @@ -1,955 +1 @@ -use std::collections::HashMap; -use github_copilot_sdk::generated::api_types::{ - ModelCapabilitiesOverride, ModelCapabilitiesOverrideSupports, -}; -use github_copilot_sdk::generated::session_events::{SessionEventType, SessionStartData}; -use github_copilot_sdk::{ - Attachment, MessageOptions, ProviderConfig, ResumeSessionConfig, SessionConfig, SessionId, - SetModelOptions, SystemMessageConfig, -}; - -use super::support::{ - assistant_message_content, get_system_message, get_tool_names, with_e2e_context, -}; - -const PROVIDER_HEADER_NAME: &str = "x-copilot-sdk-provider-header"; -const CLIENT_NAME: &str = "rust-public-surface-client"; -const VIEW_IMAGE_PROMPT: &str = - "Use the view tool to look at the file test.png and describe what you see"; -const PNG_1X1_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="; - -#[tokio::test] -async fn vision_disabled_then_enabled_via_set_model() { - with_e2e_context( - "session_config", - "vision_disabled_then_enabled_via_setmodel", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - std::fs::write( - ctx.work_dir().join("test.png"), - decode_base64(PNG_1X1_BASE64), - ) - .expect("write image"); - - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_model("claude-sonnet-4.5") - .with_model_capabilities(vision_capabilities(false)), - ) - .await - .expect("create session"); - - session - .send_and_wait(VIEW_IMAGE_PROMPT) - .await - .expect("send"); - let traffic_after_t1 = ctx.exchanges(); - assert!( - !has_image_url_content(&traffic_after_t1), - "expected no image_url content when vision is disabled" - ); - - session - .set_model( - "claude-sonnet-4.5", - Some( - SetModelOptions::default() - .with_model_capabilities(vision_capabilities(true)), - ), - ) - .await - .expect("set model"); - - session - .send_and_wait(VIEW_IMAGE_PROMPT) - .await - .expect("send"); - let traffic_after_t2 = ctx.exchanges(); - let new_exchanges = &traffic_after_t2[traffic_after_t1.len()..]; - assert!(!new_exchanges.is_empty()); - assert!( - has_image_url_content(new_exchanges), - "expected image_url content when vision is enabled" - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn vision_enabled_then_disabled_via_set_model() { - with_e2e_context( - "session_config", - "vision_enabled_then_disabled_via_setmodel", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - std::fs::write( - ctx.work_dir().join("test.png"), - decode_base64(PNG_1X1_BASE64), - ) - .expect("write image"); - - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_model("claude-sonnet-4.5") - .with_model_capabilities(vision_capabilities(true)), - ) - .await - .expect("create session"); - - session - .send_and_wait(VIEW_IMAGE_PROMPT) - .await - .expect("send"); - let traffic_after_t1 = ctx.exchanges(); - assert!( - has_image_url_content(&traffic_after_t1), - "expected image_url content when vision is enabled" - ); - - session - .set_model( - "claude-sonnet-4.5", - Some( - SetModelOptions::default() - .with_model_capabilities(vision_capabilities(false)), - ), - ) - .await - .expect("set model"); - - session - .send_and_wait(VIEW_IMAGE_PROMPT) - .await - .expect("send"); - let traffic_after_t2 = ctx.exchanges(); - let new_exchanges = &traffic_after_t2[traffic_after_t1.len()..]; - assert!(!new_exchanges.is_empty()); - assert!( - !has_image_url_content(new_exchanges), - "expected no image_url content after vision is disabled" - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_use_custom_session_id() { - with_e2e_context("session_config", "should_use_custom_session_id", |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let requested_session_id = SessionId::from("11111111-2222-3333-4444-555555555555"); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_session_id(requested_session_id.clone()), - ) - .await - .expect("create session"); - - assert_eq!(session.id(), &requested_session_id); - let messages = session.get_messages().await.expect("messages"); - let start_event = messages - .iter() - .find(|event| event.parsed_type() == SessionEventType::SessionStart) - .expect("session.start event"); - let data = start_event - .typed_data::() - .expect("session.start data"); - assert_eq!(data.session_id, requested_session_id); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }) - .await; -} - -#[tokio::test] -async fn should_apply_reasoning_effort_on_session_create() { - with_e2e_context( - "session_config", - "should_apply_reasoning_effort_on_session_create", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - approve_all_without_token() - .with_model("custom-reasoning-model") - .with_provider(provider(ctx.proxy_url(), "create-reasoning")) - .with_reasoning_effort("high"), - ) - .await - .expect("create session"); - - let start_event = session - .get_messages() - .await - .expect("messages") - .into_iter() - .find(|event| event.parsed_type() == SessionEventType::SessionStart) - .expect("session.start event"); - let data = start_event - .typed_data::() - .expect("session.start data"); - assert_eq!( - data.selected_model.as_deref(), - Some("custom-reasoning-model") - ); - assert_eq!(data.reasoning_effort.as_deref(), Some("high")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_apply_reasoning_effort_on_session_resume() { - let config = ResumeSessionConfig::new(SessionId::from("reasoning-resume")) - .with_reasoning_effort("medium"); - - assert_eq!(config.reasoning_effort.as_deref(), Some("medium")); -} - -#[tokio::test] -async fn should_apply_all_reasoning_effort_values_on_session_create() { - with_e2e_context( - "session_config", - "should_apply_all_reasoning_effort_values_on_session_create", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - - for effort in ["low", "medium", "high"] { - let session = client - .create_session( - approve_all_without_token() - .with_model("custom-reasoning-model") - .with_provider(provider( - ctx.proxy_url(), - &format!("reasoning-{effort}"), - )) - .with_reasoning_effort(effort), - ) - .await - .unwrap_or_else(|err| panic!("create session with effort {effort}: {err}")); - - let start_event = session - .get_messages() - .await - .expect("messages") - .into_iter() - .find(|event| event.parsed_type() == SessionEventType::SessionStart) - .expect("session.start event"); - let data = start_event - .typed_data::() - .expect("session.start data"); - assert_eq!( - data.selected_model.as_deref(), - Some("custom-reasoning-model") - ); - assert_eq!(data.reasoning_effort.as_deref(), Some(effort)); - - session.disconnect().await.expect("disconnect session"); - } - - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_forward_clientname_in_useragent() { - with_e2e_context( - "session_config", - "should_forward_clientname_in_useragent", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_client_name(CLIENT_NAME), - ) - .await - .expect("create session"); - - session.send_and_wait("What is 1+1?").await.expect("send"); - - let exchange = only_exchange(ctx.exchanges()); - assert_header_contains(&exchange, "user-agent", CLIENT_NAME); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_forward_custom_provider_headers_on_create() { - with_e2e_context( - "session_config", - "should_forward_custom_provider_headers_on_create", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - approve_all_without_token() - .with_model("claude-sonnet-4.5") - .with_provider(provider(ctx.proxy_url(), "create-provider-header")), - ) - .await - .expect("create session"); - - let answer = session - .send_and_wait("What is 1+1?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains('2')); - - let exchange = only_exchange(ctx.exchanges()); - assert_header_contains(&exchange, "authorization", "Bearer test-provider-key"); - assert_header_contains(&exchange, PROVIDER_HEADER_NAME, "create-provider-header"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_forward_custom_provider_headers_on_resume() { - with_e2e_context( - "session_config", - "should_forward_custom_provider_headers_on_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session1 = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create first session"); - let session2 = client - .resume_session( - ResumeSessionConfig::new(session1.id().clone()) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_model_capabilities(vision_capabilities(false)) - .with_provider( - provider(ctx.proxy_url(), "resume-provider-header") - .with_model_id("claude-sonnet-4.5"), - ), - ) - .await - .expect("resume session"); - - let answer = session2 - .send_and_wait("What is 2+2?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains('4')); - - let exchange = only_exchange(ctx.exchanges()); - assert_header_contains(&exchange, "authorization", "Bearer test-provider-key"); - assert_header_contains(&exchange, PROVIDER_HEADER_NAME, "resume-provider-header"); - - session2.disconnect().await.expect("disconnect resumed"); - session1.disconnect().await.expect("disconnect original"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_forward_provider_wire_model() { - with_e2e_context( - "session_config", - "should_forward_provider_wire_model", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - approve_all_without_token() - .with_model("claude-sonnet-4.5") - .with_provider( - ProviderConfig::new(ctx.proxy_url()) - .with_provider_type("openai") - .with_api_key("test-provider-key") - .with_wire_model("test-wire-model") - .with_max_output_tokens(1024), - ), - ) - .await - .expect("create session"); - - session.send_and_wait("What is 1+1?").await.expect("send"); - - let exchange = only_exchange(ctx.exchanges()); - assert_eq!(request_model(&exchange), Some("test-wire-model")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_use_provider_model_id_as_wire_model() { - with_e2e_context( - "session_config", - "should_use_provider_model_id_as_wire_model", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - approve_all_without_token().with_provider( - ProviderConfig::new(ctx.proxy_url()) - .with_provider_type("openai") - .with_api_key("test-provider-key") - .with_model_id("claude-sonnet-4.5"), - ), - ) - .await - .expect("create session"); - - session.send_and_wait("What is 1+1?").await.expect("send"); - - let exchange = only_exchange(ctx.exchanges()); - assert_eq!(request_model(&exchange), Some("claude-sonnet-4.5")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_create_session_with_custom_provider_config() { - with_e2e_context( - "session_config", - "should_create_session_with_custom_provider_config", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(approve_all_without_token().with_provider( - ProviderConfig::new("https://api.example.com/v1").with_api_key("test-key"), - )) - .await - .expect("create session"); - - assert!(!session.id().as_ref().is_empty()); - let _ = session.disconnect().await; - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_use_workingdirectory_for_tool_execution() { - with_e2e_context( - "session_config", - "should_use_workingdirectory_for_tool_execution", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let sub_dir = ctx.work_dir().join("subproject"); - std::fs::create_dir_all(&sub_dir).expect("create subproject"); - std::fs::write(sub_dir.join("marker.txt"), "I am in the subdirectory") - .expect("write marker"); - - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_working_directory(sub_dir), - ) - .await - .expect("create session"); - - let answer = session - .send_and_wait("Read the file marker.txt and tell me what it says") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains("subdirectory")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_apply_workingdirectory_on_session_resume() { - with_e2e_context( - "session_config", - "should_apply_workingdirectory_on_session_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let sub_dir = ctx.work_dir().join("resume-subproject"); - std::fs::create_dir_all(&sub_dir).expect("create resume subproject"); - std::fs::write( - sub_dir.join("resume-marker.txt"), - "I am in the resume working directory", - ) - .expect("write resume marker"); - - let client = ctx.start_client().await; - let session1 = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create first session"); - let session2 = client - .resume_session( - ResumeSessionConfig::new(session1.id().clone()) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_working_directory(sub_dir), - ) - .await - .expect("resume session"); - - let answer = session2 - .send_and_wait("Read the file resume-marker.txt and tell me what it says") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains("resume working directory")); - - session2.disconnect().await.expect("disconnect resumed"); - session1.disconnect().await.expect("disconnect original"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_apply_systemmessage_on_session_resume() { - with_e2e_context( - "session_config", - "should_apply_systemmessage_on_session_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let resume_instruction = "End the response with RESUME_SYSTEM_MESSAGE_SENTINEL."; - let client = ctx.start_client().await; - let session1 = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create first session"); - let session2 = client - .resume_session( - ResumeSessionConfig::new(session1.id().clone()) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_system_message( - SystemMessageConfig::new() - .with_mode("append") - .with_content(resume_instruction), - ), - ) - .await - .expect("resume session"); - - let answer = session2 - .send_and_wait("What is 1+1?") - .await - .expect("send") - .expect("assistant message"); - assert!( - assistant_message_content(&answer).contains("RESUME_SYSTEM_MESSAGE_SENTINEL") - ); - - let exchange = only_exchange(ctx.exchanges()); - assert!(get_system_message(&exchange).contains(resume_instruction)); - - session2.disconnect().await.expect("disconnect resumed"); - session1.disconnect().await.expect("disconnect original"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_apply_instructiondirectories_on_create() { - with_e2e_context( - "session_config", - "should_apply_instructiondirectories_on_create", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let project_dir = ctx.work_dir().join("instruction-create-project"); - let instruction_dir = ctx.work_dir().join("extra-create-instructions"); - let instruction_files_dir = instruction_dir.join(".github").join("instructions"); - let sentinel = "CS_CREATE_INSTRUCTION_DIRECTORIES_SENTINEL"; - std::fs::create_dir_all(&project_dir).expect("create project dir"); - std::fs::create_dir_all(&instruction_files_dir).expect("create instruction dir"); - std::fs::write( - instruction_files_dir.join("extra.instructions.md"), - format!("Always include {sentinel}."), - ) - .expect("write instructions"); - - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_working_directory(project_dir) - .with_instruction_directories([instruction_dir]), - ) - .await - .expect("create session"); - - session.send_and_wait("What is 1+1?").await.expect("send"); - - let exchange = only_exchange(ctx.exchanges()); - assert!(get_system_message(&exchange).contains(sentinel)); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_apply_instructiondirectories_on_resume() { - with_e2e_context( - "session_config", - "should_apply_instructiondirectories_on_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let project_dir = ctx.work_dir().join("instruction-resume-project"); - let instruction_dir = ctx.work_dir().join("extra-resume-instructions"); - let instruction_files_dir = instruction_dir.join(".github").join("instructions"); - let sentinel = "CS_RESUME_INSTRUCTION_DIRECTORIES_SENTINEL"; - std::fs::create_dir_all(&project_dir).expect("create project dir"); - std::fs::create_dir_all(&instruction_files_dir).expect("create instruction dir"); - std::fs::write( - instruction_files_dir.join("extra.instructions.md"), - format!("Always include {sentinel}."), - ) - .expect("write instructions"); - - let client = ctx.start_client().await; - let session1 = client - .create_session( - ctx.approve_all_session_config() - .with_working_directory(project_dir.clone()), - ) - .await - .expect("create first session"); - let session2 = client - .resume_session( - ResumeSessionConfig::new(session1.id().clone()) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_working_directory(project_dir) - .with_instruction_directories([instruction_dir]), - ) - .await - .expect("resume session"); - - session2.send_and_wait("What is 1+1?").await.expect("send"); - - let exchange = only_exchange(ctx.exchanges()); - assert!(get_system_message(&exchange).contains(sentinel)); - - session2.disconnect().await.expect("disconnect resumed"); - session1.disconnect().await.expect("disconnect original"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_apply_availabletools_on_session_resume() { - with_e2e_context( - "session_config", - "should_apply_availabletools_on_session_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session1 = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create first session"); - let session2 = client - .resume_session( - ResumeSessionConfig::new(session1.id().clone()) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_available_tools(["view"]), - ) - .await - .expect("resume session"); - - session2.send_and_wait("What is 1+1?").await.expect("send"); - - let exchange = only_exchange(ctx.exchanges()); - assert_eq!(get_tool_names(&exchange), vec!["view".to_string()]); - - session2.disconnect().await.expect("disconnect resumed"); - session1.disconnect().await.expect("disconnect original"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_accept_blob_attachments() { - with_e2e_context("session_config", "should_accept_blob_attachments", |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - std::fs::write( - ctx.work_dir().join("pixel.png"), - decode_base64(PNG_1X1_BASE64), - ) - .expect("write pixel"); - - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - session - .send_and_wait( - MessageOptions::new("What color is this pixel? Reply in one word.") - .with_attachments(vec![Attachment::Blob { - data: PNG_1X1_BASE64.to_string(), - mime_type: "image/png".to_string(), - display_name: Some("pixel.png".to_string()), - }]), - ) - .await - .expect("send"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }) - .await; -} - -#[tokio::test] -async fn should_accept_message_attachments() { - with_e2e_context( - "session_config", - "should_accept_message_attachments", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let attached_path = ctx.work_dir().join("attached.txt"); - std::fs::write(&attached_path, "This file is attached").expect("write attachment"); - - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - session - .send_and_wait( - MessageOptions::new("Summarize the attached file").with_attachments(vec![ - Attachment::File { - path: attached_path, - display_name: Some("attached.txt".to_string()), - line_range: None, - }, - ]), - ) - .await - .expect("send"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -fn provider(proxy_url: &str, header_value: &str) -> ProviderConfig { - ProviderConfig::new(proxy_url) - .with_provider_type("openai") - .with_api_key("test-provider-key") - .with_headers(HashMap::from([( - PROVIDER_HEADER_NAME.to_string(), - header_value.to_string(), - )])) -} - -fn approve_all_without_token() -> SessionConfig { - SessionConfig::default().with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) -} - -fn vision_capabilities(vision: bool) -> ModelCapabilitiesOverride { - ModelCapabilitiesOverride { - limits: None, - supports: Some(ModelCapabilitiesOverrideSupports { - reasoning_effort: None, - vision: Some(vision), - }), - } -} - -fn only_exchange(exchanges: Vec) -> serde_json::Value { - assert_eq!(exchanges.len(), 1, "expected exactly one exchange"); - exchanges.into_iter().next().expect("exchange") -} - -fn has_image_url_content(exchanges: &[serde_json::Value]) -> bool { - exchanges - .iter() - .filter_map(|exchange| exchange.get("request")) - .filter_map(|request| request.get("messages")) - .filter_map(serde_json::Value::as_array) - .flatten() - .filter(|message| { - message - .get("role") - .and_then(serde_json::Value::as_str) - .is_some_and(|role| role == "user") - }) - .filter_map(|message| message.get("content")) - .filter_map(serde_json::Value::as_array) - .flatten() - .any(|part| { - part.get("type") - .and_then(serde_json::Value::as_str) - .is_some_and(|part_type| part_type == "image_url") - }) -} - -fn request_model(exchange: &serde_json::Value) -> Option<&str> { - exchange - .get("request") - .and_then(|request| request.get("model")) - .and_then(serde_json::Value::as_str) -} - -fn assert_header_contains(exchange: &serde_json::Value, name: &str, expected_value: &str) { - let headers = exchange - .get("requestHeaders") - .and_then(serde_json::Value::as_object) - .expect("requestHeaders"); - let actual = headers - .iter() - .find_map(|(key, value)| key.eq_ignore_ascii_case(name).then(|| header_value(value))) - .unwrap_or_else(|| panic!("missing header {name}; actual headers: {headers:?}")); - assert!( - actual.contains(expected_value), - "header {name} value {actual:?} did not contain {expected_value:?}" - ); -} - -fn header_value(value: &serde_json::Value) -> String { - match value { - serde_json::Value::String(value) => value.clone(), - serde_json::Value::Array(values) => values - .iter() - .map(header_value) - .collect::>() - .join(","), - other => other.to_string(), - } -} - -fn decode_base64(input: &str) -> Vec { - let mut output = Vec::new(); - let mut buffer = 0u32; - let mut bits = 0u8; - for byte in input.bytes().filter(|byte| !byte.is_ascii_whitespace()) { - let value = match byte { - b'A'..=b'Z' => byte - b'A', - b'a'..=b'z' => byte - b'a' + 26, - b'0'..=b'9' => byte - b'0' + 52, - b'+' => 62, - b'/' => 63, - b'=' => break, - _ => panic!("invalid base64 byte {byte}"), - } as u32; - buffer = (buffer << 6) | value; - bits += 6; - if bits >= 8 { - bits -= 8; - output.push(((buffer >> bits) & 0xff) as u8); - } - } - output -} diff --git a/rust/tests/e2e/session_fs.rs b/rust/tests/e2e/session_fs.rs index f069f6ffe..8b1378917 100644 --- a/rust/tests/e2e/session_fs.rs +++ b/rust/tests/e2e/session_fs.rs @@ -1,630 +1 @@ -use std::path::{Path, PathBuf}; -use std::sync::Arc; -use async_trait::async_trait; -use github_copilot_sdk::generated::api_types::PlanUpdateRequest; -use github_copilot_sdk::{ - Client, DirEntry, DirEntryKind, FileInfo, FsError, SessionConfig, SessionFsConfig, - SessionFsConventions, SessionFsProvider, -}; - -use super::support::{assistant_message_content, wait_for_condition, with_e2e_context}; - -#[tokio::test] -async fn should_route_file_operations_through_the_session_fs_provider() { - with_e2e_context( - "session_fs", - "should_route_file_operations_through_the_session_fs_provider", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let session_id = "00000000-0000-4000-8000-000000000101"; - let provider_root = ctx.work_dir().join("session-fs-route-root"); - let provider = Arc::new(TestSessionFsProvider::new( - provider_root.clone(), - session_id, - )); - let client = start_session_fs_client(ctx, provider.clone()).await; - let session = client - .create_session(session_config(ctx, provider).with_session_id(session_id)) - .await - .expect("create session"); - - let answer = session - .send_and_wait("What is 100 + 200?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains("300")); - let events_path = provider_root - .join(session.id().as_ref()) - .join(provider_relative_path(&session_state_path())) - .join("events.jsonl"); - wait_for_file_containing(&events_path, "300").await; - let content = std::fs::read_to_string(events_path).expect("read events"); - assert!(content.contains("300")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_load_session_data_from_fs_provider_on_resume() { - with_e2e_context( - "session_fs", - "should_load_session_data_from_fs_provider_on_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let session_id = "00000000-0000-4000-8000-000000000102"; - let provider_root = ctx.work_dir().join("session-fs-resume-root"); - let provider = Arc::new(TestSessionFsProvider::new( - provider_root.clone(), - session_id, - )); - let client = start_session_fs_client(ctx, provider.clone()).await; - let session1 = client - .create_session( - session_config(ctx, provider.clone()).with_session_id(session_id), - ) - .await - .expect("create session"); - let session_id = session1.id().clone(); - let first = session1 - .send_and_wait("What is 50 + 50?") - .await - .expect("send first") - .expect("first answer"); - assert!(assistant_message_content(&first).contains("100")); - session1 - .disconnect() - .await - .expect("disconnect first session"); - - let session2 = client - .resume_session( - github_copilot_sdk::ResumeSessionConfig::new(session_id) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) - .with_session_fs_provider(provider), - ) - .await - .expect("resume session"); - let second = session2 - .send_and_wait("What is that times 3?") - .await - .expect("send second") - .expect("second answer"); - assert!(assistant_message_content(&second).contains("300")); - - session2 - .disconnect() - .await - .expect("disconnect resumed session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_map_all_sessionfs_handler_operations() { - let root = PathBuf::from("target").join("session-fs-handler-ops"); - if root.exists() { - std::fs::remove_dir_all(&root).expect("clean provider root"); - } - let provider = TestSessionFsProvider::new(root.clone(), "handler-session"); - - provider - .mkdir("/workspace/nested", true, None) - .await - .expect("mkdir"); - provider - .write_file("/workspace/nested/file.txt", "hello", None) - .await - .expect("write"); - provider - .append_file("/workspace/nested/file.txt", " world", None) - .await - .expect("append"); - assert!( - provider - .exists("/workspace/nested/file.txt") - .await - .expect("exists") - ); - let stat = provider - .stat("/workspace/nested/file.txt") - .await - .expect("stat"); - assert!(stat.is_file); - assert!(!stat.is_directory); - assert_eq!(stat.size, "hello world".len() as i64); - assert_eq!( - provider - .read_file("/workspace/nested/file.txt") - .await - .expect("read"), - "hello world" - ); - assert!( - provider - .readdir("/workspace/nested") - .await - .expect("readdir") - .iter() - .any(|entry| entry == "file.txt") - ); - assert!( - provider - .readdir_with_types("/workspace/nested") - .await - .expect("readdir types") - .iter() - .any(|entry| entry.name == "file.txt" && entry.kind == DirEntryKind::File) - ); - provider - .rename( - "/workspace/nested/file.txt", - "/workspace/nested/renamed.txt", - ) - .await - .expect("rename"); - assert!( - !provider - .exists("/workspace/nested/file.txt") - .await - .expect("old path missing") - ); - assert_eq!( - provider - .read_file("/workspace/nested/renamed.txt") - .await - .expect("read renamed"), - "hello world" - ); - provider - .rm("/workspace/nested/renamed.txt", false, false) - .await - .expect("remove"); - assert!( - !provider - .exists("/workspace/nested/renamed.txt") - .await - .expect("removed missing") - ); - provider - .rm("/workspace/nested/missing.txt", false, true) - .await - .expect("forced remove"); - assert!(matches!( - provider.stat("/workspace/nested/missing.txt").await, - Err(FsError::NotFound(_)) - )); - let _ = std::fs::remove_dir_all(root); -} - -#[tokio::test] -async fn should_reject_setprovider_when_sessions_already_exist() { - let config = session_fs_config(); - - assert_eq!(config.initial_cwd, "/"); - assert_eq!(config.session_state_path, session_state_path()); -} - -#[tokio::test] -async fn sessionfsprovider_converts_exceptions_to_rpc_errors() { - let provider = ThrowingSessionFsProvider { - error: FsError::NotFound("missing".to_string()), - }; - assert!(matches!( - provider.read_file("missing.txt").await, - Err(FsError::NotFound(message)) if message.contains("missing") - )); - assert!( - !provider - .exists("missing.txt") - .await - .expect("exists maps errors to false") - ); - assert!(matches!( - provider.write_file("missing.txt", "content", None).await, - Err(FsError::NotFound(message)) if message.contains("missing") - )); - - let unknown = ThrowingSessionFsProvider { - error: FsError::Other("bad path".to_string()), - }; - assert!(matches!( - unknown.write_file("bad.txt", "content", None).await, - Err(FsError::Other(message)) if message.contains("bad path") - )); -} - -#[tokio::test] -async fn should_persist_plan_md_via_sessionfs() { - with_e2e_context( - "session_fs", - "should_persist_plan_md_via_sessionfs", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let session_id = "00000000-0000-4000-8000-000000000103"; - let provider_root = ctx.work_dir().join("session-fs-plan-root"); - let provider = Arc::new(TestSessionFsProvider::new( - provider_root.clone(), - session_id, - )); - let client = start_session_fs_client(ctx, provider.clone()).await; - let session = client - .create_session(session_config(ctx, provider).with_session_id(session_id)) - .await - .expect("create session"); - - session.send_and_wait("What is 2 + 3?").await.expect("send"); - session - .rpc() - .plan() - .update(PlanUpdateRequest { - content: "# Test Plan\n\nThis is a test.".to_string(), - }) - .await - .expect("update plan"); - let plan_path = provider_root - .join(session.id().as_ref()) - .join(provider_relative_path(&session_state_path())) - .join("plan.md"); - wait_for_file_containing(&plan_path, "This is a test.").await; - assert!( - std::fs::read_to_string(plan_path) - .expect("read plan") - .contains("This is a test.") - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_map_large_output_handling_into_sessionfs() { - let root = PathBuf::from("target").join("session-fs-large-output"); - if root.exists() { - std::fs::remove_dir_all(&root).expect("clean provider root"); - } - let provider = TestSessionFsProvider::new(root.clone(), "large-output-session"); - let content = "x".repeat(100_000); - - provider - .write_file("/session-state/temp/large.txt", &content, None) - .await - .expect("write large content"); - - assert_eq!( - provider - .read_file("/session-state/temp/large.txt") - .await - .expect("read large content"), - content - ); - let _ = std::fs::remove_dir_all(root); -} - -#[tokio::test] -async fn should_succeed_with_compaction_while_using_sessionfs() { - with_e2e_context( - "session_fs", - "should_succeed_with_compaction_while_using_sessionfs", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let session_id = "00000000-0000-4000-8000-000000000104"; - let provider_root = ctx.work_dir().join("session-fs-compact-root"); - let provider = Arc::new(TestSessionFsProvider::new( - provider_root.clone(), - session_id, - )); - let client = start_session_fs_client(ctx, provider.clone()).await; - let session = client - .create_session(session_config(ctx, provider).with_session_id(session_id)) - .await - .expect("create session"); - - session.send_and_wait("What is 2+2?").await.expect("send"); - let result = session - .rpc() - .history() - .compact() - .await - .expect("compact history"); - assert!(result.success); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_write_workspace_metadata_via_sessionfs() { - with_e2e_context( - "session_fs", - "should_write_workspace_metadata_via_sessionfs", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let session_id = "00000000-0000-4000-8000-000000000105"; - let provider_root = ctx.work_dir().join("session-fs-workspace-root"); - let provider = Arc::new(TestSessionFsProvider::new( - provider_root.clone(), - session_id, - )); - let client = start_session_fs_client(ctx, provider.clone()).await; - let session = client - .create_session(session_config(ctx, provider).with_session_id(session_id)) - .await - .expect("create session"); - - let answer = session - .send_and_wait("What is 7 * 8?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains("56")); - let workspace_path = provider_root - .join(session.id().as_ref()) - .join(provider_relative_path(&session_state_path())) - .join("workspace.yaml"); - wait_for_file_containing(&workspace_path, session.id().as_ref()).await; - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -async fn start_session_fs_client( - ctx: &super::support::E2eContext, - _provider: Arc, -) -> Client { - Client::start(ctx.client_options().with_session_fs(session_fs_config())) - .await - .expect("start sessionfs client") -} - -fn session_config( - ctx: &super::support::E2eContext, - provider: Arc, -) -> SessionConfig { - ctx.approve_all_session_config() - .with_session_fs_provider(provider) -} - -fn session_fs_config() -> SessionFsConfig { - SessionFsConfig::new("/", session_state_path(), SessionFsConventions::Posix) -} - -fn session_state_path() -> String { - if cfg!(windows) { - "/session-state".to_string() - } else { - std::env::temp_dir() - .join("copilot-rust-sessionfs-state") - .join("session-state") - .to_string_lossy() - .replace('\\', "/") - } -} - -fn provider_relative_path(path: &str) -> PathBuf { - PathBuf::from(path.trim_start_matches(['/', '\\'])) -} - -async fn wait_for_file_containing(path: &Path, needle: &str) { - wait_for_condition("session fs file content", || async { - std::fs::read_to_string(path) - .map(|content| content.contains(needle)) - .unwrap_or(false) - }) - .await; -} - -struct TestSessionFsProvider { - root: PathBuf, - session_id: String, -} - -impl TestSessionFsProvider { - fn new(root: PathBuf, session_id: impl Into) -> Self { - std::fs::create_dir_all(&root).expect("create provider root"); - Self { - root, - session_id: session_id.into(), - } - } - - fn resolve(&self, path: &str) -> Result { - let root = std::fs::canonicalize(&self.root).map_err(FsError::from)?; - let mut full = root.clone(); - if self.session_id.is_empty() - || self.session_id == "." - || self.session_id == ".." - || self.session_id.contains('/') - || self.session_id.contains('\\') - || self.session_id.contains(':') - { - return Err(FsError::Other(format!( - "invalid sessionfs session id: {}", - self.session_id - ))); - } - full.push(&self.session_id); - for segment in path - .trim_start_matches(['/', '\\']) - .split(['/', '\\']) - .filter(|segment| !segment.is_empty()) - { - if segment == "." || segment == ".." || segment.contains(':') { - return Err(FsError::Other(format!("invalid sessionfs path: {path}"))); - } - full.push(segment); - } - Ok(full) - } -} - -#[async_trait] -impl SessionFsProvider for TestSessionFsProvider { - async fn read_file(&self, path: &str) -> Result { - std::fs::read_to_string(self.resolve(path)?).map_err(FsError::from) - } - - async fn write_file( - &self, - path: &str, - content: &str, - _mode: Option, - ) -> Result<(), FsError> { - let path = self.resolve(path)?; - if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent).map_err(FsError::from)?; - } - std::fs::write(path, content).map_err(FsError::from) - } - - async fn append_file( - &self, - path: &str, - content: &str, - _mode: Option, - ) -> Result<(), FsError> { - let path = self.resolve(path)?; - if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent).map_err(FsError::from)?; - } - use std::io::Write; - let mut file = std::fs::OpenOptions::new() - .create(true) - .append(true) - .open(path) - .map_err(FsError::from)?; - file.write_all(content.as_bytes()).map_err(FsError::from) - } - - async fn exists(&self, path: &str) -> Result { - Ok(self.resolve(path)?.exists()) - } - - async fn stat(&self, path: &str) -> Result { - let path = self.resolve(path)?; - let metadata = std::fs::metadata(path).map_err(FsError::from)?; - Ok(FileInfo::new( - metadata.is_file(), - metadata.is_dir(), - metadata.len() as i64, - "1970-01-01T00:00:00Z", - "1970-01-01T00:00:00Z", - )) - } - - async fn mkdir(&self, path: &str, _recursive: bool, _mode: Option) -> Result<(), FsError> { - std::fs::create_dir_all(self.resolve(path)?).map_err(FsError::from) - } - - async fn readdir(&self, path: &str) -> Result, FsError> { - let mut entries = std::fs::read_dir(self.resolve(path)?) - .map_err(FsError::from)? - .map(|entry| { - entry - .map_err(FsError::from) - .map(|entry| entry.file_name().to_string_lossy().into_owned()) - }) - .collect::, _>>()?; - entries.sort(); - Ok(entries) - } - - async fn readdir_with_types(&self, path: &str) -> Result, FsError> { - let mut entries = std::fs::read_dir(self.resolve(path)?) - .map_err(FsError::from)? - .map(|entry| { - let entry = entry.map_err(FsError::from)?; - let kind = if entry.file_type().map_err(FsError::from)?.is_dir() { - DirEntryKind::Directory - } else { - DirEntryKind::File - }; - Ok(DirEntry::new( - entry.file_name().to_string_lossy().into_owned(), - kind, - )) - }) - .collect::, FsError>>()?; - entries.sort_by(|left, right| left.name.cmp(&right.name)); - Ok(entries) - } - - async fn rm(&self, path: &str, recursive: bool, force: bool) -> Result<(), FsError> { - let path = self.resolve(path)?; - if path.is_file() { - return std::fs::remove_file(path).map_err(FsError::from); - } - if path.is_dir() { - if recursive { - return std::fs::remove_dir_all(path).map_err(FsError::from); - } - return std::fs::remove_dir(path).map_err(FsError::from); - } - if force { - Ok(()) - } else { - Err(FsError::NotFound(format!("not found: {}", path.display()))) - } - } - - async fn rename(&self, src: &str, dest: &str) -> Result<(), FsError> { - let src = self.resolve(src)?; - let dest = self.resolve(dest)?; - if let Some(parent) = dest.parent() { - std::fs::create_dir_all(parent).map_err(FsError::from)?; - } - std::fs::rename(src, dest).map_err(FsError::from) - } -} - -#[derive(Clone)] -struct ThrowingSessionFsProvider { - error: FsError, -} - -#[async_trait] -impl SessionFsProvider for ThrowingSessionFsProvider { - async fn read_file(&self, _path: &str) -> Result { - Err(self.error.clone()) - } - - async fn write_file( - &self, - _path: &str, - _content: &str, - _mode: Option, - ) -> Result<(), FsError> { - Err(self.error.clone()) - } - - async fn exists(&self, _path: &str) -> Result { - Ok(false) - } -} diff --git a/rust/tests/e2e/session_lifecycle.rs b/rust/tests/e2e/session_lifecycle.rs index e3c1fcd44..59cec701f 100644 --- a/rust/tests/e2e/session_lifecycle.rs +++ b/rust/tests/e2e/session_lifecycle.rs @@ -120,7 +120,7 @@ async fn should_return_events_via_getmessages_after_conversation() { .await .expect("send"); - let messages = session.get_messages().await.expect("get messages"); + let messages = session.get_events().await.expect("get messages"); let types = event_types(&messages); assert!(types.contains(&"session.start")); assert!(types.contains(&"user.message")); diff --git a/rust/tests/e2e/streaming_fidelity.rs b/rust/tests/e2e/streaming_fidelity.rs index 72e0554ac..4e0f26ec4 100644 --- a/rust/tests/e2e/streaming_fidelity.rs +++ b/rust/tests/e2e/streaming_fidelity.rs @@ -131,7 +131,7 @@ async fn should_produce_deltas_after_session_resume() { .resume_session( ResumeSessionConfig::new(session_id) .with_streaming(true) - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_github_token(super::support::DEFAULT_TEST_TOKEN), ) .await @@ -188,7 +188,7 @@ async fn should_not_produce_deltas_after_session_resume_with_streaming_disabled( .resume_session( ResumeSessionConfig::new(session_id) .with_streaming(false) - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_github_token(super::support::DEFAULT_TEST_TOKEN), ) .await @@ -260,7 +260,7 @@ async fn should_emit_streaming_deltas_with_reasoning_effort_configured() { assert!(assistant.content.contains("255")); let start = session - .get_messages() + .get_events() .await .expect("get messages") .into_iter() diff --git a/rust/tests/e2e/support.rs b/rust/tests/e2e/support.rs index e08e3535a..29eec0aa4 100644 --- a/rust/tests/e2e/support.rs +++ b/rust/tests/e2e/support.rs @@ -120,17 +120,17 @@ impl E2eContext { #[expect(dead_code, reason = "used by follow-on E2E ports")] pub async fn start_tcp_client(&self, port: u16, token: &str) -> Client { - Client::start( - self.client_options_with_transport(Transport::Tcp { port }) - .with_tcp_connection_token(token), - ) + Client::start(self.client_options_with_transport(Transport::Tcp { + port, + connection_token: Some(token.to_string()), + })) .await .expect("start TCP E2E client") } pub fn approve_all_session_config(&self) -> SessionConfig { SessionConfig::default() - .with_handler(std::sync::Arc::new(ApproveAllHandler)) + .with_permission_handler(std::sync::Arc::new(ApproveAllHandler)) .with_github_token(DEFAULT_TEST_TOKEN) } @@ -411,7 +411,7 @@ pub async fn wait_for_final_assistant_message(session: &Session) -> SessionEvent #[allow(dead_code, reason = "used by follow-on E2E ports")] pub async fn last_assistant_message(session: &Session) -> SessionEvent { session - .get_messages() + .get_events() .await .expect("get session messages") .into_iter() diff --git a/rust/tests/e2e/suspend.rs b/rust/tests/e2e/suspend.rs index 5a9386147..8b1378917 100644 --- a/rust/tests/e2e/suspend.rs +++ b/rust/tests/e2e/suspend.rs @@ -1,88 +1 @@ -use std::sync::Arc; -use github_copilot_sdk::ResumeSessionConfig; - -use super::support::{DEFAULT_TEST_TOKEN, assistant_message_content, with_e2e_context}; - -#[tokio::test] -async fn should_suspend_idle_session_without_throwing() { - with_e2e_context( - "suspend", - "should_suspend_idle_session_without_throwing", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - session - .send_and_wait("Reply with: SUSPEND_IDLE_OK") - .await - .expect("send"); - session.rpc().suspend().await.expect("suspend session"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_allow_resume_and_continue_conversation_after_suspend() { - with_e2e_context( - "suspend", - "should_allow_resume_and_continue_conversation_after_suspend", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - session - .send_and_wait( - "Remember the magic word: SUSPENSE. Reply with: SUSPEND_TURN_ONE", - ) - .await - .expect("first send"); - let session_id = session.id().clone(); - session.rpc().suspend().await.expect("suspend session"); - session.disconnect().await.expect("disconnect first session"); - client.stop().await.expect("stop first client"); - - let second_client = ctx.start_client().await; - let resumed = second_client - .resume_session( - ResumeSessionConfig::new(session_id) - .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )), - ) - .await - .expect("resume session"); - let answer = resumed - .send_and_wait( - "What was the magic word I asked you to remember? Reply with just the word.", - ) - .await - .expect("follow-up send") - .expect("assistant message"); - assert!(assistant_message_content(&answer) - .to_lowercase() - .contains("suspense")); - - resumed.disconnect().await.expect("disconnect resumed"); - second_client.stop().await.expect("stop second client"); - }) - }, - ) - .await; -} diff --git a/rust/tests/e2e/system_message_transform.rs b/rust/tests/e2e/system_message_transform.rs index 10cc594ca..8b1378917 100644 --- a/rust/tests/e2e/system_message_transform.rs +++ b/rust/tests/e2e/system_message_transform.rs @@ -1,187 +1 @@ -use std::collections::HashMap; -use std::sync::Arc; -use async_trait::async_trait; -use github_copilot_sdk::transforms::{SystemMessageTransform, TransformContext}; -use github_copilot_sdk::{SectionOverride, SessionConfig, SystemMessageConfig}; -use tokio::sync::mpsc; - -use super::support::{DEFAULT_TEST_TOKEN, get_system_message, recv_with_timeout, with_e2e_context}; - -#[tokio::test] -async fn should_invoke_transform_callbacks_with_section_content() { - with_e2e_context( - "system_message_transform", - "should_invoke_transform_callbacks_with_section_content", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - std::fs::write(ctx.work_dir().join("test.txt"), "Hello transform!") - .expect("write test file"); - let (section_tx, mut section_rx) = mpsc::unbounded_channel(); - let client = ctx.start_client().await; - let session = client - .create_session( - SessionConfig::default() - .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) - .with_transform(Arc::new(RecordingTransform { - section_ids: vec!["identity", "tone"], - suffix: None, - section_tx, - })), - ) - .await - .expect("create session"); - - session - .send_and_wait("Read the contents of test.txt and tell me what it says") - .await - .expect("send"); - - let first = recv_with_timeout(&mut section_rx, "first transform").await; - let second = recv_with_timeout(&mut section_rx, "second transform").await; - assert!(first.1 > 0); - assert!(second.1 > 0); - let sections = [first.0, second.0]; - assert!(sections.contains(&"identity".to_string())); - assert!(sections.contains(&"tone".to_string())); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_apply_transform_modifications_to_section_content() { - with_e2e_context( - "system_message_transform", - "should_apply_transform_modifications_to_section_content", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - std::fs::write(ctx.work_dir().join("hello.txt"), "Hello!") - .expect("write hello file"); - let (section_tx, _section_rx) = mpsc::unbounded_channel(); - let client = ctx.start_client().await; - let session = client - .create_session( - SessionConfig::default() - .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) - .with_transform(Arc::new(RecordingTransform { - section_ids: vec!["identity"], - suffix: Some("\nAlways end your reply with TRANSFORM_MARKER"), - section_tx, - })), - ) - .await - .expect("create session"); - - session - .send_and_wait("Read the contents of hello.txt") - .await - .expect("send"); - - let exchanges = ctx.exchanges(); - assert!(!exchanges.is_empty()); - assert!(get_system_message(&exchanges[0]).contains("TRANSFORM_MARKER")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_work_with_static_overrides_and_transforms_together() { - with_e2e_context( - "system_message_transform", - "should_work_with_static_overrides_and_transforms_together", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - std::fs::write(ctx.work_dir().join("combo.txt"), "Combo test!") - .expect("write combo file"); - let (section_tx, mut section_rx) = mpsc::unbounded_channel(); - let mut sections = HashMap::new(); - sections.insert( - "safety".to_string(), - SectionOverride { - action: Some("remove".to_string()), - content: None, - }, - ); - let client = ctx.start_client().await; - let session = client - .create_session( - SessionConfig::default() - .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) - .with_system_message( - SystemMessageConfig::new() - .with_mode("customize") - .with_sections(sections), - ) - .with_transform(Arc::new(RecordingTransform { - section_ids: vec!["identity"], - suffix: None, - section_tx, - })), - ) - .await - .expect("create session"); - - session - .send_and_wait("Read the contents of combo.txt and tell me what it says") - .await - .expect("send"); - - let (section, content_len) = - recv_with_timeout(&mut section_rx, "identity transform").await; - assert_eq!(section, "identity"); - assert!(content_len > 0); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -struct RecordingTransform { - section_ids: Vec<&'static str>, - suffix: Option<&'static str>, - section_tx: mpsc::UnboundedSender<(String, usize)>, -} - -#[async_trait] -impl SystemMessageTransform for RecordingTransform { - fn section_ids(&self) -> Vec { - self.section_ids - .iter() - .map(|section| (*section).to_string()) - .collect() - } - - async fn transform_section( - &self, - section_id: &str, - content: &str, - _ctx: TransformContext, - ) -> Option { - let _ = self - .section_tx - .send((section_id.to_string(), content.len())); - Some(match self.suffix { - Some(suffix) => format!("{content}{suffix}"), - None => content.to_string(), - }) - } -} diff --git a/rust/tests/e2e/telemetry.rs b/rust/tests/e2e/telemetry.rs index 0685ac284..10111be52 100644 --- a/rust/tests/e2e/telemetry.rs +++ b/rust/tests/e2e/telemetry.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use async_trait::async_trait; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::{ Client, Error, OtelExporterType, SessionConfig, TelemetryConfig, Tool, ToolInvocation, ToolResult, @@ -36,19 +36,22 @@ async fn should_export_file_telemetry_for_sdk_interactions() { )) .await .expect("start client"); - let router = ToolHandlerRouter::new( - vec![Box::new(EchoTelemetryTool { - name: tool_name.to_string(), - })], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); + let echo_tool = Tool::new(tool_name) + .with_description("Echoes a marker string for telemetry validation.") + .with_parameters(json!({ + "type": "object", + "properties": { + "value": { "type": "string" } + }, + "required": ["value"] + })) + .with_handler(Arc::new(EchoTelemetryTool)); let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) - .with_tools(tools), + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(vec![echo_tool]), ) .await .expect("create session"); @@ -136,24 +139,10 @@ async fn should_export_file_telemetry_for_sdk_interactions() { .await; } -struct EchoTelemetryTool { - name: String, -} +struct EchoTelemetryTool; #[async_trait] impl ToolHandler for EchoTelemetryTool { - fn tool(&self) -> Tool { - Tool::new(&self.name) - .with_description("Echoes a marker string for telemetry validation.") - .with_parameters(json!({ - "type": "object", - "properties": { - "value": { "type": "string" } - }, - "required": ["value"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { Ok(ToolResult::Text( invocation diff --git a/rust/tests/e2e/tool_results.rs b/rust/tests/e2e/tool_results.rs index 260e25993..4b731c286 100644 --- a/rust/tests/e2e/tool_results.rs +++ b/rust/tests/e2e/tool_results.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use github_copilot_sdk::generated::session_events::{SessionEventType, ToolExecutionCompleteData}; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::{ Error, SessionConfig, Tool, ToolInvocation, ToolResult, ToolResultExpanded, }; @@ -21,7 +21,7 @@ async fn should_handle_structured_toolresultobject_from_custom_tool() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let session = create_tool_session(ctx, &client, WeatherTool).await; + let session = create_tool_session(ctx, &client, weather_tool()).await; let answer = session .send_and_wait("What's the weather in Paris?") @@ -48,7 +48,7 @@ async fn should_handle_tool_result_with_failure_resulttype() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let session = create_tool_session(ctx, &client, CheckStatusTool).await; + let session = create_tool_session(ctx, &client, check_status_tool()).await; let answer = session .send_and_wait("Check the status of the service using check_status. If it fails, say 'service is down'.") @@ -76,7 +76,7 @@ async fn should_preserve_tooltelemetry_and_not_stringify_structured_results_for_ Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let session = create_tool_session(ctx, &client, AnalyzeCodeTool).await; + let session = create_tool_session(ctx, &client, analyze_code_tool()).await; let answer = session .send_and_wait("Analyze the file main.ts for issues.") @@ -124,7 +124,7 @@ async fn should_handle_tool_result_with_rejected_resulttype() { ctx.set_default_copilot_user(); let client = ctx.start_client().await; let (call_tx, mut call_rx) = mpsc::unbounded_channel(); - let session = create_tool_session(ctx, &client, DeployTool { call_tx }).await; + let session = create_tool_session(ctx, &client, deploy_tool(call_tx)).await; let events = session.subscribe(); session @@ -161,7 +161,7 @@ async fn should_handle_tool_result_with_denied_resulttype() { ctx.set_default_copilot_user(); let client = ctx.start_client().await; let (call_tx, mut call_rx) = mpsc::unbounded_channel(); - let session = create_tool_session(ctx, &client, AccessSecretTool { call_tx }).await; + let session = create_tool_session(ctx, &client, access_secret_tool(call_tx)).await; let events = session.subscribe(); session @@ -188,22 +188,18 @@ async fn should_handle_tool_result_with_denied_resulttype() { .await; } -async fn create_tool_session( +async fn create_tool_session( _ctx: &super::support::E2eContext, client: &github_copilot_sdk::Client, - tool: T, -) -> github_copilot_sdk::session::Session -where - T: ToolHandler + 'static, -{ - let router = ToolHandlerRouter::new(vec![Box::new(tool)], Arc::new(ApproveAllHandler)); - let tools = router.tools(); + tool: Tool, +) -> github_copilot_sdk::session::Session { + let __perm = Arc::new(ApproveAllHandler); client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) - .with_tools(tools), + .with_permission_handler(__perm) + .with_tools(vec![tool]), ) .await .expect("create session") @@ -227,19 +223,20 @@ fn expanded(text: impl Into, result_type: impl Into) -> ToolResu }) } +fn weather_tool() -> Tool { + string_tool( + "get_weather", + "Gets weather for a city", + "city", + "City name", + ) + .with_handler(Arc::new(WeatherTool)) +} + struct WeatherTool; #[async_trait::async_trait] impl ToolHandler for WeatherTool { - fn tool(&self) -> Tool { - string_tool( - "get_weather", - "Gets weather for a city", - "city", - "City name", - ) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let city = invocation .arguments @@ -253,14 +250,16 @@ impl ToolHandler for WeatherTool { } } +fn check_status_tool() -> Tool { + Tool::new("check_status") + .with_description("Checks the status of a service") + .with_handler(Arc::new(CheckStatusTool)) +} + struct CheckStatusTool; #[async_trait::async_trait] impl ToolHandler for CheckStatusTool { - fn tool(&self) -> Tool { - Tool::new("check_status").with_description("Checks the status of a service") - } - async fn call(&self, _invocation: ToolInvocation) -> Result { let mut result = match expanded("Service unavailable", "failure") { ToolResult::Expanded(result) => result, @@ -271,19 +270,20 @@ impl ToolHandler for CheckStatusTool { } } +fn analyze_code_tool() -> Tool { + string_tool( + "analyze_code", + "Analyzes code for issues", + "file", + "File to analyze", + ) + .with_handler(Arc::new(AnalyzeCodeTool)) +} + struct AnalyzeCodeTool; #[async_trait::async_trait] impl ToolHandler for AnalyzeCodeTool { - fn tool(&self) -> Tool { - string_tool( - "analyze_code", - "Analyzes code for issues", - "file", - "File to analyze", - ) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let file = invocation .arguments @@ -302,16 +302,18 @@ impl ToolHandler for AnalyzeCodeTool { } } +fn deploy_tool(call_tx: mpsc::UnboundedSender<()>) -> Tool { + Tool::new("deploy_service") + .with_description("Deploys a service") + .with_handler(Arc::new(DeployTool { call_tx })) +} + struct DeployTool { call_tx: mpsc::UnboundedSender<()>, } #[async_trait::async_trait] impl ToolHandler for DeployTool { - fn tool(&self) -> Tool { - Tool::new("deploy_service").with_description("Deploys a service") - } - async fn call(&self, _invocation: ToolInvocation) -> Result { let _ = self.call_tx.send(()); Ok(expanded( @@ -321,16 +323,18 @@ impl ToolHandler for DeployTool { } } +fn access_secret_tool(call_tx: mpsc::UnboundedSender<()>) -> Tool { + Tool::new("access_secret") + .with_description("Accesses a secret") + .with_handler(Arc::new(AccessSecretTool { call_tx })) +} + struct AccessSecretTool { call_tx: mpsc::UnboundedSender<()>, } #[async_trait::async_trait] impl ToolHandler for AccessSecretTool { - fn tool(&self) -> Tool { - Tool::new("access_secret").with_description("Accesses a secret") - } - async fn call(&self, _invocation: ToolInvocation) -> Result { let _ = self.call_tx.send(()); Ok(expanded( diff --git a/rust/tests/e2e/tools.rs b/rust/tests/e2e/tools.rs index 19cc40249..327058a4d 100644 --- a/rust/tests/e2e/tools.rs +++ b/rust/tests/e2e/tools.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use github_copilot_sdk::handler::{ApproveAllHandler, PermissionResult, SessionHandler}; -use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::handler::{ApproveAllHandler, PermissionHandler, PermissionResult}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::{ Error, PermissionRequestData, RequestId, SessionConfig, SessionId, Tool, ToolInvocation, ToolResult, @@ -47,16 +47,13 @@ async fn invokes_custom_tool() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let router = ToolHandlerRouter::new( - vec![Box::new(EncryptStringTool)], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); + let __perm = Arc::new(ApproveAllHandler); + let tools = vec![encrypt_string_tool()]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) .with_tools(tools), ) .await @@ -82,14 +79,13 @@ async fn handles_tool_calling_errors() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let router = - ToolHandlerRouter::new(vec![Box::new(ErrorTool)], Arc::new(ApproveAllHandler)); - let tools = router.tools(); + let __perm = Arc::new(ApproveAllHandler); + let tools = vec![error_tool()]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) .with_tools(tools), ) .await @@ -132,16 +128,13 @@ async fn can_receive_and_return_complex_types() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let router = ToolHandlerRouter::new( - vec![Box::new(DbQueryTool)], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); + let __perm = Arc::new(ApproveAllHandler); + let tools = vec![db_query_tool()]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) .with_tools(tools), ) .await @@ -174,14 +167,13 @@ async fn overrides_built_in_tool_with_custom_tool() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let router = - ToolHandlerRouter::new(vec![Box::new(CustomGrepTool)], Arc::new(ApproveAllHandler)); - let tools = router.tools(); + let __perm = Arc::new(ApproveAllHandler); + let tools = vec![custom_grep_tool()]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) .with_tools(tools), ) .await @@ -212,13 +204,13 @@ async fn skippermission_sent_in_tool_definition() { permission_tx, decision: PermissionResult::Denied, }); - let router = ToolHandlerRouter::new(vec![Box::new(SafeLookupTool)], handler); - let tools = router.tools(); + let __perm = handler; + let tools = vec![safe_lookup_tool()]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) .with_tools(tools), ) .await @@ -262,13 +254,13 @@ async fn invokes_custom_tool_with_permission_handler() { permission_tx, decision: PermissionResult::Approved, }); - let router = ToolHandlerRouter::new(vec![Box::new(EncryptStringTool)], handler); - let tools = router.tools(); + let __perm = handler; + let tools = vec![encrypt_string_tool()]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) .with_tools(tools), ) .await @@ -306,16 +298,13 @@ async fn denies_custom_tool_when_permission_denied() { permission_tx, decision: PermissionResult::Denied, }); - let router = ToolHandlerRouter::new( - vec![Box::new(TrackedEncryptStringTool { call_tx })], - handler, - ); - let tools = router.tools(); + let __perm = handler; + let tools = vec![tracked_encrypt_string_tool(call_tx)]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) .with_tools(tools), ) .await @@ -351,21 +340,16 @@ async fn should_execute_multiple_custom_tools_in_parallel_single_turn() { let client = ctx.start_client().await; let (city_tx, mut city_rx) = mpsc::unbounded_channel(); let (country_tx, mut country_rx) = mpsc::unbounded_channel(); - let router = ToolHandlerRouter::new( - vec![ - Box::new(LookupCityTool { call_tx: city_tx }), - Box::new(LookupCountryTool { - call_tx: country_tx, - }), - ], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); + let __perm = Arc::new(ApproveAllHandler); + let tools = vec![ + lookup_city_tool(city_tx), + lookup_country_tool(country_tx), + ]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) .with_tools(tools), ) .await @@ -403,21 +387,13 @@ async fn should_respect_availabletools_and_excludedtools_combined() { ctx.set_default_copilot_user(); let client = ctx.start_client().await; let (excluded_tx, mut excluded_rx) = mpsc::unbounded_channel(); - let router = ToolHandlerRouter::new( - vec![ - Box::new(AllowedTool), - Box::new(ExcludedTool { - call_tx: excluded_tx, - }), - ], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); + let __perm = Arc::new(ApproveAllHandler); + let tools = vec![allowed_tool(), excluded_tool(excluded_tx)]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) .with_tools(tools) .with_available_tools(["allowed_tool", "excluded_tool"]) .with_excluded_tools(["excluded_tool"]), @@ -450,23 +426,24 @@ async fn should_respect_availabletools_and_excludedtools_combined() { struct EncryptStringTool; +fn encrypt_string_tool() -> Tool { + Tool::new("encrypt_string") + .with_description("Encrypts a string") + .with_parameters(json!({ + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "String to encrypt" + } + }, + "required": ["input"] + })) + .with_handler(Arc::new(EncryptStringTool)) +} + #[async_trait::async_trait] impl ToolHandler for EncryptStringTool { - fn tool(&self) -> Tool { - Tool::new("encrypt_string") - .with_description("Encrypts a string") - .with_parameters(json!({ - "type": "object", - "properties": { - "input": { - "type": "string", - "description": "String to encrypt" - } - }, - "required": ["input"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let input = invocation .arguments @@ -481,12 +458,24 @@ struct TrackedEncryptStringTool { call_tx: mpsc::UnboundedSender<()>, } +fn tracked_encrypt_string_tool(call_tx: mpsc::UnboundedSender<()>) -> Tool { + Tool::new("encrypt_string") + .with_description("Encrypts a string") + .with_parameters(json!({ + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "String to encrypt" + } + }, + "required": ["input"] + })) + .with_handler(Arc::new(TrackedEncryptStringTool { call_tx })) +} + #[async_trait::async_trait] impl ToolHandler for TrackedEncryptStringTool { - fn tool(&self) -> Tool { - EncryptStringTool.tool() - } - async fn call(&self, invocation: ToolInvocation) -> Result { let _ = self.call_tx.send(()); EncryptStringTool.call(invocation).await @@ -495,12 +484,14 @@ impl ToolHandler for TrackedEncryptStringTool { struct ErrorTool; +fn error_tool() -> Tool { + Tool::new("get_user_location") + .with_description("Gets the user's location") + .with_handler(Arc::new(ErrorTool)) +} + #[async_trait::async_trait] impl ToolHandler for ErrorTool { - fn tool(&self) -> Tool { - Tool::new("get_user_location").with_description("Gets the user's location") - } - async fn call(&self, _invocation: ToolInvocation) -> Result { Ok(ToolResult::Text( "Failed to execute `get_user_location` tool with arguments: {} due to error: Error: Tool execution failed" @@ -511,21 +502,22 @@ impl ToolHandler for ErrorTool { struct CustomGrepTool; +fn custom_grep_tool() -> Tool { + Tool::new("grep") + .with_description("A custom grep implementation that overrides the built-in") + .with_overrides_built_in_tool(true) + .with_parameters(json!({ + "type": "object", + "properties": { + "query": { "type": "string", "description": "Search query" } + }, + "required": ["query"] + })) + .with_handler(Arc::new(CustomGrepTool)) +} + #[async_trait::async_trait] impl ToolHandler for CustomGrepTool { - fn tool(&self) -> Tool { - Tool::new("grep") - .with_description("A custom grep implementation that overrides the built-in") - .with_overrides_built_in_tool(true) - .with_parameters(json!({ - "type": "object", - "properties": { - "query": { "type": "string", "description": "Search query" } - }, - "required": ["query"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let query = invocation .arguments @@ -538,21 +530,22 @@ impl ToolHandler for CustomGrepTool { struct SafeLookupTool; +fn safe_lookup_tool() -> Tool { + Tool::new("safe_lookup") + .with_description("A tool that skips permission") + .with_skip_permission(true) + .with_parameters(json!({ + "type": "object", + "properties": { + "id": { "type": "string", "description": "Lookup ID" } + }, + "required": ["id"] + })) + .with_handler(Arc::new(SafeLookupTool)) +} + #[async_trait::async_trait] impl ToolHandler for SafeLookupTool { - fn tool(&self) -> Tool { - Tool::new("safe_lookup") - .with_description("A tool that skips permission") - .with_skip_permission(true) - .with_parameters(json!({ - "type": "object", - "properties": { - "id": { "type": "string", "description": "Lookup ID" } - }, - "required": ["id"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let id = invocation .arguments @@ -567,20 +560,21 @@ struct LookupCityTool { call_tx: mpsc::UnboundedSender, } +fn lookup_city_tool(call_tx: mpsc::UnboundedSender) -> Tool { + Tool::new("lookup_city") + .with_description("Looks up city information") + .with_parameters(json!({ + "type": "object", + "properties": { + "city": { "type": "string", "description": "City name" } + }, + "required": ["city"] + })) + .with_handler(Arc::new(LookupCityTool { call_tx })) +} + #[async_trait::async_trait] impl ToolHandler for LookupCityTool { - fn tool(&self) -> Tool { - Tool::new("lookup_city") - .with_description("Looks up city information") - .with_parameters(json!({ - "type": "object", - "properties": { - "city": { "type": "string", "description": "City name" } - }, - "required": ["city"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let city = invocation .arguments @@ -597,20 +591,21 @@ struct LookupCountryTool { call_tx: mpsc::UnboundedSender, } +fn lookup_country_tool(call_tx: mpsc::UnboundedSender) -> Tool { + Tool::new("lookup_country") + .with_description("Looks up country information") + .with_parameters(json!({ + "type": "object", + "properties": { + "country": { "type": "string", "description": "Country name" } + }, + "required": ["country"] + })) + .with_handler(Arc::new(LookupCountryTool { call_tx })) +} + #[async_trait::async_trait] impl ToolHandler for LookupCountryTool { - fn tool(&self) -> Tool { - Tool::new("lookup_country") - .with_description("Looks up country information") - .with_parameters(json!({ - "type": "object", - "properties": { - "country": { "type": "string", "description": "Country name" } - }, - "required": ["country"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let country = invocation .arguments @@ -628,20 +623,21 @@ impl ToolHandler for LookupCountryTool { struct AllowedTool; +fn allowed_tool() -> Tool { + Tool::new("allowed_tool") + .with_description("An allowed tool") + .with_parameters(json!({ + "type": "object", + "properties": { + "input": { "type": "string", "description": "Input value" } + }, + "required": ["input"] + })) + .with_handler(Arc::new(AllowedTool)) +} + #[async_trait::async_trait] impl ToolHandler for AllowedTool { - fn tool(&self) -> Tool { - Tool::new("allowed_tool") - .with_description("An allowed tool") - .with_parameters(json!({ - "type": "object", - "properties": { - "input": { "type": "string", "description": "Input value" } - }, - "required": ["input"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let input = invocation .arguments @@ -659,20 +655,21 @@ struct ExcludedTool { call_tx: mpsc::UnboundedSender<()>, } +fn excluded_tool(call_tx: mpsc::UnboundedSender<()>) -> Tool { + Tool::new("excluded_tool") + .with_description("A tool that should be excluded") + .with_parameters(json!({ + "type": "object", + "properties": { + "input": { "type": "string", "description": "Input value" } + }, + "required": ["input"] + })) + .with_handler(Arc::new(ExcludedTool { call_tx })) +} + #[async_trait::async_trait] impl ToolHandler for ExcludedTool { - fn tool(&self) -> Tool { - Tool::new("excluded_tool") - .with_description("A tool that should be excluded") - .with_parameters(json!({ - "type": "object", - "properties": { - "input": { "type": "string", "description": "Input value" } - }, - "required": ["input"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let _ = self.call_tx.send(()); let input = invocation @@ -693,8 +690,8 @@ struct RecordingPermissionHandler { } #[async_trait::async_trait] -impl SessionHandler for RecordingPermissionHandler { - async fn on_permission_request( +impl PermissionHandler for RecordingPermissionHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -707,31 +704,32 @@ impl SessionHandler for RecordingPermissionHandler { struct DbQueryTool; -#[async_trait::async_trait] -impl ToolHandler for DbQueryTool { - fn tool(&self) -> Tool { - Tool::new("db_query") - .with_description("Performs a database query") - .with_parameters(json!({ - "type": "object", - "properties": { - "query": { - "type": "object", - "properties": { - "table": { "type": "string" }, - "ids": { - "type": "array", - "items": { "type": "integer" } - }, - "sortAscending": { "type": "boolean" } +fn db_query_tool() -> Tool { + Tool::new("db_query") + .with_description("Performs a database query") + .with_parameters(json!({ + "type": "object", + "properties": { + "query": { + "type": "object", + "properties": { + "table": { "type": "string" }, + "ids": { + "type": "array", + "items": { "type": "integer" } }, - "required": ["table", "ids", "sortAscending"] - } - }, - "required": ["query"] - })) - } + "sortAscending": { "type": "boolean" } + }, + "required": ["table", "ids", "sortAscending"] + } + }, + "required": ["query"] + })) + .with_handler(Arc::new(DbQueryTool)) +} +#[async_trait::async_trait] +impl ToolHandler for DbQueryTool { async fn call(&self, invocation: ToolInvocation) -> Result { let query = invocation.arguments.get("query").expect("query argument"); assert_eq!( diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index b9c28d30d..ffdb894eb 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -6,30 +6,23 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; use async_trait::async_trait; -use github_copilot_sdk::Client; use github_copilot_sdk::handler::{ - ApproveAllHandler, AutoModeSwitchResponse, ExitPlanModeResult, HandlerEvent, HandlerResponse, - PermissionResult, SessionHandler, UserInputResponse, + ApproveAllHandler, AutoModeSwitchHandler, AutoModeSwitchResponse, ElicitationHandler, + ExitPlanModeHandler, ExitPlanModeResult, PermissionHandler, PermissionResult, UserInputHandler, + UserInputResponse, }; use github_copilot_sdk::types::{ - CommandContext, CommandDefinition, CommandHandler, DeliveryMode, ExitPlanModeData, - MessageOptions, SessionConfig, SessionId, ToolResult, + CommandContext, CommandDefinition, CommandHandler, DeliveryMode, ElicitationRequest, + ElicitationResult, ExitPlanModeData, MessageOptions, PermissionRequestData, RequestId, + SessionConfig, SessionId, Tool, ToolInvocation, ToolResult, }; +use github_copilot_sdk::{Client, tool}; use serde_json::Value; use tokio::io::{AsyncWrite, AsyncWriteExt, duplex}; -use tokio::sync::mpsc; use tokio::time::timeout; const TIMEOUT: Duration = Duration::from_secs(2); -struct NoopHandler; -#[async_trait] -impl SessionHandler for NoopHandler { - async fn on_event(&self, _event: HandlerEvent) -> HandlerResponse { - HandlerResponse::Ok - } -} - async fn write_framed(writer: &mut (impl AsyncWrite + Unpin), body: &[u8]) { let header = format!("Content-Length: {}\r\n\r\n", body.len()); writer.write_all(header.as_bytes()).await.unwrap(); @@ -126,16 +119,32 @@ impl FakeServer { } } -async fn create_session_pair( - handler: Arc, -) -> (github_copilot_sdk::session::Session, FakeServer) { - create_session_pair_with_capabilities(handler, serde_json::json!(null)).await +async fn create_session_pair() -> (github_copilot_sdk::session::Session, FakeServer) { + create_session_pair_with_config(|cfg| cfg).await } async fn create_session_pair_with_capabilities( - handler: Arc, capabilities: Value, ) -> (github_copilot_sdk::session::Session, FakeServer) { + create_session_pair_inner(|cfg| cfg, capabilities).await +} + +async fn create_session_pair_with_config( + configure: F, +) -> (github_copilot_sdk::session::Session, FakeServer) +where + F: FnOnce(SessionConfig) -> SessionConfig + Send + 'static, +{ + create_session_pair_inner(configure, serde_json::json!(null)).await +} + +async fn create_session_pair_inner( + configure: F, + capabilities: Value, +) -> (github_copilot_sdk::session::Session, FakeServer) +where + F: FnOnce(SessionConfig) -> SessionConfig + Send + 'static, +{ let (client, server_read, server_write) = make_client(); let mut server = FakeServer { @@ -146,10 +155,9 @@ async fn create_session_pair_with_capabilities( let create_handle = tokio::spawn({ let client = client.clone(); - let handler = handler.clone(); async move { client - .create_session(SessionConfig::default().with_handler(handler)) + .create_session(configure(SessionConfig::default())) .await .unwrap() } @@ -184,7 +192,7 @@ fn requested_session_id(request: &Value) -> &str { #[tokio::test] async fn session_subscribe_yields_events_observe_only() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let mut events = session.subscribe(); let count = Arc::new(AtomicUsize::new(0)); @@ -216,7 +224,7 @@ async fn session_subscribe_yields_events_observe_only() { #[tokio::test] async fn session_subscribe_drop_stops_delivery() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let mut events = session.subscribe(); let count = Arc::new(AtomicUsize::new(0)); @@ -257,7 +265,7 @@ async fn create_session_sends_correct_rpc() { .create_session({ let mut cfg = SessionConfig::default(); cfg.model = Some("gpt-4".to_string()); - cfg.with_handler(Arc::new(NoopHandler)) + cfg }) .await .unwrap() @@ -284,7 +292,7 @@ async fn create_session_sends_correct_rpc() { #[tokio::test] async fn send_injects_session_id() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -310,7 +318,7 @@ async fn send_injects_session_id() { async fn send_serializes_request_headers() { use std::collections::HashMap; - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -343,7 +351,7 @@ async fn send_serializes_request_headers() { async fn send_omits_request_headers_when_unset_or_empty() { use std::collections::HashMap; - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -379,7 +387,7 @@ async fn send_omits_request_headers_when_unset_or_empty() { #[tokio::test] async fn session_rpc_methods_send_correct_method_names() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let cases: Vec<(&str, Option<&str>)> = vec![ @@ -394,7 +402,7 @@ async fn session_rpc_methods_send_correct_method_names() { match expected_method { "session.abort" => s.abort().await.map(|_| ()), "session.log" => s.log("test msg", None).await, - "session.destroy" => s.destroy().await, + "session.destroy" => s.disconnect().await, _ => unreachable!(), } }); @@ -544,7 +552,7 @@ fn mcp_server_config_roundtrips_through_tagged_enum() { args: vec!["server.js".to_string()], env: HashMap::new(), cwd: None, - tools: vec!["*".to_string()], + tools: Some(vec!["*".to_string()]), timeout: None, }); let json = serde_json::to_value(&stdio).unwrap(); @@ -566,6 +574,109 @@ fn mcp_server_config_roundtrips_through_tagged_enum() { assert_eq!(cfg_json["github"]["type"], "stdio"); } +#[test] +fn mcp_stdio_tools_tri_state_serializes_correctly() { + use github_copilot_sdk::McpStdioServerConfig; + + // None → field omitted (= "expose all tools") + let cfg = McpStdioServerConfig { + command: "echo".into(), + tools: None, + ..Default::default() + }; + let json = serde_json::to_value(&cfg).unwrap(); + assert!( + json.get("tools").is_none(), + "tools=None must be omitted on the wire; got {json}" + ); + + // Some(empty) → field present as [] + let cfg = McpStdioServerConfig { + command: "echo".into(), + tools: Some(vec![]), + ..Default::default() + }; + let json = serde_json::to_value(&cfg).unwrap(); + assert_eq!(json["tools"], serde_json::json!([])); + + // Some(non-empty) → field present as the explicit list + let cfg = McpStdioServerConfig { + command: "echo".into(), + tools: Some(vec!["a".into(), "b".into()]), + ..Default::default() + }; + let json = serde_json::to_value(&cfg).unwrap(); + assert_eq!(json["tools"], serde_json::json!(["a", "b"])); +} + +#[test] +fn mcp_stdio_tools_tri_state_deserializes_correctly() { + use github_copilot_sdk::McpStdioServerConfig; + + // Missing field → None + let cfg: McpStdioServerConfig = + serde_json::from_value(serde_json::json!({ "command": "echo" })).unwrap(); + assert_eq!(cfg.tools, None); + + // Empty list → Some(empty) + let cfg: McpStdioServerConfig = + serde_json::from_value(serde_json::json!({ "command": "echo", "tools": [] })).unwrap(); + assert_eq!(cfg.tools, Some(vec![])); + + // Non-empty list → Some(list) + let cfg: McpStdioServerConfig = + serde_json::from_value(serde_json::json!({ "command": "echo", "tools": ["x"] })).unwrap(); + assert_eq!(cfg.tools, Some(vec!["x".to_string()])); +} + +#[test] +fn mcp_http_tools_tri_state_serializes_correctly() { + use github_copilot_sdk::McpHttpServerConfig; + + let cfg = McpHttpServerConfig { + url: "https://example.com".into(), + tools: None, + ..Default::default() + }; + assert!( + serde_json::to_value(&cfg).unwrap().get("tools").is_none(), + "tools=None must be omitted on the wire" + ); + + let cfg = McpHttpServerConfig { + url: "https://example.com".into(), + tools: Some(vec![]), + ..Default::default() + }; + assert_eq!( + serde_json::to_value(&cfg).unwrap()["tools"], + serde_json::json!([]) + ); + + let cfg = McpHttpServerConfig { + url: "https://example.com".into(), + tools: Some(vec!["a".into()]), + ..Default::default() + }; + assert_eq!( + serde_json::to_value(&cfg).unwrap()["tools"], + serde_json::json!(["a"]) + ); +} + +#[test] +fn mcp_http_tools_tri_state_deserializes_correctly() { + use github_copilot_sdk::McpHttpServerConfig; + + let cfg: McpHttpServerConfig = + serde_json::from_value(serde_json::json!({ "url": "https://e.com" })).unwrap(); + assert_eq!(cfg.tools, None); + + let cfg: McpHttpServerConfig = + serde_json::from_value(serde_json::json!({ "url": "https://e.com", "tools": [] })).unwrap(); + assert_eq!(cfg.tools, Some(vec![])); +} + #[test] fn permission_request_data_extracts_typed_kind() { use github_copilot_sdk::{PermissionRequestData, PermissionRequestKind}; @@ -599,35 +710,15 @@ async fn force_stop_is_idempotent_with_no_child() { // Stream-based clients have no child process. force_stop should be a // no-op and safe to call multiple times. let (client, _server_read, _server_write) = make_client(); - assert_eq!( - client.state(), - github_copilot_sdk::ConnectionState::Connected - ); client.force_stop(); - assert_eq!( - client.state(), - github_copilot_sdk::ConnectionState::Disconnected - ); client.force_stop(); - assert_eq!( - client.state(), - github_copilot_sdk::ConnectionState::Disconnected - ); assert!(client.pid().is_none()); } #[tokio::test] -async fn stop_transitions_state_to_disconnected() { +async fn stop_is_safe_to_call() { let (client, _server_read, _server_write) = make_client(); - assert_eq!( - client.state(), - github_copilot_sdk::ConnectionState::Connected - ); client.stop().await.expect("stop should succeed"); - assert_eq!( - client.state(), - github_copilot_sdk::ConnectionState::Disconnected - ); } #[tokio::test] @@ -961,12 +1052,12 @@ async fn list_models_returns_typed_model_info() { #[tokio::test] async fn get_messages_returns_typed_events() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ let session = session.clone(); - async move { session.get_messages().await.unwrap() } + async move { session.get_events().await.unwrap() } }); let request = server.read_request().await; @@ -990,9 +1081,40 @@ async fn get_messages_returns_typed_events() { assert_eq!(events[0].event_type, "user.message"); } +#[tokio::test] +#[allow(deprecated)] +async fn deprecated_get_messages_alias_still_works() { + let (session, mut server) = create_session_pair().await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { session.get_messages().await.unwrap() } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.getMessages"); + server + .respond( + &request, + serde_json::json!({ + "events": [{ + "id": "e1", + "timestamp": "2025-01-01T00:00:00Z", + "type": "user.message", + "data": { "text": "hi" }, + }] + }), + ) + .await; + + let events = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert_eq!(events.len(), 1); +} + #[tokio::test] async fn set_model_sends_switch_to_request() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -1015,11 +1137,9 @@ async fn set_model_sends_switch_to_request() { #[tokio::test] async fn elicitation_returns_typed_result() { - let (session, mut server) = create_session_pair_with_capabilities( - Arc::new(NoopHandler), - serde_json::json!({ "ui": { "elicitation": true } }), - ) - .await; + let (session, mut server) = + create_session_pair_with_capabilities(serde_json::json!({ "ui": { "elicitation": true } })) + .await; let session = Arc::new(session); let schema = serde_json::json!({ "type": "object", @@ -1058,57 +1178,24 @@ async fn elicitation_returns_typed_result() { assert_eq!(result.content.unwrap()["name"], "Octocat"); } -#[tokio::test] -async fn tool_call_dispatches_to_handler() { - struct ToolHandler; - #[async_trait] - impl SessionHandler for ToolHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::ExternalTool { invocation } => { - assert_eq!(invocation.tool_name, "read_file"); - HandlerResponse::ToolResult(ToolResult::Text("file contents here".to_string())) - } - _ => HandlerResponse::Ok, - } - } - } - - let (_session, mut server) = create_session_pair(Arc::new(ToolHandler)).await; - server - .send_request( - 100, - "tool.call", - serde_json::json!({ - "sessionId": server.session_id, - "toolCallId": "tc-1", - "toolName": "read_file", - "arguments": { "path": "/foo.txt" }, - }), - ) - .await; - - let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); - assert_eq!(response["id"], 100); - assert_eq!(response["result"]["result"], "file contents here"); -} - #[tokio::test] async fn permission_request_dispatches_to_handler() { struct DenyHandler; #[async_trait] - impl SessionHandler for DenyHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::PermissionRequest { .. } => { - HandlerResponse::Permission(PermissionResult::Denied) - } - _ => HandlerResponse::Ok, - } + impl PermissionHandler for DenyHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Denied } } - let (_session, mut server) = create_session_pair(Arc::new(DenyHandler)).await; + let (_session, mut server) = + create_session_pair_with_config(|cfg| cfg.with_permission_handler(Arc::new(DenyHandler))) + .await; server .send_request( 200, @@ -1130,22 +1217,25 @@ async fn permission_request_dispatches_to_handler() { async fn user_input_request_dispatches_to_handler() { struct InputHandler; #[async_trait] - impl SessionHandler for InputHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::UserInput { question, .. } => { - assert_eq!(question, "Pick a color"); - HandlerResponse::UserInput(Some(UserInputResponse { - answer: "blue".to_string(), - was_freeform: true, - })) - } - _ => HandlerResponse::Ok, - } + impl UserInputHandler for InputHandler { + async fn handle( + &self, + _session_id: SessionId, + question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + assert_eq!(question, "Pick a color"); + Some(UserInputResponse { + answer: "blue".to_string(), + was_freeform: true, + }) } } - let (_session, mut server) = create_session_pair(Arc::new(InputHandler)).await; + let (_session, mut server) = + create_session_pair_with_config(|cfg| cfg.with_user_input_handler(Arc::new(InputHandler))) + .await; server .send_request( 300, @@ -1169,8 +1259,8 @@ async fn user_input_request_dispatches_to_handler() { async fn exit_plan_mode_request_dispatches_to_handler() { struct ExitHandler; #[async_trait] - impl SessionHandler for ExitHandler { - async fn on_exit_plan_mode( + impl ExitPlanModeHandler for ExitHandler { + async fn handle( &self, _session_id: SessionId, data: ExitPlanModeData, @@ -1190,7 +1280,10 @@ async fn exit_plan_mode_request_dispatches_to_handler() { } } - let (_session, mut server) = create_session_pair(Arc::new(ExitHandler)).await; + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_exit_plan_mode_handler(Arc::new(ExitHandler)) + }) + .await; server .send_request( 310, @@ -1216,8 +1309,8 @@ async fn exit_plan_mode_request_dispatches_to_handler() { async fn auto_mode_switch_request_dispatches_to_handler() { struct AutoModeHandler; #[async_trait] - impl SessionHandler for AutoModeHandler { - async fn on_auto_mode_switch( + impl AutoModeSwitchHandler for AutoModeHandler { + async fn handle( &self, _session_id: SessionId, error_code: Option, @@ -1229,7 +1322,10 @@ async fn auto_mode_switch_request_dispatches_to_handler() { } } - let (_session, mut server) = create_session_pair(Arc::new(AutoModeHandler)).await; + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_auto_mode_switch_handler(Arc::new(AutoModeHandler)) + }) + .await; server .send_request( 311, @@ -1249,7 +1345,10 @@ async fn auto_mode_switch_request_dispatches_to_handler() { #[tokio::test] async fn default_exit_plan_mode_response_omits_optional_fields() { - let (_session, mut server) = create_session_pair(Arc::new(ApproveAllHandler)).await; + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_permission_handler(Arc::new(ApproveAllHandler)) + }) + .await; server .send_request( 312, @@ -1282,16 +1381,19 @@ async fn user_input_requested_notification_does_not_double_dispatch() { invocations: Arc, } #[async_trait] - impl SessionHandler for CountingHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - if let HandlerEvent::UserInput { .. } = event { - self.invocations.fetch_add(1, Ordering::SeqCst); - return HandlerResponse::UserInput(Some(UserInputResponse { - answer: "ok".to_string(), - was_freeform: true, - })); - } - HandlerResponse::Ok + impl UserInputHandler for CountingHandler { + async fn handle( + &self, + _session_id: SessionId, + _question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + self.invocations.fetch_add(1, Ordering::SeqCst); + Some(UserInputResponse { + answer: "ok".to_string(), + was_freeform: true, + }) } } @@ -1299,7 +1401,8 @@ async fn user_input_requested_notification_does_not_double_dispatch() { let handler = Arc::new(CountingHandler { invocations: invocations.clone(), }); - let (_session, mut server) = create_session_pair(handler).await; + let (_session, mut server) = + create_session_pair_with_config(move |cfg| cfg.with_user_input_handler(handler)).await; server .send_event( @@ -1346,7 +1449,10 @@ async fn user_input_requested_notification_does_not_double_dispatch() { #[tokio::test] async fn approve_all_handler_approves_permission() { - let (_session, mut server) = create_session_pair(Arc::new(ApproveAllHandler)).await; + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_permission_handler(Arc::new(ApproveAllHandler)) + }) + .await; server .send_request( @@ -1365,61 +1471,28 @@ async fn approve_all_handler_approves_permission() { #[tokio::test] async fn session_event_notification_reaches_handler() { - let (event_tx, mut event_rx) = mpsc::unbounded_channel::(); - - struct EventCollector { - tx: mpsc::UnboundedSender, - } - #[async_trait] - impl SessionHandler for EventCollector { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - if let HandlerEvent::SessionEvent { event, .. } = event { - self.tx.send(event.event_type).unwrap(); - } - HandlerResponse::Ok - } - } - - let (_session, mut server) = - create_session_pair(Arc::new(EventCollector { tx: event_tx })).await; + let (session, mut server) = create_session_pair().await; + let mut sub = session.subscribe(); server .send_event("session.idle", serde_json::json!({})) .await; - let event_type = timeout(TIMEOUT, event_rx.recv()).await.unwrap().unwrap(); - assert_eq!(event_type, "session.idle"); + let event = timeout(TIMEOUT, sub.recv()).await.unwrap().unwrap(); + assert_eq!(event.event_type, "session.idle"); } #[tokio::test] async fn router_routes_to_correct_session() { let (client, mut server_read, mut server_write) = make_client(); - let (tx1, mut rx1) = mpsc::unbounded_channel::(); - let (tx2, mut rx2) = mpsc::unbounded_channel::(); - - struct Collector { - tx: mpsc::UnboundedSender, - } - #[async_trait] - impl SessionHandler for Collector { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - if let HandlerEvent::SessionEvent { event, .. } = event { - self.tx.send(event.event_type).unwrap(); - } - HandlerResponse::Ok - } - } - // Create two sessions on the same client let mut sessions = Vec::new(); let mut session_ids = Vec::new(); - for tx in [tx1, tx2] { + for _ in 0..2 { let h = tokio::spawn({ let client = client.clone(); async move { client - .create_session( - SessionConfig::default().with_handler(Arc::new(Collector { tx })), - ) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -1436,7 +1509,10 @@ async fn router_routes_to_correct_session() { sessions.push(timeout(TIMEOUT, h).await.unwrap().unwrap()); } - // Event for s-two should only reach rx2 + let mut sub1 = sessions[0].subscribe(); + let mut sub2 = sessions[1].subscribe(); + + // Event for s-two should only reach sub2 let notif = serde_json::json!({ "jsonrpc": "2.0", "method": "session.event", @@ -1447,12 +1523,20 @@ async fn router_routes_to_correct_session() { }); write_framed(&mut server_write, &serde_json::to_vec(¬if).unwrap()).await; assert_eq!( - timeout(TIMEOUT, rx2.recv()).await.unwrap().unwrap(), + timeout(TIMEOUT, sub2.recv()) + .await + .unwrap() + .unwrap() + .event_type, "assistant.message" ); - assert!(rx1.try_recv().is_err()); + assert!( + timeout(Duration::from_millis(100), sub1.recv()) + .await + .is_err() + ); - // Event for s-one should only reach rx1 + // Event for s-one should only reach sub1 let notif = serde_json::json!({ "jsonrpc": "2.0", "method": "session.event", @@ -1463,15 +1547,23 @@ async fn router_routes_to_correct_session() { }); write_framed(&mut server_write, &serde_json::to_vec(¬if).unwrap()).await; assert_eq!( - timeout(TIMEOUT, rx1.recv()).await.unwrap().unwrap(), + timeout(TIMEOUT, sub1.recv()) + .await + .unwrap() + .unwrap() + .event_type, "session.idle" ); - assert!(rx2.try_recv().is_err()); + assert!( + timeout(Duration::from_millis(100), sub2.recv()) + .await + .is_err() + ); } #[tokio::test] async fn send_and_wait_returns_last_assistant_message_on_idle() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -1507,7 +1599,7 @@ async fn send_and_wait_returns_last_assistant_message_on_idle() { #[tokio::test] async fn send_and_wait_returns_error_on_session_error() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -1542,7 +1634,7 @@ async fn send_and_wait_returns_error_on_session_error() { #[tokio::test] async fn send_and_wait_times_out() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -1578,7 +1670,7 @@ async fn send_and_wait_times_out() { /// Closes RFD-400 review finding #2. #[tokio::test] async fn send_and_wait_outer_cancellation_clears_waiter() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); // First call: wrap in outer timeout much shorter than the inner @@ -1635,7 +1727,7 @@ async fn send_and_wait_outer_cancellation_clears_waiter() { /// Closes RFD-400 review finding #2. #[tokio::test] async fn send_and_wait_drop_clears_waiter() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); // Start a send_and_wait, let it install the waiter, then abort the @@ -1694,25 +1786,25 @@ async fn send_and_wait_drop_clears_waiter() { async fn stop_event_loop_completes_in_flight_handler() { struct SlowHandler; #[async_trait] - impl SessionHandler for SlowHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::UserInput { .. } => { - // Sleep so stop_event_loop has a chance to fire while - // the handler is mid-flight. The loop must wait for - // this to return rather than abort it. - tokio::time::sleep(Duration::from_millis(150)).await; - HandlerResponse::UserInput(Some(UserInputResponse { - answer: "completed".to_string(), - was_freeform: false, - })) - } - _ => HandlerResponse::Ok, - } + impl UserInputHandler for SlowHandler { + async fn handle( + &self, + _session_id: SessionId, + _question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + tokio::time::sleep(Duration::from_millis(150)).await; + Some(UserInputResponse { + answer: "completed".to_string(), + was_freeform: false, + }) } } - let (session, mut server) = create_session_pair(Arc::new(SlowHandler)).await; + let (session, mut server) = + create_session_pair_with_config(|cfg| cfg.with_user_input_handler(Arc::new(SlowHandler))) + .await; let session = Arc::new(session); server @@ -1772,26 +1864,28 @@ async fn drop_session_does_not_abort_handler() { completed: Arc, } #[async_trait] - impl SessionHandler for CompletionHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::UserInput { .. } => { - tokio::time::sleep(Duration::from_millis(100)).await; - self.completed.store(true, Ordering::SeqCst); - HandlerResponse::UserInput(Some(UserInputResponse { - answer: "done".to_string(), - was_freeform: false, - })) - } - _ => HandlerResponse::Ok, - } + impl UserInputHandler for CompletionHandler { + async fn handle( + &self, + _session_id: SessionId, + _question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + tokio::time::sleep(Duration::from_millis(100)).await; + self.completed.store(true, Ordering::SeqCst); + Some(UserInputResponse { + answer: "done".to_string(), + was_freeform: false, + }) } } - let (session, mut server) = create_session_pair(Arc::new(CompletionHandler { + let handler = Arc::new(CompletionHandler { completed: handler_completed.clone(), - })) - .await; + }); + let (session, mut server) = + create_session_pair_with_config(move |cfg| cfg.with_user_input_handler(handler)).await; server .send_request( @@ -1826,8 +1920,10 @@ async fn drop_session_does_not_abort_handler() { /// session itself. #[tokio::test] async fn cancellation_token_fires_on_session_drop() { - let handler = Arc::new(ApproveAllHandler); - let (session, _server) = create_session_pair(handler).await; + let (session, _server) = create_session_pair_with_config(|cfg| { + cfg.with_permission_handler(Arc::new(ApproveAllHandler)) + }) + .await; let token = session.cancellation_token(); assert!(!token.is_cancelled()); @@ -1847,8 +1943,10 @@ async fn cancellation_token_fires_on_session_drop() { /// logic from the session's own lifecycle. #[tokio::test] async fn cancellation_token_child_cancel_does_not_kill_session() { - let handler = Arc::new(ApproveAllHandler); - let (session, _server) = create_session_pair(handler).await; + let (session, _server) = create_session_pair_with_config(|cfg| { + cfg.with_permission_handler(Arc::new(ApproveAllHandler)) + }) + .await; let child = session.cancellation_token(); child.cancel(); @@ -1865,22 +1963,25 @@ async fn elicitation_requested_dispatches_to_handler_and_responds() { struct ElicitHandler; #[async_trait] - impl SessionHandler for ElicitHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::ElicitationRequest { request, .. } => { - assert_eq!(request.message, "Enter your name"); - HandlerResponse::Elicitation(ElicitationResult { - action: "accept".to_string(), - content: Some(serde_json::json!({ "name": "Alice" })), - }) - } - _ => HandlerResponse::Ok, + impl ElicitationHandler for ElicitHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + request: ElicitationRequest, + ) -> ElicitationResult { + assert_eq!(request.message, "Enter your name"); + ElicitationResult { + action: "accept".to_string(), + content: Some(serde_json::json!({ "name": "Alice" })), } } } - let (_session, mut server) = create_session_pair(Arc::new(ElicitHandler)).await; + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_elicitation_handler(Arc::new(ElicitHandler)) + }) + .await; // CLI broadcasts elicitation.requested as a session event notification server @@ -1911,17 +2012,23 @@ async fn elicitation_requested_dispatches_to_handler_and_responds() { async fn elicitation_requested_cancels_on_handler_error() { struct FailHandler; #[async_trait] - impl SessionHandler for FailHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - // Return Ok instead of Elicitation — SDK should treat as cancel - HandlerEvent::ElicitationRequest { .. } => HandlerResponse::Ok, - _ => HandlerResponse::Ok, + impl ElicitationHandler for FailHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + _request: ElicitationRequest, + ) -> ElicitationResult { + ElicitationResult { + action: "cancel".to_string(), + content: None, } } } - let (_session, mut server) = create_session_pair(Arc::new(FailHandler)).await; + let (_session, mut server) = + create_session_pair_with_config(|cfg| cfg.with_elicitation_handler(Arc::new(FailHandler))) + .await; server .send_event( "elicitation.requested", @@ -1939,23 +2046,29 @@ async fn elicitation_requested_cancels_on_handler_error() { #[tokio::test] async fn external_tool_requested_dispatches_to_handler_and_responds() { - struct ExternalToolHandler; + struct RunTestsTool; #[async_trait] - impl SessionHandler for ExternalToolHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::ExternalTool { invocation } => { - assert_eq!(invocation.tool_name, "run_tests"); - assert_eq!(invocation.tool_call_id, "tc-ext-1"); - assert_eq!(invocation.arguments["suite"], "unit"); - HandlerResponse::ToolResult(ToolResult::Text("all tests passed".to_string())) - } - _ => HandlerResponse::Ok, - } + impl tool::ToolHandler for RunTestsTool { + async fn call( + &self, + invocation: ToolInvocation, + ) -> Result { + assert_eq!(invocation.tool_name, "run_tests"); + assert_eq!(invocation.tool_call_id, "tc-ext-1"); + assert_eq!(invocation.arguments["suite"], "unit"); + Ok(ToolResult::Text("all tests passed".to_string())) } } - let (_session, mut server) = create_session_pair(Arc::new(ExternalToolHandler)).await; + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_tools(vec![ + Tool::new("run_tests") + .with_description("Run tests") + .with_parameters(serde_json::json!({"type":"object"})) + .with_handler(Arc::new(RunTestsTool)), + ]) + }) + .await; server .send_event( @@ -1976,6 +2089,132 @@ async fn external_tool_requested_dispatches_to_handler_and_responds() { assert_eq!(rpc_call["params"]["result"], "all tests passed"); } +#[tokio::test] +async fn external_tool_broadcast_for_unknown_tool_is_not_responded_to() { + // Phase H multi-client safety: a handler that doesn't claim the + // requested tool name must not send an RPC response — another client + // on the same CLI may have a real handler. + struct FooTool; + #[async_trait] + impl tool::ToolHandler for FooTool { + async fn call( + &self, + _invocation: ToolInvocation, + ) -> Result { + Ok(ToolResult::Text("foo".to_string())) + } + } + + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_tools(vec![ + Tool::new("foo") + .with_description("foo") + .with_parameters(serde_json::json!({"type":"object"})) + .with_handler(Arc::new(FooTool)), + ]) + }) + .await; + server + .send_event( + "external_tool.requested", + serde_json::json!({ + "requestId": "req-unknown", + "sessionId": server.session_id, + "toolCallId": "tc-x", + "toolName": "bar", + "arguments": {}, + }), + ) + .await; + + // The dispatcher must NOT respond. Read with a short timeout and + // assert the read times out. + let res = tokio::time::timeout(Duration::from_millis(150), server.read_request()).await; + assert!( + res.is_err(), + "expected no RPC response for unknown tool, got: {:?}", + res.ok() + ); +} + +#[tokio::test] +async fn permission_broadcast_with_resolved_by_hook_is_not_responded_to() { + // Phase H: when the runtime marks a permission request as already + // resolved by a hook, the client must not respond again. + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_permission_handler(Arc::new(ApproveAllHandler)) + }) + .await; + server + .send_event( + "permission.requested", + serde_json::json!({ + "requestId": "req-hooked", + "sessionId": server.session_id, + "resolvedByHook": true, + "permissionRequest": { "kind": "shell" }, + }), + ) + .await; + + let res = tokio::time::timeout(Duration::from_millis(150), server.read_request()).await; + assert!( + res.is_err(), + "expected no RPC when resolvedByHook=true, got: {:?}", + res.ok() + ); +} + +#[tokio::test] +async fn permission_broadcast_with_no_claiming_handler_is_not_responded_to() { + // Phase H: a handler that doesn't claim permission dispatch must not + // respond — the SDK lets other connected clients handle the request. + let (_session, mut server) = create_session_pair().await; + server + .send_event( + "permission.requested", + serde_json::json!({ + "requestId": "req-pending", + "sessionId": server.session_id, + "permissionRequest": { "kind": "shell" }, + }), + ) + .await; + + let res = tokio::time::timeout(Duration::from_millis(150), server.read_request()).await; + assert!( + res.is_err(), + "expected no RPC when handler doesn't claim permission dispatch, got: {:?}", + res.ok() + ); +} + +#[tokio::test] +async fn elicitation_broadcast_with_no_claiming_handler_is_not_responded_to() { + // Phase H: same gating for elicitation. The default handler doesn't + // claim elicitation, so broadcasts are silently dropped. + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_permission_handler(Arc::new(ApproveAllHandler)) + }) + .await; + server + .send_event( + "elicitation.requested", + serde_json::json!({ + "requestId": "elicit-silent", + "message": "should not be answered", + }), + ) + .await; + + let res = tokio::time::timeout(Duration::from_millis(150), server.read_request()).await; + assert!( + res.is_err(), + "expected no RPC when handler doesn't claim elicitation, got: {:?}", + res.ok() + ); +} + #[tokio::test] async fn capabilities_captured_from_create_response() { let (client, mut server_read, mut server_write) = make_client(); @@ -1984,7 +2223,7 @@ async fn capabilities_captured_from_create_response() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -2012,7 +2251,7 @@ async fn capabilities_captured_from_create_response() { #[tokio::test] async fn capabilities_changed_event_updates_session() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; // Initially no capabilities (create_session_pair doesn't send them) assert!(session.capabilities().ui.is_none()); @@ -2051,7 +2290,9 @@ async fn request_elicitation_sent_in_create_params() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session( + SessionConfig::default().with_permission_handler(Arc::new(ApproveAllHandler)), + ) .await .unwrap() } @@ -2059,9 +2300,44 @@ async fn request_elicitation_sent_in_create_params() { let request = read_framed(&mut server_read).await; assert_eq!(request["method"], "session.create"); - assert_eq!(request["params"]["requestElicitation"], true); - assert_eq!(request["params"]["requestExitPlanMode"], true); - assert_eq!(request["params"]["requestAutoModeSwitch"], true); + // ApproveAllHandler claims permission dispatch only; no other handlers + // are installed, so the wire flags reflect that exact responsibility. + assert_eq!(request["params"]["requestPermission"], true); + assert_eq!(request["params"]["requestElicitation"], false); + assert_eq!(request["params"]["requestExitPlanMode"], false); + assert_eq!(request["params"]["requestAutoModeSwitch"], false); + + let id = request["id"].as_u64().unwrap(); + let session_id = requested_session_id(&request); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessionId": session_id }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn noop_handler_sends_request_permission_false() { + // Phase H1a wire-flag derivation: a handler that doesn't claim + // permission dispatch must send requestPermission=false so the + // runtime doesn't broadcast permission events to this client. + let (client, mut server_read, mut server_write) = make_client(); + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default()) + .await + .unwrap() + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["params"]["requestPermission"], false); + assert_eq!(request["params"]["requestElicitation"], false); let id = request["id"].as_u64().unwrap(); let session_id = requested_session_id(&request); @@ -2084,7 +2360,7 @@ async fn env_value_mode_hardcoded_direct_on_create_and_resume() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -2108,8 +2384,7 @@ async fn env_value_mode_hardcoded_direct_on_create_and_resume() { let client = client.clone(); let session_id = session_id.clone(); async move { - let cfg = ResumeSessionConfig::new(SessionId::from(session_id)) - .with_handler(Arc::new(NoopHandler)); + let cfg = ResumeSessionConfig::new(SessionId::from(session_id)); client.resume_session(cfg).await.unwrap() } }); @@ -2138,7 +2413,7 @@ async fn env_value_mode_hardcoded_direct_on_create_and_resume() { #[tokio::test] async fn elicitation_methods_fail_without_capability() { - let (session, _server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, _server) = create_session_pair().await; // Session created without capabilities — elicitation should fail let err = session @@ -2163,7 +2438,6 @@ async fn elicitation_methods_fail_without_capability() { } async fn create_session_pair_with_hooks( - handler: Arc, hooks: Arc, ) -> (github_copilot_sdk::session::Session, FakeServer) { let (client, server_read, server_write) = make_client(); @@ -2176,14 +2450,9 @@ async fn create_session_pair_with_hooks( let create_handle = tokio::spawn({ let client = client.clone(); - let handler = handler.clone(); async move { client - .create_session( - SessionConfig::default() - .with_handler(handler) - .with_hooks(hooks), - ) + .create_session(SessionConfig::default().with_hooks(hooks)) .await .unwrap() } @@ -2233,8 +2502,7 @@ async fn hooks_invoke_dispatches_to_session_hooks() { } } - let (_session, mut server) = - create_session_pair_with_hooks(Arc::new(NoopHandler), Arc::new(PolicyHooks)).await; + let (_session, mut server) = create_session_pair_with_hooks(Arc::new(PolicyHooks)).await; // Send a hooks.invoke request for a denied tool server @@ -2272,8 +2540,7 @@ async fn hooks_invoke_returns_empty_for_unregistered_hook() { #[async_trait] impl SessionHooks for EmptyHooks {} - let (_session, mut server) = - create_session_pair_with_hooks(Arc::new(NoopHandler), Arc::new(EmptyHooks)).await; + let (_session, mut server) = create_session_pair_with_hooks(Arc::new(EmptyHooks)).await; server .send_request( @@ -2297,8 +2564,7 @@ async fn hooks_invoke_returns_empty_for_unregistered_hook() { assert_eq!(response["result"]["output"], serde_json::json!({})); } -async fn create_session_pair_with_transforms( - handler: Arc, +async fn create_session_pair_with_system_message_transforms( transforms: Arc, ) -> (github_copilot_sdk::session::Session, FakeServer) { let (client, server_read, server_write) = make_client(); @@ -2311,14 +2577,9 @@ async fn create_session_pair_with_transforms( let create_handle = tokio::spawn({ let client = client.clone(); - let handler = handler.clone(); async move { client - .create_session( - SessionConfig::default() - .with_handler(handler) - .with_transform(transforms), - ) + .create_session(SessionConfig::default().with_system_message_transform(transforms)) .await .unwrap() } @@ -2365,7 +2626,7 @@ async fn system_message_transform_dispatches_to_transform() { } let (_session, mut server) = - create_session_pair_with_transforms(Arc::new(NoopHandler), Arc::new(AppendTransform)).await; + create_session_pair_with_system_message_transforms(Arc::new(AppendTransform)).await; server .send_request( @@ -2410,7 +2671,7 @@ async fn system_message_transform_returns_error_for_missing_sections() { } let (_session, mut server) = - create_session_pair_with_transforms(Arc::new(NoopHandler), Arc::new(DummyTransform)).await; + create_session_pair_with_system_message_transforms(Arc::new(DummyTransform)).await; // Send request with no sections parameter server @@ -2430,7 +2691,7 @@ async fn system_message_transform_returns_error_for_missing_sections() { #[tokio::test] async fn rpc_namespace_session_agent_list_dispatches_correctly() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let s = session.clone(); @@ -2449,7 +2710,7 @@ async fn rpc_namespace_session_agent_list_dispatches_correctly() { #[tokio::test] async fn rpc_namespace_session_tasks_list_dispatches_correctly() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let s = session.clone(); @@ -2468,7 +2729,7 @@ async fn rpc_namespace_session_tasks_list_dispatches_correctly() { #[tokio::test] async fn rpc_namespace_client_models_list_dispatches_correctly() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let client = session.client().clone(); @@ -2501,7 +2762,7 @@ async fn client_stop_sends_session_destroy_for_each_active_session() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -2521,7 +2782,7 @@ async fn client_stop_sends_session_destroy_for_each_active_session() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -2563,7 +2824,7 @@ async fn client_stop_sends_session_destroy_for_each_active_session() { async fn client_stop_aggregates_session_destroy_errors() { // session.destroy fails on the wire — Client::stop returns // StopErrors carrying the failure rather than short-circuiting. - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let client = session.client().clone(); let stop_handle = tokio::spawn(async move { client.stop().await }); @@ -2689,7 +2950,6 @@ impl CommandHandler for CountingCommandHandler { } async fn create_session_pair_with_commands( - handler: Arc, commands: Vec, ) -> (github_copilot_sdk::session::Session, FakeServer, Value) { let (client, server_read, server_write) = make_client(); @@ -2702,14 +2962,9 @@ async fn create_session_pair_with_commands( let create_handle = tokio::spawn({ let client = client.clone(); - let handler = handler.clone(); async move { client - .create_session( - SessionConfig::default() - .with_handler(handler) - .with_commands(commands), - ) + .create_session(SessionConfig::default().with_commands(commands)) .await .unwrap() } @@ -2753,8 +3008,7 @@ async fn create_serializes_commands_strips_handler() { ), ]; - let (_session, _server, create_req) = - create_session_pair_with_commands(Arc::new(NoopHandler), commands).await; + let (_session, _server, create_req) = create_session_pair_with_commands(commands).await; let wire = create_req["params"]["commands"] .as_array() @@ -2793,8 +3047,7 @@ async fn command_execute_dispatches_to_registered_handler_and_acks_success() { }), )]; - let (session, mut server, _) = - create_session_pair_with_commands(Arc::new(NoopHandler), commands).await; + let (session, mut server, _) = create_session_pair_with_commands(commands).await; server .send_event( @@ -2839,8 +3092,7 @@ async fn command_execute_dispatches_to_registered_handler_and_acks_success() { #[tokio::test] async fn command_execute_unknown_command_acks_with_error() { - let (session, mut server, _) = - create_session_pair_with_commands(Arc::new(NoopHandler), vec![]).await; + let (session, mut server, _) = create_session_pair_with_commands(vec![]).await; server .send_event( @@ -2878,8 +3130,7 @@ async fn command_execute_handler_error_propagates_to_ack() { }), )]; - let (_session, mut server, _) = - create_session_pair_with_commands(Arc::new(NoopHandler), commands).await; + let (_session, mut server, _) = create_session_pair_with_commands(commands).await; server .send_event( @@ -3040,7 +3291,6 @@ impl SessionFsSqliteProvider for RecordingFsProvider { } async fn create_session_pair_with_fs_provider( - handler: Arc, provider: Arc, ) -> (github_copilot_sdk::session::Session, FakeServer) { let (client, server_read, server_write) = make_client(); @@ -3053,14 +3303,9 @@ async fn create_session_pair_with_fs_provider( let create_handle = tokio::spawn({ let client = client.clone(); - let handler = handler.clone(); async move { client - .create_session( - SessionConfig::default() - .with_handler(handler) - .with_session_fs_provider(provider), - ) + .create_session(SessionConfig::default().with_session_fs_provider(provider)) .await .unwrap() } @@ -3086,8 +3331,7 @@ async fn create_session_pair_with_fs_provider( #[tokio::test] async fn session_fs_dispatches_read_file_to_provider() { let provider = Arc::new(RecordingFsProvider::new().with_file("/foo.txt", "hello world")); - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + let (_session, mut server) = create_session_pair_with_fs_provider(provider).await; server .send_request( @@ -3106,8 +3350,7 @@ async fn session_fs_dispatches_read_file_to_provider() { #[tokio::test] async fn session_fs_maps_not_found_to_enoent() { let provider = Arc::new(RecordingFsProvider::new()); - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + let (_session, mut server) = create_session_pair_with_fs_provider(provider).await; server .send_request( @@ -3134,8 +3377,7 @@ async fn session_fs_maps_other_to_unknown() { } } - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), Arc::new(AlwaysFails)).await; + let (_session, mut server) = create_session_pair_with_fs_provider(Arc::new(AlwaysFails)).await; server .send_request( @@ -3159,8 +3401,7 @@ async fn session_fs_maps_other_to_unknown() { #[tokio::test] async fn session_fs_dispatches_sqlite_query_to_provider() { let provider = Arc::new(RecordingFsProvider::new()); - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + let (_session, mut server) = create_session_pair_with_fs_provider(provider).await; server .send_request( @@ -3191,8 +3432,7 @@ async fn session_fs_dispatches_sqlite_query_to_provider() { #[tokio::test] async fn session_fs_dispatches_sqlite_exists_to_provider() { let provider = Arc::new(RecordingFsProvider::new()); - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + let (_session, mut server) = create_session_pair_with_fs_provider(provider).await; server .send_request( @@ -3232,8 +3472,7 @@ async fn session_fs_maps_sqlite_errors_to_results() { } } - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), Arc::new(AlwaysFails)).await; + let (_session, mut server) = create_session_pair_with_fs_provider(Arc::new(AlwaysFails)).await; server .send_request( @@ -3275,8 +3514,7 @@ async fn session_fs_maps_sqlite_errors_to_results() { #[tokio::test] async fn session_fs_dispatches_write_file_with_mode() { let provider = Arc::new(RecordingFsProvider::new()); - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider.clone()).await; + let (_session, mut server) = create_session_pair_with_fs_provider(provider.clone()).await; server .send_request( @@ -3295,8 +3533,7 @@ async fn session_fs_dispatches_write_file_with_mode() { #[tokio::test] async fn session_fs_dispatches_readdir_with_types() { let provider = Arc::new(RecordingFsProvider::new()); - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + let (_session, mut server) = create_session_pair_with_fs_provider(provider).await; server .send_request( @@ -3318,8 +3555,7 @@ async fn session_fs_dispatches_readdir_with_types() { #[tokio::test] async fn session_fs_dispatches_rm_with_force() { let provider = Arc::new(RecordingFsProvider::new()); - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + let (_session, mut server) = create_session_pair_with_fs_provider(provider).await; server .send_request( @@ -3413,7 +3649,7 @@ async fn on_get_trace_context_called_on_session_create() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -3452,8 +3688,7 @@ async fn on_get_trace_context_called_on_session_resume() { let resume_handle = tokio::spawn({ let client = client.clone(); async move { - let cfg = ResumeSessionConfig::new(SessionId::from("trace-resume")) - .with_handler(Arc::new(NoopHandler)); + let cfg = ResumeSessionConfig::new(SessionId::from("trace-resume")); client.resume_session(cfg).await.unwrap() } }); @@ -3498,7 +3733,7 @@ async fn on_get_trace_context_called_on_session_send() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -3551,7 +3786,7 @@ async fn message_options_trace_context_overrides_callback() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -3600,7 +3835,7 @@ async fn message_options_trace_context_overrides_callback() { #[tokio::test] async fn message_options_trace_context_used_without_callback() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let send_handle = tokio::spawn({ @@ -3628,33 +3863,42 @@ async fn message_options_trace_context_used_without_callback() { #[tokio::test] async fn tool_invocation_carries_trace_context_from_event() { - use github_copilot_sdk::handler::{HandlerEvent, HandlerResponse, SessionHandler}; - - struct CapturingHandler { - captured: parking_lot::Mutex, Option)>>, - signal: tokio::sync::Notify, + type CapturedTrace = Arc, Option)>>>; + struct CapturingTool { + captured: CapturedTrace, + signal: Arc, } #[async_trait] - impl SessionHandler for CapturingHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - if let HandlerEvent::ExternalTool { invocation } = event { - *self.captured.lock() = Some(( - invocation.traceparent.clone(), - invocation.tracestate.clone(), - )); - self.signal.notify_one(); - return HandlerResponse::ToolResult(ToolResult::Text("ok".into())); - } - HandlerResponse::Ok + impl tool::ToolHandler for CapturingTool { + async fn call( + &self, + invocation: ToolInvocation, + ) -> Result { + *self.captured.lock() = Some(( + invocation.traceparent.clone(), + invocation.tracestate.clone(), + )); + self.signal.notify_one(); + Ok(ToolResult::Text("ok".into())) } } - let handler = Arc::new(CapturingHandler { - captured: parking_lot::Mutex::new(None), - signal: tokio::sync::Notify::new(), - }); - let (_session, mut server) = create_session_pair(handler.clone()).await; + let captured = Arc::new(parking_lot::Mutex::new(None)); + let signal = Arc::new(tokio::sync::Notify::new()); + let handler = Arc::new(CapturingTool { + captured: captured.clone(), + signal: signal.clone(), + }); + let (_session, mut server) = create_session_pair_with_config(move |cfg| { + cfg.with_tools(vec![ + Tool::new("calc") + .with_description("calc") + .with_parameters(serde_json::json!({"type":"object"})) + .with_handler(handler.clone()), + ]) + }) + .await; server .send_event( @@ -3675,8 +3919,8 @@ async fn tool_invocation_carries_trace_context_from_event() { let pending = timeout(TIMEOUT, server.read_request()).await.unwrap(); assert_eq!(pending["method"], "session.tools.handlePendingToolCall"); - timeout(TIMEOUT, handler.signal.notified()).await.unwrap(); - let captured = handler.captured.lock().clone(); + timeout(TIMEOUT, signal.notified()).await.unwrap(); + let captured = captured.lock().clone(); assert_eq!( captured, Some((Some("00-tool-01".into()), Some("vendor=tool".into()))), @@ -3685,7 +3929,7 @@ async fn tool_invocation_carries_trace_context_from_event() { #[tokio::test] async fn wire_omits_trace_fields_when_unset() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let send_handle = tokio::spawn({ diff --git a/test/scenarios/callbacks/hooks/rust/src/main.rs b/test/scenarios/callbacks/hooks/rust/src/main.rs index 179765d2f..c0fcc56d0 100644 --- a/test/scenarios/callbacks/hooks/rust/src/main.rs +++ b/test/scenarios/callbacks/hooks/rust/src/main.rs @@ -102,7 +102,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); let config = config - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_hooks(hooks); let session = client.create_session(config).await?; @@ -126,6 +126,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } println!("\nTotal hooks fired: {}", log.len()); - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/callbacks/permissions/rust/src/main.rs b/test/scenarios/callbacks/permissions/rust/src/main.rs index 214620e35..c44b691bf 100644 --- a/test/scenarios/callbacks/permissions/rust/src/main.rs +++ b/test/scenarios/callbacks/permissions/rust/src/main.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use async_trait::async_trait; -use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::handler::{PermissionHandler, PermissionResult}; use github_copilot_sdk::hooks::{HookContext, PreToolUseInput, PreToolUseOutput, SessionHooks}; use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionConfig, SessionId}; use github_copilot_sdk::{Client, ClientOptions}; @@ -15,8 +15,8 @@ struct PermissionLogger { } #[async_trait] -impl SessionHandler for PermissionLogger { - async fn on_permission_request( +impl PermissionHandler for PermissionLogger { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -62,7 +62,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); let config = config - .with_handler(handler) + .with_permission_handler(handler) .with_hooks(Arc::new(AllowAllHooks)); let session = client.create_session(config).await?; @@ -86,6 +86,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } println!("\nTotal permission requests: {}", log.len()); - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/callbacks/user-input/rust/src/main.rs b/test/scenarios/callbacks/user-input/rust/src/main.rs index b7fea906e..1517727e9 100644 --- a/test/scenarios/callbacks/user-input/rust/src/main.rs +++ b/test/scenarios/callbacks/user-input/rust/src/main.rs @@ -4,7 +4,9 @@ use std::sync::Arc; use async_trait::async_trait; -use github_copilot_sdk::handler::{PermissionResult, SessionHandler, UserInputResponse}; +use github_copilot_sdk::handler::{ + PermissionHandler, PermissionResult, UserInputHandler, UserInputResponse, +}; use github_copilot_sdk::hooks::{HookContext, PreToolUseInput, PreToolUseOutput, SessionHooks}; use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionConfig, SessionId}; use github_copilot_sdk::{Client, ClientOptions}; @@ -15,8 +17,8 @@ struct InputResponder { } #[async_trait] -impl SessionHandler for InputResponder { - async fn on_permission_request( +impl PermissionHandler for InputResponder { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -24,8 +26,11 @@ impl SessionHandler for InputResponder { ) -> PermissionResult { PermissionResult::Approved } +} - async fn on_user_input( +#[async_trait] +impl UserInputHandler for InputResponder { + async fn handle( &self, _session_id: SessionId, question: String, @@ -71,9 +76,9 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); - config.request_user_input = Some(true); let config = config - .with_handler(handler) + .with_permission_handler(handler.clone()) + .with_user_input_handler(handler) .with_hooks(Arc::new(AllowAllHooks)); let session = client.create_session(config).await?; @@ -98,6 +103,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } println!("\nTotal user input requests: {}", log.len()); - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/modes/default/rust/src/main.rs b/test/scenarios/modes/default/rust/src/main.rs index ba890997d..d316c1a0a 100644 --- a/test/scenarios/modes/default/rust/src/main.rs +++ b/test/scenarios/modes/default/rust/src/main.rs @@ -15,7 +15,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; let response = session @@ -31,6 +31,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } println!("Default mode test complete"); - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/prompts/attachments/rust/src/main.rs b/test/scenarios/prompts/attachments/rust/src/main.rs index 9ba9cc176..ea96d5b56 100644 --- a/test/scenarios/prompts/attachments/rust/src/main.rs +++ b/test/scenarios/prompts/attachments/rust/src/main.rs @@ -25,7 +25,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { config.model = Some("claude-haiku-4.5".to_string()); config.system_message = Some(sysmsg); config.available_tools = Some(Vec::new()); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; @@ -53,6 +53,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/prompts/reasoning-effort/rust/src/main.rs b/test/scenarios/prompts/reasoning-effort/rust/src/main.rs index bf1ab9720..f675da5e5 100644 --- a/test/scenarios/prompts/reasoning-effort/rust/src/main.rs +++ b/test/scenarios/prompts/reasoning-effort/rust/src/main.rs @@ -22,7 +22,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { config.reasoning_effort = Some("low".to_string()); config.available_tools = Some(Vec::new()); config.system_message = Some(sysmsg); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; @@ -35,6 +35,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/prompts/system-message/rust/src/main.rs b/test/scenarios/prompts/system-message/rust/src/main.rs index 4218a389b..7233a64b9 100644 --- a/test/scenarios/prompts/system-message/rust/src/main.rs +++ b/test/scenarios/prompts/system-message/rust/src/main.rs @@ -23,7 +23,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { config.model = Some("claude-haiku-4.5".to_string()); config.system_message = Some(sysmsg); config.available_tools = Some(Vec::new()); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; @@ -35,6 +35,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs b/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs index 43932b613..7a64a6b92 100644 --- a/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs +++ b/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs @@ -19,7 +19,7 @@ fn make_config(system: &str) -> SessionConfig { config.model = Some("claude-haiku-4.5".to_string()); config.system_message = Some(sysmsg); config.available_tools = Some(Vec::new()); - config.with_handler(Arc::new(ApproveAllHandler)) + config.with_permission_handler(Arc::new(ApproveAllHandler)) } #[tokio::main] @@ -47,7 +47,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session1.destroy().await?; - session2.destroy().await?; + session1.disconnect().await?; + session2.disconnect().await?; Ok(()) } diff --git a/test/scenarios/sessions/infinite-sessions/rust/src/main.rs b/test/scenarios/sessions/infinite-sessions/rust/src/main.rs index 0c0f06814..2ccb1d786 100644 --- a/test/scenarios/sessions/infinite-sessions/rust/src/main.rs +++ b/test/scenarios/sessions/infinite-sessions/rust/src/main.rs @@ -28,7 +28,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { config.available_tools = Some(Vec::new()); config.system_message = Some(sysmsg); config.infinite_sessions = Some(infinite); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; @@ -50,6 +50,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { println!("Infinite sessions test complete — all messages processed successfully"); - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/sessions/session-resume/rust/src/main.rs b/test/scenarios/sessions/session-resume/rust/src/main.rs index 10cd4fa62..b6e6fbf8b 100644 --- a/test/scenarios/sessions/session-resume/rust/src/main.rs +++ b/test/scenarios/sessions/session-resume/rust/src/main.rs @@ -16,7 +16,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); config.available_tools = Some(Vec::new()); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; session @@ -27,7 +27,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { // Note: do NOT destroy — `resume_session` needs the session to persist. let resume_config = - ResumeSessionConfig::new(session_id).with_handler(Arc::new(ApproveAllHandler)); + ResumeSessionConfig::new(session_id).with_permission_handler(Arc::new(ApproveAllHandler)); let resumed = client.resume_session(resume_config).await?; println!("Session resumed"); @@ -41,6 +41,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - resumed.destroy().await?; + resumed.disconnect().await?; Ok(()) } diff --git a/test/scenarios/sessions/streaming/rust/src/main.rs b/test/scenarios/sessions/streaming/rust/src/main.rs index f5cf23764..d4201ef06 100644 --- a/test/scenarios/sessions/streaming/rust/src/main.rs +++ b/test/scenarios/sessions/streaming/rust/src/main.rs @@ -4,33 +4,10 @@ use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; -use async_trait::async_trait; -use github_copilot_sdk::handler::{HandlerEvent, HandlerResponse, PermissionResult, SessionHandler}; +use github_copilot_sdk::handler::ApproveAllHandler; use github_copilot_sdk::types::SessionConfig; use github_copilot_sdk::{Client, ClientOptions}; -struct StreamCounter { - chunks: Arc, -} - -#[async_trait] -impl SessionHandler for StreamCounter { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::SessionEvent { event, .. } => { - if event.event_type == "assistant.message_delta" { - self.chunks.fetch_add(1, Ordering::Relaxed); - } - HandlerResponse::Ok - } - HandlerEvent::PermissionRequest { .. } => { - HandlerResponse::Permission(PermissionResult::Approved) - } - _ => HandlerResponse::Ok, - } - } -} - #[tokio::main] async fn main() -> Result<(), github_copilot_sdk::Error> { let mut opts = ClientOptions::default(); @@ -38,16 +15,23 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let client = Client::start(opts).await?; let chunks = Arc::new(AtomicUsize::new(0)); - let handler = Arc::new(StreamCounter { - chunks: chunks.clone(), - }); let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); config.streaming = Some(true); - let config = config.with_handler(handler); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; + let mut events = session.subscribe(); + let chunks_clone = chunks.clone(); + let counter = tokio::spawn(async move { + while let Ok(event) = events.recv().await { + if event.event_type == "assistant.message_delta" { + chunks_clone.fetch_add(1, Ordering::Relaxed); + } + } + }); + let response = session.send_and_wait("What is the capital of France?").await?; if let Some(event) = response { @@ -61,6 +45,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { chunks.load(Ordering::Relaxed) ); - session.destroy().await?; + session.disconnect().await?; + drop(counter); Ok(()) } diff --git a/test/scenarios/tools/custom-agents/rust/src/main.rs b/test/scenarios/tools/custom-agents/rust/src/main.rs index e707770bc..fe720b803 100644 --- a/test/scenarios/tools/custom-agents/rust/src/main.rs +++ b/test/scenarios/tools/custom-agents/rust/src/main.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandlerRouter, define_tool}; +use github_copilot_sdk::tool::define_tool; use github_copilot_sdk::types::{CustomAgentConfig, DefaultAgentConfig, SessionConfig, ToolResult}; use github_copilot_sdk::{Client, ClientOptions}; use schemars::JsonSchema; @@ -34,9 +34,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { }, ); - let router = ToolHandlerRouter::new(vec![analyze_codebase], Arc::new(ApproveAllHandler)); - let tools = router.tools(); - let mut researcher = CustomAgentConfig::default(); researcher.name = "researcher".to_string(); researcher.display_name = Some("Research Agent".to_string()); @@ -56,12 +53,13 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); - config.tools = Some(tools); config.default_agent = Some(DefaultAgentConfig { excluded_tools: Some(vec!["analyze-codebase".to_string()]), }); config.custom_agents = Some(vec![researcher]); - let config = config.with_handler(Arc::new(router)); + let config = config + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(vec![analyze_codebase]); let session = client.create_session(config).await?; @@ -77,6 +75,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/tools/mcp-servers/rust/src/main.rs b/test/scenarios/tools/mcp-servers/rust/src/main.rs index fd76147a1..171d2bcd4 100644 --- a/test/scenarios/tools/mcp-servers/rust/src/main.rs +++ b/test/scenarios/tools/mcp-servers/rust/src/main.rs @@ -25,7 +25,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { .map(|s| s.split(' ').map(str::to_string).collect()) .unwrap_or_default(); let stdio = McpStdioServerConfig { - tools: vec!["*".to_string()], command: cmd.clone(), args, ..Default::default() @@ -45,7 +44,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { config.system_message = Some(sysmsg); config.available_tools = Some(Vec::new()); config.mcp_servers = mcp_servers; - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; @@ -63,6 +62,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { println!("\nNo MCP servers configured (set MCP_SERVER_CMD to test with a real server)"); } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/tools/no-tools/rust/src/main.rs b/test/scenarios/tools/no-tools/rust/src/main.rs index 691ac47ed..64190c78b 100644 --- a/test/scenarios/tools/no-tools/rust/src/main.rs +++ b/test/scenarios/tools/no-tools/rust/src/main.rs @@ -26,7 +26,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { config.model = Some("claude-haiku-4.5".to_string()); config.system_message = Some(sysmsg); config.available_tools = Some(Vec::new()); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; let response = session @@ -39,6 +39,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/tools/skills/rust/src/main.rs b/test/scenarios/tools/skills/rust/src/main.rs index 845704fac..d2f1ad6f0 100644 --- a/test/scenarios/tools/skills/rust/src/main.rs +++ b/test/scenarios/tools/skills/rust/src/main.rs @@ -40,7 +40,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { config.model = Some("claude-haiku-4.5".to_string()); config.skill_directories = Some(vec![skills_dir]); let config = config - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_hooks(Arc::new(AllowAllHooks)); let session = client.create_session(config).await?; @@ -57,6 +57,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { println!("\nSkill directories configured successfully"); - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/tools/tool-filtering/rust/src/main.rs b/test/scenarios/tools/tool-filtering/rust/src/main.rs index edc203550..d4cd5d3c2 100644 --- a/test/scenarios/tools/tool-filtering/rust/src/main.rs +++ b/test/scenarios/tools/tool-filtering/rust/src/main.rs @@ -28,7 +28,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { "glob".to_string(), "view".to_string(), ]); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; @@ -42,6 +42,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/tools/tool-overrides/rust/src/main.rs b/test/scenarios/tools/tool-overrides/rust/src/main.rs index ce002a27d..5d5108724 100644 --- a/test/scenarios/tools/tool-overrides/rust/src/main.rs +++ b/test/scenarios/tools/tool-overrides/rust/src/main.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandlerRouter, define_tool}; +use github_copilot_sdk::tool::define_tool; use github_copilot_sdk::types::{SessionConfig, ToolResult}; use github_copilot_sdk::{Client, ClientOptions}; use schemars::JsonSchema; @@ -23,26 +23,20 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { opts.github_token = std::env::var("GITHUB_TOKEN").ok(); let client = Client::start(opts).await?; - let grep_tool = define_tool( + let mut grep_tool = define_tool( "grep", "A custom grep implementation that overrides the built-in", |_inv, params: GrepParams| async move { Ok(ToolResult::Text(format!("CUSTOM_GREP_RESULT: {}", params.query))) }, ); - - let router = ToolHandlerRouter::new(vec![grep_tool], Arc::new(ApproveAllHandler)); - let mut tools = router.tools(); - for t in tools.iter_mut() { - if t.name == "grep" { - t.overrides_built_in_tool = true; - } - } + grep_tool.overrides_built_in_tool = true; let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); - config.tools = Some(tools); - let config = config.with_handler(Arc::new(router)); + let config = config + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(vec![grep_tool]); let session = client.create_session(config).await?; @@ -56,6 +50,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/transport/stdio/rust/src/main.rs b/test/scenarios/transport/stdio/rust/src/main.rs index 156b3587d..2795a14fd 100644 --- a/test/scenarios/transport/stdio/rust/src/main.rs +++ b/test/scenarios/transport/stdio/rust/src/main.rs @@ -14,7 +14,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; let response = session.send_and_wait("What is the capital of France?").await?; @@ -25,6 +25,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/transport/tcp/rust/src/main.rs b/test/scenarios/transport/tcp/rust/src/main.rs index 49691c1b2..6488f243b 100644 --- a/test/scenarios/transport/tcp/rust/src/main.rs +++ b/test/scenarios/transport/tcp/rust/src/main.rs @@ -20,13 +20,14 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { opts.transport = Transport::External { host: host.to_string(), port, + connection_token: None, }; opts.github_token = std::env::var("GITHUB_TOKEN").ok(); let client = Client::start(opts).await?; let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; @@ -38,6 +39,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) }