Skip to content
Merged
18 changes: 12 additions & 6 deletions rdagent/oai/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, cast
from typing import Any, Callable, List, Optional, Tuple, Type, Union, cast

import pytz
from pydantic import BaseModel, TypeAdapter
Expand Down Expand Up @@ -528,12 +528,16 @@ def _create_chat_completion_auto_continue(
seed: Optional[int] = None,
json_target_type: Optional[str] = None,
add_json_in_prompt: bool = False,
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
**kwargs: Any,
) -> str:
"""
Call the chat completion function and automatically continue the conversation if the finish_reason is length.
"""

if response_format is None and json_mode:
response_format = {"type": "json_object"}

# 0) return directly if cache is hit
if seed is None and LLM_SETTINGS.use_auto_chat_cache_seed_gen:
seed = LLM_CACHE_SEED_GEN.get_next_seed()
Expand All @@ -555,11 +559,11 @@ def _create_chat_completion_auto_continue(
# Loop to get a full response
try_n = 6
for _ in range(try_n): # for some long code, 3 times may not enough for reasoning models
if json_mode and add_json_in_prompt:
if response_format and add_json_in_prompt:
self._add_json_in_prompt(new_messages)
response, finish_reason = self._create_chat_completion_inner_function(
messages=new_messages,
json_mode=json_mode,
response_format=response_format,
**kwargs,
)
all_response += response
Expand All @@ -575,17 +579,19 @@ def _create_chat_completion_auto_continue(
_, all_response = match.groups() if match else ("", all_response)

# 3) format checking
if json_mode or json_target_type:
if response_format or json_target_type:
parser = JSONParser()
all_response = parser.parse(all_response)
if json_target_type:
# deepseek will enter this branch
TypeAdapter(json_target_type).validate_json(all_response)

if (response_format := kwargs.get("response_format")) is not None:
if response_format is not None:
if not isinstance(response_format, dict) and issubclass(response_format, BaseModel):
# It may raise TypeError if initialization fails
response_format(**json.loads(all_response))
elif response_format == {"type": "json_object"}:
logger.info(f"Using OpenAI response format: {response_format}")
else:
logger.warning(f"Unknown response_format: {response_format}, skipping validation.")
if self.dump_chat_cache:
Expand Down Expand Up @@ -642,7 +648,7 @@ def _create_embedding_inner_function( # type: ignore[no-untyped-def]
def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # noqa: C901, PLR0912, PLR0915
self,
messages: list[dict[str, Any]],
json_mode: bool = False,
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
*args,
**kwargs,
) -> tuple[str, str | None]:
Expand Down
9 changes: 6 additions & 3 deletions rdagent/oai/backend/deprec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
import uuid
from copy import deepcopy
from pathlib import Path
from typing import Any, Optional, cast
from typing import Any, Optional, Type, Union, cast

import numpy as np
import openai
import tiktoken
from openai.types.chat import ChatCompletion
from pydantic import BaseModel

from rdagent.core.utils import LLM_CACHE_SEED_GEN, SingletonBaseClass, import_class
from rdagent.log import LogColors
Expand Down Expand Up @@ -294,7 +295,7 @@ def _create_embedding_inner_function( # type: ignore[no-untyped-def]
def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # noqa: C901, PLR0912, PLR0915
self,
messages: list[dict[str, Any]],
json_mode: bool = False,
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
add_json_in_prompt: bool = False,
*args,
**kwargs,
Expand Down Expand Up @@ -414,7 +415,9 @@ def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # no
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
if json_mode:

# FIX what if the model does not support response_schema
if response_format:
if add_json_in_prompt:
for message in messages[::-1]:
message["content"] = message["content"] + "\nPlease respond in json format."
Expand Down
14 changes: 8 additions & 6 deletions rdagent/oai/backend/litellm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copyreg
from typing import Any, Literal, cast
from typing import Any, Literal, Optional, Type, Union, cast

import numpy as np
from litellm import (
Expand All @@ -11,6 +11,7 @@
supports_response_schema,
token_counter,
)
from pydantic import BaseModel

from rdagent.log import LogColors
from rdagent.log import rdagent_logger as logger
Expand Down Expand Up @@ -86,23 +87,24 @@ def _create_embedding_inner_function(
def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # noqa: C901, PLR0912, PLR0915
self,
messages: list[dict[str, Any]],
json_mode: bool = False,
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
*args,
**kwargs,
) -> tuple[str, str | None]:
"""
Call the chat completion function
"""
if json_mode and supports_response_schema(model=LITELLM_SETTINGS.chat_model):
kwargs["response_format"] = {"type": "json_object"}

elif not supports_response_schema(model=LITELLM_SETTINGS.chat_model) and "response_format" in kwargs:
if response_format and not supports_response_schema(model=LITELLM_SETTINGS.chat_model):
# Deepseek will enter this branch
logger.warning(
f"{LogColors.RED}Model {LITELLM_SETTINGS.chat_model} does not support response schema, ignoring response_format argument.{LogColors.END}",
tag="llm_messages",
)
kwargs.pop("response_format")
response_format = None

if response_format:
kwargs["response_format"] = response_format

if LITELLM_SETTINGS.log_llm_chat_content:
logger.info(self._build_log_messages(messages), tag="llm_messages")
Expand Down