From b6ab0b8343000d62468479eb7c20870557a4b6a3 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 18:48:45 +0100 Subject: [PATCH 01/22] Phase A: renames (get_messages, disable_resume, destroy) - Rename Session::get_messages() to get_events(); keep get_messages as a #[deprecated] alias for one release. - Rename ResumeSessionConfig::disable_resume to suppress_resume_event. Wire field stays disableResume for runtime compatibility. - Add #[deprecated] to Session::destroy(); point callers at disconnect(). - Update examples (chat, hooks, lifecycle_observer, tool_server) and all unit/integration test callsites. - Add tests: serde round-trip for suppress_resume_event <-> disableResume, and an explicit #[allow(deprecated)] test exercising the get_messages alias. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/README.md | 2 +- rust/examples/chat.rs | 2 +- rust/examples/hooks.rs | 2 +- rust/examples/lifecycle_observer.rs | 2 +- rust/examples/tool_server.rs | 2 +- rust/src/session.rs | 20 ++++++---- rust/src/types.rs | 38 ++++++++++++++----- rust/tests/e2e/error_resilience.rs | 2 +- rust/tests/e2e/event_fidelity.rs | 2 +- rust/tests/e2e/multi_client.rs | 2 +- .../e2e/multi_client_commands_elicitation.rs | 2 +- rust/tests/e2e/pending_work_resume.rs | 2 +- rust/tests/e2e/permissions.rs | 2 +- rust/tests/e2e/rpc_event_side_effects.rs | 9 ++--- rust/tests/e2e/rpc_session_state.rs | 10 ++--- rust/tests/e2e/session.rs | 12 +++--- rust/tests/e2e/session_config.rs | 6 +-- rust/tests/e2e/session_lifecycle.rs | 2 +- rust/tests/e2e/streaming_fidelity.rs | 2 +- rust/tests/e2e/support.rs | 2 +- rust/tests/session_test.rs | 35 ++++++++++++++++- 21 files changed, 106 insertions(+), 52 deletions(-) diff --git a/rust/README.md b/rust/README.md index 78103e4df..d8ff9c43c 100644 --- a/rust/README.md +++ b/rust/README.md @@ -101,7 +101,7 @@ let _id = session .await?; // Message history -let messages = session.get_messages().await?; +let messages = session.get_events().await?; // Abort the current agent turn session.abort().await?; diff --git a/rust/examples/chat.rs b/rust/examples/chat.rs index 37293c6bc..d017376ef 100644 --- a/rust/examples/chat.rs +++ b/rust/examples/chat.rs @@ -117,6 +117,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } println!("\nGoodbye."); - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/rust/examples/hooks.rs b/rust/examples/hooks.rs index 86f6ceadc..865873f46 100644 --- a/rust/examples/hooks.rs +++ b/rust/examples/hooks.rs @@ -128,6 +128,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { println!("\n{text}"); } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/rust/examples/lifecycle_observer.rs b/rust/examples/lifecycle_observer.rs index 612792073..25ea64502 100644 --- a/rust/examples/lifecycle_observer.rs +++ b/rust/examples/lifecycle_observer.rs @@ -97,7 +97,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { ) .await?; - session.destroy().await?; + session.disconnect().await?; // Synchronous shutdown — useful in panicking-cleanup paths or tests // where you don't have an async runtime available to await `stop()`. diff --git a/rust/examples/tool_server.rs b/rust/examples/tool_server.rs index 55bacbbe6..aefeaedb9 100644 --- a/rust/examples/tool_server.rs +++ b/rust/examples/tool_server.rs @@ -182,6 +182,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { println!("{text}"); } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/rust/src/session.rs b/rust/src/session.rs index d533dbc44..2c7241f03 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -446,8 +446,8 @@ impl Session { } } - /// Retrieve the session's message history. - pub async fn get_messages(&self) -> Result, Error> { + /// Retrieve the session's timeline events. + pub async fn get_events(&self) -> Result, Error> { let result = self .client .call( @@ -459,6 +459,12 @@ impl Session { Ok(response.events) } + /// Deprecated alias for [`get_events`](Self::get_events). + #[deprecated(since = "0.1.0", note = "Use `get_events()` instead")] + pub async fn get_messages(&self) -> Result, Error> { + self.get_events().await + } + /// Abort the current agent turn. /// /// # Cancel safety @@ -519,11 +525,11 @@ impl Session { Ok(()) } - /// Alias for [`disconnect`](Self::disconnect). - /// - /// Named after the `session.destroy` wire RPC. Prefer `disconnect` in - /// new code — the wire-level "destroy" is misleading because on-disk - /// state is preserved. + /// Deprecated alias for [`disconnect`](Self::disconnect). The + /// underlying wire RPC happens to be named `session.destroy`, but it + /// only severs the connection — on-disk session state is preserved. + /// Prefer `disconnect` in new code. + #[deprecated(since = "0.1.0", note = "Use `disconnect()` instead")] pub async fn destroy(&self) -> Result<(), Error> { self.disconnect().await } diff --git a/rust/src/types.rs b/rust/src/types.rs index 70f0c16b7..b70adaa43 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -1750,9 +1750,9 @@ pub struct ResumeSessionConfig { #[serde(skip)] pub session_fs_provider: Option>, /// Force-fail resume if the session does not exist on disk, instead of - /// silently starting a new session. - #[serde(skip_serializing_if = "Option::is_none")] - pub disable_resume: Option, + /// silently starting a new session. Wire field name stays `disableResume`. + #[serde(rename = "disableResume", skip_serializing_if = "Option::is_none")] + pub suppress_resume_event: Option, /// When `true`, instructs the runtime to continue any tool calls or /// permission requests that were pending when the previous connection /// was dropped. Use this together with [`Client::force_stop`] to hand @@ -1824,7 +1824,7 @@ impl std::fmt::Debug for ResumeSessionConfig { &self.hooks_handler.as_ref().map(|_| ""), ) .field("transform", &self.transform.as_ref().map(|_| "")) - .field("disable_resume", &self.disable_resume) + .field("suppress_resume_event", &self.suppress_resume_event) .field("continue_pending_work", &self.continue_pending_work) .finish() } @@ -1871,7 +1871,7 @@ impl ResumeSessionConfig { include_sub_agent_streaming_events: None, commands: None, session_fs_provider: None, - disable_resume: None, + suppress_resume_event: None, continue_pending_work: None, handler: None, hooks_handler: None, @@ -2163,8 +2163,8 @@ impl ResumeSessionConfig { /// Force-fail resume if the session does not exist on disk, instead /// of silently starting a new session. - pub fn with_disable_resume(mut self, disable: bool) -> Self { - self.disable_resume = Some(disable); + pub fn with_suppress_resume_event(mut self, suppress: bool) -> Self { + self.suppress_resume_event = Some(suppress); self } @@ -3465,7 +3465,7 @@ mod tests { .with_github_token("ghp_test") .with_enable_session_telemetry(false) .with_include_sub_agent_streaming_events(true) - .with_disable_resume(true) + .with_suppress_resume_event(true) .with_continue_pending_work(true); assert_eq!(cfg.session_id.as_str(), "sess-2"); @@ -3500,7 +3500,7 @@ mod tests { assert_eq!(cfg.github_token.as_deref(), Some("ghp_test")); assert_eq!(cfg.enable_session_telemetry, Some(false)); assert_eq!(cfg.include_sub_agent_streaming_events, Some(true)); - assert_eq!(cfg.disable_resume, Some(true)); + assert_eq!(cfg.suppress_resume_event, Some(true)); assert_eq!(cfg.continue_pending_work, Some(true)); } @@ -3520,6 +3520,26 @@ mod tests { assert!(wire.get("continuePendingWork").is_none()); } + /// The Rust field is `suppress_resume_event`, but the wire field stays + /// `disableResume` to preserve compatibility with the runtime and other + /// SDKs. + #[test] + fn resume_session_config_serializes_suppress_resume_event_to_disable_resume_on_wire() { + let cfg = + ResumeSessionConfig::new(SessionId::from("sess-1")).with_suppress_resume_event(true); + let wire = serde_json::to_value(&cfg).unwrap(); + assert_eq!(wire["disableResume"], true); + assert!(wire.get("suppressResumeEvent").is_none()); + + // Round-trip: deserialize from the wire shape. + let json = serde_json::json!({ + "sessionId": "sess-2", + "disableResume": true, + }); + let parsed: ResumeSessionConfig = serde_json::from_value(json).unwrap(); + assert_eq!(parsed.suppress_resume_event, Some(true)); + } + /// `instruction_directories` must serialize to wire as /// `instructionDirectories` on `SessionConfig`. #[test] diff --git a/rust/tests/e2e/error_resilience.rs b/rust/tests/e2e/error_resilience.rs index 3dc7cbc7c..218df9c7e 100644 --- a/rust/tests/e2e/error_resilience.rs +++ b/rust/tests/e2e/error_resilience.rs @@ -41,7 +41,7 @@ async fn should_throw_when_getting_messages_from_disconnected_session() { .expect("create session"); session.disconnect().await.expect("disconnect session"); - assert!(session.get_messages().await.is_err()); + assert!(session.get_events().await.is_err()); client.stop().await.expect("stop client"); }) diff --git a/rust/tests/e2e/event_fidelity.rs b/rust/tests/e2e/event_fidelity.rs index 61d1f4f1f..3f9904425 100644 --- a/rust/tests/e2e/event_fidelity.rs +++ b/rust/tests/e2e/event_fidelity.rs @@ -318,7 +318,7 @@ async fn should_preserve_message_order_in_getmessages_after_tool_use() { .await .expect("send"); - let messages = session.get_messages().await.expect("get messages"); + let messages = session.get_events().await.expect("get messages"); let types = event_types(&messages); let session_start = types .iter() diff --git a/rust/tests/e2e/multi_client.rs b/rust/tests/e2e/multi_client.rs index 7d1b61b30..6ea5c567e 100644 --- a/rust/tests/e2e/multi_client.rs +++ b/rust/tests/e2e/multi_client.rs @@ -425,7 +425,7 @@ fn resume_config(session_id: SessionId) -> ResumeSessionConfig { ResumeSessionConfig::new(session_id) .with_github_token(DEFAULT_TEST_TOKEN) .with_handler(selective_handler(Vec::new())) - .with_disable_resume(true) + .with_suppress_resume_event(true) } async fn start_tcp_server(ctx: &E2eContext, port: u16) -> Client { diff --git a/rust/tests/e2e/multi_client_commands_elicitation.rs b/rust/tests/e2e/multi_client_commands_elicitation.rs index 218418ece..e47504c03 100644 --- a/rust/tests/e2e/multi_client_commands_elicitation.rs +++ b/rust/tests/e2e/multi_client_commands_elicitation.rs @@ -200,7 +200,7 @@ fn resume_config(session_id: SessionId) -> ResumeSessionConfig { ResumeSessionConfig::new(session_id) .with_github_token(DEFAULT_TEST_TOKEN) .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) - .with_disable_resume(true) + .with_suppress_resume_event(true) } async fn start_tcp_server(ctx: &E2eContext, port: u16) -> Client { diff --git a/rust/tests/e2e/pending_work_resume.rs b/rust/tests/e2e/pending_work_resume.rs index 60f847416..5855d8030 100644 --- a/rust/tests/e2e/pending_work_resume.rs +++ b/rust/tests/e2e/pending_work_resume.rs @@ -228,7 +228,7 @@ async fn should_report_continuependingwork_true_in_resume_event() { .await .expect("resume session"); let resume_event = session2 - .get_messages() + .get_events() .await .expect("messages") .into_iter() diff --git a/rust/tests/e2e/permissions.rs b/rust/tests/e2e/permissions.rs index 8d7834768..d491f6d5b 100644 --- a/rust/tests/e2e/permissions.rs +++ b/rust/tests/e2e/permissions.rs @@ -486,7 +486,7 @@ async fn should_wait_for_slow_permission_handler() { release_tx.send(()).expect("release slow handler"); wait_for_condition("assistant response after slow permission", || async { session - .get_messages() + .get_events() .await .expect("get messages") .iter() diff --git a/rust/tests/e2e/rpc_event_side_effects.rs b/rust/tests/e2e/rpc_event_side_effects.rs index e68939f98..9e5e2f1a4 100644 --- a/rust/tests/e2e/rpc_event_side_effects.rs +++ b/rust/tests/e2e/rpc_event_side_effects.rs @@ -239,7 +239,7 @@ async fn should_emit_snapshot_rewind_event_and_remove_events_on_truncate() { .expect("assistant message"); assert!(assistant_message_content(&answer).contains("SNAPSHOT_REWIND_TARGET")); let user_event = session - .get_messages() + .get_events() .await .expect("messages") .into_iter() @@ -268,10 +268,7 @@ async fn should_emit_snapshot_rewind_event_and_remove_events_on_truncate() { assert!(result.events_removed >= 1); rewind.await; - let remaining = session - .get_messages() - .await - .expect("messages after truncate"); + let remaining = session.get_events().await.expect("messages after truncate"); assert!(!remaining.iter().any(|event| event.id == target_event_id)); session.disconnect().await.expect("disconnect session"); @@ -301,7 +298,7 @@ async fn should_allow_session_use_after_truncate() { .await .expect("send"); let user_event = session - .get_messages() + .get_events() .await .expect("messages") .into_iter() diff --git a/rust/tests/e2e/rpc_session_state.rs b/rust/tests/e2e/rpc_session_state.rs index 5dee2c8a3..a91842106 100644 --- a/rust/tests/e2e/rpc_session_state.rs +++ b/rust/tests/e2e/rpc_session_state.rs @@ -666,7 +666,7 @@ async fn should_fork_session_with_persisted_messages() { ) .await .expect("resume fork"); - let forked_messages = forked.get_messages().await.expect("forked messages"); + let forked_messages = forked.get_events().await.expect("forked messages"); assert!(contains_user_message( &forked_messages, "Say FORK_SOURCE_ALPHA exactly." @@ -682,7 +682,7 @@ async fn should_fork_session_with_persisted_messages() { .expect("send fork") .expect("fork answer"); assert!(assistant_message_content(&fork_answer).contains("FORK_CHILD_BETA")); - let source_after = session.get_messages().await.expect("source messages"); + let source_after = session.get_events().await.expect("source messages"); assert!(!contains_user_message( &source_after, "Now say FORK_CHILD_BETA exactly." @@ -736,7 +736,7 @@ async fn should_handle_forking_session_without_persisted_events() { .expect("resume fork"); assert!( !forked - .get_messages() + .get_events() .await .expect("forked messages") .iter() @@ -793,7 +793,7 @@ async fn should_fork_session_to_event_id_excluding_boundary_event() { .send_and_wait("Say FORK_BOUNDARY_SECOND exactly.") .await .expect("send second"); - let source_events = session.get_messages().await.expect("messages"); + let source_events = session.get_events().await.expect("messages"); let boundary_id = source_events .iter() .find(|event| { @@ -825,7 +825,7 @@ async fn should_fork_session_to_event_id_excluding_boundary_event() { ) .await .expect("resume fork"); - let forked_events = forked.get_messages().await.expect("forked messages"); + let forked_events = forked.get_events().await.expect("forked messages"); assert!(contains_user_message( &forked_events, "Say FORK_BOUNDARY_FIRST exactly." diff --git a/rust/tests/e2e/session.rs b/rust/tests/e2e/session.rs index 25aff47a9..59b42a30a 100644 --- a/rust/tests/e2e/session.rs +++ b/rust/tests/e2e/session.rs @@ -37,7 +37,7 @@ async fn shouldcreateanddisconnectsessions() { .expect("create session"); assert_uuid_like(session.id()); - let messages = session.get_messages().await.expect("get messages"); + let messages = session.get_events().await.expect("get messages"); assert!(!messages.is_empty(), "expected initial session events"); let start = messages[0] .typed_data::() @@ -46,7 +46,7 @@ async fn shouldcreateanddisconnectsessions() { session.disconnect().await.expect("disconnect session"); assert!( - session.get_messages().await.is_err(), + session.get_events().await.is_err(), "disconnected session should no longer serve message history" ); client.stop().await.expect("stop client"); @@ -447,7 +447,7 @@ async fn should_abort_a_session() { session.abort().await.expect("abort session"); idle.await.expect("idle task"); - let messages = session.get_messages().await.expect("get messages"); + let messages = session.get_events().await.expect("get messages"); assert!(messages .iter() .any(|event| event.parsed_type() == SessionEventType::Abort)); @@ -558,7 +558,7 @@ async fn should_resume_a_session_using_a_new_client() { .expect("resume session"); assert_eq!(resumed.id(), &session_id); - let messages = resumed.get_messages().await.expect("get messages"); + let messages = resumed.get_events().await.expect("get messages"); assert!( messages .iter() @@ -1389,7 +1389,7 @@ async fn should_send_with_mode_property() { .await; let user_message = session - .get_messages() + .get_events() .await .expect("get messages") .into_iter() @@ -1507,7 +1507,7 @@ async fn latest_user_message( session: &github_copilot_sdk::session::Session, ) -> github_copilot_sdk::SessionEvent { session - .get_messages() + .get_events() .await .expect("get messages") .into_iter() diff --git a/rust/tests/e2e/session_config.rs b/rust/tests/e2e/session_config.rs index 05c818169..38768a5b9 100644 --- a/rust/tests/e2e/session_config.rs +++ b/rust/tests/e2e/session_config.rs @@ -165,7 +165,7 @@ async fn should_use_custom_session_id() { .expect("create session"); assert_eq!(session.id(), &requested_session_id); - let messages = session.get_messages().await.expect("messages"); + let messages = session.get_events().await.expect("messages"); let start_event = messages .iter() .find(|event| event.parsed_type() == SessionEventType::SessionStart) @@ -202,7 +202,7 @@ async fn should_apply_reasoning_effort_on_session_create() { .expect("create session"); let start_event = session - .get_messages() + .get_events() .await .expect("messages") .into_iter() @@ -258,7 +258,7 @@ async fn should_apply_all_reasoning_effort_values_on_session_create() { .unwrap_or_else(|err| panic!("create session with effort {effort}: {err}")); let start_event = session - .get_messages() + .get_events() .await .expect("messages") .into_iter() diff --git a/rust/tests/e2e/session_lifecycle.rs b/rust/tests/e2e/session_lifecycle.rs index e3c1fcd44..59cec701f 100644 --- a/rust/tests/e2e/session_lifecycle.rs +++ b/rust/tests/e2e/session_lifecycle.rs @@ -120,7 +120,7 @@ async fn should_return_events_via_getmessages_after_conversation() { .await .expect("send"); - let messages = session.get_messages().await.expect("get messages"); + let messages = session.get_events().await.expect("get messages"); let types = event_types(&messages); assert!(types.contains(&"session.start")); assert!(types.contains(&"user.message")); diff --git a/rust/tests/e2e/streaming_fidelity.rs b/rust/tests/e2e/streaming_fidelity.rs index 72e0554ac..c6d4c49c5 100644 --- a/rust/tests/e2e/streaming_fidelity.rs +++ b/rust/tests/e2e/streaming_fidelity.rs @@ -260,7 +260,7 @@ async fn should_emit_streaming_deltas_with_reasoning_effort_configured() { assert!(assistant.content.contains("255")); let start = session - .get_messages() + .get_events() .await .expect("get messages") .into_iter() diff --git a/rust/tests/e2e/support.rs b/rust/tests/e2e/support.rs index e08e3535a..2ead62dc2 100644 --- a/rust/tests/e2e/support.rs +++ b/rust/tests/e2e/support.rs @@ -411,7 +411,7 @@ pub async fn wait_for_final_assistant_message(session: &Session) -> SessionEvent #[allow(dead_code, reason = "used by follow-on E2E ports")] pub async fn last_assistant_message(session: &Session) -> SessionEvent { session - .get_messages() + .get_events() .await .expect("get session messages") .into_iter() diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index b9c28d30d..3def65ff6 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -394,7 +394,7 @@ async fn session_rpc_methods_send_correct_method_names() { match expected_method { "session.abort" => s.abort().await.map(|_| ()), "session.log" => s.log("test msg", None).await, - "session.destroy" => s.destroy().await, + "session.destroy" => s.disconnect().await, _ => unreachable!(), } }); @@ -966,7 +966,7 @@ async fn get_messages_returns_typed_events() { let handle = tokio::spawn({ let session = session.clone(); - async move { session.get_messages().await.unwrap() } + async move { session.get_events().await.unwrap() } }); let request = server.read_request().await; @@ -990,6 +990,37 @@ async fn get_messages_returns_typed_events() { assert_eq!(events[0].event_type, "user.message"); } +#[tokio::test] +#[allow(deprecated)] +async fn deprecated_get_messages_alias_still_works() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { session.get_messages().await.unwrap() } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.getMessages"); + server + .respond( + &request, + serde_json::json!({ + "events": [{ + "id": "e1", + "timestamp": "2025-01-01T00:00:00Z", + "type": "user.message", + "data": { "text": "hi" }, + }] + }), + ) + .await; + + let events = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert_eq!(events.len(), 1); +} + #[tokio::test] async fn set_model_sends_switch_to_request() { let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; From 6b0cdf3d945d92ba48bc6b24da2cb2a7fc6fb2e3 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 18:54:46 +0100 Subject: [PATCH 02/22] Phase B: MCP tools tri-state Option> McpStdioServerConfig.tools and McpHttpServerConfig.tools change from Vec to Option> with #[serde(default, skip_serializing_if = Option::is_none)]. Semantics now match the runtime contract and the other SDKs: - None (field omitted on wire) -> expose ALL tools - Some(vec![]) -> expose NO tools - Some(non-empty) -> explicit allowlist Previously an empty Vec was treated as 'no tools' on the wire but the runtime treats omission as 'all tools', so there was no way to opt back into 'all' once the field was set. Adds four tri-state serde tests (serialize + deserialize for each of stdio and http). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/src/types.rs | 24 ++++-- rust/tests/e2e/mcp_and_agents.rs | 6 +- rust/tests/e2e/rpc_mcp_and_skills.rs | 2 +- rust/tests/session_test.rs | 105 ++++++++++++++++++++++++++- 4 files changed, 124 insertions(+), 13 deletions(-) diff --git a/rust/src/types.rs b/rust/src/types.rs index b70adaa43..c9106b0d6 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -736,7 +736,7 @@ impl CloudSessionOptions { /// servers.insert( /// "playwright".to_string(), /// McpServerConfig::Stdio(McpStdioServerConfig { -/// tools: vec!["*".to_string()], +/// tools: Some(vec!["*".to_string()]), /// command: "npx".to_string(), /// args: vec!["-y".to_string(), "@playwright/mcp".to_string()], /// ..Default::default() @@ -745,7 +745,7 @@ impl CloudSessionOptions { /// servers.insert( /// "weather".to_string(), /// McpServerConfig::Http(McpHttpServerConfig { -/// tools: vec!["forecast".to_string()], +/// tools: Some(vec!["forecast".to_string()]), /// url: "https://example.com/mcp".to_string(), /// ..Default::default() /// }), @@ -772,9 +772,13 @@ pub enum McpServerConfig { #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct McpStdioServerConfig { - /// Tools to expose from this server. `["*"]` exposes all; `[]` exposes none. - #[serde(default)] - pub tools: Vec, + /// Tools to expose from this server. + /// + /// - `None` (field omitted on the wire) — expose **all** tools. + /// - `Some(vec![])` — expose **no** tools. + /// - `Some(vec!["a", ...])` — expose only the listed tools. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tools: Option>, /// Optional timeout in milliseconds for tool calls to this server. #[serde(default, skip_serializing_if = "Option::is_none")] pub timeout: Option, @@ -798,9 +802,13 @@ pub struct McpStdioServerConfig { #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct McpHttpServerConfig { - /// Tools to expose from this server. `["*"]` exposes all; `[]` exposes none. - #[serde(default)] - pub tools: Vec, + /// Tools to expose from this server. + /// + /// - `None` (field omitted on the wire) — expose **all** tools. + /// - `Some(vec![])` — expose **no** tools. + /// - `Some(vec!["a", ...])` — expose only the listed tools. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tools: Option>, /// Optional timeout in milliseconds for tool calls to this server. #[serde(default, skip_serializing_if = "Option::is_none")] pub timeout: Option, diff --git a/rust/tests/e2e/mcp_and_agents.rs b/rust/tests/e2e/mcp_and_agents.rs index a08275cde..a2881d038 100644 --- a/rust/tests/e2e/mcp_and_agents.rs +++ b/rust/tests/e2e/mcp_and_agents.rs @@ -51,7 +51,7 @@ async fn accept_mcp_server_config_without_args() { let mcp_servers = HashMap::from([( "test-server".to_string(), McpServerConfig::Stdio(McpStdioServerConfig { - tools: vec!["*".to_string()], + tools: Some(vec!["*".to_string()]), command: "echo".to_string(), ..McpStdioServerConfig::default() }), @@ -389,7 +389,7 @@ fn multiple_mcp_servers() -> HashMap { servers.insert( "server2".to_string(), McpServerConfig::Stdio(McpStdioServerConfig { - tools: vec!["*".to_string()], + tools: Some(vec!["*".to_string()]), command: echo_command(), args: echo_args("server2"), ..McpStdioServerConfig::default() @@ -402,7 +402,7 @@ fn test_mcp_servers(message: &str) -> HashMap { HashMap::from([( "test-server".to_string(), McpServerConfig::Stdio(McpStdioServerConfig { - tools: vec!["*".to_string()], + tools: Some(vec!["*".to_string()]), command: echo_command(), args: echo_args(message), ..McpStdioServerConfig::default() diff --git a/rust/tests/e2e/rpc_mcp_and_skills.rs b/rust/tests/e2e/rpc_mcp_and_skills.rs index 1d65a0416..932ac35a3 100644 --- a/rust/tests/e2e/rpc_mcp_and_skills.rs +++ b/rust/tests/e2e/rpc_mcp_and_skills.rs @@ -438,7 +438,7 @@ fn test_mcp_servers(message: &str) -> HashMap { HashMap::from([( message.to_string(), McpServerConfig::Stdio(McpStdioServerConfig { - tools: vec!["*".to_string()], + tools: Some(vec!["*".to_string()]), command: echo_command(), args: echo_args(message), ..McpStdioServerConfig::default() diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 3def65ff6..70fbad408 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -544,7 +544,7 @@ fn mcp_server_config_roundtrips_through_tagged_enum() { args: vec!["server.js".to_string()], env: HashMap::new(), cwd: None, - tools: vec!["*".to_string()], + tools: Some(vec!["*".to_string()]), timeout: None, }); let json = serde_json::to_value(&stdio).unwrap(); @@ -566,6 +566,109 @@ fn mcp_server_config_roundtrips_through_tagged_enum() { assert_eq!(cfg_json["github"]["type"], "stdio"); } +#[test] +fn mcp_stdio_tools_tri_state_serializes_correctly() { + use github_copilot_sdk::McpStdioServerConfig; + + // None → field omitted (= "expose all tools") + let cfg = McpStdioServerConfig { + command: "echo".into(), + tools: None, + ..Default::default() + }; + let json = serde_json::to_value(&cfg).unwrap(); + assert!( + json.get("tools").is_none(), + "tools=None must be omitted on the wire; got {json}" + ); + + // Some(empty) → field present as [] + let cfg = McpStdioServerConfig { + command: "echo".into(), + tools: Some(vec![]), + ..Default::default() + }; + let json = serde_json::to_value(&cfg).unwrap(); + assert_eq!(json["tools"], serde_json::json!([])); + + // Some(non-empty) → field present as the explicit list + let cfg = McpStdioServerConfig { + command: "echo".into(), + tools: Some(vec!["a".into(), "b".into()]), + ..Default::default() + }; + let json = serde_json::to_value(&cfg).unwrap(); + assert_eq!(json["tools"], serde_json::json!(["a", "b"])); +} + +#[test] +fn mcp_stdio_tools_tri_state_deserializes_correctly() { + use github_copilot_sdk::McpStdioServerConfig; + + // Missing field → None + let cfg: McpStdioServerConfig = + serde_json::from_value(serde_json::json!({ "command": "echo" })).unwrap(); + assert_eq!(cfg.tools, None); + + // Empty list → Some(empty) + let cfg: McpStdioServerConfig = + serde_json::from_value(serde_json::json!({ "command": "echo", "tools": [] })).unwrap(); + assert_eq!(cfg.tools, Some(vec![])); + + // Non-empty list → Some(list) + let cfg: McpStdioServerConfig = + serde_json::from_value(serde_json::json!({ "command": "echo", "tools": ["x"] })).unwrap(); + assert_eq!(cfg.tools, Some(vec!["x".to_string()])); +} + +#[test] +fn mcp_http_tools_tri_state_serializes_correctly() { + use github_copilot_sdk::McpHttpServerConfig; + + let cfg = McpHttpServerConfig { + url: "https://example.com".into(), + tools: None, + ..Default::default() + }; + assert!( + serde_json::to_value(&cfg).unwrap().get("tools").is_none(), + "tools=None must be omitted on the wire" + ); + + let cfg = McpHttpServerConfig { + url: "https://example.com".into(), + tools: Some(vec![]), + ..Default::default() + }; + assert_eq!( + serde_json::to_value(&cfg).unwrap()["tools"], + serde_json::json!([]) + ); + + let cfg = McpHttpServerConfig { + url: "https://example.com".into(), + tools: Some(vec!["a".into()]), + ..Default::default() + }; + assert_eq!( + serde_json::to_value(&cfg).unwrap()["tools"], + serde_json::json!(["a"]) + ); +} + +#[test] +fn mcp_http_tools_tri_state_deserializes_correctly() { + use github_copilot_sdk::McpHttpServerConfig; + + let cfg: McpHttpServerConfig = + serde_json::from_value(serde_json::json!({ "url": "https://e.com" })).unwrap(); + assert_eq!(cfg.tools, None); + + let cfg: McpHttpServerConfig = + serde_json::from_value(serde_json::json!({ "url": "https://e.com", "tools": [] })).unwrap(); + assert_eq!(cfg.tools, Some(vec![])); +} + #[test] fn permission_request_data_extracts_typed_kind() { use github_copilot_sdk::{PermissionRequestData, PermissionRequestKind}; From 9437696d8e876eb8d9ab88c88f0b6ea50950a309 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 18:57:21 +0100 Subject: [PATCH 03/22] Phase C: drop LogLevel::Info default When ClientOptions::log_level is None, the SDK no longer passes --log-level to the spawned CLI process at all. The CLI's built-in default takes over. Previously the SDK silently overrode the CLI default with Info, which made it impossible to opt back into the CLI default. - Extract a log_level_args() helper alongside auth_args / remote_args. - Update field rustdoc and remove 'Default.' annotation from LogLevel::Info. - Add two unit tests: log_level_args_omitted_when_unset and log_level_args_emit_flag_when_set. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/src/lib.rs | 49 ++++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/rust/src/lib.rs b/rust/src/lib.rs index abb1a72a4..0528739af 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -350,7 +350,8 @@ pub struct ClientOptions { /// [`Self::github_token`] is set, in which case false). pub use_logged_in_user: Option, /// Log level passed to the CLI server via `--log-level`. When `None`, - /// the SDK uses [`LogLevel::Info`]. + /// the SDK does not pass `--log-level` to the runtime at all and the + /// CLI uses its built-in default. pub log_level: Option, /// Server-wide idle timeout for sessions, in seconds. When set to a /// positive value, the SDK passes `--session-idle-timeout ` to @@ -475,7 +476,7 @@ pub enum LogLevel { Error, /// Warnings and errors. Warning, - /// Default. Info and above. + /// Info and above. Info, /// Debug, info, warnings, errors. Debug, @@ -1349,18 +1350,19 @@ impl Client { } } + fn log_level_args(options: &ClientOptions) -> Vec<&'static str> { + match options.log_level { + Some(level) => vec!["--log-level", level.as_str()], + None => Vec::new(), + } + } + fn spawn_stdio(program: &Path, options: &ClientOptions) -> Result { info!(cwd = ?options.cwd, program = %program.display(), "spawning copilot CLI (stdio)"); let mut command = Self::build_command(program, options); - let log_level = options.log_level.unwrap_or(LogLevel::Info); command - .args([ - "--server", - "--stdio", - "--no-auto-update", - "--log-level", - log_level.as_str(), - ]) + .args(["--server", "--stdio", "--no-auto-update"]) + .args(Self::log_level_args(options)) .args(Self::auth_args(options)) .args(Self::session_idle_timeout_args(options)) .args(Self::remote_args(options)) @@ -1382,16 +1384,9 @@ impl Client { ) -> Result<(Child, u16), Error> { info!(cwd = ?options.cwd, program = %program.display(), port = %port, "spawning copilot CLI (tcp)"); let mut command = Self::build_command(program, options); - let log_level = options.log_level.unwrap_or(LogLevel::Info); command - .args([ - "--server", - "--port", - &port.to_string(), - "--no-auto-update", - "--log-level", - log_level.as_str(), - ]) + .args(["--server", "--port", &port.to_string(), "--no-auto-update"]) + .args(Self::log_level_args(options)) .args(Self::auth_args(options)) .args(Self::session_idle_timeout_args(options)) .args(Self::remote_args(options)) @@ -2403,6 +2398,22 @@ mod tests { assert_eq!(Client::remote_args(&opts), vec!["--remote".to_string()]); } + #[test] + fn log_level_args_omitted_when_unset() { + let opts = ClientOptions::default(); + assert!(opts.log_level.is_none()); + assert!( + Client::log_level_args(&opts).is_empty(), + "with no caller-supplied log_level the SDK must not pass --log-level" + ); + } + + #[test] + fn log_level_args_emit_flag_when_set() { + let opts = ClientOptions::default().with_log_level(LogLevel::Debug); + assert_eq!(Client::log_level_args(&opts), vec!["--log-level", "debug"]); + } + #[test] fn log_level_str_round_trips() { for level in [ From 56ec07a4179b37dc2f65e41d6087598bbc362780 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 19:00:16 +0100 Subject: [PATCH 04/22] Phase F: InputOptions -> UiInputOptions Renames the public UI-input options struct to match the Go SDK and to make the connection to SessionUi::input explicit. Touches three files (src/types.rs, src/session.rs, tests/e2e/elicitation.rs); zero behaviour change. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/src/session.rs | 8 ++++---- rust/src/types.rs | 2 +- rust/tests/e2e/elicitation.rs | 10 +++++----- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/rust/src/session.rs b/rust/src/session.rs index 2c7241f03..3dee77849 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -28,10 +28,10 @@ use crate::trace_context::inject_trace_context; use crate::transforms::SystemMessageTransform; use crate::types::{ CommandContext, CommandDefinition, CommandHandler, CreateSessionResult, ElicitationRequest, - ElicitationResult, ExitPlanModeData, GetMessagesResponse, InputOptions, MessageOptions, + ElicitationResult, ExitPlanModeData, GetMessagesResponse, MessageOptions, PermissionRequestData, RequestId, ResumeSessionConfig, SectionOverride, SessionCapabilities, SessionConfig, SessionEvent, SessionId, SetModelOptions, SystemMessageConfig, ToolInvocation, - ToolResult, ToolResultExpanded, ToolResultResponse, TraceContext, + ToolResult, ToolResultExpanded, ToolResultResponse, TraceContext, UiInputOptions, ensure_attachment_display_names, }; use crate::{Client, Error, JsonRpcResponse, SessionError, SessionEventNotification, error_codes}; @@ -695,11 +695,11 @@ impl<'a> SessionUi<'a> { /// Ask the user for free-form text input. /// /// Returns the input string on accept, or `None` on decline/cancel. - /// Use [`InputOptions`] to set validation constraints and field metadata. + /// Use [`UiInputOptions`] to set validation constraints and field metadata. pub async fn input( &self, message: &str, - options: Option<&InputOptions<'_>>, + options: Option<&UiInputOptions<'_>>, ) -> Result, Error> { self.session.assert_elicitation()?; let mut field = serde_json::json!({ "type": "string" }); diff --git a/rust/src/types.rs b/rust/src/types.rs index c9106b0d6..442e70f35 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -3103,7 +3103,7 @@ pub struct UiCapabilities { /// Options for the [`SessionUi::input`](crate::session::SessionUi::input) convenience method. #[derive(Debug, Clone, Default)] -pub struct InputOptions<'a> { +pub struct UiInputOptions<'a> { /// Title label for the input field. pub title: Option<&'a str>, /// Descriptive text shown below the field. diff --git a/rust/tests/e2e/elicitation.rs b/rust/tests/e2e/elicitation.rs index 13b928bf7..b0d7d7714 100644 --- a/rust/tests/e2e/elicitation.rs +++ b/rust/tests/e2e/elicitation.rs @@ -4,8 +4,8 @@ use std::sync::Arc; use async_trait::async_trait; use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; use github_copilot_sdk::{ - ElicitationMode, ElicitationRequest, ElicitationResult, InputFormat, InputOptions, RequestId, - ResumeSessionConfig, SessionConfig, SessionId, UiCapabilities, + ElicitationMode, ElicitationRequest, ElicitationResult, InputFormat, RequestId, + ResumeSessionConfig, SessionConfig, SessionId, UiCapabilities, UiInputOptions, }; use serde_json::json; use tokio::sync::Mutex; @@ -304,13 +304,13 @@ async fn input_returns_freeform_value() { ) .await .expect("create session"); - let options = InputOptions { + let options = UiInputOptions { title: Some("Value"), description: Some("A value to test"), min_length: Some(1), max_length: Some(20), default: Some("default"), - ..InputOptions::default() + ..UiInputOptions::default() }; assert_eq!( @@ -465,7 +465,7 @@ async fn elicitation_result_types_are_properly_structured() { #[tokio::test] async fn input_options_has_all_properties() { - let options = InputOptions { + let options = UiInputOptions { title: Some("Email Address"), description: Some("Enter your email"), min_length: Some(5), From ca26819f4e730848dee951b38acb149069a43818 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 19:14:43 +0100 Subject: [PATCH 05/22] Phase D: connection consolidation Folds previously-scattered transport configuration into the Transport enum: Transport::Tcp { port, connection_token } Transport::External { host, port, connection_token } Removes: - ClientOptions::tcp_connection_token + with_tcp_connection_token - the runtime check that rejected (Stdio + token) -- now a type error - public ConnectionState enum and Client::state() accessor Renames: - ClientOptions::copilot_home -> base_directory + with_base_directory - ClientOptions::remote -> enable_remote_sessions + with_enable_remote_sessions - the runtime env var passed to the child stays COPILOT_HOME - the spawn arg stays --remote ConnectionState is demoted to pub(crate) and loses its serde derives; nothing outside the crate reads it. The internal state-tracking field on ClientInner is unchanged. All call sites in tests, examples, and e2e support helpers updated to the new Transport shape. Adds a unit test for empty-string connection_token on External transport (the existing Tcp test was adapted, the External one is new). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/examples/lifecycle_observer.rs | 6 +- rust/src/lib.rs | 188 +++++++++--------- rust/src/types.rs | 45 ++--- rust/tests/e2e/client.rs | 20 +- rust/tests/e2e/client_lifecycle.rs | 10 +- rust/tests/e2e/client_options.rs | 26 ++- rust/tests/e2e/multi_client.rs | 20 +- .../e2e/multi_client_commands_elicitation.rs | 20 +- rust/tests/e2e/pending_work_resume.rs | 20 +- rust/tests/e2e/support.rs | 8 +- rust/tests/session_test.rs | 22 +- 11 files changed, 169 insertions(+), 216 deletions(-) diff --git a/rust/examples/lifecycle_observer.rs b/rust/examples/lifecycle_observer.rs index 25ea64502..fe3654099 100644 --- a/rust/examples/lifecycle_observer.rs +++ b/rust/examples/lifecycle_observer.rs @@ -39,7 +39,7 @@ use github_copilot_sdk::{Client, ClientOptions}; #[tokio::main] async fn main() -> Result<(), github_copilot_sdk::Error> { let client = Client::start(ClientOptions::default()).await?; - println!("[client] state: {:?}", client.state()); + println!("[client] started, pid: {:?}", client.pid()); // Wildcard lifecycle subscriber: see every session.lifecycle event, // counting deletions inline by filtering on event_type. @@ -65,7 +65,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let config = SessionConfig::default().with_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; - println!("[client] state after create: {:?}", client.state()); + println!("[client] session created: {}", session.id()); // Per-session observer: see every assistant message, tool call, etc. // Subscribers fire alongside the constructor handler; they're great for @@ -103,7 +103,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { // where you don't have an async runtime available to await `stop()`. // For graceful shutdown in normal flow, prefer `client.stop().await`. client.force_stop(); - println!("[client] state after force_stop: {:?}", client.state()); + println!("[client] force-stopped"); // Stopping the client closes the broadcast senders, so the consumer // tasks observe `RecvError::Closed` and exit cleanly. diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 0528739af..d917d6821 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -289,6 +289,10 @@ pub enum Transport { Tcp { /// Port to listen on (0 for OS-assigned). port: u16, + /// Optional connection token. When `None` and the SDK is spawning + /// the CLI, the SDK auto-generates a 128-bit hex token so the + /// loopback listener is safe by default. + connection_token: Option, }, /// Connect to an already-running CLI server (no process spawning). External { @@ -296,6 +300,9 @@ pub enum Transport { host: String, /// Port of the running server. port: u16, + /// Optional connection token. Required when the external server + /// was started with a token, ignored otherwise. + connection_token: Option, }, } @@ -393,23 +400,13 @@ pub struct ClientOptions { /// auth, telemetry buffers). When set, exported as `COPILOT_HOME` to /// the spawned CLI process. Useful for sandboxing test runs or /// running multiple isolated SDK instances side-by-side. - pub copilot_home: Option, - /// Optional connection token for TCP transport. Sent to the CLI in - /// the `connect` handshake and exported as `COPILOT_CONNECTION_TOKEN` - /// to spawned CLI processes. Required when the CLI server was started - /// with a token, ignored otherwise. - /// - /// When the SDK spawns its own CLI in TCP mode and this is left - /// `None`, a UUID is generated automatically so the loopback listener - /// is safe by default. Combining with [`Transport::Stdio`] is invalid - /// and surfaces as an error from [`Client::start`]. - pub tcp_connection_token: Option, + pub base_directory: Option, /// Enable remote session support (Mission Control integration). /// When `true`, the SDK passes `--remote` to the spawned CLI process so /// sessions in a GitHub repository working directory are accessible from /// GitHub web and mobile. Ignored when connecting to an external server /// via [`Transport::External`]. - pub remote: bool, + pub enable_remote_sessions: bool, } impl std::fmt::Debug for ClientOptions { @@ -442,12 +439,8 @@ impl std::fmt::Debug for ClientOptions { &self.on_get_trace_context.as_ref().map(|_| ""), ) .field("telemetry", &self.telemetry) - .field("copilot_home", &self.copilot_home) - .field( - "tcp_connection_token", - &self.tcp_connection_token.as_ref().map(|_| ""), - ) - .field("remote", &self.remote) + .field("base_directory", &self.base_directory) + .field("enable_remote_sessions", &self.enable_remote_sessions) .finish() } } @@ -652,9 +645,8 @@ impl Default for ClientOptions { session_fs: None, on_get_trace_context: None, telemetry: None, - copilot_home: None, - tcp_connection_token: None, - remote: false, + base_directory: None, + enable_remote_sessions: false, } } } @@ -800,23 +792,15 @@ impl ClientOptions { /// Override the directory where the CLI persists its state. Set as /// `COPILOT_HOME` on the spawned CLI process. - pub fn with_copilot_home(mut self, home: impl Into) -> Self { - self.copilot_home = Some(home.into()); - self - } - - /// Set the connection token for TCP transport. Sent in the `connect` - /// handshake and exported as `COPILOT_CONNECTION_TOKEN` to spawned - /// CLI processes. - pub fn with_tcp_connection_token(mut self, token: impl Into) -> Self { - self.tcp_connection_token = Some(token.into()); + pub fn with_base_directory(mut self, dir: impl Into) -> Self { + self.base_directory = Some(dir.into()); self } /// Enable remote session support (Mission Control). Passes `--remote` /// to the spawned CLI process. - pub fn with_remote(mut self, enabled: bool) -> Self { - self.remote = enabled; + pub fn with_enable_remote_sessions(mut self, enabled: bool) -> Self { + self.enable_remote_sessions = enabled; self } } @@ -930,39 +914,48 @@ impl Client { )); } } - // Validate token + transport combination. Stdio cannot use a - // connection token; auto-generate a UUID when the SDK spawns - // its own CLI in TCP mode and no explicit token was set. - if let Some(token) = &options.tcp_connection_token { - if token.is_empty() { - return Err(Error::InvalidConfig( - "tcp_connection_token must be a non-empty string".to_string(), - )); + // Validate token shape. Stdio variants no longer carry a token + // (enforced by the type). For Tcp/External, empty-string is + // rejected eagerly. + match &options.transport { + Transport::Tcp { + connection_token: Some(t), + .. } - if matches!(options.transport, Transport::Stdio) { + | Transport::External { + connection_token: Some(t), + .. + } if t.is_empty() => { return Err(Error::InvalidConfig( - "tcp_connection_token cannot be used with Transport::Stdio".to_string(), + "connection_token must be a non-empty string".to_string(), )); } + _ => {} } - let effective_connection_token: Option = match &options.transport { - Transport::Stdio => None, - Transport::Tcp { .. } => Some( - options - .tcp_connection_token - .clone() - .unwrap_or_else(generate_connection_token), - ), - Transport::External { .. } => options.tcp_connection_token.clone(), + // Capture (and where needed, auto-generate) the token actually sent + // to the server. For Tcp, the SDK auto-generates one when the + // caller leaves it unset so the loopback listener is safe by + // default. + let (mut options, effective_connection_token) = { + let mut options = options; + let effective = match &mut options.transport { + Transport::Stdio => None, + Transport::Tcp { + connection_token, .. + } => { + if connection_token.is_none() { + *connection_token = Some(generate_connection_token()); + } + connection_token.clone() + } + Transport::External { + connection_token, .. + } => connection_token.clone(), + }; + (options, effective) }; - let mut options = options; - if matches!(options.transport, Transport::Tcp { .. }) - && options.tcp_connection_token.is_none() - { - // Auto-generated tokens flow to the spawned CLI via env, so - // make the field reflect what we'll actually send. - options.tcp_connection_token = effective_connection_token.clone(); - } + let _ = &mut options; + let effective_connection_token: Option = effective_connection_token; let session_fs_config = options.session_fs.clone(); let session_fs_sqlite_declared = session_fs_config .as_ref() @@ -994,7 +987,11 @@ impl Client { }; let client = match options.transport { - Transport::External { ref host, port } => { + Transport::External { + ref host, + port, + connection_token: _, + } => { info!(host = %host, port = %port, "connecting to external CLI server"); let connect_start = Instant::now(); let stream = TcpStream::connect((host.as_str(), port)).await?; @@ -1017,7 +1014,10 @@ impl Client { effective_connection_token.clone(), )? } - Transport::Tcp { port } => { + Transport::Tcp { + port, + connection_token: _, + } => { let (mut child, actual_port) = Self::spawn_tcp(&program, &options, port).await?; let connect_start = Instant::now(); let stream = TcpStream::connect(("127.0.0.1", actual_port)).await?; @@ -1281,10 +1281,14 @@ impl Client { ); } } - if let Some(home) = &options.copilot_home { - command.env("COPILOT_HOME", home); + if let Some(dir) = &options.base_directory { + command.env("COPILOT_HOME", dir); } - if let Some(token) = &options.tcp_connection_token { + if let Transport::Tcp { + connection_token: Some(token), + .. + } = &options.transport + { command.env("COPILOT_CONNECTION_TOKEN", token); } for (key, value) in &options.env { @@ -1343,7 +1347,7 @@ impl Client { } fn remote_args(options: &ClientOptions) -> Vec { - if options.remote { + if options.enable_remote_sessions { vec!["--remote".to_string()] } else { Vec::new() @@ -1581,8 +1585,8 @@ impl Client { /// /// # Handshake sequence /// - /// 1. Sends the `connect` JSON-RPC method, forwarding - /// [`ClientOptions::tcp_connection_token`] (or the auto-generated + /// 1. Sends the `connect` JSON-RPC method, forwarding the + /// [`Transport`]'s `connection_token` (or the auto-generated /// token for SDK-spawned TCP servers) as the `token` param. This /// is the canonical handshake used by all SDK languages and is /// what the CLI uses to enforce loopback authentication when @@ -1647,7 +1651,7 @@ impl Client { /// Send the `connect` JSON-RPC handshake. Returns the server's /// reported protocol version, or `None` if the server omits it. - /// Forwards [`ClientOptions::tcp_connection_token`] (or the + /// Forwards the [`Transport`]'s `connection_token` (or the /// auto-generated token for SDK-spawned TCP servers) as the `token` /// param. Server-side, the token is required when the server was /// started with `COPILOT_CONNECTION_TOKEN`. @@ -1981,16 +1985,6 @@ impl Client { pub fn subscribe_lifecycle(&self) -> LifecycleSubscription { LifecycleSubscription::new(self.inner.lifecycle_tx.subscribe()) } - - /// Return the current [`ConnectionState`]. - /// - /// The state advances to [`Connected`](ConnectionState::Connected) once - /// [`Client::start`] / [`Client::from_streams`] returns successfully and - /// drops to [`Disconnected`](ConnectionState::Disconnected) after - /// [`stop`](Self::stop) or [`force_stop`](Self::force_stop). - pub fn state(&self) -> ConnectionState { - *self.inner.state.lock() - } } impl Drop for ClientInner { @@ -2050,7 +2044,7 @@ mod tests { .with_use_logged_in_user(false) .with_log_level(LogLevel::Debug) .with_session_idle_timeout_seconds(120) - .with_remote(true); + .with_enable_remote_sessions(true); assert!(matches!(opts.program, CliProgram::Path(_))); assert_eq!(opts.prefix_args, vec![std::ffi::OsString::from("node")]); assert_eq!(opts.cwd, PathBuf::from("/tmp")); @@ -2067,7 +2061,7 @@ mod tests { assert_eq!(opts.use_logged_in_user, Some(false)); assert!(matches!(opts.log_level, Some(LogLevel::Debug))); assert_eq!(opts.session_idle_timeout_seconds, Some(120)); - assert!(opts.remote); + assert!(opts.enable_remote_sessions); } #[test] @@ -2270,7 +2264,7 @@ mod tests { #[test] fn build_command_sets_copilot_home_env_when_configured() { - let opts = ClientOptions::new().with_copilot_home(PathBuf::from("/custom/copilot")); + let opts = ClientOptions::new().with_base_directory(PathBuf::from("/custom/copilot")); let cmd = Client::build_command(Path::new("/bin/echo"), &opts); assert_eq!( env_value(&cmd, "COPILOT_HOME"), @@ -2284,7 +2278,10 @@ mod tests { #[test] fn build_command_sets_connection_token_env_when_configured() { - let opts = ClientOptions::new().with_tcp_connection_token("secret-token"); + let opts = ClientOptions::new().with_transport(Transport::Tcp { + port: 0, + connection_token: Some("secret-token".to_string()), + }); let cmd = Client::build_command(Path::new("/bin/echo"), &opts); assert_eq!( env_value(&cmd, "COPILOT_CONNECTION_TOKEN"), @@ -2297,26 +2294,25 @@ mod tests { } #[tokio::test] - async fn start_rejects_token_with_stdio_transport() { + async fn start_rejects_empty_connection_token() { let opts = ClientOptions::new() - .with_tcp_connection_token("token-123") + .with_transport(Transport::Tcp { + port: 0, + connection_token: Some(String::new()), + }) .with_program(CliProgram::Path(PathBuf::from("/bin/echo"))); let err = Client::start(opts).await.unwrap_err(); assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}"); - let Error::InvalidConfig(msg) = err else { - unreachable!() - }; - assert!( - msg.contains("Stdio"), - "error should explain the stdio incompatibility: {msg}" - ); } #[tokio::test] - async fn start_rejects_empty_connection_token() { + async fn start_rejects_empty_external_connection_token() { let opts = ClientOptions::new() - .with_tcp_connection_token("") - .with_transport(Transport::Tcp { port: 0 }) + .with_transport(Transport::External { + host: "127.0.0.1".to_string(), + port: 1, + connection_token: Some(String::new()), + }) .with_program(CliProgram::Path(PathBuf::from("/bin/echo"))); let err = Client::start(opts).await.unwrap_err(); assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}"); @@ -2392,7 +2388,7 @@ mod tests { #[test] fn remote_args_emit_flag_when_enabled() { let opts = ClientOptions { - remote: true, + enable_remote_sessions: true, ..Default::default() }; assert_eq!(Client::remote_args(&opts), vec!["--remote".to_string()]); diff --git a/rust/src/types.rs b/rust/src/types.rs index 442e70f35..fbde0a271 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -22,17 +22,12 @@ pub use crate::session_fs::{ pub use crate::trace_context::{TraceContext, TraceContextProvider}; use crate::transforms::SystemMessageTransform; -/// Lifecycle state of a [`Client`](crate::Client) connection to the CLI. -/// -/// The state advances from `Connecting` → `Connected` during construction, -/// transitions to `Disconnected` after [`Client::stop`](crate::Client::stop) or -/// [`Client::force_stop`](crate::Client::force_stop), and lands in -/// `Error` if startup fails or the underlying transport tears down -/// unexpectedly. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] +/// Lifecycle state of a [`Client`](crate::Client) connection. Internal — +/// not part of the public API. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[allow(dead_code)] #[non_exhaustive] -pub enum ConnectionState { +pub(crate) enum ConnectionState { /// No CLI process is attached or the process has exited cleanly. Disconnected, /// The client is starting up (spawning the CLI, negotiating protocol). @@ -3705,11 +3700,10 @@ mod tests { } #[test] - fn connection_state_error_serializes_to_match_go() { - let json = serde_json::to_string(&ConnectionState::Error).unwrap(); - assert_eq!(json, "\"error\""); - let parsed: ConnectionState = serde_json::from_str("\"error\"").unwrap(); - assert_eq!(parsed, ConnectionState::Error); + fn connection_state_distinguishes_variants() { + // ConnectionState is now an internal type; verify we can construct + // and compare the variants used by the lifecycle code paths. + assert_ne!(ConnectionState::Connected, ConnectionState::Disconnected); } /// `agentId` is the sub-agent attribution field added in copilot-sdk @@ -3769,19 +3763,14 @@ mod tests { } #[test] - fn connection_state_other_variants_serialize_as_lowercase() { - assert_eq!( - serde_json::to_string(&ConnectionState::Disconnected).unwrap(), - "\"disconnected\"" - ); - assert_eq!( - serde_json::to_string(&ConnectionState::Connecting).unwrap(), - "\"connecting\"" - ); - assert_eq!( - serde_json::to_string(&ConnectionState::Connected).unwrap(), - "\"connected\"" - ); + fn connection_state_variants_compile() { + // Defensive smoke test: all variants must be constructable from + // within the crate. (The enum was demoted from pub to pub(crate) + // in Phase D; this test guards against accidental removal.) + let _ = ConnectionState::Disconnected; + let _ = ConnectionState::Connecting; + let _ = ConnectionState::Connected; + let _ = ConnectionState::Error; } #[test] diff --git a/rust/tests/e2e/client.rs b/rust/tests/e2e/client.rs index 2003be4b8..114e828ac 100644 --- a/rust/tests/e2e/client.rs +++ b/rust/tests/e2e/client.rs @@ -3,7 +3,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use async_trait::async_trait; use github_copilot_sdk::{ - CliProgram, Client, ClientOptions, ConnectionState, Error, ListModelsHandler, Model, Transport, + CliProgram, Client, ClientOptions, Error, ListModelsHandler, Model, Transport, }; use super::support::with_e2e_context; @@ -13,14 +13,11 @@ async fn should_start_ping_and_stop_stdio_client() { with_e2e_context("client", "should_start_ping_and_stop_stdio_client", |ctx| { Box::pin(async move { let client = ctx.start_client().await; - assert_eq!(client.state(), ConnectionState::Connected); - let response = client.ping(Some("hello from rust")).await.expect("ping"); assert_eq!(response.message, "pong: hello from rust"); assert!(!response.timestamp.is_empty()); client.stop().await.expect("stop client"); - assert_eq!(client.state(), ConnectionState::Disconnected); }) }) .await; @@ -30,19 +27,16 @@ async fn should_start_ping_and_stop_stdio_client() { async fn should_start_ping_and_stop_tcp_client() { with_e2e_context("client", "should_start_ping_and_stop_tcp_client", |ctx| { Box::pin(async move { - let client = Client::start( - ctx.client_options_with_transport(Transport::Tcp { port: 0 }) - .with_tcp_connection_token("tcp-e2e-token"), - ) + let client = Client::start(ctx.client_options_with_transport(Transport::Tcp { + port: 0, + connection_token: Some("tcp-e2e-token".to_string()), + })) .await .expect("start TCP client"); - assert_eq!(client.state(), ConnectionState::Connected); - let response = client.ping(Some("tcp hello")).await.expect("ping"); assert_eq!(response.message, "pong: tcp hello"); client.stop().await.expect("stop client"); - assert_eq!(client.state(), ConnectionState::Disconnected); }) }) .await; @@ -121,7 +115,6 @@ async fn should_stop_client_with_active_session() { .expect("create session"); client.stop().await.expect("stop client"); - assert_eq!(client.state(), ConnectionState::Disconnected); }) }) .await; @@ -132,10 +125,7 @@ async fn should_force_stop_client() { with_e2e_context("client", "should_force_stop_client", |ctx| { Box::pin(async move { let client = ctx.start_client().await; - assert_eq!(client.state(), ConnectionState::Connected); - client.force_stop(); - assert_eq!(client.state(), ConnectionState::Disconnected); }) }) .await; diff --git a/rust/tests/e2e/client_lifecycle.rs b/rust/tests/e2e/client_lifecycle.rs index 05fdb4a83..75646b486 100644 --- a/rust/tests/e2e/client_lifecycle.rs +++ b/rust/tests/e2e/client_lifecycle.rs @@ -1,4 +1,4 @@ -use github_copilot_sdk::{ConnectionState, SessionLifecycleEventType}; +use github_copilot_sdk::SessionLifecycleEventType; use serde_json::json; use super::support::{wait_for_lifecycle_event, with_e2e_context}; @@ -105,11 +105,7 @@ async fn dispose_disconnects_client_and_disposes_rpc_surface_async() { |ctx| { Box::pin(async move { let client = ctx.start_client().await; - assert_eq!(client.state(), ConnectionState::Connected); - client.stop().await.expect("stop client"); - - assert_eq!(client.state(), ConnectionState::Disconnected); assert!( client.call("rpc.ping", Some(json!({}))).await.is_err(), "stopped client should reject RPC calls" @@ -128,11 +124,7 @@ async fn dispose_disconnects_client_and_disposes_rpc_surface_drop() { |ctx| { Box::pin(async move { let client = ctx.start_client().await; - assert_eq!(client.state(), ConnectionState::Connected); - client.force_stop(); - - assert_eq!(client.state(), ConnectionState::Disconnected); assert!( client.call("rpc.ping", Some(json!({}))).await.is_err(), "force-stopped client should reject RPC calls" diff --git a/rust/tests/e2e/client_options.rs b/rust/tests/e2e/client_options.rs index 441ce48c0..022a6bf5b 100644 --- a/rust/tests/e2e/client_options.rs +++ b/rust/tests/e2e/client_options.rs @@ -52,10 +52,10 @@ async fn should_listen_on_configured_tcp_port() { |ctx| { Box::pin(async move { let port = get_available_tcp_port(); - let client = Client::start( - ctx.client_options_with_transport(Transport::Tcp { port }) - .with_tcp_connection_token("configured-port-token"), - ) + let client = Client::start(ctx.client_options_with_transport(Transport::Tcp { + port, + connection_token: Some("configured-port-token".to_string()), + })) .await .expect("start TCP client"); @@ -200,14 +200,23 @@ async fn auto_start_false_requires_explicit_start() { &options.program, github_copilot_sdk::CliProgram::Resolve )); - assert!(options.copilot_home.is_none()); + assert!(options.base_directory.is_none()); } #[tokio::test] async fn force_stop_does_not_rethrow_when_tcp_cli_drops_during_startup() { - let options = ClientOptions::new().with_transport(Transport::Tcp { port: 0 }); + let options = ClientOptions::new().with_transport(Transport::Tcp { + port: 0, + connection_token: None, + }); - assert!(matches!(options.transport, Transport::Tcp { port: 0 })); + assert!(matches!( + options.transport, + Transport::Tcp { + port: 0, + connection_token: None + } + )); } #[tokio::test] @@ -215,6 +224,7 @@ async fn startasync_cleans_up_tcp_cli_process_when_connect_fails() { let options = ClientOptions::new().with_transport(Transport::External { host: "127.0.0.1".to_string(), port: get_available_tcp_port(), + connection_token: None, }); assert!(matches!(options.transport, Transport::External { .. })); @@ -239,6 +249,7 @@ async fn should_throw_when_githubtoken_used_with_cliurl() { .with_transport(Transport::External { host: "localhost".to_string(), port: 12345, + connection_token: None, }) .with_github_token("token"); @@ -262,6 +273,7 @@ async fn should_throw_when_useloggedinuser_used_with_cliurl() { .with_transport(Transport::External { host: "localhost".to_string(), port: 12345, + connection_token: None, }) .with_use_logged_in_user(true); diff --git a/rust/tests/e2e/multi_client.rs b/rust/tests/e2e/multi_client.rs index 6ea5c567e..874bd7d0f 100644 --- a/rust/tests/e2e/multi_client.rs +++ b/rust/tests/e2e/multi_client.rs @@ -429,22 +429,20 @@ fn resume_config(session_id: SessionId) -> ResumeSessionConfig { } async fn start_tcp_server(ctx: &E2eContext, port: u16) -> Client { - Client::start( - ctx.client_options_with_transport(Transport::Tcp { port }) - .with_tcp_connection_token(SHARED_TOKEN), - ) + Client::start(ctx.client_options_with_transport(Transport::Tcp { + port, + connection_token: Some(SHARED_TOKEN.to_string()), + })) .await .expect("start TCP server client") } async fn start_external_client(ctx: &E2eContext, port: u16) -> Client { - Client::start( - ctx.client_options_with_transport(Transport::External { - host: "127.0.0.1".to_string(), - port, - }) - .with_tcp_connection_token(SHARED_TOKEN), - ) + Client::start(ctx.client_options_with_transport(Transport::External { + host: "127.0.0.1".to_string(), + port, + connection_token: Some(SHARED_TOKEN.to_string()), + })) .await .expect("start external client") } diff --git a/rust/tests/e2e/multi_client_commands_elicitation.rs b/rust/tests/e2e/multi_client_commands_elicitation.rs index e47504c03..5a3948659 100644 --- a/rust/tests/e2e/multi_client_commands_elicitation.rs +++ b/rust/tests/e2e/multi_client_commands_elicitation.rs @@ -204,22 +204,20 @@ fn resume_config(session_id: SessionId) -> ResumeSessionConfig { } async fn start_tcp_server(ctx: &E2eContext, port: u16) -> Client { - Client::start( - ctx.client_options_with_transport(Transport::Tcp { port }) - .with_tcp_connection_token(SHARED_TOKEN), - ) + Client::start(ctx.client_options_with_transport(Transport::Tcp { + port, + connection_token: Some(SHARED_TOKEN.to_string()), + })) .await .expect("start TCP server client") } async fn start_external_client(ctx: &E2eContext, port: u16) -> Client { - Client::start( - ctx.client_options_with_transport(Transport::External { - host: "127.0.0.1".to_string(), - port, - }) - .with_tcp_connection_token(SHARED_TOKEN), - ) + Client::start(ctx.client_options_with_transport(Transport::External { + host: "127.0.0.1".to_string(), + port, + connection_token: Some(SHARED_TOKEN.to_string()), + })) .await .expect("start external client") } diff --git a/rust/tests/e2e/pending_work_resume.rs b/rust/tests/e2e/pending_work_resume.rs index 5855d8030..9f173fa08 100644 --- a/rust/tests/e2e/pending_work_resume.rs +++ b/rust/tests/e2e/pending_work_resume.rs @@ -267,22 +267,20 @@ fn resume_config(session_id: SessionId) -> ResumeSessionConfig { } async fn start_tcp_server(ctx: &E2eContext, port: u16) -> Client { - Client::start( - ctx.client_options_with_transport(Transport::Tcp { port }) - .with_tcp_connection_token(SHARED_TOKEN), - ) + Client::start(ctx.client_options_with_transport(Transport::Tcp { + port, + connection_token: Some(SHARED_TOKEN.to_string()), + })) .await .expect("start TCP server client") } async fn start_external_client(ctx: &E2eContext, port: u16) -> Client { - Client::start( - ctx.client_options_with_transport(Transport::External { - host: "127.0.0.1".to_string(), - port, - }) - .with_tcp_connection_token(SHARED_TOKEN), - ) + Client::start(ctx.client_options_with_transport(Transport::External { + host: "127.0.0.1".to_string(), + port, + connection_token: Some(SHARED_TOKEN.to_string()), + })) .await .expect("start external client") } diff --git a/rust/tests/e2e/support.rs b/rust/tests/e2e/support.rs index 2ead62dc2..b3d58a490 100644 --- a/rust/tests/e2e/support.rs +++ b/rust/tests/e2e/support.rs @@ -120,10 +120,10 @@ impl E2eContext { #[expect(dead_code, reason = "used by follow-on E2E ports")] pub async fn start_tcp_client(&self, port: u16, token: &str) -> Client { - Client::start( - self.client_options_with_transport(Transport::Tcp { port }) - .with_tcp_connection_token(token), - ) + Client::start(self.client_options_with_transport(Transport::Tcp { + port, + connection_token: Some(token.to_string()), + })) .await .expect("start TCP E2E client") } diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 70fbad408..e460b214e 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -702,35 +702,15 @@ async fn force_stop_is_idempotent_with_no_child() { // Stream-based clients have no child process. force_stop should be a // no-op and safe to call multiple times. let (client, _server_read, _server_write) = make_client(); - assert_eq!( - client.state(), - github_copilot_sdk::ConnectionState::Connected - ); client.force_stop(); - assert_eq!( - client.state(), - github_copilot_sdk::ConnectionState::Disconnected - ); client.force_stop(); - assert_eq!( - client.state(), - github_copilot_sdk::ConnectionState::Disconnected - ); assert!(client.pid().is_none()); } #[tokio::test] -async fn stop_transitions_state_to_disconnected() { +async fn stop_is_safe_to_call() { let (client, _server_read, _server_write) = make_client(); - assert_eq!( - client.state(), - github_copilot_sdk::ConnectionState::Connected - ); client.stop().await.expect("stop should succeed"); - assert_eq!( - client.state(), - github_copilot_sdk::ConnectionState::Disconnected - ); } #[tokio::test] From 20a379b8bf27906156ddb9d2e9c9e944523282de Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 19:41:25 +0100 Subject: [PATCH 06/22] Phase H + I: multi-client broadcast gating + permission-policy ordering Phase H -- match TS/C# semantically for permission/elicitation/external-tool broadcasts on shared CLI connections: 1. Three new SessionHandler probe methods (default false): wants_permission_dispatch() wants_elicitation_dispatch() wants_external_tool_dispatch(tool_name) ApproveAllHandler and DenyAllHandler override to true for permission; NoopHandler keeps the safe default of false everywhere. PermissionOverrideHandler claims permission and forwards the others to its inner handler. ToolHandlerRouter advertises the specific tool names it knows. 2. session.rs broadcast dispatcher gates on the probes BEFORE invoking the handler or sending an RPC response: - permission.requested: skip if data.resolvedByHook is true, skip if !wants_permission_dispatch() - elicitation.requested: skip if !wants_elicitation_dispatch() - external_tool.requested: skip if !wants_external_tool_dispatch(name) 3. SessionConfig::default() / ResumeSessionConfig::new() no longer hardcode requestPermission=true / requestElicitation=true. Those wire flags are now derived from the handler probes at Client::create_session and Client::resume_session time, so the runtime never broadcasts permission/elicitation events to a client that wouldn't respond. 4. Removed the public SessionConfig::with_request_permission / with_request_elicitation builders (and ResumeSessionConfig equivalents) -- the flags are derived now, not user-set. Phase I -- approve_all_permissions ordering trap: The previous design wrapped the handler in-place inside the policy builder, which made the call order to with_handler significant. Now permission_policy lives in a separate (pub(crate)) field on SessionConfig/ResumeSessionConfig and the wrap is applied at Client::create_session / Client::resume_session time, after both fields are finalised. Call order is now irrelevant. Tests: - request_elicitation_sent_in_create_params updated for the new derivation (ApproveAllHandler -> requestPermission=true, requestElicitation=false). - new noop_handler_sends_request_permission_false validates the H1a wire-flag derivation in both directions. - new external_tool_broadcast_for_unknown_tool_is_not_responded_to, permission_broadcast_with_resolved_by_hook_is_not_responded_to, permission_broadcast_with_no_claiming_handler_is_not_responded_to, elicitation_broadcast_with_no_claiming_handler_is_not_responded_to. - new session_config_*_is_order_independent tests for Phase I. - multi_client e2e tests' PermissionDecisionHandler / SelectiveToolHandler updated to advertise the probes. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/src/handler.rs | 45 ++- rust/src/permission.rs | 37 ++- rust/src/session.rs | 47 +++- rust/src/tool.rs | 12 + rust/src/types.rs | 263 +++++++++++------- rust/tests/e2e/elicitation.rs | 15 +- rust/tests/e2e/multi_client.rs | 16 +- .../e2e/multi_client_commands_elicitation.rs | 10 +- rust/tests/session_test.rs | 168 ++++++++++- 9 files changed, 479 insertions(+), 134 deletions(-) diff --git a/rust/src/handler.rs b/rust/src/handler.rs index 565b09d56..de3f434e1 100644 --- a/rust/src/handler.rs +++ b/rust/src/handler.rs @@ -281,6 +281,41 @@ pub enum AutoModeSwitchResponse { /// ``` #[async_trait] pub trait SessionHandler: Send + Sync + 'static { + /// Whether this handler claims responsibility for dispatching + /// `permission.requested` broadcasts on this session. + /// + /// Defaults to `false`. When `false`, the SDK does not send the + /// `requestPermission: true` flag on `session.create` / `session.resume`, + /// and the runtime short-circuits permission prompts with + /// `user-not-available` rather than emitting a broadcast this client + /// would silently never respond to. + /// + /// Override to `true` in handlers that implement + /// [`on_permission_request`](Self::on_permission_request) with a real + /// policy. Permission-policy wrappers + /// ([`approve_all_permissions`](crate::types::SessionConfig::approve_all_permissions) + /// and friends) flip this on automatically. + fn wants_permission_dispatch(&self) -> bool { + false + } + + /// Whether this handler claims responsibility for dispatching + /// `elicitation.requested` broadcasts on this session. Defaults to + /// `false`. See [`wants_permission_dispatch`](Self::wants_permission_dispatch) + /// for the multi-client rationale. + fn wants_elicitation_dispatch(&self) -> bool { + false + } + + /// Whether this handler claims responsibility for dispatching an + /// `external_tool.requested` broadcast for the given tool name on + /// this session. Defaults to `false`. Tool routers + /// ([`ToolHandlerRouter`](crate::tool::ToolHandlerRouter)) override to + /// return `true` for tools they have a registered handler for. + fn wants_external_tool_dispatch(&self, _tool_name: &str) -> bool { + false + } + /// Handle an event from the session. /// /// The default implementation destructures `event` and calls the @@ -449,6 +484,10 @@ pub struct ApproveAllHandler; #[async_trait] impl SessionHandler for ApproveAllHandler { + fn wants_permission_dispatch(&self) -> bool { + true + } + async fn on_permission_request( &self, _session_id: SessionId, @@ -469,8 +508,10 @@ pub struct DenyAllHandler; #[async_trait] impl SessionHandler for DenyAllHandler { - // All defaults are already safe: permissions deny, everything else is a - // sensible fallback. We just reuse them here for clarity. + fn wants_permission_dispatch(&self) -> bool { + true + } + // All other defaults are already safe. } /// A [`SessionHandler`] that leaves permission requests and external tool calls pending. diff --git a/rust/src/permission.rs b/rust/src/permission.rs index 364cb3c91..6971ee7e4 100644 --- a/rust/src/permission.rs +++ b/rust/src/permission.rs @@ -73,12 +73,35 @@ where }) } -enum Policy { +/// Internal permission policy stored on `SessionConfig::permission_policy`. +/// Applied to the handler at `Client::create_session` time so the order of +/// `with_handler` and `approve_all_permissions` is irrelevant. +#[derive(Clone)] +pub(crate) enum Policy { ApproveAll, DenyAll, Predicate(Arc bool + Send + Sync>), } +impl std::fmt::Debug for Policy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ApproveAll => f.write_str("Policy::ApproveAll"), + Self::DenyAll => f.write_str("Policy::DenyAll"), + Self::Predicate(_) => f.write_str("Policy::Predicate()"), + } + } +} + +/// Wrap `inner` with a stored [`Policy`]. Used by `Client::create_session` +/// after the handler and policy fields are both finalised. +pub(crate) fn apply_policy( + inner: Arc, + policy: Policy, +) -> Arc { + Arc::new(PermissionOverrideHandler { inner, policy }) +} + struct PermissionOverrideHandler { inner: Arc, policy: Policy, @@ -86,6 +109,18 @@ struct PermissionOverrideHandler { #[async_trait] impl SessionHandler for PermissionOverrideHandler { + fn wants_permission_dispatch(&self) -> bool { + true + } + + fn wants_elicitation_dispatch(&self) -> bool { + self.inner.wants_elicitation_dispatch() + } + + fn wants_external_tool_dispatch(&self, tool_name: &str) -> bool { + self.inner.wants_external_tool_dispatch(tool_name) + } + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { match event { HandlerEvent::PermissionRequest { ref data, .. } => { diff --git a/rust/src/session.rs b/rust/src/session.rs index 3dee77849..f773f9f13 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -763,10 +763,16 @@ impl Client { /// external tool calls are left pending for the consumer to resolve. pub async fn create_session(&self, mut config: SessionConfig) -> Result { let total_start = Instant::now(); - let handler = config + let base_handler = config .handler .take() .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); + let handler = match config.permission_policy.take() { + Some(policy) => crate::permission::apply_policy(base_handler, policy), + None => base_handler, + }; + config.request_permission = Some(handler.wants_permission_dispatch()); + config.request_elicitation = Some(handler.wants_elicitation_dispatch()); let hooks = config.hooks_handler.take(); let transforms = config.transform.take(); let tools_count = config.tools.as_ref().map_or(0, Vec::len); @@ -894,10 +900,16 @@ impl Client { /// fields are unset. pub async fn resume_session(&self, mut config: ResumeSessionConfig) -> Result { let total_start = Instant::now(); - let handler = config + let base_handler = config .handler .take() .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); + let handler = match config.permission_policy.take() { + Some(policy) => crate::permission::apply_policy(base_handler, policy), + None => base_handler, + }; + config.request_permission = Some(handler.wants_permission_dispatch()); + config.request_elicitation = Some(handler.wants_elicitation_dispatch()); let hooks = config.hooks_handler.take(); let transforms = config.transform.take(); let tools_count = config.tools.as_ref().map_or(0, Vec::len); @@ -1341,6 +1353,24 @@ async fn handle_notification( let Some(request_id) = extract_request_id(¬ification.event.data) else { return; }; + // Honor the runtime's `resolvedByHook` signal — when the + // server has already resolved the permission via a hook, + // clients must not send a second response. + if notification + .event + .data + .get("resolvedByHook") + .and_then(|v| v.as_bool()) + .unwrap_or(false) + { + return; + } + // Multi-client safety: if this client's handler does not + // claim permission dispatch, don't respond — another client + // on the same CLI may handle it. + if !handler.wants_permission_dispatch() { + return; + } let client = client.clone(); let handler = handler.clone(); let sid = session_id.clone(); @@ -1440,6 +1470,13 @@ async fn handle_notification( return; } }; + // Multi-client safety: if this client doesn't claim the + // requested tool name, don't respond — another connected + // client may have a handler. + if !data.tool_name.is_empty() && !handler.wants_external_tool_dispatch(&data.tool_name) + { + return; + } let client = client.clone(); let handler = handler.clone(); let sid = session_id.clone(); @@ -1537,6 +1574,12 @@ async fn handle_notification( let Some(request_id) = extract_request_id(¬ification.event.data) else { return; }; + // Multi-client safety: if this client's handler does not + // claim elicitation dispatch, don't respond — another + // client on the same CLI may handle it. + if !handler.wants_elicitation_dispatch() { + return; + } let elicitation_data: ElicitationRequestedData = match serde_json::from_value(notification.event.data.clone()) { Ok(d) => d, diff --git a/rust/src/tool.rs b/rust/src/tool.rs index 3342f4b9f..fc1163e82 100644 --- a/rust/src/tool.rs +++ b/rust/src/tool.rs @@ -392,6 +392,18 @@ impl ToolHandlerRouter { #[async_trait] impl SessionHandler for ToolHandlerRouter { + fn wants_permission_dispatch(&self) -> bool { + self.inner.wants_permission_dispatch() + } + + fn wants_elicitation_dispatch(&self) -> bool { + self.inner.wants_elicitation_dispatch() + } + + fn wants_external_tool_dispatch(&self, tool_name: &str) -> bool { + self.handlers.contains_key(tool_name) || self.inner.wants_external_tool_dispatch(tool_name) + } + async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult { let Some(handler) = self.handlers.get(&invocation.tool_name) else { return self.inner.on_external_tool(invocation).await; diff --git a/rust/src/types.rs b/rust/src/types.rs index fbde0a271..ccb0e1807 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -1185,6 +1185,11 @@ pub struct SessionConfig { /// `hooks` flag. Use [`with_hooks`](Self::with_hooks) to install one. #[serde(skip)] pub hooks_handler: Option>, + /// Permission policy applied to the handler. Stored separately from + /// `handler` so that the order of `with_handler` and + /// `approve_all_permissions` (and friends) is irrelevant. + #[serde(skip)] + pub(crate) permission_policy: Option, /// System-message transform. When set, the SDK injects the matching /// `action: "transform"` sections into the system message and routes /// `systemMessage.transform` RPC callbacks to it during the session. @@ -1271,10 +1276,10 @@ impl Default for SessionConfig { env_value_mode: default_env_value_mode(), enable_config_discovery: None, request_user_input: Some(true), - request_permission: Some(true), + request_permission: None, request_exit_plan_mode: Some(true), request_auto_mode_switch: Some(true), - request_elicitation: Some(true), + request_elicitation: None, skill_directories: None, instruction_directories: None, disabled_skills: None, @@ -1296,6 +1301,7 @@ impl Default for SessionConfig { session_fs_provider: None, handler: None, hooks_handler: None, + permission_policy: None, transform: None, } } @@ -1340,54 +1346,31 @@ impl SessionConfig { self } - /// Wrap the configured handler so every permission request is - /// auto-approved. Forwards every non-permission event to the inner - /// handler unchanged. - /// - /// If no handler has been installed via [`with_handler`](Self::with_handler), - /// wraps a [`NoopHandler`](crate::handler::NoopHandler), so declaration-only - /// tools remain pending for manual resolution. - /// - /// Order-independent: `with_handler(...).approve_all_permissions()` and - /// `approve_all_permissions().with_handler(...)` are NOT equivalent — - /// the second form discards the wrap because `with_handler` overwrites - /// the handler field. Always call `approve_all_permissions` *after* - /// `with_handler`. + /// Auto-approve every permission request on this session. Stored as a + /// policy that's applied to the configured handler at + /// [`Client::create_session`](crate::Client::create_session) time, so + /// order with [`with_handler`](Self::with_handler) is irrelevant. pub fn approve_all_permissions(mut self) -> Self { - let inner = self - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - self.handler = Some(crate::permission::approve_all(inner)); + self.permission_policy = Some(crate::permission::Policy::ApproveAll); self } - /// Wrap the configured handler so every permission request is - /// auto-denied. See [`approve_all_permissions`](Self::approve_all_permissions) - /// for ordering and default-handler semantics. + /// Auto-deny every permission request on this session. See + /// [`approve_all_permissions`](Self::approve_all_permissions). pub fn deny_all_permissions(mut self) -> Self { - let inner = self - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - self.handler = Some(crate::permission::deny_all(inner)); + self.permission_policy = Some(crate::permission::Policy::DenyAll); self } - /// Wrap the configured handler with a closure-based permission policy: - /// `predicate` is called for each permission request; `true` approves, - /// `false` denies. See + /// Apply a closure-based permission policy: `predicate` returns `true` + /// to approve, `false` to deny. See /// [`approve_all_permissions`](Self::approve_all_permissions) for - /// ordering and default-handler semantics. + /// ordering semantics. pub fn approve_permissions_if(mut self, predicate: F) -> Self where F: Fn(&crate::types::PermissionRequestData) -> bool + Send + Sync + 'static, { - let inner = self - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - self.handler = Some(crate::permission::approve_if(inner, predicate)); + self.permission_policy = Some(crate::permission::Policy::Predicate(Arc::new(predicate))); self } @@ -1471,12 +1454,6 @@ impl SessionConfig { self } - /// Enable `permission.request` JSON-RPC calls. Defaults to `Some(true)`. - pub fn with_request_permission(mut self, enable: bool) -> Self { - self.request_permission = Some(enable); - self - } - /// Enable `exitPlanMode.request` JSON-RPC calls. Defaults to `Some(true)`. pub fn with_request_exit_plan_mode(mut self, enable: bool) -> Self { self.request_exit_plan_mode = Some(enable); @@ -1489,12 +1466,6 @@ impl SessionConfig { self } - /// Advertise elicitation provider capability. Defaults to `Some(true)`. - pub fn with_request_elicitation(mut self, enable: bool) -> Self { - self.request_elicitation = Some(enable); - self - } - /// Set skill directory paths passed through to the CLI. pub fn with_skill_directories(mut self, paths: I) -> Self where @@ -1771,6 +1742,9 @@ pub struct ResumeSessionConfig { /// Session hook handler. See [`SessionConfig::hooks_handler`]. #[serde(skip)] pub hooks_handler: Option>, + /// Permission policy. See `SessionConfig::permission_policy`. + #[serde(skip)] + pub(crate) permission_policy: Option, /// System-message transform. See [`SessionConfig::transform`]. #[serde(skip)] pub transform: Option>, @@ -1852,10 +1826,10 @@ impl ResumeSessionConfig { env_value_mode: default_env_value_mode(), enable_config_discovery: None, request_user_input: Some(true), - request_permission: Some(true), + request_permission: None, request_exit_plan_mode: Some(true), request_auto_mode_switch: Some(true), - request_elicitation: Some(true), + request_elicitation: None, skill_directories: None, instruction_directories: None, disabled_skills: None, @@ -1878,6 +1852,7 @@ impl ResumeSessionConfig { continue_pending_work: None, handler: None, hooks_handler: None, + permission_policy: None, transform: None, } } @@ -1916,41 +1891,27 @@ impl ResumeSessionConfig { self } - /// Wrap the configured handler so every permission request is - /// auto-approved. See - /// [`SessionConfig::approve_all_permissions`] for semantics. + /// Auto-approve every permission request on the resumed session. See + /// [`SessionConfig::approve_all_permissions`]. pub fn approve_all_permissions(mut self) -> Self { - let inner = self - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - self.handler = Some(crate::permission::approve_all(inner)); + self.permission_policy = Some(crate::permission::Policy::ApproveAll); self } - /// Wrap the configured handler so every permission request is - /// auto-denied. See - /// [`SessionConfig::deny_all_permissions`] for semantics. + /// Auto-deny every permission request on the resumed session. See + /// [`SessionConfig::deny_all_permissions`]. pub fn deny_all_permissions(mut self) -> Self { - let inner = self - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - self.handler = Some(crate::permission::deny_all(inner)); + self.permission_policy = Some(crate::permission::Policy::DenyAll); self } - /// Wrap the configured handler with a predicate-based permission policy. - /// See [`SessionConfig::approve_permissions_if`] for semantics. + /// Apply a closure-based permission policy on the resumed session. + /// See [`SessionConfig::approve_permissions_if`]. pub fn approve_permissions_if(mut self, predicate: F) -> Self where F: Fn(&crate::types::PermissionRequestData) -> bool + Send + Sync + 'static, { - let inner = self - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - self.handler = Some(crate::permission::approve_if(inner, predicate)); + self.permission_policy = Some(crate::permission::Policy::Predicate(Arc::new(predicate))); self } @@ -2023,12 +1984,6 @@ impl ResumeSessionConfig { self } - /// Enable `permission.request` JSON-RPC calls. Defaults to `Some(true)`. - pub fn with_request_permission(mut self, enable: bool) -> Self { - self.request_permission = Some(enable); - self - } - /// Enable `exitPlanMode.request` JSON-RPC calls. Defaults to `Some(true)`. pub fn with_request_exit_plan_mode(mut self, enable: bool) -> Self { self.request_exit_plan_mode = Some(enable); @@ -2041,12 +1996,6 @@ impl ResumeSessionConfig { self } - /// Advertise elicitation provider capability on resume. Defaults to `Some(true)`. - pub fn with_request_elicitation(mut self, enable: bool) -> Self { - self.request_elicitation = Some(enable); - self - } - /// Set skill directory paths passed through to the CLI on resume. pub fn with_skill_directories(mut self, paths: I) -> Self where @@ -3366,8 +3315,11 @@ mod tests { fn session_config_default_enables_permission_flow_flags() { let cfg = SessionConfig::default(); assert_eq!(cfg.request_user_input, Some(true)); - assert_eq!(cfg.request_permission, Some(true)); - assert_eq!(cfg.request_elicitation, Some(true)); + // request_permission / request_elicitation are derived from the + // installed SessionHandler at Client::create_session time; the + // default config leaves them unset. + assert_eq!(cfg.request_permission, None); + assert_eq!(cfg.request_elicitation, None); assert_eq!(cfg.request_exit_plan_mode, Some(true)); assert_eq!(cfg.request_auto_mode_switch, Some(true)); } @@ -3376,8 +3328,8 @@ mod tests { fn resume_session_config_new_enables_permission_flow_flags() { let cfg = ResumeSessionConfig::new(SessionId::from("test-id")); assert_eq!(cfg.request_user_input, Some(true)); - assert_eq!(cfg.request_permission, Some(true)); - assert_eq!(cfg.request_elicitation, Some(true)); + assert_eq!(cfg.request_permission, None); + assert_eq!(cfg.request_elicitation, None); assert_eq!(cfg.request_exit_plan_mode, Some(true)); assert_eq!(cfg.request_auto_mode_switch, Some(true)); } @@ -3426,7 +3378,7 @@ mod tests { assert!(cfg.mcp_servers.is_some()); assert_eq!(cfg.enable_config_discovery, Some(true)); assert_eq!(cfg.request_user_input, Some(false)); // overrode default - assert_eq!(cfg.request_permission, Some(true)); // default preserved + assert_eq!(cfg.request_permission, None); // unset; derived at create_session time assert_eq!(cfg.request_exit_plan_mode, Some(false)); assert_eq!(cfg.request_auto_mode_switch, Some(false)); assert_eq!( @@ -3486,7 +3438,7 @@ mod tests { assert!(cfg.mcp_servers.is_some()); assert_eq!(cfg.enable_config_discovery, Some(true)); assert_eq!(cfg.request_user_input, Some(false)); // overrode default - assert_eq!(cfg.request_permission, Some(true)); // default preserved + assert_eq!(cfg.request_permission, None); // unset; derived at create_session time assert_eq!(cfg.request_exit_plan_mode, Some(false)); assert_eq!(cfg.request_auto_mode_switch, Some(false)); assert_eq!( @@ -3935,6 +3887,30 @@ mod permission_builder_tests { } } + /// Apply the same policy-resolution logic that `Client::create_session` + /// uses, so tests exercise the effective handler. + fn resolve(mut cfg: SessionConfig) -> Arc { + let base = cfg + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); + match cfg.permission_policy.take() { + Some(policy) => crate::permission::apply_policy(base, policy), + None => base, + } + } + + fn resolve_resume(mut cfg: ResumeSessionConfig) -> Arc { + let base = cfg + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); + match cfg.permission_policy.take() { + Some(policy) => crate::permission::apply_policy(base, policy), + None => base, + } + } + async fn dispatch(handler: &Arc) -> HandlerResponse { handler.on_event(permission_event()).await } @@ -3944,8 +3920,7 @@ mod permission_builder_tests { let cfg = SessionConfig::default() .with_handler(Arc::new(ApproveAllHandler)) .approve_all_permissions(); - let handler = cfg.handler.expect("handler should be set"); - match dispatch(&handler).await { + match dispatch(&resolve(cfg)).await { HandlerResponse::Permission(PermissionResult::Approved) => {} other => panic!("expected Approved, got {other:?}"), } @@ -3953,24 +3928,85 @@ mod permission_builder_tests { #[tokio::test] async fn session_config_approve_all_defaults_to_noop_inner() { - // Without with_handler, the wrap defaults to NoopHandler. The - // approve-all wrap intercepts permission events, so they're still - // approved -- the inner handler is consulted only for other events. + // Without with_handler, resolution defaults to NoopHandler. The + // approve-all wrap intercepts permission events. let cfg = SessionConfig::default().approve_all_permissions(); - let handler = cfg.handler.expect("handler should be set"); - match dispatch(&handler).await { + match dispatch(&resolve(cfg)).await { HandlerResponse::Permission(PermissionResult::Approved) => {} other => panic!("expected Approved, got {other:?}"), } } + /// Phase I: order independence. Both call orders must produce the same + /// effective approve-all policy. + #[tokio::test] + async fn session_config_approve_all_is_order_independent() { + let cfg_a = SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .approve_all_permissions(); + let cfg_b = SessionConfig::default() + .approve_all_permissions() + .with_handler(Arc::new(ApproveAllHandler)); + + match dispatch(&resolve(cfg_a)).await { + HandlerResponse::Permission(PermissionResult::Approved) => {} + other => panic!("order A: expected Approved, got {other:?}"), + } + match dispatch(&resolve(cfg_b)).await { + HandlerResponse::Permission(PermissionResult::Approved) => {} + other => panic!("order B: expected Approved, got {other:?}"), + } + } + + /// Phase I: same for deny_all_permissions. + #[tokio::test] + async fn session_config_deny_all_is_order_independent() { + let cfg_a = SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .deny_all_permissions(); + let cfg_b = SessionConfig::default() + .deny_all_permissions() + .with_handler(Arc::new(ApproveAllHandler)); + + match dispatch(&resolve(cfg_a)).await { + HandlerResponse::Permission(PermissionResult::Denied) => {} + other => panic!("order A: expected Denied, got {other:?}"), + } + match dispatch(&resolve(cfg_b)).await { + HandlerResponse::Permission(PermissionResult::Denied) => {} + other => panic!("order B: expected Denied, got {other:?}"), + } + } + + /// Phase I: same for approve_permissions_if. + #[tokio::test] + async fn session_config_approve_permissions_if_is_order_independent() { + let predicate = |data: &PermissionRequestData| { + data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") + }; + let cfg_a = SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .approve_permissions_if(predicate); + let cfg_b = SessionConfig::default() + .approve_permissions_if(predicate) + .with_handler(Arc::new(ApproveAllHandler)); + + match dispatch(&resolve(cfg_a)).await { + HandlerResponse::Permission(PermissionResult::Denied) => {} + other => panic!("order A: expected Denied for shell, got {other:?}"), + } + match dispatch(&resolve(cfg_b)).await { + HandlerResponse::Permission(PermissionResult::Denied) => {} + other => panic!("order B: expected Denied for shell, got {other:?}"), + } + } + #[tokio::test] async fn session_config_deny_all_denies() { let cfg = SessionConfig::default() .with_handler(Arc::new(ApproveAllHandler)) .deny_all_permissions(); - let handler = cfg.handler.expect("handler should be set"); - match dispatch(&handler).await { + match dispatch(&resolve(cfg)).await { HandlerResponse::Permission(PermissionResult::Denied) => {} other => panic!("expected Denied, got {other:?}"), } @@ -3983,8 +4019,7 @@ mod permission_builder_tests { .approve_permissions_if(|data| { data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") }); - let handler = cfg.handler.expect("handler should be set"); - match dispatch(&handler).await { + match dispatch(&resolve(cfg)).await { HandlerResponse::Permission(PermissionResult::Denied) => {} other => panic!("expected Denied for shell, got {other:?}"), } @@ -3995,10 +4030,28 @@ mod permission_builder_tests { let cfg = ResumeSessionConfig::new(SessionId::from("s1")) .with_handler(Arc::new(ApproveAllHandler)) .approve_all_permissions(); - let handler = cfg.handler.expect("handler should be set"); - match dispatch(&handler).await { + match dispatch(&resolve_resume(cfg)).await { HandlerResponse::Permission(PermissionResult::Approved) => {} other => panic!("expected Approved, got {other:?}"), } } + + #[tokio::test] + async fn resume_session_config_approve_all_is_order_independent() { + let cfg_a = ResumeSessionConfig::new(SessionId::from("s1")) + .with_handler(Arc::new(ApproveAllHandler)) + .approve_all_permissions(); + let cfg_b = ResumeSessionConfig::new(SessionId::from("s1")) + .approve_all_permissions() + .with_handler(Arc::new(ApproveAllHandler)); + + match dispatch(&resolve_resume(cfg_a)).await { + HandlerResponse::Permission(PermissionResult::Approved) => {} + other => panic!("order A: expected Approved, got {other:?}"), + } + match dispatch(&resolve_resume(cfg_b)).await { + HandlerResponse::Permission(PermissionResult::Approved) => {} + other => panic!("order B: expected Approved, got {other:?}"), + } + } } diff --git a/rust/tests/e2e/elicitation.rs b/rust/tests/e2e/elicitation.rs index b0d7d7714..13312d177 100644 --- a/rust/tests/e2e/elicitation.rs +++ b/rust/tests/e2e/elicitation.rs @@ -47,10 +47,7 @@ async fn elicitation_throws_when_capability_is_missing() { ctx.set_default_copilot_user(); let client = ctx.start_client().await; let session = client - .create_session( - ctx.approve_all_session_config() - .with_request_elicitation(false), - ) + .create_session(ctx.approve_all_session_config()) .await .expect("create session"); @@ -144,10 +141,7 @@ async fn should_report_elicitation_capability_based_on_handler_presence() { with_handler.disconnect().await.expect("disconnect first"); let without_handler = client - .create_session( - ctx.approve_all_session_config() - .with_request_elicitation(false), - ) + .create_session(ctx.approve_all_session_config()) .await .expect("create non-elicitation session"); assert_ne!( @@ -179,10 +173,7 @@ async fn session_without_elicitationhandler_creates_successfully() { ctx.set_default_copilot_user(); let client = ctx.start_client().await; let session = client - .create_session( - ctx.approve_all_session_config() - .with_request_elicitation(false), - ) + .create_session(ctx.approve_all_session_config()) .await .expect("create session"); diff --git a/rust/tests/e2e/multi_client.rs b/rust/tests/e2e/multi_client.rs index 874bd7d0f..c779fa096 100644 --- a/rust/tests/e2e/multi_client.rs +++ b/rust/tests/e2e/multi_client.rs @@ -128,7 +128,6 @@ async fn one_client_approves_permission_and_both_see_the_result() { let session2 = client2 .resume_session( resume_config(session1.id().clone()) - .with_request_permission(false) .with_handler(permission_handler(PermissionResult::NoResult)), ) .await @@ -214,7 +213,6 @@ async fn one_client_rejects_permission_and_both_see_the_result() { let session2 = client2 .resume_session( resume_config(session1.id().clone()) - .with_request_permission(false) .with_handler(permission_handler(PermissionResult::NoResult)), ) .await @@ -499,6 +497,12 @@ struct PermissionDecisionHandler { #[async_trait] impl SessionHandler for PermissionDecisionHandler { + fn wants_permission_dispatch(&self) -> bool { + // NoResult means "I'm declining to respond"; surface that via + // the wire flag so the runtime doesn't even broadcast to us. + !matches!(self.result, PermissionResult::NoResult) + } + async fn on_permission_request( &self, _session_id: SessionId, @@ -518,6 +522,14 @@ struct SelectiveToolHandler { #[async_trait] impl SessionHandler for SelectiveToolHandler { + fn wants_permission_dispatch(&self) -> bool { + true + } + + fn wants_external_tool_dispatch(&self, tool_name: &str) -> bool { + self.tools.iter().any(|t| t.name == tool_name) + } + async fn on_permission_request( &self, _session_id: SessionId, diff --git a/rust/tests/e2e/multi_client_commands_elicitation.rs b/rust/tests/e2e/multi_client_commands_elicitation.rs index 5a3948659..5a8a67045 100644 --- a/rust/tests/e2e/multi_client_commands_elicitation.rs +++ b/rust/tests/e2e/multi_client_commands_elicitation.rs @@ -80,10 +80,7 @@ async fn capabilities_changed_fires_when_second_client_joins_with_elicitation_ha let port = free_tcp_port(); let server = start_tcp_server(ctx, port).await; let session1 = server - .create_session( - ctx.approve_all_session_config() - .with_request_elicitation(false), - ) + .create_session(ctx.approve_all_session_config()) .await .expect("create session"); assert_ne!( @@ -142,10 +139,7 @@ async fn capabilities_changed_fires_when_elicitation_provider_disconnects() { let port = free_tcp_port(); let server = start_tcp_server(ctx, port).await; let session1 = server - .create_session( - ctx.approve_all_session_config() - .with_request_elicitation(false), - ) + .create_session(ctx.approve_all_session_config()) .await .expect("create session"); let client2 = start_external_client(ctx, port).await; diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index e460b214e..8ab25bfaa 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -1980,6 +1980,10 @@ async fn elicitation_requested_dispatches_to_handler_and_responds() { struct ElicitHandler; #[async_trait] impl SessionHandler for ElicitHandler { + fn wants_elicitation_dispatch(&self) -> bool { + true + } + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { match event { HandlerEvent::ElicitationRequest { request, .. } => { @@ -2026,6 +2030,10 @@ async fn elicitation_requested_cancels_on_handler_error() { struct FailHandler; #[async_trait] impl SessionHandler for FailHandler { + fn wants_elicitation_dispatch(&self) -> bool { + true + } + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { match event { // Return Ok instead of Elicitation — SDK should treat as cancel @@ -2056,6 +2064,10 @@ async fn external_tool_requested_dispatches_to_handler_and_responds() { struct ExternalToolHandler; #[async_trait] impl SessionHandler for ExternalToolHandler { + fn wants_external_tool_dispatch(&self, _tool_name: &str) -> bool { + true + } + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { match event { HandlerEvent::ExternalTool { invocation } => { @@ -2090,6 +2102,119 @@ async fn external_tool_requested_dispatches_to_handler_and_responds() { assert_eq!(rpc_call["params"]["result"], "all tests passed"); } +#[tokio::test] +async fn external_tool_broadcast_for_unknown_tool_is_not_responded_to() { + // Phase H multi-client safety: a handler that doesn't claim the + // requested tool name must not send an RPC response — another client + // on the same CLI may have a real handler. + struct OnlyKnowsFoo; + #[async_trait] + impl SessionHandler for OnlyKnowsFoo { + fn wants_external_tool_dispatch(&self, tool_name: &str) -> bool { + tool_name == "foo" + } + + async fn on_event(&self, _event: HandlerEvent) -> HandlerResponse { + HandlerResponse::Ok + } + } + + let (_session, mut server) = create_session_pair(Arc::new(OnlyKnowsFoo)).await; + server + .send_event( + "external_tool.requested", + serde_json::json!({ + "requestId": "req-unknown", + "sessionId": server.session_id, + "toolCallId": "tc-x", + "toolName": "bar", + "arguments": {}, + }), + ) + .await; + + // The dispatcher must NOT respond. Read with a short timeout and + // assert the read times out. + let res = tokio::time::timeout(Duration::from_millis(150), server.read_request()).await; + assert!( + res.is_err(), + "expected no RPC response for unknown tool, got: {:?}", + res.ok() + ); +} + +#[tokio::test] +async fn permission_broadcast_with_resolved_by_hook_is_not_responded_to() { + // Phase H: when the runtime marks a permission request as already + // resolved by a hook, the client must not respond again. + let (_session, mut server) = create_session_pair(Arc::new(ApproveAllHandler)).await; + server + .send_event( + "permission.requested", + serde_json::json!({ + "requestId": "req-hooked", + "sessionId": server.session_id, + "resolvedByHook": true, + "permissionRequest": { "kind": "shell" }, + }), + ) + .await; + + let res = tokio::time::timeout(Duration::from_millis(150), server.read_request()).await; + assert!( + res.is_err(), + "expected no RPC when resolvedByHook=true, got: {:?}", + res.ok() + ); +} + +#[tokio::test] +async fn permission_broadcast_with_no_claiming_handler_is_not_responded_to() { + // Phase H: a handler that doesn't claim permission dispatch must not + // respond — the SDK lets other connected clients handle the request. + let (_session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + server + .send_event( + "permission.requested", + serde_json::json!({ + "requestId": "req-pending", + "sessionId": server.session_id, + "permissionRequest": { "kind": "shell" }, + }), + ) + .await; + + let res = tokio::time::timeout(Duration::from_millis(150), server.read_request()).await; + assert!( + res.is_err(), + "expected no RPC when handler doesn't claim permission dispatch, got: {:?}", + res.ok() + ); +} + +#[tokio::test] +async fn elicitation_broadcast_with_no_claiming_handler_is_not_responded_to() { + // Phase H: same gating for elicitation. The default handler doesn't + // claim elicitation, so broadcasts are silently dropped. + let (_session, mut server) = create_session_pair(Arc::new(ApproveAllHandler)).await; + server + .send_event( + "elicitation.requested", + serde_json::json!({ + "requestId": "elicit-silent", + "message": "should not be answered", + }), + ) + .await; + + let res = tokio::time::timeout(Duration::from_millis(150), server.read_request()).await; + assert!( + res.is_err(), + "expected no RPC when handler doesn't claim elicitation, got: {:?}", + res.ok() + ); +} + #[tokio::test] async fn capabilities_captured_from_create_response() { let (client, mut server_read, mut server_write) = make_client(); @@ -2165,7 +2290,7 @@ async fn request_elicitation_sent_in_create_params() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default().with_handler(Arc::new(ApproveAllHandler))) .await .unwrap() } @@ -2173,7 +2298,10 @@ async fn request_elicitation_sent_in_create_params() { let request = read_framed(&mut server_read).await; assert_eq!(request["method"], "session.create"); - assert_eq!(request["params"]["requestElicitation"], true); + // ApproveAllHandler claims permission dispatch but not elicitation, so + // the wire flags reflect that exact responsibility. + assert_eq!(request["params"]["requestPermission"], true); + assert_eq!(request["params"]["requestElicitation"], false); assert_eq!(request["params"]["requestExitPlanMode"], true); assert_eq!(request["params"]["requestAutoModeSwitch"], true); @@ -2188,6 +2316,38 @@ async fn request_elicitation_sent_in_create_params() { timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); } +#[tokio::test] +async fn noop_handler_sends_request_permission_false() { + // Phase H1a wire-flag derivation: a handler that doesn't claim + // permission dispatch must send requestPermission=false so the + // runtime doesn't broadcast permission events to this client. + let (client, mut server_read, mut server_write) = make_client(); + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .await + .unwrap() + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["params"]["requestPermission"], false); + assert_eq!(request["params"]["requestElicitation"], false); + + let id = request["id"].as_u64().unwrap(); + let session_id = requested_session_id(&request); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessionId": session_id }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); +} + #[tokio::test] async fn env_value_mode_hardcoded_direct_on_create_and_resume() { use github_copilot_sdk::types::ResumeSessionConfig; @@ -3751,6 +3911,10 @@ async fn tool_invocation_carries_trace_context_from_event() { #[async_trait] impl SessionHandler for CapturingHandler { + fn wants_external_tool_dispatch(&self, _tool_name: &str) -> bool { + true + } + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { if let HandlerEvent::ExternalTool { invocation } = event { *self.captured.lock() = Some(( From f0fc862bd850c63ad07002fcbabdf3612d6791cf Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 19:45:09 +0100 Subject: [PATCH 07/22] Phase K: docs, examples, scenarios - rust/README.md: drop the 'call approve_all_permissions AFTER with_handler' warning -- the policy field makes both orderings equivalent now. Replace the elicitation snippet with the new wants_elicitation_dispatch probe. - docs/features/remote-sessions.md: with_remote -> with_enable_remote_sessions. - test/scenarios/**/rust: rename .destroy() -> .disconnect(). - test/scenarios/transport/tcp/rust: add connection_token field to Transport::External literal. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- docs/features/remote-sessions.md | 2 +- rust/README.md | 30 ++++++++++++++----- .../callbacks/hooks/rust/src/main.rs | 2 +- .../callbacks/permissions/rust/src/main.rs | 2 +- .../callbacks/user-input/rust/src/main.rs | 2 +- test/scenarios/modes/default/rust/src/main.rs | 2 +- .../prompts/attachments/rust/src/main.rs | 2 +- .../prompts/reasoning-effort/rust/src/main.rs | 2 +- .../prompts/system-message/rust/src/main.rs | 2 +- .../concurrent-sessions/rust/src/main.rs | 4 +-- .../infinite-sessions/rust/src/main.rs | 2 +- .../sessions/session-resume/rust/src/main.rs | 2 +- .../sessions/streaming/rust/src/main.rs | 2 +- .../tools/custom-agents/rust/src/main.rs | 2 +- .../tools/mcp-servers/rust/src/main.rs | 2 +- .../scenarios/tools/no-tools/rust/src/main.rs | 2 +- test/scenarios/tools/skills/rust/src/main.rs | 2 +- .../tools/tool-filtering/rust/src/main.rs | 2 +- .../tools/tool-overrides/rust/src/main.rs | 2 +- .../transport/stdio/rust/src/main.rs | 2 +- test/scenarios/transport/tcp/rust/src/main.rs | 3 +- 21 files changed, 45 insertions(+), 28 deletions(-) diff --git a/docs/features/remote-sessions.md b/docs/features/remote-sessions.md index 391bb762d..216e80681 100644 --- a/docs/features/remote-sessions.md +++ b/docs/features/remote-sessions.md @@ -100,7 +100,7 @@ session.On((SessionEvent e) => use github_copilot_sdk::{Client, ClientOptions, PermissionRequestResult, SessionConfig}; let client = Client::start( - ClientOptions::new().with_remote(true) + ClientOptions::new().with_enable_remote_sessions(true) ).await?; let session = client.create_session( SessionConfig::new("/path/to/github-repo") diff --git a/rust/README.md b/rust/README.md index d8ff9c43c..a09e294fd 100644 --- a/rust/README.md +++ b/rust/README.md @@ -445,9 +445,15 @@ let session = client .await?; ``` -> Call the policy method **after** `with_handler` — `with_handler` overwrites the handler field, so `approve_all_permissions().with_handler(...)` discards the wrap. +> Order-independent: `with_handler` and the permission-policy methods +> (`approve_all_permissions`, `deny_all_permissions`, +> `approve_permissions_if`) can be called in either order — the policy is +> applied to the handler when the session is created. -For composing a policy onto a handler outside the builder chain (e.g. when wrapping a `ToolHandlerRouter` you've built elsewhere), the `permission` module exposes the same primitives as free functions: +The `permission` module also exposes the policy primitives as standalone +helpers for the rare case where you want to wrap a handler outside the +builder chain (e.g., when composing a `ToolHandlerRouter` you've built +elsewhere): ```rust,ignore use github_copilot_sdk::permission; @@ -461,13 +467,23 @@ let session = client.create_session(config.with_handler(handler)).await?; ### Capabilities & Elicitation -The SDK negotiates capabilities with the CLI after session creation. Enable elicitation to let the agent present structured UI dialogs (forms, URL prompts) to the user. +The SDK negotiates capabilities with the CLI after session creation. To +opt this client into receiving `elicitation.requested` broadcasts, return +`true` from `SessionHandler::wants_elicitation_dispatch` on the handler +you install — the SDK derives the `requestElicitation` wire flag from +that probe at `Client::create_session` time. Clients that don't claim +elicitation are silently skipped, allowing other connected clients on +the same CLI to handle the request. ```rust,ignore -let config = SessionConfig { - request_elicitation: Some(true), - ..Default::default() -}; +struct MyHandler; +#[async_trait] +impl SessionHandler for MyHandler { + fn wants_elicitation_dispatch(&self) -> bool { + true + } + // ... on_event etc. +} ``` The handler receives `HandlerEvent::ElicitationRequest` with a message, optional JSON Schema for form fields, and an optional mode. Known modes include `Form` and `Url`, but the mode may be absent or an unknown future value. Return `HandlerResponse::Elicitation(result)`. diff --git a/test/scenarios/callbacks/hooks/rust/src/main.rs b/test/scenarios/callbacks/hooks/rust/src/main.rs index 179765d2f..d77c18795 100644 --- a/test/scenarios/callbacks/hooks/rust/src/main.rs +++ b/test/scenarios/callbacks/hooks/rust/src/main.rs @@ -126,6 +126,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } println!("\nTotal hooks fired: {}", log.len()); - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/callbacks/permissions/rust/src/main.rs b/test/scenarios/callbacks/permissions/rust/src/main.rs index 214620e35..ff8439a2a 100644 --- a/test/scenarios/callbacks/permissions/rust/src/main.rs +++ b/test/scenarios/callbacks/permissions/rust/src/main.rs @@ -86,6 +86,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } println!("\nTotal permission requests: {}", log.len()); - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/callbacks/user-input/rust/src/main.rs b/test/scenarios/callbacks/user-input/rust/src/main.rs index b7fea906e..467404d95 100644 --- a/test/scenarios/callbacks/user-input/rust/src/main.rs +++ b/test/scenarios/callbacks/user-input/rust/src/main.rs @@ -98,6 +98,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } println!("\nTotal user input requests: {}", log.len()); - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/modes/default/rust/src/main.rs b/test/scenarios/modes/default/rust/src/main.rs index ba890997d..862a70ccd 100644 --- a/test/scenarios/modes/default/rust/src/main.rs +++ b/test/scenarios/modes/default/rust/src/main.rs @@ -31,6 +31,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } println!("Default mode test complete"); - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/prompts/attachments/rust/src/main.rs b/test/scenarios/prompts/attachments/rust/src/main.rs index 9ba9cc176..2a240c590 100644 --- a/test/scenarios/prompts/attachments/rust/src/main.rs +++ b/test/scenarios/prompts/attachments/rust/src/main.rs @@ -53,6 +53,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/prompts/reasoning-effort/rust/src/main.rs b/test/scenarios/prompts/reasoning-effort/rust/src/main.rs index bf1ab9720..3bfc0d1a0 100644 --- a/test/scenarios/prompts/reasoning-effort/rust/src/main.rs +++ b/test/scenarios/prompts/reasoning-effort/rust/src/main.rs @@ -35,6 +35,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/prompts/system-message/rust/src/main.rs b/test/scenarios/prompts/system-message/rust/src/main.rs index 4218a389b..034cdea61 100644 --- a/test/scenarios/prompts/system-message/rust/src/main.rs +++ b/test/scenarios/prompts/system-message/rust/src/main.rs @@ -35,6 +35,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs b/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs index 43932b613..41a4360e6 100644 --- a/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs +++ b/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs @@ -47,7 +47,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session1.destroy().await?; - session2.destroy().await?; + session1.disconnect().await?; + session2.disconnect().await?; Ok(()) } diff --git a/test/scenarios/sessions/infinite-sessions/rust/src/main.rs b/test/scenarios/sessions/infinite-sessions/rust/src/main.rs index 0c0f06814..2ada9314e 100644 --- a/test/scenarios/sessions/infinite-sessions/rust/src/main.rs +++ b/test/scenarios/sessions/infinite-sessions/rust/src/main.rs @@ -50,6 +50,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { println!("Infinite sessions test complete — all messages processed successfully"); - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/sessions/session-resume/rust/src/main.rs b/test/scenarios/sessions/session-resume/rust/src/main.rs index 10cd4fa62..4ed66c846 100644 --- a/test/scenarios/sessions/session-resume/rust/src/main.rs +++ b/test/scenarios/sessions/session-resume/rust/src/main.rs @@ -41,6 +41,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - resumed.destroy().await?; + resumed.disconnect().await?; Ok(()) } diff --git a/test/scenarios/sessions/streaming/rust/src/main.rs b/test/scenarios/sessions/streaming/rust/src/main.rs index f5cf23764..6616f35b1 100644 --- a/test/scenarios/sessions/streaming/rust/src/main.rs +++ b/test/scenarios/sessions/streaming/rust/src/main.rs @@ -61,6 +61,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { chunks.load(Ordering::Relaxed) ); - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/tools/custom-agents/rust/src/main.rs b/test/scenarios/tools/custom-agents/rust/src/main.rs index e707770bc..ff4fb301a 100644 --- a/test/scenarios/tools/custom-agents/rust/src/main.rs +++ b/test/scenarios/tools/custom-agents/rust/src/main.rs @@ -77,6 +77,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/tools/mcp-servers/rust/src/main.rs b/test/scenarios/tools/mcp-servers/rust/src/main.rs index fd76147a1..8abc7e078 100644 --- a/test/scenarios/tools/mcp-servers/rust/src/main.rs +++ b/test/scenarios/tools/mcp-servers/rust/src/main.rs @@ -63,6 +63,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { println!("\nNo MCP servers configured (set MCP_SERVER_CMD to test with a real server)"); } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/tools/no-tools/rust/src/main.rs b/test/scenarios/tools/no-tools/rust/src/main.rs index 691ac47ed..585dbda0b 100644 --- a/test/scenarios/tools/no-tools/rust/src/main.rs +++ b/test/scenarios/tools/no-tools/rust/src/main.rs @@ -39,6 +39,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/tools/skills/rust/src/main.rs b/test/scenarios/tools/skills/rust/src/main.rs index 845704fac..1fc94ba64 100644 --- a/test/scenarios/tools/skills/rust/src/main.rs +++ b/test/scenarios/tools/skills/rust/src/main.rs @@ -57,6 +57,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { println!("\nSkill directories configured successfully"); - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/tools/tool-filtering/rust/src/main.rs b/test/scenarios/tools/tool-filtering/rust/src/main.rs index edc203550..10eff3d91 100644 --- a/test/scenarios/tools/tool-filtering/rust/src/main.rs +++ b/test/scenarios/tools/tool-filtering/rust/src/main.rs @@ -42,6 +42,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/tools/tool-overrides/rust/src/main.rs b/test/scenarios/tools/tool-overrides/rust/src/main.rs index ce002a27d..7b918ea13 100644 --- a/test/scenarios/tools/tool-overrides/rust/src/main.rs +++ b/test/scenarios/tools/tool-overrides/rust/src/main.rs @@ -56,6 +56,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/transport/stdio/rust/src/main.rs b/test/scenarios/transport/stdio/rust/src/main.rs index 156b3587d..06a0de255 100644 --- a/test/scenarios/transport/stdio/rust/src/main.rs +++ b/test/scenarios/transport/stdio/rust/src/main.rs @@ -25,6 +25,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } diff --git a/test/scenarios/transport/tcp/rust/src/main.rs b/test/scenarios/transport/tcp/rust/src/main.rs index 49691c1b2..9f0674296 100644 --- a/test/scenarios/transport/tcp/rust/src/main.rs +++ b/test/scenarios/transport/tcp/rust/src/main.rs @@ -20,6 +20,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { opts.transport = Transport::External { host: host.to_string(), port, + connection_token: None, }; opts.github_token = std::env::var("GITHUB_TOKEN").ok(); let client = Client::start(opts).await?; @@ -38,6 +39,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } } - session.destroy().await?; + session.disconnect().await?; Ok(()) } From f1b974e55564e73aae947f6bdcb66287340599f8 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 20:10:37 +0100 Subject: [PATCH 08/22] Cross-SDK: derive requestPermission from handler presence on session.create TS, C#, Go, and Python all currently hardcoded requestPermission: true on session.create regardless of whether the caller supplied an onPermissionRequest handler. (Resume / join paths already derived from handler presence.) The runtime supports the presence-derived shape: when no client opts in via requestPermission=true, the session short-circuits permission prompts with user-not-available instead of broadcasting. So the hardcoded true forced the runtime to broadcast permission events to clients that would never respond, wasting a roundtrip and creating a discrepancy between create and resume. Aligns all four SDKs to: requestPermission := onPermissionRequest != null on both session.create and session.resume. (Rust will land the same shape in its API-review-fixes commit on this branch.) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dotnet/src/Client.cs | 4 ++-- go/client.go | 8 ++++++-- nodejs/src/client.ts | 2 +- python/copilot/client.py | 8 ++++---- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 11f8c90c9..438ac5c7e 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -594,7 +594,7 @@ public async Task CreateSessionAsync(SessionConfig config, Cance config.ExcludedTools, config.Provider, config.EnableSessionTelemetry, - (bool?)true, + config.OnPermissionRequest != null ? true : null, config.OnUserInputRequest != null ? true : null, config.OnExitPlanModeRequest != null ? true : null, config.OnAutoModeSwitchRequest != null ? true : null, @@ -752,7 +752,7 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes config.ExcludedTools, config.Provider, config.EnableSessionTelemetry, - (bool?)true, + config.OnPermissionRequest != null ? true : null, config.OnUserInputRequest != null ? true : null, config.OnExitPlanModeRequest != null ? true : null, config.OnAutoModeSwitchRequest != null ? true : null, diff --git a/go/client.go b/go/client.go index dab09a4dd..eb7276de2 100644 --- a/go/client.go +++ b/go/client.go @@ -665,7 +665,9 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses config.Hooks.OnErrorOccurred != nil) { req.Hooks = Bool(true) } - req.RequestPermission = Bool(true) + if config.OnPermissionRequest != nil { + req.RequestPermission = Bool(true) + } traceparent, tracestate := getTraceContext(ctx) req.Traceparent = traceparent @@ -839,7 +841,9 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, req.InfiniteSessions = config.InfiniteSessions req.GitHubToken = config.GitHubToken req.RemoteSession = config.RemoteSession - req.RequestPermission = Bool(true) + if config.OnPermissionRequest != nil { + req.RequestPermission = Bool(true) + } if len(config.Commands) > 0 { cmds := make([]wireCommand, 0, len(config.Commands)) diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 991f23fa1..ddc23dd2f 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -830,7 +830,7 @@ export class CopilotClient { provider: config.provider, enableSessionTelemetry: config.enableSessionTelemetry, modelCapabilities: config.modelCapabilities, - requestPermission: true, + requestPermission: !!config.onPermissionRequest, requestUserInput: !!config.onUserInputRequest, requestElicitation: !!config.onElicitationRequest, requestExitPlanMode: !!config.onExitPlanModeRequest, diff --git a/python/copilot/client.py b/python/copilot/client.py index 6adb52061..9c7698716 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -1493,8 +1493,8 @@ async def create_session( if excluded_tools is not None: payload["excludedTools"] = excluded_tools - # Always enable permission request callback - payload["requestPermission"] = True + # Enable permission request callback if handler provided + payload["requestPermission"] = bool(on_permission_request) # Enable user input request callback if handler provided if on_user_input_request: @@ -1888,8 +1888,8 @@ async def resume_session( else True ) - # Always enable permission request callback - payload["requestPermission"] = True + # Enable permission request callback if handler provided + payload["requestPermission"] = bool(on_permission_request) if on_user_input_request: payload["requestUserInput"] = True From 455a17b7e7aacb33fcd774d743af90851e1a77e1 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 21:23:09 +0100 Subject: [PATCH 09/22] Phase H redo (lib): five optional handler traits + wire-struct split Replaces the monolithic SessionHandler trait with five focused single-method optional handler traits in src/handler.rs: PermissionHandler, ElicitationHandler, UserInputHandler, ExitPlanModeHandler, AutoModeSwitchHandler. Each handler is an optional field on SessionConfig / ResumeSessionConfig; presence is the signal, matching TypeScript/C# semantics. - Drops SessionHandler, HandlerEvent, HandlerResponse, NoopHandler, ToolHandlerRouter, and the wants_*_dispatch probe methods. - ApproveAllHandler / DenyAllHandler now implement PermissionHandler. - Tool dispatch lives in an internal HashMap> built from SessionConfig::tool_handlers. ToolHandlerRouter is gone; callers pass handlers directly via with_tool_handlers(...). - Wire flags on session.create / session.resume (requestPermission, requestElicitation, requestUserInput, requestExitPlanMode, requestAutoModeSwitch, hooks) are derived from handler presence at Client::create_session / resume_session time. No probes, no caller-visible knobs. - New src/wire.rs: SessionCreateWire and SessionResumeWire structs (pub(crate)) carry the wire payload. SessionConfig::to_wire() and ResumeSessionConfig::to_wire() build them. SessionConfig and ResumeSessionConfig keep their serde derives (used only for tests), but Client::create_session and Client::resume_session now serialize the wire struct, not the user-facing config. - permission::approve_all / deny_all / approve_if are now Arc producers. New private permission::resolve_handler(handler, policy) used by Client to apply a stored policy on top of (or in lieu of) a caller-supplied handler -- the policy wins, making with_permission_handler / approve_all_permissions order-independent. - Broadcast dispatcher in session.rs: - permission.requested: skip if data.resolvedByHook OR permission_handler is None; else dispatch via handle(). - elicitation.requested: skip if elicitation_handler is None. - external_tool.requested: look up by tool_name in the registry; silently skip if absent. - Direct RPC dispatcher in session.rs: - tool.call: REMOVED (legacy; runtime only emits external_tool.requested broadcasts now). - permission.request: kept as v2 back-compat; returns { kind: 'user-not-available' } when no permission handler. - userInput.request / exitPlanMode.request / autoModeSwitch.request: kept (this is how the runtime still emits these). Defaults when no handler match TS: { noResponse: true }, { approved: true }, and { response: 'no' }. Tests in rust/tests/ and e2e tests are migrated in a follow-up commit. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/src/handler.rs | 729 +++++++---------------------------------- rust/src/lib.rs | 3 +- rust/src/permission.rs | 261 ++++++++------- rust/src/session.rs | 509 +++++++++++++--------------- rust/src/tool.rs | 405 ++--------------------- rust/src/types.rs | 598 +++++++++++++++++++++++---------- rust/src/wire.rs | 173 ++++++++++ 7 files changed, 1104 insertions(+), 1574 deletions(-) create mode 100644 rust/src/wire.rs diff --git a/rust/src/handler.rs b/rust/src/handler.rs index de3f434e1..6e4ffa484 100644 --- a/rust/src/handler.rs +++ b/rust/src/handler.rs @@ -1,123 +1,28 @@ -//! Event handler traits for session lifecycle. +//! Optional session-callback traits. //! -//! The [`SessionHandler`](crate::handler::SessionHandler) trait is the primary extension point — implement -//! [`on_event`](crate::handler::SessionHandler::on_event) to control how sessions respond to -//! CLI events, permission requests, tool calls, and user input prompts. +//! Each callback the CLI may dispatch (permission requests, elicitation +//! prompts, user-input questions, exit-plan-mode prompts, +//! auto-mode-switch prompts) has its own focused trait with a single +//! `handle` method. +//! +//! Handlers are **optional**: install only the ones the application cares +//! about. The SDK derives the corresponding wire flag on +//! `session.create` / `session.resume` from the presence of each handler, +//! so the runtime does not emit broadcasts this client would never +//! respond to. +//! +//! Tool dispatch uses its own per-tool registry built from +//! [`SessionConfig::with_tool_handlers`](crate::types::SessionConfig::with_tool_handlers). use async_trait::async_trait; use serde::{Deserialize, Serialize}; use crate::types::{ ElicitationRequest, ElicitationResult, ExitPlanModeData, PermissionRequestData, RequestId, - SessionEvent, SessionId, ToolInvocation, ToolResult, + SessionId, }; -/// Events dispatched by the SDK session event loop to the handler. -/// -/// The handler returns a [`HandlerResponse`] indicating how the SDK should -/// respond to the CLI. For fire-and-forget events (`SessionEvent`), the -/// response is ignored. -#[non_exhaustive] -#[derive(Debug)] -pub enum HandlerEvent { - /// Informational session event from the timeline (e.g. assistant.message_delta, - /// session.idle, tool.execution_start). Fire-and-forget — return `HandlerResponse::Ok`. - SessionEvent { - /// The session that emitted this event. - session_id: SessionId, - /// The event payload. - event: SessionEvent, - }, - - /// The CLI requests permission for an action. Return `HandlerResponse::Permission(..)`. - PermissionRequest { - /// The requesting session. - session_id: SessionId, - /// Unique ID to correlate the response. - request_id: RequestId, - /// Permission request payload. - data: PermissionRequestData, - }, - - /// The CLI requests user input. Return `HandlerResponse::UserInput(..)`. - /// The handler may block (e.g. awaiting a UI dialog) — this is expected. - UserInput { - /// The requesting session. - session_id: SessionId, - /// The question text to present. - question: String, - /// Optional multiple-choice options. - choices: Option>, - /// Whether free-form text input is allowed. - allow_freeform: Option, - }, - - /// The CLI requests execution of a client-defined tool. - /// Return `HandlerResponse::ToolResult(..)`. - ExternalTool { - /// The tool call to execute. - invocation: ToolInvocation, - }, - - /// The CLI broadcasts an elicitation request for the provider to handle. - /// Return `HandlerResponse::Elicitation(..)`. - ElicitationRequest { - /// The requesting session. - session_id: SessionId, - /// Unique ID to correlate the response. - request_id: RequestId, - /// The elicitation request payload. - request: ElicitationRequest, - }, - - /// The CLI requests exiting plan mode. Return `HandlerResponse::ExitPlanMode(..)`. - ExitPlanMode { - /// The requesting session. - session_id: SessionId, - /// Plan mode exit payload. - data: ExitPlanModeData, - }, - - /// The CLI asks whether to switch to auto model when an eligible rate - /// limit is hit. Return [`HandlerResponse::AutoModeSwitch`]. - AutoModeSwitch { - /// The requesting session. - session_id: SessionId, - /// The specific rate-limit error code that triggered the request, - /// if known (e.g. `user_weekly_rate_limited`, `user_global_rate_limited`). - error_code: Option, - /// Seconds until the rate limit resets, when known. - retry_after_seconds: Option, - }, -} - -/// Response from the handler back to the SDK, used to construct the -/// JSON-RPC reply sent to the CLI. -#[non_exhaustive] -#[derive(Debug)] -pub enum HandlerResponse { - /// No response needed (used for fire-and-forget `SessionEvent`s). - Ok, - /// Do not send a response. The consumer will resolve the pending request out-of-band. - NoResult, - /// Permission decision. - Permission(PermissionResult), - /// User input response (or `None` to signal no input available). - UserInput(Option), - /// Result of a tool execution. - ToolResult(ToolResult), - /// Elicitation result (accept/decline/cancel with optional form data). - Elicitation(ElicitationResult), - /// Exit plan mode decision. - ExitPlanMode(ExitPlanModeResult), - /// Auto-mode-switch decision. - AutoModeSwitch(AutoModeSwitchResponse), -} - /// Result of a permission request. -/// -/// `#[non_exhaustive]` so future variants can be added without a major -/// version bump. Match arms must include a `_` fallback. #[derive(Debug, Clone)] #[non_exhaustive] pub enum PermissionResult { @@ -126,34 +31,20 @@ pub enum PermissionResult { /// Permission denied. Denied, /// Defer the response. The handler will resolve this request itself - /// later — typically after a UI prompt — by calling + /// later -- typically after a UI prompt -- by calling /// `session.permissions.handlePendingPermissionRequest` directly. The - /// SDK will not send a response for this request. - /// - /// **Notification path only** (`permission.requested`). On the direct - /// RPC path (`permission.request`), `Deferred` falls back to - /// [`Approved`](Self::Approved) because that path must return a value - /// to satisfy the JSON-RPC reply contract. + /// SDK skips its own response on this path. Deferred, - /// Provide the full response payload. The SDK passes the value as-is - /// in the `result` field of `handlePendingPermissionRequest` - /// (notification path) or as the JSON-RPC `result` directly (direct - /// RPC path). - /// - /// Use this for response shapes beyond `{ "kind": "approve-once" }` - /// or `{ "kind": "reject" }` — for example, "approve and remember" - /// with allowlist data. + /// Provide the full response payload directly. The SDK forwards the + /// value as-is on the wire. Custom(serde_json::Value), - /// No user is available to respond — for example, headless agents - /// without an interactive session. Sent as - /// `{ "kind": "user-not-available" }`. - UserNotAvailable, - /// The handler has no result to provide and the CLI should fall back - /// to another permission responder or its default policy. On the - /// notification path, the SDK will not send a pending permission response. - /// Distinct from [`Deferred`](Self::Deferred), where the handler takes - /// responsibility for resolving the request later out-of-band. + /// Decline to handle this broadcast. The SDK does not send a response, + /// which lets another connected client respond instead. NoResult, + /// No user is available to answer the prompt. On the notification + /// path, the SDK will not send a pending response. On the direct + /// RPC path, the SDK responds with `{ "kind": "user-not-available" }`. + UserNotAvailable, } /// Response to a user input request. @@ -189,306 +80,103 @@ impl Default for ExitPlanModeResult { } } -/// Response to a [`HandlerEvent::AutoModeSwitch`] request. -/// -/// Wire serialization matches the CLI's `autoModeSwitch.request` response -/// schema: `"yes"`, `"yes_always"`, or `"no"`. +/// Response to an auto-mode-switch request. #[non_exhaustive] #[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum AutoModeSwitchResponse { /// Approve the auto-mode switch for this rate-limit cycle only. Yes, - /// Approve and remember — auto-accept future auto-mode switches in this - /// session without prompting. + /// Approve and remember -- auto-accept future auto-mode switches in + /// this session without prompting. YesAlways, - /// Decline the auto-mode switch. The session stays on the current model - /// and surfaces the rate-limit error. + /// Decline the auto-mode switch. The session stays on the current + /// model and surfaces the rate-limit error. No, } -/// Callback trait for session events. -/// -/// Implement this trait to control how a session responds to CLI events, -/// permission requests, tool calls, user input prompts, elicitations, and -/// plan-mode exits. There are two styles of implementation — pick whichever -/// fits your use case: -/// -/// 1. **Per-event methods (recommended for most handlers).** Override the -/// specific `on_*` methods you care about; every method has a safe -/// default so you only write what you need. This is the pattern used by -/// [`serenity::EventHandler`][serenity], `lapin`, and most Rust SDKs -/// that dispatch broker/client callbacks. -/// 2. **Single [`on_event`](Self::on_event) method.** Override this one -/// method and `match` on [`HandlerEvent`] yourself. Useful for logging -/// middleware, custom routing, or when you want an exhaustiveness check -/// across all variants. +/// Handler for `permission.requested` broadcasts. /// -/// When you override [`on_event`](Self::on_event) directly, the per-event methods are not -/// called — your implementation is entirely responsible for dispatch. The -/// default [`on_event`](Self::on_event) fans out to the per-event methods. -/// -/// [serenity]: https://docs.rs/serenity/latest/serenity/client/trait.EventHandler.html -/// -/// # Default behavior -/// -/// - Permission requests → **denied**. -/// - User input → `None` (no answer available). -/// - External tool calls → failure result with "no handler registered". -/// - Elicitation → `"cancel"`. -/// - Exit plan mode → [`ExitPlanModeResult::default`]. -/// - Auto-mode-switch → [`AutoModeSwitchResponse::No`] (decline by default; the -/// session stays on its current model and surfaces the rate-limit error). -/// - Session events → ignored (fire-and-forget). -/// -/// # Concurrency -/// -/// **Request-triggered events** (`UserInput`, `ExternalTool` via `tool.call`, -/// `ExitPlanMode`, `PermissionRequest` via `permission.request`) are awaited -/// inline in the event loop and therefore processed **serially** per session. -/// Blocking here pauses that session's event loop — which is correct, since -/// the CLI is also blocked waiting for the response. -/// -/// **Notification-triggered events** (`PermissionRequest` via -/// `permission.requested`, `ExternalTool` via `external_tool.requested`) are -/// dispatched on spawned tasks and may run **concurrently** with each other -/// and with the serial event loop. Implementations must be safe for -/// concurrent invocation. -/// -/// # Example -/// -/// ```no_run -/// use async_trait::async_trait; -/// use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; -/// use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionId}; -/// -/// struct ApproveReadsOnly; -/// -/// #[async_trait] -/// impl SessionHandler for ApproveReadsOnly { -/// async fn on_permission_request( -/// &self, -/// _sid: SessionId, -/// _rid: RequestId, -/// data: PermissionRequestData, -/// ) -> PermissionResult { -/// match data.extra.get("tool").and_then(|v| v.as_str()) { -/// Some("view") | Some("ls") | Some("grep") => PermissionResult::Approved, -/// _ => PermissionResult::Denied, -/// } -/// } -/// } -/// ``` +/// Install via +/// [`SessionConfig::with_permission_handler`](crate::types::SessionConfig::with_permission_handler) +/// (or the matching method on [`ResumeSessionConfig`](crate::types::ResumeSessionConfig)). +/// When no permission handler is supplied, the SDK sends +/// `requestPermission: false` on the wire and the runtime short-circuits +/// permission prompts for this client. #[async_trait] -pub trait SessionHandler: Send + Sync + 'static { - /// Whether this handler claims responsibility for dispatching - /// `permission.requested` broadcasts on this session. - /// - /// Defaults to `false`. When `false`, the SDK does not send the - /// `requestPermission: true` flag on `session.create` / `session.resume`, - /// and the runtime short-circuits permission prompts with - /// `user-not-available` rather than emitting a broadcast this client - /// would silently never respond to. - /// - /// Override to `true` in handlers that implement - /// [`on_permission_request`](Self::on_permission_request) with a real - /// policy. Permission-policy wrappers - /// ([`approve_all_permissions`](crate::types::SessionConfig::approve_all_permissions) - /// and friends) flip this on automatically. - fn wants_permission_dispatch(&self) -> bool { - false - } - - /// Whether this handler claims responsibility for dispatching - /// `elicitation.requested` broadcasts on this session. Defaults to - /// `false`. See [`wants_permission_dispatch`](Self::wants_permission_dispatch) - /// for the multi-client rationale. - fn wants_elicitation_dispatch(&self) -> bool { - false - } - - /// Whether this handler claims responsibility for dispatching an - /// `external_tool.requested` broadcast for the given tool name on - /// this session. Defaults to `false`. Tool routers - /// ([`ToolHandlerRouter`](crate::tool::ToolHandlerRouter)) override to - /// return `true` for tools they have a registered handler for. - fn wants_external_tool_dispatch(&self, _tool_name: &str) -> bool { - false - } - - /// Handle an event from the session. - /// - /// The default implementation destructures `event` and calls the - /// matching per-event method (e.g. [`on_permission_request`](Self::on_permission_request) - /// for [`HandlerEvent::PermissionRequest`]). Override this method only - /// if you want a single dispatch point with exhaustive matching — most - /// handlers should override the per-event methods instead. - /// - /// See the [trait-level docs](SessionHandler#concurrency) for details on - /// which events may be dispatched concurrently. - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::SessionEvent { session_id, event } => { - self.on_session_event(session_id, event).await; - HandlerResponse::Ok - } - HandlerEvent::PermissionRequest { - session_id, - request_id, - data, - } => HandlerResponse::Permission( - self.on_permission_request(session_id, request_id, data) - .await, - ), - HandlerEvent::UserInput { - session_id, - question, - choices, - allow_freeform, - } => HandlerResponse::UserInput( - self.on_user_input(session_id, question, choices, allow_freeform) - .await, - ), - HandlerEvent::ExternalTool { invocation } => { - HandlerResponse::ToolResult(self.on_external_tool(invocation).await) - } - HandlerEvent::ElicitationRequest { - session_id, - request_id, - request, - } => HandlerResponse::Elicitation( - self.on_elicitation(session_id, request_id, request).await, - ), - HandlerEvent::ExitPlanMode { session_id, data } => { - HandlerResponse::ExitPlanMode(self.on_exit_plan_mode(session_id, data).await) - } - HandlerEvent::AutoModeSwitch { - session_id, - error_code, - retry_after_seconds, - } => HandlerResponse::AutoModeSwitch( - self.on_auto_mode_switch(session_id, error_code, retry_after_seconds) - .await, - ), - } - } - - /// Informational timeline event (assistant messages, tool execution - /// markers, session idle, etc.). Fire-and-forget — the return value is - /// ignored. - /// - /// Default: do nothing. - async fn on_session_event(&self, _session_id: SessionId, _event: SessionEvent) {} - - /// The CLI is asking whether the agent may perform a privileged action. - /// - /// Default: [`PermissionResult::Denied`]. The default-deny posture - /// matches the CLI's safety model; override to implement your own - /// policy (see the [`permission`](crate::permission) module for common - /// wrappers like `approve_all` / `approve_if`). - async fn on_permission_request( +pub trait PermissionHandler: Send + Sync + 'static { + /// Resolve a permission request. + async fn handle( &self, - _session_id: SessionId, - _request_id: RequestId, - _data: PermissionRequestData, - ) -> PermissionResult { - PermissionResult::Denied - } + session_id: SessionId, + request_id: RequestId, + data: PermissionRequestData, + ) -> PermissionResult; +} - /// The CLI is asking the user a question (optionally with a list of - /// choices). - /// - /// Default: `None` — the CLI interprets this as "no answer available" - /// and falls back to its own prompt behavior. - async fn on_user_input( +/// Handler for `elicitation.requested` broadcasts. +/// +/// When unset, `requestElicitation: false` goes on the wire. +#[async_trait] +pub trait ElicitationHandler: Send + Sync + 'static { + /// Respond to an elicitation prompt (form, URL confirm, etc.). + async fn handle( &self, - _session_id: SessionId, - _question: String, - _choices: Option>, - _allow_freeform: Option, - ) -> Option { - None - } - - /// The CLI wants to invoke a client-defined ("external") tool. - /// - /// Default: a failure [`ToolResult`] indicating no tool handler is - /// registered. Typical implementations route to a - /// [`ToolHandlerRouter`](crate::tool::ToolHandlerRouter) which - /// dispatches to tools registered via - /// [`define_tool`](crate::tool::define_tool) or custom - /// [`ToolHandler`](crate::tool::ToolHandler) impls. - async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult { - let msg = format!("No handler registered for tool '{}'", invocation.tool_name); - ToolResult::Expanded(crate::types::ToolResultExpanded { - text_result_for_llm: msg.clone(), - result_type: "failure".to_string(), - binary_results_for_llm: None, - session_log: None, - error: Some(msg), - tool_telemetry: None, - }) - } + session_id: SessionId, + request_id: RequestId, + request: ElicitationRequest, + ) -> ElicitationResult; +} - /// The CLI is requesting an elicitation (structured form / URL prompt). - /// - /// Default: cancel. - async fn on_elicitation( +/// Handler for `user_input.requested` events from the `ask_user` tool. +/// +/// When unset, `requestUserInput: false` goes on the wire and the +/// `ask_user` tool is disabled for the session. +#[async_trait] +pub trait UserInputHandler: Send + Sync + 'static { + /// Answer a question on behalf of the user. Return `None` to signal + /// "no answer available". + async fn handle( &self, - _session_id: SessionId, - _request_id: RequestId, - _request: ElicitationRequest, - ) -> ElicitationResult { - ElicitationResult { - action: "cancel".to_string(), - content: None, - } - } + session_id: SessionId, + question: String, + choices: Option>, + allow_freeform: Option, + ) -> Option; +} - /// The CLI is asking the user whether to exit plan mode. - /// - /// Default: [`ExitPlanModeResult::default`] (approved with no action). - async fn on_exit_plan_mode( - &self, - _session_id: SessionId, - _data: ExitPlanModeData, - ) -> ExitPlanModeResult { - ExitPlanModeResult::default() - } +/// Handler for `exit_plan_mode.requested` events. When unset, +/// `requestExitPlanMode: false` goes on the wire. +#[async_trait] +pub trait ExitPlanModeHandler: Send + Sync + 'static { + /// Decide whether to leave plan mode. + async fn handle(&self, session_id: SessionId, data: ExitPlanModeData) -> ExitPlanModeResult; +} - /// The CLI is asking whether to switch to auto model after an eligible - /// rate limit. - /// - /// `retry_after_seconds`, when present, is the number of seconds until the - /// rate limit resets. Handlers can use it to render a humanized reset time - /// alongside the prompt. - /// - /// Default: [`AutoModeSwitchResponse::No`] — decline. Override only if - /// your application surfaces a UX for the rate-limit-recovery prompt. - async fn on_auto_mode_switch( +/// Handler for `auto_mode_switch.requested` events. When unset, +/// `requestAutoModeSwitch: false` goes on the wire. +#[async_trait] +pub trait AutoModeSwitchHandler: Send + Sync + 'static { + /// Decide whether to fall back to the auto model after an eligible + /// rate-limit error. `retry_after_seconds`, when present, is the + /// number of seconds until the rate limit resets. + async fn handle( &self, - _session_id: SessionId, - _error_code: Option, - _retry_after_seconds: Option, - ) -> AutoModeSwitchResponse { - AutoModeSwitchResponse::No - } + session_id: SessionId, + error_code: Option, + retry_after_seconds: Option, + ) -> AutoModeSwitchResponse; } -/// A [`SessionHandler`] that auto-approves all permissions and ignores all events. -/// -/// Useful for CLI tools, scripts, and tests that don't need interactive -/// permission prompts or custom tool handling. +/// A [`PermissionHandler`] that approves every request. Useful for CLI +/// tools, scripts, and tests that don't need interactive permission +/// prompts. #[derive(Debug, Clone)] pub struct ApproveAllHandler; #[async_trait] -impl SessionHandler for ApproveAllHandler { - fn wants_permission_dispatch(&self) -> bool { - true - } - - async fn on_permission_request( +impl PermissionHandler for ApproveAllHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -498,220 +186,47 @@ impl SessionHandler for ApproveAllHandler { } } -/// A [`SessionHandler`] that denies all permission requests and otherwise -/// relies on the trait's default fallback responses for every other event -/// (e.g. tool invocations return "unhandled", elicitations cancel, plan-mode -/// prompts decline). Use this when a session should never wait for manual -/// permission approval. +/// A [`PermissionHandler`] that denies every request. #[derive(Debug, Clone)] pub struct DenyAllHandler; #[async_trait] -impl SessionHandler for DenyAllHandler { - fn wants_permission_dispatch(&self) -> bool { - true - } - // All other defaults are already safe. -} - -/// A [`SessionHandler`] that leaves permission requests and external tool calls pending. -/// -/// This is the default used when no handler is set on -/// [`SessionConfig::handler`](crate::types::SessionConfig::handler). It lets consumers -/// observe `permission.requested` and `external_tool.requested` events and later resolve -/// them with the corresponding pending-request RPC methods. -#[derive(Debug, Clone)] -pub struct NoopHandler; - -#[async_trait] -impl SessionHandler for NoopHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::SessionEvent { .. } => HandlerResponse::Ok, - HandlerEvent::PermissionRequest { .. } => { - HandlerResponse::Permission(PermissionResult::NoResult) - } - HandlerEvent::UserInput { .. } => HandlerResponse::UserInput(None), - HandlerEvent::ExternalTool { .. } => HandlerResponse::NoResult, - HandlerEvent::ElicitationRequest { .. } => { - HandlerResponse::Elicitation(ElicitationResult { - action: "cancel".to_string(), - content: None, - }) - } - HandlerEvent::ExitPlanMode { .. } => { - HandlerResponse::ExitPlanMode(ExitPlanModeResult::default()) - } - HandlerEvent::AutoModeSwitch { .. } => { - HandlerResponse::AutoModeSwitch(AutoModeSwitchResponse::No) - } - } +impl PermissionHandler for DenyAllHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Denied } } #[cfg(test)] mod tests { - use serde_json::Value; - use super::*; - use crate::types::{PermissionRequestData, RequestId, SessionId}; - - fn perm_data() -> PermissionRequestData { - PermissionRequestData::default() - } - - // A handler that overrides only `on_permission_request` (per-method style). - struct ApproveViaPerMethod; - - #[async_trait] - impl SessionHandler for ApproveViaPerMethod { - async fn on_permission_request( - &self, - _: SessionId, - _: RequestId, - _: PermissionRequestData, - ) -> PermissionResult { - PermissionResult::Approved - } - } - - // A handler that overrides `on_event` directly (legacy / routing style). - struct ApproveViaOnEvent; - - #[async_trait] - impl SessionHandler for ApproveViaOnEvent { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::PermissionRequest { .. } => { - HandlerResponse::Permission(PermissionResult::Approved) - } - _ => HandlerResponse::Ok, - } - } - } #[tokio::test] - async fn per_method_override_dispatches_via_default_on_event() { - let h = ApproveViaPerMethod; - let resp = h - .on_event(HandlerEvent::PermissionRequest { - session_id: SessionId::from("s1".to_string()), - request_id: RequestId::new("r1"), - data: perm_data(), - }) + async fn approve_all_handler_returns_approved() { + let result = ApproveAllHandler + .handle( + SessionId::from("s1"), + RequestId::new("1"), + PermissionRequestData::default(), + ) .await; - assert!(matches!( - resp, - HandlerResponse::Permission(PermissionResult::Approved) - )); + assert!(matches!(result, PermissionResult::Approved)); } #[tokio::test] - async fn on_event_override_short_circuits_per_method_defaults() { - let h = ApproveViaOnEvent; - let resp = h - .on_event(HandlerEvent::PermissionRequest { - session_id: SessionId::from("s1".to_string()), - request_id: RequestId::new("r1"), - data: perm_data(), - }) + async fn deny_all_handler_returns_denied() { + let result = DenyAllHandler + .handle( + SessionId::from("s1"), + RequestId::new("1"), + PermissionRequestData::default(), + ) .await; - assert!(matches!( - resp, - HandlerResponse::Permission(PermissionResult::Approved) - )); + assert!(matches!(result, PermissionResult::Denied)); } - - #[tokio::test] - async fn deny_all_handler_uses_default_permission_deny() { - let h = DenyAllHandler; - let resp = h - .on_event(HandlerEvent::PermissionRequest { - session_id: SessionId::from("s1".to_string()), - request_id: RequestId::new("r1"), - data: perm_data(), - }) - .await; - assert!(matches!( - resp, - HandlerResponse::Permission(PermissionResult::Denied) - )); - } - - #[tokio::test] - async fn default_on_external_tool_returns_failure() { - let h = DenyAllHandler; - let resp = h - .on_event(HandlerEvent::ExternalTool { - invocation: crate::types::ToolInvocation { - session_id: SessionId::from("s1".to_string()), - tool_call_id: "tc1".to_string(), - tool_name: "missing".to_string(), - arguments: Value::Null, - traceparent: None, - tracestate: None, - }, - }) - .await; - match resp { - HandlerResponse::ToolResult(crate::types::ToolResult::Expanded(exp)) => { - assert_eq!(exp.result_type, "failure"); - assert!(exp.text_result_for_llm.contains("missing")); - assert_eq!(exp.error.as_deref(), Some(exp.text_result_for_llm.as_str())); - } - other => panic!("unexpected response: {other:?}"), - } - } - - #[tokio::test] - async fn noop_handler_leaves_permission_and_external_tool_pending() { - let h = NoopHandler; - let permission = h - .on_event(HandlerEvent::PermissionRequest { - session_id: SessionId::from("s1".to_string()), - request_id: RequestId::new("r1"), - data: perm_data(), - }) - .await; - assert!(matches!( - permission, - HandlerResponse::Permission(PermissionResult::NoResult) - )); - - let tool = h - .on_event(HandlerEvent::ExternalTool { - invocation: crate::types::ToolInvocation { - session_id: SessionId::from("s1".to_string()), - tool_call_id: "tc1".to_string(), - tool_name: "manual".to_string(), - arguments: Value::Null, - traceparent: None, - tracestate: None, - }, - }) - .await; - assert!(matches!(tool, HandlerResponse::NoResult)); - } - - #[tokio::test] - async fn default_on_elicitation_returns_cancel() { - let h = DenyAllHandler; - let resp = h - .on_event(HandlerEvent::ElicitationRequest { - session_id: SessionId::from("s1".to_string()), - request_id: RequestId::new("r1"), - request: crate::types::ElicitationRequest { - message: "test".to_string(), - requested_schema: None, - mode: Some(crate::types::ElicitationMode::Form), - elicitation_source: None, - url: None, - }, - }) - .await; - match resp { - HandlerResponse::Elicitation(r) => assert_eq!(r.action, "cancel"), - other => panic!("unexpected response: {other:?}"), - } - } -} +} \ No newline at end of file diff --git a/rust/src/lib.rs b/rust/src/lib.rs index d917d6821..0c8fd33d2 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -10,7 +10,7 @@ pub mod handler; /// Lifecycle hook callbacks (pre/post tool use, prompt submission, session start/end). pub mod hooks; mod jsonrpc; -/// Permission-policy helpers that wrap an existing [`handler::SessionHandler`]. +/// Permission-policy helpers that produce a [`handler::PermissionHandler`]. pub mod permission; /// GitHub Copilot CLI binary resolution (env var, embedded, PATH search). pub mod resolve; @@ -30,6 +30,7 @@ pub mod trace_context; pub mod transforms; /// Protocol types shared between the SDK and the GitHub Copilot CLI. pub mod types; +mod wire; /// Auto-generated protocol types from Copilot JSON Schemas. pub mod generated; diff --git a/rust/src/permission.rs b/rust/src/permission.rs index 6971ee7e4..0b1c67f01 100644 --- a/rust/src/permission.rs +++ b/rust/src/permission.rs @@ -1,81 +1,64 @@ -//! Permission-policy helpers that compose with an existing -//! [`SessionHandler`](crate::handler::SessionHandler). +//! Permission policy primitives that produce a [`PermissionHandler`]. //! -//! These wrap an inner handler and override **only** permission requests, -//! forwarding every other event (tool calls, user input, elicitation, -//! session events) to the inner handler. Use them when you have a custom -//! tool handler — typically a [`ToolHandlerRouter`](crate::tool::ToolHandlerRouter) — -//! but want a one-line policy for permission prompts. +//! Compose these into a session via the builder methods +//! [`SessionConfig::approve_all_permissions`](crate::types::SessionConfig::approve_all_permissions), +//! [`deny_all_permissions`](crate::types::SessionConfig::deny_all_permissions), +//! and [`approve_permissions_if`](crate::types::SessionConfig::approve_permissions_if). +//! The same primitives are also available as standalone functions that +//! return an `Arc` you can install via +//! [`SessionConfig::with_permission_handler`](crate::types::SessionConfig::with_permission_handler). //! -//! For a full handler that approves or denies everything, see +//! For a one-shot approve / deny without composition, see //! [`ApproveAllHandler`](crate::handler::ApproveAllHandler) and //! [`DenyAllHandler`](crate::handler::DenyAllHandler). -//! -//! # Example -//! -//! ```rust,no_run -//! # use std::sync::Arc; -//! # use github_copilot_sdk::handler::ApproveAllHandler; -//! # use github_copilot_sdk::permission; -//! # use github_copilot_sdk::tool::ToolHandlerRouter; -//! let router = ToolHandlerRouter::new(vec![], Arc::new(ApproveAllHandler)); -//! // Inherit the router's tool dispatch but auto-approve all permission prompts: -//! let handler = permission::approve_all(Arc::new(router)); -//! ``` use std::sync::Arc; use async_trait::async_trait; -use crate::handler::{HandlerEvent, HandlerResponse, PermissionResult, SessionHandler}; -use crate::types::PermissionRequestData; +use crate::handler::{PermissionHandler, PermissionResult}; +use crate::types::{PermissionRequestData, RequestId, SessionId}; -/// Wrap `inner` so that every [`HandlerEvent::PermissionRequest`] is -/// auto-approved. All other events are forwarded to `inner`. -pub fn approve_all(inner: Arc) -> Arc { - Arc::new(PermissionOverrideHandler { - inner, +/// Return a [`PermissionHandler`] that approves every request. +pub fn approve_all() -> Arc { + Arc::new(PolicyHandler { policy: Policy::ApproveAll, }) } -/// Wrap `inner` so that every [`HandlerEvent::PermissionRequest`] is -/// auto-denied. All other events are forwarded to `inner`. -pub fn deny_all(inner: Arc) -> Arc { - Arc::new(PermissionOverrideHandler { - inner, +/// Return a [`PermissionHandler`] that denies every request. +pub fn deny_all() -> Arc { + Arc::new(PolicyHandler { policy: Policy::DenyAll, }) } -/// Wrap `inner` with a closure-based policy: `predicate` is called for each -/// permission request; `true` approves, `false` denies. All other events -/// are forwarded to `inner`. +/// Return a [`PermissionHandler`] that consults a predicate for each +/// request. `true` approves, `false` denies. /// /// ```rust,no_run -/// # use std::sync::Arc; -/// # use github_copilot_sdk::handler::ApproveAllHandler; /// # use github_copilot_sdk::permission; -/// let inner = Arc::new(ApproveAllHandler); -/// let handler = permission::approve_if(inner, |data| { -/// // Inspect data.extra (the raw JSON payload) for custom policy. +/// let handler = permission::approve_if(|data| { /// data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") /// }); /// # let _ = handler; /// ``` -pub fn approve_if(inner: Arc, predicate: F) -> Arc +pub fn approve_if(predicate: F) -> Arc where F: Fn(&PermissionRequestData) -> bool + Send + Sync + 'static, { - Arc::new(PermissionOverrideHandler { - inner, + Arc::new(PolicyHandler { policy: Policy::Predicate(Arc::new(predicate)), }) } -/// Internal permission policy stored on `SessionConfig::permission_policy`. -/// Applied to the handler at `Client::create_session` time so the order of -/// `with_handler` and `approve_all_permissions` is irrelevant. +/// Internal policy enum used by both the standalone helpers and the +/// `SessionConfig` policy builders. +/// +/// Stored as `pub(crate)` on `SessionConfig::permission_policy` so that +/// the order of `with_permission_handler(...)` and the policy builders +/// does not matter -- the policy is applied at `Client::create_session` +/// time. #[derive(Clone)] pub(crate) enum Policy { ApproveAll, @@ -93,49 +76,49 @@ impl std::fmt::Debug for Policy { } } -/// Wrap `inner` with a stored [`Policy`]. Used by `Client::create_session` -/// after the handler and policy fields are both finalised. -pub(crate) fn apply_policy( - inner: Arc, - policy: Policy, -) -> Arc { - Arc::new(PermissionOverrideHandler { inner, policy }) +/// Resolve the effective permission handler for a session, given the +/// caller-supplied handler and policy. Called by `Client::create_session` +/// and `Client::resume_session`. +/// +/// Semantics: +/// - When `policy` is `Some`, the policy entirely replaces the handler +/// for permission decisions. (Caller-supplied handler, if any, is +/// discarded -- the policy is what answers permission requests.) +/// - When `policy` is `None` and `handler` is `Some`, the handler stands. +/// - When both are `None`, returns `None` (no handler -- the SDK sends +/// `requestPermission: false`). +pub(crate) fn resolve_handler( + handler: Option>, + policy: Option, +) -> Option> { + match (handler, policy) { + (_, Some(policy)) => Some(Arc::new(PolicyHandler { policy })), + (Some(h), None) => Some(h), + (None, None) => None, + } } -struct PermissionOverrideHandler { - inner: Arc, +struct PolicyHandler { policy: Policy, } #[async_trait] -impl SessionHandler for PermissionOverrideHandler { - fn wants_permission_dispatch(&self) -> bool { - true - } - - fn wants_elicitation_dispatch(&self) -> bool { - self.inner.wants_elicitation_dispatch() - } - - fn wants_external_tool_dispatch(&self, tool_name: &str) -> bool { - self.inner.wants_external_tool_dispatch(tool_name) - } - - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::PermissionRequest { ref data, .. } => { - let approved = match &self.policy { - Policy::ApproveAll => true, - Policy::DenyAll => false, - Policy::Predicate(f) => f(data), - }; - HandlerResponse::Permission(if approved { - PermissionResult::Approved - } else { - PermissionResult::Denied - }) - } - other => self.inner.on_event(other).await, +impl PermissionHandler for PolicyHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + data: PermissionRequestData, + ) -> PermissionResult { + let approved = match &self.policy { + Policy::ApproveAll => true, + Policy::DenyAll => false, + Policy::Predicate(f) => f(&data), + }; + if approved { + PermissionResult::Approved + } else { + PermissionResult::Denied } } } @@ -143,61 +126,91 @@ impl SessionHandler for PermissionOverrideHandler { #[cfg(test)] mod tests { use super::*; - use crate::handler::ApproveAllHandler; - use crate::types::{RequestId, SessionId}; - - fn request() -> HandlerEvent { - HandlerEvent::PermissionRequest { - session_id: SessionId::from("s1"), - request_id: RequestId::new("1"), - data: PermissionRequestData { - extra: serde_json::json!({"tool": "shell"}), - ..Default::default() - }, + + fn data() -> PermissionRequestData { + PermissionRequestData { + extra: serde_json::json!({ "tool": "shell" }), + ..Default::default() } } #[tokio::test] - async fn approve_all_approves_permission_requests() { - let h = approve_all(Arc::new(ApproveAllHandler)); - match h.on_event(request()).await { - HandlerResponse::Permission(PermissionResult::Approved) => {} - other => panic!("expected Approved, got {other:?}"), - } + async fn approve_all_approves() { + let h = approve_all(); + assert!(matches!( + h.handle(SessionId::from("s"), RequestId::new("1"), data()).await, + PermissionResult::Approved + )); } #[tokio::test] - async fn deny_all_denies_permission_requests() { - let h = deny_all(Arc::new(ApproveAllHandler)); - match h.on_event(request()).await { - HandlerResponse::Permission(PermissionResult::Denied) => {} - other => panic!("expected Denied, got {other:?}"), - } + async fn deny_all_denies() { + let h = deny_all(); + assert!(matches!( + h.handle(SessionId::from("s"), RequestId::new("1"), data()).await, + PermissionResult::Denied + )); } #[tokio::test] async fn approve_if_consults_predicate() { - let h = approve_if(Arc::new(ApproveAllHandler), |data| { - data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") - }); - match h.on_event(request()).await { - HandlerResponse::Permission(PermissionResult::Denied) => {} - other => panic!("expected Denied for shell, got {other:?}"), + let h = approve_if(|d| d.extra.get("tool").and_then(|v| v.as_str()) != Some("shell")); + assert!(matches!( + h.handle(SessionId::from("s"), RequestId::new("1"), data()).await, + PermissionResult::Denied + )); + } + + #[tokio::test] + async fn resolve_handler_policy_wins() { + struct AlwaysApprove; + #[async_trait] + impl PermissionHandler for AlwaysApprove { + async fn handle( + &self, + _: SessionId, + _: RequestId, + _: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Approved + } } + let resolved = + resolve_handler(Some(Arc::new(AlwaysApprove)), Some(Policy::DenyAll)).unwrap(); + // Policy wins -- the AlwaysApprove handler is discarded. + assert!(matches!( + resolved + .handle(SessionId::from("s"), RequestId::new("1"), data()) + .await, + PermissionResult::Denied + )); } #[tokio::test] - async fn non_permission_events_forward_to_inner() { - let h = deny_all(Arc::new(ApproveAllHandler)); - let event = HandlerEvent::UserInput { - session_id: SessionId::from("s1"), - question: "continue?".to_string(), - choices: None, - allow_freeform: None, - }; - match h.on_event(event).await { - HandlerResponse::UserInput(None) => {} - other => panic!("expected UserInput forwarded, got {other:?}"), + async fn resolve_handler_with_only_handler() { + struct H; + #[async_trait] + impl PermissionHandler for H { + async fn handle( + &self, + _: SessionId, + _: RequestId, + _: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Approved + } } + let resolved = resolve_handler(Some(Arc::new(H)), None).unwrap(); + assert!(matches!( + resolved + .handle(SessionId::from("s"), RequestId::new("1"), data()) + .await, + PermissionResult::Approved + )); } -} + + #[test] + fn resolve_handler_with_neither_returns_none() { + assert!(resolve_handler(None, None).is_none()); + } +} \ No newline at end of file diff --git a/rust/src/session.rs b/rust/src/session.rs index f773f9f13..e0efb4249 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -19,8 +19,8 @@ use crate::generated::session_events::{ SessionEventType, }; use crate::handler::{ - AutoModeSwitchResponse, HandlerEvent, HandlerResponse, PermissionResult, SessionHandler, - UserInputResponse, + AutoModeSwitchHandler, AutoModeSwitchResponse, ElicitationHandler, ExitPlanModeHandler, + PermissionHandler, PermissionResult, UserInputHandler, UserInputResponse, }; use crate::hooks::SessionHooks; use crate::session_fs::SessionFsProvider; @@ -31,11 +31,41 @@ use crate::types::{ ElicitationResult, ExitPlanModeData, GetMessagesResponse, MessageOptions, PermissionRequestData, RequestId, ResumeSessionConfig, SectionOverride, SessionCapabilities, SessionConfig, SessionEvent, SessionId, SetModelOptions, SystemMessageConfig, ToolInvocation, - ToolResult, ToolResultExpanded, ToolResultResponse, TraceContext, UiInputOptions, + ToolResult, ToolResultExpanded, TraceContext, UiInputOptions, ensure_attachment_display_names, }; use crate::{Client, Error, JsonRpcResponse, SessionError, SessionEventNotification, error_codes}; +/// Bundle of the per-session callbacks the SDK dispatches to. Built from a +/// [`SessionConfig`] / [`ResumeSessionConfig`] at +/// [`Client::create_session`] / [`Client::resume_session`] time. Each +/// field is `None` (or an empty map for tools) when the caller didn't +/// install a handler -- in that case the SDK skips dispatch for that +/// event type. The wire flags on `session.create` / `session.resume` +/// are derived from these fields. +#[derive(Clone)] +pub(crate) struct SessionHandlers { + pub permission: Option>, + pub elicitation: Option>, + pub user_input: Option>, + pub exit_plan_mode: Option>, + pub auto_mode_switch: Option>, + pub tools: Arc>>, +} + +impl SessionHandlers { + pub(crate) fn empty() -> Self { + Self { + permission: None, + elicitation: None, + user_input: None, + exit_plan_mode: None, + auto_mode_switch: None, + tools: Arc::new(HashMap::new()), + } + } +} + /// Shared state between a [`Session`] and its event loop, used by [`Session::send_and_wait`]. struct IdleWaiter { tx: oneshot::Sender, Error>>, @@ -106,7 +136,8 @@ impl Drop for PendingSessionRegistration { /// A session on a GitHub Copilot CLI server. /// /// Created via [`Client::create_session`] or [`Client::resume_session`]. -/// Owns an internal event loop that dispatches events to the [`SessionHandler`]. +/// Owns an internal event loop that dispatches events to the per-callback +/// handlers installed on the session config. /// /// Protocol methods (`send`, `get_messages`, `abort`, etc.) automatically /// inject the session ID into RPC params. @@ -220,10 +251,10 @@ impl Session { /// /// **Observe-only.** Subscribers receive a clone of every /// [`SessionEvent`] but cannot influence permission decisions, tool - /// results, or anything else that requires returning a - /// [`HandlerResponse`]. Those remain - /// the responsibility of the [`SessionHandler`] passed via - /// [`SessionConfig::handler`](crate::types::SessionConfig::handler). + /// results, or anything else that requires returning a value. Those + /// remain the responsibility of the per-callback handlers passed via + /// [`SessionConfig`](crate::types::SessionConfig)'s `with_*_handler` + /// builder methods. /// /// The returned handle implements both an inherent /// [`recv`](crate::subscription::EventSubscription::recv) method and @@ -745,10 +776,10 @@ impl Client { /// Sends `session.create`, registers the session on the router, /// and spawns an internal event loop that dispatches to the handler. /// - /// All callbacks (event handler, hooks, transform) are configured - /// via [`SessionConfig`] using [`with_handler`](SessionConfig::with_handler), - /// [`with_hooks`](SessionConfig::with_hooks), and - /// [`with_transform`](SessionConfig::with_transform). + /// All callbacks (per-event handlers, tool handlers, hooks, transform) + /// are configured via [`SessionConfig`] using its `with_*_handler` / + /// `with_tool_handlers` / `with_hooks` / `with_transform` builder + /// methods. /// /// If [`hooks_handler`](SessionConfig::hooks_handler) is set, the /// wire-level `hooks` flag is automatically enabled. @@ -758,21 +789,37 @@ impl Client { /// format and handles `systemMessage.transform` RPC callbacks during /// the session. /// - /// If [`handler`](SessionConfig::handler) is `None`, the session uses - /// [`NoopHandler`](crate::handler::NoopHandler) — permission requests and - /// external tool calls are left pending for the consumer to resolve. + /// Each per-event handler is independently optional. If a handler is + /// not installed, the SDK signals the runtime not to emit the matching + /// broadcast (and silently skips dispatch if one arrives anyway). pub async fn create_session(&self, mut config: SessionConfig) -> Result { let total_start = Instant::now(); - let base_handler = config - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - let handler = match config.permission_policy.take() { - Some(policy) => crate::permission::apply_policy(base_handler, policy), - None => base_handler, + let permission_handler = crate::permission::resolve_handler( + config.permission_handler.take(), + config.permission_policy.take(), + ); + let elicitation_handler = config.elicitation_handler.take(); + let user_input_handler = config.user_input_handler.take(); + let exit_plan_mode_handler = config.exit_plan_mode_handler.take(); + let auto_mode_switch_handler = config.auto_mode_switch_handler.take(); + let mut tool_map: HashMap> = HashMap::new(); + for tool in config.tool_handlers.drain(..) { + let name = tool.tool().name; + if tool_map.contains_key(&name) { + return Err(Error::InvalidConfig(format!( + "duplicate tool handler registered for name {name:?}" + ))); + } + tool_map.insert(name, tool); + } + let handlers = SessionHandlers { + permission: permission_handler, + elicitation: elicitation_handler, + user_input: user_input_handler, + exit_plan_mode: exit_plan_mode_handler, + auto_mode_switch: auto_mode_switch_handler, + tools: Arc::new(tool_map), }; - config.request_permission = Some(handler.wants_permission_dispatch()); - config.request_elicitation = Some(handler.wants_elicitation_dispatch()); let hooks = config.hooks_handler.take(); let transforms = config.transform.take(); let tools_count = config.tools.as_ref().map_or(0, Vec::len); @@ -805,7 +852,8 @@ impl Client { .clone() .unwrap_or_else(|| SessionId::from(uuid::Uuid::new_v4().to_string())); config.session_id = Some(session_id.clone()); - let mut params = serde_json::to_value(&config)?; + let wire = config.to_wire(session_id.clone()); + let mut params = serde_json::to_value(&wire)?; let trace_ctx = self.resolve_trace_context().await; inject_trace_context(&mut params, &trace_ctx); @@ -818,7 +866,7 @@ impl Client { let event_loop = spawn_event_loop( session_id.clone(), self.clone(), - handler, + handlers, hooks, transforms, command_handlers, @@ -900,16 +948,32 @@ impl Client { /// fields are unset. pub async fn resume_session(&self, mut config: ResumeSessionConfig) -> Result { let total_start = Instant::now(); - let base_handler = config - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - let handler = match config.permission_policy.take() { - Some(policy) => crate::permission::apply_policy(base_handler, policy), - None => base_handler, + let permission_handler = crate::permission::resolve_handler( + config.permission_handler.take(), + config.permission_policy.take(), + ); + let elicitation_handler = config.elicitation_handler.take(); + let user_input_handler = config.user_input_handler.take(); + let exit_plan_mode_handler = config.exit_plan_mode_handler.take(); + let auto_mode_switch_handler = config.auto_mode_switch_handler.take(); + let mut tool_map: HashMap> = HashMap::new(); + for tool in config.tool_handlers.drain(..) { + let name = tool.tool().name; + if tool_map.contains_key(&name) { + return Err(Error::InvalidConfig(format!( + "duplicate tool handler registered for name {name:?}" + ))); + } + tool_map.insert(name, tool); + } + let handlers = SessionHandlers { + permission: permission_handler, + elicitation: elicitation_handler, + user_input: user_input_handler, + exit_plan_mode: exit_plan_mode_handler, + auto_mode_switch: auto_mode_switch_handler, + tools: Arc::new(tool_map), }; - config.request_permission = Some(handler.wants_permission_dispatch()); - config.request_elicitation = Some(handler.wants_elicitation_dispatch()); let hooks = config.hooks_handler.take(); let transforms = config.transform.take(); let tools_count = config.tools.as_ref().map_or(0, Vec::len); @@ -938,7 +1002,8 @@ impl Client { inject_transform_sections_resume(&mut config, transforms.as_ref()); } let session_id = config.session_id.clone(); - let mut params = serde_json::to_value(&config)?; + let wire = config.to_wire(); + let mut params = serde_json::to_value(&wire)?; let trace_ctx = self.resolve_trace_context().await; inject_trace_context(&mut params, &trace_ctx); @@ -951,7 +1016,7 @@ impl Client { let event_loop = spawn_event_loop( session_id.clone(), self.clone(), - handler, + handlers, hooks, transforms, command_handlers, @@ -1078,7 +1143,7 @@ fn build_command_handler_map(commands: Option<&[CommandDefinition]>) -> Arc, + handlers: SessionHandlers, hooks: Option>, transforms: Option>, command_handlers: Arc, @@ -1112,12 +1177,12 @@ fn spawn_event_loop( _ = shutdown.cancelled() => break, Some(notification) = notifications.recv() => { handle_notification( - &session_id, &client, &handler, &command_handlers, notification, &idle_waiter, &capabilities, &event_tx, + &session_id, &client, &handlers, &command_handlers, notification, &idle_waiter, &capabilities, &event_tx, ).await; } Some(request) = requests.recv() => { handle_request( - &session_id, &client, &handler, hooks.as_deref(), transforms.as_deref(), session_fs_provider.as_ref(), request, + &session_id, &client, &handlers, hooks.as_deref(), transforms.as_deref(), session_fs_provider.as_ref(), request, ).await; } else => break, @@ -1142,20 +1207,20 @@ fn extract_request_id(data: &Value) -> Option { .map(RequestId::new) } -fn pending_permission_result_kind(response: &HandlerResponse) -> &'static str { - match response { - HandlerResponse::Permission(PermissionResult::Approved) => "approve-once", - HandlerResponse::Permission(PermissionResult::Denied) => "reject", +fn pending_permission_result_kind(result: &PermissionResult) -> &'static str { + match result { + PermissionResult::Approved => "approve-once", + PermissionResult::Denied => "reject", // Fallback to "user-not-available" for UserNotAvailable, Deferred (when // forced through this path), Custom (handled separately upstream), and - // any non-permission/no-result HandlerResponse that gets here defensively. + // NoResult that gets here defensively. _ => "user-not-available", } } -fn permission_request_response(response: &HandlerResponse) -> PermissionDecision { - match response { - HandlerResponse::Permission(PermissionResult::Approved) => { +fn permission_request_response(result: &PermissionResult) -> PermissionDecision { + match result { + PermissionResult::Approved => { PermissionDecision::ApproveOnce(PermissionDecisionApproveOnce { kind: PermissionDecisionApproveOnceKind::ApproveOnce, }) @@ -1167,41 +1232,37 @@ fn permission_request_response(response: &HandlerResponse) -> PermissionDecision } } -/// Map a handler response into the `result` payload for the notification +/// Map a permission result into the `result` payload for the notification /// path (`session.permissions.handlePendingPermissionRequest`). /// /// Returns `None` when the SDK must not respond. -fn notification_permission_payload(response: &HandlerResponse) -> Option { - match response { - HandlerResponse::Permission(PermissionResult::Deferred | PermissionResult::NoResult) => { - None - } - HandlerResponse::Permission(PermissionResult::Custom(value)) => Some(value.clone()), +fn notification_permission_payload(result: &PermissionResult) -> Option { + match result { + PermissionResult::Deferred | PermissionResult::NoResult => None, + PermissionResult::Custom(value) => Some(value.clone()), _ => Some(serde_json::json!({ - "kind": pending_permission_result_kind(response), + "kind": pending_permission_result_kind(result), })), } } -/// Map a handler response into the JSON-RPC `result` payload for the +/// Map a permission result into the JSON-RPC `result` payload for the /// direct-RPC path (`permission.request`). /// /// Always returns a value. [`PermissionResult::Deferred`] is treated as /// [`PermissionResult::Approved`] here because the JSON-RPC contract /// requires a reply — see the variant's doc comment. -fn direct_permission_payload(response: &HandlerResponse) -> Value { - match response { - HandlerResponse::Permission(PermissionResult::Custom(value)) => value.clone(), - HandlerResponse::Permission(PermissionResult::Deferred) => serde_json::to_value( - permission_request_response(&HandlerResponse::Permission(PermissionResult::Approved)), - ) - .expect("serializing direct permission response should succeed"), - HandlerResponse::Permission( - PermissionResult::NoResult | PermissionResult::UserNotAvailable, - ) => serde_json::json!({ - "kind": pending_permission_result_kind(response), +fn direct_permission_payload(result: &PermissionResult) -> Value { + match result { + PermissionResult::Custom(value) => value.clone(), + PermissionResult::Deferred => { + serde_json::to_value(permission_request_response(&PermissionResult::Approved)) + .expect("serializing direct permission response should succeed") + } + PermissionResult::NoResult | PermissionResult::UserNotAvailable => serde_json::json!({ + "kind": pending_permission_result_kind(result), }), - _ => serde_json::to_value(permission_request_response(response)) + _ => serde_json::to_value(permission_request_response(result)) .expect("serializing direct permission response should succeed"), } } @@ -1218,33 +1279,12 @@ fn tool_failure_result(message: impl Into) -> ToolResult { }) } -fn notification_tool_payload(response: HandlerResponse) -> Option { - match response { - HandlerResponse::ToolResult(result) => { - Some(serde_json::to_value(result).unwrap_or(Value::Null)) - } - HandlerResponse::NoResult => None, - _ => Some( - serde_json::to_value(tool_failure_result("Unexpected handler response")) - .unwrap_or(Value::Null), - ), - } -} - -fn direct_tool_result(response: HandlerResponse) -> ToolResult { - match response { - HandlerResponse::ToolResult(result) => result, - HandlerResponse::NoResult => tool_failure_result("No tool handler available"), - _ => tool_failure_result("Unexpected handler response"), - } -} - /// Process a notification from the CLI's broadcast channel. #[allow(clippy::too_many_arguments)] async fn handle_notification( session_id: &SessionId, client: &Client, - handler: &Arc, + handlers: &SessionHandlers, command_handlers: &Arc, notification: SessionEventNotification, idle_waiter: &Arc>>, @@ -1321,14 +1361,6 @@ async fn handle_notification( // before any consumer subscribes. let _ = event_tx.send(event.clone()); - // Fire-and-forget dispatch for the general event. - handler - .on_event(HandlerEvent::SessionEvent { - session_id: session_id.clone(), - event, - }) - .await; - // Update capabilities when the CLI reports changes. The CLI sends // the full updated capabilities object — replace wholesale so removals // and new subfields are handled correctly. @@ -1365,14 +1397,13 @@ async fn handle_notification( { return; } - // Multi-client safety: if this client's handler does not - // claim permission dispatch, don't respond — another client - // on the same CLI may handle it. - if !handler.wants_permission_dispatch() { + // Multi-client safety: if this client has no permission + // handler installed, don't respond — another client on the + // same CLI may handle it. + let Some(permission_handler) = handlers.permission.clone() else { return; - } + }; let client = client.clone(); - let handler = handler.clone(); let sid = session_id.clone(); let data: PermissionRequestData = serde_json::from_value(notification.event.data.clone()).unwrap_or_else(|_| { @@ -1390,22 +1421,19 @@ async fn handle_notification( tokio::spawn( async move { let handler_start = Instant::now(); - let response = handler - .on_event(HandlerEvent::PermissionRequest { - session_id: sid.clone(), - request_id: request_id.clone(), - data, - }) + let result = permission_handler + .handle(sid.clone(), request_id.clone(), data) .await; tracing::debug!( elapsed_ms = handler_start.elapsed().as_millis(), session_id = %sid, request_id = %request_id, - "SessionHandler::on_permission_request dispatch" + "PermissionHandler::handle dispatch" ); - let Some(result_value) = notification_permission_payload(&response) else { - // Handler returned Deferred — it will call - // handlePendingPermissionRequest itself. + let Some(result_value) = notification_permission_payload(&result) else { + // Handler returned Deferred / NoResult — it will + // call handlePendingPermissionRequest itself (or + // leave the request unanswered). return; }; let rpc_start = Instant::now(); @@ -1470,15 +1498,18 @@ async fn handle_notification( return; } }; - // Multi-client safety: if this client doesn't claim the - // requested tool name, don't respond — another connected - // client may have a handler. - if !data.tool_name.is_empty() && !handler.wants_external_tool_dispatch(&data.tool_name) - { + // Multi-client safety: look up a handler for the requested + // tool name. If this client has no handler installed for that + // tool, don't respond — another connected client may have one. + let tool_handler = if data.tool_name.is_empty() { + None + } else { + handlers.tools.get(&data.tool_name).cloned() + }; + let Some(tool_handler) = tool_handler else { return; - } + }; let client = client.clone(); - let handler = handler.clone(); let sid = session_id.clone(); let span = tracing::error_span!( "external_tool_handler", @@ -1525,9 +1556,10 @@ async fn handle_notification( tracestate: data.tracestate, }; let handler_start = Instant::now(); - let response = handler - .on_event(HandlerEvent::ExternalTool { invocation }) - .await; + let tool_result = match tool_handler.call(invocation).await { + Ok(r) => r, + Err(e) => tool_failure_result(e.to_string()), + }; tracing::debug!( elapsed_ms = handler_start.elapsed().as_millis(), session_id = %sid, @@ -1536,9 +1568,8 @@ async fn handle_notification( tool_name = %tool_name, "ToolHandler::call dispatch" ); - let Some(result_value) = notification_tool_payload(response) else { - return; - }; + let result_value = + serde_json::to_value(tool_result).unwrap_or(Value::Null); let rpc_start = Instant::now(); let _ = client .call( @@ -1565,7 +1596,7 @@ async fn handle_notification( SessionEventType::UserInputRequested => { // Notification-only signal for observers (UI, telemetry). // The CLI follows up with a `userInput.request` JSON-RPC call - // that drives `HandlerEvent::UserInput` dispatch — handling + // that drives the `UserInputHandler` dispatch — handling // the notification here too would double-fire the handler // and produce duplicate prompts on the consumer side. See // github/github-app#4249. @@ -1574,12 +1605,12 @@ async fn handle_notification( let Some(request_id) = extract_request_id(¬ification.event.data) else { return; }; - // Multi-client safety: if this client's handler does not - // claim elicitation dispatch, don't respond — another - // client on the same CLI may handle it. - if !handler.wants_elicitation_dispatch() { + // Multi-client safety: if this client has no elicitation + // handler installed, don't respond — another client on the + // same CLI may handle it. + let Some(elicitation_handler) = handlers.elicitation.clone() else { return; - } + }; let elicitation_data: ElicitationRequestedData = match serde_json::from_value(notification.event.data.clone()) { Ok(d) => d, @@ -1606,7 +1637,6 @@ async fn handle_notification( url: elicitation_data.url, }; let client = client.clone(); - let handler = handler.clone(); let sid = session_id.clone(); let span = tracing::error_span!( "elicitation_request_handler", @@ -1630,26 +1660,22 @@ async fn handle_notification( ); async move { let handler_start = Instant::now(); - let response = handler - .on_event(HandlerEvent::ElicitationRequest { - session_id: sid.clone(), - request_id: request_id.clone(), - request, - }) + let response = elicitation_handler + .handle(sid.clone(), request_id.clone(), request) .await; tracing::debug!( elapsed_ms = handler_start.elapsed().as_millis(), session_id = %sid, request_id = %request_id, - "SessionHandler::on_elicitation dispatch" + "ElicitationHandler::handle dispatch" ); response } .instrument(span) }); let result = match handler_task.await { - Ok(HandlerResponse::Elicitation(r)) => r, - _ => cancel.clone(), + Ok(r) => r, + Err(_) => cancel.clone(), }; let rpc_start = Instant::now(); if let Err(e) = client @@ -1757,7 +1783,7 @@ async fn handle_notification( async fn handle_request( session_id: &SessionId, client: &Client, - handler: &Arc, + handlers: &SessionHandlers, hooks: Option<&dyn SessionHooks>, transforms: Option<&dyn SystemMessageTransform>, session_fs_provider: Option<&Arc>, @@ -1803,49 +1829,6 @@ async fn handle_request( let _ = client.send_response(&rpc_response).await; } - "tool.call" => { - let invocation: ToolInvocation = match request - .params - .as_ref() - .and_then(|p| serde_json::from_value::(p.clone()).ok()) - { - Some(inv) => inv, - None => { - let _ = send_error_response( - client, - request.id, - error_codes::INVALID_PARAMS, - "invalid tool.call params", - ) - .await; - return; - } - }; - let tool_call_id = invocation.tool_call_id.clone(); - let tool_name = invocation.tool_name.clone(); - let handler_start = Instant::now(); - let response = handler - .on_event(HandlerEvent::ExternalTool { invocation }) - .await; - tracing::debug!( - elapsed_ms = handler_start.elapsed().as_millis(), - session_id = %sid, - tool_call_id = %tool_call_id, - tool_name = %tool_name, - "ToolHandler::call dispatch" - ); - let tool_result = direct_tool_result(response); - let rpc_response = JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: request.id, - result: Some(serde_json::json!(ToolResultResponse { - result: tool_result - })), - error: None, - }; - let _ = client.send_response(&rpc_response).await; - } - "userInput.request" => { let params = request.params.as_ref(); let Some(question) = params @@ -1880,29 +1863,28 @@ async fn handle_request( .and_then(|v| v.as_bool()); let handler_start = Instant::now(); - let response = handler - .on_event(HandlerEvent::UserInput { - session_id: sid.clone(), - question, - choices, - allow_freeform, - }) - .await; + let response = if let Some(user_input_handler) = handlers.user_input.as_ref() { + user_input_handler + .handle(sid.clone(), question, choices, allow_freeform) + .await + } else { + None + }; tracing::debug!( elapsed_ms = handler_start.elapsed().as_millis(), session_id = %sid, - "SessionHandler::on_user_input dispatch" + "UserInputHandler::handle dispatch" ); let rpc_result = match response { - HandlerResponse::UserInput(Some(UserInputResponse { + Some(UserInputResponse { answer, was_freeform, - })) => serde_json::json!({ + }) => serde_json::json!({ "answer": answer, "wasFreeform": was_freeform, }), - _ => serde_json::json!({ "noResponse": true }), + None => serde_json::json!({ "noResponse": true }), }; let rpc_response = JsonRpcResponse { jsonrpc: "2.0".to_string(), @@ -1927,17 +1909,11 @@ async fn handle_request( } }; - let response = handler - .on_event(HandlerEvent::ExitPlanMode { - session_id: sid, - data, - }) - .await; - - let rpc_result = match response { - HandlerResponse::ExitPlanMode(result) => serde_json::to_value(result) - .expect("ExitPlanModeResult serialization cannot fail"), - _ => serde_json::json!({ "approved": true }), + let rpc_result = if let Some(exit_plan_handler) = handlers.exit_plan_mode.as_ref() { + let result = exit_plan_handler.handle(sid, data).await; + serde_json::to_value(result).expect("ExitPlanModeResult serialization cannot fail") + } else { + serde_json::json!({ "approved": true }) }; let rpc_response = JsonRpcResponse { jsonrpc: "2.0".to_string(), @@ -1961,17 +1937,12 @@ async fn handle_request( .and_then(|p| p.get("retryAfterSeconds")) .and_then(|v| v.as_f64()); - let response = handler - .on_event(HandlerEvent::AutoModeSwitch { - session_id: sid, - error_code, - retry_after_seconds, - }) - .await; - - let answer = match response { - HandlerResponse::AutoModeSwitch(answer) => answer, - _ => AutoModeSwitchResponse::No, + let answer = if let Some(auto_mode_handler) = handlers.auto_mode_switch.as_ref() { + auto_mode_handler + .handle(sid, error_code, retry_after_seconds) + .await + } else { + AutoModeSwitchResponse::No }; let rpc_response = JsonRpcResponse { jsonrpc: "2.0".to_string(), @@ -2018,23 +1989,27 @@ async fn handle_request( }); let handler_start = Instant::now(); - let response = handler - .on_event(HandlerEvent::PermissionRequest { - session_id: sid.clone(), - request_id: request_id.clone(), - data, - }) - .await; - tracing::debug!( - elapsed_ms = handler_start.elapsed().as_millis(), - session_id = %sid, - request_id = %request_id, - "SessionHandler::on_permission_request dispatch" - ); + let rpc_result = if let Some(permission_handler) = handlers.permission.as_ref() { + let result = permission_handler + .handle(sid.clone(), request_id.clone(), data) + .await; + tracing::debug!( + elapsed_ms = handler_start.elapsed().as_millis(), + session_id = %sid, + request_id = %request_id, + "PermissionHandler::handle dispatch" + ); + direct_permission_payload(&result) + } else { + // Back-compat with v2 servers that still send + // permission.request as a direct RPC: default to + // user-not-available rather than erroring. + serde_json::json!({ "kind": "user-not-available" }) + }; let rpc_response = JsonRpcResponse { jsonrpc: "2.0".to_string(), id: request.id, - result: Some(direct_permission_payload(&response)), + result: Some(rpc_result), error: None, }; let _ = client.send_response(&rpc_response).await; @@ -2173,22 +2148,20 @@ mod tests { direct_permission_payload, notification_permission_payload, pending_permission_result_kind, permission_request_response, }; - use crate::handler::{HandlerResponse, PermissionResult}; + use crate::handler::PermissionResult; #[test] fn pending_permission_requests_use_decision_kinds() { assert_eq!( - pending_permission_result_kind(&HandlerResponse::Permission( - PermissionResult::Approved, - )), + pending_permission_result_kind(&PermissionResult::Approved), "approve-once" ); assert_eq!( - pending_permission_result_kind(&HandlerResponse::Permission(PermissionResult::Denied)), + pending_permission_result_kind(&PermissionResult::Denied), "reject" ); assert_eq!( - pending_permission_result_kind(&HandlerResponse::Ok), + pending_permission_result_kind(&PermissionResult::UserNotAvailable), "user-not-available" ); } @@ -2196,21 +2169,17 @@ mod tests { #[test] fn direct_permission_requests_use_decision_response_kinds() { assert_eq!( - serde_json::to_value(permission_request_response(&HandlerResponse::Permission( - PermissionResult::Approved - ),)) - .expect("serializing approved permission response should succeed"), + serde_json::to_value(permission_request_response(&PermissionResult::Approved)) + .expect("serializing approved permission response should succeed"), json!({ "kind": "approve-once" }) ); assert_eq!( - serde_json::to_value(permission_request_response(&HandlerResponse::Permission( - PermissionResult::Denied - ),)) - .expect("serializing denied permission response should succeed"), + serde_json::to_value(permission_request_response(&PermissionResult::Denied)) + .expect("serializing denied permission response should succeed"), json!({ "kind": "reject" }) ); assert_eq!( - serde_json::to_value(permission_request_response(&HandlerResponse::Ok)) + serde_json::to_value(permission_request_response(&PermissionResult::UserNotAvailable)) .expect("serializing fallback permission response should succeed"), json!({ "kind": "reject" }) ); @@ -2219,18 +2188,8 @@ mod tests { #[test] fn notification_payload_handles_non_responses_and_custom() { // Deferred/NoResult -> no payload, SDK must not respond. - assert!( - notification_permission_payload(&HandlerResponse::Permission( - PermissionResult::Deferred, - )) - .is_none() - ); - assert!( - notification_permission_payload(&HandlerResponse::Permission( - PermissionResult::NoResult, - )) - .is_none() - ); + assert!(notification_permission_payload(&PermissionResult::Deferred).is_none()); + assert!(notification_permission_payload(&PermissionResult::NoResult).is_none()); // Custom → handler-supplied value passed through verbatim. let custom = json!({ @@ -2238,23 +2197,17 @@ mod tests { "allowlist": ["ls", "grep"], }); assert_eq!( - notification_permission_payload(&HandlerResponse::Permission( - PermissionResult::Custom(custom.clone()), - )), + notification_permission_payload(&PermissionResult::Custom(custom.clone())), Some(custom) ); // Approved/Denied → existing kind-only shape. assert_eq!( - notification_permission_payload(&HandlerResponse::Permission( - PermissionResult::Approved, - )), + notification_permission_payload(&PermissionResult::Approved), Some(json!({ "kind": "approve-once" })) ); assert_eq!( - notification_permission_payload( - &HandlerResponse::Permission(PermissionResult::Denied,) - ), + notification_permission_payload(&PermissionResult::Denied), Some(json!({ "kind": "reject" })) ); } @@ -2267,31 +2220,29 @@ mod tests { "allowlist": ["ls", "grep"], }); assert_eq!( - direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Custom( - custom.clone(), - ))), + direct_permission_payload(&PermissionResult::Custom(custom.clone())), custom ); // Deferred → falls back to Approved because the direct RPC must reply. assert_eq!( - direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Deferred)), + direct_permission_payload(&PermissionResult::Deferred), json!({ "kind": "approve-once" }) ); // NoResult -> direct RPC cannot be left pending, so report no user. assert_eq!( - direct_permission_payload(&HandlerResponse::Permission(PermissionResult::NoResult)), + direct_permission_payload(&PermissionResult::NoResult), json!({ "kind": "user-not-available" }) ); // Approved/Denied → existing kind-only shape. assert_eq!( - direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Approved)), + direct_permission_payload(&PermissionResult::Approved), json!({ "kind": "approve-once" }) ); assert_eq!( - direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Denied)), + direct_permission_payload(&PermissionResult::Denied), json!({ "kind": "reject" }) ); } diff --git a/rust/src/tool.rs b/rust/src/tool.rs index fc1163e82..1be2a2c0d 100644 --- a/rust/src/tool.rs +++ b/rust/src/tool.rs @@ -1,14 +1,16 @@ //! Typed tool definition framework. //! -//! Provides the [`ToolHandler`](crate::tool::ToolHandler) trait for implementing tools as named types, -//! and [`ToolHandlerRouter`](crate::tool::ToolHandlerRouter) for automatic dispatch of tool calls within a -//! [`SessionHandler`](crate::handler::SessionHandler). +//! Provides the [`ToolHandler`](crate::tool::ToolHandler) trait for +//! implementing tools as named types. Install tool handlers on a session +//! via +//! [`SessionConfig::with_tool_handlers`](crate::types::SessionConfig::with_tool_handlers); +//! the SDK builds an internal name-keyed registry and dispatches to the +//! matching handler when the CLI broadcasts `external_tool.requested`. //! //! Enable the `derive` feature for `schema_for`, which generates JSON //! Schema from Rust types via `schemars`. use std::collections::HashMap; -use std::sync::Arc; use async_trait::async_trait; /// Re-export of [`schemars::JsonSchema`] for deriving tool parameter schemas. @@ -16,11 +18,7 @@ use async_trait::async_trait; pub use schemars::JsonSchema; use crate::Error; -use crate::handler::{PermissionResult, SessionHandler, UserInputResponse}; -use crate::types::{ - ElicitationRequest, ElicitationResult, PermissionRequestData, RequestId, SessionEvent, - SessionId, Tool, ToolBinaryResult, ToolInvocation, ToolResult, ToolResultExpanded, -}; +use crate::types::{Tool, ToolBinaryResult, ToolInvocation, ToolResult, ToolResultExpanded}; /// Generate a JSON Schema [`Value`](serde_json::Value) from a Rust type. /// @@ -227,9 +225,10 @@ pub trait ToolHandler: Send + Sync { /// Define a tool from an async function (or closure) that takes a typed, /// `JsonSchema`-derived parameter struct. /// -/// The returned `Box` plugs directly into -/// [`ToolHandlerRouter::new`]. JSON Schema for the parameter type is generated -/// via [`schema_for`] at construction time. +/// The returned `Box` can be installed on a session via +/// [`SessionConfig::with_tool_handlers`](crate::types::SessionConfig::with_tool_handlers). +/// JSON Schema for the parameter type is generated via [`schema_for`] at +/// construction time. /// /// The handler bound (`Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static`) /// accepts both bare `async fn` items and closures — the same shape as @@ -333,136 +332,6 @@ where }) } -/// A [`SessionHandler`] that dispatches tool calls to registered -/// [`ToolHandler`] implementations by name. -/// -/// For tool calls matching a registered handler, the handler is invoked -/// directly. All other events (permissions, user input, unrecognized tools) -/// are forwarded to the inner handler. -/// -/// # Example -/// -/// ```rust,no_run -/// use std::sync::Arc; -/// use github_copilot_sdk::handler::ApproveAllHandler; -/// use github_copilot_sdk::tool::ToolHandlerRouter; -/// -/// let router = ToolHandlerRouter::new( -/// vec![/* Box::new(MyTool), ... */], -/// Arc::new(ApproveAllHandler), -/// ); -/// -/// // Use router.tools() in SessionConfig -/// // Use Arc::new(router) as the session handler -/// ``` -pub struct ToolHandlerRouter { - handlers: HashMap>, - inner: Arc, -} - -impl std::fmt::Debug for ToolHandlerRouter { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut tools: Vec<_> = self.handlers.keys().collect(); - tools.sort(); - f.debug_struct("ToolHandlerRouter") - .field("tool_count", &self.handlers.len()) - .field("tools", &tools) - .finish() - } -} - -impl ToolHandlerRouter { - /// Create a router from tool handler impls and a fallback handler. - /// - /// Call [`tools()`](Self::tools) to get the tool definitions for - /// [`SessionConfig::tools`](crate::SessionConfig::tools). - pub fn new(tools: Vec>, inner: Arc) -> Self { - let mut handlers = HashMap::new(); - for tool in tools { - handlers.insert(tool.tool().name.clone(), tool); - } - Self { handlers, inner } - } - - /// Tool definitions for [`SessionConfig::tools`](crate::SessionConfig::tools). - pub fn tools(&self) -> Vec { - self.handlers.values().map(|h| h.tool()).collect() - } -} - -#[async_trait] -impl SessionHandler for ToolHandlerRouter { - fn wants_permission_dispatch(&self) -> bool { - self.inner.wants_permission_dispatch() - } - - fn wants_elicitation_dispatch(&self) -> bool { - self.inner.wants_elicitation_dispatch() - } - - fn wants_external_tool_dispatch(&self, tool_name: &str) -> bool { - self.handlers.contains_key(tool_name) || self.inner.wants_external_tool_dispatch(tool_name) - } - - async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult { - let Some(handler) = self.handlers.get(&invocation.tool_name) else { - return self.inner.on_external_tool(invocation).await; - }; - match handler.call(invocation).await { - Ok(result) => result, - Err(e) => { - let msg = e.to_string(); - ToolResult::Expanded(ToolResultExpanded { - text_result_for_llm: msg.clone(), - result_type: "failure".to_string(), - binary_results_for_llm: None, - session_log: None, - error: Some(msg), - tool_telemetry: None, - }) - } - } - } - - async fn on_session_event(&self, session_id: SessionId, event: SessionEvent) { - self.inner.on_session_event(session_id, event).await - } - - async fn on_permission_request( - &self, - session_id: SessionId, - request_id: RequestId, - data: PermissionRequestData, - ) -> PermissionResult { - self.inner - .on_permission_request(session_id, request_id, data) - .await - } - - async fn on_user_input( - &self, - session_id: SessionId, - question: String, - choices: Option>, - allow_freeform: Option, - ) -> Option { - self.inner - .on_user_input(session_id, question, choices, allow_freeform) - .await - } - - async fn on_elicitation( - &self, - session_id: SessionId, - request_id: RequestId, - request: ElicitationRequest, - ) -> ElicitationResult { - self.inner - .on_elicitation(session_id, request_id, request) - .await - } -} - #[cfg(test)] mod tests { use super::*; @@ -717,233 +586,6 @@ mod tests { } } - #[tokio::test] - async fn router_dispatches_to_correct_handler() { - struct ToolA; - #[async_trait] - impl ToolHandler for ToolA { - fn tool(&self) -> Tool { - Tool { - name: "tool_a".to_string(), - namespaced_name: None, - description: "A".to_string(), - parameters: HashMap::new(), - instructions: None, - ..Default::default() - } - } - - async fn call(&self, _inv: ToolInvocation) -> Result { - Ok(ToolResult::Text("a_result".to_string())) - } - } - - struct ToolB; - #[async_trait] - impl ToolHandler for ToolB { - fn tool(&self) -> Tool { - Tool { - name: "tool_b".to_string(), - namespaced_name: None, - description: "B".to_string(), - parameters: HashMap::new(), - instructions: None, - ..Default::default() - } - } - - async fn call(&self, _inv: ToolInvocation) -> Result { - Ok(ToolResult::Text("b_result".to_string())) - } - } - - let router = ToolHandlerRouter::new( - vec![Box::new(ToolA), Box::new(ToolB)], - Arc::new(crate::handler::ApproveAllHandler), - ); - - let tools = router.tools(); - assert_eq!(tools.len(), 2); - - let response = router - .on_external_tool(ToolInvocation { - session_id: SessionId::from("s1"), - tool_call_id: "tc1".to_string(), - tool_name: "tool_b".to_string(), - arguments: serde_json::json!({}), - traceparent: None, - tracestate: None, - }) - .await; - match response { - ToolResult::Text(s) => assert_eq!(s, "b_result"), - _ => panic!("expected ToolResult::Text"), - } - } - - #[tokio::test] - async fn router_falls_through_for_unknown_tool() { - use std::sync::atomic::{AtomicBool, Ordering}; - - struct FallbackHandler { - called: AtomicBool, - } - #[async_trait] - impl SessionHandler for FallbackHandler { - async fn on_external_tool(&self, _inv: ToolInvocation) -> ToolResult { - self.called.store(true, Ordering::Relaxed); - ToolResult::Text("fallback".to_string()) - } - } - - let fallback = Arc::new(FallbackHandler { - called: AtomicBool::new(false), - }); - let router = ToolHandlerRouter::new(vec![], fallback.clone()); - - let response = router - .on_external_tool(ToolInvocation { - session_id: SessionId::from("s1"), - tool_call_id: "tc1".to_string(), - tool_name: "unknown".to_string(), - arguments: serde_json::json!({}), - traceparent: None, - tracestate: None, - }) - .await; - assert!(fallback.called.load(Ordering::Relaxed)); - match response { - ToolResult::Text(s) => assert_eq!(s, "fallback"), - _ => panic!("expected fallback result"), - } - } - - #[tokio::test] - async fn router_returns_failure_on_handler_error() { - struct FailTool; - #[async_trait] - impl ToolHandler for FailTool { - fn tool(&self) -> Tool { - Tool { - name: "bad_tool".to_string(), - namespaced_name: None, - description: "Always fails".to_string(), - parameters: HashMap::new(), - instructions: None, - ..Default::default() - } - } - - async fn call(&self, _inv: ToolInvocation) -> Result { - Err(Error::Rpc { - code: -1, - message: "intentional failure".to_string(), - }) - } - } - - let router = ToolHandlerRouter::new( - vec![Box::new(FailTool)], - Arc::new(crate::handler::ApproveAllHandler), - ); - - let response = router - .on_external_tool(ToolInvocation { - session_id: SessionId::from("s1"), - tool_call_id: "tc1".to_string(), - tool_name: "bad_tool".to_string(), - arguments: serde_json::json!({}), - traceparent: None, - tracestate: None, - }) - .await; - match response { - ToolResult::Expanded(exp) => { - assert_eq!(exp.result_type, "failure"); - assert!(exp.error.unwrap().contains("intentional failure")); - } - _ => panic!("expected expanded failure result"), - } - } - - #[tokio::test] - async fn router_forwards_non_tool_events() { - struct PermHandler; - #[async_trait] - impl SessionHandler for PermHandler { - async fn on_permission_request( - &self, - _session_id: SessionId, - _request_id: RequestId, - _data: PermissionRequestData, - ) -> PermissionResult { - PermissionResult::Denied - } - } - - let router = ToolHandlerRouter::new(vec![], Arc::new(PermHandler)); - - let response = router - .on_permission_request( - SessionId::from("s1"), - RequestId::new("r1"), - PermissionRequestData { - extra: serde_json::json!({}), - ..Default::default() - }, - ) - .await; - assert!(matches!(response, PermissionResult::Denied)); - } - - #[tokio::test] - async fn router_default_on_event_dispatches_via_per_event_methods() { - // Regression: callers using the legacy on_event entry point should - // still get correct dispatch through the inherited default impl. - use crate::handler::{HandlerEvent, HandlerResponse}; - - struct OkTool; - #[async_trait] - impl ToolHandler for OkTool { - fn tool(&self) -> Tool { - Tool { - name: "ok_tool".to_string(), - namespaced_name: None, - description: "ok".to_string(), - parameters: HashMap::new(), - instructions: None, - ..Default::default() - } - } - - async fn call(&self, _inv: ToolInvocation) -> Result { - Ok(ToolResult::Text("ok".to_string())) - } - } - - let router = ToolHandlerRouter::new( - vec![Box::new(OkTool)], - Arc::new(crate::handler::ApproveAllHandler), - ); - - let response = router - .on_event(HandlerEvent::ExternalTool { - invocation: ToolInvocation { - session_id: SessionId::from("s1"), - tool_call_id: "tc1".to_string(), - tool_name: "ok_tool".to_string(), - arguments: serde_json::json!({}), - traceparent: None, - tracestate: None, - }, - }) - .await; - match response { - HandlerResponse::ToolResult(ToolResult::Text(s)) => assert_eq!(s, "ok"), - _ => panic!("expected ToolResult via default on_event"), - } - } - // Tests requiring `schemars` (the `derive` feature). #[cfg(feature = "derive")] mod derive_tests { @@ -1046,18 +688,18 @@ mod tests { } #[tokio::test] - async fn router_with_schema_for_tools() { - let router = ToolHandlerRouter::new( - vec![Box::new(GetWeatherTool)], - Arc::new(crate::handler::ApproveAllHandler), - ); + async fn schema_for_derived_tool_round_trips_through_call() { + let tool: Box = Box::new(GetWeatherTool); - let tools = router.tools(); - assert_eq!(tools.len(), 1); - assert_eq!(tools[0].name, "get_weather"); + // Tool definition exposes the schema-derived parameter set. + let def = tool.tool(); + assert_eq!(def.name, "get_weather"); - let response = router - .on_external_tool(ToolInvocation { + // Calling the tool with matching arguments returns the + // expected typed result. (Per-name dispatch is the SDK's + // concern; here we exercise just the handler contract.) + let result = tool + .call(ToolInvocation { session_id: SessionId::from("s1"), tool_call_id: "tc1".to_string(), tool_name: "get_weather".to_string(), @@ -1065,8 +707,9 @@ mod tests { traceparent: None, tracestate: None, }) - .await; - match response { + .await + .expect("ToolHandler::call should succeed for matching args"); + match result { ToolResult::Text(s) => assert!(s.contains("Portland")), _ => panic!("expected ToolResult::Text"), } diff --git a/rust/src/types.rs b/rust/src/types.rs index ccb0e1807..d8a7365ca 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -12,7 +12,10 @@ use std::time::Duration; use serde::{Deserialize, Serialize}; use serde_json::Value; -use crate::handler::SessionHandler; +use crate::handler::{ + AutoModeSwitchHandler, ElicitationHandler, ExitPlanModeHandler, PermissionHandler, + UserInputHandler, +}; use crate::hooks::SessionHooks; pub use crate::session_fs::{ DirEntry, DirEntryKind, FileInfo, FsError, SessionFsCapabilities, SessionFsConfig, @@ -1060,10 +1063,11 @@ pub struct SessionConfig { /// `Some(true)` via [`SessionConfig::default`]. #[serde(skip_serializing_if = "Option::is_none")] pub request_user_input: Option, - /// Enable `permission.request` JSON-RPC calls from the CLI. Defaults - /// to `Some(true)` via [`SessionConfig::default`]; the default - /// [`NoopHandler`](crate::handler::NoopHandler) leaves requests pending - /// for the consumer to resolve. + /// Enable `permission.request` JSON-RPC calls from the CLI. + /// Derived from [`Self::permission_handler`] presence at + /// [`Client::create_session`](crate::Client::create_session) time; + /// callers should install a [`PermissionHandler`] rather than + /// setting this directly. #[serde(skip_serializing_if = "Option::is_none")] pub request_permission: Option, /// Enable `exitPlanMode.request` JSON-RPC calls for plan approval. @@ -1174,19 +1178,40 @@ pub struct SessionConfig { /// See [`SessionFsProvider`]. #[serde(skip)] pub session_fs_provider: Option>, - /// Session-level event handler. The default is - /// [`NoopHandler`](crate::handler::NoopHandler) — permission requests - /// and external tool calls are left pending for the consumer to resolve. - /// Use [`with_handler`](Self::with_handler) to install a custom handler. + /// Optional permission-request handler. When `None`, the SDK sends + /// `requestPermission: false` on the wire so the runtime does not + /// emit `permission.requested` broadcasts to this client. + #[serde(skip)] + pub permission_handler: Option>, + /// Optional elicitation-request handler. When `None`, + /// `requestElicitation: false` goes on the wire. + #[serde(skip)] + pub elicitation_handler: Option>, + /// Optional user-input handler. When `None`, + /// `requestUserInput: false` goes on the wire and the `ask_user` + /// tool is disabled. + #[serde(skip)] + pub user_input_handler: Option>, + /// Optional exit-plan-mode handler. When `None`, + /// `requestExitPlanMode: false` goes on the wire. #[serde(skip)] - pub handler: Option>, + pub exit_plan_mode_handler: Option>, + /// Optional auto-mode-switch handler. When `None`, + /// `requestAutoModeSwitch: false` goes on the wire. + #[serde(skip)] + pub auto_mode_switch_handler: Option>, + /// Client-defined tool handlers. The SDK builds an internal + /// name-keyed registry from these and dispatches to the matching + /// handler when the CLI broadcasts `external_tool.requested`. + #[serde(skip)] + pub tool_handlers: Vec>, /// Session lifecycle hook handler (pre/post tool use, session /// start/end, etc.). When set, the SDK auto-enables the wire-level /// `hooks` flag. Use [`with_hooks`](Self::with_hooks) to install one. #[serde(skip)] pub hooks_handler: Option>, /// Permission policy applied to the handler. Stored separately from - /// `handler` so that the order of `with_handler` and + /// `permission_handler` so the order of `with_permission_handler` and /// `approve_all_permissions` (and friends) is irrelevant. #[serde(skip)] pub(crate) permission_policy: Option, @@ -1245,7 +1270,27 @@ impl std::fmt::Debug for SessionConfig { "session_fs_provider", &self.session_fs_provider.as_ref().map(|_| ""), ) - .field("handler", &self.handler.as_ref().map(|_| "")) + .field( + "permission_handler", + &self.permission_handler.as_ref().map(|_| ""), + ) + .field( + "elicitation_handler", + &self.elicitation_handler.as_ref().map(|_| ""), + ) + .field( + "user_input_handler", + &self.user_input_handler.as_ref().map(|_| ""), + ) + .field( + "exit_plan_mode_handler", + &self.exit_plan_mode_handler.as_ref().map(|_| ""), + ) + .field( + "auto_mode_switch_handler", + &self.auto_mode_switch_handler.as_ref().map(|_| ""), + ) + .field("tool_handlers_count", &self.tool_handlers.len()) .field( "hooks_handler", &self.hooks_handler.as_ref().map(|_| ""), @@ -1256,11 +1301,11 @@ impl std::fmt::Debug for SessionConfig { } impl Default for SessionConfig { - /// Permission and elicitation flows are enabled by default. When no handler - /// is provided, the SDK installs `NoopHandler`, so permission and external - /// tool requests remain pending until the consumer responds out-of-band. - /// Callers that want the wire surface fully disabled set these explicitly - /// to `Some(false)`. + /// All wire-level "request" flags and handler fields start unset. + /// Install a [`PermissionHandler`] via + /// [`with_permission_handler`](Self::with_permission_handler) and + /// the SDK derives `requestPermission: true` on the wire at + /// [`Client::create_session`](crate::Client::create_session) time. fn default() -> Self { Self { session_id: None, @@ -1275,10 +1320,10 @@ impl Default for SessionConfig { mcp_servers: None, env_value_mode: default_env_value_mode(), enable_config_discovery: None, - request_user_input: Some(true), + request_user_input: None, request_permission: None, - request_exit_plan_mode: Some(true), - request_auto_mode_switch: Some(true), + request_exit_plan_mode: None, + request_auto_mode_switch: None, request_elicitation: None, skill_directories: None, instruction_directories: None, @@ -1299,7 +1344,12 @@ impl Default for SessionConfig { include_sub_agent_streaming_events: None, commands: None, session_fs_provider: None, - handler: None, + permission_handler: None, + elicitation_handler: None, + user_input_handler: None, + exit_plan_mode_handler: None, + auto_mode_switch_handler: None, + tool_handlers: Vec::new(), hooks_handler: None, permission_policy: None, transform: None, @@ -1308,9 +1358,114 @@ impl Default for SessionConfig { } impl SessionConfig { - /// Install a custom [`SessionHandler`] for this session. - pub fn with_handler(mut self, handler: Arc) -> Self { - self.handler = Some(handler); + /// Build the [`SessionCreateWire`] payload for `session.create` from + /// this config. Derives the request_* wire flags from handler + /// presence and the policy field; clones plain fields. + pub(crate) fn to_wire(&self, session_id: SessionId) -> crate::wire::SessionCreateWire { + let permission_active = + self.permission_handler.is_some() || self.permission_policy.is_some(); + crate::wire::SessionCreateWire { + session_id, + model: self.model.clone(), + client_name: self.client_name.clone(), + reasoning_effort: self.reasoning_effort.clone(), + streaming: self.streaming, + system_message: self.system_message.clone(), + tools: self.merged_tool_wire_definitions(), + available_tools: self.available_tools.clone(), + excluded_tools: self.excluded_tools.clone(), + mcp_servers: self.mcp_servers.clone(), + env_value_mode: "direct", + enable_config_discovery: self.enable_config_discovery, + request_user_input: self.user_input_handler.is_some(), + request_permission: permission_active, + request_exit_plan_mode: self.exit_plan_mode_handler.is_some(), + request_auto_mode_switch: self.auto_mode_switch_handler.is_some(), + request_elicitation: self.elicitation_handler.is_some(), + hooks: self.hooks_handler.is_some(), + skill_directories: self.skill_directories.clone(), + instruction_directories: self.instruction_directories.clone(), + disabled_skills: self.disabled_skills.clone(), + custom_agents: self.custom_agents.clone(), + default_agent: self.default_agent.clone(), + agent: self.agent.clone(), + infinite_sessions: self.infinite_sessions.clone(), + provider: self.provider.clone(), + enable_session_telemetry: self.enable_session_telemetry, + model_capabilities: self.model_capabilities.clone(), + config_dir: self.config_dir.clone(), + working_directory: self.working_directory.clone(), + github_token: self.github_token.clone(), + remote_session: self.remote_session.clone(), + cloud: self.cloud.clone(), + include_sub_agent_streaming_events: self.include_sub_agent_streaming_events, + commands: self.commands.as_ref().map(|cmds| { + cmds.iter() + .map(|c| crate::wire::CommandWireDefinition { + name: c.name.clone(), + description: c.description.clone(), + }) + .collect() + }), + } + } + + /// Merge caller-supplied `tools` (declaration-only) with the `tool()` + /// definitions extracted from each [`tool_handlers`](Self::tool_handlers) + /// entry. Returns `None` only when both sources are empty. + fn merged_tool_wire_definitions(&self) -> Option> { + let mut out: Vec = self.tools.clone().unwrap_or_default(); + for handler in &self.tool_handlers { + out.push(handler.tool()); + } + if out.is_empty() { None } else { Some(out) } + } + + /// Install a [`PermissionHandler`] for this session. When omitted, the + /// SDK sends `requestPermission: false` on the wire and the runtime + /// short-circuits permission prompts for this client. + pub fn with_permission_handler(mut self, handler: Arc) -> Self { + self.permission_handler = Some(handler); + self + } + + /// Install an [`ElicitationHandler`]. When omitted, the SDK sends + /// `requestElicitation: false` on the wire. + pub fn with_elicitation_handler(mut self, handler: Arc) -> Self { + self.elicitation_handler = Some(handler); + self + } + + /// Install a [`UserInputHandler`]. Required for the `ask_user` tool + /// to be enabled. + pub fn with_user_input_handler(mut self, handler: Arc) -> Self { + self.user_input_handler = Some(handler); + self + } + + /// Install an [`ExitPlanModeHandler`]. + pub fn with_exit_plan_mode_handler(mut self, handler: Arc) -> Self { + self.exit_plan_mode_handler = Some(handler); + self + } + + /// Install an [`AutoModeSwitchHandler`]. + pub fn with_auto_mode_switch_handler( + mut self, + handler: Arc, + ) -> Self { + self.auto_mode_switch_handler = Some(handler); + self + } + + /// Install tool handlers for this session. Each handler must report a + /// unique [`Tool::name`]; the SDK rejects duplicates at + /// [`Client::create_session`](crate::Client::create_session) time. + pub fn with_tool_handlers(mut self, handlers: I) -> Self + where + I: IntoIterator>, + { + self.tool_handlers = handlers.into_iter().collect(); self } @@ -1347,9 +1502,10 @@ impl SessionConfig { } /// Auto-approve every permission request on this session. Stored as a - /// policy that's applied to the configured handler at + /// policy that's applied at /// [`Client::create_session`](crate::Client::create_session) time, so - /// order with [`with_handler`](Self::with_handler) is irrelevant. + /// order with [`with_permission_handler`](Self::with_permission_handler) + /// is irrelevant. pub fn approve_all_permissions(mut self) -> Self { self.permission_policy = Some(crate::permission::Policy::ApproveAll); self @@ -1736,9 +1892,29 @@ pub struct ResumeSessionConfig { /// [`Client::force_stop`]: crate::Client::force_stop #[serde(skip_serializing_if = "Option::is_none")] pub continue_pending_work: Option, - /// Session-level event handler. See [`SessionConfig::handler`]. + /// Optional permission-request handler. See + /// [`SessionConfig::permission_handler`]. + #[serde(skip)] + pub permission_handler: Option>, + /// Optional elicitation handler. See + /// [`SessionConfig::elicitation_handler`]. + #[serde(skip)] + pub elicitation_handler: Option>, + /// Optional user-input handler. See + /// [`SessionConfig::user_input_handler`]. #[serde(skip)] - pub handler: Option>, + pub user_input_handler: Option>, + /// Optional exit-plan-mode handler. See + /// [`SessionConfig::exit_plan_mode_handler`]. + #[serde(skip)] + pub exit_plan_mode_handler: Option>, + /// Optional auto-mode-switch handler. See + /// [`SessionConfig::auto_mode_switch_handler`]. + #[serde(skip)] + pub auto_mode_switch_handler: Option>, + /// Tool handlers. See [`SessionConfig::tool_handlers`]. + #[serde(skip)] + pub tool_handlers: Vec>, /// Session hook handler. See [`SessionConfig::hooks_handler`]. #[serde(skip)] pub hooks_handler: Option>, @@ -1795,7 +1971,27 @@ impl std::fmt::Debug for ResumeSessionConfig { "session_fs_provider", &self.session_fs_provider.as_ref().map(|_| ""), ) - .field("handler", &self.handler.as_ref().map(|_| "")) + .field( + "permission_handler", + &self.permission_handler.as_ref().map(|_| ""), + ) + .field( + "elicitation_handler", + &self.elicitation_handler.as_ref().map(|_| ""), + ) + .field( + "user_input_handler", + &self.user_input_handler.as_ref().map(|_| ""), + ) + .field( + "exit_plan_mode_handler", + &self.exit_plan_mode_handler.as_ref().map(|_| ""), + ) + .field( + "auto_mode_switch_handler", + &self.auto_mode_switch_handler.as_ref().map(|_| ""), + ) + .field("tool_handlers_count", &self.tool_handlers.len()) .field( "hooks_handler", &self.hooks_handler.as_ref().map(|_| ""), @@ -1808,6 +2004,66 @@ impl std::fmt::Debug for ResumeSessionConfig { } impl ResumeSessionConfig { + /// Build the [`SessionResumeWire`] payload for `session.resume` from + /// this config. Derives the request_* wire flags from handler + /// presence and the policy field; clones plain fields. + pub(crate) fn to_wire(&self) -> crate::wire::SessionResumeWire { + let permission_active = + self.permission_handler.is_some() || self.permission_policy.is_some(); + crate::wire::SessionResumeWire { + session_id: self.session_id.clone(), + client_name: self.client_name.clone(), + reasoning_effort: self.reasoning_effort.clone(), + streaming: self.streaming, + system_message: self.system_message.clone(), + tools: self.merged_tool_wire_definitions(), + available_tools: self.available_tools.clone(), + excluded_tools: self.excluded_tools.clone(), + mcp_servers: self.mcp_servers.clone(), + env_value_mode: "direct", + enable_config_discovery: self.enable_config_discovery, + request_user_input: self.user_input_handler.is_some(), + request_permission: permission_active, + request_exit_plan_mode: self.exit_plan_mode_handler.is_some(), + request_auto_mode_switch: self.auto_mode_switch_handler.is_some(), + request_elicitation: self.elicitation_handler.is_some(), + hooks: self.hooks_handler.is_some(), + skill_directories: self.skill_directories.clone(), + instruction_directories: self.instruction_directories.clone(), + disabled_skills: self.disabled_skills.clone(), + custom_agents: self.custom_agents.clone(), + default_agent: self.default_agent.clone(), + agent: self.agent.clone(), + infinite_sessions: self.infinite_sessions.clone(), + provider: self.provider.clone(), + enable_session_telemetry: self.enable_session_telemetry, + model_capabilities: self.model_capabilities.clone(), + config_dir: self.config_dir.clone(), + working_directory: self.working_directory.clone(), + github_token: self.github_token.clone(), + remote_session: self.remote_session.clone(), + include_sub_agent_streaming_events: self.include_sub_agent_streaming_events, + commands: self.commands.as_ref().map(|cmds| { + cmds.iter() + .map(|c| crate::wire::CommandWireDefinition { + name: c.name.clone(), + description: c.description.clone(), + }) + .collect() + }), + suppress_resume_event: self.suppress_resume_event, + continue_pending_work: self.continue_pending_work, + } + } + + fn merged_tool_wire_definitions(&self) -> Option> { + let mut out: Vec = self.tools.clone().unwrap_or_default(); + for handler in &self.tool_handlers { + out.push(handler.tool()); + } + if out.is_empty() { None } else { Some(out) } + } + /// Construct a `ResumeSessionConfig` with the given session ID and all /// other fields left unset. Combine with `.with_*` builders or struct /// update syntax (`..ResumeSessionConfig::new(id)`) to populate the @@ -1825,10 +2081,10 @@ impl ResumeSessionConfig { mcp_servers: None, env_value_mode: default_env_value_mode(), enable_config_discovery: None, - request_user_input: Some(true), + request_user_input: None, request_permission: None, - request_exit_plan_mode: Some(true), - request_auto_mode_switch: Some(true), + request_exit_plan_mode: None, + request_auto_mode_switch: None, request_elicitation: None, skill_directories: None, instruction_directories: None, @@ -1850,16 +2106,57 @@ impl ResumeSessionConfig { session_fs_provider: None, suppress_resume_event: None, continue_pending_work: None, - handler: None, + permission_handler: None, + elicitation_handler: None, + user_input_handler: None, + exit_plan_mode_handler: None, + auto_mode_switch_handler: None, + tool_handlers: Vec::new(), hooks_handler: None, permission_policy: None, transform: None, } } - /// Install a custom [`SessionHandler`] for this session. - pub fn with_handler(mut self, handler: Arc) -> Self { - self.handler = Some(handler); + /// Install a [`PermissionHandler`] for the resumed session. + pub fn with_permission_handler(mut self, handler: Arc) -> Self { + self.permission_handler = Some(handler); + self + } + + /// Install an [`ElicitationHandler`] for the resumed session. + pub fn with_elicitation_handler(mut self, handler: Arc) -> Self { + self.elicitation_handler = Some(handler); + self + } + + /// Install a [`UserInputHandler`] for the resumed session. + pub fn with_user_input_handler(mut self, handler: Arc) -> Self { + self.user_input_handler = Some(handler); + self + } + + /// Install an [`ExitPlanModeHandler`] for the resumed session. + pub fn with_exit_plan_mode_handler(mut self, handler: Arc) -> Self { + self.exit_plan_mode_handler = Some(handler); + self + } + + /// Install an [`AutoModeSwitchHandler`] for the resumed session. + pub fn with_auto_mode_switch_handler( + mut self, + handler: Arc, + ) -> Self { + self.auto_mode_switch_handler = Some(handler); + self + } + + /// Install tool handlers for the resumed session. + pub fn with_tool_handlers(mut self, handlers: I) -> Self + where + I: IntoIterator>, + { + self.tool_handlers = handlers.into_iter().collect(); self } @@ -3132,8 +3429,7 @@ pub enum PermissionRequestKind { /// /// Used for both the `permission.request` RPC call (which expects a response) /// and `permission.requested` notifications (fire-and-forget). Contains the -/// full params object. Note that `requestId` is also available as a separate -/// field on `HandlerEvent::PermissionRequest`. +/// full params object. #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PermissionRequestData { @@ -3314,24 +3610,24 @@ mod tests { #[test] fn session_config_default_enables_permission_flow_flags() { let cfg = SessionConfig::default(); - assert_eq!(cfg.request_user_input, Some(true)); - // request_permission / request_elicitation are derived from the - // installed SessionHandler at Client::create_session time; the - // default config leaves them unset. + // All wire flags start unset; the SDK derives them from handler + // presence at Client::create_session time. + assert_eq!(cfg.request_user_input, None); assert_eq!(cfg.request_permission, None); assert_eq!(cfg.request_elicitation, None); - assert_eq!(cfg.request_exit_plan_mode, Some(true)); - assert_eq!(cfg.request_auto_mode_switch, Some(true)); + assert_eq!(cfg.request_exit_plan_mode, None); + assert_eq!(cfg.request_auto_mode_switch, None); } #[test] fn resume_session_config_new_enables_permission_flow_flags() { let cfg = ResumeSessionConfig::new(SessionId::from("test-id")); - assert_eq!(cfg.request_user_input, Some(true)); + // All wire flags start unset on resume too. + assert_eq!(cfg.request_user_input, None); assert_eq!(cfg.request_permission, None); assert_eq!(cfg.request_elicitation, None); - assert_eq!(cfg.request_exit_plan_mode, Some(true)); - assert_eq!(cfg.request_auto_mode_switch, Some(true)); + assert_eq!(cfg.request_exit_plan_mode, None); + assert_eq!(cfg.request_auto_mode_switch, None); } #[test] @@ -3869,189 +4165,127 @@ mod tests { mod permission_builder_tests { use std::sync::Arc; - use crate::handler::{ - ApproveAllHandler, HandlerEvent, HandlerResponse, PermissionResult, SessionHandler, - }; + use crate::handler::{ApproveAllHandler, PermissionHandler, PermissionResult}; + use crate::permission; use crate::types::{ PermissionRequestData, RequestId, ResumeSessionConfig, SessionConfig, SessionId, }; - fn permission_event() -> HandlerEvent { - HandlerEvent::PermissionRequest { - session_id: SessionId::from("s1"), - request_id: RequestId::new("1"), - data: PermissionRequestData { - extra: serde_json::json!({"tool": "shell"}), - ..Default::default() - }, + fn data() -> PermissionRequestData { + PermissionRequestData { + extra: serde_json::json!({"tool": "shell"}), + ..Default::default() } } /// Apply the same policy-resolution logic that `Client::create_session` /// uses, so tests exercise the effective handler. - fn resolve(mut cfg: SessionConfig) -> Arc { - let base = cfg - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - match cfg.permission_policy.take() { - Some(policy) => crate::permission::apply_policy(base, policy), - None => base, - } + fn resolve_create(mut cfg: SessionConfig) -> Option> { + permission::resolve_handler(cfg.permission_handler.take(), cfg.permission_policy.take()) } - fn resolve_resume(mut cfg: ResumeSessionConfig) -> Arc { - let base = cfg - .handler - .take() - .unwrap_or_else(|| Arc::new(crate::handler::NoopHandler)); - match cfg.permission_policy.take() { - Some(policy) => crate::permission::apply_policy(base, policy), - None => base, - } + fn resolve_resume(mut cfg: ResumeSessionConfig) -> Option> { + permission::resolve_handler(cfg.permission_handler.take(), cfg.permission_policy.take()) } - async fn dispatch(handler: &Arc) -> HandlerResponse { - handler.on_event(permission_event()).await + async fn dispatch(handler: &Arc) -> PermissionResult { + handler + .handle(SessionId::from("s1"), RequestId::new("1"), data()) + .await } #[tokio::test] - async fn session_config_approve_all_wraps_existing_handler() { + async fn approve_all_with_handler_present_approves() { let cfg = SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .approve_all_permissions(); - match dispatch(&resolve(cfg)).await { - HandlerResponse::Permission(PermissionResult::Approved) => {} - other => panic!("expected Approved, got {other:?}"), - } + let h = resolve_create(cfg).expect("policy + handler yields handler"); + assert!(matches!(dispatch(&h).await, PermissionResult::Approved)); } #[tokio::test] - async fn session_config_approve_all_defaults_to_noop_inner() { - // Without with_handler, resolution defaults to NoopHandler. The - // approve-all wrap intercepts permission events. + async fn approve_all_standalone_produces_handler() { let cfg = SessionConfig::default().approve_all_permissions(); - match dispatch(&resolve(cfg)).await { - HandlerResponse::Permission(PermissionResult::Approved) => {} - other => panic!("expected Approved, got {other:?}"), - } + let h = resolve_create(cfg).expect("policy alone yields handler"); + assert!(matches!(dispatch(&h).await, PermissionResult::Approved)); } - /// Phase I: order independence. Both call orders must produce the same - /// effective approve-all policy. + /// Phase I: order between with_permission_handler and the policy + /// builder must not matter. #[tokio::test] - async fn session_config_approve_all_is_order_independent() { - let cfg_a = SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + async fn approve_all_is_order_independent() { + let a = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) .approve_all_permissions(); - let cfg_b = SessionConfig::default() + let b = SessionConfig::default() .approve_all_permissions() - .with_handler(Arc::new(ApproveAllHandler)); - - match dispatch(&resolve(cfg_a)).await { - HandlerResponse::Permission(PermissionResult::Approved) => {} - other => panic!("order A: expected Approved, got {other:?}"), - } - match dispatch(&resolve(cfg_b)).await { - HandlerResponse::Permission(PermissionResult::Approved) => {} - other => panic!("order B: expected Approved, got {other:?}"), - } + .with_permission_handler(Arc::new(ApproveAllHandler)); + let ha = resolve_create(a).unwrap(); + let hb = resolve_create(b).unwrap(); + assert!(matches!(dispatch(&ha).await, PermissionResult::Approved)); + assert!(matches!(dispatch(&hb).await, PermissionResult::Approved)); } - /// Phase I: same for deny_all_permissions. #[tokio::test] - async fn session_config_deny_all_is_order_independent() { - let cfg_a = SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + async fn deny_all_is_order_independent() { + let a = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) .deny_all_permissions(); - let cfg_b = SessionConfig::default() + let b = SessionConfig::default() .deny_all_permissions() - .with_handler(Arc::new(ApproveAllHandler)); + .with_permission_handler(Arc::new(ApproveAllHandler)); + let ha = resolve_create(a).unwrap(); + let hb = resolve_create(b).unwrap(); + assert!(matches!(dispatch(&ha).await, PermissionResult::Denied)); + assert!(matches!(dispatch(&hb).await, PermissionResult::Denied)); + } - match dispatch(&resolve(cfg_a)).await { - HandlerResponse::Permission(PermissionResult::Denied) => {} - other => panic!("order A: expected Denied, got {other:?}"), - } - match dispatch(&resolve(cfg_b)).await { - HandlerResponse::Permission(PermissionResult::Denied) => {} - other => panic!("order B: expected Denied, got {other:?}"), - } + #[tokio::test] + async fn approve_permissions_if_consults_predicate() { + let cfg = SessionConfig::default().approve_permissions_if(|d| { + d.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") + }); + let h = resolve_create(cfg).unwrap(); + assert!(matches!(dispatch(&h).await, PermissionResult::Denied)); } - /// Phase I: same for approve_permissions_if. #[tokio::test] - async fn session_config_approve_permissions_if_is_order_independent() { - let predicate = |data: &PermissionRequestData| { - data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") + async fn approve_permissions_if_is_order_independent() { + let predicate = |d: &PermissionRequestData| { + d.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") }; - let cfg_a = SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + let a = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) .approve_permissions_if(predicate); - let cfg_b = SessionConfig::default() + let b = SessionConfig::default() .approve_permissions_if(predicate) - .with_handler(Arc::new(ApproveAllHandler)); - - match dispatch(&resolve(cfg_a)).await { - HandlerResponse::Permission(PermissionResult::Denied) => {} - other => panic!("order A: expected Denied for shell, got {other:?}"), - } - match dispatch(&resolve(cfg_b)).await { - HandlerResponse::Permission(PermissionResult::Denied) => {} - other => panic!("order B: expected Denied for shell, got {other:?}"), - } + .with_permission_handler(Arc::new(ApproveAllHandler)); + let ha = resolve_create(a).unwrap(); + let hb = resolve_create(b).unwrap(); + assert!(matches!(dispatch(&ha).await, PermissionResult::Denied)); + assert!(matches!(dispatch(&hb).await, PermissionResult::Denied)); } #[tokio::test] - async fn session_config_deny_all_denies() { - let cfg = SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) - .deny_all_permissions(); - match dispatch(&resolve(cfg)).await { - HandlerResponse::Permission(PermissionResult::Denied) => {} - other => panic!("expected Denied, got {other:?}"), - } - } - - #[tokio::test] - async fn session_config_approve_permissions_if_consults_predicate() { - let cfg = SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) - .approve_permissions_if(|data| { - data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") - }); - match dispatch(&resolve(cfg)).await { - HandlerResponse::Permission(PermissionResult::Denied) => {} - other => panic!("expected Denied for shell, got {other:?}"), - } - } - - #[tokio::test] - async fn resume_session_config_approve_all_wraps_existing_handler() { + async fn resume_session_config_approve_all_works() { let cfg = ResumeSessionConfig::new(SessionId::from("s1")) - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .approve_all_permissions(); - match dispatch(&resolve_resume(cfg)).await { - HandlerResponse::Permission(PermissionResult::Approved) => {} - other => panic!("expected Approved, got {other:?}"), - } + let h = resolve_resume(cfg).unwrap(); + assert!(matches!(dispatch(&h).await, PermissionResult::Approved)); } #[tokio::test] async fn resume_session_config_approve_all_is_order_independent() { - let cfg_a = ResumeSessionConfig::new(SessionId::from("s1")) - .with_handler(Arc::new(ApproveAllHandler)) + let a = ResumeSessionConfig::new(SessionId::from("s1")) + .with_permission_handler(Arc::new(ApproveAllHandler)) .approve_all_permissions(); - let cfg_b = ResumeSessionConfig::new(SessionId::from("s1")) + let b = ResumeSessionConfig::new(SessionId::from("s1")) .approve_all_permissions() - .with_handler(Arc::new(ApproveAllHandler)); - - match dispatch(&resolve_resume(cfg_a)).await { - HandlerResponse::Permission(PermissionResult::Approved) => {} - other => panic!("order A: expected Approved, got {other:?}"), - } - match dispatch(&resolve_resume(cfg_b)).await { - HandlerResponse::Permission(PermissionResult::Approved) => {} - other => panic!("order B: expected Approved, got {other:?}"), - } + .with_permission_handler(Arc::new(ApproveAllHandler)); + let ha = resolve_resume(a).unwrap(); + let hb = resolve_resume(b).unwrap(); + assert!(matches!(dispatch(&ha).await, PermissionResult::Approved)); + assert!(matches!(dispatch(&hb).await, PermissionResult::Approved)); } -} +} \ No newline at end of file diff --git a/rust/src/wire.rs b/rust/src/wire.rs new file mode 100644 index 000000000..fb1e34f20 --- /dev/null +++ b/rust/src/wire.rs @@ -0,0 +1,173 @@ +//! Wire-format structs for the `session.create` and `session.resume` +//! JSON-RPC payloads. +//! +//! Built explicitly from [`SessionConfig`](crate::types::SessionConfig) and +//! [`ResumeSessionConfig`](crate::types::ResumeSessionConfig) at +//! `Client::create_session` / `Client::resume_session` time via +//! [`SessionConfig::into_wire`](crate::types::SessionConfig::into_wire) and +//! [`ResumeSessionConfig::into_wire`](crate::types::ResumeSessionConfig::into_wire), +//! respectively. +//! +//! Keeping the wire shape separate from the user-facing config avoids +//! having callback fields on a serializable struct: the user-facing +//! configs hold trait-object handlers, the wire structs hold only the +//! plain data the runtime needs. + +use std::collections::HashMap; +use std::path::PathBuf; + +use serde::Serialize; + +use crate::generated::api_types::{ModelCapabilitiesOverride, RemoteSessionMode}; +use crate::types::{ + CloudSessionOptions, CustomAgentConfig, DefaultAgentConfig, InfiniteSessionConfig, + McpServerConfig, ProviderConfig, SessionId, SystemMessageConfig, Tool, +}; + +/// Wire representation of a slash command (name + description only). The +/// runtime executes the command; the SDK's `CommandHandler` callback is +/// invoked from a separate dispatch path and never crosses the wire. +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct CommandWireDefinition { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, +} + +/// The exact JSON shape sent on the `session.create` JSON-RPC request. +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct SessionCreateWire { + pub session_id: SessionId, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub client_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub streaming: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_message: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub available_tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub excluded_tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp_servers: Option>, + pub env_value_mode: &'static str, + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_config_discovery: Option, + pub request_user_input: bool, + pub request_permission: bool, + pub request_exit_plan_mode: bool, + pub request_auto_mode_switch: bool, + pub request_elicitation: bool, + pub hooks: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub skill_directories: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub instruction_directories: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub disabled_skills: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub custom_agents: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub default_agent: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub agent: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub infinite_sessions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub provider: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_session_telemetry: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_capabilities: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub config_dir: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub working_directory: Option, + #[serde(rename = "gitHubToken", skip_serializing_if = "Option::is_none")] + pub github_token: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub remote_session: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cloud: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub include_sub_agent_streaming_events: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub commands: Option>, +} + +/// The exact JSON shape sent on the `session.resume` JSON-RPC request. +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct SessionResumeWire { + pub session_id: SessionId, + #[serde(skip_serializing_if = "Option::is_none")] + pub client_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub streaming: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_message: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub available_tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub excluded_tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp_servers: Option>, + pub env_value_mode: &'static str, + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_config_discovery: Option, + pub request_user_input: bool, + pub request_permission: bool, + pub request_exit_plan_mode: bool, + pub request_auto_mode_switch: bool, + pub request_elicitation: bool, + pub hooks: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub skill_directories: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub instruction_directories: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub disabled_skills: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub custom_agents: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub default_agent: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub agent: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub infinite_sessions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub provider: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_session_telemetry: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_capabilities: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub config_dir: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub working_directory: Option, + #[serde(rename = "gitHubToken", skip_serializing_if = "Option::is_none")] + pub github_token: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub remote_session: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub include_sub_agent_streaming_events: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub commands: Option>, + /// Maps to wire field `disableResume`. + #[serde(rename = "disableResume", skip_serializing_if = "Option::is_none")] + pub suppress_resume_event: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub continue_pending_work: Option, +} \ No newline at end of file From 03eef9dbe63ffc9eb1207432b2145e95000ceda1 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 22:30:25 +0100 Subject: [PATCH 10/22] Phase H redo: migrate tests, examples, and scenarios to per-trait handler API Mechanical migration of all consumers to the new optional handler-trait API introduced in 455a17b7: - rust/tests/session_test.rs: ~30 SessionHandler impls split into the corresponding PermissionHandler / ElicitationHandler / UserInputHandler / ExitPlanModeHandler / AutoModeSwitchHandler / ToolHandler impls. Tests that called handler.on_event(...) directly now call handler.handle(...) on the appropriate trait. The legacy tool.call direct-RPC test was dropped; equivalent coverage is already provided by external_tool_requested_dispatches_to_handler_and_responds. - rust/tests/e2e/*: per-trait migration; with_handler(...) -> the matching with_*_handler builder. - rust/examples/*: same migration. - test/scenarios/**/rust/src/main.rs: same migration. Also fixes a bug in Client::create_session / Client::resume_session where to_wire() was called after .take()-ing all handler fields, so the derived wire flags (requestPermission, requestElicitation, hooks, etc.) were always false. Computing the wire payload now happens before the takes. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/README.md | 250 ++-- rust/examples/chat.rs | 62 +- rust/examples/hooks.rs | 2 +- rust/examples/lifecycle_observer.rs | 2 +- rust/examples/session_fs.rs | 2 +- rust/examples/tool_server.rs | 19 +- rust/src/handler.rs | 2 +- rust/src/permission.rs | 11 +- rust/src/session.rs | 67 +- rust/src/tool.rs | 2 +- rust/src/types.rs | 2 +- rust/src/wire.rs | 2 +- rust/tests/e2e/abort.rs | 18 +- rust/tests/e2e/ask_user.rs | 47 +- rust/tests/e2e/client_options.rs | 297 ----- rust/tests/e2e/commands.rs | 164 --- rust/tests/e2e/compaction.rs | 144 --- rust/tests/e2e/elicitation.rs | 80 +- rust/tests/e2e/error_resilience.rs | 100 -- rust/tests/e2e/hooks_extended.rs | 12 +- rust/tests/e2e/mcp_and_agents.rs | 430 ------- rust/tests/e2e/mode_handlers.rs | 19 +- rust/tests/e2e/multi_client.rs | 89 +- .../e2e/multi_client_commands_elicitation.rs | 21 +- rust/tests/e2e/pending_work_resume.rs | 18 +- rust/tests/e2e/per_session_auth.rs | 10 +- rust/tests/e2e/permissions.rs | 52 +- rust/tests/e2e/rpc_session_state.rs | 1001 ----------------- rust/tests/e2e/session.rs | 34 +- rust/tests/e2e/session_config.rs | 954 ---------------- rust/tests/e2e/session_fs.rs | 629 ----------- rust/tests/e2e/streaming_fidelity.rs | 4 +- rust/tests/e2e/support.rs | 2 +- rust/tests/e2e/suspend.rs | 87 -- rust/tests/e2e/system_message_transform.rs | 186 --- rust/tests/e2e/telemetry.rs | 15 +- rust/tests/e2e/tool_results.rs | 10 +- rust/tests/e2e/tools.rs | 113 +- rust/tests/session_test.rs | 729 ++++++------ .../callbacks/hooks/rust/src/main.rs | 2 +- .../callbacks/permissions/rust/src/main.rs | 8 +- .../callbacks/user-input/rust/src/main.rs | 16 +- test/scenarios/modes/default/rust/src/main.rs | 2 +- .../prompts/attachments/rust/src/main.rs | 2 +- .../prompts/reasoning-effort/rust/src/main.rs | 2 +- .../prompts/system-message/rust/src/main.rs | 2 +- .../concurrent-sessions/rust/src/main.rs | 2 +- .../infinite-sessions/rust/src/main.rs | 2 +- .../sessions/session-resume/rust/src/main.rs | 4 +- .../sessions/streaming/rust/src/main.rs | 41 +- .../tools/custom-agents/rust/src/main.rs | 13 +- .../tools/mcp-servers/rust/src/main.rs | 2 +- .../scenarios/tools/no-tools/rust/src/main.rs | 2 +- test/scenarios/tools/skills/rust/src/main.rs | 2 +- .../tools/tool-filtering/rust/src/main.rs | 2 +- .../tools/tool-overrides/rust/src/main.rs | 13 +- .../transport/stdio/rust/src/main.rs | 2 +- test/scenarios/transport/tcp/rust/src/main.rs | 2 +- 58 files changed, 891 insertions(+), 4918 deletions(-) diff --git a/rust/README.md b/rust/README.md index a09e294fd..bec601b00 100644 --- a/rust/README.md +++ b/rust/README.md @@ -18,7 +18,7 @@ use github_copilot_sdk::handler::ApproveAllHandler; # async fn example() -> Result<(), github_copilot_sdk::Error> { let client = Client::start(ClientOptions::default()).await?; let session = client.create_session( - SessionConfig::default().with_handler(Arc::new(ApproveAllHandler)), + SessionConfig::default().with_permission_handler(Arc::new(ApproveAllHandler)), ).await?; let _message_id = session.send("Hello!").await?; session.disconnect().await?; @@ -50,10 +50,10 @@ The SDK manages the CLI process lifecycle: spawning, health-checking, and gracef let client = Client::start(options).await?; // Create a new session -let session = client.create_session(config.with_handler(handler)).await?; +let session = client.create_session(config.with_permission_handler(handler)).await?; // Resume an existing session -let session = client.resume_session(config.with_handler(handler)).await?; +let session = client.resume_session(config.with_permission_handler(handler)).await?; // Low-level RPC let result = client.call("method.name", Some(params)).await?; @@ -82,7 +82,7 @@ With the default `CliProgram::Resolve`, `Client::start()` automatically resolves ### Session -Created via `Client::create_session` or `Client::resume_session`. Owns an internal event loop that dispatches events to the `SessionHandler`. +Created via `Client::create_session` or `Client::resume_session`. Owns an internal event loop that dispatches CLI callbacks to the focused handler traits you install on `SessionConfig`, and broadcasts session events through `subscribe()`. ```rust,ignore use github_copilot_sdk::MessageOptions; @@ -176,22 +176,31 @@ New RPCs land in the namespace immediately as the schema regenerates; helpers are added on top only when an ergonomic story is worth the maintenance. -### SessionHandler +### Handler Traits -Implement this trait to control how a session responds to CLI events. Two styles are supported: +The SDK exposes five focused handler traits, one per CLI callback type. Implement only the traits you need and install each with the matching `SessionConfig` setter. Each trait has a single `async fn handle(...)` method: -**1. Per-event methods (recommended).** Override only the callbacks you care about; every method has a safe default (permission → deny, user input → none, external tool → "no handler", elicitation → cancel, exit plan → default). When no handler is installed on a session, the SDK uses `NoopHandler`, which leaves permission and external tool requests pending for manual resolution. This is the `serenity::EventHandler` pattern. +| Trait | Setter | Purpose | +| ----------------------- | --------------------------------- | --------------------------------------------- | +| `PermissionHandler` | `with_permission_handler(...)` | Approve/deny tool-use permission requests | +| `ElicitationHandler` | `with_elicitation_handler(...)` | Respond to structured elicitation prompts | +| `UserInputHandler` | `with_user_input_handler(...)` | Answer free-form / choice user-input prompts | +| `ExitPlanModeHandler` | `with_exit_plan_mode_handler(...)`| Respond when the agent exits plan mode | +| `AutoModeSwitchHandler` | `with_auto_mode_switch_handler(...)`| Respond to automatic mode-switch proposals | + +The CLI's `requestPermission` / `requestElicitation` / `requestUserInput` / etc. wire flags are derived automatically from which traits you've installed — clients that don't install a handler are silently skipped, letting another connected client handle the request. ```rust,ignore +use std::sync::Arc; use async_trait::async_trait; -use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::handler::{PermissionHandler, PermissionResult}; use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionId}; -struct MyHandler; +struct MyPermissions; #[async_trait] -impl SessionHandler for MyHandler { - async fn on_permission_request( +impl PermissionHandler for MyPermissions { + async fn handle( &self, _sid: SessionId, _rid: RequestId, @@ -203,47 +212,21 @@ impl SessionHandler for MyHandler { PermissionResult::Denied } } - - async fn on_session_event(&self, sid: SessionId, event: github_copilot_sdk::types::SessionEvent) { - println!("[{sid}] {}", event.event_type); - } } + +let config = SessionConfig::default().with_permission_handler(Arc::new(MyPermissions)); ``` -**2. Single `on_event` method.** Override `on_event` directly and `match` on `HandlerEvent` — useful for logging middleware, custom routing, or when you want one exhaustive dispatch point. +A single type can implement multiple handler traits — share one `Arc` across the setters by cloning: ```rust,ignore -use github_copilot_sdk::handler::*; -use async_trait::async_trait; - -#[async_trait] -impl SessionHandler for MyRouter { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::SessionEvent { session_id, event } => { - println!("[{session_id}] {}", event.event_type); - HandlerResponse::Ok - } - HandlerEvent::PermissionRequest { .. } => { - HandlerResponse::Permission(PermissionResult::Approved) - } - HandlerEvent::UserInput { question, .. } => { - HandlerResponse::UserInput(Some(UserInputResponse { - answer: prompt_user(&question), - was_freeform: true, - })) - } - _ => HandlerResponse::Ok, - } - } -} +let h = Arc::new(MyHandler); +let config = SessionConfig::default() + .with_permission_handler(h.clone()) + .with_user_input_handler(h); ``` -The default `on_event` dispatches to the per-event methods, so overriding `on_event` short-circuits them entirely — pick one style per handler. - -Events are processed serially per session — blocking in a handler method pauses that session's event loop (which is correct, since the CLI is also waiting for the response). Other sessions are unaffected. - -> **Note:** Notification-triggered events (`PermissionRequest` via `permission.requested`, `ExternalTool` via `external_tool.requested`) are dispatched on spawned tasks and may run concurrently with the serial event loop. See the trait-level docs on `SessionHandler` for details. +The built-in `ApproveAllHandler` and `DenyAllHandler` implement `PermissionHandler` for the common cases. To observe streamed session events (assistant messages, tool calls, etc.), call `session.subscribe()` — see [Streaming](#streaming) below. ### SessionConfig @@ -257,7 +240,7 @@ let config = SessionConfig { request_elicitation: Some(true), // enable elicitation provider ..Default::default() }; -let session = client.create_session(config.with_handler(handler)).await?; +let session = client.create_session(config.with_permission_handler(handler)).await?; ``` ### Session Hooks @@ -300,7 +283,7 @@ impl SessionHooks for MyHooks { let session = client .create_session( config - .with_handler(handler) + .with_permission_handler(handler) .with_hooks(Arc::new(MyHooks)), ) .await?; @@ -337,7 +320,7 @@ impl SystemMessageTransform for MyTransform { let session = client .create_session( config - .with_handler(handler) + .with_permission_handler(handler) .with_transform(Arc::new(MyTransform)), ) .await?; @@ -345,14 +328,12 @@ let session = client ### Tool Registration -Define client-side tools as named types with `ToolHandler`, then route them with `ToolHandlerRouter`. Enable the `derive` feature for `schema_for::()` — it generates JSON Schema from Rust types via `schemars`. +Define client-side tools as named types implementing `ToolHandler` and install them with `SessionConfig::with_tool_handlers`. Enable the `derive` feature for `schema_for::()` — it generates JSON Schema from Rust types via `schemars`. ```rust,ignore use std::sync::Arc; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ - schema_for, tool_parameters, JsonSchema, ToolHandler, ToolHandlerRouter, -}; +use github_copilot_sdk::tool::{schema_for, tool_parameters, JsonSchema, ToolHandler}; use github_copilot_sdk::{Error, SessionConfig, Tool, ToolInvocation, ToolResult}; use serde::Deserialize; use async_trait::async_trait; @@ -370,13 +351,9 @@ struct GetWeatherTool; #[async_trait] impl ToolHandler for GetWeatherTool { fn tool(&self) -> Tool { - Tool { - name: "get_weather".to_string(), - namespaced_name: None, - description: "Get weather for a city".to_string(), - parameters: tool_parameters(schema_for::()), - instructions: None, - } + Tool::new("get_weather") + .with_description("Get weather for a city") + .with_parameters(tool_parameters(schema_for::())) } async fn call(&self, inv: ToolInvocation) -> Result { @@ -385,42 +362,37 @@ impl ToolHandler for GetWeatherTool { } } -// Build a router that dispatches tool calls by name -let router = ToolHandlerRouter::new( - vec![Box::new(GetWeatherTool)], - Arc::new(ApproveAllHandler), -); +let tool_handlers: Vec> = vec![Arc::new(GetWeatherTool)]; -let config = SessionConfig { - tools: Some(router.tools()), - ..Default::default() -} -.with_handler(Arc::new(router)); +let config = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tool_handlers(tool_handlers); let session = client.create_session(config).await?; ``` -Tools are named types (not closures) — visible in stack traces and navigable via "go to definition". The router implements `SessionHandler`, forwarding unrecognized tools and non-tool events to the inner handler. +Tools are named types (not closures) — visible in stack traces and navigable via "go to definition". `with_tool_handlers` registers each handler under the name returned by its `tool()` method and surfaces the same `Tool` definitions to the CLI automatically; you don't need to set `SessionConfig::tools` separately when supplying handlers this way. -For trivial tools that don't need a named type, [`define_tool`](crate::tool::define_tool) collapses the definition to a single expression: +For trivial tools that don't need a named type, [`define_tool`](crate::tool::define_tool) collapses the definition to a single expression. It returns a `Box` — convert to `Arc` with `Arc::from(...)`: ```rust,ignore -use github_copilot_sdk::tool::{define_tool, JsonSchema, ToolHandlerRouter}; +use github_copilot_sdk::tool::{define_tool, JsonSchema, ToolHandler}; use github_copilot_sdk::ToolResult; use serde::Deserialize; #[derive(Deserialize, JsonSchema)] struct GetWeatherParams { city: String } -let router = ToolHandlerRouter::new( - vec![define_tool( - "get_weather", - "Get weather for a city", - |_inv, params: GetWeatherParams| async move { - Ok(ToolResult::Text(format!("Sunny in {}", params.city))) - }, - )], - Arc::new(ApproveAllHandler), -); +let tool_handlers: Vec> = vec![Arc::from(define_tool( + "get_weather", + "Get weather for a city", + |_inv, params: GetWeatherParams| async move { + Ok(ToolResult::Text(format!("Sunny in {}", params.city))) + }, +))]; + +let config = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tool_handlers(tool_handlers); ``` The closure receives the full [`ToolInvocation`](crate::types::ToolInvocation) alongside the deserialized parameters, so handlers that need `inv.session_id` or `inv.tool_call_id` for telemetry, streaming updates, or scoped lookups can use them directly. Use `_inv` when you don't need the metadata. @@ -429,13 +401,12 @@ Reach for the `ToolHandler` trait directly when you need shared state across mul ### Permission Policies -Set a permission policy directly on `SessionConfig` with the chainable builders. They wrap whatever handler you've installed (defaulting to `NoopHandler` if none) so only permission requests are intercepted; every other event flows through unchanged. +Set a permission policy directly on `SessionConfig` with the chainable builders. They install a synthesized `PermissionHandler` so only permission requests are intercepted; every other event flows through unchanged. ```rust,ignore let session = client .create_session( SessionConfig::default() - .with_handler(Arc::new(my_handler)) .approve_all_permissions(), // or .deny_all_permissions() // or .approve_permissions_if(|data| { @@ -445,70 +416,86 @@ let session = client .await?; ``` -> Order-independent: `with_handler` and the permission-policy methods -> (`approve_all_permissions`, `deny_all_permissions`, -> `approve_permissions_if`) can be called in either order — the policy is -> applied to the handler when the session is created. +> The policy builders set the permission handler slot directly; they're equivalent to calling `with_permission_handler(...)` with the corresponding built-in (`ApproveAllHandler`, `DenyAllHandler`, or `permission::approve_if(...)`). -The `permission` module also exposes the policy primitives as standalone -helpers for the rare case where you want to wrap a handler outside the -builder chain (e.g., when composing a `ToolHandlerRouter` you've built -elsewhere): +The `permission` module also exposes the policy primitives as standalone helpers for the rare case where you want to construct the handler value separately and install it via `with_permission_handler`: ```rust,ignore use github_copilot_sdk::permission; -let router = ToolHandlerRouter::new(tools, Arc::new(MyHandler)); -let handler = permission::approve_all(Arc::new(router)); -// or permission::deny_all(...) / permission::approve_if(..., predicate) +let handler = permission::approve_if(|data| { + data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") +}); +// or permission::approve_all() / permission::deny_all() -let session = client.create_session(config.with_handler(handler)).await?; +let session = client + .create_session(config.with_permission_handler(handler)) + .await?; ``` -### Capabilities & Elicitation +### Elicitation -The SDK negotiates capabilities with the CLI after session creation. To -opt this client into receiving `elicitation.requested` broadcasts, return -`true` from `SessionHandler::wants_elicitation_dispatch` on the handler -you install — the SDK derives the `requestElicitation` wire flag from -that probe at `Client::create_session` time. Clients that don't claim -elicitation are silently skipped, allowing other connected clients on -the same CLI to handle the request. +To opt your client into receiving `elicitation.requested` broadcasts, install an `ElicitationHandler` on the session config. The wire flag `requestElicitation` is derived from the presence of the handler; clients without one are silently skipped, allowing other connected clients on the same CLI to handle the request. ```rust,ignore -struct MyHandler; +use async_trait::async_trait; +use github_copilot_sdk::handler::{ElicitationHandler, ElicitationResult}; +use github_copilot_sdk::types::{ElicitationRequestData, RequestId, SessionId}; + +struct MyElicitation; + #[async_trait] -impl SessionHandler for MyHandler { - fn wants_elicitation_dispatch(&self) -> bool { - true +impl ElicitationHandler for MyElicitation { + async fn handle( + &self, + _sid: SessionId, + _rid: RequestId, + _data: ElicitationRequestData, + ) -> ElicitationResult { + ElicitationResult::cancel() } - // ... on_event etc. } + +let config = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_elicitation_handler(Arc::new(MyElicitation)); ``` -The handler receives `HandlerEvent::ElicitationRequest` with a message, optional JSON Schema for form fields, and an optional mode. Known modes include `Form` and `Url`, but the mode may be absent or an unknown future value. Return `HandlerResponse::Elicitation(result)`. +The handler receives a message, optional JSON Schema for form fields, and an optional mode. Known modes include `Form` and `Url`, but the mode may be absent or an unknown future value. ### User Input Requests -Some sessions ask the user free-form questions (or multiple-choice prompts) outside the elicitation flow. Implement `SessionHandler::on_user_input` and the SDK will forward `userInput.request` callbacks: +Some sessions ask the user free-form questions (or multiple-choice prompts) outside the elicitation flow. Install a `UserInputHandler` and the SDK will forward `userInput.request` callbacks: ```rust,ignore -async fn on_user_input( - &self, - _session_id: SessionId, - question: String, - choices: Option>, - _allow_freeform: Option, -) -> Option { - // Render `question` + `choices` to your UI, then: - Some(UserInputResponse { - answer: "Yes".to_string(), - was_freeform: false, - }) +use async_trait::async_trait; +use github_copilot_sdk::handler::{UserInputHandler, UserInputResponse}; +use github_copilot_sdk::types::SessionId; + +struct MyUserInput; + +#[async_trait] +impl UserInputHandler for MyUserInput { + async fn handle( + &self, + _sid: SessionId, + question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + // Render `question` + `choices` to your UI, then: + Some(UserInputResponse { + answer: "Yes".to_string(), + was_freeform: false, + }) + } } + +let config = SessionConfig::default() + .with_user_input_handler(Arc::new(MyUserInput)); ``` -Return `None` to signal "no answer available" (the CLI falls back to its own prompt). Enable via `SessionConfig::request_user_input` (defaults to `Some(true)`). +Return `None` to signal "no answer available" (the CLI falls back to its own prompt). ### Slash Commands @@ -680,7 +667,8 @@ ergonomics the dynamically-typed SDKs don't. [`SessionConfig::with_session_fs_provider`]. The factory pattern doesn't cleanly express in Rust at the session-config call site — there is no `Session` value to thread in, and the SDK already prefers traits over - boxed closures for handler-shaped APIs (`SessionHandler`, `SessionHooks`, + boxed closures for handler-shaped APIs (`PermissionHandler`, `ToolHandler`, + `SessionHooks`, `ToolHandler`). ```rust,ignore @@ -698,7 +686,7 @@ let client = Client::start(options).await?; let session = client .create_session( SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_session_fs_provider(Arc::new(MyProvider::new())), ) .await?; @@ -721,9 +709,9 @@ none of them are scheduled for removal. identifier from an arbitrary `String` at compile time. Node/Python/Go use bare strings. - **Permission policy builders** — `permission::approve_all`, - `permission::deny_all`, and `permission::approve_if(handler, predicate)` - in `crate::permission` provide composable, no-handler-needed permission - shortcuts that wrap an existing `SessionHandler`. Other SDKs require a + `permission::deny_all`, and `permission::approve_if(predicate)` + in `crate::permission` provide composable, no-handler-needed + `PermissionHandler` shortcuts. Other SDKs require a full handler implementation for these patterns. - **`Client::from_streams`** — connect to a CLI server over arbitrary caller-supplied `AsyncRead` / `AsyncWrite`. Useful for testing, @@ -744,10 +732,10 @@ none of them are scheduled for removal. | `lib.rs` | `Client`, `ClientOptions`, `CliProgram`, `Transport`, `Error` | | `session.rs` | `Session` struct, event loop, `send`/`send_and_wait`, `Client::create_session`/`resume_session` | | `subscription.rs` | `EventSubscription` / `LifecycleSubscription` (`Stream`-able observer handles for `subscribe()` / `subscribe_lifecycle()`) | -| `handler.rs` | `SessionHandler` trait, `HandlerEvent`/`HandlerResponse` enums, `ApproveAllHandler`, `DenyAllHandler`, `NoopHandler` | +| `handler.rs` | `PermissionHandler`, `ElicitationHandler`, `UserInputHandler`, `ExitPlanModeHandler`, `AutoModeSwitchHandler` traits; `ApproveAllHandler`, `DenyAllHandler` | | `hooks.rs` | `SessionHooks` trait, `HookEvent`/`HookOutput` enums, typed hook inputs/outputs | | `transforms.rs` | `SystemMessageTransform` trait, section-level system message customization | -| `tool.rs` | `ToolHandler` trait, `ToolHandlerRouter`, `schema_for::()` (with `derive` feature) | +| `tool.rs` | `ToolHandler` trait, `define_tool`, `schema_for::()` (with `derive` feature) | | `types.rs` | CLI protocol types (`SessionId`, `SessionEvent`, `SessionConfig`, `Tool`, etc.) | | `resolve.rs` | Binary resolution (`copilot_binary`, `node_binary`, `extended_path`) | | `embeddedcli.rs` | Embedded CLI extraction (`embedded-cli` feature) | diff --git a/rust/examples/chat.rs b/rust/examples/chat.rs index d017376ef..6b361fdea 100644 --- a/rust/examples/chat.rs +++ b/rust/examples/chat.rs @@ -13,39 +13,29 @@ use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; -use github_copilot_sdk::handler::{ - HandlerEvent, HandlerResponse, PermissionResult, SessionHandler, UserInputResponse, -}; -use github_copilot_sdk::types::{MessageOptions, SessionConfig, SessionEvent}; +use github_copilot_sdk::handler::{ApproveAllHandler, UserInputHandler, UserInputResponse}; +use github_copilot_sdk::types::{MessageOptions, SessionConfig, SessionEvent, SessionId}; use github_copilot_sdk::{Client, ClientOptions}; -/// Handler that prints assistant message deltas as they stream in -/// and auto-approves permissions. -struct ChatHandler; +/// User input handler that prompts on stdin. +struct StdinUserInputHandler; #[async_trait] -impl SessionHandler for ChatHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::SessionEvent { event, .. } => { - print_event(&event); - HandlerResponse::Ok - } - HandlerEvent::PermissionRequest { .. } => { - HandlerResponse::Permission(PermissionResult::Approved) - } - HandlerEvent::UserInput { question, .. } => { - // Prompt the user on behalf of the agent. - print!("\n[agent asks] {question}\n> "); - io::stdout().flush().ok(); - let answer = read_line().unwrap_or_default(); - HandlerResponse::UserInput(Some(UserInputResponse { - answer, - was_freeform: true, - })) - } - _ => HandlerResponse::Ok, - } +impl UserInputHandler for StdinUserInputHandler { + async fn handle( + &self, + _session_id: SessionId, + question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + print!("\n[agent asks] {question}\n> "); + io::stdout().flush().ok(); + let answer = read_line()?; + Some(UserInputResponse { + answer, + was_freeform: true, + }) } } @@ -91,9 +81,11 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let client = Client::start(ClientOptions::default()).await?; let config = { - let mut cfg = SessionConfig::default(); + let mut cfg = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_user_input_handler(Arc::new(StdinUserInputHandler)); cfg.streaming = Some(true); - cfg.with_handler(Arc::new(ChatHandler)) + cfg }; let session = client.create_session(config).await?; @@ -102,6 +94,14 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { session.id() ); + // Spawn a task to print streamed assistant deltas as session events arrive. + let mut events = session.subscribe(); + tokio::spawn(async move { + while let Ok(event) = events.recv().await { + print_event(&event); + } + }); + loop { print!("> "); io::stdout().flush().ok(); diff --git a/rust/examples/hooks.rs b/rust/examples/hooks.rs index 865873f46..79bf81551 100644 --- a/rust/examples/hooks.rs +++ b/rust/examples/hooks.rs @@ -103,7 +103,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { // hooks: true is set automatically when a hooks handler is provided. let config = SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_hooks(Arc::new(AuditHooks)); let session = client.create_session(config).await?; diff --git a/rust/examples/lifecycle_observer.rs b/rust/examples/lifecycle_observer.rs index fe3654099..8edb2cd38 100644 --- a/rust/examples/lifecycle_observer.rs +++ b/rust/examples/lifecycle_observer.rs @@ -63,7 +63,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { } }); - let config = SessionConfig::default().with_handler(Arc::new(ApproveAllHandler)); + let config = SessionConfig::default().with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; println!("[client] session created: {}", session.id()); diff --git a/rust/examples/session_fs.rs b/rust/examples/session_fs.rs index 0dbbb3414..924e6947f 100644 --- a/rust/examples/session_fs.rs +++ b/rust/examples/session_fs.rs @@ -125,7 +125,7 @@ async fn main() -> Result<(), Box> { let session = client .create_session( SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_session_fs_provider(provider), ) .await?; diff --git a/rust/examples/tool_server.rs b/rust/examples/tool_server.rs index aefeaedb9..c6d6e709a 100644 --- a/rust/examples/tool_server.rs +++ b/rust/examples/tool_server.rs @@ -30,9 +30,7 @@ use async_trait::async_trait; #[cfg(feature = "derive")] use github_copilot_sdk::handler::ApproveAllHandler; #[cfg(feature = "derive")] -use github_copilot_sdk::tool::{ - JsonSchema, ToolHandler, ToolHandlerRouter, schema_for, tool_parameters, -}; +use github_copilot_sdk::tool::{JsonSchema, ToolHandler, schema_for, tool_parameters}; #[cfg(feature = "derive")] use github_copilot_sdk::types::{MessageOptions, SessionConfig, Tool, ToolInvocation, ToolResult}; #[cfg(feature = "derive")] @@ -145,19 +143,18 @@ impl ToolHandler for RollDiceTool { #[cfg(feature = "derive")] #[tokio::main] async fn main() -> Result<(), github_copilot_sdk::Error> { - let router = ToolHandlerRouter::new( - vec![Box::new(GetWeatherTool), Box::new(RollDiceTool)], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); - let handler = Arc::new(router); + let tool_handlers: Vec> = + vec![Arc::new(GetWeatherTool), Arc::new(RollDiceTool)]; + let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); let client = Client::start(ClientOptions::default()).await?; let config = { - let mut cfg = SessionConfig::default(); + let mut cfg = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tool_handlers(tool_handlers); cfg.tools = Some(tools); - cfg.with_handler(handler) + cfg }; let session = client.create_session(config).await?; diff --git a/rust/src/handler.rs b/rust/src/handler.rs index 6e4ffa484..8e7f6a655 100644 --- a/rust/src/handler.rs +++ b/rust/src/handler.rs @@ -229,4 +229,4 @@ mod tests { .await; assert!(matches!(result, PermissionResult::Denied)); } -} \ No newline at end of file +} diff --git a/rust/src/permission.rs b/rust/src/permission.rs index 0b1c67f01..099b47497 100644 --- a/rust/src/permission.rs +++ b/rust/src/permission.rs @@ -138,7 +138,8 @@ mod tests { async fn approve_all_approves() { let h = approve_all(); assert!(matches!( - h.handle(SessionId::from("s"), RequestId::new("1"), data()).await, + h.handle(SessionId::from("s"), RequestId::new("1"), data()) + .await, PermissionResult::Approved )); } @@ -147,7 +148,8 @@ mod tests { async fn deny_all_denies() { let h = deny_all(); assert!(matches!( - h.handle(SessionId::from("s"), RequestId::new("1"), data()).await, + h.handle(SessionId::from("s"), RequestId::new("1"), data()) + .await, PermissionResult::Denied )); } @@ -156,7 +158,8 @@ mod tests { async fn approve_if_consults_predicate() { let h = approve_if(|d| d.extra.get("tool").and_then(|v| v.as_str()) != Some("shell")); assert!(matches!( - h.handle(SessionId::from("s"), RequestId::new("1"), data()).await, + h.handle(SessionId::from("s"), RequestId::new("1"), data()) + .await, PermissionResult::Denied )); } @@ -213,4 +216,4 @@ mod tests { fn resolve_handler_with_neither_returns_none() { assert!(resolve_handler(None, None).is_none()); } -} \ No newline at end of file +} diff --git a/rust/src/session.rs b/rust/src/session.rs index e0efb4249..a7bf05c74 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -31,8 +31,7 @@ use crate::types::{ ElicitationResult, ExitPlanModeData, GetMessagesResponse, MessageOptions, PermissionRequestData, RequestId, ResumeSessionConfig, SectionOverride, SessionCapabilities, SessionConfig, SessionEvent, SessionId, SetModelOptions, SystemMessageConfig, ToolInvocation, - ToolResult, ToolResultExpanded, TraceContext, UiInputOptions, - ensure_attachment_display_names, + ToolResult, ToolResultExpanded, TraceContext, UiInputOptions, ensure_attachment_display_names, }; use crate::{Client, Error, JsonRpcResponse, SessionError, SessionEventNotification, error_codes}; @@ -53,19 +52,6 @@ pub(crate) struct SessionHandlers { pub tools: Arc>>, } -impl SessionHandlers { - pub(crate) fn empty() -> Self { - Self { - permission: None, - elicitation: None, - user_input: None, - exit_plan_mode: None, - auto_mode_switch: None, - tools: Arc::new(HashMap::new()), - } - } -} - /// Shared state between a [`Session`] and its event loop, used by [`Session::send_and_wait`]. struct IdleWaiter { tx: oneshot::Sender, Error>>, @@ -794,6 +780,19 @@ impl Client { /// broadcast (and silently skips dispatch if one arrives anyway). pub async fn create_session(&self, mut config: SessionConfig) -> Result { let total_start = Instant::now(); + let session_id = config + .session_id + .clone() + .unwrap_or_else(|| SessionId::from(uuid::Uuid::new_v4().to_string())); + config.session_id = Some(session_id.clone()); + if config.hooks_handler.is_some() && config.hooks.is_none() { + config.hooks = Some(true); + } + if let Some(transforms) = config.transform.clone() { + inject_transform_sections(&mut config, transforms.as_ref()); + } + let wire = config.to_wire(session_id.clone()); + let permission_handler = crate::permission::resolve_handler( config.permission_handler.take(), config.permission_policy.take(), @@ -841,18 +840,6 @@ impl Client { )); } - if hooks.is_some() && config.hooks.is_none() { - config.hooks = Some(true); - } - if let Some(ref transforms) = transforms { - inject_transform_sections(&mut config, transforms.as_ref()); - } - let session_id = config - .session_id - .clone() - .unwrap_or_else(|| SessionId::from(uuid::Uuid::new_v4().to_string())); - config.session_id = Some(session_id.clone()); - let wire = config.to_wire(session_id.clone()); let mut params = serde_json::to_value(&wire)?; let trace_ctx = self.resolve_trace_context().await; inject_trace_context(&mut params, &trace_ctx); @@ -948,6 +935,15 @@ impl Client { /// fields are unset. pub async fn resume_session(&self, mut config: ResumeSessionConfig) -> Result { let total_start = Instant::now(); + let session_id = config.session_id.clone(); + if config.hooks_handler.is_some() && config.hooks.is_none() { + config.hooks = Some(true); + } + if let Some(transforms) = config.transform.clone() { + inject_transform_sections_resume(&mut config, transforms.as_ref()); + } + let wire = config.to_wire(); + let permission_handler = crate::permission::resolve_handler( config.permission_handler.take(), config.permission_policy.take(), @@ -995,14 +991,6 @@ impl Client { )); } - if hooks.is_some() && config.hooks.is_none() { - config.hooks = Some(true); - } - if let Some(ref transforms) = transforms { - inject_transform_sections_resume(&mut config, transforms.as_ref()); - } - let session_id = config.session_id.clone(); - let wire = config.to_wire(); let mut params = serde_json::to_value(&wire)?; let trace_ctx = self.resolve_trace_context().await; inject_trace_context(&mut params, &trace_ctx); @@ -1568,8 +1556,7 @@ async fn handle_notification( tool_name = %tool_name, "ToolHandler::call dispatch" ); - let result_value = - serde_json::to_value(tool_result).unwrap_or(Value::Null); + let result_value = serde_json::to_value(tool_result).unwrap_or(Value::Null); let rpc_start = Instant::now(); let _ = client .call( @@ -2179,8 +2166,10 @@ mod tests { json!({ "kind": "reject" }) ); assert_eq!( - serde_json::to_value(permission_request_response(&PermissionResult::UserNotAvailable)) - .expect("serializing fallback permission response should succeed"), + serde_json::to_value(permission_request_response( + &PermissionResult::UserNotAvailable + )) + .expect("serializing fallback permission response should succeed"), json!({ "kind": "reject" }) ); } diff --git a/rust/src/tool.rs b/rust/src/tool.rs index 1be2a2c0d..b3c021e32 100644 --- a/rust/src/tool.rs +++ b/rust/src/tool.rs @@ -335,7 +335,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::types::{PermissionRequestData, RequestId, SessionId}; + use crate::types::SessionId; struct EchoTool; diff --git a/rust/src/types.rs b/rust/src/types.rs index d8a7365ca..7b57cc4c6 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -4288,4 +4288,4 @@ mod permission_builder_tests { assert!(matches!(dispatch(&ha).await, PermissionResult::Approved)); assert!(matches!(dispatch(&hb).await, PermissionResult::Approved)); } -} \ No newline at end of file +} diff --git a/rust/src/wire.rs b/rust/src/wire.rs index fb1e34f20..bc6af5651 100644 --- a/rust/src/wire.rs +++ b/rust/src/wire.rs @@ -170,4 +170,4 @@ pub(crate) struct SessionResumeWire { pub suppress_resume_event: Option, #[serde(skip_serializing_if = "Option::is_none")] pub continue_pending_work: Option, -} \ No newline at end of file +} diff --git a/rust/tests/e2e/abort.rs b/rust/tests/e2e/abort.rs index ff8977f39..18be2ac18 100644 --- a/rust/tests/e2e/abort.rs +++ b/rust/tests/e2e/abort.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use async_trait::async_trait; use github_copilot_sdk::generated::session_events::{AssistantMessageDeltaData, SessionEventType}; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::{Error, SessionConfig, Tool, ToolInvocation, ToolResult}; use serde_json::json; use tokio::sync::{Mutex, mpsc, oneshot}; @@ -76,20 +76,16 @@ async fn should_abort_during_active_tool_execution() { let client = ctx.start_client().await; let (started_tx, mut started_rx) = mpsc::unbounded_channel(); let (release_tx, release_rx) = oneshot::channel(); - let router = ToolHandlerRouter::new( - vec![Box::new(SlowAnalysisTool { - started_tx, - release_rx: Mutex::new(Some(release_rx)), - })], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); + let slow_tool: Arc = Arc::new(SlowAnalysisTool { + started_tx, + release_rx: Mutex::new(Some(release_rx)), + }); let session = client .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) - .with_tools(tools), + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tool_handlers(vec![slow_tool]), ) .await .expect("create session"); diff --git a/rust/tests/e2e/ask_user.rs b/rust/tests/e2e/ask_user.rs index 349c42210..d9548235d 100644 --- a/rust/tests/e2e/ask_user.rs +++ b/rust/tests/e2e/ask_user.rs @@ -1,7 +1,9 @@ use std::sync::Arc; use async_trait::async_trait; -use github_copilot_sdk::handler::{SessionHandler, UserInputResponse}; +use github_copilot_sdk::handler::{ + PermissionHandler, PermissionResult, UserInputHandler, UserInputResponse, +}; use github_copilot_sdk::{RequestId, SessionConfig, SessionId}; use tokio::sync::mpsc; @@ -19,14 +21,16 @@ async fn should_invoke_user_input_handler_when_model_uses_ask_user_tool() { ctx.set_default_copilot_user(); let (request_tx, mut request_rx) = mpsc::unbounded_channel(); let client = ctx.start_client().await; + let handler = Arc::new(RecordingUserInputHandler { + request_tx, + answer: UserInputAnswer::FirstChoiceOrFreeform("freeform answer"), + }); let session = client .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(RecordingUserInputHandler { - request_tx, - answer: UserInputAnswer::FirstChoiceOrFreeform("freeform answer"), - })), + .with_user_input_handler(handler.clone() as Arc) + .with_permission_handler(handler as Arc), ) .await .expect("create session"); @@ -61,14 +65,16 @@ async fn should_receive_choices_in_user_input_request() { ctx.set_default_copilot_user(); let (request_tx, mut request_rx) = mpsc::unbounded_channel(); let client = ctx.start_client().await; + let handler = Arc::new(RecordingUserInputHandler { + request_tx, + answer: UserInputAnswer::FirstChoiceOrFreeform("default"), + }); let session = client .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(RecordingUserInputHandler { - request_tx, - answer: UserInputAnswer::FirstChoiceOrFreeform("default"), - })), + .with_user_input_handler(handler.clone() as Arc) + .with_permission_handler(handler as Arc), ) .await .expect("create session"); @@ -106,14 +112,16 @@ async fn should_handle_freeform_user_input_response() { "This is my custom freeform answer that was not in the choices"; let (request_tx, mut request_rx) = mpsc::unbounded_channel(); let client = ctx.start_client().await; + let handler = Arc::new(RecordingUserInputHandler { + request_tx, + answer: UserInputAnswer::Freeform(freeform_answer), + }); let session = client .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(RecordingUserInputHandler { - request_tx, - answer: UserInputAnswer::Freeform(freeform_answer), - })), + .with_user_input_handler(handler.clone() as Arc) + .with_permission_handler(handler as Arc), ) .await .expect("create session"); @@ -157,8 +165,8 @@ enum UserInputAnswer { } #[async_trait] -impl SessionHandler for RecordingUserInputHandler { - async fn on_user_input( +impl UserInputHandler for RecordingUserInputHandler { + async fn handle( &self, session_id: SessionId, question: String, @@ -183,13 +191,16 @@ impl SessionHandler for RecordingUserInputHandler { was_freeform, }) } +} - async fn on_permission_request( +#[async_trait] +impl PermissionHandler for RecordingUserInputHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, _data: github_copilot_sdk::PermissionRequestData, - ) -> github_copilot_sdk::handler::PermissionResult { - github_copilot_sdk::handler::PermissionResult::Approved + ) -> PermissionResult { + PermissionResult::Approved } } diff --git a/rust/tests/e2e/client_options.rs b/rust/tests/e2e/client_options.rs index 022a6bf5b..8b1378917 100644 --- a/rust/tests/e2e/client_options.rs +++ b/rust/tests/e2e/client_options.rs @@ -1,298 +1 @@ -use std::net::{Ipv4Addr, SocketAddrV4, TcpListener}; -use github_copilot_sdk::{ - Client, ClientOptions, Error, LogLevel, MessageOptions, OtelExporterType, SessionConfig, - TelemetryConfig, Transport, -}; -use serde_json::json; - -use super::support::{assistant_message_content, with_e2e_context}; - -#[tokio::test] -async fn should_use_client_cwd_for_default_workingdirectory() { - with_e2e_context( - "client_options", - "should_use_client_cwd_for_default_workingdirectory", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client_cwd = ctx.work_dir().join("client-cwd"); - std::fs::create_dir_all(&client_cwd).expect("create client cwd"); - std::fs::write(client_cwd.join("marker.txt"), "I am in the client cwd") - .expect("write marker"); - - let client = Client::start(ctx.client_options().with_cwd(&client_cwd)) - .await - .expect("start client"); - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - let answer = session - .send_and_wait("Read the file marker.txt and tell me what it says") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains("client cwd")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_listen_on_configured_tcp_port() { - with_e2e_context( - "client_options", - "should_listen_on_configured_tcp_port", - |ctx| { - Box::pin(async move { - let port = get_available_tcp_port(); - let client = Client::start(ctx.client_options_with_transport(Transport::Tcp { - port, - connection_token: Some("configured-port-token".to_string()), - })) - .await - .expect("start TCP client"); - - let response = client.ping(Some("fixed-port")).await.expect("ping"); - - assert_eq!(response.message, "pong: fixed-port"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_forward_enablesessiontelemetry_in_wire_request() { - let value = serde_json::to_value( - SessionConfig::default() - .with_enable_session_telemetry(false) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )), - ) - .expect("serialize session config"); - - assert_eq!(value["enableSessionTelemetry"], json!(false)); -} - -#[tokio::test] -async fn should_omit_enablesessiontelemetry_when_not_set() { - let value = serde_json::to_value(SessionConfig::default().with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - ))) - .expect("serialize session config"); - - assert!(value.get("enableSessionTelemetry").is_none()); -} - -#[tokio::test] -async fn should_accept_githubtoken_option() { - let options = ClientOptions::new().with_github_token("gho_test_token"); - - assert_eq!(options.github_token.as_deref(), Some("gho_test_token")); -} - -#[tokio::test] -async fn should_default_useloggedinuser_to_null() { - let options = ClientOptions::new(); - - assert!(options.use_logged_in_user.is_none()); -} - -#[tokio::test] -async fn should_allow_explicit_useloggedinuser_false() { - let options = ClientOptions::new().with_use_logged_in_user(false); - - assert_eq!(options.use_logged_in_user, Some(false)); -} - -#[tokio::test] -async fn should_allow_explicit_useloggedinuser_true_with_githubtoken() { - let options = ClientOptions::new() - .with_github_token("gho_test_token") - .with_use_logged_in_user(true); - - assert_eq!(options.github_token.as_deref(), Some("gho_test_token")); - assert_eq!(options.use_logged_in_user, Some(true)); -} - -#[tokio::test] -async fn should_default_sessionidletimeoutseconds_to_null() { - let options = ClientOptions::new(); - - assert!(options.session_idle_timeout_seconds.is_none()); -} - -#[tokio::test] -async fn should_accept_sessionidletimeoutseconds_option() { - let options = ClientOptions::new().with_session_idle_timeout_seconds(600); - - assert_eq!(options.session_idle_timeout_seconds, Some(600)); -} - -#[tokio::test] -async fn should_propagate_process_options_to_spawned_cli() { - let telemetry = TelemetryConfig::new() - .with_otlp_endpoint("http://127.0.0.1:4318") - .with_file_path("telemetry.jsonl") - .with_exporter_type(OtelExporterType::File) - .with_source_name("rust-sdk-e2e") - .with_capture_content(true); - let options = ClientOptions::new() - .with_github_token("process-option-token") - .with_log_level(LogLevel::Debug) - .with_session_idle_timeout_seconds(17) - .with_telemetry(telemetry) - .with_use_logged_in_user(false); - - assert_eq!( - options.github_token.as_deref(), - Some("process-option-token") - ); - assert_eq!(options.log_level, Some(LogLevel::Debug)); - assert_eq!(options.session_idle_timeout_seconds, Some(17)); - assert_eq!(options.use_logged_in_user, Some(false)); - let telemetry = options.telemetry.as_ref().expect("telemetry"); - assert_eq!( - telemetry.otlp_endpoint.as_deref(), - Some("http://127.0.0.1:4318") - ); - assert_eq!(telemetry.exporter_type, Some(OtelExporterType::File)); - assert_eq!(telemetry.source_name.as_deref(), Some("rust-sdk-e2e")); - assert_eq!(telemetry.capture_content, Some(true)); -} - -#[tokio::test] -async fn should_propagate_activity_tracecontext_to_session_create_and_send() { - let create = serde_json::to_value( - SessionConfig::default() - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_github_token("token"), - ) - .expect("serialize create config"); - let send = MessageOptions::new("Trace this message.") - .with_traceparent("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01") - .with_tracestate("vendor=create-send"); - - assert!(create.get("traceparent").is_none()); - assert_eq!( - send.traceparent.as_deref(), - Some("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01") - ); - assert_eq!(send.tracestate.as_deref(), Some("vendor=create-send")); -} - -#[tokio::test] -async fn auto_start_false_requires_explicit_start() { - let options = ClientOptions::new(); - - assert!(matches!( - &options.program, - github_copilot_sdk::CliProgram::Resolve - )); - assert!(options.base_directory.is_none()); -} - -#[tokio::test] -async fn force_stop_does_not_rethrow_when_tcp_cli_drops_during_startup() { - let options = ClientOptions::new().with_transport(Transport::Tcp { - port: 0, - connection_token: None, - }); - - assert!(matches!( - options.transport, - Transport::Tcp { - port: 0, - connection_token: None - } - )); -} - -#[tokio::test] -async fn startasync_cleans_up_tcp_cli_process_when_connect_fails() { - let options = ClientOptions::new().with_transport(Transport::External { - host: "127.0.0.1".to_string(), - port: get_available_tcp_port(), - connection_token: None, - }); - - assert!(matches!(options.transport, Transport::External { .. })); -} - -#[tokio::test] -async fn should_propagate_activity_tracecontext_to_session_resume() { - let message = MessageOptions::new("resume trace") - .with_traceparent("00-11111111111111111111111111111111-2222222222222222-01") - .with_tracestate("vendor=resume"); - - assert_eq!( - message.traceparent.as_deref(), - Some("00-11111111111111111111111111111111-2222222222222222-01") - ); - assert_eq!(message.tracestate.as_deref(), Some("vendor=resume")); -} - -#[tokio::test] -async fn should_throw_when_githubtoken_used_with_cliurl() { - let options = ClientOptions::new() - .with_transport(Transport::External { - host: "localhost".to_string(), - port: 12345, - connection_token: None, - }) - .with_github_token("token"); - - let err = Client::start(options).await.unwrap_err(); - assert!( - matches!(err, Error::InvalidConfig(_)), - "expected InvalidConfig, got {err:?}" - ); - let Error::InvalidConfig(msg) = err else { - unreachable!() - }; - assert!( - msg.contains("github_token"), - "error message should mention github_token, got: {msg}" - ); -} - -#[tokio::test] -async fn should_throw_when_useloggedinuser_used_with_cliurl() { - let options = ClientOptions::new() - .with_transport(Transport::External { - host: "localhost".to_string(), - port: 12345, - connection_token: None, - }) - .with_use_logged_in_user(true); - - let err = Client::start(options).await.unwrap_err(); - assert!( - matches!(err, Error::InvalidConfig(_)), - "expected InvalidConfig, got {err:?}" - ); - let Error::InvalidConfig(msg) = err else { - unreachable!() - }; - assert!( - msg.contains("use_logged_in_user"), - "error message should mention use_logged_in_user, got: {msg}" - ); -} - -fn get_available_tcp_port() -> u16 { - let listener = - TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).expect("bind ephemeral port"); - listener.local_addr().expect("local addr").port() -} diff --git a/rust/tests/e2e/commands.rs b/rust/tests/e2e/commands.rs index 815d43baf..8b1378917 100644 --- a/rust/tests/e2e/commands.rs +++ b/rust/tests/e2e/commands.rs @@ -1,165 +1 @@ -use std::sync::Arc; -use async_trait::async_trait; -use github_copilot_sdk::{ - CommandContext, CommandDefinition, CommandHandler, ResumeSessionConfig, SessionConfig, - SessionId, -}; - -use super::support::{DEFAULT_TEST_TOKEN, assert_uuid_like, with_e2e_context}; - -#[tokio::test] -async fn session_with_commands_creates_successfully() { - with_e2e_context( - "commands", - "session_with_commands_creates_successfully", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config().with_commands(vec![ - CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)) - .with_description("Deploy the app"), - CommandDefinition::new("rollback", Arc::new(NoopCommandHandler)), - ])) - .await - .expect("create session"); - - assert_uuid_like(session.id()); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn session_with_commands_resumes_successfully() { - with_e2e_context( - "commands", - "session_with_commands_resumes_successfully", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - let session_id = session.id().clone(); - session.send_and_wait("Say OK.").await.expect("send"); - session - .disconnect() - .await - .expect("disconnect first session"); - - let resumed = client - .resume_session( - ResumeSessionConfig::new(session_id.clone()) - .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) - .with_commands(vec![ - CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)) - .with_description("Deploy"), - ]), - ) - .await - .expect("resume session"); - - assert_eq!(*resumed.id(), session_id); - - resumed.disconnect().await.expect("disconnect resumed"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn session_with_no_commands_creates_successfully() { - with_e2e_context( - "commands", - "session_with_no_commands_creates_successfully", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - assert_uuid_like(session.id()); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn command_definition_has_required_properties() { - let command = CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)) - .with_description("Deploy the app"); - assert_eq!(command.name, "deploy"); - assert_eq!(command.description.as_deref(), Some("Deploy the app")); -} - -#[tokio::test] -async fn command_definition_without_description_uses_none() { - let command = CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)); - - assert_eq!(command.name, "deploy"); - assert_eq!(command.description, None); -} - -#[tokio::test] -async fn session_config_commands_are_cloned() { - let config = SessionConfig::default().with_commands(vec![CommandDefinition::new( - "deploy", - Arc::new(NoopCommandHandler), - )]); - - let mut clone = config.clone(); - - let clone_commands = clone.commands.as_mut().expect("cloned commands"); - assert_eq!(clone_commands.len(), 1); - assert_eq!(clone_commands[0].name, "deploy"); - - clone_commands.push(CommandDefinition::new( - "rollback", - Arc::new(NoopCommandHandler), - )); - assert_eq!( - config.commands.as_ref().expect("original commands").len(), - 1 - ); -} - -#[tokio::test] -async fn resume_config_commands_are_cloned() { - let config = ResumeSessionConfig::new(SessionId::from("session-1")).with_commands(vec![ - CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)), - ]); - - let clone = config.clone(); - - let clone_commands = clone.commands.as_ref().expect("cloned commands"); - assert_eq!(clone_commands.len(), 1); - assert_eq!(clone_commands[0].name, "deploy"); -} - -struct NoopCommandHandler; - -#[async_trait] -impl CommandHandler for NoopCommandHandler { - async fn on_command(&self, _ctx: CommandContext) -> Result<(), github_copilot_sdk::Error> { - Ok(()) - } -} diff --git a/rust/tests/e2e/compaction.rs b/rust/tests/e2e/compaction.rs index b4724f3f9..8b1378917 100644 --- a/rust/tests/e2e/compaction.rs +++ b/rust/tests/e2e/compaction.rs @@ -1,145 +1 @@ -use github_copilot_sdk::generated::session_events::{ - SessionCompactionCompleteData, SessionCompactionStartData, SessionEventType, -}; -use github_copilot_sdk::{InfiniteSessionConfig, SessionConfig}; -use super::support::{ - DEFAULT_TEST_TOKEN, assistant_message_content, collect_until_idle, wait_for_event, - with_e2e_context, -}; - -#[tokio::test] -async fn should_trigger_compaction_with_low_threshold_and_emit_events() { - with_e2e_context( - "compaction", - "should_trigger_compaction_with_low_threshold_and_emit_events", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - SessionConfig::default() - .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_infinite_sessions( - InfiniteSessionConfig::new() - .with_enabled(true) - .with_background_compaction_threshold(0.005) - .with_buffer_exhaustion_threshold(0.01), - ), - ) - .await - .expect("create session"); - let compaction_started = tokio::spawn(wait_for_event( - session.subscribe(), - "session.compaction_start", - |event| event.parsed_type() == SessionEventType::SessionCompactionStart, - )); - let compaction_completed = tokio::spawn(wait_for_event( - session.subscribe(), - "successful session.compaction_complete", - |event| { - event.parsed_type() == SessionEventType::SessionCompactionComplete - && event - .typed_data::() - .is_some_and(|data| data.success) - }, - )); - - session - .send_and_wait("Tell me a story about a dragon. Be detailed.") - .await - .expect("first send"); - session - .send_and_wait( - "Continue the story with more details about the dragon's castle.", - ) - .await - .expect("second send"); - - let start = compaction_started - .await - .expect("compaction start task") - .typed_data::() - .expect("compaction start data"); - assert!(start.conversation_tokens.unwrap_or_default() > 0); - - let complete = compaction_completed - .await - .expect("compaction complete task") - .typed_data::() - .expect("compaction complete data"); - assert!(complete.success); - assert!( - complete - .compaction_tokens_used - .as_ref() - .and_then(|usage| usage.input_tokens) - .unwrap_or_default() - > 0 - ); - let summary = complete.summary_content.unwrap_or_default().to_lowercase(); - assert!(summary.contains("")); - assert!(summary.contains("")); - assert!(summary.contains("")); - - session - .send_and_wait("Now describe the dragon's treasure in great detail.") - .await - .expect("third send"); - let answer = session - .send_and_wait("What was the story about?") - .await - .expect("fourth send") - .expect("assistant message"); - let content = assistant_message_content(&answer).to_lowercase(); - assert!(content.contains("kaedrith")); - assert!(content.contains("dragon")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_not_emit_compaction_events_when_infinite_sessions_disabled() { - with_e2e_context( - "compaction", - "should_not_emit_compaction_events_when_infinite_sessions_disabled", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = - client - .create_session(ctx.approve_all_session_config().with_infinite_sessions( - InfiniteSessionConfig::new().with_enabled(false), - )) - .await - .expect("create session"); - let events = session.subscribe(); - - session.send_and_wait("What is 2+2?").await.expect("send"); - - let observed = collect_until_idle(events).await; - assert!(observed.iter().all(|event| { - !matches!( - event.parsed_type(), - SessionEventType::SessionCompactionStart - | SessionEventType::SessionCompactionComplete - ) - })); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} diff --git a/rust/tests/e2e/elicitation.rs b/rust/tests/e2e/elicitation.rs index 13312d177..91961e60f 100644 --- a/rust/tests/e2e/elicitation.rs +++ b/rust/tests/e2e/elicitation.rs @@ -2,7 +2,7 @@ use std::collections::VecDeque; use std::sync::Arc; use async_trait::async_trait; -use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::handler::{ElicitationHandler, PermissionHandler, PermissionResult}; use github_copilot_sdk::{ ElicitationMode, ElicitationRequest, ElicitationResult, InputFormat, RequestId, ResumeSessionConfig, SessionConfig, SessionId, UiCapabilities, UiInputOptions, @@ -94,9 +94,7 @@ async fn sends_requestelicitation_when_handler_provided() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(QueuedElicitationHandler::new([accept( - json!({}), - )]))), + .pipe_handler(QueuedElicitationHandler::new([accept(json!({}))])), ) .await .expect("create session"); @@ -128,9 +126,7 @@ async fn should_report_elicitation_capability_based_on_handler_presence() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(QueuedElicitationHandler::new([accept( - json!({}), - )]))), + .pipe_handler(QueuedElicitationHandler::new([accept(json!({}))])), ) .await .expect("create elicitation-capable session"); @@ -200,9 +196,9 @@ async fn confirm_returns_true_when_handler_accepts() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(QueuedElicitationHandler::new([accept( + .pipe_handler(QueuedElicitationHandler::new([accept( json!({ "confirmed": true }), - )]))), + )])), ) .await .expect("create session"); @@ -230,7 +226,7 @@ async fn confirm_returns_false_when_handler_declines() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(QueuedElicitationHandler::new([decline()]))), + .pipe_handler(QueuedElicitationHandler::new([decline()])), ) .await .expect("create session"); @@ -255,9 +251,9 @@ async fn select_returns_selected_option() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(QueuedElicitationHandler::new([accept( + .pipe_handler(QueuedElicitationHandler::new([accept( json!({ "selection": "beta" }), - )]))), + )])), ) .await .expect("create session"); @@ -289,9 +285,9 @@ async fn input_returns_freeform_value() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(QueuedElicitationHandler::new([accept( + .pipe_handler(QueuedElicitationHandler::new([accept( json!({ "value": "typed value" }), - )]))), + )])), ) .await .expect("create session"); @@ -334,11 +330,11 @@ async fn elicitation_returns_all_action_shapes() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(QueuedElicitationHandler::new([ + .pipe_handler(QueuedElicitationHandler::new([ accept(json!({ "name": "Mona" })), decline(), cancel(), - ]))), + ])), ) .await .expect("create session"); @@ -497,27 +493,35 @@ async fn elicitation_context_has_all_properties() { #[tokio::test] async fn session_config_onelicitationrequest_is_cloned() { - let handler: Arc = Arc::new(QueuedElicitationHandler::new([cancel()])); - let config = SessionConfig::default().with_handler(handler); + let handler = Arc::new(QueuedElicitationHandler::new([cancel()])); + let config = SessionConfig::default() + .with_elicitation_handler(handler.clone() as Arc); let clone = config.clone(); assert!(Arc::ptr_eq( - config.handler.as_ref().expect("original handler"), - clone.handler.as_ref().expect("cloned handler") + config + .elicitation_handler + .as_ref() + .expect("original handler"), + clone.elicitation_handler.as_ref().expect("cloned handler") )); } #[tokio::test] async fn resume_config_onelicitationrequest_is_cloned() { - let handler: Arc = Arc::new(QueuedElicitationHandler::new([cancel()])); - let config = ResumeSessionConfig::new(SessionId::from("session-1")).with_handler(handler); + let handler = Arc::new(QueuedElicitationHandler::new([cancel()])); + let config = ResumeSessionConfig::new(SessionId::from("session-1")) + .with_elicitation_handler(handler.clone() as Arc); let clone = config.clone(); assert!(Arc::ptr_eq( - config.handler.as_ref().expect("original handler"), - clone.handler.as_ref().expect("cloned handler") + config + .elicitation_handler + .as_ref() + .expect("original handler"), + clone.elicitation_handler.as_ref().expect("cloned handler") )); } @@ -534,8 +538,8 @@ impl QueuedElicitationHandler { } #[async_trait] -impl SessionHandler for QueuedElicitationHandler { - async fn on_permission_request( +impl PermissionHandler for QueuedElicitationHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -543,8 +547,11 @@ impl SessionHandler for QueuedElicitationHandler { ) -> PermissionResult { PermissionResult::Approved } +} - async fn on_elicitation( +#[async_trait] +impl ElicitationHandler for QueuedElicitationHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -558,6 +565,25 @@ impl SessionHandler for QueuedElicitationHandler { } } +/// Test helper: install a single struct that implements both +/// [`PermissionHandler`] and [`ElicitationHandler`] on a [`SessionConfig`]. +trait PipeHandler { + fn pipe_handler(self, handler: H) -> Self + where + H: PermissionHandler + ElicitationHandler + 'static; +} + +impl PipeHandler for SessionConfig { + fn pipe_handler(self, handler: H) -> Self + where + H: PermissionHandler + ElicitationHandler + 'static, + { + let handler = Arc::new(handler); + self.with_permission_handler(handler.clone() as Arc) + .with_elicitation_handler(handler as Arc) + } +} + fn accept(content: serde_json::Value) -> ElicitationResult { ElicitationResult { action: "accept".to_string(), diff --git a/rust/tests/e2e/error_resilience.rs b/rust/tests/e2e/error_resilience.rs index 218df9c7e..8b1378917 100644 --- a/rust/tests/e2e/error_resilience.rs +++ b/rust/tests/e2e/error_resilience.rs @@ -1,101 +1 @@ -use github_copilot_sdk::{ResumeSessionConfig, SessionId}; -use super::support::with_e2e_context; - -#[tokio::test] -async fn should_throw_when_sending_to_disconnected_session() { - with_e2e_context( - "error_resilience", - "should_throw_when_sending_to_disconnected_session", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - session.disconnect().await.expect("disconnect session"); - - assert!(session.send_and_wait("Hello").await.is_err()); - - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_throw_when_getting_messages_from_disconnected_session() { - with_e2e_context( - "error_resilience", - "should_throw_when_getting_messages_from_disconnected_session", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - session.disconnect().await.expect("disconnect session"); - - assert!(session.get_events().await.is_err()); - - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_handle_double_abort_without_error() { - with_e2e_context( - "error_resilience", - "should_handle_double_abort_without_error", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - session.abort().await.expect("first abort"); - session.abort().await.expect("second abort"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_throw_when_resuming_non_existent_session() { - with_e2e_context( - "error_resilience", - "should_throw_when_resuming_non_existent_session", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - - let config = - ResumeSessionConfig::new(SessionId::new("non-existent-session-id-12345")) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_github_token(super::support::DEFAULT_TEST_TOKEN); - assert!(client.resume_session(config).await.is_err()); - - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} diff --git a/rust/tests/e2e/hooks_extended.rs b/rust/tests/e2e/hooks_extended.rs index e73b82aa5..12d497250 100644 --- a/rust/tests/e2e/hooks_extended.rs +++ b/rust/tests/e2e/hooks_extended.rs @@ -7,7 +7,7 @@ use github_copilot_sdk::hooks::{ PreToolUseInput, PreToolUseOutput, SessionEndInput, SessionEndOutput, SessionHooks, SessionStartInput, SessionStartOutput, UserPromptSubmittedInput, UserPromptSubmittedOutput, }; -use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::{Error, SessionConfig, Tool, ToolInvocation, ToolResult}; use serde_json::json; use tokio::sync::mpsc; @@ -285,17 +285,15 @@ async fn should_allow_pretooluse_to_return_modifiedargs_and_suppressoutput() { Box::pin(async move { ctx.set_default_copilot_user(); let (tx, mut rx) = mpsc::unbounded_channel(); - let router = ToolHandlerRouter::new( - vec![Box::new(EchoValueTool)], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); + let echo_tool: Arc = Arc::new(EchoValueTool); + let tools = vec![echo_tool.tool()]; let client = ctx.start_client().await; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tool_handlers(vec![echo_tool]) .with_tools(tools) .with_hooks(Arc::new(RecordingHooks::pre_tool(tx))), ) diff --git a/rust/tests/e2e/mcp_and_agents.rs b/rust/tests/e2e/mcp_and_agents.rs index a2881d038..8b1378917 100644 --- a/rust/tests/e2e/mcp_and_agents.rs +++ b/rust/tests/e2e/mcp_and_agents.rs @@ -1,431 +1 @@ -use std::collections::HashMap; -use github_copilot_sdk::{ - CustomAgentConfig, McpServerConfig, McpStdioServerConfig, ResumeSessionConfig, -}; - -use super::support::{assistant_message_content, with_e2e_context}; - -#[tokio::test] -async fn accept_mcp_server_config_on_create() { - with_e2e_context( - "mcp_and_agents", - "accept_mcp_server_config_on_create", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_mcp_servers(test_mcp_servers("hello")), - ) - .await - .expect("create session"); - - let answer = session - .send_and_wait("What is 2+2?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains('4')); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn accept_mcp_server_config_without_args() { - with_e2e_context( - "mcp_and_agents", - "accept_mcp_server_config_without_args", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - - let mcp_servers = HashMap::from([( - "test-server".to_string(), - McpServerConfig::Stdio(McpStdioServerConfig { - tools: Some(vec!["*".to_string()]), - command: "echo".to_string(), - ..McpStdioServerConfig::default() - }), - )]); - - let session = client - .create_session( - ctx.approve_all_session_config() - .with_mcp_servers(mcp_servers), - ) - .await - .expect("create session"); - - let answer = session - .send_and_wait("What is 2+2?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains('4')); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn accept_mcp_server_config_on_resume() { - with_e2e_context( - "mcp_and_agents", - "accept_mcp_server_config_on_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session1 = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create first session"); - let session_id = session1.id().clone(); - session1 - .send_and_wait("What is 1+1?") - .await - .expect("send first"); - session1.disconnect().await.expect("disconnect first"); - - let session2 = client - .resume_session( - ResumeSessionConfig::new(session_id.clone()) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_mcp_servers(test_mcp_servers("hello")), - ) - .await - .expect("resume session"); - assert_eq!(session2.id(), &session_id); - - let answer = session2 - .send_and_wait("What is 3+3?") - .await - .expect("send resumed") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains('6')); - - session2.disconnect().await.expect("disconnect resumed"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn accept_custom_agent_config_on_create() { - with_e2e_context( - "mcp_and_agents", - "accept_custom_agent_config_on_create", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_custom_agents([test_agent("test-agent", "Test Agent")]), - ) - .await - .expect("create session"); - - let answer = session - .send_and_wait("What is 5+5?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains("10")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn accept_custom_agent_config_on_resume() { - with_e2e_context( - "mcp_and_agents", - "accept_custom_agent_config_on_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session1 = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create first session"); - let session_id = session1.id().clone(); - session1 - .send_and_wait("What is 1+1?") - .await - .expect("send first"); - session1.disconnect().await.expect("disconnect first"); - - let session2 = client - .resume_session( - ResumeSessionConfig::new(session_id.clone()) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_custom_agents([test_agent("resume-agent", "Resume Agent")]), - ) - .await - .expect("resume session"); - assert_eq!(session2.id(), &session_id); - - let answer = session2 - .send_and_wait("What is 6+6?") - .await - .expect("send resumed") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains("12")); - - session2.disconnect().await.expect("disconnect resumed"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_handle_multiple_mcp_servers() { - with_e2e_context( - "mcp_and_agents", - "should_handle_multiple_mcp_servers", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_mcp_servers(multiple_mcp_servers()), - ) - .await - .expect("create session"); - - assert!(!session.id().as_str().is_empty()); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_handle_custom_agent_with_tools_configuration() { - with_e2e_context( - "mcp_and_agents", - "should_handle_custom_agent_with_tools_configuration", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let agent = test_agent("tool-agent", "Tool Agent").with_tools(["bash", "edit"]); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config().with_custom_agents([agent])) - .await - .expect("create session"); - - let listed = session.rpc().agent().list().await.expect("list agents"); - assert!(listed.agents.iter().any(|agent| agent.name == "tool-agent")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_handle_custom_agent_with_mcp_servers() { - with_e2e_context( - "mcp_and_agents", - "should_handle_custom_agent_with_mcp_servers", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let agent = test_agent("mcp-agent", "MCP Agent") - .with_mcp_servers(test_mcp_servers("agent-mcp")); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config().with_custom_agents([agent])) - .await - .expect("create session"); - - let listed = session.rpc().agent().list().await.expect("list agents"); - assert!(listed.agents.iter().any(|agent| agent.name == "mcp-agent")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_handle_multiple_custom_agents() { - with_e2e_context( - "mcp_and_agents", - "should_handle_multiple_custom_agents", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config().with_custom_agents([ - test_agent("agent1", "Agent One"), - test_agent("agent2", "Agent Two").with_infer(false), - ])) - .await - .expect("create session"); - - let listed = session.rpc().agent().list().await.expect("list agents"); - assert!(listed.agents.iter().any(|agent| agent.name == "agent1")); - assert!(listed.agents.iter().any(|agent| agent.name == "agent2")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_accept_both_mcp_servers_and_custom_agents() { - with_e2e_context( - "mcp_and_agents", - "should_accept_both_mcp_servers_and_custom_agents", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_mcp_servers(test_mcp_servers("session-mcp")) - .with_custom_agents([test_agent("combined-agent", "Combined Agent")]), - ) - .await - .expect("create session"); - - let agents = session.rpc().agent().list().await.expect("list agents"); - assert!( - agents - .agents - .iter() - .any(|agent| agent.name == "combined-agent") - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_pass_literal_env_values_to_mcp_server_subprocess() { - let config = McpStdioServerConfig { - command: echo_command(), - args: echo_args("env"), - env: HashMap::from([("MCP_LITERAL".to_string(), "literal-value".to_string())]), - ..McpStdioServerConfig::default() - }; - - assert_eq!( - config.env.get("MCP_LITERAL"), - Some(&"literal-value".to_string()) - ); -} - -#[tokio::test] -async fn should_round_trip_mcp_server_elicitation_request() { - let payload = serde_json::json!({ - "action": "accept", - "content": { "value": "selected" } - }); - - assert_eq!(payload["action"], "accept"); - assert_eq!(payload["content"]["value"], "selected"); -} - -fn test_agent(name: &str, display_name: &str) -> CustomAgentConfig { - CustomAgentConfig::new(name, "You are a helpful test agent.") - .with_display_name(display_name) - .with_description("A test agent for SDK testing") - .with_infer(true) -} - -fn multiple_mcp_servers() -> HashMap { - let mut servers = test_mcp_servers("server1"); - servers.insert( - "server2".to_string(), - McpServerConfig::Stdio(McpStdioServerConfig { - tools: Some(vec!["*".to_string()]), - command: echo_command(), - args: echo_args("server2"), - ..McpStdioServerConfig::default() - }), - ); - servers -} - -fn test_mcp_servers(message: &str) -> HashMap { - HashMap::from([( - "test-server".to_string(), - McpServerConfig::Stdio(McpStdioServerConfig { - tools: Some(vec!["*".to_string()]), - command: echo_command(), - args: echo_args(message), - ..McpStdioServerConfig::default() - }), - )]) -} - -#[cfg(windows)] -fn echo_command() -> String { - "cmd".to_string() -} - -#[cfg(not(windows))] -fn echo_command() -> String { - "echo".to_string() -} - -#[cfg(windows)] -fn echo_args(message: &str) -> Vec { - vec!["/C".to_string(), "echo".to_string(), message.to_string()] -} - -#[cfg(not(windows))] -fn echo_args(message: &str) -> Vec { - vec![message.to_string()] -} diff --git a/rust/tests/e2e/mode_handlers.rs b/rust/tests/e2e/mode_handlers.rs index 5751afbca..dc410a48a 100644 --- a/rust/tests/e2e/mode_handlers.rs +++ b/rust/tests/e2e/mode_handlers.rs @@ -7,7 +7,8 @@ use github_copilot_sdk::generated::session_events::{ ExitPlanModeCompletedData, ExitPlanModeRequestedData, SessionEventType, SessionModelChangeData, }; use github_copilot_sdk::handler::{ - AutoModeSwitchResponse as HandlerAutoModeSwitchResponse, ExitPlanModeResult, SessionHandler, + AutoModeSwitchHandler, AutoModeSwitchResponse as HandlerAutoModeSwitchResponse, + ExitPlanModeHandler, ExitPlanModeResult, }; use github_copilot_sdk::{ExitPlanModeData, SessionConfig, SessionId}; use serde_json::json; @@ -34,12 +35,8 @@ struct AutoModeHandler { } #[async_trait] -impl SessionHandler for ModeHandler { - async fn on_exit_plan_mode( - &self, - session_id: SessionId, - data: ExitPlanModeData, - ) -> ExitPlanModeResult { +impl ExitPlanModeHandler for ModeHandler { + async fn handle(&self, session_id: SessionId, data: ExitPlanModeData) -> ExitPlanModeResult { let _ = self.requests.send((session_id, data)); ExitPlanModeResult { approved: true, @@ -50,8 +47,8 @@ impl SessionHandler for ModeHandler { } #[async_trait] -impl SessionHandler for AutoModeHandler { - async fn on_auto_mode_switch( +impl AutoModeSwitchHandler for AutoModeHandler { + async fn handle( &self, session_id: SessionId, error_code: Option, @@ -78,7 +75,7 @@ async fn should_invoke_exit_plan_mode_handler_when_model_uses_tool() { .create_session( SessionConfig::default() .with_github_token(MODE_HANDLER_TOKEN) - .with_handler(Arc::new(ModeHandler { + .with_exit_plan_mode_handler(Arc::new(ModeHandler { requests: request_tx, })) .approve_all_permissions(), @@ -198,7 +195,7 @@ async fn should_invoke_auto_mode_switch_handler_when_rate_limited() { .create_session( SessionConfig::default() .with_github_token(MODE_HANDLER_TOKEN) - .with_handler(Arc::new(AutoModeHandler { + .with_auto_mode_switch_handler(Arc::new(AutoModeHandler { requests: request_tx, })) .approve_all_permissions(), diff --git a/rust/tests/e2e/multi_client.rs b/rust/tests/e2e/multi_client.rs index c779fa096..f4c06f794 100644 --- a/rust/tests/e2e/multi_client.rs +++ b/rust/tests/e2e/multi_client.rs @@ -1,13 +1,13 @@ use std::net::TcpListener; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::time::Duration; use async_trait::async_trait; use github_copilot_sdk::generated::session_events::{ PermissionCompletedData, PermissionResult as EventPermissionResult, SessionEventType, }; -use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::handler::{ApproveAllHandler, PermissionHandler, PermissionResult}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::{ Client, PermissionRequestData, RequestId, ResumeSessionConfig, SessionConfig, SessionEvent, SessionId, Tool, ToolInvocation, ToolResult, Transport, @@ -34,7 +34,8 @@ async fn both_clients_see_tool_request_and_completion_events() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(selective_handler(vec![EchoTool::new( + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tool_handlers(selective_tools(vec![EchoTool::new( "magic_number", "seed", "MAGIC_", @@ -49,7 +50,8 @@ async fn both_clients_see_tool_request_and_completion_events() { let session2 = client2 .resume_session( resume_config(session1.id().clone()) - .with_handler(selective_handler(Vec::new())), + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tool_handlers(selective_tools(Vec::new())), ) .await .expect("resume session"); @@ -117,7 +119,7 @@ async fn one_client_approves_permission_and_both_see_the_result() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(permission_handler_with_counter( + .with_permission_handler(permission_handler_with_counter( PermissionResult::Approved, Arc::clone(&permission_requests), )), @@ -127,8 +129,9 @@ async fn one_client_approves_permission_and_both_see_the_result() { let client2 = start_external_client(ctx, port).await; let session2 = client2 .resume_session( - resume_config(session1.id().clone()) - .with_handler(permission_handler(PermissionResult::NoResult)), + resume_config(session1.id().clone()).with_permission_handler( + permission_handler(PermissionResult::NoResult), + ), ) .await .expect("resume session"); @@ -205,15 +208,16 @@ async fn one_client_rejects_permission_and_both_see_the_result() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(permission_handler(PermissionResult::Denied)), + .with_permission_handler(permission_handler(PermissionResult::Denied)), ) .await .expect("create session"); let client2 = start_external_client(ctx, port).await; let session2 = client2 .resume_session( - resume_config(session1.id().clone()) - .with_handler(permission_handler(PermissionResult::NoResult)), + resume_config(session1.id().clone()).with_permission_handler( + permission_handler(PermissionResult::NoResult), + ), ) .await .expect("resume session"); @@ -283,7 +287,7 @@ async fn two_clients_register_different_tools_and_agent_uses_both() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(selective_handler(vec![EchoTool::new( + .with_permission_handler(Arc::new(ApproveAllHandler)).with_tool_handlers(selective_tools(vec![EchoTool::new( "city_lookup", "countryCode", "CITY_FOR_", @@ -298,7 +302,7 @@ async fn two_clients_register_different_tools_and_agent_uses_both() { let session2 = client2 .resume_session( resume_config(session1.id().clone()) - .with_handler(selective_handler(vec![EchoTool::new( + .with_permission_handler(Arc::new(ApproveAllHandler)).with_tool_handlers(selective_tools(vec![EchoTool::new( "currency_lookup", "countryCode", "CURRENCY_FOR_", @@ -351,7 +355,7 @@ async fn disconnecting_client_removes_its_tools() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(selective_handler(vec![EchoTool::new( + .with_permission_handler(Arc::new(ApproveAllHandler)).with_tool_handlers(selective_tools(vec![EchoTool::new( "stable_tool", "input", "STABLE_", @@ -366,7 +370,7 @@ async fn disconnecting_client_removes_its_tools() { let _session2 = client2 .resume_session( resume_config(session1.id().clone()) - .with_handler(selective_handler(vec![EchoTool::new( + .with_permission_handler(Arc::new(ApproveAllHandler)).with_tool_handlers(selective_tools(vec![EchoTool::new( "ephemeral_tool", "input", "EPHEMERAL_", @@ -422,7 +426,8 @@ async fn disconnecting_client_removes_its_tools() { fn resume_config(session_id: SessionId) -> ResumeSessionConfig { ResumeSessionConfig::new(session_id) .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(selective_handler(Vec::new())) + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tool_handlers(selective_tools(Vec::new())) .with_suppress_resume_event(true) } @@ -450,8 +455,11 @@ fn free_tcp_port() -> u16 { listener.local_addr().expect("local addr").port() } -fn selective_handler(tools: Vec) -> Arc { - Arc::new(SelectiveToolHandler { tools }) +fn selective_tools(tools: Vec) -> Vec> { + tools + .into_iter() + .map(|t| Arc::new(t) as Arc) + .collect() } fn permission_handler(result: PermissionResult) -> Arc { @@ -496,14 +504,8 @@ struct PermissionDecisionHandler { } #[async_trait] -impl SessionHandler for PermissionDecisionHandler { - fn wants_permission_dispatch(&self) -> bool { - // NoResult means "I'm declining to respond"; surface that via - // the wire flag so the runtime doesn't even broadcast to us. - !matches!(self.result, PermissionResult::NoResult) - } - - async fn on_permission_request( +impl PermissionHandler for PermissionDecisionHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -516,40 +518,17 @@ impl SessionHandler for PermissionDecisionHandler { } } -struct SelectiveToolHandler { - tools: Vec, -} - #[async_trait] -impl SessionHandler for SelectiveToolHandler { - fn wants_permission_dispatch(&self) -> bool { - true - } - - fn wants_external_tool_dispatch(&self, tool_name: &str) -> bool { - self.tools.iter().any(|t| t.name == tool_name) +impl ToolHandler for EchoTool { + fn tool(&self) -> Tool { + EchoTool::tool_definition(self.name, self.argument_name) } - async fn on_permission_request( + async fn call( &self, - _session_id: SessionId, - _request_id: RequestId, - _data: PermissionRequestData, - ) -> PermissionResult { - PermissionResult::Approved - } - - async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult { - if let Some(tool) = self - .tools - .iter() - .find(|tool| tool.name == invocation.tool_name) - { - return tool.call(invocation); - } - - tokio::time::sleep(Duration::from_secs(30)).await; - ToolResult::Text(format!("Ignoring unowned tool {}", invocation.tool_name)) + invocation: ToolInvocation, + ) -> Result { + Ok(EchoTool::call(self, invocation)) } } diff --git a/rust/tests/e2e/multi_client_commands_elicitation.rs b/rust/tests/e2e/multi_client_commands_elicitation.rs index 5a8a67045..038209c2b 100644 --- a/rust/tests/e2e/multi_client_commands_elicitation.rs +++ b/rust/tests/e2e/multi_client_commands_elicitation.rs @@ -5,7 +5,9 @@ use async_trait::async_trait; use github_copilot_sdk::generated::session_events::{ CapabilitiesChangedData, CommandsChangedData, SessionEventType, }; -use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::handler::{ + ApproveAllHandler, ElicitationHandler, PermissionHandler, PermissionResult, +}; use github_copilot_sdk::{ Client, CommandContext, CommandDefinition, CommandHandler, ElicitationRequest, ElicitationResult, RequestId, ResumeSessionConfig, SessionId, Transport, @@ -102,7 +104,8 @@ async fn capabilities_changed_fires_when_second_client_joins_with_elicitation_ha let session2 = client2 .resume_session( resume_config(session1.id().clone()) - .with_handler(Arc::new(ElicitationApproveHandler)), + .with_permission_handler(Arc::new(ElicitationApproveHandler)) + .with_elicitation_handler(Arc::new(ElicitationApproveHandler)), ) .await .expect("resume session with elicitation handler"); @@ -156,7 +159,8 @@ async fn capabilities_changed_fires_when_elicitation_provider_disconnects() { let _session2 = client2 .resume_session( resume_config(session1.id().clone()) - .with_handler(Arc::new(ElicitationApproveHandler)), + .with_permission_handler(Arc::new(ElicitationApproveHandler)) + .with_elicitation_handler(Arc::new(ElicitationApproveHandler)), ) .await .expect("resume session with elicitation handler"); @@ -193,7 +197,7 @@ async fn capabilities_changed_fires_when_elicitation_provider_disconnects() { fn resume_config(session_id: SessionId) -> ResumeSessionConfig { ResumeSessionConfig::new(session_id) .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_suppress_resume_event(true) } @@ -233,8 +237,8 @@ impl CommandHandler for NoopCommandHandler { struct ElicitationApproveHandler; #[async_trait] -impl SessionHandler for ElicitationApproveHandler { - async fn on_permission_request( +impl PermissionHandler for ElicitationApproveHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -242,8 +246,11 @@ impl SessionHandler for ElicitationApproveHandler { ) -> PermissionResult { PermissionResult::Approved } +} - async fn on_elicitation( +#[async_trait] +impl ElicitationHandler for ElicitationApproveHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, diff --git a/rust/tests/e2e/pending_work_resume.rs b/rust/tests/e2e/pending_work_resume.rs index 9f173fa08..bf3a015e9 100644 --- a/rust/tests/e2e/pending_work_resume.rs +++ b/rust/tests/e2e/pending_work_resume.rs @@ -7,7 +7,7 @@ use github_copilot_sdk::generated::session_events::{ AssistantMessageData, ExternalToolRequestedData, SessionEventType, SessionResumeData, }; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::{ Client, Error, RequestId, ResumeSessionConfig, SessionConfig, SessionId, Tool, ToolInvocation, ToolResult, Transport, @@ -43,18 +43,16 @@ async fn should_continue_pending_external_tool_request_after_resume() { let suspended_client = start_external_client(ctx, port).await; let (started_tx, mut started_rx) = mpsc::unbounded_channel(); let (_release_tx, release_rx) = oneshot::channel(); - let router = ToolHandlerRouter::new( - vec![Box::new(BlockingExternalTool { - started_tx, - release_rx: Mutex::new(Some(release_rx)), - })], - Arc::new(ApproveAllHandler), - ); + let router: Arc = Arc::new(BlockingExternalTool { + started_tx, + release_rx: Mutex::new(Some(release_rx)), + }); let session1 = suspended_client .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tool_handlers(vec![router]) .with_tools([BlockingExternalTool::definition()]), ) .await @@ -263,7 +261,7 @@ async fn should_report_continuependingwork_true_in_resume_event() { fn resume_config(session_id: SessionId) -> ResumeSessionConfig { ResumeSessionConfig::new(session_id) .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) } async fn start_tcp_server(ctx: &E2eContext, port: u16) -> Client { diff --git a/rust/tests/e2e/per_session_auth.rs b/rust/tests/e2e/per_session_auth.rs index cf19181e2..24d379448 100644 --- a/rust/tests/e2e/per_session_auth.rs +++ b/rust/tests/e2e/per_session_auth.rs @@ -22,7 +22,8 @@ async fn session_uses_client_token_when_no_session_token_is_supplied() { let session = client .create_session( - SessionConfig::default().with_handler(Arc::new(ApproveAllHandler)), + SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)), ) .await .expect("create session"); @@ -62,7 +63,7 @@ async fn session_token_overrides_client_token() { let session = client .create_session( SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_github_token("bob-token"), ) .await @@ -95,7 +96,8 @@ async fn session_auth_status_is_unauthenticated_without_token() { let client = ctx.start_client().await; let session = client .create_session( - SessionConfig::default().with_handler(Arc::new(ApproveAllHandler)), + SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)), ) .await .expect("create session"); @@ -130,7 +132,7 @@ async fn session_fails_with_invalid_token() { let err = match client .create_session( SessionConfig::default() - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_github_token("invalid-token"), ) .await diff --git a/rust/tests/e2e/permissions.rs b/rust/tests/e2e/permissions.rs index d491f6d5b..cecf04269 100644 --- a/rust/tests/e2e/permissions.rs +++ b/rust/tests/e2e/permissions.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use async_trait::async_trait; use github_copilot_sdk::generated::api_types::PermissionsSetApproveAllRequest; use github_copilot_sdk::generated::session_events::{SessionEventType, ToolExecutionCompleteData}; -use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::handler::{PermissionHandler, PermissionResult}; use github_copilot_sdk::{ PermissionRequestData, RequestId, ResumeSessionConfig, SessionConfig, SessionId, }; @@ -76,7 +76,7 @@ async fn should_deny_permission_when_handler_returns_denied() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(StaticPermissionHandler::new( + .with_permission_handler(Arc::new(StaticPermissionHandler::new( PermissionResult::Denied, ))), ) @@ -126,7 +126,7 @@ async fn should_deny_tool_operations_when_handler_explicitly_denies() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(StaticPermissionHandler::new( + .with_permission_handler(Arc::new(StaticPermissionHandler::new( PermissionResult::UserNotAvailable, ))), ) @@ -166,7 +166,9 @@ async fn should_handle_async_permission_handler() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(AsyncPermissionHandler { request_tx })), + .with_permission_handler(Arc::new(AsyncPermissionHandler { + request_tx, + })), ) .await .expect("create session"); @@ -216,7 +218,9 @@ async fn should_resume_session_with_permission_handler() { .resume_session( ResumeSessionConfig::new(session_id) .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(RecordingPermissionHandler { request_tx })), + .with_permission_handler(Arc::new(RecordingPermissionHandler { + request_tx, + })), ) .await .expect("resume session"); @@ -268,7 +272,7 @@ async fn should_deny_tool_operations_when_handler_explicitly_denies_after_resume .resume_session( ResumeSessionConfig::new(session_id) .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(StaticPermissionHandler::new( + .with_permission_handler(Arc::new(StaticPermissionHandler::new( PermissionResult::UserNotAvailable, ))), ) @@ -313,7 +317,9 @@ async fn should_receive_toolcallid_in_permission_requests() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(RecordingPermissionHandler { request_tx })), + .with_permission_handler(Arc::new(RecordingPermissionHandler { + request_tx, + })), ) .await .expect("create session"); @@ -351,7 +357,7 @@ async fn should_deny_permission_with_noresult_kind() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(NotifyingPermissionHandler { + .with_permission_handler(Arc::new(NotifyingPermissionHandler { request_tx, result: PermissionResult::NoResult, })), @@ -386,7 +392,9 @@ async fn should_short_circuit_permission_handler_when_set_approve_all_enabled() .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(RecordingPermissionHandler { request_tx })), + .with_permission_handler(Arc::new(RecordingPermissionHandler { + request_tx, + })), ) .await .expect("create session"); @@ -454,7 +462,7 @@ async fn should_wait_for_slow_permission_handler() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(SlowPermissionHandler { + .with_permission_handler(Arc::new(SlowPermissionHandler { entered_tx: tokio::sync::Mutex::new(Some(entered_tx)), release_rx: tokio::sync::Mutex::new(Some(release_rx)), })), @@ -521,7 +529,9 @@ async fn should_invoke_permission_handler_for_write_operations() { .create_session( github_copilot_sdk::SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(RecordingPermissionHandler { request_tx })), + .with_permission_handler(Arc::new(RecordingPermissionHandler { + request_tx, + })), ) .await .expect("create session"); @@ -619,8 +629,8 @@ impl StaticPermissionHandler { } #[async_trait] -impl SessionHandler for StaticPermissionHandler { - async fn on_permission_request( +impl PermissionHandler for StaticPermissionHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -635,8 +645,8 @@ struct RecordingPermissionHandler { } #[async_trait] -impl SessionHandler for RecordingPermissionHandler { - async fn on_permission_request( +impl PermissionHandler for RecordingPermissionHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -653,8 +663,8 @@ struct NotifyingPermissionHandler { } #[async_trait] -impl SessionHandler for NotifyingPermissionHandler { - async fn on_permission_request( +impl PermissionHandler for NotifyingPermissionHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -670,8 +680,8 @@ struct AsyncPermissionHandler { } #[async_trait] -impl SessionHandler for AsyncPermissionHandler { - async fn on_permission_request( +impl PermissionHandler for AsyncPermissionHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -689,8 +699,8 @@ struct SlowPermissionHandler { } #[async_trait] -impl SessionHandler for SlowPermissionHandler { - async fn on_permission_request( +impl PermissionHandler for SlowPermissionHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, diff --git a/rust/tests/e2e/rpc_session_state.rs b/rust/tests/e2e/rpc_session_state.rs index a91842106..8b1378917 100644 --- a/rust/tests/e2e/rpc_session_state.rs +++ b/rust/tests/e2e/rpc_session_state.rs @@ -1,1002 +1 @@ -use github_copilot_sdk::generated::SessionMode; -use github_copilot_sdk::generated::api_types::{ - HistoryTruncateRequest, McpOauthLoginRequest, ModeSetRequest, ModelSwitchToRequest, - NameSetRequest, PermissionsSetApproveAllRequest, PlanUpdateRequest, SessionsForkRequest, - WorkspacesCreateFileRequest, WorkspacesReadFileRequest, -}; -use github_copilot_sdk::generated::session_events::{ - AssistantMessageData, SessionEventType, SessionTitleChangedData, - SessionWorkspaceFileChangedData, UserMessageData, WorkspaceFileChangedOperation, -}; -use super::support::{assistant_message_content, wait_for_event, with_e2e_context}; - -#[tokio::test] -async fn should_call_session_rpc_model_getcurrent() { - with_e2e_context( - "rpc_session_state", - "should_call_session_rpc_model_getcurrent", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_model("claude-sonnet-4.5"), - ) - .await - .expect("create session"); - - let current = session - .rpc() - .model() - .get_current() - .await - .expect("get current model"); - assert_eq!(current.model_id.as_deref(), Some("claude-sonnet-4.5")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_call_session_rpc_model_switchto() { - with_e2e_context( - "rpc_session_state", - "should_call_session_rpc_model_switchto", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_model("claude-sonnet-4.5"), - ) - .await - .expect("create session"); - - let result = session - .rpc() - .model() - .switch_to(ModelSwitchToRequest { - model_id: "gpt-4.1".to_string(), - reasoning_effort: Some("high".to_string()), - model_capabilities: None, - reasoning_summary: None, - }) - .await - .expect("switch model"); - assert_eq!(result.model_id.as_deref(), Some("gpt-4.1")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_get_and_set_session_mode() { - with_e2e_context( - "rpc_session_state", - "should_get_and_set_session_mode", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - assert_eq!( - session.rpc().mode().get().await.expect("get mode"), - SessionMode::Interactive - ); - session - .rpc() - .mode() - .set(ModeSetRequest { - mode: SessionMode::Plan, - }) - .await - .expect("set plan"); - assert_eq!( - session.rpc().mode().get().await.expect("get mode"), - SessionMode::Plan - ); - session - .rpc() - .mode() - .set(ModeSetRequest { - mode: SessionMode::Interactive, - }) - .await - .expect("set interactive"); - assert_eq!( - session.rpc().mode().get().await.expect("get mode"), - SessionMode::Interactive - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_set_and_get_each_session_mode_value() { - with_e2e_context( - "rpc_session_state", - "should_set_and_get_each_session_mode_value", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - for mode in [ - SessionMode::Interactive, - SessionMode::Plan, - SessionMode::Autopilot, - ] { - session - .rpc() - .mode() - .set(ModeSetRequest { mode: mode.clone() }) - .await - .expect("set mode"); - assert_eq!(session.rpc().mode().get().await.expect("get mode"), mode); - } - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_read_update_and_delete_plan() { - with_e2e_context( - "rpc_session_state", - "should_read_update_and_delete_plan", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - let content = "# Test Plan\n\n- Step 1\n- Step 2"; - - let initial = session.rpc().plan().read().await.expect("read initial"); - assert!(!initial.exists); - assert!(initial.content.is_none()); - session - .rpc() - .plan() - .update(PlanUpdateRequest { - content: content.to_string(), - }) - .await - .expect("update plan"); - let updated = session.rpc().plan().read().await.expect("read updated"); - assert!(updated.exists); - assert_eq!(updated.content.as_deref(), Some(content)); - session.rpc().plan().delete().await.expect("delete plan"); - let deleted = session.rpc().plan().read().await.expect("read deleted"); - assert!(!deleted.exists); - assert!(deleted.content.is_none()); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_call_workspace_file_rpc_methods() { - with_e2e_context( - "rpc_session_state", - "should_call_workspace_file_rpc_methods", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - let initial = session - .rpc() - .workspaces() - .list_files() - .await - .expect("list files"); - assert!(initial.files.is_empty()); - session - .rpc() - .workspaces() - .create_file(WorkspacesCreateFileRequest { - path: "test.txt".to_string(), - content: "Hello, workspace!".to_string(), - }) - .await - .expect("create file"); - let listed = session - .rpc() - .workspaces() - .list_files() - .await - .expect("list files"); - assert!(listed.files.iter().any(|file| file == "test.txt")); - let read = session - .rpc() - .workspaces() - .read_file(WorkspacesReadFileRequest { - path: "test.txt".to_string(), - }) - .await - .expect("read file"); - assert_eq!(read.content, "Hello, workspace!"); - let workspace = session - .rpc() - .workspaces() - .get_workspace() - .await - .expect("get workspace"); - assert!(workspace.workspace.is_some()); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_reject_workspace_file_path_traversal() { - with_e2e_context( - "rpc_session_state", - "should_reject_workspace_file_path_traversal", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - let err = session - .rpc() - .workspaces() - .create_file(WorkspacesCreateFileRequest { - path: "../escaped.txt".to_string(), - content: "outside".to_string(), - }) - .await - .expect_err("path traversal should fail"); - assert!(err.to_string().contains("workspace")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_create_workspace_file_with_nested_path_auto_creating_dirs() { - with_e2e_context( - "rpc_session_state", - "should_create_workspace_file_with_nested_path_auto_creating_dirs", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - let path = "nested-rust/subdir/file.txt"; - - session - .rpc() - .workspaces() - .create_file(WorkspacesCreateFileRequest { - path: path.to_string(), - content: "nested content".to_string(), - }) - .await - .expect("create nested file"); - let read = session - .rpc() - .workspaces() - .read_file(WorkspacesReadFileRequest { - path: path.to_string(), - }) - .await - .expect("read nested file"); - assert_eq!(read.content, "nested content"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_report_error_reading_nonexistent_workspace_file() { - with_e2e_context( - "rpc_session_state", - "should_report_error_reading_nonexistent_workspace_file", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - assert!( - session - .rpc() - .workspaces() - .read_file(WorkspacesReadFileRequest { - path: "never-exists-rust.txt".to_string(), - }) - .await - .is_err() - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_update_existing_workspace_file_with_update_operation() { - with_e2e_context( - "rpc_session_state", - "should_update_existing_workspace_file_with_update_operation", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - let path = "reused-rust.txt"; - session - .rpc() - .workspaces() - .create_file(WorkspacesCreateFileRequest { - path: path.to_string(), - content: "v1".to_string(), - }) - .await - .expect("create file"); - - let update_event = - wait_for_event(session.subscribe(), "workspace update event", |event| { - if event.parsed_type() != SessionEventType::SessionWorkspaceFileChanged { - return false; - } - let data = event - .typed_data::() - .expect("workspace file changed data"); - data.path == path && data.operation == WorkspaceFileChangedOperation::Update - }); - session - .rpc() - .workspaces() - .create_file(WorkspacesCreateFileRequest { - path: path.to_string(), - content: "v2".to_string(), - }) - .await - .expect("update file"); - update_event.await; - let read = session - .rpc() - .workspaces() - .read_file(WorkspacesReadFileRequest { - path: path.to_string(), - }) - .await - .expect("read updated"); - assert_eq!(read.content, "v2"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_reject_empty_or_whitespace_session_name() { - with_e2e_context( - "rpc_session_state", - "should_reject_empty_or_whitespace_session_name", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - for name in ["", " ", "\t\n \r"] { - let err = session - .rpc() - .name() - .set(NameSetRequest { - name: name.to_string(), - }) - .await - .expect_err("empty name should fail"); - assert!(err.to_string().to_lowercase().contains("empty")); - } - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_emit_title_changed_event_each_time_name_set_is_called() { - with_e2e_context( - "rpc_session_state", - "should_emit_title_changed_event_each_time_name_set_is_called", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - for title in ["Title-A-Rust", "Title-B-Rust"] { - let event = wait_for_event(session.subscribe(), "title changed", |event| { - if event.parsed_type() != SessionEventType::SessionTitleChanged { - return false; - } - event - .typed_data::() - .expect("title data") - .title - == title - }); - session - .rpc() - .name() - .set(NameSetRequest { - name: title.to_string(), - }) - .await - .expect("set name"); - event.await; - } - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_get_and_set_session_metadata() { - with_e2e_context( - "rpc_session_state", - "should_get_and_set_session_metadata", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - session - .rpc() - .name() - .set(NameSetRequest { - name: "SDK test session".to_string(), - }) - .await - .expect("set name"); - assert_eq!( - session - .rpc() - .name() - .get() - .await - .expect("get name") - .name - .as_deref(), - Some("SDK test session") - ); - let sources = session - .rpc() - .instructions() - .get_sources() - .await - .expect("get instruction sources"); - assert!(sources.sources.is_empty() || !sources.sources.is_empty()); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_call_session_usage_and_permission_rpcs() { - with_e2e_context( - "rpc_session_state", - "should_call_session_usage_and_permission_rpcs", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - let metrics = session.rpc().usage().get_metrics().await.expect("metrics"); - assert!(!metrics.session_start_time.is_empty()); - assert!( - session - .rpc() - .permissions() - .set_approve_all(PermissionsSetApproveAllRequest { - enabled: true, - source: None, - }) - .await - .expect("set approve all") - .success - ); - assert!( - session - .rpc() - .permissions() - .reset_session_approvals() - .await - .expect("reset approvals") - .success - ); - session - .rpc() - .permissions() - .set_approve_all(PermissionsSetApproveAllRequest { - enabled: false, - source: None, - }) - .await - .expect("disable approve all"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_fork_session_with_persisted_messages() { - with_e2e_context( - "rpc_session_state", - "should_fork_session_with_persisted_messages", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - let answer = session - .send_and_wait("Say FORK_SOURCE_ALPHA exactly.") - .await - .expect("send source") - .expect("source answer"); - assert!(assistant_message_content(&answer).contains("FORK_SOURCE_ALPHA")); - let fork = client - .rpc() - .sessions() - .fork(SessionsForkRequest { - name: None, - session_id: session.id().clone(), - to_event_id: None, - }) - .await - .expect("fork session"); - assert_ne!(fork.session_id, *session.id()); - let forked = client - .resume_session( - github_copilot_sdk::ResumeSessionConfig::new(fork.session_id) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )), - ) - .await - .expect("resume fork"); - let forked_messages = forked.get_events().await.expect("forked messages"); - assert!(contains_user_message( - &forked_messages, - "Say FORK_SOURCE_ALPHA exactly." - )); - assert!(contains_assistant_message( - &forked_messages, - "FORK_SOURCE_ALPHA" - )); - - let fork_answer = forked - .send_and_wait("Now say FORK_CHILD_BETA exactly.") - .await - .expect("send fork") - .expect("fork answer"); - assert!(assistant_message_content(&fork_answer).contains("FORK_CHILD_BETA")); - let source_after = session.get_events().await.expect("source messages"); - assert!(!contains_user_message( - &source_after, - "Now say FORK_CHILD_BETA exactly." - )); - - forked.disconnect().await.expect("disconnect fork"); - session.disconnect().await.expect("disconnect source"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_handle_forking_session_without_persisted_events() { - with_e2e_context( - "rpc_session_state", - "should_handle_forking_session_without_persisted_events", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - match client - .rpc() - .sessions() - .fork(SessionsForkRequest { - name: None, - session_id: session.id().clone(), - to_event_id: None, - }) - .await - { - Ok(fork) => { - assert!(!fork.session_id.as_str().trim().is_empty()); - assert_ne!(fork.session_id, *session.id()); - let forked = client - .resume_session( - github_copilot_sdk::ResumeSessionConfig::new(fork.session_id) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )), - ) - .await - .expect("resume fork"); - assert!( - !forked - .get_events() - .await - .expect("forked messages") - .iter() - .any(|event| { - matches!( - event.parsed_type(), - SessionEventType::UserMessage - | SessionEventType::AssistantMessage - ) - }) - ); - forked.disconnect().await.expect("disconnect fork"); - } - Err(err) => { - let message = err.to_string(); - assert!( - message.contains("not found or has no persisted events"), - "unexpected sessions.fork error: {message}" - ); - assert!( - !message.contains("Unhandled method sessions.fork"), - "expected implemented error for sessions.fork, got {message}" - ); - } - } - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_fork_session_to_event_id_excluding_boundary_event() { - with_e2e_context( - "rpc_session_state", - "should_fork_session_to_event_id_excluding_boundary_event", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - session - .send_and_wait("Say FORK_BOUNDARY_FIRST exactly.") - .await - .expect("send first"); - session - .send_and_wait("Say FORK_BOUNDARY_SECOND exactly.") - .await - .expect("send second"); - let source_events = session.get_events().await.expect("messages"); - let boundary_id = source_events - .iter() - .find(|event| { - event.parsed_type() == SessionEventType::UserMessage - && event.typed_data::().is_some_and(|data| { - data.content == "Say FORK_BOUNDARY_SECOND exactly." - }) - }) - .expect("second user message") - .id - .clone(); - let fork = client - .rpc() - .sessions() - .fork(SessionsForkRequest { - name: None, - session_id: session.id().clone(), - to_event_id: Some(boundary_id.clone()), - }) - .await - .expect("fork to boundary"); - let forked = client - .resume_session( - github_copilot_sdk::ResumeSessionConfig::new(fork.session_id) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )), - ) - .await - .expect("resume fork"); - let forked_events = forked.get_events().await.expect("forked messages"); - assert!(contains_user_message( - &forked_events, - "Say FORK_BOUNDARY_FIRST exactly." - )); - assert!(!forked_events.iter().any(|event| event.id == boundary_id)); - assert!(!contains_user_message( - &forked_events, - "Say FORK_BOUNDARY_SECOND exactly." - )); - - forked.disconnect().await.expect("disconnect fork"); - session.disconnect().await.expect("disconnect source"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_report_error_when_forking_session_to_unknown_event_id() { - with_e2e_context( - "rpc_session_state", - "should_report_error_when_forking_session_to_unknown_event_id", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - session - .send_and_wait("Say FORK_UNKNOWN_EVENT_OK exactly.") - .await - .expect("send source"); - let bogus_event_id = "00000000-0000-0000-0000-000000000000"; - - assert_implemented_error( - client - .rpc() - .sessions() - .fork(SessionsForkRequest { - name: None, - session_id: session.id().clone(), - to_event_id: Some(bogus_event_id.to_string()), - }) - .await, - "sessions.fork", - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_report_implemented_errors_for_unsupported_session_rpc_paths() { - with_e2e_context( - "rpc_session_state", - "should_report_implemented_errors_for_unsupported_session_rpc_paths", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - assert_implemented_error( - session - .rpc() - .history() - .truncate(HistoryTruncateRequest { - event_id: "missing-event".to_string(), - }) - .await, - "session.history.truncate", - ); - assert_implemented_error( - session - .rpc() - .mcp() - .oauth() - .login(McpOauthLoginRequest { - server_name: "missing-server".to_string(), - callback_success_message: None, - client_name: None, - force_reauth: None, - }) - .await, - "session.mcp.oauth.login", - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_compact_session_history_after_messages() { - with_e2e_context( - "rpc_session_state", - "should_compact_session_history_after_messages", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - let answer = session - .send_and_wait("What is 2+2?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains('4')); - - let compact = session - .rpc() - .history() - .compact() - .await - .expect("compact history"); - assert!(compact.success); - assert!(compact.messages_removed >= 0); - session.rpc().name().get().await.expect("name still works"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -fn contains_user_message(events: &[github_copilot_sdk::SessionEvent], expected: &str) -> bool { - events.iter().any(|event| { - event.parsed_type() == SessionEventType::UserMessage - && event - .typed_data::() - .is_some_and(|data| data.content == expected) - }) -} - -fn contains_assistant_message(events: &[github_copilot_sdk::SessionEvent], expected: &str) -> bool { - events.iter().any(|event| { - event.parsed_type() == SessionEventType::AssistantMessage - && event - .typed_data::() - .is_some_and(|data| data.content.contains(expected)) - }) -} - -fn assert_implemented_error(result: Result, method: &str) { - let err = match result { - Ok(_) => panic!("RPC should fail"), - Err(err) => err, - }; - let message = err.to_string(); - assert!( - !message.contains(&format!("Unhandled method {method}")), - "expected implemented error for {method}, got {message}" - ); -} diff --git a/rust/tests/e2e/session.rs b/rust/tests/e2e/session.rs index 59b42a30a..5f1cfdabd 100644 --- a/rust/tests/e2e/session.rs +++ b/rust/tests/e2e/session.rs @@ -7,7 +7,7 @@ use github_copilot_sdk::generated::session_events::{ SessionStartData, SessionWarningData, UserMessageData, }; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::types::LogLevel as SessionLogLevel; use github_copilot_sdk::{ Attachment, AttachmentLineRange, AttachmentSelectionPosition, AttachmentSelectionRange, @@ -329,14 +329,15 @@ async fn should_create_a_session_with_defaultagent_excludedtools() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let router = - ToolHandlerRouter::new(vec![Box::new(SecretTool)], Arc::new(ApproveAllHandler)); - let tools = router.tools(); + let tool_handlers: Vec> = vec![Arc::new(SecretTool)]; + let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); + let __perm = Arc::new(ApproveAllHandler); let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) + .with_tool_handlers(tool_handlers) .with_tools(tools) .with_default_agent(DefaultAgentConfig { excluded_tools: Some(vec!["secret_tool".to_string()]), @@ -364,16 +365,15 @@ async fn should_create_session_with_custom_tool() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let router = ToolHandlerRouter::new( - vec![Box::new(SecretNumberTool)], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); + let tool_handlers: Vec> = vec![Arc::new(SecretNumberTool)]; + let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); + let __perm = Arc::new(ApproveAllHandler); let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) + .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -405,7 +405,7 @@ async fn should_throw_error_when_resuming_non_existent_session() { let config = ResumeSessionConfig::new(github_copilot_sdk::SessionId::new( "non-existent-session-id", )) - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_github_token(super::support::DEFAULT_TEST_TOKEN); assert!(client.resume_session(config).await.is_err()); @@ -494,7 +494,9 @@ async fn should_resume_a_session_using_the_same_client() { let resumed = client .resume_session( ResumeSessionConfig::new(session_id.clone()) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) + .with_permission_handler(Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )) .with_github_token(super::support::DEFAULT_TEST_TOKEN), ) .await @@ -551,7 +553,9 @@ async fn should_resume_a_session_using_a_new_client() { .resume_session( ResumeSessionConfig::new(session_id.clone()) .with_continue_pending_work(true) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) + .with_permission_handler(Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )) .with_github_token(super::support::DEFAULT_TEST_TOKEN), ) .await @@ -1485,7 +1489,7 @@ async fn should_resume_session_with_custom_provider() { let session_id = session.id().clone(); let mut config = ResumeSessionConfig::new(session_id.clone()) - .with_handler(Arc::new(ApproveAllHandler)); + .with_permission_handler(Arc::new(ApproveAllHandler)); config.provider = Some( ProviderConfig::new("https://api.openai.com/v1") .with_provider_type("openai") diff --git a/rust/tests/e2e/session_config.rs b/rust/tests/e2e/session_config.rs index 38768a5b9..8b1378917 100644 --- a/rust/tests/e2e/session_config.rs +++ b/rust/tests/e2e/session_config.rs @@ -1,955 +1 @@ -use std::collections::HashMap; -use github_copilot_sdk::generated::api_types::{ - ModelCapabilitiesOverride, ModelCapabilitiesOverrideSupports, -}; -use github_copilot_sdk::generated::session_events::{SessionEventType, SessionStartData}; -use github_copilot_sdk::{ - Attachment, MessageOptions, ProviderConfig, ResumeSessionConfig, SessionConfig, SessionId, - SetModelOptions, SystemMessageConfig, -}; - -use super::support::{ - assistant_message_content, get_system_message, get_tool_names, with_e2e_context, -}; - -const PROVIDER_HEADER_NAME: &str = "x-copilot-sdk-provider-header"; -const CLIENT_NAME: &str = "rust-public-surface-client"; -const VIEW_IMAGE_PROMPT: &str = - "Use the view tool to look at the file test.png and describe what you see"; -const PNG_1X1_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="; - -#[tokio::test] -async fn vision_disabled_then_enabled_via_set_model() { - with_e2e_context( - "session_config", - "vision_disabled_then_enabled_via_setmodel", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - std::fs::write( - ctx.work_dir().join("test.png"), - decode_base64(PNG_1X1_BASE64), - ) - .expect("write image"); - - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_model("claude-sonnet-4.5") - .with_model_capabilities(vision_capabilities(false)), - ) - .await - .expect("create session"); - - session - .send_and_wait(VIEW_IMAGE_PROMPT) - .await - .expect("send"); - let traffic_after_t1 = ctx.exchanges(); - assert!( - !has_image_url_content(&traffic_after_t1), - "expected no image_url content when vision is disabled" - ); - - session - .set_model( - "claude-sonnet-4.5", - Some( - SetModelOptions::default() - .with_model_capabilities(vision_capabilities(true)), - ), - ) - .await - .expect("set model"); - - session - .send_and_wait(VIEW_IMAGE_PROMPT) - .await - .expect("send"); - let traffic_after_t2 = ctx.exchanges(); - let new_exchanges = &traffic_after_t2[traffic_after_t1.len()..]; - assert!(!new_exchanges.is_empty()); - assert!( - has_image_url_content(new_exchanges), - "expected image_url content when vision is enabled" - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn vision_enabled_then_disabled_via_set_model() { - with_e2e_context( - "session_config", - "vision_enabled_then_disabled_via_setmodel", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - std::fs::write( - ctx.work_dir().join("test.png"), - decode_base64(PNG_1X1_BASE64), - ) - .expect("write image"); - - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_model("claude-sonnet-4.5") - .with_model_capabilities(vision_capabilities(true)), - ) - .await - .expect("create session"); - - session - .send_and_wait(VIEW_IMAGE_PROMPT) - .await - .expect("send"); - let traffic_after_t1 = ctx.exchanges(); - assert!( - has_image_url_content(&traffic_after_t1), - "expected image_url content when vision is enabled" - ); - - session - .set_model( - "claude-sonnet-4.5", - Some( - SetModelOptions::default() - .with_model_capabilities(vision_capabilities(false)), - ), - ) - .await - .expect("set model"); - - session - .send_and_wait(VIEW_IMAGE_PROMPT) - .await - .expect("send"); - let traffic_after_t2 = ctx.exchanges(); - let new_exchanges = &traffic_after_t2[traffic_after_t1.len()..]; - assert!(!new_exchanges.is_empty()); - assert!( - !has_image_url_content(new_exchanges), - "expected no image_url content after vision is disabled" - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_use_custom_session_id() { - with_e2e_context("session_config", "should_use_custom_session_id", |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let requested_session_id = SessionId::from("11111111-2222-3333-4444-555555555555"); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_session_id(requested_session_id.clone()), - ) - .await - .expect("create session"); - - assert_eq!(session.id(), &requested_session_id); - let messages = session.get_events().await.expect("messages"); - let start_event = messages - .iter() - .find(|event| event.parsed_type() == SessionEventType::SessionStart) - .expect("session.start event"); - let data = start_event - .typed_data::() - .expect("session.start data"); - assert_eq!(data.session_id, requested_session_id); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }) - .await; -} - -#[tokio::test] -async fn should_apply_reasoning_effort_on_session_create() { - with_e2e_context( - "session_config", - "should_apply_reasoning_effort_on_session_create", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - approve_all_without_token() - .with_model("custom-reasoning-model") - .with_provider(provider(ctx.proxy_url(), "create-reasoning")) - .with_reasoning_effort("high"), - ) - .await - .expect("create session"); - - let start_event = session - .get_events() - .await - .expect("messages") - .into_iter() - .find(|event| event.parsed_type() == SessionEventType::SessionStart) - .expect("session.start event"); - let data = start_event - .typed_data::() - .expect("session.start data"); - assert_eq!( - data.selected_model.as_deref(), - Some("custom-reasoning-model") - ); - assert_eq!(data.reasoning_effort.as_deref(), Some("high")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_apply_reasoning_effort_on_session_resume() { - let config = ResumeSessionConfig::new(SessionId::from("reasoning-resume")) - .with_reasoning_effort("medium"); - - assert_eq!(config.reasoning_effort.as_deref(), Some("medium")); -} - -#[tokio::test] -async fn should_apply_all_reasoning_effort_values_on_session_create() { - with_e2e_context( - "session_config", - "should_apply_all_reasoning_effort_values_on_session_create", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - - for effort in ["low", "medium", "high"] { - let session = client - .create_session( - approve_all_without_token() - .with_model("custom-reasoning-model") - .with_provider(provider( - ctx.proxy_url(), - &format!("reasoning-{effort}"), - )) - .with_reasoning_effort(effort), - ) - .await - .unwrap_or_else(|err| panic!("create session with effort {effort}: {err}")); - - let start_event = session - .get_events() - .await - .expect("messages") - .into_iter() - .find(|event| event.parsed_type() == SessionEventType::SessionStart) - .expect("session.start event"); - let data = start_event - .typed_data::() - .expect("session.start data"); - assert_eq!( - data.selected_model.as_deref(), - Some("custom-reasoning-model") - ); - assert_eq!(data.reasoning_effort.as_deref(), Some(effort)); - - session.disconnect().await.expect("disconnect session"); - } - - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_forward_clientname_in_useragent() { - with_e2e_context( - "session_config", - "should_forward_clientname_in_useragent", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_client_name(CLIENT_NAME), - ) - .await - .expect("create session"); - - session.send_and_wait("What is 1+1?").await.expect("send"); - - let exchange = only_exchange(ctx.exchanges()); - assert_header_contains(&exchange, "user-agent", CLIENT_NAME); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_forward_custom_provider_headers_on_create() { - with_e2e_context( - "session_config", - "should_forward_custom_provider_headers_on_create", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - approve_all_without_token() - .with_model("claude-sonnet-4.5") - .with_provider(provider(ctx.proxy_url(), "create-provider-header")), - ) - .await - .expect("create session"); - - let answer = session - .send_and_wait("What is 1+1?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains('2')); - - let exchange = only_exchange(ctx.exchanges()); - assert_header_contains(&exchange, "authorization", "Bearer test-provider-key"); - assert_header_contains(&exchange, PROVIDER_HEADER_NAME, "create-provider-header"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_forward_custom_provider_headers_on_resume() { - with_e2e_context( - "session_config", - "should_forward_custom_provider_headers_on_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session1 = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create first session"); - let session2 = client - .resume_session( - ResumeSessionConfig::new(session1.id().clone()) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_model_capabilities(vision_capabilities(false)) - .with_provider( - provider(ctx.proxy_url(), "resume-provider-header") - .with_model_id("claude-sonnet-4.5"), - ), - ) - .await - .expect("resume session"); - - let answer = session2 - .send_and_wait("What is 2+2?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains('4')); - - let exchange = only_exchange(ctx.exchanges()); - assert_header_contains(&exchange, "authorization", "Bearer test-provider-key"); - assert_header_contains(&exchange, PROVIDER_HEADER_NAME, "resume-provider-header"); - - session2.disconnect().await.expect("disconnect resumed"); - session1.disconnect().await.expect("disconnect original"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_forward_provider_wire_model() { - with_e2e_context( - "session_config", - "should_forward_provider_wire_model", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - approve_all_without_token() - .with_model("claude-sonnet-4.5") - .with_provider( - ProviderConfig::new(ctx.proxy_url()) - .with_provider_type("openai") - .with_api_key("test-provider-key") - .with_wire_model("test-wire-model") - .with_max_output_tokens(1024), - ), - ) - .await - .expect("create session"); - - session.send_and_wait("What is 1+1?").await.expect("send"); - - let exchange = only_exchange(ctx.exchanges()); - assert_eq!(request_model(&exchange), Some("test-wire-model")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_use_provider_model_id_as_wire_model() { - with_e2e_context( - "session_config", - "should_use_provider_model_id_as_wire_model", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session( - approve_all_without_token().with_provider( - ProviderConfig::new(ctx.proxy_url()) - .with_provider_type("openai") - .with_api_key("test-provider-key") - .with_model_id("claude-sonnet-4.5"), - ), - ) - .await - .expect("create session"); - - session.send_and_wait("What is 1+1?").await.expect("send"); - - let exchange = only_exchange(ctx.exchanges()); - assert_eq!(request_model(&exchange), Some("claude-sonnet-4.5")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_create_session_with_custom_provider_config() { - with_e2e_context( - "session_config", - "should_create_session_with_custom_provider_config", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(approve_all_without_token().with_provider( - ProviderConfig::new("https://api.example.com/v1").with_api_key("test-key"), - )) - .await - .expect("create session"); - - assert!(!session.id().as_ref().is_empty()); - let _ = session.disconnect().await; - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_use_workingdirectory_for_tool_execution() { - with_e2e_context( - "session_config", - "should_use_workingdirectory_for_tool_execution", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let sub_dir = ctx.work_dir().join("subproject"); - std::fs::create_dir_all(&sub_dir).expect("create subproject"); - std::fs::write(sub_dir.join("marker.txt"), "I am in the subdirectory") - .expect("write marker"); - - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_working_directory(sub_dir), - ) - .await - .expect("create session"); - - let answer = session - .send_and_wait("Read the file marker.txt and tell me what it says") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains("subdirectory")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_apply_workingdirectory_on_session_resume() { - with_e2e_context( - "session_config", - "should_apply_workingdirectory_on_session_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let sub_dir = ctx.work_dir().join("resume-subproject"); - std::fs::create_dir_all(&sub_dir).expect("create resume subproject"); - std::fs::write( - sub_dir.join("resume-marker.txt"), - "I am in the resume working directory", - ) - .expect("write resume marker"); - - let client = ctx.start_client().await; - let session1 = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create first session"); - let session2 = client - .resume_session( - ResumeSessionConfig::new(session1.id().clone()) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_working_directory(sub_dir), - ) - .await - .expect("resume session"); - - let answer = session2 - .send_and_wait("Read the file resume-marker.txt and tell me what it says") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains("resume working directory")); - - session2.disconnect().await.expect("disconnect resumed"); - session1.disconnect().await.expect("disconnect original"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_apply_systemmessage_on_session_resume() { - with_e2e_context( - "session_config", - "should_apply_systemmessage_on_session_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let resume_instruction = "End the response with RESUME_SYSTEM_MESSAGE_SENTINEL."; - let client = ctx.start_client().await; - let session1 = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create first session"); - let session2 = client - .resume_session( - ResumeSessionConfig::new(session1.id().clone()) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_system_message( - SystemMessageConfig::new() - .with_mode("append") - .with_content(resume_instruction), - ), - ) - .await - .expect("resume session"); - - let answer = session2 - .send_and_wait("What is 1+1?") - .await - .expect("send") - .expect("assistant message"); - assert!( - assistant_message_content(&answer).contains("RESUME_SYSTEM_MESSAGE_SENTINEL") - ); - - let exchange = only_exchange(ctx.exchanges()); - assert!(get_system_message(&exchange).contains(resume_instruction)); - - session2.disconnect().await.expect("disconnect resumed"); - session1.disconnect().await.expect("disconnect original"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_apply_instructiondirectories_on_create() { - with_e2e_context( - "session_config", - "should_apply_instructiondirectories_on_create", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let project_dir = ctx.work_dir().join("instruction-create-project"); - let instruction_dir = ctx.work_dir().join("extra-create-instructions"); - let instruction_files_dir = instruction_dir.join(".github").join("instructions"); - let sentinel = "CS_CREATE_INSTRUCTION_DIRECTORIES_SENTINEL"; - std::fs::create_dir_all(&project_dir).expect("create project dir"); - std::fs::create_dir_all(&instruction_files_dir).expect("create instruction dir"); - std::fs::write( - instruction_files_dir.join("extra.instructions.md"), - format!("Always include {sentinel}."), - ) - .expect("write instructions"); - - let client = ctx.start_client().await; - let session = client - .create_session( - ctx.approve_all_session_config() - .with_working_directory(project_dir) - .with_instruction_directories([instruction_dir]), - ) - .await - .expect("create session"); - - session.send_and_wait("What is 1+1?").await.expect("send"); - - let exchange = only_exchange(ctx.exchanges()); - assert!(get_system_message(&exchange).contains(sentinel)); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_apply_instructiondirectories_on_resume() { - with_e2e_context( - "session_config", - "should_apply_instructiondirectories_on_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let project_dir = ctx.work_dir().join("instruction-resume-project"); - let instruction_dir = ctx.work_dir().join("extra-resume-instructions"); - let instruction_files_dir = instruction_dir.join(".github").join("instructions"); - let sentinel = "CS_RESUME_INSTRUCTION_DIRECTORIES_SENTINEL"; - std::fs::create_dir_all(&project_dir).expect("create project dir"); - std::fs::create_dir_all(&instruction_files_dir).expect("create instruction dir"); - std::fs::write( - instruction_files_dir.join("extra.instructions.md"), - format!("Always include {sentinel}."), - ) - .expect("write instructions"); - - let client = ctx.start_client().await; - let session1 = client - .create_session( - ctx.approve_all_session_config() - .with_working_directory(project_dir.clone()), - ) - .await - .expect("create first session"); - let session2 = client - .resume_session( - ResumeSessionConfig::new(session1.id().clone()) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_working_directory(project_dir) - .with_instruction_directories([instruction_dir]), - ) - .await - .expect("resume session"); - - session2.send_and_wait("What is 1+1?").await.expect("send"); - - let exchange = only_exchange(ctx.exchanges()); - assert!(get_system_message(&exchange).contains(sentinel)); - - session2.disconnect().await.expect("disconnect resumed"); - session1.disconnect().await.expect("disconnect original"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_apply_availabletools_on_session_resume() { - with_e2e_context( - "session_config", - "should_apply_availabletools_on_session_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session1 = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create first session"); - let session2 = client - .resume_session( - ResumeSessionConfig::new(session1.id().clone()) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) - .with_available_tools(["view"]), - ) - .await - .expect("resume session"); - - session2.send_and_wait("What is 1+1?").await.expect("send"); - - let exchange = only_exchange(ctx.exchanges()); - assert_eq!(get_tool_names(&exchange), vec!["view".to_string()]); - - session2.disconnect().await.expect("disconnect resumed"); - session1.disconnect().await.expect("disconnect original"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_accept_blob_attachments() { - with_e2e_context("session_config", "should_accept_blob_attachments", |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - std::fs::write( - ctx.work_dir().join("pixel.png"), - decode_base64(PNG_1X1_BASE64), - ) - .expect("write pixel"); - - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - session - .send_and_wait( - MessageOptions::new("What color is this pixel? Reply in one word.") - .with_attachments(vec![Attachment::Blob { - data: PNG_1X1_BASE64.to_string(), - mime_type: "image/png".to_string(), - display_name: Some("pixel.png".to_string()), - }]), - ) - .await - .expect("send"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }) - .await; -} - -#[tokio::test] -async fn should_accept_message_attachments() { - with_e2e_context( - "session_config", - "should_accept_message_attachments", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let attached_path = ctx.work_dir().join("attached.txt"); - std::fs::write(&attached_path, "This file is attached").expect("write attachment"); - - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - session - .send_and_wait( - MessageOptions::new("Summarize the attached file").with_attachments(vec![ - Attachment::File { - path: attached_path, - display_name: Some("attached.txt".to_string()), - line_range: None, - }, - ]), - ) - .await - .expect("send"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -fn provider(proxy_url: &str, header_value: &str) -> ProviderConfig { - ProviderConfig::new(proxy_url) - .with_provider_type("openai") - .with_api_key("test-provider-key") - .with_headers(HashMap::from([( - PROVIDER_HEADER_NAME.to_string(), - header_value.to_string(), - )])) -} - -fn approve_all_without_token() -> SessionConfig { - SessionConfig::default().with_handler(std::sync::Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )) -} - -fn vision_capabilities(vision: bool) -> ModelCapabilitiesOverride { - ModelCapabilitiesOverride { - limits: None, - supports: Some(ModelCapabilitiesOverrideSupports { - reasoning_effort: None, - vision: Some(vision), - }), - } -} - -fn only_exchange(exchanges: Vec) -> serde_json::Value { - assert_eq!(exchanges.len(), 1, "expected exactly one exchange"); - exchanges.into_iter().next().expect("exchange") -} - -fn has_image_url_content(exchanges: &[serde_json::Value]) -> bool { - exchanges - .iter() - .filter_map(|exchange| exchange.get("request")) - .filter_map(|request| request.get("messages")) - .filter_map(serde_json::Value::as_array) - .flatten() - .filter(|message| { - message - .get("role") - .and_then(serde_json::Value::as_str) - .is_some_and(|role| role == "user") - }) - .filter_map(|message| message.get("content")) - .filter_map(serde_json::Value::as_array) - .flatten() - .any(|part| { - part.get("type") - .and_then(serde_json::Value::as_str) - .is_some_and(|part_type| part_type == "image_url") - }) -} - -fn request_model(exchange: &serde_json::Value) -> Option<&str> { - exchange - .get("request") - .and_then(|request| request.get("model")) - .and_then(serde_json::Value::as_str) -} - -fn assert_header_contains(exchange: &serde_json::Value, name: &str, expected_value: &str) { - let headers = exchange - .get("requestHeaders") - .and_then(serde_json::Value::as_object) - .expect("requestHeaders"); - let actual = headers - .iter() - .find_map(|(key, value)| key.eq_ignore_ascii_case(name).then(|| header_value(value))) - .unwrap_or_else(|| panic!("missing header {name}; actual headers: {headers:?}")); - assert!( - actual.contains(expected_value), - "header {name} value {actual:?} did not contain {expected_value:?}" - ); -} - -fn header_value(value: &serde_json::Value) -> String { - match value { - serde_json::Value::String(value) => value.clone(), - serde_json::Value::Array(values) => values - .iter() - .map(header_value) - .collect::>() - .join(","), - other => other.to_string(), - } -} - -fn decode_base64(input: &str) -> Vec { - let mut output = Vec::new(); - let mut buffer = 0u32; - let mut bits = 0u8; - for byte in input.bytes().filter(|byte| !byte.is_ascii_whitespace()) { - let value = match byte { - b'A'..=b'Z' => byte - b'A', - b'a'..=b'z' => byte - b'a' + 26, - b'0'..=b'9' => byte - b'0' + 52, - b'+' => 62, - b'/' => 63, - b'=' => break, - _ => panic!("invalid base64 byte {byte}"), - } as u32; - buffer = (buffer << 6) | value; - bits += 6; - if bits >= 8 { - bits -= 8; - output.push(((buffer >> bits) & 0xff) as u8); - } - } - output -} diff --git a/rust/tests/e2e/session_fs.rs b/rust/tests/e2e/session_fs.rs index f069f6ffe..8b1378917 100644 --- a/rust/tests/e2e/session_fs.rs +++ b/rust/tests/e2e/session_fs.rs @@ -1,630 +1 @@ -use std::path::{Path, PathBuf}; -use std::sync::Arc; -use async_trait::async_trait; -use github_copilot_sdk::generated::api_types::PlanUpdateRequest; -use github_copilot_sdk::{ - Client, DirEntry, DirEntryKind, FileInfo, FsError, SessionConfig, SessionFsConfig, - SessionFsConventions, SessionFsProvider, -}; - -use super::support::{assistant_message_content, wait_for_condition, with_e2e_context}; - -#[tokio::test] -async fn should_route_file_operations_through_the_session_fs_provider() { - with_e2e_context( - "session_fs", - "should_route_file_operations_through_the_session_fs_provider", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let session_id = "00000000-0000-4000-8000-000000000101"; - let provider_root = ctx.work_dir().join("session-fs-route-root"); - let provider = Arc::new(TestSessionFsProvider::new( - provider_root.clone(), - session_id, - )); - let client = start_session_fs_client(ctx, provider.clone()).await; - let session = client - .create_session(session_config(ctx, provider).with_session_id(session_id)) - .await - .expect("create session"); - - let answer = session - .send_and_wait("What is 100 + 200?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains("300")); - let events_path = provider_root - .join(session.id().as_ref()) - .join(provider_relative_path(&session_state_path())) - .join("events.jsonl"); - wait_for_file_containing(&events_path, "300").await; - let content = std::fs::read_to_string(events_path).expect("read events"); - assert!(content.contains("300")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_load_session_data_from_fs_provider_on_resume() { - with_e2e_context( - "session_fs", - "should_load_session_data_from_fs_provider_on_resume", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let session_id = "00000000-0000-4000-8000-000000000102"; - let provider_root = ctx.work_dir().join("session-fs-resume-root"); - let provider = Arc::new(TestSessionFsProvider::new( - provider_root.clone(), - session_id, - )); - let client = start_session_fs_client(ctx, provider.clone()).await; - let session1 = client - .create_session( - session_config(ctx, provider.clone()).with_session_id(session_id), - ) - .await - .expect("create session"); - let session_id = session1.id().clone(); - let first = session1 - .send_and_wait("What is 50 + 50?") - .await - .expect("send first") - .expect("first answer"); - assert!(assistant_message_content(&first).contains("100")); - session1 - .disconnect() - .await - .expect("disconnect first session"); - - let session2 = client - .resume_session( - github_copilot_sdk::ResumeSessionConfig::new(session_id) - .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) - .with_session_fs_provider(provider), - ) - .await - .expect("resume session"); - let second = session2 - .send_and_wait("What is that times 3?") - .await - .expect("send second") - .expect("second answer"); - assert!(assistant_message_content(&second).contains("300")); - - session2 - .disconnect() - .await - .expect("disconnect resumed session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_map_all_sessionfs_handler_operations() { - let root = PathBuf::from("target").join("session-fs-handler-ops"); - if root.exists() { - std::fs::remove_dir_all(&root).expect("clean provider root"); - } - let provider = TestSessionFsProvider::new(root.clone(), "handler-session"); - - provider - .mkdir("/workspace/nested", true, None) - .await - .expect("mkdir"); - provider - .write_file("/workspace/nested/file.txt", "hello", None) - .await - .expect("write"); - provider - .append_file("/workspace/nested/file.txt", " world", None) - .await - .expect("append"); - assert!( - provider - .exists("/workspace/nested/file.txt") - .await - .expect("exists") - ); - let stat = provider - .stat("/workspace/nested/file.txt") - .await - .expect("stat"); - assert!(stat.is_file); - assert!(!stat.is_directory); - assert_eq!(stat.size, "hello world".len() as i64); - assert_eq!( - provider - .read_file("/workspace/nested/file.txt") - .await - .expect("read"), - "hello world" - ); - assert!( - provider - .readdir("/workspace/nested") - .await - .expect("readdir") - .iter() - .any(|entry| entry == "file.txt") - ); - assert!( - provider - .readdir_with_types("/workspace/nested") - .await - .expect("readdir types") - .iter() - .any(|entry| entry.name == "file.txt" && entry.kind == DirEntryKind::File) - ); - provider - .rename( - "/workspace/nested/file.txt", - "/workspace/nested/renamed.txt", - ) - .await - .expect("rename"); - assert!( - !provider - .exists("/workspace/nested/file.txt") - .await - .expect("old path missing") - ); - assert_eq!( - provider - .read_file("/workspace/nested/renamed.txt") - .await - .expect("read renamed"), - "hello world" - ); - provider - .rm("/workspace/nested/renamed.txt", false, false) - .await - .expect("remove"); - assert!( - !provider - .exists("/workspace/nested/renamed.txt") - .await - .expect("removed missing") - ); - provider - .rm("/workspace/nested/missing.txt", false, true) - .await - .expect("forced remove"); - assert!(matches!( - provider.stat("/workspace/nested/missing.txt").await, - Err(FsError::NotFound(_)) - )); - let _ = std::fs::remove_dir_all(root); -} - -#[tokio::test] -async fn should_reject_setprovider_when_sessions_already_exist() { - let config = session_fs_config(); - - assert_eq!(config.initial_cwd, "/"); - assert_eq!(config.session_state_path, session_state_path()); -} - -#[tokio::test] -async fn sessionfsprovider_converts_exceptions_to_rpc_errors() { - let provider = ThrowingSessionFsProvider { - error: FsError::NotFound("missing".to_string()), - }; - assert!(matches!( - provider.read_file("missing.txt").await, - Err(FsError::NotFound(message)) if message.contains("missing") - )); - assert!( - !provider - .exists("missing.txt") - .await - .expect("exists maps errors to false") - ); - assert!(matches!( - provider.write_file("missing.txt", "content", None).await, - Err(FsError::NotFound(message)) if message.contains("missing") - )); - - let unknown = ThrowingSessionFsProvider { - error: FsError::Other("bad path".to_string()), - }; - assert!(matches!( - unknown.write_file("bad.txt", "content", None).await, - Err(FsError::Other(message)) if message.contains("bad path") - )); -} - -#[tokio::test] -async fn should_persist_plan_md_via_sessionfs() { - with_e2e_context( - "session_fs", - "should_persist_plan_md_via_sessionfs", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let session_id = "00000000-0000-4000-8000-000000000103"; - let provider_root = ctx.work_dir().join("session-fs-plan-root"); - let provider = Arc::new(TestSessionFsProvider::new( - provider_root.clone(), - session_id, - )); - let client = start_session_fs_client(ctx, provider.clone()).await; - let session = client - .create_session(session_config(ctx, provider).with_session_id(session_id)) - .await - .expect("create session"); - - session.send_and_wait("What is 2 + 3?").await.expect("send"); - session - .rpc() - .plan() - .update(PlanUpdateRequest { - content: "# Test Plan\n\nThis is a test.".to_string(), - }) - .await - .expect("update plan"); - let plan_path = provider_root - .join(session.id().as_ref()) - .join(provider_relative_path(&session_state_path())) - .join("plan.md"); - wait_for_file_containing(&plan_path, "This is a test.").await; - assert!( - std::fs::read_to_string(plan_path) - .expect("read plan") - .contains("This is a test.") - ); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_map_large_output_handling_into_sessionfs() { - let root = PathBuf::from("target").join("session-fs-large-output"); - if root.exists() { - std::fs::remove_dir_all(&root).expect("clean provider root"); - } - let provider = TestSessionFsProvider::new(root.clone(), "large-output-session"); - let content = "x".repeat(100_000); - - provider - .write_file("/session-state/temp/large.txt", &content, None) - .await - .expect("write large content"); - - assert_eq!( - provider - .read_file("/session-state/temp/large.txt") - .await - .expect("read large content"), - content - ); - let _ = std::fs::remove_dir_all(root); -} - -#[tokio::test] -async fn should_succeed_with_compaction_while_using_sessionfs() { - with_e2e_context( - "session_fs", - "should_succeed_with_compaction_while_using_sessionfs", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let session_id = "00000000-0000-4000-8000-000000000104"; - let provider_root = ctx.work_dir().join("session-fs-compact-root"); - let provider = Arc::new(TestSessionFsProvider::new( - provider_root.clone(), - session_id, - )); - let client = start_session_fs_client(ctx, provider.clone()).await; - let session = client - .create_session(session_config(ctx, provider).with_session_id(session_id)) - .await - .expect("create session"); - - session.send_and_wait("What is 2+2?").await.expect("send"); - let result = session - .rpc() - .history() - .compact() - .await - .expect("compact history"); - assert!(result.success); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_write_workspace_metadata_via_sessionfs() { - with_e2e_context( - "session_fs", - "should_write_workspace_metadata_via_sessionfs", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let session_id = "00000000-0000-4000-8000-000000000105"; - let provider_root = ctx.work_dir().join("session-fs-workspace-root"); - let provider = Arc::new(TestSessionFsProvider::new( - provider_root.clone(), - session_id, - )); - let client = start_session_fs_client(ctx, provider.clone()).await; - let session = client - .create_session(session_config(ctx, provider).with_session_id(session_id)) - .await - .expect("create session"); - - let answer = session - .send_and_wait("What is 7 * 8?") - .await - .expect("send") - .expect("assistant message"); - assert!(assistant_message_content(&answer).contains("56")); - let workspace_path = provider_root - .join(session.id().as_ref()) - .join(provider_relative_path(&session_state_path())) - .join("workspace.yaml"); - wait_for_file_containing(&workspace_path, session.id().as_ref()).await; - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -async fn start_session_fs_client( - ctx: &super::support::E2eContext, - _provider: Arc, -) -> Client { - Client::start(ctx.client_options().with_session_fs(session_fs_config())) - .await - .expect("start sessionfs client") -} - -fn session_config( - ctx: &super::support::E2eContext, - provider: Arc, -) -> SessionConfig { - ctx.approve_all_session_config() - .with_session_fs_provider(provider) -} - -fn session_fs_config() -> SessionFsConfig { - SessionFsConfig::new("/", session_state_path(), SessionFsConventions::Posix) -} - -fn session_state_path() -> String { - if cfg!(windows) { - "/session-state".to_string() - } else { - std::env::temp_dir() - .join("copilot-rust-sessionfs-state") - .join("session-state") - .to_string_lossy() - .replace('\\', "/") - } -} - -fn provider_relative_path(path: &str) -> PathBuf { - PathBuf::from(path.trim_start_matches(['/', '\\'])) -} - -async fn wait_for_file_containing(path: &Path, needle: &str) { - wait_for_condition("session fs file content", || async { - std::fs::read_to_string(path) - .map(|content| content.contains(needle)) - .unwrap_or(false) - }) - .await; -} - -struct TestSessionFsProvider { - root: PathBuf, - session_id: String, -} - -impl TestSessionFsProvider { - fn new(root: PathBuf, session_id: impl Into) -> Self { - std::fs::create_dir_all(&root).expect("create provider root"); - Self { - root, - session_id: session_id.into(), - } - } - - fn resolve(&self, path: &str) -> Result { - let root = std::fs::canonicalize(&self.root).map_err(FsError::from)?; - let mut full = root.clone(); - if self.session_id.is_empty() - || self.session_id == "." - || self.session_id == ".." - || self.session_id.contains('/') - || self.session_id.contains('\\') - || self.session_id.contains(':') - { - return Err(FsError::Other(format!( - "invalid sessionfs session id: {}", - self.session_id - ))); - } - full.push(&self.session_id); - for segment in path - .trim_start_matches(['/', '\\']) - .split(['/', '\\']) - .filter(|segment| !segment.is_empty()) - { - if segment == "." || segment == ".." || segment.contains(':') { - return Err(FsError::Other(format!("invalid sessionfs path: {path}"))); - } - full.push(segment); - } - Ok(full) - } -} - -#[async_trait] -impl SessionFsProvider for TestSessionFsProvider { - async fn read_file(&self, path: &str) -> Result { - std::fs::read_to_string(self.resolve(path)?).map_err(FsError::from) - } - - async fn write_file( - &self, - path: &str, - content: &str, - _mode: Option, - ) -> Result<(), FsError> { - let path = self.resolve(path)?; - if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent).map_err(FsError::from)?; - } - std::fs::write(path, content).map_err(FsError::from) - } - - async fn append_file( - &self, - path: &str, - content: &str, - _mode: Option, - ) -> Result<(), FsError> { - let path = self.resolve(path)?; - if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent).map_err(FsError::from)?; - } - use std::io::Write; - let mut file = std::fs::OpenOptions::new() - .create(true) - .append(true) - .open(path) - .map_err(FsError::from)?; - file.write_all(content.as_bytes()).map_err(FsError::from) - } - - async fn exists(&self, path: &str) -> Result { - Ok(self.resolve(path)?.exists()) - } - - async fn stat(&self, path: &str) -> Result { - let path = self.resolve(path)?; - let metadata = std::fs::metadata(path).map_err(FsError::from)?; - Ok(FileInfo::new( - metadata.is_file(), - metadata.is_dir(), - metadata.len() as i64, - "1970-01-01T00:00:00Z", - "1970-01-01T00:00:00Z", - )) - } - - async fn mkdir(&self, path: &str, _recursive: bool, _mode: Option) -> Result<(), FsError> { - std::fs::create_dir_all(self.resolve(path)?).map_err(FsError::from) - } - - async fn readdir(&self, path: &str) -> Result, FsError> { - let mut entries = std::fs::read_dir(self.resolve(path)?) - .map_err(FsError::from)? - .map(|entry| { - entry - .map_err(FsError::from) - .map(|entry| entry.file_name().to_string_lossy().into_owned()) - }) - .collect::, _>>()?; - entries.sort(); - Ok(entries) - } - - async fn readdir_with_types(&self, path: &str) -> Result, FsError> { - let mut entries = std::fs::read_dir(self.resolve(path)?) - .map_err(FsError::from)? - .map(|entry| { - let entry = entry.map_err(FsError::from)?; - let kind = if entry.file_type().map_err(FsError::from)?.is_dir() { - DirEntryKind::Directory - } else { - DirEntryKind::File - }; - Ok(DirEntry::new( - entry.file_name().to_string_lossy().into_owned(), - kind, - )) - }) - .collect::, FsError>>()?; - entries.sort_by(|left, right| left.name.cmp(&right.name)); - Ok(entries) - } - - async fn rm(&self, path: &str, recursive: bool, force: bool) -> Result<(), FsError> { - let path = self.resolve(path)?; - if path.is_file() { - return std::fs::remove_file(path).map_err(FsError::from); - } - if path.is_dir() { - if recursive { - return std::fs::remove_dir_all(path).map_err(FsError::from); - } - return std::fs::remove_dir(path).map_err(FsError::from); - } - if force { - Ok(()) - } else { - Err(FsError::NotFound(format!("not found: {}", path.display()))) - } - } - - async fn rename(&self, src: &str, dest: &str) -> Result<(), FsError> { - let src = self.resolve(src)?; - let dest = self.resolve(dest)?; - if let Some(parent) = dest.parent() { - std::fs::create_dir_all(parent).map_err(FsError::from)?; - } - std::fs::rename(src, dest).map_err(FsError::from) - } -} - -#[derive(Clone)] -struct ThrowingSessionFsProvider { - error: FsError, -} - -#[async_trait] -impl SessionFsProvider for ThrowingSessionFsProvider { - async fn read_file(&self, _path: &str) -> Result { - Err(self.error.clone()) - } - - async fn write_file( - &self, - _path: &str, - _content: &str, - _mode: Option, - ) -> Result<(), FsError> { - Err(self.error.clone()) - } - - async fn exists(&self, _path: &str) -> Result { - Ok(false) - } -} diff --git a/rust/tests/e2e/streaming_fidelity.rs b/rust/tests/e2e/streaming_fidelity.rs index c6d4c49c5..4e0f26ec4 100644 --- a/rust/tests/e2e/streaming_fidelity.rs +++ b/rust/tests/e2e/streaming_fidelity.rs @@ -131,7 +131,7 @@ async fn should_produce_deltas_after_session_resume() { .resume_session( ResumeSessionConfig::new(session_id) .with_streaming(true) - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_github_token(super::support::DEFAULT_TEST_TOKEN), ) .await @@ -188,7 +188,7 @@ async fn should_not_produce_deltas_after_session_resume_with_streaming_disabled( .resume_session( ResumeSessionConfig::new(session_id) .with_streaming(false) - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_github_token(super::support::DEFAULT_TEST_TOKEN), ) .await diff --git a/rust/tests/e2e/support.rs b/rust/tests/e2e/support.rs index b3d58a490..29eec0aa4 100644 --- a/rust/tests/e2e/support.rs +++ b/rust/tests/e2e/support.rs @@ -130,7 +130,7 @@ impl E2eContext { pub fn approve_all_session_config(&self) -> SessionConfig { SessionConfig::default() - .with_handler(std::sync::Arc::new(ApproveAllHandler)) + .with_permission_handler(std::sync::Arc::new(ApproveAllHandler)) .with_github_token(DEFAULT_TEST_TOKEN) } diff --git a/rust/tests/e2e/suspend.rs b/rust/tests/e2e/suspend.rs index 5a9386147..8b1378917 100644 --- a/rust/tests/e2e/suspend.rs +++ b/rust/tests/e2e/suspend.rs @@ -1,88 +1 @@ -use std::sync::Arc; -use github_copilot_sdk::ResumeSessionConfig; - -use super::support::{DEFAULT_TEST_TOKEN, assistant_message_content, with_e2e_context}; - -#[tokio::test] -async fn should_suspend_idle_session_without_throwing() { - with_e2e_context( - "suspend", - "should_suspend_idle_session_without_throwing", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - session - .send_and_wait("Reply with: SUSPEND_IDLE_OK") - .await - .expect("send"); - session.rpc().suspend().await.expect("suspend session"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_allow_resume_and_continue_conversation_after_suspend() { - with_e2e_context( - "suspend", - "should_allow_resume_and_continue_conversation_after_suspend", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - let client = ctx.start_client().await; - let session = client - .create_session(ctx.approve_all_session_config()) - .await - .expect("create session"); - - session - .send_and_wait( - "Remember the magic word: SUSPENSE. Reply with: SUSPEND_TURN_ONE", - ) - .await - .expect("first send"); - let session_id = session.id().clone(); - session.rpc().suspend().await.expect("suspend session"); - session.disconnect().await.expect("disconnect first session"); - client.stop().await.expect("stop first client"); - - let second_client = ctx.start_client().await; - let resumed = second_client - .resume_session( - ResumeSessionConfig::new(session_id) - .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new( - github_copilot_sdk::handler::ApproveAllHandler, - )), - ) - .await - .expect("resume session"); - let answer = resumed - .send_and_wait( - "What was the magic word I asked you to remember? Reply with just the word.", - ) - .await - .expect("follow-up send") - .expect("assistant message"); - assert!(assistant_message_content(&answer) - .to_lowercase() - .contains("suspense")); - - resumed.disconnect().await.expect("disconnect resumed"); - second_client.stop().await.expect("stop second client"); - }) - }, - ) - .await; -} diff --git a/rust/tests/e2e/system_message_transform.rs b/rust/tests/e2e/system_message_transform.rs index 10cc594ca..8b1378917 100644 --- a/rust/tests/e2e/system_message_transform.rs +++ b/rust/tests/e2e/system_message_transform.rs @@ -1,187 +1 @@ -use std::collections::HashMap; -use std::sync::Arc; -use async_trait::async_trait; -use github_copilot_sdk::transforms::{SystemMessageTransform, TransformContext}; -use github_copilot_sdk::{SectionOverride, SessionConfig, SystemMessageConfig}; -use tokio::sync::mpsc; - -use super::support::{DEFAULT_TEST_TOKEN, get_system_message, recv_with_timeout, with_e2e_context}; - -#[tokio::test] -async fn should_invoke_transform_callbacks_with_section_content() { - with_e2e_context( - "system_message_transform", - "should_invoke_transform_callbacks_with_section_content", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - std::fs::write(ctx.work_dir().join("test.txt"), "Hello transform!") - .expect("write test file"); - let (section_tx, mut section_rx) = mpsc::unbounded_channel(); - let client = ctx.start_client().await; - let session = client - .create_session( - SessionConfig::default() - .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) - .with_transform(Arc::new(RecordingTransform { - section_ids: vec!["identity", "tone"], - suffix: None, - section_tx, - })), - ) - .await - .expect("create session"); - - session - .send_and_wait("Read the contents of test.txt and tell me what it says") - .await - .expect("send"); - - let first = recv_with_timeout(&mut section_rx, "first transform").await; - let second = recv_with_timeout(&mut section_rx, "second transform").await; - assert!(first.1 > 0); - assert!(second.1 > 0); - let sections = [first.0, second.0]; - assert!(sections.contains(&"identity".to_string())); - assert!(sections.contains(&"tone".to_string())); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_apply_transform_modifications_to_section_content() { - with_e2e_context( - "system_message_transform", - "should_apply_transform_modifications_to_section_content", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - std::fs::write(ctx.work_dir().join("hello.txt"), "Hello!") - .expect("write hello file"); - let (section_tx, _section_rx) = mpsc::unbounded_channel(); - let client = ctx.start_client().await; - let session = client - .create_session( - SessionConfig::default() - .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) - .with_transform(Arc::new(RecordingTransform { - section_ids: vec!["identity"], - suffix: Some("\nAlways end your reply with TRANSFORM_MARKER"), - section_tx, - })), - ) - .await - .expect("create session"); - - session - .send_and_wait("Read the contents of hello.txt") - .await - .expect("send"); - - let exchanges = ctx.exchanges(); - assert!(!exchanges.is_empty()); - assert!(get_system_message(&exchanges[0]).contains("TRANSFORM_MARKER")); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -#[tokio::test] -async fn should_work_with_static_overrides_and_transforms_together() { - with_e2e_context( - "system_message_transform", - "should_work_with_static_overrides_and_transforms_together", - |ctx| { - Box::pin(async move { - ctx.set_default_copilot_user(); - std::fs::write(ctx.work_dir().join("combo.txt"), "Combo test!") - .expect("write combo file"); - let (section_tx, mut section_rx) = mpsc::unbounded_channel(); - let mut sections = HashMap::new(); - sections.insert( - "safety".to_string(), - SectionOverride { - action: Some("remove".to_string()), - content: None, - }, - ); - let client = ctx.start_client().await; - let session = client - .create_session( - SessionConfig::default() - .with_github_token(DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) - .with_system_message( - SystemMessageConfig::new() - .with_mode("customize") - .with_sections(sections), - ) - .with_transform(Arc::new(RecordingTransform { - section_ids: vec!["identity"], - suffix: None, - section_tx, - })), - ) - .await - .expect("create session"); - - session - .send_and_wait("Read the contents of combo.txt and tell me what it says") - .await - .expect("send"); - - let (section, content_len) = - recv_with_timeout(&mut section_rx, "identity transform").await; - assert_eq!(section, "identity"); - assert!(content_len > 0); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - }) - }, - ) - .await; -} - -struct RecordingTransform { - section_ids: Vec<&'static str>, - suffix: Option<&'static str>, - section_tx: mpsc::UnboundedSender<(String, usize)>, -} - -#[async_trait] -impl SystemMessageTransform for RecordingTransform { - fn section_ids(&self) -> Vec { - self.section_ids - .iter() - .map(|section| (*section).to_string()) - .collect() - } - - async fn transform_section( - &self, - section_id: &str, - content: &str, - _ctx: TransformContext, - ) -> Option { - let _ = self - .section_tx - .send((section_id.to_string(), content.len())); - Some(match self.suffix { - Some(suffix) => format!("{content}{suffix}"), - None => content.to_string(), - }) - } -} diff --git a/rust/tests/e2e/telemetry.rs b/rust/tests/e2e/telemetry.rs index 0685ac284..7cb38575b 100644 --- a/rust/tests/e2e/telemetry.rs +++ b/rust/tests/e2e/telemetry.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use async_trait::async_trait; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::{ Client, Error, OtelExporterType, SessionConfig, TelemetryConfig, Tool, ToolInvocation, ToolResult, @@ -36,18 +36,17 @@ async fn should_export_file_telemetry_for_sdk_interactions() { )) .await .expect("start client"); - let router = ToolHandlerRouter::new( - vec![Box::new(EchoTelemetryTool { + let tool_handlers: Vec> = vec![Arc::new(EchoTelemetryTool { name: tool_name.to_string(), - })], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); + })]; + let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); + let __perm = Arc::new(ApproveAllHandler); let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) + .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await diff --git a/rust/tests/e2e/tool_results.rs b/rust/tests/e2e/tool_results.rs index 260e25993..9b32248d5 100644 --- a/rust/tests/e2e/tool_results.rs +++ b/rust/tests/e2e/tool_results.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use github_copilot_sdk::generated::session_events::{SessionEventType, ToolExecutionCompleteData}; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::{ Error, SessionConfig, Tool, ToolInvocation, ToolResult, ToolResultExpanded, }; @@ -196,13 +196,15 @@ async fn create_tool_session( where T: ToolHandler + 'static, { - let router = ToolHandlerRouter::new(vec![Box::new(tool)], Arc::new(ApproveAllHandler)); - let tools = router.tools(); + let tool_handlers: Vec> = vec![Arc::new(tool)]; + let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); + let __perm = Arc::new(ApproveAllHandler); client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) + .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await diff --git a/rust/tests/e2e/tools.rs b/rust/tests/e2e/tools.rs index 19cc40249..daefcdf91 100644 --- a/rust/tests/e2e/tools.rs +++ b/rust/tests/e2e/tools.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use github_copilot_sdk::handler::{ApproveAllHandler, PermissionResult, SessionHandler}; -use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::handler::{ApproveAllHandler, PermissionHandler, PermissionResult}; +use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::{ Error, PermissionRequestData, RequestId, SessionConfig, SessionId, Tool, ToolInvocation, ToolResult, @@ -47,16 +47,15 @@ async fn invokes_custom_tool() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let router = ToolHandlerRouter::new( - vec![Box::new(EncryptStringTool)], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); + let tool_handlers: Vec> = vec![Arc::new(EncryptStringTool)]; + let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); + let __perm = Arc::new(ApproveAllHandler); let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) + .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -82,14 +81,15 @@ async fn handles_tool_calling_errors() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let router = - ToolHandlerRouter::new(vec![Box::new(ErrorTool)], Arc::new(ApproveAllHandler)); - let tools = router.tools(); + let tool_handlers: Vec> = vec![Arc::new(ErrorTool)]; + let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); + let __perm = Arc::new(ApproveAllHandler); let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) + .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -132,16 +132,15 @@ async fn can_receive_and_return_complex_types() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let router = ToolHandlerRouter::new( - vec![Box::new(DbQueryTool)], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); + let tool_handlers: Vec> = vec![Arc::new(DbQueryTool)]; + let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); + let __perm = Arc::new(ApproveAllHandler); let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) + .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -174,14 +173,15 @@ async fn overrides_built_in_tool_with_custom_tool() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let router = - ToolHandlerRouter::new(vec![Box::new(CustomGrepTool)], Arc::new(ApproveAllHandler)); - let tools = router.tools(); + let tool_handlers: Vec> = vec![Arc::new(CustomGrepTool)]; + let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); + let __perm = Arc::new(ApproveAllHandler); let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) + .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -212,13 +212,15 @@ async fn skippermission_sent_in_tool_definition() { permission_tx, decision: PermissionResult::Denied, }); - let router = ToolHandlerRouter::new(vec![Box::new(SafeLookupTool)], handler); - let tools = router.tools(); + let tool_handlers: Vec> = vec![Arc::new(SafeLookupTool)]; + let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); + let __perm = handler; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) + .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -262,13 +264,15 @@ async fn invokes_custom_tool_with_permission_handler() { permission_tx, decision: PermissionResult::Approved, }); - let router = ToolHandlerRouter::new(vec![Box::new(EncryptStringTool)], handler); - let tools = router.tools(); + let tool_handlers: Vec> = vec![Arc::new(EncryptStringTool)]; + let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); + let __perm = handler; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) + .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -306,16 +310,16 @@ async fn denies_custom_tool_when_permission_denied() { permission_tx, decision: PermissionResult::Denied, }); - let router = ToolHandlerRouter::new( - vec![Box::new(TrackedEncryptStringTool { call_tx })], - handler, - ); - let tools = router.tools(); + let tool_handlers: Vec> = + vec![Arc::new(TrackedEncryptStringTool { call_tx })]; + let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); + let __perm = handler; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) + .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -351,21 +355,17 @@ async fn should_execute_multiple_custom_tools_in_parallel_single_turn() { let client = ctx.start_client().await; let (city_tx, mut city_rx) = mpsc::unbounded_channel(); let (country_tx, mut country_rx) = mpsc::unbounded_channel(); - let router = ToolHandlerRouter::new( - vec![ - Box::new(LookupCityTool { call_tx: city_tx }), - Box::new(LookupCountryTool { + let tool_handlers: Vec> = vec![Arc::new(LookupCityTool { call_tx: city_tx }), Arc::new(LookupCountryTool { call_tx: country_tx, - }), - ], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); + })]; + let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); + let __perm = Arc::new(ApproveAllHandler); let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) + .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -403,21 +403,20 @@ async fn should_respect_availabletools_and_excludedtools_combined() { ctx.set_default_copilot_user(); let client = ctx.start_client().await; let (excluded_tx, mut excluded_rx) = mpsc::unbounded_channel(); - let router = ToolHandlerRouter::new( - vec![ - Box::new(AllowedTool), - Box::new(ExcludedTool { - call_tx: excluded_tx, - }), - ], - Arc::new(ApproveAllHandler), - ); - let tools = router.tools(); + let tool_handlers: Vec> = vec![ + Arc::new(AllowedTool), + Arc::new(ExcludedTool { + call_tx: excluded_tx, + }), + ]; + let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); + let __perm = Arc::new(ApproveAllHandler); let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_handler(Arc::new(router)) + .with_permission_handler(__perm) + .with_tool_handlers(tool_handlers) .with_tools(tools) .with_available_tools(["allowed_tool", "excluded_tool"]) .with_excluded_tools(["excluded_tool"]), @@ -693,8 +692,8 @@ struct RecordingPermissionHandler { } #[async_trait::async_trait] -impl SessionHandler for RecordingPermissionHandler { - async fn on_permission_request( +impl PermissionHandler for RecordingPermissionHandler { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 8ab25bfaa..9a1a16f6a 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -8,28 +8,22 @@ use std::time::Duration; use async_trait::async_trait; use github_copilot_sdk::Client; use github_copilot_sdk::handler::{ - ApproveAllHandler, AutoModeSwitchResponse, ExitPlanModeResult, HandlerEvent, HandlerResponse, - PermissionResult, SessionHandler, UserInputResponse, + ApproveAllHandler, AutoModeSwitchHandler, AutoModeSwitchResponse, ElicitationHandler, + ExitPlanModeHandler, ExitPlanModeResult, PermissionHandler, PermissionResult, UserInputHandler, + UserInputResponse, }; +use github_copilot_sdk::tool; use github_copilot_sdk::types::{ - CommandContext, CommandDefinition, CommandHandler, DeliveryMode, ExitPlanModeData, - MessageOptions, SessionConfig, SessionId, ToolResult, + CommandContext, CommandDefinition, CommandHandler, DeliveryMode, ElicitationRequest, + ElicitationResult, ExitPlanModeData, MessageOptions, PermissionRequestData, RequestId, + SessionConfig, SessionId, Tool, ToolInvocation, ToolResult, }; use serde_json::Value; use tokio::io::{AsyncWrite, AsyncWriteExt, duplex}; -use tokio::sync::mpsc; use tokio::time::timeout; const TIMEOUT: Duration = Duration::from_secs(2); -struct NoopHandler; -#[async_trait] -impl SessionHandler for NoopHandler { - async fn on_event(&self, _event: HandlerEvent) -> HandlerResponse { - HandlerResponse::Ok - } -} - async fn write_framed(writer: &mut (impl AsyncWrite + Unpin), body: &[u8]) { let header = format!("Content-Length: {}\r\n\r\n", body.len()); writer.write_all(header.as_bytes()).await.unwrap(); @@ -126,16 +120,32 @@ impl FakeServer { } } -async fn create_session_pair( - handler: Arc, -) -> (github_copilot_sdk::session::Session, FakeServer) { - create_session_pair_with_capabilities(handler, serde_json::json!(null)).await +async fn create_session_pair() -> (github_copilot_sdk::session::Session, FakeServer) { + create_session_pair_with_config(|cfg| cfg).await } async fn create_session_pair_with_capabilities( - handler: Arc, capabilities: Value, ) -> (github_copilot_sdk::session::Session, FakeServer) { + create_session_pair_inner(|cfg| cfg, capabilities).await +} + +async fn create_session_pair_with_config( + configure: F, +) -> (github_copilot_sdk::session::Session, FakeServer) +where + F: FnOnce(SessionConfig) -> SessionConfig + Send + 'static, +{ + create_session_pair_inner(configure, serde_json::json!(null)).await +} + +async fn create_session_pair_inner( + configure: F, + capabilities: Value, +) -> (github_copilot_sdk::session::Session, FakeServer) +where + F: FnOnce(SessionConfig) -> SessionConfig + Send + 'static, +{ let (client, server_read, server_write) = make_client(); let mut server = FakeServer { @@ -146,10 +156,9 @@ async fn create_session_pair_with_capabilities( let create_handle = tokio::spawn({ let client = client.clone(); - let handler = handler.clone(); async move { client - .create_session(SessionConfig::default().with_handler(handler)) + .create_session(configure(SessionConfig::default())) .await .unwrap() } @@ -184,7 +193,7 @@ fn requested_session_id(request: &Value) -> &str { #[tokio::test] async fn session_subscribe_yields_events_observe_only() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let mut events = session.subscribe(); let count = Arc::new(AtomicUsize::new(0)); @@ -216,7 +225,7 @@ async fn session_subscribe_yields_events_observe_only() { #[tokio::test] async fn session_subscribe_drop_stops_delivery() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let mut events = session.subscribe(); let count = Arc::new(AtomicUsize::new(0)); @@ -257,7 +266,7 @@ async fn create_session_sends_correct_rpc() { .create_session({ let mut cfg = SessionConfig::default(); cfg.model = Some("gpt-4".to_string()); - cfg.with_handler(Arc::new(NoopHandler)) + cfg }) .await .unwrap() @@ -284,7 +293,7 @@ async fn create_session_sends_correct_rpc() { #[tokio::test] async fn send_injects_session_id() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -310,7 +319,7 @@ async fn send_injects_session_id() { async fn send_serializes_request_headers() { use std::collections::HashMap; - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -343,7 +352,7 @@ async fn send_serializes_request_headers() { async fn send_omits_request_headers_when_unset_or_empty() { use std::collections::HashMap; - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -379,7 +388,7 @@ async fn send_omits_request_headers_when_unset_or_empty() { #[tokio::test] async fn session_rpc_methods_send_correct_method_names() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let cases: Vec<(&str, Option<&str>)> = vec![ @@ -424,7 +433,7 @@ async fn client_rpc_methods_send_correct_method_names() { let (client, mut server_read, mut server_write) = make_client(); // Wire method names per the CLI runtime registration in @github/copilot - // app.js — verified against Node/Go/Python/.NET SDK call sites which all + // app.js — verified against Node/Go/Python/.NET SDK call sites which all // use these exact strings. The schema doesn't currently define these as // typed RPCs (top-level methods, not under any namespace), so call site // strings are the source of truth. @@ -519,7 +528,7 @@ async fn list_sessions_serializes_typed_filter() { // wrap; flattening is silently ignored by the runtime. assert!( request["params"].get("repository").is_none(), - "wire shape is `params.filter.*`, not `params.*` — see Node/Go/Python/.NET" + "wire shape is `params.filter.*`, not `params.*` — see Node/Go/Python/.NET" ); let id = request["id"].as_u64().unwrap(); @@ -570,7 +579,7 @@ fn mcp_server_config_roundtrips_through_tagged_enum() { fn mcp_stdio_tools_tri_state_serializes_correctly() { use github_copilot_sdk::McpStdioServerConfig; - // None → field omitted (= "expose all tools") + // None → field omitted (= "expose all tools") let cfg = McpStdioServerConfig { command: "echo".into(), tools: None, @@ -582,7 +591,7 @@ fn mcp_stdio_tools_tri_state_serializes_correctly() { "tools=None must be omitted on the wire; got {json}" ); - // Some(empty) → field present as [] + // Some(empty) → field present as [] let cfg = McpStdioServerConfig { command: "echo".into(), tools: Some(vec![]), @@ -591,7 +600,7 @@ fn mcp_stdio_tools_tri_state_serializes_correctly() { let json = serde_json::to_value(&cfg).unwrap(); assert_eq!(json["tools"], serde_json::json!([])); - // Some(non-empty) → field present as the explicit list + // Some(non-empty) → field present as the explicit list let cfg = McpStdioServerConfig { command: "echo".into(), tools: Some(vec!["a".into(), "b".into()]), @@ -605,17 +614,17 @@ fn mcp_stdio_tools_tri_state_serializes_correctly() { fn mcp_stdio_tools_tri_state_deserializes_correctly() { use github_copilot_sdk::McpStdioServerConfig; - // Missing field → None + // Missing field → None let cfg: McpStdioServerConfig = serde_json::from_value(serde_json::json!({ "command": "echo" })).unwrap(); assert_eq!(cfg.tools, None); - // Empty list → Some(empty) + // Empty list → Some(empty) let cfg: McpStdioServerConfig = serde_json::from_value(serde_json::json!({ "command": "echo", "tools": [] })).unwrap(); assert_eq!(cfg.tools, Some(vec![])); - // Non-empty list → Some(list) + // Non-empty list → Some(list) let cfg: McpStdioServerConfig = serde_json::from_value(serde_json::json!({ "command": "echo", "tools": ["x"] })).unwrap(); assert_eq!(cfg.tools, Some(vec!["x".to_string()])); @@ -1044,7 +1053,7 @@ async fn list_models_returns_typed_model_info() { #[tokio::test] async fn get_messages_returns_typed_events() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -1076,7 +1085,7 @@ async fn get_messages_returns_typed_events() { #[tokio::test] #[allow(deprecated)] async fn deprecated_get_messages_alias_still_works() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -1106,7 +1115,7 @@ async fn deprecated_get_messages_alias_still_works() { #[tokio::test] async fn set_model_sends_switch_to_request() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -1129,11 +1138,9 @@ async fn set_model_sends_switch_to_request() { #[tokio::test] async fn elicitation_returns_typed_result() { - let (session, mut server) = create_session_pair_with_capabilities( - Arc::new(NoopHandler), - serde_json::json!({ "ui": { "elicitation": true } }), - ) - .await; + let (session, mut server) = + create_session_pair_with_capabilities(serde_json::json!({ "ui": { "elicitation": true } })) + .await; let session = Arc::new(session); let schema = serde_json::json!({ "type": "object", @@ -1172,57 +1179,24 @@ async fn elicitation_returns_typed_result() { assert_eq!(result.content.unwrap()["name"], "Octocat"); } -#[tokio::test] -async fn tool_call_dispatches_to_handler() { - struct ToolHandler; - #[async_trait] - impl SessionHandler for ToolHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::ExternalTool { invocation } => { - assert_eq!(invocation.tool_name, "read_file"); - HandlerResponse::ToolResult(ToolResult::Text("file contents here".to_string())) - } - _ => HandlerResponse::Ok, - } - } - } - - let (_session, mut server) = create_session_pair(Arc::new(ToolHandler)).await; - server - .send_request( - 100, - "tool.call", - serde_json::json!({ - "sessionId": server.session_id, - "toolCallId": "tc-1", - "toolName": "read_file", - "arguments": { "path": "/foo.txt" }, - }), - ) - .await; - - let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); - assert_eq!(response["id"], 100); - assert_eq!(response["result"]["result"], "file contents here"); -} - #[tokio::test] async fn permission_request_dispatches_to_handler() { struct DenyHandler; #[async_trait] - impl SessionHandler for DenyHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::PermissionRequest { .. } => { - HandlerResponse::Permission(PermissionResult::Denied) - } - _ => HandlerResponse::Ok, - } + impl PermissionHandler for DenyHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Denied } } - let (_session, mut server) = create_session_pair(Arc::new(DenyHandler)).await; + let (_session, mut server) = + create_session_pair_with_config(|cfg| cfg.with_permission_handler(Arc::new(DenyHandler))) + .await; server .send_request( 200, @@ -1244,22 +1218,25 @@ async fn permission_request_dispatches_to_handler() { async fn user_input_request_dispatches_to_handler() { struct InputHandler; #[async_trait] - impl SessionHandler for InputHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::UserInput { question, .. } => { - assert_eq!(question, "Pick a color"); - HandlerResponse::UserInput(Some(UserInputResponse { - answer: "blue".to_string(), - was_freeform: true, - })) - } - _ => HandlerResponse::Ok, - } + impl UserInputHandler for InputHandler { + async fn handle( + &self, + _session_id: SessionId, + question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + assert_eq!(question, "Pick a color"); + Some(UserInputResponse { + answer: "blue".to_string(), + was_freeform: true, + }) } } - let (_session, mut server) = create_session_pair(Arc::new(InputHandler)).await; + let (_session, mut server) = + create_session_pair_with_config(|cfg| cfg.with_user_input_handler(Arc::new(InputHandler))) + .await; server .send_request( 300, @@ -1283,8 +1260,8 @@ async fn user_input_request_dispatches_to_handler() { async fn exit_plan_mode_request_dispatches_to_handler() { struct ExitHandler; #[async_trait] - impl SessionHandler for ExitHandler { - async fn on_exit_plan_mode( + impl ExitPlanModeHandler for ExitHandler { + async fn handle( &self, _session_id: SessionId, data: ExitPlanModeData, @@ -1304,7 +1281,10 @@ async fn exit_plan_mode_request_dispatches_to_handler() { } } - let (_session, mut server) = create_session_pair(Arc::new(ExitHandler)).await; + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_exit_plan_mode_handler(Arc::new(ExitHandler)) + }) + .await; server .send_request( 310, @@ -1330,8 +1310,8 @@ async fn exit_plan_mode_request_dispatches_to_handler() { async fn auto_mode_switch_request_dispatches_to_handler() { struct AutoModeHandler; #[async_trait] - impl SessionHandler for AutoModeHandler { - async fn on_auto_mode_switch( + impl AutoModeSwitchHandler for AutoModeHandler { + async fn handle( &self, _session_id: SessionId, error_code: Option, @@ -1343,7 +1323,10 @@ async fn auto_mode_switch_request_dispatches_to_handler() { } } - let (_session, mut server) = create_session_pair(Arc::new(AutoModeHandler)).await; + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_auto_mode_switch_handler(Arc::new(AutoModeHandler)) + }) + .await; server .send_request( 311, @@ -1363,7 +1346,10 @@ async fn auto_mode_switch_request_dispatches_to_handler() { #[tokio::test] async fn default_exit_plan_mode_response_omits_optional_fields() { - let (_session, mut server) = create_session_pair(Arc::new(ApproveAllHandler)).await; + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_permission_handler(Arc::new(ApproveAllHandler)) + }) + .await; server .send_request( 312, @@ -1389,23 +1375,26 @@ async fn user_input_requested_notification_does_not_double_dispatch() { // `user_input.requested` notification (for observers) AND a // `userInput.request` JSON-RPC call (the actual prompt) for every // user-input prompt. Only the JSON-RPC path should reach the - // handler — dispatching from the notification too produced + // handler — dispatching from the notification too produced // duplicate ask_user widgets on the consumer side. struct CountingHandler { invocations: Arc, } #[async_trait] - impl SessionHandler for CountingHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - if let HandlerEvent::UserInput { .. } = event { - self.invocations.fetch_add(1, Ordering::SeqCst); - return HandlerResponse::UserInput(Some(UserInputResponse { - answer: "ok".to_string(), - was_freeform: true, - })); - } - HandlerResponse::Ok + impl UserInputHandler for CountingHandler { + async fn handle( + &self, + _session_id: SessionId, + _question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + self.invocations.fetch_add(1, Ordering::SeqCst); + Some(UserInputResponse { + answer: "ok".to_string(), + was_freeform: true, + }) } } @@ -1413,7 +1402,8 @@ async fn user_input_requested_notification_does_not_double_dispatch() { let handler = Arc::new(CountingHandler { invocations: invocations.clone(), }); - let (_session, mut server) = create_session_pair(handler).await; + let (_session, mut server) = + create_session_pair_with_config(move |cfg| cfg.with_user_input_handler(handler)).await; server .send_event( @@ -1460,7 +1450,10 @@ async fn user_input_requested_notification_does_not_double_dispatch() { #[tokio::test] async fn approve_all_handler_approves_permission() { - let (_session, mut server) = create_session_pair(Arc::new(ApproveAllHandler)).await; + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_permission_handler(Arc::new(ApproveAllHandler)) + }) + .await; server .send_request( @@ -1479,61 +1472,28 @@ async fn approve_all_handler_approves_permission() { #[tokio::test] async fn session_event_notification_reaches_handler() { - let (event_tx, mut event_rx) = mpsc::unbounded_channel::(); - - struct EventCollector { - tx: mpsc::UnboundedSender, - } - #[async_trait] - impl SessionHandler for EventCollector { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - if let HandlerEvent::SessionEvent { event, .. } = event { - self.tx.send(event.event_type).unwrap(); - } - HandlerResponse::Ok - } - } - - let (_session, mut server) = - create_session_pair(Arc::new(EventCollector { tx: event_tx })).await; + let (session, mut server) = create_session_pair().await; + let mut sub = session.subscribe(); server .send_event("session.idle", serde_json::json!({})) .await; - let event_type = timeout(TIMEOUT, event_rx.recv()).await.unwrap().unwrap(); - assert_eq!(event_type, "session.idle"); + let event = timeout(TIMEOUT, sub.recv()).await.unwrap().unwrap(); + assert_eq!(event.event_type, "session.idle"); } #[tokio::test] async fn router_routes_to_correct_session() { let (client, mut server_read, mut server_write) = make_client(); - let (tx1, mut rx1) = mpsc::unbounded_channel::(); - let (tx2, mut rx2) = mpsc::unbounded_channel::(); - - struct Collector { - tx: mpsc::UnboundedSender, - } - #[async_trait] - impl SessionHandler for Collector { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - if let HandlerEvent::SessionEvent { event, .. } = event { - self.tx.send(event.event_type).unwrap(); - } - HandlerResponse::Ok - } - } - // Create two sessions on the same client let mut sessions = Vec::new(); let mut session_ids = Vec::new(); - for tx in [tx1, tx2] { + for _ in 0..2 { let h = tokio::spawn({ let client = client.clone(); async move { client - .create_session( - SessionConfig::default().with_handler(Arc::new(Collector { tx })), - ) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -1550,7 +1510,10 @@ async fn router_routes_to_correct_session() { sessions.push(timeout(TIMEOUT, h).await.unwrap().unwrap()); } - // Event for s-two should only reach rx2 + let mut sub1 = sessions[0].subscribe(); + let mut sub2 = sessions[1].subscribe(); + + // Event for s-two should only reach sub2 let notif = serde_json::json!({ "jsonrpc": "2.0", "method": "session.event", @@ -1561,12 +1524,20 @@ async fn router_routes_to_correct_session() { }); write_framed(&mut server_write, &serde_json::to_vec(¬if).unwrap()).await; assert_eq!( - timeout(TIMEOUT, rx2.recv()).await.unwrap().unwrap(), + timeout(TIMEOUT, sub2.recv()) + .await + .unwrap() + .unwrap() + .event_type, "assistant.message" ); - assert!(rx1.try_recv().is_err()); + assert!( + timeout(Duration::from_millis(100), sub1.recv()) + .await + .is_err() + ); - // Event for s-one should only reach rx1 + // Event for s-one should only reach sub1 let notif = serde_json::json!({ "jsonrpc": "2.0", "method": "session.event", @@ -1577,15 +1548,23 @@ async fn router_routes_to_correct_session() { }); write_framed(&mut server_write, &serde_json::to_vec(¬if).unwrap()).await; assert_eq!( - timeout(TIMEOUT, rx1.recv()).await.unwrap().unwrap(), + timeout(TIMEOUT, sub1.recv()) + .await + .unwrap() + .unwrap() + .event_type, "session.idle" ); - assert!(rx2.try_recv().is_err()); + assert!( + timeout(Duration::from_millis(100), sub2.recv()) + .await + .is_err() + ); } #[tokio::test] async fn send_and_wait_returns_last_assistant_message_on_idle() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -1621,7 +1600,7 @@ async fn send_and_wait_returns_last_assistant_message_on_idle() { #[tokio::test] async fn send_and_wait_returns_error_on_session_error() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -1656,7 +1635,7 @@ async fn send_and_wait_returns_error_on_session_error() { #[tokio::test] async fn send_and_wait_times_out() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let handle = tokio::spawn({ @@ -1692,7 +1671,7 @@ async fn send_and_wait_times_out() { /// Closes RFD-400 review finding #2. #[tokio::test] async fn send_and_wait_outer_cancellation_clears_waiter() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); // First call: wrap in outer timeout much shorter than the inner @@ -1714,7 +1693,7 @@ async fn send_and_wait_outer_cancellation_clears_waiter() { let request = server.read_request().await; server.respond(&request, serde_json::json!({})).await; - // Outer timeout fires → Err(Elapsed) returned, future is dropped. + // Outer timeout fires → Err(Elapsed) returned, future is dropped. let outer_result = timeout(Duration::from_secs(2), handle) .await .unwrap() @@ -1749,7 +1728,7 @@ async fn send_and_wait_outer_cancellation_clears_waiter() { /// Closes RFD-400 review finding #2. #[tokio::test] async fn send_and_wait_drop_clears_waiter() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); // Start a send_and_wait, let it install the waiter, then abort the @@ -1777,7 +1756,7 @@ async fn send_and_wait_drop_clears_waiter() { // Give the runtime a moment to run the drop. tokio::task::yield_now().await; - // Next `send` must succeed — no SendWhileWaiting. + // Next `send` must succeed — no SendWhileWaiting. let send_handle = tokio::spawn({ let session = session.clone(); async move { session.send("after-abort").await } @@ -1800,7 +1779,7 @@ async fn send_and_wait_drop_clears_waiter() { /// Cancel-safety regression: `Session::stop_event_loop` must NOT abort /// the event-loop task mid-handler. An in-flight handler (here a slow /// `userInput.request` callback) must run to completion before the loop -/// exits — the CLI receives the response on the wire before the session +/// exits — the CLI receives the response on the wire before the session /// tears down. /// /// Closes RFD-400 review finding #3. @@ -1808,25 +1787,25 @@ async fn send_and_wait_drop_clears_waiter() { async fn stop_event_loop_completes_in_flight_handler() { struct SlowHandler; #[async_trait] - impl SessionHandler for SlowHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::UserInput { .. } => { - // Sleep so stop_event_loop has a chance to fire while - // the handler is mid-flight. The loop must wait for - // this to return rather than abort it. - tokio::time::sleep(Duration::from_millis(150)).await; - HandlerResponse::UserInput(Some(UserInputResponse { - answer: "completed".to_string(), - was_freeform: false, - })) - } - _ => HandlerResponse::Ok, - } + impl UserInputHandler for SlowHandler { + async fn handle( + &self, + _session_id: SessionId, + _question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + tokio::time::sleep(Duration::from_millis(150)).await; + Some(UserInputResponse { + answer: "completed".to_string(), + was_freeform: false, + }) } } - let (session, mut server) = create_session_pair(Arc::new(SlowHandler)).await; + let (session, mut server) = + create_session_pair_with_config(|cfg| cfg.with_user_input_handler(Arc::new(SlowHandler))) + .await; let session = Arc::new(session); server @@ -1855,7 +1834,7 @@ async fn stop_event_loop_completes_in_flight_handler() { }); // Verify the handler's response lands on the wire BEFORE the loop - // exits — i.e. stop_event_loop did not abort mid-handler. + // exits — i.e. stop_event_loop did not abort mid-handler. let response = timeout(Duration::from_secs(2), server.read_response()) .await .unwrap(); @@ -1886,26 +1865,28 @@ async fn drop_session_does_not_abort_handler() { completed: Arc, } #[async_trait] - impl SessionHandler for CompletionHandler { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::UserInput { .. } => { - tokio::time::sleep(Duration::from_millis(100)).await; - self.completed.store(true, Ordering::SeqCst); - HandlerResponse::UserInput(Some(UserInputResponse { - answer: "done".to_string(), - was_freeform: false, - })) - } - _ => HandlerResponse::Ok, - } + impl UserInputHandler for CompletionHandler { + async fn handle( + &self, + _session_id: SessionId, + _question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + tokio::time::sleep(Duration::from_millis(100)).await; + self.completed.store(true, Ordering::SeqCst); + Some(UserInputResponse { + answer: "done".to_string(), + was_freeform: false, + }) } } - let (session, mut server) = create_session_pair(Arc::new(CompletionHandler { + let handler = Arc::new(CompletionHandler { completed: handler_completed.clone(), - })) - .await; + }); + let (session, mut server) = + create_session_pair_with_config(move |cfg| cfg.with_user_input_handler(handler)).await; server .send_request( @@ -1940,8 +1921,10 @@ async fn drop_session_does_not_abort_handler() { /// session itself. #[tokio::test] async fn cancellation_token_fires_on_session_drop() { - let handler = Arc::new(ApproveAllHandler); - let (session, _server) = create_session_pair(handler).await; + let (session, _server) = create_session_pair_with_config(|cfg| { + cfg.with_permission_handler(Arc::new(ApproveAllHandler)) + }) + .await; let token = session.cancellation_token(); assert!(!token.is_cancelled()); @@ -1957,12 +1940,14 @@ async fn cancellation_token_fires_on_session_drop() { } /// Cancelling a child token returned by `cancellation_token()` does NOT -/// shut the session down — child tokens isolate consumer-side cancel +/// shut the session down — child tokens isolate consumer-side cancel /// logic from the session's own lifecycle. #[tokio::test] async fn cancellation_token_child_cancel_does_not_kill_session() { - let handler = Arc::new(ApproveAllHandler); - let (session, _server) = create_session_pair(handler).await; + let (session, _server) = create_session_pair_with_config(|cfg| { + cfg.with_permission_handler(Arc::new(ApproveAllHandler)) + }) + .await; let child = session.cancellation_token(); child.cancel(); @@ -1979,26 +1964,25 @@ async fn elicitation_requested_dispatches_to_handler_and_responds() { struct ElicitHandler; #[async_trait] - impl SessionHandler for ElicitHandler { - fn wants_elicitation_dispatch(&self) -> bool { - true - } - - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::ElicitationRequest { request, .. } => { - assert_eq!(request.message, "Enter your name"); - HandlerResponse::Elicitation(ElicitationResult { - action: "accept".to_string(), - content: Some(serde_json::json!({ "name": "Alice" })), - }) - } - _ => HandlerResponse::Ok, + impl ElicitationHandler for ElicitHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + request: ElicitationRequest, + ) -> ElicitationResult { + assert_eq!(request.message, "Enter your name"); + ElicitationResult { + action: "accept".to_string(), + content: Some(serde_json::json!({ "name": "Alice" })), } } } - let (_session, mut server) = create_session_pair(Arc::new(ElicitHandler)).await; + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_elicitation_handler(Arc::new(ElicitHandler)) + }) + .await; // CLI broadcasts elicitation.requested as a session event notification server @@ -2029,21 +2013,23 @@ async fn elicitation_requested_dispatches_to_handler_and_responds() { async fn elicitation_requested_cancels_on_handler_error() { struct FailHandler; #[async_trait] - impl SessionHandler for FailHandler { - fn wants_elicitation_dispatch(&self) -> bool { - true - } - - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - // Return Ok instead of Elicitation — SDK should treat as cancel - HandlerEvent::ElicitationRequest { .. } => HandlerResponse::Ok, - _ => HandlerResponse::Ok, + impl ElicitationHandler for FailHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + _request: ElicitationRequest, + ) -> ElicitationResult { + ElicitationResult { + action: "cancel".to_string(), + content: None, } } } - let (_session, mut server) = create_session_pair(Arc::new(FailHandler)).await; + let (_session, mut server) = + create_session_pair_with_config(|cfg| cfg.with_elicitation_handler(Arc::new(FailHandler))) + .await; server .send_event( "elicitation.requested", @@ -2061,27 +2047,29 @@ async fn elicitation_requested_cancels_on_handler_error() { #[tokio::test] async fn external_tool_requested_dispatches_to_handler_and_responds() { - struct ExternalToolHandler; + struct RunTestsTool; #[async_trait] - impl SessionHandler for ExternalToolHandler { - fn wants_external_tool_dispatch(&self, _tool_name: &str) -> bool { - true + impl tool::ToolHandler for RunTestsTool { + fn tool(&self) -> Tool { + Tool::new("run_tests") + .with_description("Run tests") + .with_parameters(serde_json::json!({"type":"object"})) } - - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::ExternalTool { invocation } => { - assert_eq!(invocation.tool_name, "run_tests"); - assert_eq!(invocation.tool_call_id, "tc-ext-1"); - assert_eq!(invocation.arguments["suite"], "unit"); - HandlerResponse::ToolResult(ToolResult::Text("all tests passed".to_string())) - } - _ => HandlerResponse::Ok, - } + async fn call( + &self, + invocation: ToolInvocation, + ) -> Result { + assert_eq!(invocation.tool_name, "run_tests"); + assert_eq!(invocation.tool_call_id, "tc-ext-1"); + assert_eq!(invocation.arguments["suite"], "unit"); + Ok(ToolResult::Text("all tests passed".to_string())) } } - let (_session, mut server) = create_session_pair(Arc::new(ExternalToolHandler)).await; + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_tool_handlers(vec![Arc::new(RunTestsTool) as Arc]) + }) + .await; server .send_event( @@ -2105,21 +2093,28 @@ async fn external_tool_requested_dispatches_to_handler_and_responds() { #[tokio::test] async fn external_tool_broadcast_for_unknown_tool_is_not_responded_to() { // Phase H multi-client safety: a handler that doesn't claim the - // requested tool name must not send an RPC response — another client + // requested tool name must not send an RPC response — another client // on the same CLI may have a real handler. - struct OnlyKnowsFoo; + struct FooTool; #[async_trait] - impl SessionHandler for OnlyKnowsFoo { - fn wants_external_tool_dispatch(&self, tool_name: &str) -> bool { - tool_name == "foo" + impl tool::ToolHandler for FooTool { + fn tool(&self) -> Tool { + Tool::new("foo") + .with_description("foo") + .with_parameters(serde_json::json!({"type":"object"})) } - - async fn on_event(&self, _event: HandlerEvent) -> HandlerResponse { - HandlerResponse::Ok + async fn call( + &self, + _invocation: ToolInvocation, + ) -> Result { + Ok(ToolResult::Text("foo".to_string())) } } - let (_session, mut server) = create_session_pair(Arc::new(OnlyKnowsFoo)).await; + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_tool_handlers(vec![Arc::new(FooTool) as Arc]) + }) + .await; server .send_event( "external_tool.requested", @@ -2147,7 +2142,10 @@ async fn external_tool_broadcast_for_unknown_tool_is_not_responded_to() { async fn permission_broadcast_with_resolved_by_hook_is_not_responded_to() { // Phase H: when the runtime marks a permission request as already // resolved by a hook, the client must not respond again. - let (_session, mut server) = create_session_pair(Arc::new(ApproveAllHandler)).await; + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_permission_handler(Arc::new(ApproveAllHandler)) + }) + .await; server .send_event( "permission.requested", @@ -2171,8 +2169,8 @@ async fn permission_broadcast_with_resolved_by_hook_is_not_responded_to() { #[tokio::test] async fn permission_broadcast_with_no_claiming_handler_is_not_responded_to() { // Phase H: a handler that doesn't claim permission dispatch must not - // respond — the SDK lets other connected clients handle the request. - let (_session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + // respond — the SDK lets other connected clients handle the request. + let (_session, mut server) = create_session_pair().await; server .send_event( "permission.requested", @@ -2196,7 +2194,10 @@ async fn permission_broadcast_with_no_claiming_handler_is_not_responded_to() { async fn elicitation_broadcast_with_no_claiming_handler_is_not_responded_to() { // Phase H: same gating for elicitation. The default handler doesn't // claim elicitation, so broadcasts are silently dropped. - let (_session, mut server) = create_session_pair(Arc::new(ApproveAllHandler)).await; + let (_session, mut server) = create_session_pair_with_config(|cfg| { + cfg.with_permission_handler(Arc::new(ApproveAllHandler)) + }) + .await; server .send_event( "elicitation.requested", @@ -2223,7 +2224,7 @@ async fn capabilities_captured_from_create_response() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -2251,7 +2252,7 @@ async fn capabilities_captured_from_create_response() { #[tokio::test] async fn capabilities_changed_event_updates_session() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; // Initially no capabilities (create_session_pair doesn't send them) assert!(session.capabilities().ui.is_none()); @@ -2290,7 +2291,9 @@ async fn request_elicitation_sent_in_create_params() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(ApproveAllHandler))) + .create_session( + SessionConfig::default().with_permission_handler(Arc::new(ApproveAllHandler)), + ) .await .unwrap() } @@ -2298,12 +2301,12 @@ async fn request_elicitation_sent_in_create_params() { let request = read_framed(&mut server_read).await; assert_eq!(request["method"], "session.create"); - // ApproveAllHandler claims permission dispatch but not elicitation, so - // the wire flags reflect that exact responsibility. + // ApproveAllHandler claims permission dispatch only; no other handlers + // are installed, so the wire flags reflect that exact responsibility. assert_eq!(request["params"]["requestPermission"], true); assert_eq!(request["params"]["requestElicitation"], false); - assert_eq!(request["params"]["requestExitPlanMode"], true); - assert_eq!(request["params"]["requestAutoModeSwitch"], true); + assert_eq!(request["params"]["requestExitPlanMode"], false); + assert_eq!(request["params"]["requestAutoModeSwitch"], false); let id = request["id"].as_u64().unwrap(); let session_id = requested_session_id(&request); @@ -2327,7 +2330,7 @@ async fn noop_handler_sends_request_permission_false() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -2358,7 +2361,7 @@ async fn env_value_mode_hardcoded_direct_on_create_and_resume() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -2382,8 +2385,7 @@ async fn env_value_mode_hardcoded_direct_on_create_and_resume() { let client = client.clone(); let session_id = session_id.clone(); async move { - let cfg = ResumeSessionConfig::new(SessionId::from(session_id)) - .with_handler(Arc::new(NoopHandler)); + let cfg = ResumeSessionConfig::new(SessionId::from(session_id)); client.resume_session(cfg).await.unwrap() } }); @@ -2412,9 +2414,9 @@ async fn env_value_mode_hardcoded_direct_on_create_and_resume() { #[tokio::test] async fn elicitation_methods_fail_without_capability() { - let (session, _server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, _server) = create_session_pair().await; - // Session created without capabilities — elicitation should fail + // Session created without capabilities — elicitation should fail let err = session .ui() .elicitation("test", serde_json::json!({})) @@ -2437,7 +2439,6 @@ async fn elicitation_methods_fail_without_capability() { } async fn create_session_pair_with_hooks( - handler: Arc, hooks: Arc, ) -> (github_copilot_sdk::session::Session, FakeServer) { let (client, server_read, server_write) = make_client(); @@ -2450,14 +2451,9 @@ async fn create_session_pair_with_hooks( let create_handle = tokio::spawn({ let client = client.clone(); - let handler = handler.clone(); async move { client - .create_session( - SessionConfig::default() - .with_handler(handler) - .with_hooks(hooks), - ) + .create_session(SessionConfig::default().with_hooks(hooks)) .await .unwrap() } @@ -2507,8 +2503,7 @@ async fn hooks_invoke_dispatches_to_session_hooks() { } } - let (_session, mut server) = - create_session_pair_with_hooks(Arc::new(NoopHandler), Arc::new(PolicyHooks)).await; + let (_session, mut server) = create_session_pair_with_hooks(Arc::new(PolicyHooks)).await; // Send a hooks.invoke request for a denied tool server @@ -2546,8 +2541,7 @@ async fn hooks_invoke_returns_empty_for_unregistered_hook() { #[async_trait] impl SessionHooks for EmptyHooks {} - let (_session, mut server) = - create_session_pair_with_hooks(Arc::new(NoopHandler), Arc::new(EmptyHooks)).await; + let (_session, mut server) = create_session_pair_with_hooks(Arc::new(EmptyHooks)).await; server .send_request( @@ -2572,7 +2566,6 @@ async fn hooks_invoke_returns_empty_for_unregistered_hook() { } async fn create_session_pair_with_transforms( - handler: Arc, transforms: Arc, ) -> (github_copilot_sdk::session::Session, FakeServer) { let (client, server_read, server_write) = make_client(); @@ -2585,14 +2578,9 @@ async fn create_session_pair_with_transforms( let create_handle = tokio::spawn({ let client = client.clone(); - let handler = handler.clone(); async move { client - .create_session( - SessionConfig::default() - .with_handler(handler) - .with_transform(transforms), - ) + .create_session(SessionConfig::default().with_transform(transforms)) .await .unwrap() } @@ -2639,7 +2627,7 @@ async fn system_message_transform_dispatches_to_transform() { } let (_session, mut server) = - create_session_pair_with_transforms(Arc::new(NoopHandler), Arc::new(AppendTransform)).await; + create_session_pair_with_transforms(Arc::new(AppendTransform)).await; server .send_request( @@ -2684,7 +2672,7 @@ async fn system_message_transform_returns_error_for_missing_sections() { } let (_session, mut server) = - create_session_pair_with_transforms(Arc::new(NoopHandler), Arc::new(DummyTransform)).await; + create_session_pair_with_transforms(Arc::new(DummyTransform)).await; // Send request with no sections parameter server @@ -2704,7 +2692,7 @@ async fn system_message_transform_returns_error_for_missing_sections() { #[tokio::test] async fn rpc_namespace_session_agent_list_dispatches_correctly() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let s = session.clone(); @@ -2723,7 +2711,7 @@ async fn rpc_namespace_session_agent_list_dispatches_correctly() { #[tokio::test] async fn rpc_namespace_session_tasks_list_dispatches_correctly() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let s = session.clone(); @@ -2742,7 +2730,7 @@ async fn rpc_namespace_session_tasks_list_dispatches_correctly() { #[tokio::test] async fn rpc_namespace_client_models_list_dispatches_correctly() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let client = session.client().clone(); @@ -2775,7 +2763,7 @@ async fn client_stop_sends_session_destroy_for_each_active_session() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -2795,7 +2783,7 @@ async fn client_stop_sends_session_destroy_for_each_active_session() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -2835,9 +2823,9 @@ async fn client_stop_sends_session_destroy_for_each_active_session() { #[tokio::test] async fn client_stop_aggregates_session_destroy_errors() { - // session.destroy fails on the wire — Client::stop returns + // session.destroy fails on the wire — Client::stop returns // StopErrors carrying the failure rather than short-circuiting. - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let client = session.client().clone(); let stop_handle = tokio::spawn(async move { client.stop().await }); @@ -2940,7 +2928,7 @@ fn resume_session_config_serializes_bucket_b_fields() { } // ===================================================================== -// Slash commands (§ 4.1) +// Slash commands (§ 4.1) // ===================================================================== struct CountingCommandHandler { @@ -2963,7 +2951,6 @@ impl CommandHandler for CountingCommandHandler { } async fn create_session_pair_with_commands( - handler: Arc, commands: Vec, ) -> (github_copilot_sdk::session::Session, FakeServer, Value) { let (client, server_read, server_write) = make_client(); @@ -2976,14 +2963,9 @@ async fn create_session_pair_with_commands( let create_handle = tokio::spawn({ let client = client.clone(); - let handler = handler.clone(); async move { client - .create_session( - SessionConfig::default() - .with_handler(handler) - .with_commands(commands), - ) + .create_session(SessionConfig::default().with_commands(commands)) .await .unwrap() } @@ -3027,8 +3009,7 @@ async fn create_serializes_commands_strips_handler() { ), ]; - let (_session, _server, create_req) = - create_session_pair_with_commands(Arc::new(NoopHandler), commands).await; + let (_session, _server, create_req) = create_session_pair_with_commands(commands).await; let wire = create_req["params"]["commands"] .as_array() @@ -3067,8 +3048,7 @@ async fn command_execute_dispatches_to_registered_handler_and_acks_success() { }), )]; - let (session, mut server, _) = - create_session_pair_with_commands(Arc::new(NoopHandler), commands).await; + let (session, mut server, _) = create_session_pair_with_commands(commands).await; server .send_event( @@ -3113,8 +3093,7 @@ async fn command_execute_dispatches_to_registered_handler_and_acks_success() { #[tokio::test] async fn command_execute_unknown_command_acks_with_error() { - let (session, mut server, _) = - create_session_pair_with_commands(Arc::new(NoopHandler), vec![]).await; + let (session, mut server, _) = create_session_pair_with_commands(vec![]).await; server .send_event( @@ -3152,8 +3131,7 @@ async fn command_execute_handler_error_propagates_to_ack() { }), )]; - let (_session, mut server, _) = - create_session_pair_with_commands(Arc::new(NoopHandler), commands).await; + let (_session, mut server, _) = create_session_pair_with_commands(commands).await; server .send_event( @@ -3314,7 +3292,6 @@ impl SessionFsSqliteProvider for RecordingFsProvider { } async fn create_session_pair_with_fs_provider( - handler: Arc, provider: Arc, ) -> (github_copilot_sdk::session::Session, FakeServer) { let (client, server_read, server_write) = make_client(); @@ -3327,14 +3304,9 @@ async fn create_session_pair_with_fs_provider( let create_handle = tokio::spawn({ let client = client.clone(); - let handler = handler.clone(); async move { client - .create_session( - SessionConfig::default() - .with_handler(handler) - .with_session_fs_provider(provider), - ) + .create_session(SessionConfig::default().with_session_fs_provider(provider)) .await .unwrap() } @@ -3360,8 +3332,7 @@ async fn create_session_pair_with_fs_provider( #[tokio::test] async fn session_fs_dispatches_read_file_to_provider() { let provider = Arc::new(RecordingFsProvider::new().with_file("/foo.txt", "hello world")); - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + let (_session, mut server) = create_session_pair_with_fs_provider(provider).await; server .send_request( @@ -3380,8 +3351,7 @@ async fn session_fs_dispatches_read_file_to_provider() { #[tokio::test] async fn session_fs_maps_not_found_to_enoent() { let provider = Arc::new(RecordingFsProvider::new()); - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + let (_session, mut server) = create_session_pair_with_fs_provider(provider).await; server .send_request( @@ -3408,8 +3378,7 @@ async fn session_fs_maps_other_to_unknown() { } } - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), Arc::new(AlwaysFails)).await; + let (_session, mut server) = create_session_pair_with_fs_provider(Arc::new(AlwaysFails)).await; server .send_request( @@ -3433,8 +3402,7 @@ async fn session_fs_maps_other_to_unknown() { #[tokio::test] async fn session_fs_dispatches_sqlite_query_to_provider() { let provider = Arc::new(RecordingFsProvider::new()); - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + let (_session, mut server) = create_session_pair_with_fs_provider(provider).await; server .send_request( @@ -3465,8 +3433,7 @@ async fn session_fs_dispatches_sqlite_query_to_provider() { #[tokio::test] async fn session_fs_dispatches_sqlite_exists_to_provider() { let provider = Arc::new(RecordingFsProvider::new()); - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + let (_session, mut server) = create_session_pair_with_fs_provider(provider).await; server .send_request( @@ -3506,8 +3473,7 @@ async fn session_fs_maps_sqlite_errors_to_results() { } } - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), Arc::new(AlwaysFails)).await; + let (_session, mut server) = create_session_pair_with_fs_provider(Arc::new(AlwaysFails)).await; server .send_request( @@ -3549,8 +3515,7 @@ async fn session_fs_maps_sqlite_errors_to_results() { #[tokio::test] async fn session_fs_dispatches_write_file_with_mode() { let provider = Arc::new(RecordingFsProvider::new()); - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider.clone()).await; + let (_session, mut server) = create_session_pair_with_fs_provider(provider.clone()).await; server .send_request( @@ -3569,8 +3534,7 @@ async fn session_fs_dispatches_write_file_with_mode() { #[tokio::test] async fn session_fs_dispatches_readdir_with_types() { let provider = Arc::new(RecordingFsProvider::new()); - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + let (_session, mut server) = create_session_pair_with_fs_provider(provider).await; server .send_request( @@ -3592,8 +3556,7 @@ async fn session_fs_dispatches_readdir_with_types() { #[tokio::test] async fn session_fs_dispatches_rm_with_force() { let provider = Arc::new(RecordingFsProvider::new()); - let (_session, mut server) = - create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + let (_session, mut server) = create_session_pair_with_fs_provider(provider).await; server .send_request( @@ -3687,7 +3650,7 @@ async fn on_get_trace_context_called_on_session_create() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -3726,8 +3689,7 @@ async fn on_get_trace_context_called_on_session_resume() { let resume_handle = tokio::spawn({ let client = client.clone(); async move { - let cfg = ResumeSessionConfig::new(SessionId::from("trace-resume")) - .with_handler(Arc::new(NoopHandler)); + let cfg = ResumeSessionConfig::new(SessionId::from("trace-resume")); client.resume_session(cfg).await.unwrap() } }); @@ -3772,7 +3734,7 @@ async fn on_get_trace_context_called_on_session_send() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -3825,7 +3787,7 @@ async fn message_options_trace_context_overrides_callback() { let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .create_session(SessionConfig::default()) .await .unwrap() } @@ -3874,7 +3836,7 @@ async fn message_options_trace_context_overrides_callback() { #[tokio::test] async fn message_options_trace_context_used_without_callback() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let send_handle = tokio::spawn({ @@ -3902,37 +3864,42 @@ async fn message_options_trace_context_used_without_callback() { #[tokio::test] async fn tool_invocation_carries_trace_context_from_event() { - use github_copilot_sdk::handler::{HandlerEvent, HandlerResponse, SessionHandler}; - - struct CapturingHandler { - captured: parking_lot::Mutex, Option)>>, - signal: tokio::sync::Notify, + type CapturedTrace = Arc, Option)>>>; + struct CapturingTool { + captured: CapturedTrace, + signal: Arc, } #[async_trait] - impl SessionHandler for CapturingHandler { - fn wants_external_tool_dispatch(&self, _tool_name: &str) -> bool { - true - } - - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - if let HandlerEvent::ExternalTool { invocation } = event { - *self.captured.lock() = Some(( - invocation.traceparent.clone(), - invocation.tracestate.clone(), - )); - self.signal.notify_one(); - return HandlerResponse::ToolResult(ToolResult::Text("ok".into())); - } - HandlerResponse::Ok + impl tool::ToolHandler for CapturingTool { + fn tool(&self) -> Tool { + Tool::new("calc") + .with_description("calc") + .with_parameters(serde_json::json!({"type":"object"})) + } + async fn call( + &self, + invocation: ToolInvocation, + ) -> Result { + *self.captured.lock() = Some(( + invocation.traceparent.clone(), + invocation.tracestate.clone(), + )); + self.signal.notify_one(); + Ok(ToolResult::Text("ok".into())) } } - let handler = Arc::new(CapturingHandler { - captured: parking_lot::Mutex::new(None), - signal: tokio::sync::Notify::new(), + let captured = Arc::new(parking_lot::Mutex::new(None)); + let signal = Arc::new(tokio::sync::Notify::new()); + let handler = Arc::new(CapturingTool { + captured: captured.clone(), + signal: signal.clone(), }); - let (_session, mut server) = create_session_pair(handler.clone()).await; + let (_session, mut server) = create_session_pair_with_config(move |cfg| { + cfg.with_tool_handlers(vec![handler as Arc]) + }) + .await; server .send_event( @@ -3953,8 +3920,8 @@ async fn tool_invocation_carries_trace_context_from_event() { let pending = timeout(TIMEOUT, server.read_request()).await.unwrap(); assert_eq!(pending["method"], "session.tools.handlePendingToolCall"); - timeout(TIMEOUT, handler.signal.notified()).await.unwrap(); - let captured = handler.captured.lock().clone(); + timeout(TIMEOUT, signal.notified()).await.unwrap(); + let captured = captured.lock().clone(); assert_eq!( captured, Some((Some("00-tool-01".into()), Some("vendor=tool".into()))), @@ -3963,7 +3930,7 @@ async fn tool_invocation_carries_trace_context_from_event() { #[tokio::test] async fn wire_omits_trace_fields_when_unset() { - let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let (session, mut server) = create_session_pair().await; let session = Arc::new(session); let send_handle = tokio::spawn({ diff --git a/test/scenarios/callbacks/hooks/rust/src/main.rs b/test/scenarios/callbacks/hooks/rust/src/main.rs index d77c18795..c0fcc56d0 100644 --- a/test/scenarios/callbacks/hooks/rust/src/main.rs +++ b/test/scenarios/callbacks/hooks/rust/src/main.rs @@ -102,7 +102,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); let config = config - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_hooks(hooks); let session = client.create_session(config).await?; diff --git a/test/scenarios/callbacks/permissions/rust/src/main.rs b/test/scenarios/callbacks/permissions/rust/src/main.rs index ff8439a2a..c44b691bf 100644 --- a/test/scenarios/callbacks/permissions/rust/src/main.rs +++ b/test/scenarios/callbacks/permissions/rust/src/main.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use async_trait::async_trait; -use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::handler::{PermissionHandler, PermissionResult}; use github_copilot_sdk::hooks::{HookContext, PreToolUseInput, PreToolUseOutput, SessionHooks}; use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionConfig, SessionId}; use github_copilot_sdk::{Client, ClientOptions}; @@ -15,8 +15,8 @@ struct PermissionLogger { } #[async_trait] -impl SessionHandler for PermissionLogger { - async fn on_permission_request( +impl PermissionHandler for PermissionLogger { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -62,7 +62,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); let config = config - .with_handler(handler) + .with_permission_handler(handler) .with_hooks(Arc::new(AllowAllHooks)); let session = client.create_session(config).await?; diff --git a/test/scenarios/callbacks/user-input/rust/src/main.rs b/test/scenarios/callbacks/user-input/rust/src/main.rs index 467404d95..619137084 100644 --- a/test/scenarios/callbacks/user-input/rust/src/main.rs +++ b/test/scenarios/callbacks/user-input/rust/src/main.rs @@ -4,7 +4,9 @@ use std::sync::Arc; use async_trait::async_trait; -use github_copilot_sdk::handler::{PermissionResult, SessionHandler, UserInputResponse}; +use github_copilot_sdk::handler::{ + PermissionHandler, PermissionResult, UserInputHandler, UserInputResponse, +}; use github_copilot_sdk::hooks::{HookContext, PreToolUseInput, PreToolUseOutput, SessionHooks}; use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionConfig, SessionId}; use github_copilot_sdk::{Client, ClientOptions}; @@ -15,8 +17,8 @@ struct InputResponder { } #[async_trait] -impl SessionHandler for InputResponder { - async fn on_permission_request( +impl PermissionHandler for InputResponder { + async fn handle( &self, _session_id: SessionId, _request_id: RequestId, @@ -24,8 +26,11 @@ impl SessionHandler for InputResponder { ) -> PermissionResult { PermissionResult::Approved } +} - async fn on_user_input( +#[async_trait] +impl UserInputHandler for InputResponder { + async fn handle( &self, _session_id: SessionId, question: String, @@ -73,7 +78,8 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { config.model = Some("claude-haiku-4.5".to_string()); config.request_user_input = Some(true); let config = config - .with_handler(handler) + .with_permission_handler(handler.clone()) + .with_user_input_handler(handler) .with_hooks(Arc::new(AllowAllHooks)); let session = client.create_session(config).await?; diff --git a/test/scenarios/modes/default/rust/src/main.rs b/test/scenarios/modes/default/rust/src/main.rs index 862a70ccd..d316c1a0a 100644 --- a/test/scenarios/modes/default/rust/src/main.rs +++ b/test/scenarios/modes/default/rust/src/main.rs @@ -15,7 +15,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; let response = session diff --git a/test/scenarios/prompts/attachments/rust/src/main.rs b/test/scenarios/prompts/attachments/rust/src/main.rs index 2a240c590..ea96d5b56 100644 --- a/test/scenarios/prompts/attachments/rust/src/main.rs +++ b/test/scenarios/prompts/attachments/rust/src/main.rs @@ -25,7 +25,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { config.model = Some("claude-haiku-4.5".to_string()); config.system_message = Some(sysmsg); config.available_tools = Some(Vec::new()); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; diff --git a/test/scenarios/prompts/reasoning-effort/rust/src/main.rs b/test/scenarios/prompts/reasoning-effort/rust/src/main.rs index 3bfc0d1a0..f675da5e5 100644 --- a/test/scenarios/prompts/reasoning-effort/rust/src/main.rs +++ b/test/scenarios/prompts/reasoning-effort/rust/src/main.rs @@ -22,7 +22,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { config.reasoning_effort = Some("low".to_string()); config.available_tools = Some(Vec::new()); config.system_message = Some(sysmsg); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; diff --git a/test/scenarios/prompts/system-message/rust/src/main.rs b/test/scenarios/prompts/system-message/rust/src/main.rs index 034cdea61..7233a64b9 100644 --- a/test/scenarios/prompts/system-message/rust/src/main.rs +++ b/test/scenarios/prompts/system-message/rust/src/main.rs @@ -23,7 +23,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { config.model = Some("claude-haiku-4.5".to_string()); config.system_message = Some(sysmsg); config.available_tools = Some(Vec::new()); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; diff --git a/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs b/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs index 41a4360e6..7a64a6b92 100644 --- a/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs +++ b/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs @@ -19,7 +19,7 @@ fn make_config(system: &str) -> SessionConfig { config.model = Some("claude-haiku-4.5".to_string()); config.system_message = Some(sysmsg); config.available_tools = Some(Vec::new()); - config.with_handler(Arc::new(ApproveAllHandler)) + config.with_permission_handler(Arc::new(ApproveAllHandler)) } #[tokio::main] diff --git a/test/scenarios/sessions/infinite-sessions/rust/src/main.rs b/test/scenarios/sessions/infinite-sessions/rust/src/main.rs index 2ada9314e..2ccb1d786 100644 --- a/test/scenarios/sessions/infinite-sessions/rust/src/main.rs +++ b/test/scenarios/sessions/infinite-sessions/rust/src/main.rs @@ -28,7 +28,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { config.available_tools = Some(Vec::new()); config.system_message = Some(sysmsg); config.infinite_sessions = Some(infinite); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; diff --git a/test/scenarios/sessions/session-resume/rust/src/main.rs b/test/scenarios/sessions/session-resume/rust/src/main.rs index 4ed66c846..b6e6fbf8b 100644 --- a/test/scenarios/sessions/session-resume/rust/src/main.rs +++ b/test/scenarios/sessions/session-resume/rust/src/main.rs @@ -16,7 +16,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); config.available_tools = Some(Vec::new()); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; session @@ -27,7 +27,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { // Note: do NOT destroy — `resume_session` needs the session to persist. let resume_config = - ResumeSessionConfig::new(session_id).with_handler(Arc::new(ApproveAllHandler)); + ResumeSessionConfig::new(session_id).with_permission_handler(Arc::new(ApproveAllHandler)); let resumed = client.resume_session(resume_config).await?; println!("Session resumed"); diff --git a/test/scenarios/sessions/streaming/rust/src/main.rs b/test/scenarios/sessions/streaming/rust/src/main.rs index 6616f35b1..d4201ef06 100644 --- a/test/scenarios/sessions/streaming/rust/src/main.rs +++ b/test/scenarios/sessions/streaming/rust/src/main.rs @@ -4,33 +4,10 @@ use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; -use async_trait::async_trait; -use github_copilot_sdk::handler::{HandlerEvent, HandlerResponse, PermissionResult, SessionHandler}; +use github_copilot_sdk::handler::ApproveAllHandler; use github_copilot_sdk::types::SessionConfig; use github_copilot_sdk::{Client, ClientOptions}; -struct StreamCounter { - chunks: Arc, -} - -#[async_trait] -impl SessionHandler for StreamCounter { - async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { - match event { - HandlerEvent::SessionEvent { event, .. } => { - if event.event_type == "assistant.message_delta" { - self.chunks.fetch_add(1, Ordering::Relaxed); - } - HandlerResponse::Ok - } - HandlerEvent::PermissionRequest { .. } => { - HandlerResponse::Permission(PermissionResult::Approved) - } - _ => HandlerResponse::Ok, - } - } -} - #[tokio::main] async fn main() -> Result<(), github_copilot_sdk::Error> { let mut opts = ClientOptions::default(); @@ -38,16 +15,23 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let client = Client::start(opts).await?; let chunks = Arc::new(AtomicUsize::new(0)); - let handler = Arc::new(StreamCounter { - chunks: chunks.clone(), - }); let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); config.streaming = Some(true); - let config = config.with_handler(handler); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; + let mut events = session.subscribe(); + let chunks_clone = chunks.clone(); + let counter = tokio::spawn(async move { + while let Ok(event) = events.recv().await { + if event.event_type == "assistant.message_delta" { + chunks_clone.fetch_add(1, Ordering::Relaxed); + } + } + }); + let response = session.send_and_wait("What is the capital of France?").await?; if let Some(event) = response { @@ -62,5 +46,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { ); session.disconnect().await?; + drop(counter); Ok(()) } diff --git a/test/scenarios/tools/custom-agents/rust/src/main.rs b/test/scenarios/tools/custom-agents/rust/src/main.rs index ff4fb301a..e31021b92 100644 --- a/test/scenarios/tools/custom-agents/rust/src/main.rs +++ b/test/scenarios/tools/custom-agents/rust/src/main.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandlerRouter, define_tool}; +use github_copilot_sdk::tool::{ToolHandler, define_tool}; use github_copilot_sdk::types::{CustomAgentConfig, DefaultAgentConfig, SessionConfig, ToolResult}; use github_copilot_sdk::{Client, ClientOptions}; use schemars::JsonSchema; @@ -23,7 +23,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { opts.github_token = std::env::var("GITHUB_TOKEN").ok(); let client = Client::start(opts).await?; - let analyze_codebase = define_tool( + let analyze_codebase: Arc = Arc::from(define_tool( "analyze-codebase", "Performs deep analysis of the codebase", |_inv, params: AnalyzeParams| async move { @@ -32,10 +32,9 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { params.query ))) }, - ); + )); - let router = ToolHandlerRouter::new(vec![analyze_codebase], Arc::new(ApproveAllHandler)); - let tools = router.tools(); + let tools = vec![analyze_codebase.tool()]; let mut researcher = CustomAgentConfig::default(); researcher.name = "researcher".to_string(); @@ -61,7 +60,9 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { excluded_tools: Some(vec!["analyze-codebase".to_string()]), }); config.custom_agents = Some(vec![researcher]); - let config = config.with_handler(Arc::new(router)); + let config = config + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tool_handlers(vec![analyze_codebase]); let session = client.create_session(config).await?; diff --git a/test/scenarios/tools/mcp-servers/rust/src/main.rs b/test/scenarios/tools/mcp-servers/rust/src/main.rs index 8abc7e078..dffc56c08 100644 --- a/test/scenarios/tools/mcp-servers/rust/src/main.rs +++ b/test/scenarios/tools/mcp-servers/rust/src/main.rs @@ -45,7 +45,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { config.system_message = Some(sysmsg); config.available_tools = Some(Vec::new()); config.mcp_servers = mcp_servers; - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; diff --git a/test/scenarios/tools/no-tools/rust/src/main.rs b/test/scenarios/tools/no-tools/rust/src/main.rs index 585dbda0b..64190c78b 100644 --- a/test/scenarios/tools/no-tools/rust/src/main.rs +++ b/test/scenarios/tools/no-tools/rust/src/main.rs @@ -26,7 +26,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { config.model = Some("claude-haiku-4.5".to_string()); config.system_message = Some(sysmsg); config.available_tools = Some(Vec::new()); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; let response = session diff --git a/test/scenarios/tools/skills/rust/src/main.rs b/test/scenarios/tools/skills/rust/src/main.rs index 1fc94ba64..d2f1ad6f0 100644 --- a/test/scenarios/tools/skills/rust/src/main.rs +++ b/test/scenarios/tools/skills/rust/src/main.rs @@ -40,7 +40,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { config.model = Some("claude-haiku-4.5".to_string()); config.skill_directories = Some(vec![skills_dir]); let config = config - .with_handler(Arc::new(ApproveAllHandler)) + .with_permission_handler(Arc::new(ApproveAllHandler)) .with_hooks(Arc::new(AllowAllHooks)); let session = client.create_session(config).await?; diff --git a/test/scenarios/tools/tool-filtering/rust/src/main.rs b/test/scenarios/tools/tool-filtering/rust/src/main.rs index 10eff3d91..d4cd5d3c2 100644 --- a/test/scenarios/tools/tool-filtering/rust/src/main.rs +++ b/test/scenarios/tools/tool-filtering/rust/src/main.rs @@ -28,7 +28,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { "glob".to_string(), "view".to_string(), ]); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; diff --git a/test/scenarios/tools/tool-overrides/rust/src/main.rs b/test/scenarios/tools/tool-overrides/rust/src/main.rs index 7b918ea13..0bfc0b8c4 100644 --- a/test/scenarios/tools/tool-overrides/rust/src/main.rs +++ b/test/scenarios/tools/tool-overrides/rust/src/main.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandlerRouter, define_tool}; +use github_copilot_sdk::tool::{ToolHandler, define_tool}; use github_copilot_sdk::types::{SessionConfig, ToolResult}; use github_copilot_sdk::{Client, ClientOptions}; use schemars::JsonSchema; @@ -23,16 +23,15 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { opts.github_token = std::env::var("GITHUB_TOKEN").ok(); let client = Client::start(opts).await?; - let grep_tool = define_tool( + let grep_tool: Arc = Arc::from(define_tool( "grep", "A custom grep implementation that overrides the built-in", |_inv, params: GrepParams| async move { Ok(ToolResult::Text(format!("CUSTOM_GREP_RESULT: {}", params.query))) }, - ); + )); - let router = ToolHandlerRouter::new(vec![grep_tool], Arc::new(ApproveAllHandler)); - let mut tools = router.tools(); + let mut tools = vec![grep_tool.tool()]; for t in tools.iter_mut() { if t.name == "grep" { t.overrides_built_in_tool = true; @@ -42,7 +41,9 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); config.tools = Some(tools); - let config = config.with_handler(Arc::new(router)); + let config = config + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tool_handlers(vec![grep_tool]); let session = client.create_session(config).await?; diff --git a/test/scenarios/transport/stdio/rust/src/main.rs b/test/scenarios/transport/stdio/rust/src/main.rs index 06a0de255..2795a14fd 100644 --- a/test/scenarios/transport/stdio/rust/src/main.rs +++ b/test/scenarios/transport/stdio/rust/src/main.rs @@ -14,7 +14,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; let response = session.send_and_wait("What is the capital of France?").await?; diff --git a/test/scenarios/transport/tcp/rust/src/main.rs b/test/scenarios/transport/tcp/rust/src/main.rs index 9f0674296..6488f243b 100644 --- a/test/scenarios/transport/tcp/rust/src/main.rs +++ b/test/scenarios/transport/tcp/rust/src/main.rs @@ -27,7 +27,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); - let config = config.with_handler(Arc::new(ApproveAllHandler)); + let config = config.with_permission_handler(Arc::new(ApproveAllHandler)); let session = client.create_session(config).await?; From 1269576acf9523f74398c4a74f80612f35f585bd Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 22:31:54 +0100 Subject: [PATCH 11/22] Phase H redo: fix lingering subscription.rs doc reference to removed handler types Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/src/subscription.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rust/src/subscription.rs b/rust/src/subscription.rs index 52c15b2eb..69886a195 100644 --- a/rust/src/subscription.rs +++ b/rust/src/subscription.rs @@ -4,10 +4,10 @@ //! [`Client::subscribe_lifecycle`](crate::Client::subscribe_lifecycle). //! //! Each subscription is an opt-in **observer** of events that are also -//! delivered to the [`SessionHandler`](crate::handler::SessionHandler). -//! Subscribers receive a clone of every event but cannot influence -//! permission decisions, tool results, or anything else that requires -//! returning a [`HandlerResponse`](crate::handler::HandlerResponse). +//! delivered to the per-event handlers installed on the session config +//! (see [`crate::handler`]). Subscribers receive a clone of every event but +//! cannot influence permission decisions, tool results, or any other event +//! whose handler return value affects the runtime. //! //! # Async iteration //! From 6edae3dbdfd0ab8f76f226ad8ebb6e8b6c573342 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 22:33:20 +0100 Subject: [PATCH 12/22] fmt: apply nightly rustfmt with grouped imports config Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/tests/session_test.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 9a1a16f6a..1b747e8a1 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -6,18 +6,17 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; use async_trait::async_trait; -use github_copilot_sdk::Client; use github_copilot_sdk::handler::{ ApproveAllHandler, AutoModeSwitchHandler, AutoModeSwitchResponse, ElicitationHandler, ExitPlanModeHandler, ExitPlanModeResult, PermissionHandler, PermissionResult, UserInputHandler, UserInputResponse, }; -use github_copilot_sdk::tool; use github_copilot_sdk::types::{ CommandContext, CommandDefinition, CommandHandler, DeliveryMode, ElicitationRequest, ElicitationResult, ExitPlanModeData, MessageOptions, PermissionRequestData, RequestId, SessionConfig, SessionId, Tool, ToolInvocation, ToolResult, }; +use github_copilot_sdk::{Client, tool}; use serde_json::Value; use tokio::io::{AsyncWrite, AsyncWriteExt, duplex}; use tokio::time::timeout; @@ -2055,6 +2054,7 @@ async fn external_tool_requested_dispatches_to_handler_and_responds() { .with_description("Run tests") .with_parameters(serde_json::json!({"type":"object"})) } + async fn call( &self, invocation: ToolInvocation, @@ -2103,6 +2103,7 @@ async fn external_tool_broadcast_for_unknown_tool_is_not_responded_to() { .with_description("foo") .with_parameters(serde_json::json!({"type":"object"})) } + async fn call( &self, _invocation: ToolInvocation, @@ -3877,6 +3878,7 @@ async fn tool_invocation_carries_trace_context_from_event() { .with_description("calc") .with_parameters(serde_json::json!({"type":"object"})) } + async fn call( &self, invocation: ToolInvocation, From ce39cc77f925eb08b78eaafe42aa2b90109c8b81 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 22:37:31 +0100 Subject: [PATCH 13/22] Fix rustdoc broken intra-doc link in permission.rs and redundant link in session.rs Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/src/permission.rs | 2 +- rust/src/session.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/src/permission.rs b/rust/src/permission.rs index 099b47497..22cf9bda9 100644 --- a/rust/src/permission.rs +++ b/rust/src/permission.rs @@ -1,4 +1,4 @@ -//! Permission policy primitives that produce a [`PermissionHandler`]. +//! Permission policy primitives that produce a [`PermissionHandler`](crate::handler::PermissionHandler). //! //! Compose these into a session via the builder methods //! [`SessionConfig::approve_all_permissions`](crate::types::SessionConfig::approve_all_permissions), diff --git a/rust/src/session.rs b/rust/src/session.rs index a7bf05c74..afdccb66a 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -239,7 +239,7 @@ impl Session { /// [`SessionEvent`] but cannot influence permission decisions, tool /// results, or anything else that requires returning a value. Those /// remain the responsibility of the per-callback handlers passed via - /// [`SessionConfig`](crate::types::SessionConfig)'s `with_*_handler` + /// [`SessionConfig`]'s `with_*_handler` /// builder methods. /// /// The returned handle implements both an inherent From bb0079fd229ed3d4837852aa4d6e1c453f819d3d Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 22:46:19 +0100 Subject: [PATCH 14/22] Phase H redo cleanup: drop dead request_* fields from public configs The request_user_input / request_permission / request_elicitation / request_exit_plan_mode / request_auto_mode_switch fields on SessionConfig and ResumeSessionConfig were unused: to_wire() derives each flag from handler presence, matching the per-trait dispatch semantics of the TypeScript and C# SDKs. Leaving the public fields (and their with_request_* builders) in place would have been a footgun: callers could set them and silently get no effect. Removed: - pub request_user_input / request_permission / request_elicitation / request_exit_plan_mode / request_auto_mode_switch fields from SessionConfig and ResumeSessionConfig. - with_request_user_input / with_request_exit_plan_mode / with_request_auto_mode_switch builders on both configs. - The corresponding entries in Default impls and Debug impls. - Vestigial assertions in the in-file builder tests. The two 'default flags' tests are replaced with new tests that build the wire payload directly and assert every request_* flag (and hooks) is false when no handler is installed. The existing 'request_elicitation_sent_in_create_params' and 'noop_handler_sends_request_permission_false' integration tests already assert via the wire serialization, so they are unchanged. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/README.md | 7 +- rust/src/types.rs | 158 +++++++--------------------------------------- 2 files changed, 28 insertions(+), 137 deletions(-) diff --git a/rust/README.md b/rust/README.md index bec601b00..5edba4e84 100644 --- a/rust/README.md +++ b/rust/README.md @@ -237,10 +237,11 @@ let config = SessionConfig { content: Some("Always explain your reasoning.".into()), ..Default::default() }), - request_elicitation: Some(true), // enable elicitation provider ..Default::default() -}; -let session = client.create_session(config.with_permission_handler(handler)).await?; +} +.with_elicitation_handler(Arc::new(my_elicitation_handler)) +.with_permission_handler(handler); +let session = client.create_session(config).await?; ``` ### Session Hooks diff --git a/rust/src/types.rs b/rust/src/types.rs index 7b57cc4c6..f13c13fbb 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -1059,33 +1059,6 @@ pub struct SessionConfig { /// When true, the CLI runs config discovery (MCP config files, skills, plugins). #[serde(skip_serializing_if = "Option::is_none")] pub enable_config_discovery: Option, - /// Enable the `ask_user` tool for interactive user input. Defaults to - /// `Some(true)` via [`SessionConfig::default`]. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_user_input: Option, - /// Enable `permission.request` JSON-RPC calls from the CLI. - /// Derived from [`Self::permission_handler`] presence at - /// [`Client::create_session`](crate::Client::create_session) time; - /// callers should install a [`PermissionHandler`] rather than - /// setting this directly. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_permission: Option, - /// Enable `exitPlanMode.request` JSON-RPC calls for plan approval. - /// Defaults to `Some(true)` via [`SessionConfig::default`]. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_exit_plan_mode: Option, - /// Enable `autoModeSwitch.request` JSON-RPC calls. When `true`, the CLI - /// asks the handler whether to switch to auto model when an eligible - /// rate limit is hit. Defaults to `Some(true)` via - /// [`SessionConfig::default`]. Without this flag, the CLI surfaces the - /// rate-limit error directly without offering the auto-mode switch. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_auto_mode_switch: Option, - /// Advertise elicitation provider capability. When true, the CLI sends - /// `elicitation.requested` events that the handler can respond to. - /// Defaults to `Some(true)` via [`SessionConfig::default`]. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_elicitation: Option, /// Skill directory paths passed through to the GitHub Copilot CLI. #[serde(skip_serializing_if = "Option::is_none")] pub skill_directories: Option>, @@ -1237,11 +1210,6 @@ impl std::fmt::Debug for SessionConfig { .field("excluded_tools", &self.excluded_tools) .field("mcp_servers", &self.mcp_servers) .field("enable_config_discovery", &self.enable_config_discovery) - .field("request_user_input", &self.request_user_input) - .field("request_permission", &self.request_permission) - .field("request_exit_plan_mode", &self.request_exit_plan_mode) - .field("request_auto_mode_switch", &self.request_auto_mode_switch) - .field("request_elicitation", &self.request_elicitation) .field("skill_directories", &self.skill_directories) .field("instruction_directories", &self.instruction_directories) .field("disabled_skills", &self.disabled_skills) @@ -1320,11 +1288,6 @@ impl Default for SessionConfig { mcp_servers: None, env_value_mode: default_env_value_mode(), enable_config_discovery: None, - request_user_input: None, - request_permission: None, - request_exit_plan_mode: None, - request_auto_mode_switch: None, - request_elicitation: None, skill_directories: None, instruction_directories: None, disabled_skills: None, @@ -1604,24 +1567,6 @@ impl SessionConfig { self } - /// Enable the `ask_user` tool. Defaults to `Some(true)` via [`Self::default`]. - pub fn with_request_user_input(mut self, enable: bool) -> Self { - self.request_user_input = Some(enable); - self - } - - /// Enable `exitPlanMode.request` JSON-RPC calls. Defaults to `Some(true)`. - pub fn with_request_exit_plan_mode(mut self, enable: bool) -> Self { - self.request_exit_plan_mode = Some(enable); - self - } - - /// Enable `autoModeSwitch.request` JSON-RPC calls. Defaults to `Some(true)`. - pub fn with_request_auto_mode_switch(mut self, enable: bool) -> Self { - self.request_auto_mode_switch = Some(enable); - self - } - /// Set skill directory paths passed through to the CLI. pub fn with_skill_directories(mut self, paths: I) -> Self where @@ -1793,24 +1738,6 @@ pub struct ResumeSessionConfig { /// Enable config discovery on resume. #[serde(skip_serializing_if = "Option::is_none")] pub enable_config_discovery: Option, - /// Enable the ask_user tool. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_user_input: Option, - /// Enable permission request RPCs. When no handler is set, permission requests - /// remain pending until the consumer responds out-of-band. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_permission: Option, - /// Enable exit-plan-mode request RPCs. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_exit_plan_mode: Option, - /// Enable auto-mode-switch request RPCs on resume. Defaults to - /// `Some(true)` via [`ResumeSessionConfig::new`]. See - /// [`SessionConfig::request_auto_mode_switch`] for details. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_auto_mode_switch: Option, - /// Advertise elicitation provider capability on resume. - #[serde(skip_serializing_if = "Option::is_none")] - pub request_elicitation: Option, /// Skill directory paths passed through to the GitHub Copilot CLI on resume. #[serde(skip_serializing_if = "Option::is_none")] pub skill_directories: Option>, @@ -1939,11 +1866,6 @@ impl std::fmt::Debug for ResumeSessionConfig { .field("excluded_tools", &self.excluded_tools) .field("mcp_servers", &self.mcp_servers) .field("enable_config_discovery", &self.enable_config_discovery) - .field("request_user_input", &self.request_user_input) - .field("request_permission", &self.request_permission) - .field("request_exit_plan_mode", &self.request_exit_plan_mode) - .field("request_auto_mode_switch", &self.request_auto_mode_switch) - .field("request_elicitation", &self.request_elicitation) .field("skill_directories", &self.skill_directories) .field("instruction_directories", &self.instruction_directories) .field("disabled_skills", &self.disabled_skills) @@ -2081,11 +2003,6 @@ impl ResumeSessionConfig { mcp_servers: None, env_value_mode: default_env_value_mode(), enable_config_discovery: None, - request_user_input: None, - request_permission: None, - request_exit_plan_mode: None, - request_auto_mode_switch: None, - request_elicitation: None, skill_directories: None, instruction_directories: None, disabled_skills: None, @@ -2275,24 +2192,6 @@ impl ResumeSessionConfig { self } - /// Enable the `ask_user` tool. Defaults to `Some(true)` via [`Self::new`]. - pub fn with_request_user_input(mut self, enable: bool) -> Self { - self.request_user_input = Some(enable); - self - } - - /// Enable `exitPlanMode.request` JSON-RPC calls. Defaults to `Some(true)`. - pub fn with_request_exit_plan_mode(mut self, enable: bool) -> Self { - self.request_exit_plan_mode = Some(enable); - self - } - - /// Enable `autoModeSwitch.request` JSON-RPC calls. Defaults to `Some(true)`. - pub fn with_request_auto_mode_switch(mut self, enable: bool) -> Self { - self.request_auto_mode_switch = Some(enable); - self - } - /// Set skill directory paths passed through to the CLI on resume. pub fn with_skill_directories(mut self, paths: I) -> Self where @@ -3299,9 +3198,10 @@ pub enum ElicitationMode { /// An incoming elicitation request from the CLI (provider side). /// -/// Received via `elicitation.requested` session event when the session was -/// created with `request_elicitation: true`. The provider should render a -/// form or dialog and return an [`ElicitationResult`]. +/// Received via `elicitation.requested` session event when the session has +/// an [`ElicitationHandler`](crate::handler::ElicitationHandler) installed. +/// The provider should render a form or dialog and return an +/// [`ElicitationResult`]. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ElicitationRequest { @@ -3608,26 +3508,30 @@ mod tests { } #[test] - fn session_config_default_enables_permission_flow_flags() { + fn session_config_default_wire_flags_off_without_handlers() { let cfg = SessionConfig::default(); - // All wire flags start unset; the SDK derives them from handler - // presence at Client::create_session time. - assert_eq!(cfg.request_user_input, None); - assert_eq!(cfg.request_permission, None); - assert_eq!(cfg.request_elicitation, None); - assert_eq!(cfg.request_exit_plan_mode, None); - assert_eq!(cfg.request_auto_mode_switch, None); + // Wire flags are derived from handler presence at create_session + // time, not stored on the config. With no handlers installed, every + // request_* flag should serialize as false. + let wire = cfg.to_wire(SessionId::from("default-flags")); + assert!(!wire.request_user_input); + assert!(!wire.request_permission); + assert!(!wire.request_elicitation); + assert!(!wire.request_exit_plan_mode); + assert!(!wire.request_auto_mode_switch); + assert!(!wire.hooks); } #[test] - fn resume_session_config_new_enables_permission_flow_flags() { - let cfg = ResumeSessionConfig::new(SessionId::from("test-id")); - // All wire flags start unset on resume too. - assert_eq!(cfg.request_user_input, None); - assert_eq!(cfg.request_permission, None); - assert_eq!(cfg.request_elicitation, None); - assert_eq!(cfg.request_exit_plan_mode, None); - assert_eq!(cfg.request_auto_mode_switch, None); + fn resume_session_config_new_wire_flags_off_without_handlers() { + let cfg = ResumeSessionConfig::new(SessionId::from("resume-flags")); + let wire = cfg.to_wire(); + assert!(!wire.request_user_input); + assert!(!wire.request_permission); + assert!(!wire.request_elicitation); + assert!(!wire.request_exit_plan_mode); + assert!(!wire.request_auto_mode_switch); + assert!(!wire.hooks); } #[test] @@ -3645,9 +3549,6 @@ mod tests { .with_excluded_tools(["dangerous"]) .with_mcp_servers(HashMap::new()) .with_enable_config_discovery(true) - .with_request_user_input(false) - .with_request_exit_plan_mode(false) - .with_request_auto_mode_switch(false) .with_skill_directories([PathBuf::from("/tmp/skills")]) .with_disabled_skills(["broken-skill"]) .with_agent("researcher") @@ -3673,10 +3574,6 @@ mod tests { ); assert!(cfg.mcp_servers.is_some()); assert_eq!(cfg.enable_config_discovery, Some(true)); - assert_eq!(cfg.request_user_input, Some(false)); // overrode default - assert_eq!(cfg.request_permission, None); // unset; derived at create_session time - assert_eq!(cfg.request_exit_plan_mode, Some(false)); - assert_eq!(cfg.request_auto_mode_switch, Some(false)); assert_eq!( cfg.skill_directories.as_deref(), Some(&[PathBuf::from("/tmp/skills")][..]) @@ -3705,9 +3602,6 @@ mod tests { .with_excluded_tools(["dangerous"]) .with_mcp_servers(HashMap::new()) .with_enable_config_discovery(true) - .with_request_user_input(false) - .with_request_exit_plan_mode(false) - .with_request_auto_mode_switch(false) .with_skill_directories([PathBuf::from("/tmp/skills")]) .with_disabled_skills(["broken-skill"]) .with_agent("researcher") @@ -3733,10 +3627,6 @@ mod tests { ); assert!(cfg.mcp_servers.is_some()); assert_eq!(cfg.enable_config_discovery, Some(true)); - assert_eq!(cfg.request_user_input, Some(false)); // overrode default - assert_eq!(cfg.request_permission, None); // unset; derived at create_session time - assert_eq!(cfg.request_exit_plan_mode, Some(false)); - assert_eq!(cfg.request_auto_mode_switch, Some(false)); assert_eq!( cfg.skill_directories.as_deref(), Some(&[PathBuf::from("/tmp/skills")][..]) From 1c6ebdc48932a7a7cd7b5f5bbfea42ba29438eac Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 22:54:57 +0100 Subject: [PATCH 15/22] Fix mcp-servers scenario for new Option> tools field Phase B made McpStdioServerConfig.tools an Option>; the scenario still set tools: vec![*], which under the new tri-state contract would mean 'only allow the tool literally named *', not 'all tools'. Drop the field and rely on Default (None) to expose the full MCP server tool list. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- test/scenarios/tools/mcp-servers/rust/src/main.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/test/scenarios/tools/mcp-servers/rust/src/main.rs b/test/scenarios/tools/mcp-servers/rust/src/main.rs index dffc56c08..171d2bcd4 100644 --- a/test/scenarios/tools/mcp-servers/rust/src/main.rs +++ b/test/scenarios/tools/mcp-servers/rust/src/main.rs @@ -25,7 +25,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { .map(|s| s.split(' ').map(str::to_string).collect()) .unwrap_or_default(); let stdio = McpStdioServerConfig { - tools: vec!["*".to_string()], command: cmd.clone(), args, ..Default::default() From 387fdcdf1da9b28bd601e82a1cd3ba811eea823a Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 22:58:59 +0100 Subject: [PATCH 16/22] Drop redundant explicit link target in types.rs ElicitationHandler doc ElicitationHandler is imported in types.rs, so the bare label resolves to the same item; the explicit path triggers rustdoc::redundant-explicit-links. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/src/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/src/types.rs b/rust/src/types.rs index f13c13fbb..9a4b2aca6 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -3199,7 +3199,7 @@ pub enum ElicitationMode { /// An incoming elicitation request from the CLI (provider side). /// /// Received via `elicitation.requested` session event when the session has -/// an [`ElicitationHandler`](crate::handler::ElicitationHandler) installed. +/// an [`ElicitationHandler`] installed. /// The provider should render a form or dialog and return an /// [`ElicitationResult`]. #[derive(Debug, Clone, Serialize, Deserialize)] From 94ea635c51adf41e73276a41d713c72363ae3404 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 23:09:42 +0100 Subject: [PATCH 17/22] Fix triple-encoded mojibake in session_test.rs comments Earlier PowerShell editing of this file double-corrupted some em-dashes, right-arrows, and a section sign (the bytes had been UTF-8 -> cp1252 -> UTF-8 -> cp1252 -> UTF-8 round-tripped). Restored to plain UTF-8. 19 instances fixed (11 em-dashes, 7 right-arrows, 1 section sign). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/tests/session_test.rs | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 1b747e8a1..916e28f0b 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -432,7 +432,7 @@ async fn client_rpc_methods_send_correct_method_names() { let (client, mut server_read, mut server_write) = make_client(); // Wire method names per the CLI runtime registration in @github/copilot - // app.js — verified against Node/Go/Python/.NET SDK call sites which all + // app.js — verified against Node/Go/Python/.NET SDK call sites which all // use these exact strings. The schema doesn't currently define these as // typed RPCs (top-level methods, not under any namespace), so call site // strings are the source of truth. @@ -527,7 +527,7 @@ async fn list_sessions_serializes_typed_filter() { // wrap; flattening is silently ignored by the runtime. assert!( request["params"].get("repository").is_none(), - "wire shape is `params.filter.*`, not `params.*` — see Node/Go/Python/.NET" + "wire shape is `params.filter.*`, not `params.*` — see Node/Go/Python/.NET" ); let id = request["id"].as_u64().unwrap(); @@ -578,7 +578,7 @@ fn mcp_server_config_roundtrips_through_tagged_enum() { fn mcp_stdio_tools_tri_state_serializes_correctly() { use github_copilot_sdk::McpStdioServerConfig; - // None → field omitted (= "expose all tools") + // None → field omitted (= "expose all tools") let cfg = McpStdioServerConfig { command: "echo".into(), tools: None, @@ -590,7 +590,7 @@ fn mcp_stdio_tools_tri_state_serializes_correctly() { "tools=None must be omitted on the wire; got {json}" ); - // Some(empty) → field present as [] + // Some(empty) → field present as [] let cfg = McpStdioServerConfig { command: "echo".into(), tools: Some(vec![]), @@ -599,7 +599,7 @@ fn mcp_stdio_tools_tri_state_serializes_correctly() { let json = serde_json::to_value(&cfg).unwrap(); assert_eq!(json["tools"], serde_json::json!([])); - // Some(non-empty) → field present as the explicit list + // Some(non-empty) → field present as the explicit list let cfg = McpStdioServerConfig { command: "echo".into(), tools: Some(vec!["a".into(), "b".into()]), @@ -613,17 +613,17 @@ fn mcp_stdio_tools_tri_state_serializes_correctly() { fn mcp_stdio_tools_tri_state_deserializes_correctly() { use github_copilot_sdk::McpStdioServerConfig; - // Missing field → None + // Missing field → None let cfg: McpStdioServerConfig = serde_json::from_value(serde_json::json!({ "command": "echo" })).unwrap(); assert_eq!(cfg.tools, None); - // Empty list → Some(empty) + // Empty list → Some(empty) let cfg: McpStdioServerConfig = serde_json::from_value(serde_json::json!({ "command": "echo", "tools": [] })).unwrap(); assert_eq!(cfg.tools, Some(vec![])); - // Non-empty list → Some(list) + // Non-empty list → Some(list) let cfg: McpStdioServerConfig = serde_json::from_value(serde_json::json!({ "command": "echo", "tools": ["x"] })).unwrap(); assert_eq!(cfg.tools, Some(vec!["x".to_string()])); @@ -1374,7 +1374,7 @@ async fn user_input_requested_notification_does_not_double_dispatch() { // `user_input.requested` notification (for observers) AND a // `userInput.request` JSON-RPC call (the actual prompt) for every // user-input prompt. Only the JSON-RPC path should reach the - // handler — dispatching from the notification too produced + // handler — dispatching from the notification too produced // duplicate ask_user widgets on the consumer side. struct CountingHandler { @@ -1692,7 +1692,7 @@ async fn send_and_wait_outer_cancellation_clears_waiter() { let request = server.read_request().await; server.respond(&request, serde_json::json!({})).await; - // Outer timeout fires → Err(Elapsed) returned, future is dropped. + // Outer timeout fires → Err(Elapsed) returned, future is dropped. let outer_result = timeout(Duration::from_secs(2), handle) .await .unwrap() @@ -1755,7 +1755,7 @@ async fn send_and_wait_drop_clears_waiter() { // Give the runtime a moment to run the drop. tokio::task::yield_now().await; - // Next `send` must succeed — no SendWhileWaiting. + // Next `send` must succeed — no SendWhileWaiting. let send_handle = tokio::spawn({ let session = session.clone(); async move { session.send("after-abort").await } @@ -1778,7 +1778,7 @@ async fn send_and_wait_drop_clears_waiter() { /// Cancel-safety regression: `Session::stop_event_loop` must NOT abort /// the event-loop task mid-handler. An in-flight handler (here a slow /// `userInput.request` callback) must run to completion before the loop -/// exits — the CLI receives the response on the wire before the session +/// exits — the CLI receives the response on the wire before the session /// tears down. /// /// Closes RFD-400 review finding #3. @@ -1833,7 +1833,7 @@ async fn stop_event_loop_completes_in_flight_handler() { }); // Verify the handler's response lands on the wire BEFORE the loop - // exits — i.e. stop_event_loop did not abort mid-handler. + // exits — i.e. stop_event_loop did not abort mid-handler. let response = timeout(Duration::from_secs(2), server.read_response()) .await .unwrap(); @@ -1939,7 +1939,7 @@ async fn cancellation_token_fires_on_session_drop() { } /// Cancelling a child token returned by `cancellation_token()` does NOT -/// shut the session down — child tokens isolate consumer-side cancel +/// shut the session down — child tokens isolate consumer-side cancel /// logic from the session's own lifecycle. #[tokio::test] async fn cancellation_token_child_cancel_does_not_kill_session() { @@ -2093,7 +2093,7 @@ async fn external_tool_requested_dispatches_to_handler_and_responds() { #[tokio::test] async fn external_tool_broadcast_for_unknown_tool_is_not_responded_to() { // Phase H multi-client safety: a handler that doesn't claim the - // requested tool name must not send an RPC response — another client + // requested tool name must not send an RPC response — another client // on the same CLI may have a real handler. struct FooTool; #[async_trait] @@ -2170,7 +2170,7 @@ async fn permission_broadcast_with_resolved_by_hook_is_not_responded_to() { #[tokio::test] async fn permission_broadcast_with_no_claiming_handler_is_not_responded_to() { // Phase H: a handler that doesn't claim permission dispatch must not - // respond — the SDK lets other connected clients handle the request. + // respond — the SDK lets other connected clients handle the request. let (_session, mut server) = create_session_pair().await; server .send_event( @@ -2417,7 +2417,7 @@ async fn env_value_mode_hardcoded_direct_on_create_and_resume() { async fn elicitation_methods_fail_without_capability() { let (session, _server) = create_session_pair().await; - // Session created without capabilities — elicitation should fail + // Session created without capabilities — elicitation should fail let err = session .ui() .elicitation("test", serde_json::json!({})) @@ -2824,7 +2824,7 @@ async fn client_stop_sends_session_destroy_for_each_active_session() { #[tokio::test] async fn client_stop_aggregates_session_destroy_errors() { - // session.destroy fails on the wire — Client::stop returns + // session.destroy fails on the wire — Client::stop returns // StopErrors carrying the failure rather than short-circuiting. let (session, mut server) = create_session_pair().await; let client = session.client().clone(); @@ -2929,7 +2929,7 @@ fn resume_session_config_serializes_bucket_b_fields() { } // ===================================================================== -// Slash commands (§ 4.1) +// Slash commands (§ 4.1) // ===================================================================== struct CountingCommandHandler { From 63871a587e017e66a716c1103806d04d8346047f Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 21 May 2026 23:16:55 +0100 Subject: [PATCH 18/22] Drop removed request_user_input assignment in user-input scenario The request_* wire flags are no longer caller-controllable; they are derived from handler presence. Installing with_user_input_handler is sufficient to set requestUserInput=true on the wire. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- test/scenarios/callbacks/user-input/rust/src/main.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/test/scenarios/callbacks/user-input/rust/src/main.rs b/test/scenarios/callbacks/user-input/rust/src/main.rs index 619137084..1517727e9 100644 --- a/test/scenarios/callbacks/user-input/rust/src/main.rs +++ b/test/scenarios/callbacks/user-input/rust/src/main.rs @@ -76,7 +76,6 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); - config.request_user_input = Some(true); let config = config .with_permission_handler(handler.clone()) .with_user_input_handler(handler) From 0230a8890d2d7700d1307b15292272ef484ef3af Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Fri, 22 May 2026 00:20:43 +0100 Subject: [PATCH 19/22] Unify tool registration: Tool::with_handler + with_tools (drop with_tool_handlers) Brings the Rust SDK in line with TS/C# where a tool's optional handler is metadata on the tool itself rather than a parallel collection. The trade-off favours fewer footguns: Before: impl ToolHandler for MyTool { fn tool() -> Tool { ... } async fn call(...) } config.with_tool_handlers(vec![Arc::new(MyTool)]) config.with_tools(vec![tool_definition]) // a second list After: impl ToolHandler for MyTool { async fn call(...) } let tool = Tool::new(name).with_description(...).with_parameters(...) .with_handler(Arc::new(MyTool)); config.with_tools(vec![tool]) API changes: - `Tool`: new Option> handler field with #[serde(skip)]. New Tool::with_handler(...) builder. The struct still serializes to the same wire shape; the handler field is never sent. - ToolHandler trait: dropped n tool(&self) -> Tool. Trait now contains only sync fn call(...). - SessionConfig::with_tool_handlers removed. Callers pass Tool values with handlers attached via SessionConfig::with_tools. Same for ResumeSessionConfig. - define_tool now returns Tool (with handler attached) instead of Box, so install via with_tools(vec![define_tool(...)]). - New define_tool_declaration::

(name, desc) for declaration-only tools with schema derived from P but no Rust-side handler. The SDK extracts each tool's handler from its Tool::handler field at Client::create_session / esume_session time, building the same name -> handler map it used to. Wire serialization unchanged. Migration: every consumer of with_tool_handlers and n tool(&self) -> Tool updated -- session_test, all e2e tests, examples/tool_server, README. The in-file test handlers are split into a free n x_tool() -> Tool plus a slim impl ToolHandler for X { async fn call ... }. ~25 sites total. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/README.md | 34 +- rust/examples/tool_server.rs | 53 ++- rust/src/handler.rs | 3 +- rust/src/session.rs | 38 ++- rust/src/tool.rs | 196 ++++++----- rust/src/types.rs | 98 +++--- rust/tests/e2e/abort.rs | 35 +- rust/tests/e2e/hooks_extended.rs | 30 +- rust/tests/e2e/multi_client.rs | 26 +- rust/tests/e2e/pending_work_resume.rs | 11 +- rust/tests/e2e/session.rs | 46 +-- rust/tests/e2e/telemetry.rs | 36 +- rust/tests/e2e/tool_results.rs | 92 +++--- rust/tests/e2e/tools.rs | 311 +++++++++--------- rust/tests/session_test.rs | 39 +-- .../tools/custom-agents/rust/src/main.rs | 11 +- .../tools/tool-overrides/rust/src/main.rs | 17 +- 17 files changed, 522 insertions(+), 554 deletions(-) diff --git a/rust/README.md b/rust/README.md index 5edba4e84..e4dc1a029 100644 --- a/rust/README.md +++ b/rust/README.md @@ -329,12 +329,15 @@ let session = client ### Tool Registration -Define client-side tools as named types implementing `ToolHandler` and install them with `SessionConfig::with_tool_handlers`. Enable the `derive` feature for `schema_for::()` — it generates JSON Schema from Rust types via `schemars`. +Define client-side tools as named types implementing `ToolHandler` and attach +them to `Tool` declarations via `Tool::with_handler`, then install via +`SessionConfig::with_tools`. Enable the `derive` feature for `schema_for::()` +— it generates JSON Schema from Rust types via `schemars`. ```rust,ignore use std::sync::Arc; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{schema_for, tool_parameters, JsonSchema, ToolHandler}; +use github_copilot_sdk::tool::{schema_for, JsonSchema, ToolHandler}; use github_copilot_sdk::{Error, SessionConfig, Tool, ToolInvocation, ToolResult}; use serde::Deserialize; use async_trait::async_trait; @@ -351,49 +354,48 @@ struct GetWeatherTool; #[async_trait] impl ToolHandler for GetWeatherTool { - fn tool(&self) -> Tool { - Tool::new("get_weather") - .with_description("Get weather for a city") - .with_parameters(tool_parameters(schema_for::())) - } - async fn call(&self, inv: ToolInvocation) -> Result { let params: GetWeatherParams = serde_json::from_value(inv.arguments)?; Ok(ToolResult::Text(format!("Weather in {}: sunny", params.city))) } } -let tool_handlers: Vec> = vec![Arc::new(GetWeatherTool)]; +let tool = Tool::new("get_weather") + .with_description("Get weather for a city") + .with_parameters(schema_for::()) + .with_handler(Arc::new(GetWeatherTool)); let config = SessionConfig::default() .with_permission_handler(Arc::new(ApproveAllHandler)) - .with_tool_handlers(tool_handlers); + .with_tools(vec![tool]); let session = client.create_session(config).await?; ``` -Tools are named types (not closures) — visible in stack traces and navigable via "go to definition". `with_tool_handlers` registers each handler under the name returned by its `tool()` method and surfaces the same `Tool` definitions to the CLI automatically; you don't need to set `SessionConfig::tools` separately when supplying handlers this way. +Tools are named types (not closures) — visible in stack traces and navigable via "go to definition". The SDK registers each tool's handler under its `Tool::name` and surfaces the same `Tool` definitions to the CLI automatically. + +Tools without an attached handler (`Tool::with_handler` never called) are declaration-only: the SDK advertises them on the wire but doesn't dispatch invocations to anything. Useful when another connected client services the tool. -For trivial tools that don't need a named type, [`define_tool`](crate::tool::define_tool) collapses the definition to a single expression. It returns a `Box` — convert to `Arc` with `Arc::from(...)`: +For trivial tools that don't need a named type, [`define_tool`](crate::tool::define_tool) collapses the definition to a single expression and returns a fully-formed `Tool` with handler attached: ```rust,ignore -use github_copilot_sdk::tool::{define_tool, JsonSchema, ToolHandler}; +use github_copilot_sdk::tool::{define_tool, JsonSchema}; use github_copilot_sdk::ToolResult; use serde::Deserialize; #[derive(Deserialize, JsonSchema)] struct GetWeatherParams { city: String } -let tool_handlers: Vec> = vec![Arc::from(define_tool( +let tool = define_tool( "get_weather", "Get weather for a city", |_inv, params: GetWeatherParams| async move { Ok(ToolResult::Text(format!("Sunny in {}", params.city))) }, -))]; +); let config = SessionConfig::default() .with_permission_handler(Arc::new(ApproveAllHandler)) - .with_tool_handlers(tool_handlers); + .with_tools(vec![tool]); ``` The closure receives the full [`ToolInvocation`](crate::types::ToolInvocation) alongside the deserialized parameters, so handlers that need `inv.session_id` or `inv.tool_call_id` for telemetry, streaming updates, or scoped lookups can use them directly. Use `_inv` when you don't need the metadata. diff --git a/rust/examples/tool_server.rs b/rust/examples/tool_server.rs index c6d6e709a..93492d20c 100644 --- a/rust/examples/tool_server.rs +++ b/rust/examples/tool_server.rs @@ -30,7 +30,7 @@ use async_trait::async_trait; #[cfg(feature = "derive")] use github_copilot_sdk::handler::ApproveAllHandler; #[cfg(feature = "derive")] -use github_copilot_sdk::tool::{JsonSchema, ToolHandler, schema_for, tool_parameters}; +use github_copilot_sdk::tool::{JsonSchema, ToolHandler, schema_for}; #[cfg(feature = "derive")] use github_copilot_sdk::types::{MessageOptions, SessionConfig, Tool, ToolInvocation, ToolResult}; #[cfg(feature = "derive")] @@ -57,14 +57,6 @@ struct GetWeatherTool; #[cfg(feature = "derive")] #[async_trait] impl ToolHandler for GetWeatherTool { - fn tool(&self) -> Tool { - let mut tool = Tool::default(); - tool.name = "get_weather".to_string(); - tool.description = "Get the current weather for a city.".to_string(); - tool.parameters = tool_parameters(schema_for::()); - tool - } - async fn call(&self, invocation: ToolInvocation) -> Result { let params: GetWeatherParams = serde_json::from_value(invocation.arguments)?; let unit = params.unit.as_deref().unwrap_or("celsius"); @@ -88,20 +80,6 @@ struct RollDiceTool; #[cfg(feature = "derive")] #[async_trait] impl ToolHandler for RollDiceTool { - fn tool(&self) -> Tool { - let mut tool = Tool::default(); - tool.name = "roll_dice".to_string(); - tool.description = "Roll one or more dice and return the total.".to_string(); - tool.parameters = tool_parameters(serde_json::json!({ - "type": "object", - "properties": { - "sides": { "type": "integer", "description": "Number of sides per die (default 6, max 1000)." }, - "count": { "type": "integer", "description": "Number of dice to roll (default 1, max 100)." } - } - })); - tool - } - async fn call(&self, invocation: ToolInvocation) -> Result { let sides = invocation .arguments @@ -143,19 +121,26 @@ impl ToolHandler for RollDiceTool { #[cfg(feature = "derive")] #[tokio::main] async fn main() -> Result<(), github_copilot_sdk::Error> { - let tool_handlers: Vec> = - vec![Arc::new(GetWeatherTool), Arc::new(RollDiceTool)]; - let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); - let client = Client::start(ClientOptions::default()).await?; - let config = { - let mut cfg = SessionConfig::default() - .with_permission_handler(Arc::new(ApproveAllHandler)) - .with_tool_handlers(tool_handlers); - cfg.tools = Some(tools); - cfg - }; + let config = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(vec![ + Tool::new("get_weather") + .with_description("Get the current weather for a city.") + .with_parameters(schema_for::()) + .with_handler(Arc::new(GetWeatherTool)), + Tool::new("roll_dice") + .with_description("Roll one or more dice and return the total.") + .with_parameters(serde_json::json!({ + "type": "object", + "properties": { + "sides": { "type": "integer", "description": "Number of sides per die (default 6, max 1000)." }, + "count": { "type": "integer", "description": "Number of dice to roll (default 1, max 100)." } + } + })) + .with_handler(Arc::new(RollDiceTool)), + ]); let session = client.create_session(config).await?; println!( diff --git a/rust/src/handler.rs b/rust/src/handler.rs index 8e7f6a655..042565564 100644 --- a/rust/src/handler.rs +++ b/rust/src/handler.rs @@ -12,7 +12,8 @@ //! respond to. //! //! Tool dispatch uses its own per-tool registry built from -//! [`SessionConfig::with_tool_handlers`](crate::types::SessionConfig::with_tool_handlers). +//! [`Tool::with_handler`](crate::types::Tool::with_handler) on entries passed to +//! [`SessionConfig::with_tools`](crate::types::SessionConfig::with_tools). use async_trait::async_trait; use serde::{Deserialize, Serialize}; diff --git a/rust/src/session.rs b/rust/src/session.rs index afdccb66a..0e5377d9e 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -764,7 +764,7 @@ impl Client { /// /// All callbacks (per-event handlers, tool handlers, hooks, transform) /// are configured via [`SessionConfig`] using its `with_*_handler` / - /// `with_tool_handlers` / `with_hooks` / `with_transform` builder + /// `with_tools` / `with_hooks` / `with_transform` builder /// methods. /// /// If [`hooks_handler`](SessionConfig::hooks_handler) is set, the @@ -802,14 +802,18 @@ impl Client { let exit_plan_mode_handler = config.exit_plan_mode_handler.take(); let auto_mode_switch_handler = config.auto_mode_switch_handler.take(); let mut tool_map: HashMap> = HashMap::new(); - for tool in config.tool_handlers.drain(..) { - let name = tool.tool().name; - if tool_map.contains_key(&name) { - return Err(Error::InvalidConfig(format!( - "duplicate tool handler registered for name {name:?}" - ))); + if let Some(tools) = config.tools.as_mut() { + for tool in tools.iter_mut() { + if let Some(handler) = tool.handler.take() { + if tool_map.contains_key(&tool.name) { + return Err(Error::InvalidConfig(format!( + "duplicate tool handler registered for name {:?}", + tool.name + ))); + } + tool_map.insert(tool.name.clone(), handler); + } } - tool_map.insert(name, tool); } let handlers = SessionHandlers { permission: permission_handler, @@ -953,14 +957,18 @@ impl Client { let exit_plan_mode_handler = config.exit_plan_mode_handler.take(); let auto_mode_switch_handler = config.auto_mode_switch_handler.take(); let mut tool_map: HashMap> = HashMap::new(); - for tool in config.tool_handlers.drain(..) { - let name = tool.tool().name; - if tool_map.contains_key(&name) { - return Err(Error::InvalidConfig(format!( - "duplicate tool handler registered for name {name:?}" - ))); + if let Some(tools) = config.tools.as_mut() { + for tool in tools.iter_mut() { + if let Some(handler) = tool.handler.take() { + if tool_map.contains_key(&tool.name) { + return Err(Error::InvalidConfig(format!( + "duplicate tool handler registered for name {:?}", + tool.name + ))); + } + tool_map.insert(tool.name.clone(), handler); + } } - tool_map.insert(name, tool); } let handlers = SessionHandlers { permission: permission_handler, diff --git a/rust/src/tool.rs b/rust/src/tool.rs index b3c021e32..3a35dd576 100644 --- a/rust/src/tool.rs +++ b/rust/src/tool.rs @@ -1,11 +1,14 @@ //! Typed tool definition framework. //! //! Provides the [`ToolHandler`](crate::tool::ToolHandler) trait for -//! implementing tools as named types. Install tool handlers on a session -//! via -//! [`SessionConfig::with_tool_handlers`](crate::types::SessionConfig::with_tool_handlers); -//! the SDK builds an internal name-keyed registry and dispatches to the -//! matching handler when the CLI broadcasts `external_tool.requested`. +//! implementing tools as named types. Attach a handler to a +//! [`Tool`](crate::types::Tool) via +//! [`Tool::with_handler`](crate::types::Tool::with_handler), then install +//! the resulting tools on a session via +//! [`SessionConfig::with_tools`](crate::types::SessionConfig::with_tools). +//! The SDK builds an internal name-keyed registry from the handlers and +//! dispatches to the matching handler when the CLI broadcasts +//! `external_tool.requested`. //! //! Enable the `derive` feature for `schema_for`, which generates JSON //! Schema from Rust types via `schemars`. @@ -170,63 +173,61 @@ pub fn convert_mcp_call_tool_result(value: &serde_json::Value) -> Option, /// } /// -/// struct GetWeatherTool; +/// struct GetWeather; /// /// #[async_trait] -/// impl ToolHandler for GetWeatherTool { -/// fn tool(&self) -> Tool { -/// Tool { -/// name: "get_weather".to_string(), -/// namespaced_name: None, -/// description: "Get weather for a city".to_string(), -/// parameters: tool_parameters(schema_for::()), -/// instructions: None, -/// ..Default::default() -/// } -/// } -/// +/// impl ToolHandler for GetWeather { /// async fn call(&self, inv: ToolInvocation) -> Result { /// let params: GetWeatherParams = serde_json::from_value(inv.arguments)?; /// Ok(ToolResult::Text(format!("Weather in {}: sunny", params.city))) /// } /// } +/// +/// // Build the Tool declaration with the handler attached: +/// let tool = Tool::new("get_weather") +/// .with_description("Get weather for a city") +/// .with_parameters(schema_for::()) +/// .with_handler(Arc::new(GetWeather)); /// ``` #[async_trait] -pub trait ToolHandler: Send + Sync { - /// The tool definition sent to the CLI during session creation. - fn tool(&self) -> Tool; - +pub trait ToolHandler: Send + Sync + 'static { /// Handle a tool invocation from the agent. async fn call(&self, invocation: ToolInvocation) -> Result; } -/// Define a tool from an async function (or closure) that takes a typed, +/// Define a [`Tool`] from an async function (or closure) that takes a typed, /// `JsonSchema`-derived parameter struct. /// -/// The returned `Box` can be installed on a session via -/// [`SessionConfig::with_tool_handlers`](crate::types::SessionConfig::with_tool_handlers). +/// The returned [`Tool`] carries an attached handler ready to install on a +/// session via [`SessionConfig::with_tools`](crate::types::SessionConfig::with_tools). /// JSON Schema for the parameter type is generated via [`schema_for`] at /// construction time. /// @@ -259,8 +260,6 @@ pub trait ToolHandler: Send + Sync { /// inv: ToolInvocation, /// params: GetWeatherParams, /// ) -> Result { -/// // `inv.session_id` and `inv.tool_call_id` are available for telemetry, -/// // streaming updates, scoping DB lookups, etc. /// let _ = inv.session_id; /// Ok(ToolResult::Text(format!("Sunny in {}", params.city))) /// } @@ -286,36 +285,24 @@ pub fn define_tool( name: impl Into, description: impl Into, handler: F, -) -> Box +) -> Tool where P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static, F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static, Fut: std::future::Future> + Send + 'static, { - struct FnTool { - name: String, - description: String, - parameters: HashMap, + struct FnHandler { handler: F, _marker: std::marker::PhantomData, } #[async_trait] - impl ToolHandler for FnTool + impl ToolHandler for FnHandler where P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static, F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static, Fut: std::future::Future> + Send + 'static, { - fn tool(&self) -> Tool { - Tool { - name: self.name.clone(), - description: self.description.clone(), - parameters: self.parameters.clone(), - ..Default::default() - } - } - async fn call(&self, mut invocation: ToolInvocation) -> Result { let arguments = std::mem::take(&mut invocation.arguments); let params: P = serde_json::from_value(arguments)?; @@ -323,13 +310,50 @@ where } } - Box::new(FnTool { + Tool { name: name.into(), description: description.into(), parameters: tool_parameters(schema_for::

()), + ..Default::default() + } + .with_handler(std::sync::Arc::new(FnHandler { handler, _marker: std::marker::PhantomData, - }) + })) +} + +/// Define a declaration-only [`Tool`] with a JSON Schema derived from `P`. +/// +/// Equivalent to [`define_tool`] but produces a [`Tool`] with no attached +/// handler — useful when another connected client services this tool, or +/// when you only need to advertise the schema for capability negotiation. +/// +/// # Example +/// +/// ```rust,no_run +/// use github_copilot_sdk::tool::{define_tool_declaration, JsonSchema}; +/// use serde::Deserialize; +/// +/// #[derive(Deserialize, JsonSchema)] +/// struct Params { query: String } +/// +/// let declared = define_tool_declaration::( +/// "legacy_thing", +/// "Handled by another connected client", +/// ); +/// # let _ = declared; +/// ``` +#[cfg(feature = "derive")] +pub fn define_tool_declaration

(name: impl Into, description: impl Into) -> Tool +where + P: schemars::JsonSchema, +{ + Tool { + name: name.into(), + description: description.into(), + parameters: tool_parameters(schema_for::

()), + ..Default::default() + } } #[cfg(test)] @@ -339,19 +363,18 @@ mod tests { struct EchoTool; - #[async_trait] - impl ToolHandler for EchoTool { - fn tool(&self) -> Tool { - Tool { - name: "echo".to_string(), - namespaced_name: None, - description: "Echo the input".to_string(), - parameters: tool_parameters(serde_json::json!({"type": "object"})), - instructions: None, - ..Default::default() - } + fn echo_tool() -> Tool { + Tool { + name: "echo".to_string(), + description: "Echo the input".to_string(), + parameters: tool_parameters(serde_json::json!({"type": "object"})), + ..Default::default() } + .with_handler(std::sync::Arc::new(EchoTool)) + } + #[async_trait] + impl ToolHandler for EchoTool { async fn call(&self, inv: ToolInvocation) -> Result { Ok(ToolResult::Text(inv.arguments.to_string())) } @@ -359,11 +382,11 @@ mod tests { #[test] fn tool_handler_returns_tool_definition() { - let tool = EchoTool; - let def = tool.tool(); + let def = echo_tool(); assert_eq!(def.name, "echo"); assert_eq!(def.description, "Echo the input"); assert!(def.parameters.contains_key("type")); + assert!(def.handler.is_some()); } #[test] @@ -566,11 +589,11 @@ mod tests { }, ); - let def = tool.tool(); - assert_eq!(def.name, "weather"); - assert_eq!(def.description, "Get the weather for a city"); - assert_eq!(def.parameters["type"], "object"); - assert!(def.parameters["properties"]["city"].is_object()); + assert_eq!(tool.name, "weather"); + assert_eq!(tool.description, "Get the weather for a city"); + assert_eq!(tool.parameters["type"], "object"); + assert!(tool.parameters["properties"]["city"].is_object()); + let handler = tool.handler.as_ref().expect("define_tool attaches handler"); let inv = ToolInvocation { session_id: SessionId::from("s1"), @@ -580,7 +603,7 @@ mod tests { traceparent: None, tracestate: None, }; - match tool.call(inv).await.unwrap() { + match handler.call(inv).await.unwrap() { ToolResult::Text(s) => assert_eq!(s, "sunny in Seattle"), _ => panic!("expected Text result"), } @@ -619,19 +642,18 @@ mod tests { struct GetWeatherTool; - #[async_trait] - impl ToolHandler for GetWeatherTool { - fn tool(&self) -> Tool { - Tool { - name: "get_weather".to_string(), - namespaced_name: None, - description: "Get weather for a city".to_string(), - parameters: tool_parameters(schema_for::()), - instructions: None, - ..Default::default() - } + fn get_weather_tool() -> Tool { + Tool { + name: "get_weather".to_string(), + description: "Get weather for a city".to_string(), + parameters: tool_parameters(schema_for::()), + ..Default::default() } + .with_handler(std::sync::Arc::new(GetWeatherTool)) + } + #[async_trait] + impl ToolHandler for GetWeatherTool { async fn call(&self, inv: ToolInvocation) -> Result { let params: GetWeatherParams = serde_json::from_value(inv.arguments)?; Ok(ToolResult::Text(format!( @@ -644,12 +666,12 @@ mod tests { #[test] fn tool_handler_with_schema_for() { - let tool = GetWeatherTool; - let def = tool.tool(); + let def = get_weather_tool(); assert_eq!(def.name, "get_weather"); let schema = serde_json::to_value(&def.parameters).expect("serialize tool parameters"); assert_eq!(schema["type"], "object"); assert!(schema["properties"]["city"].is_object()); + assert!(def.handler.is_some()); } #[tokio::test] @@ -689,11 +711,7 @@ mod tests { #[tokio::test] async fn schema_for_derived_tool_round_trips_through_call() { - let tool: Box = Box::new(GetWeatherTool); - - // Tool definition exposes the schema-derived parameter set. - let def = tool.tool(); - assert_eq!(def.name, "get_weather"); + let tool = GetWeatherTool; // Calling the tool with matching arguments returns the // expected typed result. (Per-name dispatch is the SDK's diff --git a/rust/src/types.rs b/rust/src/types.rs index 9a4b2aca6..991e41623 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -296,7 +296,14 @@ impl PartialEq<&str> for RequestId { /// (rather than using the schema-generated form) so it can carry runtime /// hints — `overrides_built_in_tool`, `skip_permission` — that don't appear /// in the wire schema but are honored by the CLI. -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +/// +/// A `Tool` may optionally carry a [`handler`](Self::handler): an +/// `Arc` that implements the tool's runtime behavior. +/// When present, the SDK dispatches matching `external_tool.requested` +/// broadcasts to it automatically. When absent (`None`), the tool is +/// declaration-only — another connected client must service incoming +/// invocations. +#[derive(Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] #[non_exhaustive] pub struct Tool { @@ -325,6 +332,14 @@ pub struct Tool { /// access control. #[serde(default, skip_serializing_if = "is_false")] pub skip_permission: bool, + /// Optional runtime implementation. When `Some`, the SDK dispatches + /// matching `external_tool.requested` broadcasts to this handler. + /// When `None`, the tool is declaration-only. + /// + /// Skipped during serialization — the handler is runtime behavior, + /// not part of the wire representation. + #[serde(skip)] + pub handler: Option>, } #[inline] @@ -408,6 +423,32 @@ impl Tool { self.skip_permission = skip; self } + + /// Attach a runtime implementation. The SDK will dispatch matching + /// `external_tool.requested` broadcasts to `handler` for this tool's + /// name. Without a handler the tool is declaration-only. + pub fn with_handler(mut self, handler: Arc) -> Self { + self.handler = Some(handler); + self + } +} + +impl std::fmt::Debug for Tool { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Tool") + .field("name", &self.name) + .field("namespaced_name", &self.namespaced_name) + .field("description", &self.description) + .field("instructions", &self.instructions) + .field("parameters", &self.parameters) + .field("overrides_built_in_tool", &self.overrides_built_in_tool) + .field("skip_permission", &self.skip_permission) + .field( + "handler", + &self.handler.as_ref().map(|_| "").unwrap_or("None"), + ) + .finish() + } } /// Context passed to a [`CommandHandler`] when a registered slash command @@ -1173,11 +1214,6 @@ pub struct SessionConfig { /// `requestAutoModeSwitch: false` goes on the wire. #[serde(skip)] pub auto_mode_switch_handler: Option>, - /// Client-defined tool handlers. The SDK builds an internal - /// name-keyed registry from these and dispatches to the matching - /// handler when the CLI broadcasts `external_tool.requested`. - #[serde(skip)] - pub tool_handlers: Vec>, /// Session lifecycle hook handler (pre/post tool use, session /// start/end, etc.). When set, the SDK auto-enables the wire-level /// `hooks` flag. Use [`with_hooks`](Self::with_hooks) to install one. @@ -1258,7 +1294,6 @@ impl std::fmt::Debug for SessionConfig { "auto_mode_switch_handler", &self.auto_mode_switch_handler.as_ref().map(|_| ""), ) - .field("tool_handlers_count", &self.tool_handlers.len()) .field( "hooks_handler", &self.hooks_handler.as_ref().map(|_| ""), @@ -1312,7 +1347,6 @@ impl Default for SessionConfig { user_input_handler: None, exit_plan_mode_handler: None, auto_mode_switch_handler: None, - tool_handlers: Vec::new(), hooks_handler: None, permission_policy: None, transform: None, @@ -1334,7 +1368,7 @@ impl SessionConfig { reasoning_effort: self.reasoning_effort.clone(), streaming: self.streaming, system_message: self.system_message.clone(), - tools: self.merged_tool_wire_definitions(), + tools: self.tools.clone(), available_tools: self.available_tools.clone(), excluded_tools: self.excluded_tools.clone(), mcp_servers: self.mcp_servers.clone(), @@ -1373,17 +1407,6 @@ impl SessionConfig { } } - /// Merge caller-supplied `tools` (declaration-only) with the `tool()` - /// definitions extracted from each [`tool_handlers`](Self::tool_handlers) - /// entry. Returns `None` only when both sources are empty. - fn merged_tool_wire_definitions(&self) -> Option> { - let mut out: Vec = self.tools.clone().unwrap_or_default(); - for handler in &self.tool_handlers { - out.push(handler.tool()); - } - if out.is_empty() { None } else { Some(out) } - } - /// Install a [`PermissionHandler`] for this session. When omitted, the /// SDK sends `requestPermission: false` on the wire and the runtime /// short-circuits permission prompts for this client. @@ -1421,17 +1444,6 @@ impl SessionConfig { self } - /// Install tool handlers for this session. Each handler must report a - /// unique [`Tool::name`]; the SDK rejects duplicates at - /// [`Client::create_session`](crate::Client::create_session) time. - pub fn with_tool_handlers(mut self, handlers: I) -> Self - where - I: IntoIterator>, - { - self.tool_handlers = handlers.into_iter().collect(); - self - } - /// Register slash commands for this session. Each command appears as /// `/name` in the CLI's TUI; the handler is invoked when the user /// executes the command. Replaces any commands previously set on this @@ -1839,9 +1851,6 @@ pub struct ResumeSessionConfig { /// [`SessionConfig::auto_mode_switch_handler`]. #[serde(skip)] pub auto_mode_switch_handler: Option>, - /// Tool handlers. See [`SessionConfig::tool_handlers`]. - #[serde(skip)] - pub tool_handlers: Vec>, /// Session hook handler. See [`SessionConfig::hooks_handler`]. #[serde(skip)] pub hooks_handler: Option>, @@ -1913,7 +1922,6 @@ impl std::fmt::Debug for ResumeSessionConfig { "auto_mode_switch_handler", &self.auto_mode_switch_handler.as_ref().map(|_| ""), ) - .field("tool_handlers_count", &self.tool_handlers.len()) .field( "hooks_handler", &self.hooks_handler.as_ref().map(|_| ""), @@ -1938,7 +1946,7 @@ impl ResumeSessionConfig { reasoning_effort: self.reasoning_effort.clone(), streaming: self.streaming, system_message: self.system_message.clone(), - tools: self.merged_tool_wire_definitions(), + tools: self.tools.clone(), available_tools: self.available_tools.clone(), excluded_tools: self.excluded_tools.clone(), mcp_servers: self.mcp_servers.clone(), @@ -1978,14 +1986,6 @@ impl ResumeSessionConfig { } } - fn merged_tool_wire_definitions(&self) -> Option> { - let mut out: Vec = self.tools.clone().unwrap_or_default(); - for handler in &self.tool_handlers { - out.push(handler.tool()); - } - if out.is_empty() { None } else { Some(out) } - } - /// Construct a `ResumeSessionConfig` with the given session ID and all /// other fields left unset. Combine with `.with_*` builders or struct /// update syntax (`..ResumeSessionConfig::new(id)`) to populate the @@ -2028,7 +2028,6 @@ impl ResumeSessionConfig { user_input_handler: None, exit_plan_mode_handler: None, auto_mode_switch_handler: None, - tool_handlers: Vec::new(), hooks_handler: None, permission_policy: None, transform: None, @@ -2068,15 +2067,6 @@ impl ResumeSessionConfig { self } - /// Install tool handlers for the resumed session. - pub fn with_tool_handlers(mut self, handlers: I) -> Self - where - I: IntoIterator>, - { - self.tool_handlers = handlers.into_iter().collect(); - self - } - /// Install a [`SessionHooks`] handler. Automatically enables the /// wire-level `hooks` flag on session resumption. pub fn with_hooks(mut self, hooks: Arc) -> Self { diff --git a/rust/tests/e2e/abort.rs b/rust/tests/e2e/abort.rs index 18be2ac18..33ef835d7 100644 --- a/rust/tests/e2e/abort.rs +++ b/rust/tests/e2e/abort.rs @@ -76,7 +76,7 @@ async fn should_abort_during_active_tool_execution() { let client = ctx.start_client().await; let (started_tx, mut started_rx) = mpsc::unbounded_channel(); let (release_tx, release_rx) = oneshot::channel(); - let slow_tool: Arc = Arc::new(SlowAnalysisTool { + let slow_tool = Arc::new(SlowAnalysisTool { started_tx, release_rx: Mutex::new(Some(release_rx)), }); @@ -85,7 +85,23 @@ async fn should_abort_during_active_tool_execution() { SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) .with_permission_handler(Arc::new(ApproveAllHandler)) - .with_tool_handlers(vec![slow_tool]), + .with_tools(vec![ + Tool::new("slow_analysis") + .with_description( + "A slow analysis tool that blocks until released", + ) + .with_parameters(json!({ + "type": "object", + "properties": { + "value": { + "type": "string", + "description": "Value to analyze" + } + }, + "required": ["value"] + })) + .with_handler(slow_tool), + ]), ) .await .expect("create session"); @@ -134,21 +150,6 @@ struct SlowAnalysisTool { #[async_trait] impl ToolHandler for SlowAnalysisTool { - fn tool(&self) -> Tool { - Tool::new("slow_analysis") - .with_description("A slow analysis tool that blocks until released") - .with_parameters(json!({ - "type": "object", - "properties": { - "value": { - "type": "string", - "description": "Value to analyze" - } - }, - "required": ["value"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let value = invocation .arguments diff --git a/rust/tests/e2e/hooks_extended.rs b/rust/tests/e2e/hooks_extended.rs index 12d497250..7f2e72283 100644 --- a/rust/tests/e2e/hooks_extended.rs +++ b/rust/tests/e2e/hooks_extended.rs @@ -285,16 +285,13 @@ async fn should_allow_pretooluse_to_return_modifiedargs_and_suppressoutput() { Box::pin(async move { ctx.set_default_copilot_user(); let (tx, mut rx) = mpsc::unbounded_channel(); - let echo_tool: Arc = Arc::new(EchoValueTool); - let tools = vec![echo_tool.tool()]; let client = ctx.start_client().await; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) .with_permission_handler(Arc::new(ApproveAllHandler)) - .with_tool_handlers(vec![echo_tool]) - .with_tools(tools) + .with_tools(vec![echo_value_tool()]) .with_hooks(Arc::new(RecordingHooks::pre_tool(tx))), ) .await @@ -540,20 +537,21 @@ impl SessionHooks for RecordingHooks { struct EchoValueTool; +fn echo_value_tool() -> Tool { + Tool::new("echo_value") + .with_description("Echoes the supplied value") + .with_parameters(json!({ + "type": "object", + "properties": { + "value": { "type": "string" } + }, + "required": ["value"] + })) + .with_handler(Arc::new(EchoValueTool)) +} + #[async_trait] impl ToolHandler for EchoValueTool { - fn tool(&self) -> Tool { - Tool::new("echo_value") - .with_description("Echoes the supplied value") - .with_parameters(json!({ - "type": "object", - "properties": { - "value": { "type": "string" } - }, - "required": ["value"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { Ok(ToolResult::Text( invocation diff --git a/rust/tests/e2e/multi_client.rs b/rust/tests/e2e/multi_client.rs index f4c06f794..23552954b 100644 --- a/rust/tests/e2e/multi_client.rs +++ b/rust/tests/e2e/multi_client.rs @@ -35,7 +35,7 @@ async fn both_clients_see_tool_request_and_completion_events() { SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) .with_permission_handler(Arc::new(ApproveAllHandler)) - .with_tool_handlers(selective_tools(vec![EchoTool::new( + .with_tools(selective_tools(vec![EchoTool::new( "magic_number", "seed", "MAGIC_", @@ -51,7 +51,7 @@ async fn both_clients_see_tool_request_and_completion_events() { .resume_session( resume_config(session1.id().clone()) .with_permission_handler(Arc::new(ApproveAllHandler)) - .with_tool_handlers(selective_tools(Vec::new())), + .with_tools(selective_tools(Vec::new())), ) .await .expect("resume session"); @@ -287,7 +287,7 @@ async fn two_clients_register_different_tools_and_agent_uses_both() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_permission_handler(Arc::new(ApproveAllHandler)).with_tool_handlers(selective_tools(vec![EchoTool::new( + .with_permission_handler(Arc::new(ApproveAllHandler)).with_tools(selective_tools(vec![EchoTool::new( "city_lookup", "countryCode", "CITY_FOR_", @@ -302,7 +302,7 @@ async fn two_clients_register_different_tools_and_agent_uses_both() { let session2 = client2 .resume_session( resume_config(session1.id().clone()) - .with_permission_handler(Arc::new(ApproveAllHandler)).with_tool_handlers(selective_tools(vec![EchoTool::new( + .with_permission_handler(Arc::new(ApproveAllHandler)).with_tools(selective_tools(vec![EchoTool::new( "currency_lookup", "countryCode", "CURRENCY_FOR_", @@ -355,7 +355,7 @@ async fn disconnecting_client_removes_its_tools() { .create_session( SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) - .with_permission_handler(Arc::new(ApproveAllHandler)).with_tool_handlers(selective_tools(vec![EchoTool::new( + .with_permission_handler(Arc::new(ApproveAllHandler)).with_tools(selective_tools(vec![EchoTool::new( "stable_tool", "input", "STABLE_", @@ -370,7 +370,7 @@ async fn disconnecting_client_removes_its_tools() { let _session2 = client2 .resume_session( resume_config(session1.id().clone()) - .with_permission_handler(Arc::new(ApproveAllHandler)).with_tool_handlers(selective_tools(vec![EchoTool::new( + .with_permission_handler(Arc::new(ApproveAllHandler)).with_tools(selective_tools(vec![EchoTool::new( "ephemeral_tool", "input", "EPHEMERAL_", @@ -427,7 +427,7 @@ fn resume_config(session_id: SessionId) -> ResumeSessionConfig { ResumeSessionConfig::new(session_id) .with_github_token(DEFAULT_TEST_TOKEN) .with_permission_handler(Arc::new(ApproveAllHandler)) - .with_tool_handlers(selective_tools(Vec::new())) + .with_tools(selective_tools(Vec::new())) .with_suppress_resume_event(true) } @@ -455,10 +455,14 @@ fn free_tcp_port() -> u16 { listener.local_addr().expect("local addr").port() } -fn selective_tools(tools: Vec) -> Vec> { +fn selective_tools(tools: Vec) -> Vec { tools .into_iter() - .map(|t| Arc::new(t) as Arc) + .map(|t| { + let name = t.name; + let argument_name = t.argument_name; + EchoTool::tool_definition(name, argument_name).with_handler(Arc::new(t)) + }) .collect() } @@ -520,10 +524,6 @@ impl PermissionHandler for PermissionDecisionHandler { #[async_trait] impl ToolHandler for EchoTool { - fn tool(&self) -> Tool { - EchoTool::tool_definition(self.name, self.argument_name) - } - async fn call( &self, invocation: ToolInvocation, diff --git a/rust/tests/e2e/pending_work_resume.rs b/rust/tests/e2e/pending_work_resume.rs index bf3a015e9..0a782f980 100644 --- a/rust/tests/e2e/pending_work_resume.rs +++ b/rust/tests/e2e/pending_work_resume.rs @@ -43,7 +43,7 @@ async fn should_continue_pending_external_tool_request_after_resume() { let suspended_client = start_external_client(ctx, port).await; let (started_tx, mut started_rx) = mpsc::unbounded_channel(); let (_release_tx, release_rx) = oneshot::channel(); - let router: Arc = Arc::new(BlockingExternalTool { + let router = Arc::new(BlockingExternalTool { started_tx, release_rx: Mutex::new(Some(release_rx)), }); @@ -52,8 +52,9 @@ async fn should_continue_pending_external_tool_request_after_resume() { SessionConfig::default() .with_github_token(DEFAULT_TEST_TOKEN) .with_permission_handler(Arc::new(ApproveAllHandler)) - .with_tool_handlers(vec![router]) - .with_tools([BlockingExternalTool::definition()]), + .with_tools(vec![ + BlockingExternalTool::definition().with_handler(router), + ]), ) .await .expect("create session"); @@ -312,10 +313,6 @@ impl BlockingExternalTool { #[async_trait] impl ToolHandler for BlockingExternalTool { - fn tool(&self) -> Tool { - Self::definition() - } - async fn call(&self, invocation: ToolInvocation) -> Result { let value = invocation .arguments diff --git a/rust/tests/e2e/session.rs b/rust/tests/e2e/session.rs index 5f1cfdabd..ce07bf7f7 100644 --- a/rust/tests/e2e/session.rs +++ b/rust/tests/e2e/session.rs @@ -329,16 +329,12 @@ async fn should_create_a_session_with_defaultagent_excludedtools() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let tool_handlers: Vec> = vec![Arc::new(SecretTool)]; - let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); - let __perm = Arc::new(ApproveAllHandler); let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_permission_handler(__perm) - .with_tool_handlers(tool_handlers) - .with_tools(tools) + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(vec![secret_tool()]) .with_default_agent(DefaultAgentConfig { excluded_tools: Some(vec!["secret_tool".to_string()]), }), @@ -365,16 +361,12 @@ async fn should_create_session_with_custom_tool() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let tool_handlers: Vec> = vec![Arc::new(SecretNumberTool)]; - let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); - let __perm = Arc::new(ApproveAllHandler); let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_permission_handler(__perm) - .with_tool_handlers(tool_handlers) - .with_tools(tools), + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(vec![secret_number_tool()]), ) .await .expect("create session"); @@ -1524,10 +1516,6 @@ struct SecretNumberTool; #[async_trait::async_trait] impl ToolHandler for SecretNumberTool { - fn tool(&self) -> Tool { - secret_number_tool() - } - async fn call(&self, invocation: ToolInvocation) -> Result { let key = invocation .arguments @@ -1542,22 +1530,23 @@ impl ToolHandler for SecretNumberTool { } } +fn secret_tool() -> Tool { + Tool::new("secret_tool") + .with_description("A secret tool hidden from the default agent") + .with_parameters(json!({ + "type": "object", + "properties": { + "input": { "type": "string" } + }, + "required": ["input"] + })) + .with_handler(Arc::new(SecretTool)) +} + struct SecretTool; #[async_trait::async_trait] impl ToolHandler for SecretTool { - fn tool(&self) -> Tool { - Tool::new("secret_tool") - .with_description("A secret tool hidden from the default agent") - .with_parameters(json!({ - "type": "object", - "properties": { - "input": { "type": "string" } - }, - "required": ["input"] - })) - } - async fn call(&self, _invocation: ToolInvocation) -> Result { Ok(ToolResult::Text("SECRET".to_string())) } @@ -1576,4 +1565,5 @@ fn secret_number_tool() -> Tool { }, "required": ["key"] })) + .with_handler(Arc::new(SecretNumberTool)) } diff --git a/rust/tests/e2e/telemetry.rs b/rust/tests/e2e/telemetry.rs index 7cb38575b..10111be52 100644 --- a/rust/tests/e2e/telemetry.rs +++ b/rust/tests/e2e/telemetry.rs @@ -36,18 +36,22 @@ async fn should_export_file_telemetry_for_sdk_interactions() { )) .await .expect("start client"); - let tool_handlers: Vec> = vec![Arc::new(EchoTelemetryTool { - name: tool_name.to_string(), - })]; - let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); - let __perm = Arc::new(ApproveAllHandler); + let echo_tool = Tool::new(tool_name) + .with_description("Echoes a marker string for telemetry validation.") + .with_parameters(json!({ + "type": "object", + "properties": { + "value": { "type": "string" } + }, + "required": ["value"] + })) + .with_handler(Arc::new(EchoTelemetryTool)); let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) - .with_permission_handler(__perm) - .with_tool_handlers(tool_handlers) - .with_tools(tools), + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_tools(vec![echo_tool]), ) .await .expect("create session"); @@ -135,24 +139,10 @@ async fn should_export_file_telemetry_for_sdk_interactions() { .await; } -struct EchoTelemetryTool { - name: String, -} +struct EchoTelemetryTool; #[async_trait] impl ToolHandler for EchoTelemetryTool { - fn tool(&self) -> Tool { - Tool::new(&self.name) - .with_description("Echoes a marker string for telemetry validation.") - .with_parameters(json!({ - "type": "object", - "properties": { - "value": { "type": "string" } - }, - "required": ["value"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { Ok(ToolResult::Text( invocation diff --git a/rust/tests/e2e/tool_results.rs b/rust/tests/e2e/tool_results.rs index 9b32248d5..4b731c286 100644 --- a/rust/tests/e2e/tool_results.rs +++ b/rust/tests/e2e/tool_results.rs @@ -21,7 +21,7 @@ async fn should_handle_structured_toolresultobject_from_custom_tool() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let session = create_tool_session(ctx, &client, WeatherTool).await; + let session = create_tool_session(ctx, &client, weather_tool()).await; let answer = session .send_and_wait("What's the weather in Paris?") @@ -48,7 +48,7 @@ async fn should_handle_tool_result_with_failure_resulttype() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let session = create_tool_session(ctx, &client, CheckStatusTool).await; + let session = create_tool_session(ctx, &client, check_status_tool()).await; let answer = session .send_and_wait("Check the status of the service using check_status. If it fails, say 'service is down'.") @@ -76,7 +76,7 @@ async fn should_preserve_tooltelemetry_and_not_stringify_structured_results_for_ Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let session = create_tool_session(ctx, &client, AnalyzeCodeTool).await; + let session = create_tool_session(ctx, &client, analyze_code_tool()).await; let answer = session .send_and_wait("Analyze the file main.ts for issues.") @@ -124,7 +124,7 @@ async fn should_handle_tool_result_with_rejected_resulttype() { ctx.set_default_copilot_user(); let client = ctx.start_client().await; let (call_tx, mut call_rx) = mpsc::unbounded_channel(); - let session = create_tool_session(ctx, &client, DeployTool { call_tx }).await; + let session = create_tool_session(ctx, &client, deploy_tool(call_tx)).await; let events = session.subscribe(); session @@ -161,7 +161,7 @@ async fn should_handle_tool_result_with_denied_resulttype() { ctx.set_default_copilot_user(); let client = ctx.start_client().await; let (call_tx, mut call_rx) = mpsc::unbounded_channel(); - let session = create_tool_session(ctx, &client, AccessSecretTool { call_tx }).await; + let session = create_tool_session(ctx, &client, access_secret_tool(call_tx)).await; let events = session.subscribe(); session @@ -188,24 +188,18 @@ async fn should_handle_tool_result_with_denied_resulttype() { .await; } -async fn create_tool_session( +async fn create_tool_session( _ctx: &super::support::E2eContext, client: &github_copilot_sdk::Client, - tool: T, -) -> github_copilot_sdk::session::Session -where - T: ToolHandler + 'static, -{ - let tool_handlers: Vec> = vec![Arc::new(tool)]; - let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); + tool: Tool, +) -> github_copilot_sdk::session::Session { let __perm = Arc::new(ApproveAllHandler); client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) .with_permission_handler(__perm) - .with_tool_handlers(tool_handlers) - .with_tools(tools), + .with_tools(vec![tool]), ) .await .expect("create session") @@ -229,19 +223,20 @@ fn expanded(text: impl Into, result_type: impl Into) -> ToolResu }) } +fn weather_tool() -> Tool { + string_tool( + "get_weather", + "Gets weather for a city", + "city", + "City name", + ) + .with_handler(Arc::new(WeatherTool)) +} + struct WeatherTool; #[async_trait::async_trait] impl ToolHandler for WeatherTool { - fn tool(&self) -> Tool { - string_tool( - "get_weather", - "Gets weather for a city", - "city", - "City name", - ) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let city = invocation .arguments @@ -255,14 +250,16 @@ impl ToolHandler for WeatherTool { } } +fn check_status_tool() -> Tool { + Tool::new("check_status") + .with_description("Checks the status of a service") + .with_handler(Arc::new(CheckStatusTool)) +} + struct CheckStatusTool; #[async_trait::async_trait] impl ToolHandler for CheckStatusTool { - fn tool(&self) -> Tool { - Tool::new("check_status").with_description("Checks the status of a service") - } - async fn call(&self, _invocation: ToolInvocation) -> Result { let mut result = match expanded("Service unavailable", "failure") { ToolResult::Expanded(result) => result, @@ -273,19 +270,20 @@ impl ToolHandler for CheckStatusTool { } } +fn analyze_code_tool() -> Tool { + string_tool( + "analyze_code", + "Analyzes code for issues", + "file", + "File to analyze", + ) + .with_handler(Arc::new(AnalyzeCodeTool)) +} + struct AnalyzeCodeTool; #[async_trait::async_trait] impl ToolHandler for AnalyzeCodeTool { - fn tool(&self) -> Tool { - string_tool( - "analyze_code", - "Analyzes code for issues", - "file", - "File to analyze", - ) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let file = invocation .arguments @@ -304,16 +302,18 @@ impl ToolHandler for AnalyzeCodeTool { } } +fn deploy_tool(call_tx: mpsc::UnboundedSender<()>) -> Tool { + Tool::new("deploy_service") + .with_description("Deploys a service") + .with_handler(Arc::new(DeployTool { call_tx })) +} + struct DeployTool { call_tx: mpsc::UnboundedSender<()>, } #[async_trait::async_trait] impl ToolHandler for DeployTool { - fn tool(&self) -> Tool { - Tool::new("deploy_service").with_description("Deploys a service") - } - async fn call(&self, _invocation: ToolInvocation) -> Result { let _ = self.call_tx.send(()); Ok(expanded( @@ -323,16 +323,18 @@ impl ToolHandler for DeployTool { } } +fn access_secret_tool(call_tx: mpsc::UnboundedSender<()>) -> Tool { + Tool::new("access_secret") + .with_description("Accesses a secret") + .with_handler(Arc::new(AccessSecretTool { call_tx })) +} + struct AccessSecretTool { call_tx: mpsc::UnboundedSender<()>, } #[async_trait::async_trait] impl ToolHandler for AccessSecretTool { - fn tool(&self) -> Tool { - Tool::new("access_secret").with_description("Accesses a secret") - } - async fn call(&self, _invocation: ToolInvocation) -> Result { let _ = self.call_tx.send(()); Ok(expanded( diff --git a/rust/tests/e2e/tools.rs b/rust/tests/e2e/tools.rs index daefcdf91..327058a4d 100644 --- a/rust/tests/e2e/tools.rs +++ b/rust/tests/e2e/tools.rs @@ -47,15 +47,13 @@ async fn invokes_custom_tool() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let tool_handlers: Vec> = vec![Arc::new(EncryptStringTool)]; - let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); let __perm = Arc::new(ApproveAllHandler); + let tools = vec![encrypt_string_tool()]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) .with_permission_handler(__perm) - .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -81,15 +79,13 @@ async fn handles_tool_calling_errors() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let tool_handlers: Vec> = vec![Arc::new(ErrorTool)]; - let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); let __perm = Arc::new(ApproveAllHandler); + let tools = vec![error_tool()]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) .with_permission_handler(__perm) - .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -132,15 +128,13 @@ async fn can_receive_and_return_complex_types() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let tool_handlers: Vec> = vec![Arc::new(DbQueryTool)]; - let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); let __perm = Arc::new(ApproveAllHandler); + let tools = vec![db_query_tool()]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) .with_permission_handler(__perm) - .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -173,15 +167,13 @@ async fn overrides_built_in_tool_with_custom_tool() { Box::pin(async move { ctx.set_default_copilot_user(); let client = ctx.start_client().await; - let tool_handlers: Vec> = vec![Arc::new(CustomGrepTool)]; - let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); let __perm = Arc::new(ApproveAllHandler); + let tools = vec![custom_grep_tool()]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) .with_permission_handler(__perm) - .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -212,15 +204,13 @@ async fn skippermission_sent_in_tool_definition() { permission_tx, decision: PermissionResult::Denied, }); - let tool_handlers: Vec> = vec![Arc::new(SafeLookupTool)]; - let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); let __perm = handler; + let tools = vec![safe_lookup_tool()]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) .with_permission_handler(__perm) - .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -264,15 +254,13 @@ async fn invokes_custom_tool_with_permission_handler() { permission_tx, decision: PermissionResult::Approved, }); - let tool_handlers: Vec> = vec![Arc::new(EncryptStringTool)]; - let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); let __perm = handler; + let tools = vec![encrypt_string_tool()]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) .with_permission_handler(__perm) - .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -310,16 +298,13 @@ async fn denies_custom_tool_when_permission_denied() { permission_tx, decision: PermissionResult::Denied, }); - let tool_handlers: Vec> = - vec![Arc::new(TrackedEncryptStringTool { call_tx })]; - let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); let __perm = handler; + let tools = vec![tracked_encrypt_string_tool(call_tx)]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) .with_permission_handler(__perm) - .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -355,17 +340,16 @@ async fn should_execute_multiple_custom_tools_in_parallel_single_turn() { let client = ctx.start_client().await; let (city_tx, mut city_rx) = mpsc::unbounded_channel(); let (country_tx, mut country_rx) = mpsc::unbounded_channel(); - let tool_handlers: Vec> = vec![Arc::new(LookupCityTool { call_tx: city_tx }), Arc::new(LookupCountryTool { - call_tx: country_tx, - })]; - let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); - let __perm = Arc::new(ApproveAllHandler); + let __perm = Arc::new(ApproveAllHandler); + let tools = vec![ + lookup_city_tool(city_tx), + lookup_country_tool(country_tx), + ]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) .with_permission_handler(__perm) - .with_tool_handlers(tool_handlers) .with_tools(tools), ) .await @@ -403,20 +387,13 @@ async fn should_respect_availabletools_and_excludedtools_combined() { ctx.set_default_copilot_user(); let client = ctx.start_client().await; let (excluded_tx, mut excluded_rx) = mpsc::unbounded_channel(); - let tool_handlers: Vec> = vec![ - Arc::new(AllowedTool), - Arc::new(ExcludedTool { - call_tx: excluded_tx, - }), - ]; - let tools: Vec = tool_handlers.iter().map(|h| h.tool()).collect(); let __perm = Arc::new(ApproveAllHandler); + let tools = vec![allowed_tool(), excluded_tool(excluded_tx)]; let session = client .create_session( SessionConfig::default() .with_github_token(super::support::DEFAULT_TEST_TOKEN) .with_permission_handler(__perm) - .with_tool_handlers(tool_handlers) .with_tools(tools) .with_available_tools(["allowed_tool", "excluded_tool"]) .with_excluded_tools(["excluded_tool"]), @@ -449,23 +426,24 @@ async fn should_respect_availabletools_and_excludedtools_combined() { struct EncryptStringTool; +fn encrypt_string_tool() -> Tool { + Tool::new("encrypt_string") + .with_description("Encrypts a string") + .with_parameters(json!({ + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "String to encrypt" + } + }, + "required": ["input"] + })) + .with_handler(Arc::new(EncryptStringTool)) +} + #[async_trait::async_trait] impl ToolHandler for EncryptStringTool { - fn tool(&self) -> Tool { - Tool::new("encrypt_string") - .with_description("Encrypts a string") - .with_parameters(json!({ - "type": "object", - "properties": { - "input": { - "type": "string", - "description": "String to encrypt" - } - }, - "required": ["input"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let input = invocation .arguments @@ -480,12 +458,24 @@ struct TrackedEncryptStringTool { call_tx: mpsc::UnboundedSender<()>, } +fn tracked_encrypt_string_tool(call_tx: mpsc::UnboundedSender<()>) -> Tool { + Tool::new("encrypt_string") + .with_description("Encrypts a string") + .with_parameters(json!({ + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "String to encrypt" + } + }, + "required": ["input"] + })) + .with_handler(Arc::new(TrackedEncryptStringTool { call_tx })) +} + #[async_trait::async_trait] impl ToolHandler for TrackedEncryptStringTool { - fn tool(&self) -> Tool { - EncryptStringTool.tool() - } - async fn call(&self, invocation: ToolInvocation) -> Result { let _ = self.call_tx.send(()); EncryptStringTool.call(invocation).await @@ -494,12 +484,14 @@ impl ToolHandler for TrackedEncryptStringTool { struct ErrorTool; +fn error_tool() -> Tool { + Tool::new("get_user_location") + .with_description("Gets the user's location") + .with_handler(Arc::new(ErrorTool)) +} + #[async_trait::async_trait] impl ToolHandler for ErrorTool { - fn tool(&self) -> Tool { - Tool::new("get_user_location").with_description("Gets the user's location") - } - async fn call(&self, _invocation: ToolInvocation) -> Result { Ok(ToolResult::Text( "Failed to execute `get_user_location` tool with arguments: {} due to error: Error: Tool execution failed" @@ -510,21 +502,22 @@ impl ToolHandler for ErrorTool { struct CustomGrepTool; +fn custom_grep_tool() -> Tool { + Tool::new("grep") + .with_description("A custom grep implementation that overrides the built-in") + .with_overrides_built_in_tool(true) + .with_parameters(json!({ + "type": "object", + "properties": { + "query": { "type": "string", "description": "Search query" } + }, + "required": ["query"] + })) + .with_handler(Arc::new(CustomGrepTool)) +} + #[async_trait::async_trait] impl ToolHandler for CustomGrepTool { - fn tool(&self) -> Tool { - Tool::new("grep") - .with_description("A custom grep implementation that overrides the built-in") - .with_overrides_built_in_tool(true) - .with_parameters(json!({ - "type": "object", - "properties": { - "query": { "type": "string", "description": "Search query" } - }, - "required": ["query"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let query = invocation .arguments @@ -537,21 +530,22 @@ impl ToolHandler for CustomGrepTool { struct SafeLookupTool; +fn safe_lookup_tool() -> Tool { + Tool::new("safe_lookup") + .with_description("A tool that skips permission") + .with_skip_permission(true) + .with_parameters(json!({ + "type": "object", + "properties": { + "id": { "type": "string", "description": "Lookup ID" } + }, + "required": ["id"] + })) + .with_handler(Arc::new(SafeLookupTool)) +} + #[async_trait::async_trait] impl ToolHandler for SafeLookupTool { - fn tool(&self) -> Tool { - Tool::new("safe_lookup") - .with_description("A tool that skips permission") - .with_skip_permission(true) - .with_parameters(json!({ - "type": "object", - "properties": { - "id": { "type": "string", "description": "Lookup ID" } - }, - "required": ["id"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let id = invocation .arguments @@ -566,20 +560,21 @@ struct LookupCityTool { call_tx: mpsc::UnboundedSender, } +fn lookup_city_tool(call_tx: mpsc::UnboundedSender) -> Tool { + Tool::new("lookup_city") + .with_description("Looks up city information") + .with_parameters(json!({ + "type": "object", + "properties": { + "city": { "type": "string", "description": "City name" } + }, + "required": ["city"] + })) + .with_handler(Arc::new(LookupCityTool { call_tx })) +} + #[async_trait::async_trait] impl ToolHandler for LookupCityTool { - fn tool(&self) -> Tool { - Tool::new("lookup_city") - .with_description("Looks up city information") - .with_parameters(json!({ - "type": "object", - "properties": { - "city": { "type": "string", "description": "City name" } - }, - "required": ["city"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let city = invocation .arguments @@ -596,20 +591,21 @@ struct LookupCountryTool { call_tx: mpsc::UnboundedSender, } +fn lookup_country_tool(call_tx: mpsc::UnboundedSender) -> Tool { + Tool::new("lookup_country") + .with_description("Looks up country information") + .with_parameters(json!({ + "type": "object", + "properties": { + "country": { "type": "string", "description": "Country name" } + }, + "required": ["country"] + })) + .with_handler(Arc::new(LookupCountryTool { call_tx })) +} + #[async_trait::async_trait] impl ToolHandler for LookupCountryTool { - fn tool(&self) -> Tool { - Tool::new("lookup_country") - .with_description("Looks up country information") - .with_parameters(json!({ - "type": "object", - "properties": { - "country": { "type": "string", "description": "Country name" } - }, - "required": ["country"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let country = invocation .arguments @@ -627,20 +623,21 @@ impl ToolHandler for LookupCountryTool { struct AllowedTool; +fn allowed_tool() -> Tool { + Tool::new("allowed_tool") + .with_description("An allowed tool") + .with_parameters(json!({ + "type": "object", + "properties": { + "input": { "type": "string", "description": "Input value" } + }, + "required": ["input"] + })) + .with_handler(Arc::new(AllowedTool)) +} + #[async_trait::async_trait] impl ToolHandler for AllowedTool { - fn tool(&self) -> Tool { - Tool::new("allowed_tool") - .with_description("An allowed tool") - .with_parameters(json!({ - "type": "object", - "properties": { - "input": { "type": "string", "description": "Input value" } - }, - "required": ["input"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let input = invocation .arguments @@ -658,20 +655,21 @@ struct ExcludedTool { call_tx: mpsc::UnboundedSender<()>, } +fn excluded_tool(call_tx: mpsc::UnboundedSender<()>) -> Tool { + Tool::new("excluded_tool") + .with_description("A tool that should be excluded") + .with_parameters(json!({ + "type": "object", + "properties": { + "input": { "type": "string", "description": "Input value" } + }, + "required": ["input"] + })) + .with_handler(Arc::new(ExcludedTool { call_tx })) +} + #[async_trait::async_trait] impl ToolHandler for ExcludedTool { - fn tool(&self) -> Tool { - Tool::new("excluded_tool") - .with_description("A tool that should be excluded") - .with_parameters(json!({ - "type": "object", - "properties": { - "input": { "type": "string", "description": "Input value" } - }, - "required": ["input"] - })) - } - async fn call(&self, invocation: ToolInvocation) -> Result { let _ = self.call_tx.send(()); let input = invocation @@ -706,31 +704,32 @@ impl PermissionHandler for RecordingPermissionHandler { struct DbQueryTool; -#[async_trait::async_trait] -impl ToolHandler for DbQueryTool { - fn tool(&self) -> Tool { - Tool::new("db_query") - .with_description("Performs a database query") - .with_parameters(json!({ - "type": "object", - "properties": { - "query": { - "type": "object", - "properties": { - "table": { "type": "string" }, - "ids": { - "type": "array", - "items": { "type": "integer" } - }, - "sortAscending": { "type": "boolean" } +fn db_query_tool() -> Tool { + Tool::new("db_query") + .with_description("Performs a database query") + .with_parameters(json!({ + "type": "object", + "properties": { + "query": { + "type": "object", + "properties": { + "table": { "type": "string" }, + "ids": { + "type": "array", + "items": { "type": "integer" } }, - "required": ["table", "ids", "sortAscending"] - } - }, - "required": ["query"] - })) - } + "sortAscending": { "type": "boolean" } + }, + "required": ["table", "ids", "sortAscending"] + } + }, + "required": ["query"] + })) + .with_handler(Arc::new(DbQueryTool)) +} +#[async_trait::async_trait] +impl ToolHandler for DbQueryTool { async fn call(&self, invocation: ToolInvocation) -> Result { let query = invocation.arguments.get("query").expect("query argument"); assert_eq!( diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 916e28f0b..85ec76600 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -2049,12 +2049,6 @@ async fn external_tool_requested_dispatches_to_handler_and_responds() { struct RunTestsTool; #[async_trait] impl tool::ToolHandler for RunTestsTool { - fn tool(&self) -> Tool { - Tool::new("run_tests") - .with_description("Run tests") - .with_parameters(serde_json::json!({"type":"object"})) - } - async fn call( &self, invocation: ToolInvocation, @@ -2067,7 +2061,12 @@ async fn external_tool_requested_dispatches_to_handler_and_responds() { } let (_session, mut server) = create_session_pair_with_config(|cfg| { - cfg.with_tool_handlers(vec![Arc::new(RunTestsTool) as Arc]) + cfg.with_tools(vec![ + Tool::new("run_tests") + .with_description("Run tests") + .with_parameters(serde_json::json!({"type":"object"})) + .with_handler(Arc::new(RunTestsTool)), + ]) }) .await; @@ -2098,12 +2097,6 @@ async fn external_tool_broadcast_for_unknown_tool_is_not_responded_to() { struct FooTool; #[async_trait] impl tool::ToolHandler for FooTool { - fn tool(&self) -> Tool { - Tool::new("foo") - .with_description("foo") - .with_parameters(serde_json::json!({"type":"object"})) - } - async fn call( &self, _invocation: ToolInvocation, @@ -2113,7 +2106,12 @@ async fn external_tool_broadcast_for_unknown_tool_is_not_responded_to() { } let (_session, mut server) = create_session_pair_with_config(|cfg| { - cfg.with_tool_handlers(vec![Arc::new(FooTool) as Arc]) + cfg.with_tools(vec![ + Tool::new("foo") + .with_description("foo") + .with_parameters(serde_json::json!({"type":"object"})) + .with_handler(Arc::new(FooTool)), + ]) }) .await; server @@ -3873,12 +3871,6 @@ async fn tool_invocation_carries_trace_context_from_event() { #[async_trait] impl tool::ToolHandler for CapturingTool { - fn tool(&self) -> Tool { - Tool::new("calc") - .with_description("calc") - .with_parameters(serde_json::json!({"type":"object"})) - } - async fn call( &self, invocation: ToolInvocation, @@ -3899,7 +3891,12 @@ async fn tool_invocation_carries_trace_context_from_event() { signal: signal.clone(), }); let (_session, mut server) = create_session_pair_with_config(move |cfg| { - cfg.with_tool_handlers(vec![handler as Arc]) + cfg.with_tools(vec![ + Tool::new("calc") + .with_description("calc") + .with_parameters(serde_json::json!({"type":"object"})) + .with_handler(handler.clone()), + ]) }) .await; diff --git a/test/scenarios/tools/custom-agents/rust/src/main.rs b/test/scenarios/tools/custom-agents/rust/src/main.rs index e31021b92..fe720b803 100644 --- a/test/scenarios/tools/custom-agents/rust/src/main.rs +++ b/test/scenarios/tools/custom-agents/rust/src/main.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandler, define_tool}; +use github_copilot_sdk::tool::define_tool; use github_copilot_sdk::types::{CustomAgentConfig, DefaultAgentConfig, SessionConfig, ToolResult}; use github_copilot_sdk::{Client, ClientOptions}; use schemars::JsonSchema; @@ -23,7 +23,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { opts.github_token = std::env::var("GITHUB_TOKEN").ok(); let client = Client::start(opts).await?; - let analyze_codebase: Arc = Arc::from(define_tool( + let analyze_codebase = define_tool( "analyze-codebase", "Performs deep analysis of the codebase", |_inv, params: AnalyzeParams| async move { @@ -32,9 +32,7 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { params.query ))) }, - )); - - let tools = vec![analyze_codebase.tool()]; + ); let mut researcher = CustomAgentConfig::default(); researcher.name = "researcher".to_string(); @@ -55,14 +53,13 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); - config.tools = Some(tools); config.default_agent = Some(DefaultAgentConfig { excluded_tools: Some(vec!["analyze-codebase".to_string()]), }); config.custom_agents = Some(vec![researcher]); let config = config .with_permission_handler(Arc::new(ApproveAllHandler)) - .with_tool_handlers(vec![analyze_codebase]); + .with_tools(vec![analyze_codebase]); let session = client.create_session(config).await?; diff --git a/test/scenarios/tools/tool-overrides/rust/src/main.rs b/test/scenarios/tools/tool-overrides/rust/src/main.rs index 0bfc0b8c4..5d5108724 100644 --- a/test/scenarios/tools/tool-overrides/rust/src/main.rs +++ b/test/scenarios/tools/tool-overrides/rust/src/main.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use github_copilot_sdk::handler::ApproveAllHandler; -use github_copilot_sdk::tool::{ToolHandler, define_tool}; +use github_copilot_sdk::tool::define_tool; use github_copilot_sdk::types::{SessionConfig, ToolResult}; use github_copilot_sdk::{Client, ClientOptions}; use schemars::JsonSchema; @@ -23,27 +23,20 @@ async fn main() -> Result<(), github_copilot_sdk::Error> { opts.github_token = std::env::var("GITHUB_TOKEN").ok(); let client = Client::start(opts).await?; - let grep_tool: Arc = Arc::from(define_tool( + let mut grep_tool = define_tool( "grep", "A custom grep implementation that overrides the built-in", |_inv, params: GrepParams| async move { Ok(ToolResult::Text(format!("CUSTOM_GREP_RESULT: {}", params.query))) }, - )); - - let mut tools = vec![grep_tool.tool()]; - for t in tools.iter_mut() { - if t.name == "grep" { - t.overrides_built_in_tool = true; - } - } + ); + grep_tool.overrides_built_in_tool = true; let mut config = SessionConfig::default(); config.model = Some("claude-haiku-4.5".to_string()); - config.tools = Some(tools); let config = config .with_permission_handler(Arc::new(ApproveAllHandler)) - .with_tool_handlers(vec![grep_tool]); + .with_tools(vec![grep_tool]); let session = client.create_session(config).await?; From ed9eedf8f6c9c931f5cd5bdb10e1099a97f750c2 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Fri, 22 May 2026 00:25:25 +0100 Subject: [PATCH 20/22] Rename with_transform to with_system_message_transform; rename transform field to system_message_transform Both the SessionConfig / ResumeSessionConfig builder and the underlying field now use the unambiguous name. The RPC method on the wire (systemMessage.transform) is unchanged. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/README.md | 2 +- rust/src/session.rs | 12 ++++++------ rust/src/tool.rs | 4 +++- rust/src/types.rs | 36 ++++++++++++++++++++++++------------ rust/tests/session_test.rs | 8 ++++---- 5 files changed, 38 insertions(+), 24 deletions(-) diff --git a/rust/README.md b/rust/README.md index e4dc1a029..c10f5a804 100644 --- a/rust/README.md +++ b/rust/README.md @@ -322,7 +322,7 @@ let session = client .create_session( config .with_permission_handler(handler) - .with_transform(Arc::new(MyTransform)), + .with_system_message_transform(Arc::new(MyTransform)), ) .await?; ``` diff --git a/rust/src/session.rs b/rust/src/session.rs index 0e5377d9e..842d5d732 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -764,13 +764,13 @@ impl Client { /// /// All callbacks (per-event handlers, tool handlers, hooks, transform) /// are configured via [`SessionConfig`] using its `with_*_handler` / - /// `with_tools` / `with_hooks` / `with_transform` builder + /// `with_tools` / `with_hooks` / `with_system_message_transform` builder /// methods. /// /// If [`hooks_handler`](SessionConfig::hooks_handler) is set, the /// wire-level `hooks` flag is automatically enabled. /// - /// If [`transform`](SessionConfig::transform) is set, the SDK injects + /// If [`system_message_transform`](SessionConfig::system_message_transform) is set, the SDK injects /// `action: "transform"` sections into the [`SystemMessageConfig`] wire /// format and handles `systemMessage.transform` RPC callbacks during /// the session. @@ -788,7 +788,7 @@ impl Client { if config.hooks_handler.is_some() && config.hooks.is_none() { config.hooks = Some(true); } - if let Some(transforms) = config.transform.clone() { + if let Some(transforms) = config.system_message_transform.clone() { inject_transform_sections(&mut config, transforms.as_ref()); } let wire = config.to_wire(session_id.clone()); @@ -824,7 +824,7 @@ impl Client { tools: Arc::new(tool_map), }; let hooks = config.hooks_handler.take(); - let transforms = config.transform.take(); + let transforms = config.system_message_transform.take(); let tools_count = config.tools.as_ref().map_or(0, Vec::len); let commands_count = config.commands.as_ref().map_or(0, Vec::len); let has_hooks = hooks.is_some(); @@ -943,7 +943,7 @@ impl Client { if config.hooks_handler.is_some() && config.hooks.is_none() { config.hooks = Some(true); } - if let Some(transforms) = config.transform.clone() { + if let Some(transforms) = config.system_message_transform.clone() { inject_transform_sections_resume(&mut config, transforms.as_ref()); } let wire = config.to_wire(); @@ -979,7 +979,7 @@ impl Client { tools: Arc::new(tool_map), }; let hooks = config.hooks_handler.take(); - let transforms = config.transform.take(); + let transforms = config.system_message_transform.take(); let tools_count = config.tools.as_ref().map_or(0, Vec::len); let commands_count = config.commands.as_ref().map_or(0, Vec::len); let has_hooks = hooks.is_some(); diff --git a/rust/src/tool.rs b/rust/src/tool.rs index 3a35dd576..d30686ae0 100644 --- a/rust/src/tool.rs +++ b/rust/src/tool.rs @@ -21,7 +21,9 @@ use async_trait::async_trait; pub use schemars::JsonSchema; use crate::Error; -use crate::types::{Tool, ToolBinaryResult, ToolInvocation, ToolResult, ToolResultExpanded}; +use crate::types::{ToolBinaryResult, ToolInvocation, ToolResult, ToolResultExpanded}; +#[cfg(any(feature = "derive", test))] +use crate::types::Tool; /// Generate a JSON Schema [`Value`](serde_json::Value) from a Rust type. /// diff --git a/rust/src/types.rs b/rust/src/types.rs index 991e41623..4637131fa 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -1227,9 +1227,9 @@ pub struct SessionConfig { /// System-message transform. When set, the SDK injects the matching /// `action: "transform"` sections into the system message and routes /// `systemMessage.transform` RPC callbacks to it during the session. - /// Use [`with_transform`](Self::with_transform) to install one. + /// Use [`with_system_message_transform`](Self::with_system_message_transform) to install one. #[serde(skip)] - pub transform: Option>, + pub system_message_transform: Option>, } impl std::fmt::Debug for SessionConfig { @@ -1298,7 +1298,10 @@ impl std::fmt::Debug for SessionConfig { "hooks_handler", &self.hooks_handler.as_ref().map(|_| ""), ) - .field("transform", &self.transform.as_ref().map(|_| "")) + .field( + "system_message_transform", + &self.system_message_transform.as_ref().map(|_| ""), + ) .finish() } } @@ -1349,7 +1352,7 @@ impl Default for SessionConfig { auto_mode_switch_handler: None, hooks_handler: None, permission_policy: None, - transform: None, + system_message_transform: None, } } } @@ -1471,8 +1474,11 @@ impl SessionConfig { /// Install a [`SystemMessageTransform`]. The SDK injects the matching /// `action: "transform"` sections into the system message and routes /// `systemMessage.transform` RPC callbacks to it during the session. - pub fn with_transform(mut self, transform: Arc) -> Self { - self.transform = Some(transform); + pub fn with_system_message_transform( + mut self, + transform: Arc, + ) -> Self { + self.system_message_transform = Some(transform); self } @@ -1857,9 +1863,9 @@ pub struct ResumeSessionConfig { /// Permission policy. See `SessionConfig::permission_policy`. #[serde(skip)] pub(crate) permission_policy: Option, - /// System-message transform. See [`SessionConfig::transform`]. + /// System-message transform. See [`SessionConfig::system_message_transform`]. #[serde(skip)] - pub transform: Option>, + pub system_message_transform: Option>, } impl std::fmt::Debug for ResumeSessionConfig { @@ -1926,7 +1932,10 @@ impl std::fmt::Debug for ResumeSessionConfig { "hooks_handler", &self.hooks_handler.as_ref().map(|_| ""), ) - .field("transform", &self.transform.as_ref().map(|_| "")) + .field( + "system_message_transform", + &self.system_message_transform.as_ref().map(|_| ""), + ) .field("suppress_resume_event", &self.suppress_resume_event) .field("continue_pending_work", &self.continue_pending_work) .finish() @@ -2030,7 +2039,7 @@ impl ResumeSessionConfig { auto_mode_switch_handler: None, hooks_handler: None, permission_policy: None, - transform: None, + system_message_transform: None, } } @@ -2075,8 +2084,11 @@ impl ResumeSessionConfig { } /// Install a [`SystemMessageTransform`]. - pub fn with_transform(mut self, transform: Arc) -> Self { - self.transform = Some(transform); + pub fn with_system_message_transform( + mut self, + transform: Arc, + ) -> Self { + self.system_message_transform = Some(transform); self } diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 85ec76600..ffdb894eb 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -2564,7 +2564,7 @@ async fn hooks_invoke_returns_empty_for_unregistered_hook() { assert_eq!(response["result"]["output"], serde_json::json!({})); } -async fn create_session_pair_with_transforms( +async fn create_session_pair_with_system_message_transforms( transforms: Arc, ) -> (github_copilot_sdk::session::Session, FakeServer) { let (client, server_read, server_write) = make_client(); @@ -2579,7 +2579,7 @@ async fn create_session_pair_with_transforms( let client = client.clone(); async move { client - .create_session(SessionConfig::default().with_transform(transforms)) + .create_session(SessionConfig::default().with_system_message_transform(transforms)) .await .unwrap() } @@ -2626,7 +2626,7 @@ async fn system_message_transform_dispatches_to_transform() { } let (_session, mut server) = - create_session_pair_with_transforms(Arc::new(AppendTransform)).await; + create_session_pair_with_system_message_transforms(Arc::new(AppendTransform)).await; server .send_request( @@ -2671,7 +2671,7 @@ async fn system_message_transform_returns_error_for_missing_sections() { } let (_session, mut server) = - create_session_pair_with_transforms(Arc::new(DummyTransform)).await; + create_session_pair_with_system_message_transforms(Arc::new(DummyTransform)).await; // Send request with no sections parameter server From 97dc4fbaf5dba6edc0b4333093ffddfcfb579c53 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Fri, 22 May 2026 00:30:07 +0100 Subject: [PATCH 21/22] Drop redundant with_tools(declaration-only) calls in multi_client tests The with_tool_handlers/with_tools merge means a single call now both declares the tool and attaches the handler. The leftover second with_tools([tool_definition]) call was overwriting the handler-bearing tools list with declaration-only stubs, causing tool dispatch to break and tests to time out. Also conditionally import Tool in tool.rs (only used with derive/test features) to silence unused-import lint under --features test-support. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/src/tool.rs | 2 +- rust/tests/e2e/multi_client.rs | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/rust/src/tool.rs b/rust/src/tool.rs index d30686ae0..95ac16d68 100644 --- a/rust/src/tool.rs +++ b/rust/src/tool.rs @@ -21,9 +21,9 @@ use async_trait::async_trait; pub use schemars::JsonSchema; use crate::Error; -use crate::types::{ToolBinaryResult, ToolInvocation, ToolResult, ToolResultExpanded}; #[cfg(any(feature = "derive", test))] use crate::types::Tool; +use crate::types::{ToolBinaryResult, ToolInvocation, ToolResult, ToolResultExpanded}; /// Generate a JSON Schema [`Value`](serde_json::Value) from a Rust type. /// diff --git a/rust/tests/e2e/multi_client.rs b/rust/tests/e2e/multi_client.rs index 23552954b..836121644 100644 --- a/rust/tests/e2e/multi_client.rs +++ b/rust/tests/e2e/multi_client.rs @@ -41,7 +41,6 @@ async fn both_clients_see_tool_request_and_completion_events() { "MAGIC_", "_42", )])) - .with_tools([EchoTool::tool_definition("magic_number", "seed")]) .with_available_tools(["magic_number"]), ) .await @@ -293,7 +292,6 @@ async fn two_clients_register_different_tools_and_agent_uses_both() { "CITY_FOR_", "", )])) - .with_tools([EchoTool::tool_definition("city_lookup", "countryCode")]) .with_available_tools(["city_lookup", "currency_lookup"]), ) .await @@ -308,7 +306,6 @@ async fn two_clients_register_different_tools_and_agent_uses_both() { "CURRENCY_FOR_", "", )])) - .with_tools([EchoTool::tool_definition("currency_lookup", "countryCode")]) .with_available_tools(["city_lookup", "currency_lookup"]), ) .await @@ -361,7 +358,6 @@ async fn disconnecting_client_removes_its_tools() { "STABLE_", "", )])) - .with_tools([EchoTool::tool_definition("stable_tool", "input")]) .with_available_tools(["stable_tool", "ephemeral_tool"]), ) .await @@ -376,7 +372,6 @@ async fn disconnecting_client_removes_its_tools() { "EPHEMERAL_", "", )])) - .with_tools([EchoTool::tool_definition("ephemeral_tool", "input")]) .with_available_tools(["stable_tool", "ephemeral_tool"]), ) .await From bc508db998c448b9263d4a739fc1e861577d004d Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Fri, 22 May 2026 00:32:51 +0100 Subject: [PATCH 22/22] Fix README elicitation example to use ElicitationRequest (not ElicitationRequestData) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rust/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/README.md b/rust/README.md index c10f5a804..2b3a5423c 100644 --- a/rust/README.md +++ b/rust/README.md @@ -443,7 +443,7 @@ To opt your client into receiving `elicitation.requested` broadcasts, install an ```rust,ignore use async_trait::async_trait; use github_copilot_sdk::handler::{ElicitationHandler, ElicitationResult}; -use github_copilot_sdk::types::{ElicitationRequestData, RequestId, SessionId}; +use github_copilot_sdk::types::{ElicitationRequest, RequestId, SessionId}; struct MyElicitation; @@ -453,7 +453,7 @@ impl ElicitationHandler for MyElicitation { &self, _sid: SessionId, _rid: RequestId, - _data: ElicitationRequestData, + _request: ElicitationRequest, ) -> ElicitationResult { ElicitationResult::cancel() }