def __init__(self): super(IC_DA, self).__init__() self.s_dim = s_dim self.a_dim = a_dim self.inv_h_dim = inv_h_dim self.fwd_h_dim = fwd_h_dim self.s_enc_out_dim = s_enc_out_dim self.icm_beta = icm_beta # Multiwoz vector voc_file = os.path.join(path_convlab, 'data/multiwoz/sys_da_voc.txt') voc_opp_file = os.path.join(path_convlab, 'data/multiwoz/usr_da_voc.txt') self.vector = MultiWozVector(voc_file, voc_opp_file) # encoder self.state_encoder = nn.Sequential( nn.Linear(self.a_dim, self.s_enc_out_dim), nn.ReLU(), nn.Linear(self.s_enc_out_dim, self.s_enc_out_dim)).to(device=DEVICE) # inverse and forward models heads self.inv_head = nn.Sequential( nn.Linear(2 * self.s_enc_out_dim, self.inv_h_dim), nn.ReLU(), nn.Linear(self.inv_h_dim, self.inv_h_dim), nn.ReLU(), nn.Linear(self.inv_h_dim, self.a_dim)).to(device=DEVICE) self.fwd_head = nn.Sequential( nn.Linear(self.a_dim + self.s_enc_out_dim, self.fwd_h_dim), nn.ReLU(), nn.Linear(self.fwd_h_dim, self.fwd_h_dim), nn.ReLU(), nn.Linear(self.fwd_h_dim, self.s_enc_out_dim)).to(device=DEVICE)
def __init__(self): self.update_round = update_round self.optim_batchsz = optim_batchsz self.gamma = gamma self.epsilon = surrogate_clip self.tau = tau self.policy_lr = policy_lr self.value_lr = value_lr # featurizer voc_file = os.path.join(path_convlab, 'data/multiwoz/sys_da_voc.txt') voc_opp_file = os.path.join(path_convlab, 'data/multiwoz/usr_da_voc.txt') self.vector = MultiWozVector(voc_file, voc_opp_file) # construct policy and value network self.policy = MultiDiscretePolicy(self.vector.state_dim, h_dim, self.vector.da_dim).to(device=DEVICE) self.value = Value(self.vector.state_dim, hv_dim).to(device=DEVICE) # optimizers self.policy_optim = optim.AdamW(self.policy.parameters(), lr=self.policy_lr) self.value_optim = optim.AdamW(self.value.parameters(), lr=self.value_lr) # load pre-trained policy net load_mle(self.policy)
def __init__(self, a_dim, rnd_h_dim=524, out_dim=340): super(RND_DA, self).__init__() # Multiwoz vector voc_file = os.path.join(path_convlab, 'data/multiwoz/sys_da_voc.txt') voc_opp_file = os.path.join(path_convlab, 'data/multiwoz/usr_da_voc.txt') self.vector = MultiWozVector(voc_file, voc_opp_file) # net self.net = nn.Sequential(nn.Linear(2 * a_dim, rnd_h_dim), nn.ReLU(), nn.Linear(rnd_h_dim, rnd_h_dim), nn.ReLU(), nn.Linear(rnd_h_dim, out_dim)) self.net = self.net.float().to(device=DEVICE)
def __init__(self): super(ICM_UTT, self).__init__() self.user_nlg = user_nlg self.sys_nlg = sys_nlg self.s_dim = s_dim self.a_dim = a_dim self.inv_h_dim = inv_h_dim self.fwd_h_dim = fwd_h_dim self.s_enc_out_dim = s_enc_out_dim self.max_len = max_len self.classifier_only = classifier_only self.icm_beta = icm_beta # Multiwoz vector voc_file = os.path.join(path_convlab, 'data/multiwoz/sys_da_voc.txt') voc_opp_file = os.path.join(path_convlab, 'data/multiwoz/usr_da_voc.txt') self.vector = MultiWozVector(voc_file, voc_opp_file) # tokenizer self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # --- state encoder --- self.state_encoder = BertForSequenceClassification.from_pretrained( "bert-base-uncased", num_labels=self.s_enc_out_dim).to(device=DEVICE) # inverse and forward models heads self.inv_head = nn.Sequential( nn.Linear(2 * self.s_enc_out_dim, self.inv_h_dim), nn.ReLU(), nn.Linear(self.inv_h_dim, self.inv_h_dim), nn.ReLU(), nn.Linear(self.inv_h_dim, self.a_dim)).to(device=DEVICE) self.fwd_head = nn.Sequential( nn.Linear(self.max_len + self.s_enc_out_dim, self.fwd_h_dim), nn.ReLU(), nn.Linear(self.fwd_h_dim, self.fwd_h_dim), nn.ReLU(), nn.Linear(self.fwd_h_dim, self.s_enc_out_dim)).to(device=DEVICE) # freeze state encoder but head if classifier_only: for name, param in self.state_encoder.named_parameters(): if 'classifier' not in name: # classifier layer param.requires_grad = False
class RND_DA(nn.Module): def __init__(self, a_dim, rnd_h_dim=524, out_dim=340): super(RND_DA, self).__init__() # Multiwoz vector voc_file = os.path.join(path_convlab, 'data/multiwoz/sys_da_voc.txt') voc_opp_file = os.path.join(path_convlab, 'data/multiwoz/usr_da_voc.txt') self.vector = MultiWozVector(voc_file, voc_opp_file) # net self.net = nn.Sequential(nn.Linear(2 * a_dim, rnd_h_dim), nn.ReLU(), nn.Linear(rnd_h_dim, rnd_h_dim), nn.ReLU(), nn.Linear(rnd_h_dim, out_dim)) self.net = self.net.float().to(device=DEVICE) def da_tensorize(self, user_da, sys_da): user_da_vec = self.vector.action_vectorize(user_da) sys_da_vec = self.vector.action_vectorize(sys_da) user_da_tensor = torch.tensor(user_da_vec).float().to(device=DEVICE) sys_da_tensor = torch.tensor(sys_da_vec).float().to(device=DEVICE) # concat das das_tensor = torch.cat((user_da_tensor, sys_da_tensor), 0) return das_tensor def forward(self, user_da, sys_da): # vectorize das and concatenate das_tensor = self.da_tensorize(user_da, sys_da) # pass thru net x = self.net(das_tensor) return x
class IC_DA(nn.Module): def __init__(self): super(IC_DA, self).__init__() self.s_dim = s_dim self.a_dim = a_dim self.inv_h_dim = inv_h_dim self.fwd_h_dim = fwd_h_dim self.s_enc_out_dim = s_enc_out_dim self.icm_beta = icm_beta # Multiwoz vector voc_file = os.path.join(path_convlab, 'data/multiwoz/sys_da_voc.txt') voc_opp_file = os.path.join(path_convlab, 'data/multiwoz/usr_da_voc.txt') self.vector = MultiWozVector(voc_file, voc_opp_file) # encoder self.state_encoder = nn.Sequential( nn.Linear(self.a_dim, self.s_enc_out_dim), nn.ReLU(), nn.Linear(self.s_enc_out_dim, self.s_enc_out_dim)).to(device=DEVICE) # inverse and forward models heads self.inv_head = nn.Sequential( nn.Linear(2 * self.s_enc_out_dim, self.inv_h_dim), nn.ReLU(), nn.Linear(self.inv_h_dim, self.inv_h_dim), nn.ReLU(), nn.Linear(self.inv_h_dim, self.a_dim)).to(device=DEVICE) self.fwd_head = nn.Sequential( nn.Linear(self.a_dim + self.s_enc_out_dim, self.fwd_h_dim), nn.ReLU(), nn.Linear(self.fwd_h_dim, self.fwd_h_dim), nn.ReLU(), nn.Linear(self.fwd_h_dim, self.s_enc_out_dim)).to(device=DEVICE) def da_tensorize(self, da): da_vec = self.vector.action_vectorize(da) da_tensor = torch.tensor(da_vec).float().to(device=DEVICE) return da_tensor def forward(self, state_da, action, next_state_da): action_vec = self.vector.action_vectorize(action) # tensorize inputs state_tensor = self.da_tensorize(state_da) action_tensor = self.da_tensorize(action) next_state_tensor = self.da_tensorize(next_state_da) # get encodings phi_state = self.state_encoder(state_tensor) phi_next_state = self.state_encoder(next_state_tensor) # --- Inverse model pass --- # concat both encodings phi_concat = torch.cat((phi_state, phi_next_state), 0) # pass thru inverse model action_vec_est = torch.sigmoid(self.inv_head(phi_concat)) # --- Forward model pass --- # concat state encoding with action token tensors phi_s_a_concat = torch.cat((phi_state, action_tensor), 0) # pass thru forward model phi_next_state_est = self.fwd_head(phi_s_a_concat) return action_vec, action_vec_est, phi_next_state, phi_next_state_est def get_intrinsic_rewards(self, user_das, sys_das, mask, eta=0.01): intrinsic_rewards = [] # iterate over batch and form tuples of (state, action, next_state) for i in range(len(user_das) - 1): # not count if end of dialogue if mask[i].item() == 0.0: continue # get tuple state_da = user_das[i] action = sys_das[i + 1] next_state_da = user_das[i + 1] # get estimates from inverse and forward models _, _, phi_next_state, phi_next_state_est = self.forward( state_da, action, next_state_da) # extrinsic reward intrinsic_reward = (eta / 2.0) * torch.mean( (phi_next_state_est - phi_next_state)**2) # append to list intrinsic_rewards.append(intrinsic_reward.item()) return intrinsic_rewards def compute_loss(self, user_das, sys_das, mask): loss = torch.tensor(0.).to(device=DEVICE) # iterate over batch and form tuples of (state, action, next_state) for i in range(len(user_das) - 1): # not count if end of dialogue if mask[i].item() == 0.0: continue # get tuple state_da = user_das[i] #state_utt = user_nlg.generate(user_das[i]) action = sys_das[i + 1] #action_utt = sys_nlg.generate(sys_das[i+1]) next_state_da = user_das[i + 1] #next_state_utt = user_nlg.generate(user_das[i+1]) # get estimates from inverse and forward models action_vec, action_vec_est, phi_next_state, phi_next_state_est = self.forward( state_da, action, next_state_da) # inverse model loss loss_inv = torch.mean((torch.Tensor(action_vec).to(device=DEVICE) - action_vec_est)**2) # forward model loss loss_fwd = torch.mean((phi_next_state_est - phi_next_state)**2) # total ICM loss weighted by icm_beta loss_step = (1 - self.icm_beta) * loss_inv + self.icm_beta * loss_fwd loss += loss_step return loss / len(user_das)
class ICM_UTT(nn.Module): def __init__(self): super(ICM_UTT, self).__init__() self.user_nlg = user_nlg self.sys_nlg = sys_nlg self.s_dim = s_dim self.a_dim = a_dim self.inv_h_dim = inv_h_dim self.fwd_h_dim = fwd_h_dim self.s_enc_out_dim = s_enc_out_dim self.max_len = max_len self.classifier_only = classifier_only self.icm_beta = icm_beta # Multiwoz vector voc_file = os.path.join(path_convlab, 'data/multiwoz/sys_da_voc.txt') voc_opp_file = os.path.join(path_convlab, 'data/multiwoz/usr_da_voc.txt') self.vector = MultiWozVector(voc_file, voc_opp_file) # tokenizer self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # --- state encoder --- self.state_encoder = BertForSequenceClassification.from_pretrained( "bert-base-uncased", num_labels=self.s_enc_out_dim).to(device=DEVICE) # inverse and forward models heads self.inv_head = nn.Sequential( nn.Linear(2 * self.s_enc_out_dim, self.inv_h_dim), nn.ReLU(), nn.Linear(self.inv_h_dim, self.inv_h_dim), nn.ReLU(), nn.Linear(self.inv_h_dim, self.a_dim)).to(device=DEVICE) self.fwd_head = nn.Sequential( nn.Linear(self.max_len + self.s_enc_out_dim, self.fwd_h_dim), nn.ReLU(), nn.Linear(self.fwd_h_dim, self.fwd_h_dim), nn.ReLU(), nn.Linear(self.fwd_h_dim, self.s_enc_out_dim)).to(device=DEVICE) # freeze state encoder but head if classifier_only: for name, param in self.state_encoder.named_parameters(): if 'classifier' not in name: # classifier layer param.requires_grad = False def state_utt_tensorize(self, state_utt): state_tokens = self.tokenizer.tokenize(state_utt) # to cover all corner cases if len(state_tokens) >= self.max_len: tokens = ["[CLS]"] + state_tokens[:(self.max_len - 2)] + [ "[SEP]" ] + ["[PAD]"] * (self.max_len - len(state_tokens) - 2) elif len(state_tokens) == self.max_len - 1: tokens = ["[CLS]"] + state_tokens[:(self.max_len - 2)] + ["[SEP]"] else: tokens = ["[CLS]"] + state_tokens + [ "[SEP]" ] + ["[PAD]"] * (self.max_len - len(state_tokens) - 2) indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens) # to cover all corner cases if len(state_tokens) >= self.max_len: mask_ids = [0] * (len(state_tokens[:(self.max_len - 2)]) + 2) + [1] * (self.max_len - len(state_tokens) - 2) elif len(state_tokens) == self.max_len - 1: mask_ids = [0] * (len(state_tokens)) + [1] else: mask_ids = [0] * (len(state_tokens) + 2) + [1] * (self.max_len - len(state_tokens) - 2) state_tokens_tensor = torch.tensor([indexed_tokens]).to(device=DEVICE) state_mask_tensor = torch.tensor([mask_ids]).to(device=DEVICE) return state_tokens_tensor, state_mask_tensor def action_utt_tensorize(self, action_utt): action_tokens = self.tokenizer.tokenize(action_utt) # to cover all corner cases if len(action_tokens) >= self.max_len: tokens = ["[CLS]"] + action_tokens[:(self.max_len - 2)] + [ "[SEP]" ] + ["[PAD]"] * (self.max_len - len(action_tokens) - 2) elif len(action_tokens) == self.max_len - 1: tokens = ["[CLS]"] + action_tokens[:(self.max_len - 2)] + ["[SEP]"] else: tokens = ["[CLS]"] + action_tokens + [ "[SEP]" ] + ["[PAD]"] * (self.max_len - len(action_tokens) - 2) indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens) # to cover all corner cases if len(action_tokens) >= self.max_len: mask_ids = [0] * (len(action_tokens[:(self.max_len - 2)]) + 2 ) + [1] * (self.max_len - len(action_tokens) - 2) elif len(action_tokens) == self.max_len - 1: mask_ids = [0] * (len(action_tokens)) + [1] else: mask_ids = [0] * (len(action_tokens) + 2) + [1] * ( self.max_len - len(action_tokens) - 2) action_tokens_tensor = torch.tensor([indexed_tokens]).to(device=DEVICE) action_mask_tensor = torch.tensor([mask_ids]).to(device=DEVICE) return action_tokens_tensor, action_mask_tensor def forward(self, state_utt, action_da, next_state_utt): # utterize (?) action action_utt = self.user_nlg.generate(action_da) # vectorize action action_vec = self.vector.action_vectorize(action_da) # tensorize state_tokens_tensor, state_mask_tensor = self.state_utt_tensorize( state_utt) action_tokens_tensor, action_mask_tensor = self.action_utt_tensorize( action_utt) next_state_tokens_tensor, next_state_mask_tensor = self.state_utt_tensorize( next_state_utt) # get encodings phi_state = self.state_encoder(state_tokens_tensor, state_mask_tensor)[0].squeeze(0) phi_next_state = self.state_encoder( next_state_tokens_tensor, next_state_mask_tensor)[0].squeeze(0) # --- Inverse model pass --- # concat both encodings phi_concat = torch.cat((phi_state, phi_next_state), 0) # pass thru inverse model action_vec_est = torch.sigmoid(self.inv_head(phi_concat)) # --- Forward model pass --- # concat state encoding with action token tensors phi_s_a_concat = torch.cat( (phi_state, action_tokens_tensor.squeeze(0)), 0) # pass thru forward model phi_next_state_est = self.fwd_head(phi_s_a_concat) return action_vec, action_vec_est, phi_next_state, phi_next_state_est def get_intrinsic_rewards(self, user_das, sys_das, mask, eta=0.01): intrinsic_rewards = [] # iterate over batch and form tuples of (state, action, next_state) for i in range(len(user_das) - 1): # not count if end of dialogue if mask[i] == 0.0: intrinsic_rewards.append(0.0) # get tuple state_utt = self.user_nlg.generate(user_das[i]) action_da = sys_das[i + 1] next_state_utt = self.user_nlg.generate(user_das[i + 1]) # get estimates from inverse and forward models _, _, phi_next_state, phi_next_state_est = self.forward( state_utt, action_da, next_state_utt) # extrinsic reward intrinsic_reward = (eta / 2.0) * torch.mean( (phi_next_state_est - phi_next_state)**2) # append to list intrinsic_rewards.append(intrinsic_reward.item()) return intrinsic_rewards def compute_loss(self, user_das, sys_das, mask): loss = torch.tensor(0.).to(device=DEVICE) # iterate over batch and form tuples of (state, action, next_state) for i in range(len(user_das) - 1): # not count if end of dialogue if mask[i] == 0.0: continue # get tuple state_utt = self.user_nlg.generate(user_das[i]) action_da = sys_das[i + 1] next_state_utt = self.user_nlg.generate(user_das[i + 1]) # get estimates from inverse and forward models action_vec, action_vec_est, phi_next_state, phi_next_state_est = self.forward( state_utt, action_da, next_state_utt) # inverse model loss loss_inv = torch.mean((torch.Tensor(action_vec).to(device=DEVICE) - action_vec_est)**2) # forward model loss loss_fwd = torch.mean((phi_next_state_est - phi_next_state)**2) # total ICM loss weighted by icm_beta loss_step = (1 - self.icm_beta) * loss_inv + self.icm_beta * loss_fwd loss += loss_step return loss / len(user_das)
class PPO(Policy): def __init__(self): self.update_round = update_round self.optim_batchsz = optim_batchsz self.gamma = gamma self.epsilon = surrogate_clip self.tau = tau self.policy_lr = policy_lr self.value_lr = value_lr # featurizer voc_file = os.path.join(path_convlab, 'data/multiwoz/sys_da_voc.txt') voc_opp_file = os.path.join(path_convlab, 'data/multiwoz/usr_da_voc.txt') self.vector = MultiWozVector(voc_file, voc_opp_file) # construct policy and value network self.policy = MultiDiscretePolicy(self.vector.state_dim, h_dim, self.vector.da_dim).to(device=DEVICE) self.value = Value(self.vector.state_dim, hv_dim).to(device=DEVICE) # optimizers self.policy_optim = optim.AdamW(self.policy.parameters(), lr=self.policy_lr) self.value_optim = optim.AdamW(self.value.parameters(), lr=self.value_lr) # load pre-trained policy net load_mle(self.policy) def predict(self, state): """ Predict an system action given state. Args: state (dict): Dialog state. Please refer to util/state.py Returns: action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...}) """ s_vec = torch.Tensor(self.vector.state_vectorize(state)) a = self.policy.select_action(s_vec.to(device=DEVICE), False).cpu() action = self.vector.action_devectorize(a.numpy()) state['system_action'] = action return action def init_session(self): """ Restore after one session """ pass def est_adv(self, r, v, mask): """ we save a trajectory in continuous space and it reaches the ending of current trajectory when mask=0. :param r: reward, Tensor, [b] :param v: estimated value, Tensor, [b] :param mask: indicates ending for 0 otherwise 1, Tensor, [b] :return: A(s, a), V-target(s), both Tensor """ batchsz = v.size(0) # v_target is worked out by Bellman equation. v_target = torch.Tensor(batchsz).to(device=DEVICE) delta = torch.Tensor(batchsz).to(device=DEVICE) A_sa = torch.Tensor(batchsz).to(device=DEVICE) prev_v_target = 0 prev_v = 0 prev_A_sa = 0 for t in reversed(range(batchsz)): # mask here indicates a end of trajectory # this value will be treated as the target value of value network. # mask = 0 means the immediate reward is the real V(s) since it's end of trajectory. # formula: V(s_t) = r_t + gamma * V(s_t+1) v_target[t] = r[t] + self.gamma * prev_v_target * mask[t] # please refer to : https://arxiv.org/abs/1506.02438 # for generalized adavantage estimation # formula: delta(s_t) = r_t + gamma * V(s_t+1) - V(s_t) delta[t] = r[t] + self.gamma * prev_v * mask[t] - v[t] # formula: A(s, a) = delta(s_t) + gamma * lamda * A(s_t+1, a_t+1) # here use symbol tau as lambda, but original paper uses symbol lambda. A_sa[t] = delta[t] + self.gamma * self.tau * prev_A_sa * mask[t] # update previous prev_v_target = v_target[t] prev_v = v[t] prev_A_sa = A_sa[t] # normalize A_sa A_sa = (A_sa - A_sa.mean()) / A_sa.std() return A_sa, v_target def update(self, epoch, batchsz, s, a, r, mask): # get estimated V(s) and PI_old(s, a) # actually, PI_old(s, a) can be saved when interacting with env, so as to save the time of one forward elapsed # v: [b, 1] => [b] v = self.value(s).squeeze(-1).detach() log_pi_old_sa = self.policy.get_log_prob(s, a).detach() # estimate advantage and v_target according to GAE and Bellman Equation A_sa, v_target = self.est_adv(r, v, mask) for i in range(self.update_round): # 1. shuffle current batch perm = torch.randperm(batchsz) # shuffle the variable for mutliple optimize v_target_shuf, A_sa_shuf, s_shuf, a_shuf, log_pi_old_sa_shuf = v_target[perm], A_sa[perm], s[perm], a[perm], \ log_pi_old_sa[perm] # 2. get mini-batch for optimizing optim_chunk_num = int(np.ceil(batchsz / self.optim_batchsz)) # chunk the optim_batch for total batch v_target_shuf, A_sa_shuf, s_shuf, a_shuf, log_pi_old_sa_shuf = torch.chunk(v_target_shuf, optim_chunk_num), \ torch.chunk(A_sa_shuf, optim_chunk_num), \ torch.chunk(s_shuf, optim_chunk_num), \ torch.chunk(a_shuf, optim_chunk_num), \ torch.chunk(log_pi_old_sa_shuf, optim_chunk_num) # 3. iterate all mini-batch to optimize policy_loss, value_loss = 0., 0. for v_target_b, A_sa_b, s_b, a_b, log_pi_old_sa_b in zip( v_target_shuf, A_sa_shuf, s_shuf, a_shuf, log_pi_old_sa_shuf): # print('optim:', batchsz, v_target_b.size(), A_sa_b.size(), s_b.size(), a_b.size(), log_pi_old_sa_b.size()) # 1. update value network self.value_optim.zero_grad() v_b = self.value(s_b).squeeze(-1) loss = (v_b - v_target_b).pow(2).mean() value_loss += loss.item() # backprop loss.backward() # nn.utils.clip_grad_norm(self.value.parameters(), 4) self.value_optim.step() # 2. update policy network by clipping self.policy_optim.zero_grad() # [b, 1] log_pi_sa = self.policy.get_log_prob(s_b, a_b) # ratio = exp(log_Pi(a|s) - log_Pi_old(a|s)) = Pi(a|s) / Pi_old(a|s) # we use log_pi for stability of numerical operation # [b, 1] => [b] ratio = (log_pi_sa - log_pi_old_sa_b).exp().squeeze(-1) # because the joint action prob is the multiplication of the prob of each da # it may become extremely small # and the ratio may be inf in this case, which causes the gradient to be nan # clamp in case of the inf ratio, which causes the gradient to be nan ratio = torch.clamp(ratio, 0, 10) surrogate1 = ratio * A_sa_b surrogate2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * A_sa_b # this is element-wise comparing. # we add negative symbol to convert gradient ascent to gradient descent surrogate = -torch.min(surrogate1, surrogate2).mean() policy_loss += surrogate.item() # backprop surrogate.backward() # although the ratio is clamped, the grad may still contain nan due to 0 * inf # set the inf in the gradient to 0 for p in self.policy.parameters(): p.grad[p.grad != p.grad] = 0.0 # gradient clipping, for stability torch.nn.utils.clip_grad_norm_(self.policy.parameters(), grad_clip) # optim step self.policy_optim.step() value_loss /= optim_chunk_num policy_loss /= optim_chunk_num