OpenClaw是一个基于自然语言理解的信息提取模型,这里提供一个简洁的PyTorch实现版本,包含核心架构和基本功能。

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer
class OpenClawSimplified(nn.Module):
"""
OpenClaw简化版模型
用于信息提取任务,结合自然语言查询和文档内容
"""
def __init__(self, bert_model_name='bert-base-uncased', hidden_size=768, num_labels=3):
"""
初始化模型
参数:
bert_model_name: BERT模型名称
hidden_size: 隐藏层维度
num_labels: 标签数量 (BIO标注)
"""
super(OpenClawSimplified, self).__init__()
# BERT编码器
self.bert = BertModel.from_pretrained(bert_model_name)
self.tokenizer = BertTokenizer.from_pretrained(bert_model_name)
# 查询注意力机制
self.query_attention = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=8,
dropout=0.1,
batch_first=True
)
# 条件层归一化
self.conditional_layer_norm = ConditionalLayerNorm(hidden_size)
# 跨度预测头
self.span_start_predictor = nn.Linear(hidden_size, 1)
self.span_end_predictor = nn.Linear(hidden_size, 1)
# 序列标注头
self.sequence_labeler = nn.Linear(hidden_size, num_labels)
# 融合层
self.fusion_layer = nn.Sequential(
nn.Linear(hidden_size * 2, hidden_size),
nn.ReLU(),
nn.Dropout(0.1)
)
def forward(self, input_ids, attention_mask, query_input_ids, query_attention_mask):
"""
前向传播
参数:
input_ids: 文档token IDs [batch_size, seq_len]
attention_mask: 文档注意力掩码
query_input_ids: 查询token IDs [batch_size, query_len]
query_attention_mask: 查询注意力掩码
返回:
span_start_logits: 跨度开始位置分数
span_end_logits: 跨度结束位置分数
sequence_logits: 序列标注分数
"""
batch_size = input_ids.size(0)
# 编码文档
document_outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
document_embeddings = document_outputs.last_hidden_state # [batch_size, seq_len, hidden_size]
# 编码查询
query_outputs = self.bert(
input_ids=query_input_ids,
attention_mask=query_attention_mask
)
query_embeddings = query_outputs.last_hidden_state # [batch_size, query_len, hidden_size]
# 查询感知的文档表示
# 使用查询注意力机制
query_aware_doc, _ = self.query_attention(
query=document_embeddings,
key=query_embeddings,
value=query_embeddings,
key_padding_mask=~query_attention_mask.bool() if query_attention_mask is not None else None
)
# 条件层归一化
normalized_doc = self.conditional_layer_norm(document_embeddings, query_aware_doc)
# 融合文档原始表示和查询感知表示
fused_representation = self.fusion_layer(
torch.cat([document_embeddings, normalized_doc], dim=-1)
)
# 跨度预测
span_start_logits = self.span_start_predictor(fused_representation).squeeze(-1)
span_end_logits = self.span_end_predictor(fused_representation).squeeze(-1)
# 序列标注
sequence_logits = self.sequence_labeler(fused_representation)
return span_start_logits, span_end_logits, sequence_logits
def extract_spans(self, input_ids, attention_mask, query_input_ids, query_attention_mask, threshold=0.5):
"""
提取文本跨度
参数:
threshold: 预测阈值
"""
span_start_logits, span_end_logits, _ = self.forward(
input_ids, attention_mask, query_input_ids, query_attention_mask
)
# 应用sigmoid获取概率
start_probs = torch.sigmoid(span_start_logits)
end_probs = torch.sigmoid(span_end_logits)
# 提取跨度
batch_spans = []
for i in range(start_probs.size(0)):
spans = []
start_positions = torch.where(start_probs[i] > threshold)[0].tolist()
end_positions = torch.where(end_probs[i] > threshold)[0].tolist()
# 简单匹配开始和结束位置
for start in start_positions:
# 寻找最近的结束位置
possible_ends = [end for end in end_positions if end >= start]
if possible_ends:
end = min(possible_ends)
spans.append((start.item(), end.item()))
batch_spans.append(spans)
return batch_spans
class ConditionalLayerNorm(nn.Module):
"""
条件层归一化
根据条件信息调整归一化
"""
def __init__(self, hidden_size, eps=1e-12):
super(ConditionalLayerNorm, self).__init__()
self.hidden_size = hidden_size
self.eps = eps
# 可学习的缩放和偏移参数
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
# 条件变换参数
self.condition_weight = nn.Linear(hidden_size, hidden_size)
self.condition_bias = nn.Linear(hidden_size, hidden_size)
def forward(self, x, condition):
"""
参数:
x: 输入张量 [batch_size, seq_len, hidden_size]
condition: 条件张量 [batch_size, seq_len, hidden_size]
"""
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
# 标准归一化
normalized = (x - mean) / (std + self.eps)
# 应用条件变换
condition_scale = self.condition_weight(condition)
condition_shift = self.condition_bias(condition)
# 条件归一化
output = normalized * self.weight + self.bias
output = output * condition_scale + condition_shift
return output
class OpenClawTrainer:
"""
OpenClaw训练器
"""
def __init__(self, model, learning_rate=2e-5):
self.model = model
self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# 损失函数
self.span_loss_fn = nn.BCEWithLogitsLoss()
self.sequence_loss_fn = nn.CrossEntropyLoss()
def train_step(self, batch):
"""
单步训练
"""
self.model.train()
self.optimizer.zero_grad()
# 解包批次数据
(input_ids, attention_mask,
query_input_ids, query_attention_mask,
span_start_labels, span_end_labels, sequence_labels) = batch
# 前向传播
span_start_logits, span_end_logits, sequence_logits = self.model(
input_ids, attention_mask,
query_input_ids, query_attention_mask
)
# 计算损失
span_start_loss = self.span_loss_fn(
span_start_logits, span_start_labels.float()
)
span_end_loss = self.span_loss_fn(
span_end_logits, span_end_labels.float()
)
# 序列标注损失
batch_size, seq_len, num_labels = sequence_logits.shape
sequence_loss = self.sequence_loss_fn(
sequence_logits.view(-1, num_labels),
sequence_labels.view(-1)
)
# 总损失
total_loss = span_start_loss + span_end_loss + sequence_loss
# 反向传播
total_loss.backward()
self.optimizer.step()
return {
'total_loss': total_loss.item(),
'span_start_loss': span_start_loss.item(),
'span_end_loss': span_end_loss.item(),
'sequence_loss': sequence_loss.item()
}
def evaluate(self, dataloader):
"""
评估模型
"""
self.model.eval()
total_loss = 0
total_samples = 0
with torch.no_grad():
for batch in dataloader:
(input_ids, attention_mask,
query_input_ids, query_attention_mask,
span_start_labels, span_end_labels, sequence_labels) = batch
# 前向传播
span_start_logits, span_end_logits, sequence_logits = self.model(
input_ids, attention_mask,
query_input_ids, query_attention_mask
)
# 计算损失
span_start_loss = self.span_loss_fn(
span_start_logits, span_start_labels.float()
)
span_end_loss = self.span_loss_fn(
span_end_logits, span_end_labels.float()
)
batch_size, seq_len, num_labels = sequence_logits.shape
sequence_loss = self.sequence_loss_fn(
sequence_logits.view(-1, num_labels),
sequence_labels.view(-1)
)
total_loss += (span_start_loss + span_end_loss + sequence_loss).item()
total_samples += batch_size
return total_loss / total_samples if total_samples > 0 else 0
def create_sample_batch():
"""
创建示例批次数据
"""
# 模拟数据
batch_size = 2
seq_len = 10
query_len = 5
num_labels = 3 # BIO标签
# 文档token IDs
input_ids = torch.randint(100, 1000, (batch_size, seq_len))
attention_mask = torch.ones((batch_size, seq_len))
# 查询token IDs
query_input_ids = torch.randint(100, 1000, (batch_size, query_len))
query_attention_mask = torch.ones((batch_size, query_len))
# 跨度标签 (二分类)
span_start_labels = torch.randint(0, 2, (batch_size, seq_len)).float()
span_end_labels = torch.randint(0, 2, (batch_size, seq_len)).float()
# 序列标签
sequence_labels = torch.randint(0, num_labels, (batch_size, seq_len))
return (input_ids, attention_mask,
query_input_ids, query_attention_mask,
span_start_labels, span_end_labels, sequence_labels)
# 使用示例
if __name__ == "__main__":
# 初始化模型
model = OpenClawSimplified(num_labels=3)
# 创建训练器
trainer = OpenClawTrainer(model)
# 创建示例数据
batch = create_sample_batch()
# 训练步骤
losses = trainer.train_step(batch)
print("训练损失:", losses)
# 提取跨度示例
(input_ids, attention_mask,
query_input_ids, query_attention_mask, _, _, _) = batch
spans = model.extract_spans(
input_ids, attention_mask,
query_input_ids, query_attention_mask
)
print("提取的跨度:", spans)
# 模型参数量统计
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"总参数: {total_params:,}")
print(f"可训练参数: {trainable_params:,}")
主要特点
- 双编码器架构: 同时编码文档和自然语言查询
- 查询注意力机制: 使用多头注意力让文档表示关注查询相关信息
- 条件层归一化: 根据查询信息调整文档表示
- 多任务学习: 同时进行跨度提取和序列标注
- 轻量级设计: 相比于完整OpenClaw,移除了复杂模块,保留核心功能
使用方式
# 1. 初始化模型
model = OpenClawSimplified(num_labels=3) # 3表示BIO标签数量
# 2. 准备数据
# 需要文档和查询的tokenized输入
# 3. 训练模型
trainer = OpenClawTrainer(model)
for epoch in range(num_epochs):
for batch in dataloader:
losses = trainer.train_step(batch)
# 4. 提取信息
spans = model.extract_spans(
document_tokens,
query_tokens,
threshold=0.5
)
这个简洁版OpenClaw保留了原模型的核心思想,同时大大简化了实现复杂度,适合学习、实验和小规模部署。
版权声明:除非特别标注,否则均为本站原创文章,转载时请以链接形式注明文章出处。