Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 117 additions & 42 deletions giskard/llm/client/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Sequence
from typing import Dict, List, Optional, Sequence

import json
from abc import ABC, abstractmethod

from ..config import LLMConfigurationError
from ..errors import LLMImportError
Expand All @@ -15,84 +16,158 @@
) from err


def _format_messages_claude(messages: Sequence[ChatMessage]):
input_msg_prompt: List = []
system_prompts = []
class BaseBedrockClient(LLMClient, ABC):
def __init__(self, bedrock_runtime_client, model: str):
self._client = bedrock_runtime_client
self.model = model

for msg in messages:
# System prompt is a specific parameter in Claude
if msg.role.lower() == "system":
system_prompts.append(msg.content)
continue
@abstractmethod
def _format_body(
self,
messages: Sequence[ChatMessage],
temperature: float = 1,
max_tokens: Optional[int] = 1000,
caller_id: Optional[str] = None,
seed: Optional[int] = None,
format=None,
) -> Dict:
...

# Only role user and assistant are allowed
role = msg.role.lower()
role = role if role in ["assistant", "user"] else "user"
@abstractmethod
def _parse_completion(self, completion, caller_id: Optional[str] = None) -> ChatMessage:
...

# Consecutive messages need to be grouped
last_message = None if len(input_msg_prompt) == 0 else input_msg_prompt[-1]
if last_message is not None and last_message["role"] == role:
last_message["content"].append({"type": "text", "text": msg.content})
continue
def complete(
self,
messages: Sequence[ChatMessage],
temperature: float = 1,
max_tokens: Optional[int] = 1000,
caller_id: Optional[str] = None,
seed: Optional[int] = None,
format=None,
) -> ChatMessage:
# create the json body to send to the API
body = self._format_body(messages, temperature, max_tokens, caller_id, seed, format)

input_msg_prompt.append({"role": role, "content": [{"type": "text", "text": msg.content}]})
# invoke the model and get the response
try:
accept = "application/json"
contentType = "application/json"
response = self._client.invoke_model(body=body, modelId=self.model, accept=accept, contentType=contentType)
completion = json.loads(response.get("body").read())
except RuntimeError as err:
raise LLMConfigurationError("Could not get response from Bedrock API") from err

return input_msg_prompt, "\n".join(system_prompts)
return self._parse_completion(completion, caller_id)


class ClaudeBedrockClient(LLMClient):
class ClaudeBedrockClient(BaseBedrockClient):
def __init__(
self,
bedrock_runtime_client,
model: str = "anthropic.claude-3-sonnet-20240229-v1:0",
anthropic_version: str = "bedrock-2023-05-31",
):
self._client = bedrock_runtime_client
self.model = model
# only supporting claude 3
if "claude-3" not in model:
raise LLMConfigurationError(f"Only claude-3 models are supported as of now, got {self.model}")

super().__init__(bedrock_runtime_client, model)
self.anthropic_version = anthropic_version

def complete(
def _format_body(
self,
messages: Sequence[ChatMessage],
temperature: float = 1,
max_tokens: Optional[int] = 1000,
caller_id: Optional[str] = None,
seed: Optional[int] = None,
format=None,
) -> ChatMessage:
# only supporting claude 3 to start
if "claude-3" not in self.model:
raise LLMConfigurationError(f"Only claude-3 models are supported as of now, got {self.model}")
) -> Dict:
input_msg_prompt: List = []
system_prompts = []

messages, system = _format_messages_claude(messages)
for msg in messages:
# System prompt is a specific parameter in Claude
if msg.role.lower() == "system":
system_prompts.append(msg.content)
continue

# create the json body to send to the API
body = json.dumps(
# Only role user and assistant are allowed
role = msg.role.lower()
role = role if role in ["assistant", "user"] else "user"

# Consecutive messages need to be grouped
last_message = None if len(input_msg_prompt) == 0 else input_msg_prompt[-1]
if last_message is not None and last_message["role"] == role:
last_message["content"].append({"type": "text", "text": msg.content})
continue

input_msg_prompt.append({"role": role, "content": [{"type": "text", "text": msg.content}]})

return json.dumps(
{
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": max_tokens,
"temperature": temperature,
"system": system,
"messages": messages,
"system": "\n".join(system_prompts),
"messages": input_msg_prompt,
}
)

# invoke the model and get the response
try:
accept = "application/json"
contentType = "application/json"
response = self._client.invoke_model(body=body, modelId=self.model, accept=accept, contentType=contentType)
completion = json.loads(response.get("body").read())
except RuntimeError as err:
raise LLMConfigurationError("Could not get response from Bedrock API") from err

def _parse_completion(self, completion, caller_id: Optional[str] = None) -> ChatMessage:
self.logger.log_call(
prompt_tokens=completion["usage"]["input_tokens"],
sampled_tokens=completion["usage"]["input_tokens"],
sampled_tokens=completion["usage"]["output_tokens"],
model=self.model,
client_class=self.__class__.__name__,
caller_id=caller_id,
)

msg = completion["content"][0]["text"]
return ChatMessage(role="assistant", content=msg)


class LLamaBedrockClient(BaseBedrockClient):
def __init__(self, bedrock_runtime_client, model: str = "meta.llama3-8b-instruct-v1:0"):
# only supporting llama
if "llama" not in model:
raise LLMConfigurationError(f"Only Llama models are supported as of now, got {self.model}")

super().__init__(bedrock_runtime_client, model)

def _format_body(
self,
messages: Sequence[ChatMessage],
temperature: float = 1,
max_tokens: Optional[int] = 1000,
caller_id: Optional[str] = None,
seed: Optional[int] = None,
format=None,
) -> Dict:
# Create the messages format needed for llama bedrock specifically
prompts = []
for msg in messages:
prompts.append(f"# {msg.role}:\n{msg.content}\n")

# create the json body to send to the API
messages = "\n".join(prompts)
return json.dumps(
{
"max_gen_len": max_tokens,
"temperature": temperature,
"prompt": f"{messages}\n# assistant:\n",
}
)

def _parse_completion(self, completion, caller_id: Optional[str] = None) -> ChatMessage:
self.logger.log_call(
prompt_tokens=completion["prompt_token_count"],
sampled_tokens=completion["generation_token_count"],
model=self.model,
client_class=self.__class__.__name__,
caller_id=caller_id,
)

msg = completion["generation"]
return ChatMessage(role="assistant", content=msg)