diff --git a/giskard/llm/client/bedrock.py b/giskard/llm/client/bedrock.py index 0295aade9d..6ad052e832 100644 --- a/giskard/llm/client/bedrock.py +++ b/giskard/llm/client/bedrock.py @@ -1,6 +1,7 @@ -from typing import List, Optional, Sequence +from typing import Dict, List, Optional, Sequence import json +from abc import ABC, abstractmethod from ..config import LLMConfigurationError from ..errors import LLMImportError @@ -15,43 +16,66 @@ ) from err -def _format_messages_claude(messages: Sequence[ChatMessage]): - input_msg_prompt: List = [] - system_prompts = [] +class BaseBedrockClient(LLMClient, ABC): + def __init__(self, bedrock_runtime_client, model: str): + self._client = bedrock_runtime_client + self.model = model - for msg in messages: - # System prompt is a specific parameter in Claude - if msg.role.lower() == "system": - system_prompts.append(msg.content) - continue + @abstractmethod + def _format_body( + self, + messages: Sequence[ChatMessage], + temperature: float = 1, + max_tokens: Optional[int] = 1000, + caller_id: Optional[str] = None, + seed: Optional[int] = None, + format=None, + ) -> Dict: + ... - # Only role user and assistant are allowed - role = msg.role.lower() - role = role if role in ["assistant", "user"] else "user" + @abstractmethod + def _parse_completion(self, completion, caller_id: Optional[str] = None) -> ChatMessage: + ... - # Consecutive messages need to be grouped - last_message = None if len(input_msg_prompt) == 0 else input_msg_prompt[-1] - if last_message is not None and last_message["role"] == role: - last_message["content"].append({"type": "text", "text": msg.content}) - continue + def complete( + self, + messages: Sequence[ChatMessage], + temperature: float = 1, + max_tokens: Optional[int] = 1000, + caller_id: Optional[str] = None, + seed: Optional[int] = None, + format=None, + ) -> ChatMessage: + # create the json body to send to the API + body = self._format_body(messages, temperature, max_tokens, caller_id, seed, format) - input_msg_prompt.append({"role": role, "content": [{"type": "text", "text": msg.content}]}) + # invoke the model and get the response + try: + accept = "application/json" + contentType = "application/json" + response = self._client.invoke_model(body=body, modelId=self.model, accept=accept, contentType=contentType) + completion = json.loads(response.get("body").read()) + except RuntimeError as err: + raise LLMConfigurationError("Could not get response from Bedrock API") from err - return input_msg_prompt, "\n".join(system_prompts) + return self._parse_completion(completion, caller_id) -class ClaudeBedrockClient(LLMClient): +class ClaudeBedrockClient(BaseBedrockClient): def __init__( self, bedrock_runtime_client, model: str = "anthropic.claude-3-sonnet-20240229-v1:0", anthropic_version: str = "bedrock-2023-05-31", ): - self._client = bedrock_runtime_client - self.model = model + # only supporting claude 3 + if "claude-3" not in model: + raise LLMConfigurationError(f"Only claude-3 models are supported as of now, got {self.model}") + + super().__init__(bedrock_runtime_client, model) self.anthropic_version = anthropic_version - def complete( + def _format_body( self, messages: Sequence[ChatMessage], temperature: float = 1, @@ -59,36 +83,42 @@ def complete( caller_id: Optional[str] = None, seed: Optional[int] = None, format=None, - ) -> ChatMessage: - # only supporting claude 3 to start - if "claude-3" not in self.model: - raise LLMConfigurationError(f"Only claude-3 models are supported as of now, got {self.model}") + ) -> Dict: + input_msg_prompt: List = [] + system_prompts = [] - messages, system = _format_messages_claude(messages) + for msg in messages: + # System prompt is a specific parameter in Claude + if msg.role.lower() == "system": + system_prompts.append(msg.content) + continue - # create the json body to send to the API - body = json.dumps( + # Only role user and assistant are allowed + role = msg.role.lower() + role = role if role in ["assistant", "user"] else "user" + + # Consecutive messages need to be grouped + last_message = None if len(input_msg_prompt) == 0 else input_msg_prompt[-1] + if last_message is not None and last_message["role"] == role: + last_message["content"].append({"type": "text", "text": msg.content}) + continue + + input_msg_prompt.append({"role": role, "content": [{"type": "text", "text": msg.content}]}) + + return json.dumps( { "anthropic_version": "bedrock-2023-05-31", "max_tokens": max_tokens, "temperature": temperature, - "system": system, - "messages": messages, + "system": "\n".join(system_prompts), + "messages": input_msg_prompt, } ) - # invoke the model and get the response - try: - accept = "application/json" - contentType = "application/json" - response = self._client.invoke_model(body=body, modelId=self.model, accept=accept, contentType=contentType) - completion = json.loads(response.get("body").read()) - except RuntimeError as err: - raise LLMConfigurationError("Could not get response from Bedrock API") from err - + def _parse_completion(self, completion, caller_id: Optional[str] = None) -> ChatMessage: self.logger.log_call( prompt_tokens=completion["usage"]["input_tokens"], - sampled_tokens=completion["usage"]["input_tokens"], + sampled_tokens=completion["usage"]["output_tokens"], model=self.model, client_class=self.__class__.__name__, caller_id=caller_id, @@ -96,3 +126,48 @@ def complete( msg = completion["content"][0]["text"] return ChatMessage(role="assistant", content=msg) + + +class LLamaBedrockClient(BaseBedrockClient): + def __init__(self, bedrock_runtime_client, model: str = "meta.llama3-8b-instruct-v1:0"): + # only supporting llama + if "llama" not in model: + raise LLMConfigurationError(f"Only Llama models are supported as of now, got {self.model}") + + super().__init__(bedrock_runtime_client, model) + + def _format_body( + self, + messages: Sequence[ChatMessage], + temperature: float = 1, + max_tokens: Optional[int] = 1000, + caller_id: Optional[str] = None, + seed: Optional[int] = None, + format=None, + ) -> Dict: + # Create the messages format needed for llama bedrock specifically + prompts = [] + for msg in messages: + prompts.append(f"# {msg.role}:\n{msg.content}\n") + + # create the json body to send to the API + messages = "\n".join(prompts) + return json.dumps( + { + "max_gen_len": max_tokens, + "temperature": temperature, + "prompt": f"{messages}\n# assistant:\n", + } + ) + + def _parse_completion(self, completion, caller_id: Optional[str] = None) -> ChatMessage: + self.logger.log_call( + prompt_tokens=completion["prompt_token_count"], + sampled_tokens=completion["generation_token_count"], + model=self.model, + client_class=self.__class__.__name__, + caller_id=caller_id, + ) + + msg = completion["generation"] + return ChatMessage(role="assistant", content=msg)