Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rdagent/oai/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def _create_embedding_with_cache(
return [content_to_embedding_dict[content] for content in input_content_list] # type: ignore[misc]

@abstractmethod
def support_function_calling(self) -> bool:
def supports_response_schema(self) -> bool:
"""
Check if the backend supports function calling
"""
Expand Down
2 changes: 1 addition & 1 deletion rdagent/oai/backend/deprec.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _azure_patch(model: str) -> str:
raise
return encoding

def support_function_calling(self) -> bool:
def supports_response_schema(self) -> bool:
"""
Check if the backend supports function calling.
Currently, deprec backend does not support function calling so it returns False. #FIXME: maybe a mapping to the backend class is needed.
Expand Down
4 changes: 2 additions & 2 deletions rdagent/oai/backend/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # no
)
return content, finish_reason

def support_function_calling(self) -> bool:
def supports_response_schema(self) -> bool:
"""
Check if the backend supports function calling
"""
return supports_function_calling(model=LITELLM_SETTINGS.chat_model) and LITELLM_SETTINGS.enable_function_call
return supports_response_schema(model=LITELLM_SETTINGS.chat_model) and LITELLM_SETTINGS.enable_response_schema
5 changes: 2 additions & 3 deletions rdagent/oai/llm_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ class LLMSettings(ExtendedBaseSettings):
embedding_model: str = "text-embedding-3-small"

reasoning_effort: Literal["low", "medium", "high"] | None = None
enable_function_call: bool = (
True # Whether to enable function calling in chat models. may not work for models that do not support it.
)
enable_response_schema: bool = True
# Whether to enable response_schema in chat models. may not work for models that do not support it.

# Handling format
reasoning_think_rm: bool = False
Expand Down
34 changes: 17 additions & 17 deletions rdagent/scenarios/data_science/proposal/exp_gen/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,12 +457,12 @@ def _f(user_prompt):
class DSProposalV2ExpGen(ExpGen):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.support_function_calling = APIBackend().support_function_calling()
self.supports_response_schema = APIBackend().supports_response_schema()

def identify_scenario_problem(self, scenario_desc: str, sota_exp_desc: str) -> Dict:
sys_prompt = T(".prompts_v2:scenario_problem.system").r(
problem_output_format=(
T(".prompts_v2:output_format.problem").r() if not self.support_function_calling else None
T(".prompts_v2:output_format.problem").r() if not self.supports_response_schema else None
),
)
user_prompt = T(".prompts_v2:scenario_problem.user").r(
Expand All @@ -472,10 +472,10 @@ def identify_scenario_problem(self, scenario_desc: str, sota_exp_desc: str) -> D
response = APIBackend().build_messages_and_create_chat_completion(
user_prompt=user_prompt,
system_prompt=sys_prompt,
response_format=ScenarioChallenges if self.support_function_calling else {"type": "json_object"},
json_target_type=Dict[str, Dict[str, str]] if not self.support_function_calling else None,
response_format=ScenarioChallenges if self.supports_response_schema else {"type": "json_object"},
json_target_type=Dict[str, Dict[str, str]] if not self.supports_response_schema else None,
)
if self.support_function_calling:
if self.supports_response_schema:
challenges = ScenarioChallenges(**json.loads(response))
# Translate to problems
problems = {o.caption: {"problem": o.statement, "reason": o.reasoning} for o in challenges.challenges}
Expand All @@ -490,7 +490,7 @@ def identify_feedback_problem(
) -> Dict:
sys_prompt = T(".prompts_v2:feedback_problem.system").r(
problem_output_format=(
T(".prompts_v2:output_format.problem").r() if not self.support_function_calling else None
T(".prompts_v2:output_format.problem").r() if not self.supports_response_schema else None
),
inject_diverse=inject_diverse,
)
Expand All @@ -502,10 +502,10 @@ def identify_feedback_problem(
response = APIBackend().build_messages_and_create_chat_completion(
user_prompt=user_prompt,
system_prompt=sys_prompt,
response_format=TraceChallenges if self.support_function_calling else {"type": "json_object"},
json_target_type=Dict[str, Dict[str, str]] if not self.support_function_calling else None,
response_format=TraceChallenges if self.supports_response_schema else {"type": "json_object"},
json_target_type=Dict[str, Dict[str, str]] if not self.supports_response_schema else None,
)
if self.support_function_calling:
if self.supports_response_schema:
challenges = TraceChallenges(**json.loads(response))
# Translate to problems
problems = {o.caption: {"problem": o.statement, "reason": o.reasoning} for o in challenges.challenges}
Expand Down Expand Up @@ -569,7 +569,7 @@ def hypothesis_gen(
sys_prompt = T(".prompts_v2:hypothesis_gen.system").r(
hypothesis_output_format=(
T(".prompts_v2:output_format.hypothesis").r(pipeline=pipeline, enable_idea_pool=enable_idea_pool)
if not self.support_function_calling
if not self.supports_response_schema
else None
),
pipeline=pipeline,
Expand All @@ -586,12 +586,12 @@ def hypothesis_gen(
response = APIBackend().build_messages_and_create_chat_completion(
user_prompt=user_prompt,
system_prompt=sys_prompt,
response_format=HypothesisList if self.support_function_calling else {"type": "json_object"},
response_format=HypothesisList if self.supports_response_schema else {"type": "json_object"},
json_target_type=(
Dict[str, Dict[str, str | Dict[str, str | int]]] if not self.support_function_calling else None
Dict[str, Dict[str, str | Dict[str, str | int]]] if not self.supports_response_schema else None
),
)
if self.support_function_calling:
if self.supports_response_schema:
hypotheses = HypothesisList(**json.loads(response))
resp_dict = {
h.caption: {
Expand Down Expand Up @@ -728,7 +728,7 @@ def task_gen(
component_info = get_component(hypotheses[0].component)
data_folder_info = self.scen.processed_data_folder_description
sys_prompt = T(".prompts_v2:task_gen.system").r(
task_output_format=component_info["task_output_format"] if not self.support_function_calling else None,
task_output_format=component_info["task_output_format"] if not self.supports_response_schema else None,
# task_output_format=component_info["task_output_format"],
component_desc=component_desc,
workflow_check=not pipeline and hypotheses[0].component != "Workflow",
Expand All @@ -744,12 +744,12 @@ def task_gen(
response = APIBackend().build_messages_and_create_chat_completion(
user_prompt=user_prompt,
system_prompt=sys_prompt,
response_format=CodingSketch if self.support_function_calling else {"type": "json_object"},
json_target_type=Dict[str, str | Dict[str, str]] if not self.support_function_calling else None,
response_format=CodingSketch if self.supports_response_schema else {"type": "json_object"},
json_target_type=Dict[str, str | Dict[str, str]] if not self.supports_response_schema else None,
)
task_dict = json.loads(response)
task_design = (
task_dict.get("task_design", {}) if not self.support_function_calling else task_dict.get("sketch", {})
task_dict.get("task_design", {}) if not self.supports_response_schema else task_dict.get("sketch", {})
)
logger.info(f"Task design:\n{task_design}")
task_name = hypotheses[0].component
Expand Down