import torch
import torch.nn as nn
import torch.nn.functional as F
class GRPO:
def __init__(self, policy, ref_policy, lr=1e-5, beta=0.02, eps_clip=0.2):
self.policy = policy
self.ref_policy = ref_policy
self.optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
self.beta = beta
self.eps_clip = eps_clip
def compute_loss(self, input_ids, old_logp, rewards, advantages):
"""
input_ids: [B, T]
old_logp: [B, T] 旧策略log概率
rewards: RM奖励
advantages: GAE优势
"""
new_logp = self.policy.log_prob(input_ids) # [B, T]
ratio = torch.exp(new_logp - old_logp) # [B, T]
# GRPO:组内归一化优势(每组4样本)
B = advantages.size(0)
group_size = 4
advantages =
# PPO裁剪
surr1 =
surr2 =
policy_loss =
# KL惩罚
ref_logp =
kl =
loss =
return loss
def step(self, input_ids, old_logp, rewards, advantages):
loss = self.compute_loss(input_ids, old_logp, rewards, advantages)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
self.optimizer.step()
return loss.item()