diff --git a/pyproject.toml b/pyproject.toml index 65df37a..342d4d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,3 +11,17 @@ build-backend = "setuptools.build_meta" [project.scripts] mcp-proxy = "mcp_proxy.__main__:main" + +[tool.uv] +dev-dependencies = [ + "pytest>=8.3.3", + "pytest-asyncio>=0.25.0", +] + +[tool.pytest.ini_options] +pythonpath = "src" +addopts = [ + "--import-mode=importlib", +] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" diff --git a/src/mcp_proxy/__init__.py b/src/mcp_proxy/__init__.py index 0b98f49..8c95f43 100644 --- a/src/mcp_proxy/__init__.py +++ b/src/mcp_proxy/__init__.py @@ -7,79 +7,89 @@ logger = logging.getLogger(__name__) -async def confugure_app(name: str, remote_app: ClientSession): - app = server.Server(name) +async def create_proxy_server(remote_app: ClientSession): + """Create a server instance from a remote app.""" - async def _list_prompts(_: t.Any) -> types.ServerResult: - result = await remote_app.list_prompts() - return types.ServerResult(result) + response = await remote_app.initialize() + capabilities = response.capabilities - app.request_handlers[types.ListPromptsRequest] = _list_prompts + app = server.Server(response.serverInfo.name) - async def _get_prompt(req: types.GetPromptRequest) -> types.ServerResult: - result = await remote_app.get_prompt(req.params.name, req.params.arguments) - return types.ServerResult(result) + if capabilities.prompts: + async def _list_prompts(_: t.Any) -> types.ServerResult: + result = await remote_app.list_prompts() + return types.ServerResult(result) - app.request_handlers[types.GetPromptRequest] = _get_prompt + app.request_handlers[types.ListPromptsRequest] = _list_prompts - async def _list_resources(_: t.Any) -> types.ServerResult: - result = await remote_app.list_resources() - return types.ServerResult(result) + async def _get_prompt(req: types.GetPromptRequest) -> types.ServerResult: + result = await remote_app.get_prompt(req.params.name, req.params.arguments) + return types.ServerResult(result) - app.request_handlers[types.ListResourcesRequest] = _list_resources + app.request_handlers[types.GetPromptRequest] = _get_prompt - # list_resource_templates() is not implemented in the client - # async def _list_resource_templates(_: t.Any) -> types.ServerResult: - # result = await remote_app.list_resource_templates() - # return types.ServerResult(result) + if capabilities.resources: + async def _list_resources(_: t.Any) -> types.ServerResult: + result = await remote_app.list_resources() + return types.ServerResult(result) - # app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates + app.request_handlers[types.ListResourcesRequest] = _list_resources - async def _read_resource(req: types.ReadResourceRequest): - result = await remote_app.read_resource(req.params.uri) - return types.ServerResult(result) + # list_resource_templates() is not implemented in the client + # async def _list_resource_templates(_: t.Any) -> types.ServerResult: + # result = await remote_app.list_resource_templates() + # return types.ServerResult(result) - app.request_handlers[types.ReadResourceRequest] = _read_resource + # app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates - async def _set_logging_level(req: types.SetLevelRequest): - await remote_app.set_logging_level(req.params.level) - return types.ServerResult(types.EmptyResult()) + async def _read_resource(req: types.ReadResourceRequest): + result = await remote_app.read_resource(req.params.uri) + return types.ServerResult(result) - app.request_handlers[types.SetLevelRequest] = _set_logging_level + app.request_handlers[types.ReadResourceRequest] = _read_resource - async def _subscribe_resource(req: types.SubscribeRequest): - await remote_app.subscribe_resource(req.params.uri) - return types.ServerResult(types.EmptyResult()) + if capabilities.logging: + async def _set_logging_level(req: types.SetLevelRequest): + await remote_app.set_logging_level(req.params.level) + return types.ServerResult(types.EmptyResult()) - app.request_handlers[types.SubscribeRequest] = _subscribe_resource + app.request_handlers[types.SetLevelRequest] = _set_logging_level - async def _unsubscribe_resource(req: types.UnsubscribeRequest): - await remote_app.unsubscribe_resource(req.params.uri) - return types.ServerResult(types.EmptyResult()) + if capabilities.resources: + async def _subscribe_resource(req: types.SubscribeRequest): + await remote_app.subscribe_resource(req.params.uri) + return types.ServerResult(types.EmptyResult()) - app.request_handlers[types.UnsubscribeRequest] = _unsubscribe_resource + app.request_handlers[types.SubscribeRequest] = _subscribe_resource - async def _list_tools(_: t.Any): - tools = await remote_app.list_tools() - return types.ServerResult(tools) + async def _unsubscribe_resource(req: types.UnsubscribeRequest): + await remote_app.unsubscribe_resource(req.params.uri) + return types.ServerResult(types.EmptyResult()) - app.request_handlers[types.ListToolsRequest] = _list_tools + app.request_handlers[types.UnsubscribeRequest] = _unsubscribe_resource - async def _call_tool(req: types.CallToolRequest) -> types.ServerResult: - try: - result = await remote_app.call_tool( - req.params.name, (req.params.arguments or {}) - ) - return types.ServerResult(result) - except Exception as e: - return types.ServerResult( - types.CallToolResult( - content=[types.TextContent(type="text", text=str(e))], - isError=True, + if capabilities.tools: + async def _list_tools(_: t.Any): + tools = await remote_app.list_tools() + return types.ServerResult(tools) + + app.request_handlers[types.ListToolsRequest] = _list_tools + + async def _call_tool(req: types.CallToolRequest) -> types.ServerResult: + try: + result = await remote_app.call_tool( + req.params.name, (req.params.arguments or {}) + ) + return types.ServerResult(result) + except Exception as e: + return types.ServerResult( + types.CallToolResult( + content=[types.TextContent(type="text", text=str(e))], + isError=True, + ) ) - ) - app.request_handlers[types.CallToolRequest] = _call_tool + app.request_handlers[types.CallToolRequest] = _call_tool async def _send_progress_notification(req: types.ProgressNotification): await remote_app.send_progress_notification( @@ -96,12 +106,7 @@ async def _complete(req: types.CompleteRequest): app.request_handlers[types.CompleteRequest] = _complete - async with server.stdio_server() as (read_stream, write_stream): - await app.run( - read_stream, - write_stream, - app.create_initialization_options(), - ) + return app async def run_sse_client(url: str): @@ -109,6 +114,10 @@ async def run_sse_client(url: str): async with sse_client(url=url) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: - response = await session.initialize() - - await confugure_app(response.serverInfo.name, session) + app = await create_proxy_server(session) + async with server.stdio_server() as (read_stream, write_stream): + await app.run( + read_stream, + write_stream, + app.create_initialization_options(), + ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..56648f1 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for mcp-proxy.""" diff --git a/tests/test_init.py b/tests/test_init.py new file mode 100644 index 0000000..95d0924 --- /dev/null +++ b/tests/test_init.py @@ -0,0 +1,109 @@ +"""Tests for the mcp-proxy module. + +Tests are running in two modes: +- One where the server is exercised directly though an in memory client, just to + set a baseline for the expected behavior. +- Another where the server is exercised through a proxy server, which forwards + the requests to the original server. + +The same test code is run on both to ensure parity. +""" + +from typing import Any +from collections.abc import AsyncGenerator, Callable +from contextlib import asynccontextmanager, AbstractAsyncContextManager + +import pytest + +from mcp import types +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.shared.exceptions import McpError +from mcp.shared.memory import create_connected_server_and_client_session + +from mcp_proxy import create_proxy_server + +TOOL_INPUT_SCHEMA = { + "type": "object", + "properties": { + "input1": {"type": "string"} + } +} + +SessionContextManager = Callable[[Server], AbstractAsyncContextManager[ClientSession]] + +# Direct server connection +in_memory: SessionContextManager = create_connected_server_and_client_session + +@asynccontextmanager +async def proxy(server: Server) -> AsyncGenerator[ClientSession, None]: + """Create a connection to the server through the proxy server.""" + async with in_memory(server) as session: + wrapped_server = await create_proxy_server(session) + async with in_memory(wrapped_server) as wrapped_session: + yield wrapped_session + + +@pytest.fixture(params=["server", "proxy"], scope="function") +def session_generator(request: Any) -> SessionContextManager: + """Fixture that returns a client creation strategy either direct or using the proxy.""" + if request.param == "server": + return in_memory + return proxy + + +async def test_list_prompts(session_generator: SessionContextManager): + """Test list_prompts.""" + + server = Server("prompt-server") + + @server.list_prompts() + async def list_prompts() -> list[types.Prompt]: + return [types.Prompt(name="prompt1")] + + async with session_generator(server) as session: + result = await session.initialize() + assert result.serverInfo.name == "prompt-server" + assert result.capabilities + assert result.capabilities.prompts + assert not result.capabilities.tools + assert not result.capabilities.resources + assert not result.capabilities.logging + + result = await session.list_prompts() + assert result.prompts == [types.Prompt(name="prompt1")] + + with pytest.raises(McpError, match="Method not found"): + await session.list_tools() + + +async def test_list_tools(session_generator: SessionContextManager): + """Test list_tools.""" + + server = Server("tools-server") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool( + name="tool-name", + description="tool-description", + inputSchema=TOOL_INPUT_SCHEMA + )] + + async with session_generator(server) as session: + result = await session.initialize() + assert result.serverInfo.name == "tools-server" + assert result.capabilities + assert result.capabilities.tools + assert not result.capabilities.prompts + assert not result.capabilities.resources + assert not result.capabilities.logging + + result = await session.list_tools() + assert len(result.tools) == 1 + assert result.tools[0].name == "tool-name" + assert result.tools[0].description == "tool-description" + assert result.tools[0].inputSchema == TOOL_INPUT_SCHEMA + + with pytest.raises(McpError, match="Method not found"): + await session.list_prompts() diff --git a/uv.lock b/uv.lock index 5c0054f..d1b7e87 100644 --- a/uv.lock +++ b/uv.lock @@ -33,6 +33,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a5/32/8f6669fc4798494966bf446c8c4a162e0b5d893dff088afddf76414f70e1/certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56", size = 164927 }, ] +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, +] + [[package]] name = "h11" version = "0.14.0" @@ -88,6 +97,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 }, ] +[[package]] +name = "iniconfig" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, +] + [[package]] name = "mcp" version = "1.1.2" @@ -113,9 +131,39 @@ dependencies = [ { name = "mcp" }, ] +[package.dev-dependencies] +dev = [ + { name = "pytest" }, + { name = "pytest-asyncio" }, +] + [package.metadata] requires-dist = [{ name = "mcp" }] +[package.metadata.requires-dev] +dev = [ + { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-asyncio", specifier = ">=0.25.0" }, +] + +[[package]] +name = "packaging" +version = "24.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451 }, +] + +[[package]] +name = "pluggy" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, +] + [[package]] name = "pydantic" version = "2.10.4" @@ -183,6 +231,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/b2/b2b50d5ecf21acf870190ae5d093602d95f66c9c31f9d5de6062eb329ad1/pydantic_core-2.27.2-cp313-cp313-win_arm64.whl", hash = "sha256:ac4dbfd1691affb8f48c2c13241a2e3b60ff23247cbcf981759c768b6633cf8b", size = 1885186 }, ] +[[package]] +name = "pytest" +version = "8.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 }, +] + +[[package]] +name = "pytest-asyncio" +version = "0.25.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/18/82fcb4ee47d66d99f6cd1efc0b11b2a25029f303c599a5afda7c1bca4254/pytest_asyncio-0.25.0.tar.gz", hash = "sha256:8c0610303c9e0442a5db8604505fc0f545456ba1528824842b37b4a626cbf609", size = 53298 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/56/2ee0cab25c11d4e38738a2a98c645a8f002e2ecf7b5ed774c70d53b92bb1/pytest_asyncio-0.25.0-py3-none-any.whl", hash = "sha256:db5432d18eac6b7e28b46dcd9b69921b55c3b1086e85febfe04e70b18d9e81b3", size = 19245 }, +] + [[package]] name = "sniffio" version = "1.3.1"