# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import json
from collections import defaultdict
from dataclasses import dataclass
from types import GeneratorType
from typing import Any, Callable, Dict, List, Optional, Tuple
from tenacity import retry
from tenacity.stop import stop_after_attempt
from tenacity.wait import wait_exponential
from camel.agents import BaseAgent
from camel.configs import BaseConfig, ChatGPTConfig
from camel.functions import OpenAIFunction
from camel.messages import BaseMessage, FunctionCallingMessage, OpenAIMessage
from camel.models import BaseModelBackend, ModelFactory
from camel.typing import ModelType, RoleType
from camel.utils import get_model_encoding, openai_api_key_required
[docs]@dataclass(frozen=True)
class ChatAgentResponse:
r"""Response of a ChatAgent.
Attributes:
msgs (List[BaseMessage]): A list of zero, one or several messages.
If the list is empty, there is some error in message generation.
If the list has one message, this is normal mode.
If the list has several messages, this is the critic mode.
terminated (bool): A boolean indicating whether the agent decided
to terminate the chat session.
info (Dict[str, Any]): Extra information about the chat message.
"""
msgs: List[BaseMessage]
terminated: bool
info: Dict[str, Any]
@property
def msg(self):
if len(self.msgs) != 1:
raise RuntimeError("Property msg is only available "
"for a single message in msgs.")
return self.msgs[0]
[docs]@dataclass(frozen=True)
class ChatRecord:
r"""Historical records of who made what message.
Attributes:
role_at_backend (str): Role of the message that mirrors OpenAI
message role that may be `system` or `user` or `assistant`.
message (BaseMessage): Message payload.
"""
role_at_backend: str
message: BaseMessage
[docs] def to_openai_message(self):
r"""Converts the payload message to OpenAI-compatible format.
Returns:
OpenAIMessage: OpenAI-compatible message
"""
return self.message.to_openai_message(self.role_at_backend)
[docs]@dataclass(frozen=True)
class FunctionCallingRecord:
r"""Historical records of functions called in the conversation.
Attributes:
func_name (str): The name of the function being called.
args (Dict[str, Any]): The dictionary of arguments passed to
the function.
result (Any): The execution result of calling this function.
"""
func_name: str
args: Dict[str, Any]
result: Any
def __str__(self) -> str:
r"""Overridden version of the string function.
Returns:
str: Modified string to represent the function calling.
"""
return (f"Function Execution: {self.func_name}\n"
f"\tArgs: {self.args}\n"
f"\tResult: {self.result}")
[docs]class ChatAgent(BaseAgent):
r"""Class for managing conversations of CAMEL Chat Agents.
Args:
system_message (BaseMessage): The system message for the chat agent.
model (ModelType, optional): The LLM model to use for generating
responses. (default :obj:`ModelType.GPT_3_5_TURBO`)
model_config (Any, optional): Configuration options for the LLM model.
(default: :obj:`None`)
message_window_size (int, optional): The maximum number of previous
messages to include in the context window. If `None`, no windowing
is performed. (default: :obj:`None`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
function_list (Optional[List[OpenAIFunction]]): List of available
:obj:`OpenAIFunction`. (default: :obj:`None`)
"""
def __init__(
self,
system_message: BaseMessage,
model: Optional[ModelType] = None,
model_config: Optional[BaseConfig] = None,
message_window_size: Optional[int] = None,
output_language: Optional[str] = None,
function_list: Optional[List[OpenAIFunction]] = None,
) -> None:
self.orig_sys_message: BaseMessage = system_message
self.system_message = system_message
self.role_name: str = system_message.role_name
self.role_type: RoleType = system_message.role_type
self.output_language: Optional[str] = output_language
if self.output_language is not None:
self.set_output_language(self.output_language)
self.model: ModelType = (model if model is not None else
ModelType.GPT_3_5_TURBO)
self.message_window_size: Optional[int] = message_window_size
self.func_dict: Dict[str, Callable] = {}
if function_list is not None:
for func in function_list:
self.func_dict[func.name] = func.func
self.model_config = model_config or ChatGPTConfig()
self.model_backend: BaseModelBackend = ModelFactory.create(
self.model, self.model_config.__dict__)
self.model_token_limit: int = self.model_backend.token_limit
self.terminated: bool = False
self.stored_messages: List[ChatRecord]
self.init_messages()
[docs] def reset(self):
r"""Resets the :obj:`ChatAgent` to its initial state and returns the
stored messages.
Returns:
List[BaseMessage]: The stored messages.
"""
self.terminated = False
self.init_messages()
@property
def system_message(self) -> BaseMessage:
r"""The getter method for the property :obj:`system_message`.
Returns:
BaseMessage: The system message of this agent.
"""
return self._system_message
@system_message.setter
def system_message(self, message: BaseMessage):
r"""The setter method for the property :obj:`system_message`.
Args:
message (BaseMessage): The message to be set as the
new system message of this agent.
"""
self._system_message = message
[docs] def is_function_calling_enabled(self) -> bool:
r"""Whether OpenAI function calling is enabled for this agent.
Returns:
bool: Whether OpenAI function calling is enabled for this
agent, determined by whether the dictionary of functions
is empty.
"""
return len(self.func_dict) > 0
[docs] def set_output_language(self, output_language: str) -> BaseMessage:
r"""Sets the output language for the system message. This method
updates the output language for the system message. The output
language determines the language in which the output text should be
generated.
Args:
output_language (str): The desired output language.
Returns:
BaseMessage: The updated system message object.
"""
self.output_language = output_language
content = (self.orig_sys_message.content +
("\nRegardless of the input language, "
f"you must output text in {output_language}."))
self.system_message = self.system_message.create_new_instance(content)
return self.system_message
[docs] def get_info(self, id: Optional[str], usage: Optional[Dict[str, int]],
termination_reasons: List[str], num_tokens: int,
called_funcs: List[FunctionCallingRecord]) -> Dict[str, Any]:
r"""Returns a dictionary containing information about the chat session.
Args:
id (str, optional): The ID of the chat session.
usage (Dict[str, int], optional): Information about the usage of
the LLM model.
termination_reasons (List[str]): The reasons for the termination
of the chat session.
num_tokens (int): The number of tokens used in the chat session.
called_funcs (List[FunctionCallingRecord]): The list of function
calling records, containing the information of called
functions.
Returns:
Dict[str, Any]: The chat session information.
"""
return {
"id": id,
"usage": usage,
"termination_reasons": termination_reasons,
"num_tokens": num_tokens,
"called_functions": called_funcs,
}
[docs] def init_messages(self) -> None:
r"""Initializes the stored messages list with the initial system
message.
"""
self.stored_messages = [ChatRecord('system', self.system_message)]
[docs] def update_messages(self, role: str,
message: BaseMessage) -> List[ChatRecord]:
r"""Updates the stored messages list with a new message.
Args:
message (BaseMessage): The new message to add to the stored
messages.
Returns:
List[BaseMessage]: The updated stored messages.
"""
if role not in {'system', 'user', 'assistant', 'function'}:
raise ValueError(f"Unsupported role {role}")
self.stored_messages.append(ChatRecord(role, message))
return self.stored_messages
[docs] def submit_message(self, message: BaseMessage) -> None:
r"""Submits the externally provided message as if it were an answer of
the chat LLM from the backend. Currently, the choice of the critic is
submitted with this method.
Args:
message (BaseMessage): An external message to be added as an
assistant response.
"""
self.stored_messages.append(ChatRecord('assistant', message))
[docs] @retry(wait=wait_exponential(min=5, max=60), stop=stop_after_attempt(5))
@openai_api_key_required
def step(
self,
input_message: BaseMessage,
) -> ChatAgentResponse:
r"""Performs a single step in the chat session by generating a response
to the input message.
Args:
input_message (BaseMessage): The input message to the agent.
Its `role` field that specifies the role at backend may be either
`user` or `assistant` but it will be set to `user` anyway since
for the self agent any incoming message is external.
Returns:
ChatAgentResponse: A struct containing the output messages,
a boolean indicating whether the chat session has terminated,
and information about the chat session.
"""
messages = self.update_messages('user', input_message)
output_messages: List[BaseMessage]
info: Dict[str, Any]
called_funcs: List[FunctionCallingRecord] = []
while True:
# Format messages and get the token number
openai_messages: Optional[List[OpenAIMessage]]
num_tokens: int
openai_messages, num_tokens = self.preprocess_messages(messages)
# Terminate when number of tokens exceeds the limit
if num_tokens >= self.model_token_limit:
return self.step_token_exceed(num_tokens, called_funcs)
# Obtain LLM's response and validate it
response = self.model_backend.run(openai_messages)
self.validate_model_response(response)
if not self.model_backend.stream:
output_messages, finish_reasons, usage_dict, response_id = (
self.handle_batch_response(response))
else:
output_messages, finish_reasons, usage_dict, response_id = (
self.handle_stream_response(response, num_tokens))
if self.is_function_calling_enabled(
) and finish_reasons[0] == 'function_call':
# Do function calling
func_assistant_msg, func_result_msg, func_record = (
self.step_function_call(response))
# Update the messages
messages = self.update_messages('assistant',
func_assistant_msg)
messages = self.update_messages('function', func_result_msg)
called_funcs.append(func_record)
else:
# Function calling disabled or chat stopped
info = self.get_info(
response_id,
usage_dict,
finish_reasons,
num_tokens,
called_funcs,
)
break
return ChatAgentResponse(output_messages, self.terminated, info)
[docs] def preprocess_messages(
self,
messages: List[ChatRecord]) -> Tuple[List[OpenAIMessage], int]:
r"""Truncate the list of messages if message window is defined and
the current length of message list is beyond the window size. Then
convert the list of messages to OpenAI's input format and calculate
the number of tokens.
Args:
messages (List[ChatRecord]): The list of structs containing
information about previous chat messages.
Returns:
tuple: A tuple containing the truncated list of messages in
OpenAI's input format and the number of tokens.
"""
if (self.message_window_size
is not None) and (len(messages) > self.message_window_size):
messages = [ChatRecord('system', self.system_message)
] + messages[-self.message_window_size:]
openai_messages: List[OpenAIMessage]
openai_messages = [record.to_openai_message() for record in messages]
num_tokens = self.model_backend.count_tokens_from_messages(
openai_messages)
return openai_messages, num_tokens
[docs] def validate_model_response(self, response: Any) -> None:
r"""Validate the type of the response returned by the model.
Args:
response (Any): The response returned by the model.
"""
if not self.model_backend.stream:
if not isinstance(response, dict):
raise RuntimeError("OpenAI returned unexpected batch struct")
else:
if not isinstance(response, GeneratorType):
raise RuntimeError("OpenAI returned unexpected stream struct")
[docs] def handle_batch_response(
self, response: Dict[str, Any]
) -> Tuple[List[BaseMessage], List[str], Dict[str, int], str]:
r"""
Args:
response (dict): Model response.
Returns:
tuple: A tuple of list of output `ChatMessage`, list of
finish reasons, usage dictionary, and response id.
"""
output_messages: List[BaseMessage] = []
for choice in response["choices"]:
chat_message = BaseMessage(role_name=self.role_name,
role_type=self.role_type,
meta_dict=dict(),
content=choice["message"]['content'])
output_messages.append(chat_message)
finish_reasons = [
str(choice["finish_reason"]) for choice in response["choices"]
]
return output_messages, finish_reasons, dict(
response["usage"]), response["id"]
[docs] def handle_stream_response(
self,
response: Any,
prompt_tokens: int,
) -> Tuple[List[BaseMessage], List[str], Dict[str, int], str]:
r"""
Args:
response (dict): Model response.
prompt_tokens (int): Number of input prompt tokens.
Returns:
tuple: A tuple of list of output `ChatMessage`, list of
finish reasons, usage dictionary, and response id.
"""
content_dict: defaultdict = defaultdict(lambda: "")
finish_reasons_dict: defaultdict = defaultdict(lambda: "")
output_messages: List[BaseMessage] = []
response_id: str = ""
# All choices in one response share one role
role: str = ""
for chunk in response:
response_id = chunk["id"]
for choice in chunk["choices"]:
index: int = choice["index"]
delta: Dict = choice["delta"]
if len(delta) != 0:
# When response has not been stopped
# Notice that only the first chunk has the "role"
role = delta.get("role", role)
delta_content = delta.get("content", "")
content_dict[index] += delta_content
else:
finish_reasons_dict[index] = choice["finish_reason"]
chat_message = BaseMessage(role_name=self.role_name,
role_type=self.role_type,
meta_dict=dict(),
content=content_dict[index])
output_messages.append(chat_message)
finish_reasons = [
finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
]
usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
return output_messages, finish_reasons, usage_dict, response_id
[docs] def step_token_exceed(
self, num_tokens: int,
called_funcs: List[FunctionCallingRecord]) -> ChatAgentResponse:
r"""Return trivial response containing number of tokens and information
of called functions when the number of tokens exceeds.
Args:
num_tokens (int): Number of tokens in the messages.
called_funcs (List[FunctionCallingRecord]): List of information
objects of functions called in the current step.
Returns:
ChatAgentResponse: The struct containing trivial outputs and
information about token number and called functions.
"""
self.terminated = True
output_messages: List[BaseMessage] = []
info = self.get_info(
None,
None,
["max_tokens_exceeded"],
num_tokens,
called_funcs,
)
return ChatAgentResponse(
output_messages,
self.terminated,
info,
)
[docs] def step_function_call(
self, response: Dict[str, Any]
) -> Tuple[FunctionCallingMessage, FunctionCallingMessage,
FunctionCallingRecord]:
r"""Execute the function with arguments following the model's response.
Args:
response (Dict[str, Any]): the response obtained by calling the
model.
Returns:
tuple: a tuple consisting of two obj:`FunctionCallingMessage`,
one about the arguments and the other about the execution
result, and a struct for logging information about this
function call.
"""
# Note that when function calling is enabled, `n` is set to 1.
choice = response["choices"][0]
func_name = choice["message"]["function_call"]["name"]
func = self.func_dict[func_name]
args_str: str = choice["message"]["function_call"]["arguments"]
args = json.loads(args_str.replace("\'", "\""))
# Pass the extracted arguments to the indicated function
try:
result = func(**args)
except Exception:
raise ValueError(
f"Execution of function {func.__name__} failed with "
f"arguments being {args}.")
assist_msg = FunctionCallingMessage(
role_name=self.role_name,
role_type=self.role_type,
meta_dict=None,
content="",
func_name=func_name,
args=args,
)
func_msg = FunctionCallingMessage(
role_name=self.role_name,
role_type=self.role_type,
meta_dict=None,
content="",
func_name=func_name,
result=result,
)
# Record information about this function call
func_record = FunctionCallingRecord(func_name, args, result)
return assist_msg, func_msg, func_record
[docs] def get_usage_dict(self, output_messages: List[BaseMessage],
prompt_tokens: int) -> Dict[str, int]:
r"""Get usage dictionary when using the stream mode.
Args:
output_messages (list): List of output messages.
prompt_tokens (int): Number of input prompt tokens.
Returns:
dict: Usage dictionary.
"""
encoding = get_model_encoding(self.model.value_for_tiktoken)
completion_tokens = 0
for message in output_messages:
completion_tokens += len(encoding.encode(message.content))
usage_dict = dict(completion_tokens=completion_tokens,
prompt_tokens=prompt_tokens,
total_tokens=completion_tokens + prompt_tokens)
return usage_dict
def __repr__(self) -> str:
r"""Returns a string representation of the :obj:`ChatAgent`.
Returns:
str: The string representation of the :obj:`ChatAgent`.
"""
return f"ChatAgent({self.role_name}, {self.role_type}, {self.model})"