diff --git a/giskard/llm/client/mistral.py b/giskard/llm/client/mistral.py index 3aa0529610..04b190280c 100644 --- a/giskard/llm/client/mistral.py +++ b/giskard/llm/client/mistral.py @@ -1,5 +1,6 @@ from typing import Optional, Sequence +import os from dataclasses import asdict from logging import warning @@ -9,8 +10,7 @@ from .base import ChatMessage try: - from mistralai.client import MistralClient as _MistralClient - from mistralai.models.chat_completion import ChatMessage as MistralChatMessage + from mistralai import Mistral except ImportError as err: raise LLMImportError( flavor="llm", msg="To use Mistral models, please install the `mistralai` package with `pip install mistralai`" @@ -18,9 +18,9 @@ class MistralClient(LLMClient): - def __init__(self, model: str = "mistral-large-latest", client: _MistralClient = None): + def __init__(self, model: str = "mistral-large-latest", client: Mistral = None): self.model = model - self._client = client or _MistralClient() + self._client = client or Mistral(api_key=os.getenv("MISTRAL_API_KEY", "")) def complete( self, @@ -43,9 +43,9 @@ def complete( extra_params["response_format"] = {"type": "json_object"} try: - completion = self._client.chat( + completion = self._client.chat.complete( model=self.model, - messages=[MistralChatMessage(**asdict(m)) for m in messages], + messages=[asdict(m) for m in messages], temperature=temperature, max_tokens=max_tokens, **extra_params, diff --git a/pdm.lock b/pdm.lock index 4b14431235..688a1ed257 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "doc", "llm", "ml_runtime", "talk", "tensorflow", "test"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:f1f61e224e92ce1baaff8986cbbc1ba4ec6850a28672045c1d333aaceb2153ec" +content_hash = "sha256:22fea39c3c4fb7d259beaf5e9868b34ccbd9201689b2f1b6d5bb57e2bea6cc69" [[metadata.targets]] requires_python = ">=3.9,<3.13" @@ -2642,6 +2642,17 @@ files = [ {file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"}, ] +[[package]] +name = "jsonpath-python" +version = "1.0.6" +requires_python = ">=3.6" +summary = "A more powerful JSONPath implementation in modern python" +groups = ["dev"] +files = [ + {file = "jsonpath-python-1.0.6.tar.gz", hash = "sha256:dd5be4a72d8a2995c3f583cf82bf3cd1a9544cfdabf2d22595b67aff07349666"}, + {file = "jsonpath_python-1.0.6-py3-none-any.whl", hash = "sha256:1e3b78df579f5efc23565293612decee04214609208a2335884b3ee3f786b575"}, +] + [[package]] name = "jsonpointer" version = "3.0.0" @@ -3624,18 +3635,20 @@ files = [ [[package]] name = "mistralai" -version = "0.4.2" -requires_python = "<4.0,>=3.9" -summary = "" +version = "1.0.3" +requires_python = "<4.0,>=3.8" +summary = "Python Client SDK for the Mistral AI API." groups = ["dev"] dependencies = [ - "httpx<1,>=0.25", - "orjson<3.11,>=3.9.10", - "pydantic<3,>=2.5.2", + "httpx<0.28.0,>=0.27.0", + "jsonpath-python<2.0.0,>=1.0.6", + "pydantic<2.9.0,>=2.8.2", + "python-dateutil<3.0.0,>=2.9.0.post0", + "typing-inspect<0.10.0,>=0.9.0", ] files = [ - {file = "mistralai-0.4.2-py3-none-any.whl", hash = "sha256:63c98eea139585f0a3b2c4c6c09c453738bac3958055e6f2362d3866e96b0168"}, - {file = "mistralai-0.4.2.tar.gz", hash = "sha256:5eb656710517168ae053f9847b0bb7f617eda07f1f93f946ad6c91a4d407fd93"}, + {file = "mistralai-1.0.3-py3-none-any.whl", hash = "sha256:64af7c9192e64dc66b2da6d1c4d54a1324a881c21665a2f93d6b35d9de9f87c8"}, + {file = "mistralai-1.0.3.tar.gz", hash = "sha256:84f1a217666c76fec9d477ae266399b813c3ac32a4a348d2ecd5fe1c039b0667"}, ] [[package]] @@ -4690,7 +4703,7 @@ name = "orjson" version = "3.10.7" requires_python = ">=3.8" summary = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" -groups = ["dev", "ml_runtime", "test"] +groups = ["ml_runtime", "test"] files = [ {file = "orjson-3.10.7-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:74f4544f5a6405b90da8ea724d15ac9c36da4d72a738c64685003337401f5c12"}, {file = "orjson-3.10.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34a566f22c28222b08875b18b0dfbf8a947e69df21a9ed5c51a6bf91cfb944ac"}, @@ -7652,7 +7665,7 @@ files = [ name = "typing-inspect" version = "0.9.0" summary = "Runtime inspection utilities for typing module." -groups = ["ml_runtime", "test"] +groups = ["dev", "ml_runtime", "test"] dependencies = [ "mypy-extensions>=0.3.0", "typing-extensions>=3.7.4", diff --git a/pyproject.toml b/pyproject.toml index 40a5801850..9c7e07c235 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ dev = [ "pytest-asyncio>=0.21.1", "pydantic>=2", "avidtools", - "mistralai>=0.1.8, <1", + "mistralai>=1", "boto3>=1.34.88", "scikit-learn==1.4.2", ] diff --git a/tests/llm/test_llm_client.py b/tests/llm/test_llm_client.py index 86ea12020e..428c0d4dbe 100644 --- a/tests/llm/test_llm_client.py +++ b/tests/llm/test_llm_client.py @@ -4,9 +4,6 @@ import pydantic import pytest from google.generativeai.types import ContentDict -from mistralai.models.chat_completion import ChatCompletionResponse, ChatCompletionResponseChoice -from mistralai.models.chat_completion import ChatMessage as MistralChatMessage -from mistralai.models.chat_completion import FinishReason, UsageInfo from openai.types import CompletionUsage from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice @@ -35,22 +32,6 @@ ) -DEMO_MISTRAL_RESPONSE = ChatCompletionResponse( - id="2d62260a7a354e02922a4f6ad36930d3", - object="chat.completion", - created=1630000000, - model="mistral-large", - choices=[ - ChatCompletionResponseChoice( - index=0, - message=MistralChatMessage(role="assistant", content="This is a test!", name=None, tool_calls=None), - finish_reason=FinishReason.stop, - ) - ], - usage=UsageInfo(prompt_tokens=9, total_tokens=89, completion_tokens=80), -) - - def test_llm_complete_message(): client = Mock() client.chat.completions.create.return_value = DEMO_OPENAI_RESPONSE @@ -69,8 +50,25 @@ def test_llm_complete_message(): @pytest.mark.skipif(not PYDANTIC_V2, reason="Mistral raise an error with pydantic < 2") def test_mistral_client(): + from mistralai.models import ChatCompletionChoice, ChatCompletionResponse, UsageInfo + + demo_response = ChatCompletionResponse( + id="2d62260a7a354e02922a4f6ad36930d3", + object="chat.completion", + created=1630000000, + model="mistral-large", + choices=[ + ChatCompletionChoice( + index=0, + message={"role": "assistant", "content": "This is a test!"}, + finish_reason="stop", + ) + ], + usage=UsageInfo(prompt_tokens=9, total_tokens=89, completion_tokens=80), + ) + client = Mock() - client.chat.return_value = DEMO_MISTRAL_RESPONSE + client.chat.complete.return_value = demo_response from giskard.llm.client.mistral import MistralClient @@ -78,10 +76,10 @@ def test_mistral_client(): [ChatMessage(role="user", content="Hello")], temperature=0.11, max_tokens=12 ) - client.chat.assert_called_once() - assert client.chat.call_args[1]["messages"] == [MistralChatMessage(role="user", content="Hello")] - assert client.chat.call_args[1]["temperature"] == 0.11 - assert client.chat.call_args[1]["max_tokens"] == 12 + client.chat.complete.assert_called_once() + assert client.chat.complete.call_args[1]["messages"] == [{"role": "user", "content": "Hello"}] + assert client.chat.complete.call_args[1]["temperature"] == 0.11 + assert client.chat.complete.call_args[1]["max_tokens"] == 12 assert isinstance(res, ChatMessage) assert res.content == "This is a test!"