Skip to content

利用大模型高效生成训练数据[含代码]

在大模型出现之前,训练数据的构建主要依赖于人工标注、开源数据集,以及从线上数据中提取合适的监督数据。然而,开源数据通常难以完全满足特定业务需求(大部分情况下无法直接使用),而现有的线上数据也可能无法抽取出符合要求的监督数据。这时,人工标注似乎成为唯一的选择,但这种方法不仅耗费大量人力和时间,还难以保证足够的效率。
随着大模型的出现,数据生成的方式开始发生变化。通过构建高质量的 prompt,我们可以利用大模型自动生成训练数据。这一过程的原理相对简单,因此本次分享的重点不在于原理,而是向大家展示一套本人在实际工作中经常使用的代码,以便大家可以直接应用于自己的项目。

整体流程概述

当接收到一个业务需求时,与产品团队对齐细节后便可以开始编写 prompt。通常,我会使用 vLLM 部署效果较好的大模型,便于通过 OpenAI 的 SDK 进行调用。经过反复调试和迭代 prompt,当效果接近预期后,即可开始生成用于训练小模型的数据(假设现有硬件资源无法支持如 72B 这样的超大型模型,且 1.5B 或 7B 模型的效果也无法满足需求)。为提高数据生产的效率,通常会采用多进程加速生成过程。接下来,我将从代码层面详细讲解这一过程。

vLLM 部署大模型

vLLM 提供了非常便捷的命令行部署方式,更多参数选择可以参考vLLM文档

bash
CUDA_VISIBLE_DEVICES="0,1" python -m vllm.entrypoints.openai.api_server --served-model-name model_name --model model_path --tensor-parallel-size 2 --port 8002

假设启动过程顺利,你将在终端中看到访问地址,通常为 http://0.0.0.0:8002。在浏览器中输入 http://0.0.0.0:8002/docs,你将访问到一个可交互的文档界面,便于测试服务是否正常运行。

调试 Prompt

vLLM 提供了一个基于 Gradio 的示例代码,极大地方便了 prompt 的调试。我在原有的基础上做了一些修改,允许在界面上直接编辑 prompt,从而避免了每次修改后都需重启服务的不便。以下是修改后的代码示例:

python
import argparse
from collections.abc import Generator
import gradio as gr
from openai import OpenAI

# Argument parser setup
parser = argparse.ArgumentParser(description="Chatbot Interface with Customizable Parameters")
parser.add_argument("--model-url", type=str, default="http://localhost:8000/v1", help="Model URL")
parser.add_argument("-m", "--model", type=str, default="gpt-3.5-turbo", help="Model name for the chatbot")
parser.add_argument("--temp", type=float, default=0.8, help="Temperature for text generation")
parser.add_argument("--stop-token-ids", type=str, default="", help="Comma-separated stop token IDs")
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001)
args = parser.parse_args()

# Set OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = args.model_url

# Create an OpenAI client to interact with the API server
client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)

def predict(message: str, history: list[tuple[str, str]], system_message: str) -> Generator[str, None, None]:
    history_openai_format = [{"role": "system", "content": system_message}]
    for human, assistant in history:
        history_openai_format.append({"role": "user", "content": human})
        history_openai_format.append({"role": "assistant", "content": assistant})
    history_openai_format.append({"role": "user", "content": message})

    stream = client.chat.completions.create(
        model=args.model,
        messages=history_openai_format,
        stream=True,
        extra_body={
            "repetition_penalty": 1,
            "stop_token_ids": [int(id.strip()) for id in args.stop_token_ids.split(",") if id.strip()] if args.stop_token_ids else [],
        },
        max_tokens=2048,
    )

    partial_message = ""
    for chunk in stream:
        partial_message += chunk.choices[0].delta.content or ""
        yield partial_message

# Create and launch a chat interface with Gradio
gr.ChatInterface(
    predict,
    additional_inputs=[gr.Textbox("you are a helpful assistant", label="System Prompt")],
    additional_inputs_accordion=gr.Accordion(open=True),
).queue().launch(server_name=args.host, server_port=args.port, share=True)

在服务启动过程中,可能会出现以下信息提示:

shell
Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps:
...

尽管问题不解决也能在本地访问,但如果需要与他人共享服务,建议按照提示进行修复。具体步骤如下:

shell
wget https://cdn-media.huggingface.co/frpc-gradio-0.2/frpc_darwin_arm64
mv frpc_darwin_arm64 frpc_darwin_arm64_v0.2
chmod +x frpc_darwin_arm64_v0.2
mv frpc_darwin_arm64_v0.2 you_gradio_path_in_env

修复完成后,重新启动服务,你将看到如下输出:

shell
Running on local URL:  http://127.0.0.1:8001
Running on public URL: https://24e925b09b9a9c337d.gradio.live

第二个地址可以分享给他人访问。

大规模数据蒸馏

在获得满意的 prompt 效果后,我们可以开始大规模生产训练数据。这一步的核心是将之前的对话界面代码修改为读取待标注的数据,并使用多进程调用 vLLM 启动的服务。以下是主要代码示例:

python
# 读取输入数据
df = pd.read_json(CONFIG["INPUT_FILE"], lines=True)

# 并行处理数据
with ProcessPoolExecutor(max_workers=CONFIG["MAX_WORKERS"]) as executor:
    list(tqdm(executor.map(process_row, df.to_dict(orient="records")), total=len(df)))

def process_row(row):
    try:
        user_input = USER_INPUT_TEMPLATE.format(**row)
        messages = [
            {"role": "system", "content": SYSTEM_MESSAGE},
            {"role": "user", "content": user_input},
        ]
        response = client.chat.completions.create(
            model=CONFIG["MODEL_NAME"],
            messages=messages,
        ).choices[0].message.content

        post_process(row, response)
    except Exception as e:
        print(f"处理数据时出错: {e}")
        print(f"跳过数据: {row.get('id', 'unknown')}")

def post_process(row, response):
    row["model_response"] = response
    unique_id = str(uuid.uuid4())
    filename = f"{CONFIG['PROCESSED_DIR']}/{unique_id}.json"
    with open(filename, "w", encoding="utf-8") as f:
        json.dump(row, f, ensure_ascii=False, indent=4)

上述代码通过多进程加速数据生产,能有效提高效率。一般情况下,几十万条数据的生产时间在两到三天左右,具体时间还取决于任务数据的长度以及部署大模型的硬件资源条件等。

总结

本文分享了本人在工作中常用的一套蒸馏数据的代码,希望能为大家在实际项目中提供帮助,完整代码见数据蒸馏。本文重点在于高效地生成训练数据,未详细讨论 prompt 的迭代优化以及不同场景下如何生成高质量数据。这些内容将在后续文章中进一步探讨。