★本文首发于:https://zhuanlan.zhihu.com/p/710467126
目前 LLM(Large Language Model)从文本补全到内容创作,都展示出了强大的生成能力。然而通过 LLM 生成结构化的数据如 JSON 格式的输出,却仍然是一个有挑战性的任务。
生成结构化的数据不仅要求模型输出符合特定的语法规则,还需要确保数据的正确性和一致性。
虽然通过 prompt 工程可能可以实现指定格式的结构化数据生成,但是这也很大程度取决于模型的能力。
本文将探讨如何结合人工规则让 LLM 输出符合 JSON 格式的数据。
本文主要是结合 lm-format-enforcer
( https://github.com/noamgat/lm-format-enforcer ) 这个库来讲解如何让 LLM 生成指定格式的 JSON 数据。
目前该库也是被 vllm 作为 JSON 格式输出的后端之一:https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
结构化数据生成的原理用一句话概括就是:
每个 step 拿到当前 model 给出的 logits 之后,在采样下一个 token 之前,通过人工设定的规则可以得到当前 step 只允许采样的 token 集合,接着通过加 bias 的方式压制其他不允许采样的 token,从而实现指定的结构化数据生成。
那么怎么得到当前 step 可允许采样的 token 集合,就是本文重点讲解的内容了。
lm-format-enforcer
这个库包含两个核心模块,分别是 tokenizer 前缀树 和 字符级别的解析器,通过这两个模块就可以实现上述的功能。
lm-format-enforcer
这个库在初始化阶段,首先会根据 tokenizer 给出的词表,初始化一个字符级别的前缀树,这个前缀树怎么理解呢?
通过 tokenizer 给出的词表,我们可以得到一个词表中的 字符串 和 对应 token id 的映射。通过这些映射,就可以来构造这个前缀树。
树上每个节点对应词表中某个字符串的其中一个字符,每个节点的子节点就是连着的下一个字符,当字符串中的字符已经遍历完了,这时候就是填入该字符串对应的 token id。
现在通过具体的例子解释一下,这个前缀树是如何构造的。
我们用 llama2 模型的词表来解读,假设就取词表中的一个小子集:
{
"a": 29874,
"ar" : 279,
"are" : 598,
"Y": 29979,
"You" : 3492,
"O": 29949,
"OK": 8949,
}
下面用图展示树的构造过程:
遍历第1个映射:
遍历第2个映射:
遍历第3个映射:
遍历第4个映射:
遍历第5个映射:
遍历第6个映射:
遍历第7个映射:
通过上面图示,展示了如何通过词表子集构造前缀树,实际的前缀树比这个大多了,整个词表中的 字符串 和 token id 的映射都会通过这样的方式插入到前缀树中。
构造好前缀树之后,接下来就是讲解怎么得到每个 step 可允许采样的 token 集合。
lm-format-enforcer
还有另一个重要的模块就是 字符级别的解析器。
这个解析器的作用简单来理解就是,在初始化的时候,会接收用户指定的 json schema,接着在后续每一步生成过程中,会根据之前生成的内容,判断目前处于什么状态,然后根据当前所处的状态直接给出限定的字符集合。
下面举个简单的例子,比如用户指定的 json schema 是:
JSON_SCHEMA = {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "Name of the city."
}
},
"required": ["city"]
}
想要 LLM 生成一个 JSON object ,内容是包含一个 city
属性,该属性的内容是一个字符串,表示一个城市的名字,同时该 city
必须要在结果中出现。
解析器的作用就是,比如目前已经生成好的内容是 :
{
"
那么下一步一定是要生成 city
这个字符串,解析器的作用就是根据目前的状态,会给出限定的字符集合 ['c', 'i', 't', 'y']
。
然后接下来比如生成到了:
{
"city": "
那么接下就是要 LLM 生成一个城市的名字,但是其实对于解析器来说,他只知道接下来要生成的内容是字符串,而且内容只需要符合 JSON 格式就行了,所以这时候给出的限定字符集合就非常大了,词表中的 token 对应的字符串只要符合 JSON 格式的都可以。
最后具体能生成什么城市名字,还有这个城市是否真实存在,就得看 LLM 的能力了。
下面用一个具体的例子讲解一下,怎么结合 前缀树 和 解析器,获取每个 step 限定的 token 集合。
假设用户的输入 prompt 和指定的 json schema 是:
prompt = "Please output a JSON segment representing the name of a city, including fields for city name."
JSON_SCHEMA = {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "Name of the city."
}
},
"required": ["city"]
}
有一点需要注意,获取可允许采样 token 集合在 lm-format-enforcer
库中是通过递归的方式实现的,下面为了讲解方便,会给每一层递归编个号:
第 0 层递归
首先解析器给出的限定字符集合就是
[' ', '\t', '\n', '\r', '{']
包括空格和大括号在内的5个字符。
然后将这个5个字符和前缀树根节点的所有第一层子节点对应的字符集合做一个交集。
获取得到的字符交集还是这 5 个字符:
[' ', '\t', '\n', '\r', '{']
接着遍历这个字符交集。
遍历每个字符的时候会假设目前已经生成了该字符,比如一开始遍历空格字符 ' '
,会将空格当作已经生成的内容加入到解析器中,这时候解析器内部状态会变化,同时取前缀树中空格字符节点对应的所有子节点,进入下一轮递归。
下一轮递归开始的时候,首先将会该子节点包含的所有 token id 加入到当前 step 的候选 token 列表中,然后继续重复上述流程。
第 1 层递归
首先看目前遍历到的前缀树节点包含的 token id 集合是
[35, 29871]
分别对应 llama2 词表中的字符串
"<0x20>"
"▁"
其中, 0x20
表示 ASCII 编码表中的空格字符,所以 在 llam2 的词表中,空格对应的 token 有两个。
接着继续看第 1 层的递归,解析器在上一层添加了空格字符之后,给出的限定字符集合仍然是
[' ', '\t', '\n', '\r', '{']
因为假设前面生成的是空格的情况下,接下来的可生成的字符其实还是可以是之前的 5 个中选一个。
然后前缀树当前节点下的所有第一层子节点的字符集合:
[' ', 't', 'a', 's', 'd', 'c', 'w', 'p', 'f', 'm', 'o', 'b', 'i', 'h', 'l', 'n', 'I', '(', 'C', 'S', 'u', 'A', '\\', 'e', 'T', 'v', 'g', '*', 'r', 'M', 'y', 'P', 'B', '=', 'D', 'L', '"', 'H', 'E', 'F', 'R', '$', '#', 'W', 'G', 'N', 'k', '`', '{', 'j', 'J', 'O', 'q', '-', 'п', 'K', 'V', 'в', '}', 'U', 'z', '[', "'", '<', 'с', ':', 'и', 'Y', 'о', 'Q', 'д', 'н', '&', '+', '@', 'з', 'м', '–', 'Z', '—', 'à', 'б', '/', 'С', '«', 'у', '.', '|', '_', 'é', 'x', 'В', 'П', 'к', 'X', 'К', 'г', 'а', 'М', '%', 'А', 'р', '“', 'Б', 'Н', '>', 'Д', 'Р', '?', 'ф', 'Г', 'О', 'е', 'Т', 'т', ')', '!', '„', 'Л', 'і', ',', 'У', '»', ';', 'è', 'И', 'ä', 'я', 'э', 'З', 'ч', 'ü', 'Ф', 'ј', '·', 'î', 'Х', 'É', 'Е', 'ш', 'č', 'л', 'Ч', '~', 'ц', 'ú', 'ö', 'á', 'Ш', 'ș', 'х', 'ж', ']', 'Э', '‘', 'І', 'Ц', 'щ', 'Я', 'ž', 'ś', '^', 'Ö', 'š', '†', '°', '\r', 'Ю', 'Ж', 'Ü', 'Á', 'й', 'Č', 'ê', 'ю', 'À', '№', 'Š', 'å', 'є', '•', '→', 'Ś', 'Å', 'ї', 'Ä', 'Î', '│', '×', 'ż', 'Ž', '−', 'È', 'Ł', 'Є', 'í', 'Ż', 'Й', '£', 'Ј', '…', '’', '§', 'ó', 'Ú', '¿', 'ř', 'â', 'α', '\xa0', 'ő', 'њ', 'ا', '€', '”', 'Ó', 'Щ', 'ł', 'Í', '¡']
其实对应的都是词表中起始字符是空格的 token ,然后两者的交集是:
[' ', '\r', '{']
其实就是对应词表中以空格起始的三个 token :
"▁▁": 259
"▁\r": 6756
"▁{": 426
接着遍历交集 [' ', '\r', '{']
,进入第 2 层递归。
由于 llama2 词表中包含连续空格的 token 最长的有15个连续空格 token :
"▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁": 462
但是递归最多只会深入到 12 层,因为 lm-format-enforcer
库中默认限定了最长连续的空格数量是 12 个,所以连续探索空格达到 12 层递归之后就会终止探索,接着回溯到第 1 层,继续那一层其他剩下还没探索的交集字符的递归过程。
一直重复直到所有层 前缀树 和 解析器 的所有字符交集都探索完毕。
最终第一个 step 得到的可允许采样的 token 集合是:
"<0x20>": 35 # 对应 ASCII 表中的空格字符
"▁": 29871
"▁▁": 259
"▁▁▁": 1678
"▁▁▁▁": 268
"▁▁▁▁▁": 418
"▁▁▁▁▁▁": 539
"▁▁▁▁▁▁▁": 4706
"▁▁▁▁▁▁▁▁": 308
"▁▁▁▁▁▁▁▁▁": 3986
"▁▁▁▁▁▁▁▁▁▁": 965
"▁▁▁▁▁▁▁▁▁▁▁": 9651
"▁▁▁▁▁▁▁▁▁▁▁▁": 632
"▁\r": 6756
"▁{": 426
"▁{\r": 3336
"▁{\"": 8853
"<0x09>": 12 # 对应 ASCII 表中的 \t 字符
"<0x0A>": 13 # 对应 ASCII 表中的 \r 字符
"<0x0D>": 16 # 对应 ASCII 表中的 \n 字符
"\r": 30004
"<0x7B>": 126 # 对应 ASCII 表中的 { 字符
"{": 29912
"{\r": 14626
"{\"": 6377
然后我们直接跳到第 6 个 step,假设目前 LLM 已经生成的内容是,
{
"
前面每个 step 生成的内容按顺序是 ['\n', '\n', '\n', '{', '\n', '"']
:
然后根据用户设定的 json schema,接下来其实就是要限制采样必须生成 city
这个字符串,我们来看下递归的过程。
第 0 层递归
首先解析器给出的限定字符集合就是 ['c']
然后前缀树根节点所有第一层子节点的交集就只有 'c'
字符,然后将 c
加入解析器,同时取根节点下 c
对应的所有子节点进入
第 1 层递归
而由于上一层生成了字符 c
,那么对于解析器来说,接下来的字符肯定要是 i
,所以给出的限定字符集合就是 ['i']
,和当前树节点的第一层子节点的交集自然也就是只有字符 'i'
,然后继续递归。
以此类推,可得当前 step 的限定 token 集合为:
"<0x63>": 102 # 对应 ASCII 表中小写字符 c
"c": 29883
"ci": 455
"cit": 20752
"city": 12690
接着跳到第 9 个 step,假设到目前为止已经生成了:
{
"city": "
那么这时候,根据解析器的判断,接下来其实就是可以自由生成任意符合 json 格式的字符,所以这时候返回的 token 集合会非常大,接近词表大小。
lm-format-enforcer
中对这个情况做了优化,就是这些 token 集合是可以在生成前缀树的过程中拿到。
所以如果当前是自由生成字符模式,则不会进入递归过程,直接返回这些 token 集合即可。
在拿到可允许采样的 token 集合之后,接下来的操作就简单了,只需要给 logits tensor 加一个偏置即可,伪代码实现:
allow_tokens = [xx, yy, zz, ....]
bias = torch.full((vocab_size,), -math.inf)
bias[allow_tokens] = 0
logit += bias
通过给不允许采样的 token 加一个负无穷的方式来压制这些 token 不会被采样得到。
其实除了 lm-format-enforcer
的实现方式之外,还有其他人工规则的结构化生成库比如 github 上 star 更多的 outlines
库。感兴趣的读者可以进一步对比两者的实现有什么不同。
[1] https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
[2] https://github.com/noamgat/lm-format-enforcer?tab=readme-ov-file#how-does-it-work
[3] https://github.com/outlines-dev/outlines