与 LLM 引擎的集成¶
XGrammar 实现了高效的结构化生成。在本教程中,将探讨 XGrammar 的关键组件,以及如何将 XGrammar 集成到 LLM 引擎中。
首先在 高级流程 中阐述了相关概念。接着,展示了 XGrammar 如何实现 批量推理的结构化生成。
以下代码片段是实际可运行的代码,因为模拟了 LLM 的生成过程。
安装 XGrammar¶
XGrammar 可以通过 pip 安装。建议始终在独立的 conda 虚拟环境中安装它。
高级流程¶
在本节中,将探讨将 XGrammar 集成到 LLM 引擎以实现结构化生成时的关键组件。
首先,为本教程导入必要的库。
import xgrammar as xgr
import torch
import numpy as np
from transformers import AutoTokenizer, AutoConfig
xgr.TokenizerInfo¶
xgr.TokenizerInfo
是针对每个模型的构造,它封装了分词器的信息,包括其所有词汇。有几种实例化它的方法,最便捷的方式是使用 AutoTokenizer
。请注意,对于某些模型,由于填充的存在, AutoConfig.vocab_size
可能会大于 AutoTokenizer.vocab_size
,前者是模型逻辑单元(logits)的形状。为了安全起见,在实例化 xgr.TokenizerInfo
时,始终传入前者。
# Get tokenizer info
model_id = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = AutoConfig.from_pretrained(model_id)
# This can be larger than tokenizer.vocab_size due to paddings
full_vocab_size = config.vocab_size
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=full_vocab_size)
xgr.GrammarCompiler¶
有了 xgr.TokenizerInfo
,就可以实例化 xgr.GrammarCompiler
。这是根据模型的分词器信息编译语法的构造。因此,对于每个模型,你可以持久地使用同一个 xgr.GrammarCompiler
,因为它可以为相同的 xgr.TokenizerInfo
编译不同的语法。请注意,compiler
的行为可以通过 max_threads
进行多线程配置,以及通过 enable_cache
(默认为true)来启用或禁用编译语法的缓存。
compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
xgr.CompiledGrammar¶
接下来,使用 xgr.GrammarCompiler
,可以编译语法,结果是得到 xgr.CompiledGrammar
。这里我们使用了内置的 JSON 语法。对于其他语法,请参阅 JSON 生成 和 EBNF引导的生成。到目前为止,所看到的一切都是针对每个模型的(而不是每次生成)。
compiled_grammar: xgr.CompiledGrammar = compiler.compile_builtin_json_grammar()
xgr.GrammarMatcher¶
有了编译好的语法,就可以实例化 xgr.GrammarMatcher
。它是 LLM 引擎与之交互的主要构造,用于维护结构化生成的状态。请注意,每个请求都应该有自己的 xgr.GrammarMatcher
,因为每个请求都有不同的生成状态,正如将在 批量推理的结构化生成 中看到的那样。
# Instantiate grammar matcher with the compiled grammar
matcher = xgr.GrammarMatcher(compiled_grammar)
自回归生成中的位掩码逻辑单元(Logits)¶
现在模拟单请求的自回归生成过程。关于批量推理,请参阅后面的 批量推理的结构化生成 部分。
首先,使用 xgr.allocate_token_bitmask()
预分配令牌位掩码,它本质上是形状为 (batch_size, vocab_size)
的 torch.Tensor
。你也可以使用自己的实现来分配位掩码。
在每个自回归步骤中,根据匹配器的当前状态使用 xgr.GrammarMatcher.fill_next_token_bitmask()
填充令牌位掩码。然后,使用 xgr.apply_token_bitmask_inplace()
将位掩码应用到模型的逻辑单元(logits)中,如果 logits
在CUDA上(推荐),则调用 CUDA 内核,否则使用 CPU 实现。
掩码处理后,非法令牌的逻辑单元(logits)被设置为负无穷大,这样就永远不会采样到它们。采样到令牌后,使用 xgr.GrammarMatcher.accept_token()
更新 xgr.GrammarMatcher
的状态。最后,使用 xgr.GrammarMatcher.reset()
为下一次生成做准备。
# Here we simulate a valid sampled response
sim_sampled_response = '{ "library": "xgrammar" }<|end_of_text|>'
sim_sampled_token_ids = tokenizer.encode(sim_sampled_response, add_special_tokens=False)
# Allocate a token bitmask
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
# Each loop iteration is a simulated auto-regressive step
for i, sim_token_id in enumerate(sim_sampled_token_ids):
# LLM inference to get logits, here we use randn to simulate.
# logits is a tensor of shape (full_vocab_size,) on GPU
# logits = LLM.inference()
logits = torch.randn(full_vocab_size).cuda()
# Apply bitmask to logits to mask invalid tokens
matcher.fill_next_token_bitmask(token_bitmask)
xgr.apply_token_bitmask_inplace(logits, token_bitmask.to(logits.device))
# Sample next token
probs = torch.softmax(logits, dim=-1).cpu().numpy()
next_token_id = np.random.choice(list(range(full_vocab_size)), p=probs)
# Accept token from matcher to update its state, so that the next bitmask
# generated will enforce the next token to be generated. Assert to make
# sure the token is indeed valid. Here we accept the simulated response
# assert matcher.accept_token(next_token_id)
assert matcher.accept_token(sim_token_id)
# Since we accepted a stop token `<|end_of_text|>`, we have terminated
assert matcher.is_terminated()
# Reset to be ready for the next auto-regressive generation
matcher.reset()
批量推理的结构化生成¶
上面的代码片段假设了单请求生成。本节将展示相同的概念如何应用于批量生成。
首先,按照上述完全相同的步骤为每个模型构造 xgr.TokenizerInfo
和 xgr.GrammarCompiler
。假设每个请求需要生成有效的 JSON。
import xgrammar as xgr
import torch
import numpy as np
from transformers import AutoTokenizer, AutoConfig
# Get tokenizer info
model_id = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = AutoConfig.from_pretrained(model_id)
# This can be larger than tokenizer.vocab_size due to paddings
full_vocab_size = config.vocab_size
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=full_vocab_size)
# Compile a JSON grammar
compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
compiled_grammar: xgr.CompiledGrammar = compiler.compile_builtin_json_grammar()
现在,需要为批次中的每个请求维护 xgr.GrammarMatcher
,因为每个请求都有不同的生成状态。请注意,批次中的每个请求可以遵循不同的 xgr.CompiledGrammar
,但这里为了简单起见,它们都遵循通用的 JSON 语法。
batch_size = 2
matchers = [
xgr.GrammarMatcher(compiled_grammar)
for i in range(batch_size)
]
token_bitmask = xgr.allocate_token_bitmask(batch_size, tokenizer_info.vocab_size)
模拟了批量推理的自回归生成过程。请注意,这里假设两个请求的生成长度相同以简化问题。但根据你的引擎如何支持批量推理,应该很容易进行推广。与单请求生成的关键区别在于,在批量请求生成中,每个请求都有自己的 xgr.GrammarMatcher
需要维护。
sim_sampled_responses = ['{"name": "a"}<|end_of_text|>', '{"name": "b"}<|end_of_text|>']
sim_sampled_token_ids = [
tokenizer.encode(response, add_special_tokens=False)
for response in sim_sampled_responses
]
# Each loop iteration is a simulated auto-regressive step
for loop_iter in range(len(sim_sampled_token_ids[0])):
# LLM batched inference to get logits, here we use randn to simulate
# Now, logits is a tensor of shape (batch_size, full_vocab_size) on GPU
# logits = LLM.inference()
logits = torch.randn(batch_size, full_vocab_size).cuda()
# This for loop is parallelizable using threading.Thread. But estimate
# the overhead in your engine.
for i in range(batch_size):
matchers[i].fill_next_token_bitmask(token_bitmask, i)
xgr.apply_token_bitmask_inplace(logits, token_bitmask.to(logits.device))
# Sample next token
probs = torch.softmax(logits, dim=-1).cpu().numpy()
next_token_ids = [
np.random.choice(list(range(full_vocab_size)), p=probs[i])
for i in range(batch_size)
]
# Update the matcher for each request
for i in range(batch_size):
# Here we accept the simulated response
# assert matchers[i].accept_token(next_token_ids[i])
matchers[i].accept_token(sim_sampled_token_ids[i][loop_iter])
# In our simulated case, all requests should have terminated since we accepted
# a stop token `<|end_of_text|>`
for i in range(batch_size):
assert matchers[i].is_terminated()
# Reset to be ready for the next generation
matchers[i].reset()