LLM 上层设施 增强生成

前面几部分中,我们逐次利用了微调、RAG、多路投票技术来提升生成模型的性能。除了微调意外,RAG 和多路投票都称为增强生成技术(Augmented Generation)。AG 是指在生成模型的基础上,通过引入额外的信息或者技术,来提升生成模型的性能。在这一部分,我们将介绍并使用几种常见的 AG 技术,包括已经使用的 RAG 和多路投票,以及将要介绍的重采样技术与数据增强技术。

数据增强技术

数据增强技术是指通过对原始数据进行一系列变换,来生成新的数据。数据增强技术可以提升生成模型的性能,因为生成模型在训练时可以看到更多的数据,从而提升泛化能力。常见的数据增强技术包括,使用同义词替换;使用大模型生成数据;使用机器翻译生成数据。

同义词替换顾名思义,就是将原始数据中的某些词替换为其同义词。这种方法可以提升生成模型的泛化能力,因为生成模型在训练时可以看到更多的数据。

使用大模型生成数据是指使用大模型根据已有数据生成新的数据,然后供较小的模型训练(当然,也可以用较小的模型生成,但效果不好)。这种方法可以提升生成模型的性能,因为大模型可以生成更多的数据,从而提升泛化能力。然而,由于幻觉问题,生成的数据可能不符合预期,需要进行清洗,正确性也存疑。

机器翻译生成数据类似于同义词替换,不过是先把数据翻译成另一种语言,然后再翻译回来。这种方法可以提升生成模型的泛化能力,因为生成模型在训练时可以看到更多的数据。这样的数据比起用大模型生成的数据更加可靠。

这里我们使用腾讯翻译的 api。

def translate(text: list[str], trans_back=False) -> list[str]:
    from tencentcloud.common import credential
    from tencentcloud.tmt.v20180321 import tmt_client, models
    import os
    cred = credential.Credential(
        os.getenv("TENCENT_SECRET_ID"),
        os.getenv("TENCENT_SECRET_KEY")
    )
    from tencentcloud.common.profile.client_profile import ClientProfile
    from tencentcloud.common.profile.http_profile import HttpProfile
    httpProfile = HttpProfile()
    httpProfile.endpoint = "tmt.tencentcloudapi.com"
    clientProfile = ClientProfile()
    clientProfile.httpProfile = httpProfile
    client = tmt_client.TmtClient(cred, "ap-beijing", clientProfile)
    request = models.TextTranslateBatchRequest()
    params = {
        "Source": "zh" if not trans_back else "ja",
        "Target": "ja" if not trans_back else "zh",
        "ProjectId": 0,
        "SourceTextList": text
    }
    import json
    request.from_json_string(json.dumps(params))
    response = client.TextTranslateBatch(request)
    return json.loads(response.to_json_string())["TargetTextList"]

def translate_entries_inplace(entries: list[Entry], trans_back=False):
    problems = [entry.problem for entry in entries]
    translated_problems = translate(problems, trans_back)
    for entry, translated_problem in zip(entries, translated_problems):
        entry.problem = translated_problem
    reasoning_flat = [
        question.reasoning
        for entry in entries
        for question in entry.questions
    ]
    translated_reasoning = translate(reasoning_flat, trans_back)
    reasoning_idx = 0
    for entry in entries:
        for question in entry.questions:
            question.reasoning = translated_reasoning[reasoning_idx]
            reasoning_idx += 1
    options_flat = [
        option
        for entry in entries
        for question in entry.questions
        for option in question.options
    ]
    translated_options = translate(options_flat, trans_back)
    options_idx = 0
    for entry in entries:
        for question in entry.questions:
            question.options = translated_options[options_idx:options_idx+len(question.options)]
            options_idx += len(question.options)
    questions_flat = [
        question.question
        for entry in entries
        for question in entry.questions
    ]
    translated_questions = translate(questions_flat, trans_back)
    questions_idx = 0
    for entry in entries:
        for question in entry.questions:
            question.question = translated_questions[questions_idx]
            questions_idx += 1

batch = 5
for i in tqdm(range(0, len(export_entries_translated), batch)):
    from_idx =i
    to_idx = min(i + batch, len(export_entries_translated))
    translate_entries_inplace(export_entries_translated[from_idx:to_idx])
    translate_entries_inplace(export_entries_translated[from_idx:to_idx], True)

这样我们的数据量就翻倍了。不过,我们需要在翻译时去掉翻译类的任务,因为这样的任务翻译后可能会有歧义;在 RAG 时,也不要使用扩展的数据,因为这样子,相似度搜索可能会获得多个相同的问题。

此外,instruct 不用翻译,至翻译问题,选项和推理。

重采样技术

重采样又称拒绝采样,是一种用于生成模型的采样技术。该种技术是指,当生成模型生成的结果不符合预期时,拒绝该结果并重新生成。常用的方法是打分器。多路投票也可以看作是一种重采样技术。

打分器是一种用于评估生成结果的技术。在生成模型生成结果后,打分器会对结果进行评分,如果结果不符合预期,则拒绝该结果并重新生成。打分器的设计是一个重要的问题,一个好的打分器可以提升生成模型的性能。

这里我们使用一个小模型进行全量训练作为打分器。

这次我们使用 Qwen2 0.5B 模型进行全量训练,用于打分器。如果输入的推理是正确的,那么打分器的输出 1,否则输出 0。

数据准备

之前我们已经生成好了正确的推理数据,然后我们需要一些错误的推理数据。

这一过程很简单:把我们生成推理数据的 Prompt 的答案改为错误的答案,然后重新生成推理数据,大模型会因为幻觉而进行错误的推理。

因为本来想要的数据就是错误的,所以我们使用一个小模型即可,这里可以用免费 qwen1.5-0.5b-chat 模型。

之前虽然生成好了正确的推理数据,为了提高数据的质量,我们用百度千帆提供的 Yi-34B-Chat(免费但有每日配额)和 Qianfan-Chinese-Llama-2-13B 重新生成一遍。这些模型参数量更大,因此生成的推理数据质量更高。

def generate_reasoning_inplace(entries: list[Entry]) -> None:
    template = """你是一个逻辑推理专家,擅长解决逻辑推理问题。以下是一个逻辑推理的题目,形式为单项选择题。所有的问题都是(close-world assumption)闭世界假设,即未观测事实都为假。每个问题都保证能通过一系列基于形式逻辑的推理(包括同一律,矛盾律,排中律的使用等)得到确定的唯一答案。我会向你提供答案,而你要给出逐步的解析来教会我如何得到答案。我是一个小学生,所以每一步的推理不要太复杂。
{problem}

### 问题
{question}
{options}

### 答案
{answer}

### 分析过程"""
    import time
    import qianfan
    client = qianfan.ChatCompletion(model="Qianfan-Chinese-Llama-2-13B")
    progress_q = tqdm(total=sum([len(entry.questions) for entry in entries]))
    import re
    answer_regex = re.compile(r"答案.*?([A-Z])")
    def process_question(entry, question):
        if question.reasoning is not None:
            matches = answer_regex.findall(question.reasoning)
            answer = matches[-1] if len(matches) > 0 else None
            if answer == None:
                with open("./a.txt", "a") as f:
                    f.write("1\n")
                    f.write(question.reasoning)
                progress_q.update(1)
                return
            if answer != question.answer:
                print(f"答案不匹配: {answer} != {question.answer}")
            else:
                progress_q.update(1)
                return
        max_retries = 3
        for attempt in range(max_retries):
            try:
                reasoning_prompt = template.format(**{
                    "problem": entry.problem,
                    "question": question.question,
                    "options": format_options(question.options),
                    "answer": question.answer
                })
                from qianfan import QfResponse
                resp = client.do([{
                    "role": "user",
                    "content": reasoning_prompt
                }])
                question.reasoning = resp.body["result"]
                # 答案是 [A-Z]
                import re
                matches = answer_regex.findall(question.reasoning)
                answer = matches[-1] if len(matches) > 0 else None 
                if answer is None:
                    break
                if answer != question.answer:
                    print(f"答案不匹配: {answer} != {question.answer}")
                    # Retry
                    continue
                with open("reasoning.txt", "a") as f:
                    f.write(f"{reasoning_prompt}\n\n{question.reasoning}\n\n")
                break  # Exit the loop if successful
            except Exception as e:
                if attempt < max_retries - 1:
                    time.sleep(1)  # Optional: wait a bit before retrying
                else:
                    print(f"Failed to process question.")
                    print(str(e))
        progress_q.update(1)

    with ThreadPoolExecutor(max_workers=1) as executor:
        futures = []
        for entry in tqdm(entries):
            for question in entry.questions:
                futures.append(executor.submit(process_question, entry, question))

        for future in futures:
            future.result()
    progress_q.close()

此外,还有一点要注意,有时候模型的推理给出的答案与我们输入的答案不一致,而且这样的错误数据较多,这样自我们就强迫模型重新生成,至答案一致为止,如果实在得不到正确推理就把它剔除。我们使用简单的正则表达式来提取答案。

import re
answer_regex = re.compile(r"答案.*?([A-Z])")
export_entries = []
for entry in entries:
    export_questions = []
    for question in entry.questions:
        if question.reasoning is None:
            continue
        matches = answer_regex.findall(question.reasoning)
        answer = matches[-1] if len(matches) > 0 else None
        if answer is None:
            continue
        if answer != question.answer:
            print(f"答案不匹配: {answer} != {question.answer}")
            continue
        export_questions.append(QuestionItem(
            question=question.question, options=question.options, reasoning=question.reasoning, answer=answer
        ))
    export_entries.append(Entry(problem=entry.problem, questions=export_questions))

之后导出数据即可。

训练

这一次我们使用全量训练。因为模型很小,所以显存占用可以接受。此外,我们还可以使用 lomo 优化器,可参考这个文档

#!/usr/bin/env python
# coding: utf-8

# In[1]:


from transformers import Qwen2ForSequenceClassification
from transformers import Qwen2TokenizerFast
import torch
from datasets import Dataset
from dataclasses import dataclass, field


# In[2]:


model_path = "./Qwen2-0.5B"
max_length = 512
output_dir = "./scorer_output"


# In[3]:


def get_model_and_tokenizer() -> tuple[Qwen2ForSequenceClassification, Qwen2TokenizerFast]:
    model = Qwen2ForSequenceClassification.from_pretrained(
        model_path,
        num_labels=2,
    )
    tokenizer: Qwen2TokenizerFast = Qwen2TokenizerFast.from_pretrained(model_path)
    model.config.pad_token_id = tokenizer.eos_token_id
    return model, tokenizer


# In[4]:


@dataclass
class QuestionItem:
    question: str
    options: list[str]
    reasoning: str | None = field(default=None)
    answer: str | None = field(default=None)

@dataclass
class Entry:
    problem: str = field(default="")
    questions: list[QuestionItem] = field(default_factory=list)


# In[5]:


def get_train_dataset() -> Dataset:
    import pickle
    entries = pickle.load(open("./entries.pkl", "rb"))
    entries_false = pickle.load(open("./entries_false.pkl", "rb"))
    entries_aug = pickle.load(open("./entries_aug.pkl", "rb"))
    entries_aug_false = pickle.load(open("./entries_aug_false.pkl", "rb"))
    dataset = []
    for entry in (entries + entries_aug):
        for question in entry.questions:
            label = 1
            dataset.append({
                "reasoning": question.reasoning,
                "labels": label
            })
    for entry in (entries_false + entries_aug_false):
        for question in entry.questions:
            label = 0
            dataset.append({
                "reasoning": question.reasoning,
                "labels": label
            })
    return Dataset.from_list(dataset).shuffle(seed=42)


# In[6]:


def encode_dataset(dataset: Dataset, tokenizer: Qwen2TokenizerFast) -> Dataset:
    def encode(examples):
        encoded = tokenizer(
            examples["reasoning"],
            pad_to_multiple_of=8,
            max_length=max_length,
            truncation=True
        )
        return encoded
    return dataset.map(encode, batched=True).remove_columns(["reasoning"])


# In[7]:


import os
os.environ["WANDB_PROJECT"] = "fine-tune-logic-inference-scorer"


# In[8]:


from trl.trainer import SFTTrainer, SFTConfig
from transformers import TrainingArguments

def get_config() -> SFTConfig:
    return SFTConfig(
        output_dir=output_dir,
        num_train_epochs=3,
        learning_rate=1e-4,
        logging_steps=5,
        seed=42,
        optim="lomo",
        lr_scheduler_type="cosine",
        report_to="wandb",
    )


# In[9]:


def train():
    from accelerate import Accelerator
    acc = Accelerator()
    model, tokenizer = get_model_and_tokenizer()
    train_dataset = get_train_dataset()
    train_dataset = encode_dataset(train_dataset, tokenizer)
    model, tokenizer, train_dataset = acc.prepare(model, tokenizer, train_dataset)
    config = get_config()
    from transformers import DataCollatorWithPadding
    trainer = SFTTrainer(
        model=model,
        args=config,
        train_dataset=train_dataset,
        data_collator=DataCollatorWithPadding(tokenizer),
    )
    trainer.train()
    trainer.save_model("model")


# In[10]:


if __name__ == "__main__":
    train()


训练过程中,显示 grad_norm 为 0 是正常的,因为 lomo 优化器会卸载梯度,因此不会计算梯度的范数。只要 loss 能降到 0.6 以下(因为是 2 分类问题,random guess 的交叉熵损失是 ),就说明是正常的。

打分器 logging

使用打分器

使用打分器也很简单,只需要加载模型,然后输入推理即可。

def score_reasoning(reasoning: list[str], scorer: Qwen2ForSequenceClassification, tokenizer: Qwen2TokenizerFast) -> list[float]:
    encoded = tokenizer(reasoning, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to("cuda")
    with torch.no_grad():
        logits = scorer(**encoded).logits
    return logits.argmax(dim=-1).detach().cpu().numpy().tolist()

推理代码

下面的代码有大量的 monkey patch,但是能跑就好。

流程图如下,

推理流程图

用新的数据重新走完微调、推理流程,即可获得新的结果。

#!/usr/bin/env python
# coding: utf-8

# In[1]:


import warnings

warnings.simplefilter("ignore")


# In[2]:


from transformers import AutoModelForCausalLM, Qwen2TokenizerFast, Qwen2ForSequenceClassification
from dataclasses import dataclass, field
from langchain_core.prompts import ChatPromptTemplate
from tqdm import tqdm
from datasets import Dataset
from langchain_core.documents import Document
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings

embedding_path = "./Dmeta-embedding-zh"
def get_dataset() -> Dataset:
    train_dataset = Dataset.load_from_disk("./train_dataset")
    return train_dataset
def to_doc(row: dict) -> Document:
    return Document(
        row["question"], metadata=row
    )
def get_embedding() -> HuggingFaceEmbeddings:
    return HuggingFaceEmbeddings(model_name=embedding_path)
def get_db() -> Chroma:
    db = Chroma.from_documents(
        [to_doc(row) for row in get_dataset()],
        get_embedding()
    )
    return db


def get_scorer() -> Qwen2ForSequenceClassification:
    m = Qwen2ForSequenceClassification.from_pretrained("./scorer", device_map="auto", num_labels=2)
    return m



# In[3]:


model_path = "./qwen2-7b-instruct"
adapter_path = "./output"
max_new_tokens = 1024
max_length = max_new_tokens
batch_size = 16
voters = 5

def score_reasoning(reasoning: list[str], scorer: Qwen2ForSequenceClassification, tokenizer: Qwen2TokenizerFast) -> list[float]:
    encoded = tokenizer(reasoning, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to("cuda")
    with torch.no_grad():
        logits = scorer(**encoded).logits
    return logits.argmax(dim=-1).detach().cpu().numpy().tolist()

# In[4]:

# In[5]:


import json

class EntryEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, Entry):
            return obj.__dict__
        if isinstance(obj, QuestionItem):
            return obj.__dict__
        if isinstance(obj, str):
            return obj
        return super().default(obj)

@dataclass
class QuestionItem:
    question: str
    options: list[str]
    answer: str | None = field(default=None)

@dataclass
class Entry:
    problem: str = field(default="")
    questions: list[QuestionItem] = field(default_factory=list)
    id: str = field(default="")


# In[6]:


def parse_file(file_path: str) -> list[Entry]:
    with open(file_path, "r") as f:
        lines = f.readlines()
    entries = []
    import json
    for line in lines:
        entry = json.loads(line)
        questions = []
        for question in entry["questions"]:
            questions.append(QuestionItem(**question))
        entries.append(Entry(problem=entry["problem"], questions=questions, id=entry["id"]))
    return entries


# In[7]:


def format_options(options):
    return '\n'.join(
        [
            f'{chr(ord("A") + i)}: {option}'
            for i, option in enumerate(options)
        ]
    )


# In[8]:


# In[9]:


def dump_entries(entries: list[Entry], file_path: str):
    import json
    with open(file_path, "w") as f:
        for entry in entries:
            f.write(json.dumps(entry, cls=EntryEncoder, ensure_ascii=False) + "\n")

def build_prompt_no_header(x: dict) -> str:
    problem = x["problem"]
    question = x["question"]
    reasoning = x["reasoning"]
    answer = x["answer"]
    options = x["options"]
    full_text = f"""### 题目        
{problem}

### 问题
{question}
{options}

### 分析过程
{reasoning}

### 答案
答案是:{answer}"""
    return full_text

def build_prompt(x: dict, db: Chroma) -> str:
    head = r"""你是一个逻辑推理专家,擅长解决逻辑推理问题。以下是一个逻辑推理的题目,形式为单项选择题。所有的问题都是(close-world assumption)闭世界假设,即未观测事实都为假。每个问题都保证能通过一系列基于形式逻辑的推理(包括同一律,矛盾律,排中律的使用等)得到确定的答案。请逐步分析问题,写出思考过程,并在最后一行输出答案,最后一行的格式为"答案是:A"或"答案是:B"或"答案是:C"或"答案是:D"等等。如果你做对了这个题目,你会获得的一亿奖金。"""
    tops = db.similarity_search(
        x["question"],
        k=2,
    )
    first_top = tops[0]
    second_top = tops[1]
    full = f"""{head}
这是一个例子:
{build_prompt_no_header(first_top.metadata)}

这是另一个例子:
{build_prompt_no_header(second_top.metadata)}

现在,你需要解决这个问题:
### 题目
{x["problem"]}

### 问题
{x["question"]}
{x["options"]}

### 分析过程"""
    return full

def get_prompts_of_entry(entry: Entry, db: Chroma) -> list[str]:
    prompts = []
    for question in entry.questions:
        prompt = build_prompt({
            "problem": entry.problem,
            "question": question.question,
            "options": format_options(question.options)
        }, db)
        prompts.append(prompt)
    return prompts

from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
import torch
def get_model_and_sampling() -> tuple[LLM, SamplingParams, LoRARequest]:
    llm = LLM(model_path, model_path, enable_lora=True, quantization="fp8", gpu_memory_utilization=0.7)
    sampling = SamplingParams(
        max_tokens=max_new_tokens,
        n=voters,
    )
    lora = LoRARequest(
        "default",
        1,
        adapter_path
    )
    return llm, sampling, lora

def get_answer_from_raw(raw) -> str:
    raw = raw
    raw = raw.split("---")[0]
    raw = raw.split("### 题目")[0]
    raw = raw.split("### 问题")[0]
    remove_chars = ["*", "-"]
    for char in remove_chars:
        raw = raw.replace(char, "")
    import re
    match = re.search(r"答案是:[A-Z]", raw)
    if match:
        answer = match.group()
        answer = answer.split(":")[-1]
    else:
        answer = None
    return answer

def log_file(log_content: str):
    with open("./a.txt", "a") as f:
        f.write(log_content + "\n")

def generate_answers_batch(prompts: list[str], model: LLM, sampling: SamplingParams, lora: LoRARequest, scorer: Qwen2ForSequenceClassification, tokenizer: Qwen2TokenizerFast) -> list[str]:
    answers_unchecked = [[""] * voters for _ in range(len(prompts))]
    scores = [0] * len(prompts)
    resample_times = 0
    max_resample = 5
    while not all([s == 1 for s in scores]) and resample_times < max_resample:
        resample_times += 1
        need_resample = [
            i for i in range(len(prompts)) if scores[i] == 0
        ]
        resample_prompts = [prompts[i] for i in need_resample]
        resampled_responses = model.generate(
            resample_prompts,
            sampling,
            lora_request=lora
        )
        resampled_answers = [
            [
                o.text for o in r.outputs
            ] 
            for r in resampled_responses
        ]
        resampled_scores = [
            score_reasoning(ans, scorer, tokenizer)
            for ans in resampled_answers
        ]
        for i in range(len(resampled_answers)):
            for j in range(len(resampled_answers[i])):
                answers_unchecked[need_resample[i]][j] = resampled_answers[i][j]
        # remove answers that are not valid
        for i in range(len(resampled_scores)):
            for j in range(len(resampled_scores[i])):
                if resampled_scores[i][j] != 1:
                    answers_unchecked[need_resample[i]][j] = ""
        # choose the max score of each set of reasoning
        resampled_scores = [max(s) for s in resampled_scores]
        scores = [1] * len(prompts)
        for i in range(len(need_resample)):
            scores[need_resample[i]] = resampled_scores[i]



    answers = []
    for one_output in answers_unchecked:
        answer = [get_answer_from_raw(r) for r in one_output]
        log_file(
            "=" * 20 + "\n" + \
            "Prompt: \n" + prompts[answers_unchecked.index(one_output)] + "\n" + \
            "=" * 20 + "\n" + \
            ("=" * 20 + "\n").join(
                [f"Output: \b{o}" for o in one_output]
            )
        )
        filtered_answers = [ans for ans in answer if ans is not None]
        if not filtered_answers:
            most_common_answer = None
        else:
            most_common_answer = max(set(filtered_answers), key=filtered_answers.count)
        answer = most_common_answer
        if answer is None:
            print("Failed to get answer, default to A.")
            answer = "A"
        answers.append(answer)
    return answers


def generate_answers_inplace(entries: list[Entry], prompts: list[list[str]], model: LLM, sampling: SamplingParams, lora: LoRARequest, scorer: Qwen2ForSequenceClassification, tokenizer: Qwen2TokenizerFast):
    next_entry_index = 0
    next_question_index = 0
    total = sum([len(entry.questions) for entry in entries])



    progress = tqdm(total=(total // batch_size + (1 if total % batch_size != 0 else 0)))

    while next_entry_index < len(entries):
        batch_entries = []
        batch_prompts = []
        start_index_of_first_entry_question = next_question_index
        start_index_of_first_entry = next_entry_index
        while len(batch_prompts) < batch_size and next_entry_index < len(entries):
            entry = entries[next_entry_index]
            prompt = prompts[next_entry_index][next_question_index]
            batch_entries.append(entry)
            batch_prompts.append(prompt)
            next_question_index += 1
            if next_question_index >= len(entry.questions):
                next_question_index = 0
                next_entry_index += 1
        answers = generate_answers_batch(batch_prompts, model, sampling, lora, scorer, tokenizer)

        if len(answers) != batch_size:
            print("Answers length not equal to prompt entries length.")
            print("This should not happen.")

        i = 0
        next_entry_to_set = start_index_of_first_entry
        next_question_to_set = start_index_of_first_entry_question
        while i < len(answers):
            entry = entries[next_entry_to_set]
            entry.questions[next_question_to_set].answer = answers[i]
            i += 1
            next_question_to_set += 1
            if next_question_to_set >= len(entry.questions):
                next_question_to_set = 0
                next_entry_to_set += 1
        progress.update()
    progress.close()


# In[10]:

def get_tokenizer() -> Qwen2TokenizerFast:
    return Qwen2TokenizerFast.from_pretrained(model_path)

def main():
    entries = parse_file("./round1_test_data.jsonl")
    print("building db...")
    db = get_db()
    print("db built")
    print("building prompts...")
    prompts = [get_prompts_of_entry(entry, db) for entry in entries]
    # unload the embedding
    del db
    print(prompts[0])
    generator, sampling, lora = get_model_and_sampling()
    scorer = get_scorer()
    tokenizer = get_tokenizer()
    generate_answers_inplace(entries, prompts, generator, sampling, lora, scorer, tokenizer)
    dump_entries(entries, "./upload.jsonl")


# In[11]:


if __name__ == "__main__":
    main()

总结

在这一部分中,我们介绋了数据增强技术和重采样技术,并使用了这两种技术来提升生成模型的性能。数据增强技术可以提升生成模型的泛化能力,重采样技术可以提升生成模型的性能。在实际应用中,我们可以根据具体的任务和数据集来选择合适的增强技术。

不过最后这个打分器的效果不是很好,应该多给一些数据,或者使用更大的模型。不过,这种机制在某些场景下是有用的。

对于逻辑推理问题,最好还可以引用逻辑链(chain of thoughts)来生成推理过程,这样可以更好地理解推理过程。不过,由于我们使用的 instruct 模型而非 chat 模型,因此无法生成逻辑链。