def __init__(self, inp_emb_sz, hidden_sz, num_tokens, apply_masc=True, T=1): super(GuideLanguage, self).__init__() self.hidden_sz = hidden_sz self.inp_emb_sz = inp_emb_sz self.num_tokens = num_tokens self.apply_masc = apply_masc self.T = T self.embed_fn = nn.Embedding(num_tokens, inp_emb_sz, padding_idx=0) self.encoder_fn = nn.LSTM(inp_emb_sz, hidden_sz // 2, batch_first=True, bidirectional=True) self.cbow_fn = CBoW(11, hidden_sz) self.T_prediction_fn = nn.Linear(hidden_sz, T + 1) self.feat_control_emb = nn.Parameter( torch.FloatTensor(hidden_sz).normal_(0.0, 0.1)) self.feat_control_step_fn = ControlStep(hidden_sz) if apply_masc: self.act_control_emb = nn.Parameter( torch.FloatTensor(hidden_sz).normal_(0.0, 0.1)) self.act_control_step_fn = ControlStep(hidden_sz) self.action_linear_fn = nn.Linear(hidden_sz, 9) self.landmark_write_gate = nn.ParameterList() self.obs_write_gate = nn.ParameterList() for _ in range(T + 1): self.landmark_write_gate.append( nn.Parameter( torch.FloatTensor(1, hidden_sz, 1, 1).normal_(0, 0.1))) self.obs_write_gate.append( nn.Parameter( torch.FloatTensor(1, hidden_sz).normal_(0.0, 0.1))) if apply_masc: self.masc_fn = MASC(self.hidden_sz) else: self.masc_fn = NoMASC(self.hidden_sz) self.loss = nn.CrossEntropyLoss()
def __init__(self, in_vocab_sz, num_landmarks, T=2, apply_masc=True): super(GuideContinuous, self).__init__() self.in_vocab_sz = in_vocab_sz self.num_landmarks = num_landmarks self.apply_masc = apply_masc self.T = T self.landmark_write_gate = nn.ParameterList() for _ in range(self.T + 1): self.landmark_write_gate.append(nn.Parameter(torch.FloatTensor(1, in_vocab_sz, 1, 1).normal_(0.0, 0.1))) self.cbow_fn = CBoW(num_landmarks, in_vocab_sz, init_std=0.01) if self.apply_masc: self.masc_fn = MASC(in_vocab_sz) self.extract_fns = nn.ModuleList() for _ in range(T): self.extract_fns.append(nn.Linear(in_vocab_sz, 9)) else: self.masc_fn = NoMASC(in_vocab_sz) self.loss = nn.CrossEntropyLoss(reduce=False)
def __init__(self, in_vocab_sz, num_landmarks, apply_masc=True, T=2): super(GuideDiscrete, self).__init__() self.in_vocab_sz = in_vocab_sz self.num_landmarks = num_landmarks self.T = T self.apply_masc = apply_masc self.emb_map = CBoW(num_landmarks, in_vocab_sz, init_std=0.1) self.obs_emb_fn = nn.Linear(in_vocab_sz, in_vocab_sz) self.landmark_write_gate = nn.ParameterList() for _ in range(T + 1): self.landmark_write_gate.append( nn.Parameter( torch.FloatTensor(1, in_vocab_sz, 1, 1).normal_(0.0, 0.1))) if apply_masc: self.masc_fn = MASC(in_vocab_sz) self.action_emb = nn.ModuleList() for i in range(T): self.action_emb.append(nn.Linear(in_vocab_sz, 9)) else: self.masc_fn = NoMASC(in_vocab_sz) self.loss = nn.CrossEntropyLoss(reduce=False)
class GuideLanguage(nn.Module): def __init__(self, inp_emb_sz, hidden_sz, num_tokens, apply_masc=True, T=1): super(GuideLanguage, self).__init__() self.hidden_sz = hidden_sz self.inp_emb_sz = inp_emb_sz self.num_tokens = num_tokens self.apply_masc = apply_masc self.T = T self.embed_fn = nn.Embedding(num_tokens, inp_emb_sz, padding_idx=0) self.encoder_fn = nn.LSTM(inp_emb_sz, hidden_sz // 2, batch_first=True, bidirectional=True) self.cbow_fn = CBoW(11, hidden_sz) self.T_prediction_fn = nn.Linear(hidden_sz, T + 1) self.feat_control_emb = nn.Parameter( torch.FloatTensor(hidden_sz).normal_(0.0, 0.1)) self.feat_control_step_fn = ControlStep(hidden_sz) if apply_masc: self.act_control_emb = nn.Parameter( torch.FloatTensor(hidden_sz).normal_(0.0, 0.1)) self.act_control_step_fn = ControlStep(hidden_sz) self.action_linear_fn = nn.Linear(hidden_sz, 9) self.landmark_write_gate = nn.ParameterList() self.obs_write_gate = nn.ParameterList() for _ in range(T + 1): self.landmark_write_gate.append( nn.Parameter( torch.FloatTensor(1, hidden_sz, 1, 1).normal_(0, 0.1))) self.obs_write_gate.append( nn.Parameter( torch.FloatTensor(1, hidden_sz).normal_(0.0, 0.1))) if apply_masc: self.masc_fn = MASC(self.hidden_sz) else: self.masc_fn = NoMASC(self.hidden_sz) self.loss = nn.CrossEntropyLoss() def forward(self, batch, add_rl_loss=False): batch_size = batch['utterance'].size(0) input_emb = self.embed_fn(batch['utterance']) hidden_states, _ = self.encoder_fn(input_emb) last_state_indices = batch['utterance_mask'].sum(1).long() - 1 last_hidden_states = hidden_states[torch.arange(batch_size).long(), last_state_indices, :] T_dist = F.softmax(self.T_prediction_fn(last_hidden_states)) sampled_Ts = T_dist.multinomial(1).squeeze(-1) obs_msgs = list() feat_controller = self.feat_control_emb.unsqueeze(0).repeat( batch_size, 1) for step in range(self.T + 1): extracted_msg, feat_controller = self.feat_control_step_fn( hidden_states, batch['utterance_mask'], feat_controller) obs_msgs.append(extracted_msg) tourist_obs_msg = [] for i, (gate, emb) in enumerate(zip(self.obs_write_gate, obs_msgs)): include = (i <= sampled_Ts).float().unsqueeze(-1) tourist_obs_msg.append(include * F.sigmoid(gate) * emb) tourist_obs_msg = sum(tourist_obs_msg) landmark_emb = self.cbow_fn(batch['landmarks']).permute(0, 3, 1, 2) landmark_embs = [landmark_emb] if self.apply_masc: act_controller = self.act_control_emb.unsqueeze(0).repeat( batch_size, 1) for step in range(self.T): extracted_msg, act_controller = self.act_control_step_fn( hidden_states, batch['utterance_mask'], act_controller) action_out = self.action_linear_fn(extracted_msg) out = self.masc_fn.forward(landmark_embs[-1], action_out, current_step=step, Ts=sampled_Ts) landmark_embs.append(out) else: for step in range(self.T): landmark_embs.append(self.masc_fn.forward(landmark_embs[-1])) landmarks = sum([ F.sigmoid(gate) * emb for gate, emb in zip(self.landmark_write_gate, landmark_embs) ]) landmarks = landmarks.resize(batch_size, landmarks.size(1), 16).transpose(1, 2) out = dict() logits = torch.bmm(landmarks, tourist_obs_msg.unsqueeze(-1)).squeeze(-1) out['prob'] = F.softmax(logits, dim=1) y_true = (batch['target'][:, 0] * 4 + batch['target'][:, 1]) out['sl_loss'] = -torch.log( torch.gather(out['prob'], 1, y_true.unsqueeze(-1)) + 1e-8) # add RL loss if add_rl_loss: advantage = -(out['sl_loss'] - out['sl_loss'].mean()).detach() log_prob = torch.log( torch.gather(T_dist, 1, sampled_Ts.unsqueeze(-1)) + 1e-8) out['rl_loss'] = log_prob * advantage out['acc'] = sum([ 1.0 for pred, target in zip(out['prob'].max(1)[1].data.cpu().numpy(), y_true.data.cpu().numpy()) if pred == target ]) / batch_size return out def save(self, path): state = dict() state['hidden_sz'] = self.hidden_sz state['embed_sz'] = self.inp_emb_sz state['num_tokens'] = self.num_tokens state['apply_masc'] = self.apply_masc state['T'] = self.T state['parameters'] = self.state_dict() torch.save(state, path) @classmethod def load(cls, path): state = torch.load(path) guide = cls(state['embed_sz'], state['hidden_sz'], state['num_tokens'], T=state['T'], apply_masc=state['apply_masc']) guide.load_state_dict(state['parameters']) return guide
class GuideContinuous(nn.Module): def __init__(self, in_vocab_sz, num_landmarks, T=2, apply_masc=True): super(GuideContinuous, self).__init__() self.in_vocab_sz = in_vocab_sz self.num_landmarks = num_landmarks self.apply_masc = apply_masc self.T = T self.landmark_write_gate = nn.ParameterList() for _ in range(self.T + 1): self.landmark_write_gate.append(nn.Parameter(torch.FloatTensor(1, in_vocab_sz, 1, 1).normal_(0.0, 0.1))) self.cbow_fn = CBoW(num_landmarks, in_vocab_sz, init_std=0.01) if self.apply_masc: self.masc_fn = MASC(in_vocab_sz) self.extract_fns = nn.ModuleList() for _ in range(T): self.extract_fns.append(nn.Linear(in_vocab_sz, 9)) else: self.masc_fn = NoMASC(in_vocab_sz) self.loss = nn.CrossEntropyLoss(reduce=False) def forward(self, msg, batch): obs_msg, act_msg = msg['obs'], msg['act'] l_emb = self.cbow_fn.forward(batch['landmarks']).permute(0, 3, 1, 2) l_embs = [l_emb] if self.apply_masc: for j in range(self.T): act_mask = self.extract_fns[j](act_msg) out = self.masc_fn.forward(l_embs[-1], act_mask) l_embs.append(out) else: for j in range(self.T): out = self.masc_fn.forward(l_emb) l_embs.append(out) landmarks = sum([F.sigmoid(gate)*emb for gate, emb in zip(self.landmark_write_gate, l_embs)]) landmarks = landmarks.resize(l_emb.size(0), landmarks.size(1), 16).transpose(1, 2) out = dict() logits = torch.bmm(landmarks, obs_msg.unsqueeze(-1)).squeeze(-1) out['prob'] = F.softmax(logits, dim=1) y_true = (batch['target'][:, 0]*4 + batch['target'][:, 1]) out['loss'] = self.loss(logits, y_true) out['acc'] = sum([1.0 for pred, target in zip(out['prob'].max(1)[1].data.cpu().numpy(), y_true.data.cpu().numpy()) if pred == target])/y_true.size(0) return out def save(self, path): state = dict() state['in_vocab_sz'] = self.in_vocab_sz state['num_landmarks'] = self.num_landmarks state['parameters'] = self.state_dict() state['T'] = self.T state['apply_masc'] = self.apply_masc torch.save(state, path) @classmethod def load(cls, path): state = torch.load(path) guide = cls(state['in_vocab_sz'], state['num_landmarks'], T=state['T'], apply_masc=state['apply_masc']) guide.load_state_dict(state['parameters']) return guide
class GuideDiscrete(nn.Module): def __init__(self, in_vocab_sz, num_landmarks, apply_masc=True, T=2): super(GuideDiscrete, self).__init__() self.in_vocab_sz = in_vocab_sz self.num_landmarks = num_landmarks self.T = T self.apply_masc = apply_masc self.emb_map = CBoW(num_landmarks, in_vocab_sz, init_std=0.1) self.obs_emb_fn = nn.Linear(in_vocab_sz, in_vocab_sz) self.landmark_write_gate = nn.ParameterList() for _ in range(T + 1): self.landmark_write_gate.append( nn.Parameter( torch.FloatTensor(1, in_vocab_sz, 1, 1).normal_(0.0, 0.1))) if apply_masc: self.masc_fn = MASC(in_vocab_sz) self.action_emb = nn.ModuleList() for i in range(T): self.action_emb.append(nn.Linear(in_vocab_sz, 9)) else: self.masc_fn = NoMASC(in_vocab_sz) self.loss = nn.CrossEntropyLoss(reduce=False) def forward(self, message, batch): msg_obs = self.obs_emb_fn(message[0]) batch_size = message[0].size(0) landmark_emb = self.emb_map.forward(batch['landmarks']).permute( 0, 3, 1, 2) landmark_embs = [landmark_emb] if self.apply_masc: for j in range(self.T): act_msg = message[1] action_out = self.action_emb[j](act_msg) out = self.masc_fn.forward(landmark_embs[-1], action_out, current_step=j) landmark_embs.append(out) else: for j in range(self.T): out = self.masc_fn.forward(landmark_embs[-1]) landmark_embs.append(out) landmarks = sum([ F.sigmoid(gate) * emb for gate, emb in zip(self.landmark_write_gate, landmark_embs) ]) landmarks = landmarks.view(batch_size, landmarks.size(1), 16).transpose(1, 2) out = dict() logits = torch.bmm(landmarks, msg_obs.unsqueeze(-1)).squeeze(-1) out['prob'] = F.softmax(logits, 1) y_true = (batch['target'][:, 0] * 4 + batch['target'][:, 1]) out['loss'] = self.loss(logits, y_true) out['acc'] = sum([ 1.0 for pred, target in zip(out['prob'].max(1)[1].data.cpu().numpy(), y_true.data.cpu().numpy()) if pred == target ]) / y_true.size(0) return out def save(self, path): state = dict() state['in_vocab_sz'] = self.in_vocab_sz state['num_landmarks'] = self.num_landmarks state['parameters'] = self.state_dict() state['T'] = self.T state['apply_masc'] = self.apply_masc torch.save(state, path) @classmethod def load(cls, path): state = torch.load(path) guide = cls(state['in_vocab_sz'], state['num_landmarks'], T=state['T'], apply_masc=state['apply_masc']) guide.load_state_dict(state['parameters']) return guide