Source code for camel.utils.token_counting
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from camel.messages import OpenAIMessage
from camel.typing import ModelType
def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
r"""Parse the message list into a single prompt following model-specifc
formats.
Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.
model (ModelType): Model type for which messages will be parsed.
Returns:
str: A single prompt summarizing all the messages.
"""
system_message = messages[0]["content"]
ret: str
if model == ModelType.LLAMA_2:
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
seps = [" ", " </s><s>"]
role_map = {"user": "[INST]", "assistant": "[/INST]"}
system_prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n"
ret = ""
for i, msg in enumerate(messages[1:]):
role = role_map[msg["role"]]
message = msg["content"]
if message:
if i == 0:
ret += system_prompt + message
else:
ret += role + " " + message + seps[i % 2]
else:
ret += role
return ret
elif model == ModelType.VICUNA or model == ModelType.VICUNA_16K:
seps = [" ", "</s>"]
role_map = {"user": "USER", "assistant": "ASSISTANT"}
system_prompt = f"{system_message}"
ret = system_prompt + seps[0]
for i, msg in enumerate(messages[1:]):
role = role_map[msg["role"]]
message = msg["content"]
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
else:
raise ValueError(f"Invalid model type: {model}")
[docs]def get_model_encoding(value_for_tiktoken: str):
r"""Get model encoding from tiktoken.
Args:
value_for_tiktoken: Model value for tiktoken.
Returns:
tiktoken.Encoding: Model encoding.
"""
import tiktoken
try:
encoding = tiktoken.encoding_for_model(value_for_tiktoken)
except KeyError:
print("Model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
return encoding
[docs]class BaseTokenCounter(ABC):
r"""Base class for token counters of different kinds of models."""
[docs] @abstractmethod
def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
r"""Count number of tokens in the provided message list.
Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.
Returns:
int: Number of tokens in the messages.
"""
pass
[docs]class OpenSourceTokenCounter(BaseTokenCounter):
def __init__(self, model_type: ModelType, model_path: str):
r"""Constructor for the token counter for open-source models.
Args:
model_type (ModelType): Model type for which tokens will be
counted.
model_path (str): The path to the model files, where the tokenizer
model should be located.
"""
# Use a fast Rust-based tokenizer if it is supported for a given model.
# If a fast tokenizer is not available for a given model,
# a normal Python-based tokenizer is returned instead.
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=True,
)
except TypeError:
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=False,
)
except:
raise ValueError(
f"Invalid `model_path` ({model_path}) is provided. "
"Tokenizer loading failed.")
self.tokenizer = tokenizer
self.model_type = model_type
[docs] def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
r"""Count number of tokens in the provided message list using
loaded tokenizer specific for this type of model.
Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.
Returns:
int: Number of tokens in the messages.
"""
prompt = messages_to_prompt(messages, self.model_type)
input_ids = self.tokenizer(prompt).input_ids
return len(input_ids)
[docs]class OpenAITokenCounter(BaseTokenCounter):
def __init__(self, model: ModelType):
r"""Constructor for the token counter for OpenAI models.
Args:
model_type (ModelType): Model type for which tokens will be
counted.
"""
self.model: str = model.value_for_tiktoken
self.tokens_per_message: int
self.tokens_per_name: int
if self.model == "gpt-3.5-turbo-0301":
# Every message follows <|start|>{role/name}\n{content}<|end|>\n
self.tokens_per_message = 4
# If there's a name, the role is omitted
self.tokens_per_name = -1
elif ("gpt-3.5-turbo" in self.model) or ("gpt-4" in self.model):
self.tokens_per_message = 3
self.tokens_per_name = 1
else:
# flake8: noqa :E501
raise NotImplementedError(
"Token counting for OpenAI Models is not presently "
f"implemented for model {model}. "
"See https://github.com/openai/openai-python/blob/main/chatml.md "
"for information on how messages are converted to tokens. "
"See https://platform.openai.com/docs/models/gpt-4"
"or https://platform.openai.com/docs/models/gpt-3-5"
"for information about openai chat models.")
self.encoding = get_model_encoding(self.model)
[docs] def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
r"""Count number of tokens in the provided message list with the
help of package tiktoken.
Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.
Returns:
int: Number of tokens in the messages.
"""
num_tokens = 0
for message in messages:
num_tokens += self.tokens_per_message
for key, value in message.items():
num_tokens += len(self.encoding.encode(str(value)))
if key == "name":
num_tokens += self.tokens_per_name
# every reply is primed with <|start|>assistant<|message|>
num_tokens += 3
return num_tokens