# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import json import os import sys import time from pathlib import Path from typing import List, Literal, Optional, Tuple, TypedDict import torch import torch.nn.functional as F from fairscale.nn.model_parallel.initialize import ( get_model_parallel_rank, initialize_model_parallel, model_parallel_is_initialized, ) from llama.model import ModelArgs, Transformer from llama.tokenizer import Tokenizer if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" Role = Literal["system", "user", "assistant"] class Message(TypedDict): role: Role content: str destination: str # required for model responses class InfillingPrediction(TypedDict, total=False): generation: str full_text: str tokens: List[str] # not required logprobs: List[float] # not required class CompletionPrediction(TypedDict, total=False): generation: str tokens: List[str] # not required logprobs: List[float] # not required class ChatPrediction(TypedDict, total=False): generation: Message tokens: List[str] # not required logprobs: List[float] # not required Dialog = List[Message] B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" SPECIAL_TAGS = [B_INST, E_INST, "<>", "<>", ""] UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt." class Llama: @staticmethod def build( ckpt_dir: str, tokenizer_path: str, max_seq_len: int, max_batch_size: int, model_parallel_size: Optional[int] = None, ) -> "Llama": if not torch.distributed.is_initialized(): if device == "cuda": torch.distributed.init_process_group("nccl") else: torch.distributed.init_process_group("gloo") if not model_parallel_is_initialized(): if model_parallel_size is None: model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) initialize_model_parallel(model_parallel_size) local_rank = int(os.environ.get("LOCAL_RANK", 0)) if device == "cuda": torch.cuda.set_device(local_rank) # seed must be the same in all processes torch.manual_seed(1) if local_rank > 0: sys.stdout = open(os.devnull, "w") start_time = time.time() checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert model_parallel_size == len( checkpoints ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" ckpt_path = checkpoints[get_model_parallel_rank()] checkpoint = torch.load(ckpt_path, map_location="cpu") with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read()) model_args: ModelArgs = ModelArgs( max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params, ) tokenizer = Tokenizer(model_path=tokenizer_path) model_args.vocab_size = tokenizer.n_words # support for mac if device == "cuda": if torch.cuda.is_bf16_supported(): torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) else: torch.set_default_tensor_type(torch.cuda.HalfTensor) else: torch.set_default_tensor_type(torch.HalfTensor) model = Transformer(model_args) model.load_state_dict(checkpoint, strict=False) model.to(device) print(f"Loaded in {time.time() - start_time:.2f} seconds") return Llama(model, tokenizer) def __init__(self, model: Transformer, tokenizer: Tokenizer): self.model = model self.tokenizer = tokenizer @torch.inference_mode() def generate( self, prompt_tokens: List[List[int]], max_gen_len: int, temperature: float = 0.6, top_p: float = 0.9, logprobs: bool = False, echo: bool = False, stop_token: Optional[int] = None, ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: if stop_token is None: stop_token = self.tokenizer.eos_id params = self.model.params bsz = len(prompt_tokens) assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) min_prompt_len = min(len(t) for t in prompt_tokens) max_prompt_len = max(len(t) for t in prompt_tokens) assert max_prompt_len <= params.max_seq_len total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) pad_id = self.tokenizer.pad_id tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device) for k, t in enumerate(prompt_tokens): tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device) if logprobs: token_logprobs = torch.zeros_like(tokens, dtype=torch.float, device=device) prev_pos = 0 stop_reached = torch.tensor([False] * bsz, device=device) input_text_mask = tokens != pad_id for cur_pos in range(min_prompt_len, total_len): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) if logprobs: token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( input=logits.transpose(1, 2), target=tokens[:, prev_pos + 1 : cur_pos + 1], reduction="none", ignore_index=pad_id, ) if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) next_token = sample_top_p(probs, top_p) else: next_token = torch.argmax(logits[:, -1], dim=-1) next_token = next_token.reshape(-1) # only replace token if prompt has already been generated next_token = torch.where( input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token ) tokens[:, cur_pos] = next_token stop_reached |= (~input_text_mask[:, cur_pos]) & (next_token == stop_token) prev_pos = cur_pos if all(stop_reached): break if logprobs: token_logprobs = token_logprobs.tolist() out_tokens, out_logprobs = [], [] for i, toks in enumerate(tokens.tolist()): # cut to max gen len start = 0 if echo else len(prompt_tokens[i]) toks = toks[start : len(prompt_tokens[i]) + max_gen_len] probs = None if logprobs: probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] # cut to stop token if present if stop_token in toks: stop_idx = toks.index(stop_token) toks = toks[:stop_idx] probs = probs[:stop_idx] if logprobs else None out_tokens.append(toks) out_logprobs.append(probs) return (out_tokens, out_logprobs if logprobs else None) def text_completion( self, prompts: List[str], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, echo: bool = False, ) -> List[CompletionPrediction]: if max_gen_len is None: max_gen_len = self.model.params.max_seq_len - 1 prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] generation_tokens, generation_logprobs = self.generate( prompt_tokens=prompt_tokens, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, logprobs=logprobs, echo=echo, ) if logprobs: assert generation_logprobs is not None return [ { "generation": self.tokenizer.decode(t), "tokens": [self.tokenizer.token_piece(x) for x in t], "logprobs": logprobs_i, } for t, logprobs_i in zip(generation_tokens, generation_logprobs) ] return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] def text_infilling( self, prefixes: List[str], suffixes: List[str], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, suffix_first: bool = False, ) -> List[InfillingPrediction]: assert self.tokenizer.eot_id is not None if max_gen_len is None: max_gen_len = self.model.params.max_seq_len - 1 prompt_tokens = [ infilling_prompt_tokens( self.tokenizer, prefix, suffix, suffix_first=suffix_first ) for prefix, suffix in zip(prefixes, suffixes) ] generation_tokens, generation_logprobs = self.generate( prompt_tokens=prompt_tokens, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, logprobs=logprobs, echo=False, stop_token=self.tokenizer.eot_id, ) generations = [self.tokenizer.decode_infilling(t) for t in generation_tokens] if logprobs: assert generation_logprobs is not None return [ { "generation": generation, "logprobs": logprobs_i, "tokens": [self.tokenizer.token_piece(x) for x in t], "full_text": prefix + generation + suffix, } for prefix, suffix, generation, t, logprobs_i in zip( prefixes, suffixes, generations, generation_tokens, generation_logprobs, ) ] else: return [ { "generation": generation, "full_text": prefix + generation + suffix, } for prefix, suffix, generation in zip(prefixes, suffixes, generations) ] def chat_completion( self, dialogs: List[Dialog], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, ) -> List[ChatPrediction]: if self.tokenizer.step_id is not None: return self._chat_completion_turns( dialogs=dialogs, temperature=temperature, top_p=top_p, max_gen_len=max_gen_len, logprobs=logprobs, ) if max_gen_len is None: max_gen_len = self.model.params.max_seq_len - 1 prompt_tokens = [] unsafe_requests = [] for dialog in dialogs: unsafe_requests.append( any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog]) ) if dialog[0]["role"] == "system": dialog = [ # type: ignore { "role": dialog[1]["role"], "content": B_SYS + dialog[0]["content"] + E_SYS + dialog[1]["content"], } ] + dialog[2:] assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( [msg["role"] == "assistant" for msg in dialog[1::2]] ), ( "model only supports 'system', 'user' and 'assistant' roles, " "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" ) dialog_tokens: List[int] = sum( [ self.tokenizer.encode( f"{B_INST} {prompt['content'].strip()} {E_INST} {answer['content'].strip()} ", bos=True, eos=True, ) for prompt, answer in zip( dialog[::2], dialog[1::2], ) ], [], ) assert ( dialog[-1]["role"] == "user" ), f"Last message must be from user, got {dialog[-1]['role']}" dialog_tokens += self.tokenizer.encode( f"{B_INST} {dialog[-1]['content'].strip()} {E_INST}", bos=True, eos=False, ) prompt_tokens.append(dialog_tokens) generation_tokens, generation_logprobs = self.generate( prompt_tokens=prompt_tokens, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, logprobs=logprobs, ) if logprobs: assert generation_logprobs is not None return [ { "generation": { # type: ignore "role": "assistant", "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR, }, "tokens": [self.tokenizer.token_piece(x) for x in t], "logprobs": logprobs_i, } for t, logprobs_i, unsafe in zip( generation_tokens, generation_logprobs, unsafe_requests ) ] return [ { "generation": { # type: ignore "role": "assistant", "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR, } } for t, unsafe in zip(generation_tokens, unsafe_requests) ] def _chat_completion_turns( self, dialogs: List[Dialog], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, ) -> List[ChatPrediction]: if self.tokenizer.step_id is None: raise RuntimeError("Model not suitable for chat_completion_step()") if max_gen_len is None: max_gen_len = self.model.params.max_seq_len - 1 prompt_tokens = [] unsafe_requests = [] for dialog in dialogs: unsafe_requests.append( any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog]) ) # Insert system message if not provided if dialog[0]["role"] != "system": dialog = [{"role": "system", "content": ""}] + dialog # type: ignore dialog_tokens = dialog_prompt_tokens(self.tokenizer, dialog) prompt_tokens.append(dialog_tokens) generation_tokens, generation_logprobs = self.generate( prompt_tokens=prompt_tokens, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, logprobs=logprobs, stop_token=self.tokenizer.step_id, ) if logprobs: assert generation_logprobs is not None return [ { "generation": { "role": "assistant", "destination": "user", "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR, }, "tokens": [self.tokenizer.token_piece(x) for x in t], "logprobs": logprobs_i, } for t, logprobs_i, unsafe in zip( generation_tokens, generation_logprobs, unsafe_requests ) ] return [ { "generation": { "role": "assistant", "destination": "user", "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR, } } for t, unsafe in zip(generation_tokens, unsafe_requests) ] def sample_top_p(probs, p): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > p probs_sort[mask] = 0.0 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token def infilling_prompt_tokens( tokenizer: Tokenizer, pre: str, suf: str, suffix_first: bool = False, ) -> List[int]: """ Format and encode an infilling problem. If `suffix_first` is set, format in suffix-prefix-middle format. """ assert tokenizer.prefix_id is not None assert tokenizer.middle_id is not None assert tokenizer.suffix_id is not None if suffix_first: # format as "
 {suf}  {pre}"
        return (
            [tokenizer.bos_id, tokenizer.prefix_id, tokenizer.suffix_id]
            + tokenizer.encode_infilling(suf)
            + [tokenizer.middle_id]
            + tokenizer.encode(pre, bos=False, eos=False)
        )
    else:
        # format as "
 {pre} {suf} "
        return (
            [tokenizer.bos_id, tokenizer.prefix_id]
            + tokenizer.encode(pre, bos=False, eos=False)
            + [tokenizer.suffix_id]
            + tokenizer.encode_infilling(suf)
            + [tokenizer.middle_id]
        )


def dialog_prompt_tokens(tokenizer: Tokenizer, dialog: Dialog) -> List[int]:
    """
    Prompt formatting for multi-turn dialogs.
    The dialog is expected to start with a system message and then alternate
    between user and assistant messages.
    """
    assert tokenizer.step_id is not None
    assert all([msg["role"] == "user" for msg in dialog[1::2]]) and all(
        [msg["role"] == "assistant" for msg in dialog[2::2]]
    ), (
        "model only supports 'system', 'user' and 'assistant' roles, "
        "starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
    )
    assert (
        dialog[-1]["role"] == "user"
    ), f"Last message must be from user, got {dialog[-1]['role']}"

    # Format context
    dialog_tokens: List[int] = [tokenizer.bos_id]
    headers: List[str] = []
    for message in dialog:
        headers.clear()
        headers.append(f"Source: {message['role'].strip()}")
        if message.get("destination") is not None:
            headers.append(f"Destination: {message['destination'].strip()}")
        header = " " + "\n".join(headers)
        dialog_tokens += tokenizer.encode(header, bos=False, eos=False)

        if message["content"]:
            body = "\n\n " + message["content"].strip()
            dialog_tokens += tokenizer.encode(body, bos=False, eos=False)

        dialog_tokens += [tokenizer.step_id]

    # Start of reply
    headers.clear()
    headers.append("Source: assistant")
    headers.append("Destination: user")
    header = " " + "\n".join(headers)
    dialog_tokens += tokenizer.encode(header, bos=False, eos=False)
    dialog_tokens += tokenizer.encode("\n\n ", bos=False, eos=False)

    return dialog_tokens