Example #1
0
    def __init__(self,
                 inp_emb_sz,
                 hidden_sz,
                 num_tokens,
                 apply_masc=True,
                 T=1):
        super(GuideLanguage, self).__init__()
        self.hidden_sz = hidden_sz
        self.inp_emb_sz = inp_emb_sz
        self.num_tokens = num_tokens
        self.apply_masc = apply_masc
        self.T = T

        self.embed_fn = nn.Embedding(num_tokens, inp_emb_sz, padding_idx=0)
        self.encoder_fn = nn.LSTM(inp_emb_sz,
                                  hidden_sz // 2,
                                  batch_first=True,
                                  bidirectional=True)
        self.cbow_fn = CBoW(11, hidden_sz)

        self.T_prediction_fn = nn.Linear(hidden_sz, T + 1)

        self.feat_control_emb = nn.Parameter(
            torch.FloatTensor(hidden_sz).normal_(0.0, 0.1))
        self.feat_control_step_fn = ControlStep(hidden_sz)

        if apply_masc:
            self.act_control_emb = nn.Parameter(
                torch.FloatTensor(hidden_sz).normal_(0.0, 0.1))
            self.act_control_step_fn = ControlStep(hidden_sz)
            self.action_linear_fn = nn.Linear(hidden_sz, 9)

        self.landmark_write_gate = nn.ParameterList()
        self.obs_write_gate = nn.ParameterList()
        for _ in range(T + 1):
            self.landmark_write_gate.append(
                nn.Parameter(
                    torch.FloatTensor(1, hidden_sz, 1, 1).normal_(0, 0.1)))
            self.obs_write_gate.append(
                nn.Parameter(
                    torch.FloatTensor(1, hidden_sz).normal_(0.0, 0.1)))

        if apply_masc:
            self.masc_fn = MASC(self.hidden_sz)
        else:
            self.masc_fn = NoMASC(self.hidden_sz)

        self.loss = nn.CrossEntropyLoss()
Example #2
0
    def __init__(self, in_vocab_sz, num_landmarks, T=2, apply_masc=True):
        super(GuideContinuous, self).__init__()
        self.in_vocab_sz = in_vocab_sz
        self.num_landmarks = num_landmarks
        self.apply_masc = apply_masc
        self.T = T

        self.landmark_write_gate = nn.ParameterList()
        for _ in range(self.T + 1):
            self.landmark_write_gate.append(nn.Parameter(torch.FloatTensor(1, in_vocab_sz, 1, 1).normal_(0.0, 0.1)))

        self.cbow_fn = CBoW(num_landmarks, in_vocab_sz, init_std=0.01)

        if self.apply_masc:
            self.masc_fn = MASC(in_vocab_sz)
            self.extract_fns = nn.ModuleList()
            for _ in range(T):
                self.extract_fns.append(nn.Linear(in_vocab_sz, 9))
        else:
            self.masc_fn = NoMASC(in_vocab_sz)

        self.loss = nn.CrossEntropyLoss(reduce=False)
Example #3
0
    def __init__(self, in_vocab_sz, num_landmarks, apply_masc=True, T=2):
        super(GuideDiscrete, self).__init__()
        self.in_vocab_sz = in_vocab_sz
        self.num_landmarks = num_landmarks
        self.T = T
        self.apply_masc = apply_masc
        self.emb_map = CBoW(num_landmarks, in_vocab_sz, init_std=0.1)
        self.obs_emb_fn = nn.Linear(in_vocab_sz, in_vocab_sz)
        self.landmark_write_gate = nn.ParameterList()
        for _ in range(T + 1):
            self.landmark_write_gate.append(
                nn.Parameter(
                    torch.FloatTensor(1, in_vocab_sz, 1, 1).normal_(0.0, 0.1)))

        if apply_masc:
            self.masc_fn = MASC(in_vocab_sz)
            self.action_emb = nn.ModuleList()
            for i in range(T):
                self.action_emb.append(nn.Linear(in_vocab_sz, 9))
        else:
            self.masc_fn = NoMASC(in_vocab_sz)

        self.loss = nn.CrossEntropyLoss(reduce=False)
Example #4
0
class GuideLanguage(nn.Module):
    def __init__(self,
                 inp_emb_sz,
                 hidden_sz,
                 num_tokens,
                 apply_masc=True,
                 T=1):
        super(GuideLanguage, self).__init__()
        self.hidden_sz = hidden_sz
        self.inp_emb_sz = inp_emb_sz
        self.num_tokens = num_tokens
        self.apply_masc = apply_masc
        self.T = T

        self.embed_fn = nn.Embedding(num_tokens, inp_emb_sz, padding_idx=0)
        self.encoder_fn = nn.LSTM(inp_emb_sz,
                                  hidden_sz // 2,
                                  batch_first=True,
                                  bidirectional=True)
        self.cbow_fn = CBoW(11, hidden_sz)

        self.T_prediction_fn = nn.Linear(hidden_sz, T + 1)

        self.feat_control_emb = nn.Parameter(
            torch.FloatTensor(hidden_sz).normal_(0.0, 0.1))
        self.feat_control_step_fn = ControlStep(hidden_sz)

        if apply_masc:
            self.act_control_emb = nn.Parameter(
                torch.FloatTensor(hidden_sz).normal_(0.0, 0.1))
            self.act_control_step_fn = ControlStep(hidden_sz)
            self.action_linear_fn = nn.Linear(hidden_sz, 9)

        self.landmark_write_gate = nn.ParameterList()
        self.obs_write_gate = nn.ParameterList()
        for _ in range(T + 1):
            self.landmark_write_gate.append(
                nn.Parameter(
                    torch.FloatTensor(1, hidden_sz, 1, 1).normal_(0, 0.1)))
            self.obs_write_gate.append(
                nn.Parameter(
                    torch.FloatTensor(1, hidden_sz).normal_(0.0, 0.1)))

        if apply_masc:
            self.masc_fn = MASC(self.hidden_sz)
        else:
            self.masc_fn = NoMASC(self.hidden_sz)

        self.loss = nn.CrossEntropyLoss()

    def forward(self, batch, add_rl_loss=False):
        batch_size = batch['utterance'].size(0)
        input_emb = self.embed_fn(batch['utterance'])
        hidden_states, _ = self.encoder_fn(input_emb)

        last_state_indices = batch['utterance_mask'].sum(1).long() - 1

        last_hidden_states = hidden_states[torch.arange(batch_size).long(),
                                           last_state_indices, :]
        T_dist = F.softmax(self.T_prediction_fn(last_hidden_states))
        sampled_Ts = T_dist.multinomial(1).squeeze(-1)

        obs_msgs = list()
        feat_controller = self.feat_control_emb.unsqueeze(0).repeat(
            batch_size, 1)
        for step in range(self.T + 1):
            extracted_msg, feat_controller = self.feat_control_step_fn(
                hidden_states, batch['utterance_mask'], feat_controller)
            obs_msgs.append(extracted_msg)

        tourist_obs_msg = []
        for i, (gate, emb) in enumerate(zip(self.obs_write_gate, obs_msgs)):
            include = (i <= sampled_Ts).float().unsqueeze(-1)
            tourist_obs_msg.append(include * F.sigmoid(gate) * emb)
        tourist_obs_msg = sum(tourist_obs_msg)

        landmark_emb = self.cbow_fn(batch['landmarks']).permute(0, 3, 1, 2)
        landmark_embs = [landmark_emb]

        if self.apply_masc:
            act_controller = self.act_control_emb.unsqueeze(0).repeat(
                batch_size, 1)
            for step in range(self.T):
                extracted_msg, act_controller = self.act_control_step_fn(
                    hidden_states, batch['utterance_mask'], act_controller)
                action_out = self.action_linear_fn(extracted_msg)
                out = self.masc_fn.forward(landmark_embs[-1],
                                           action_out,
                                           current_step=step,
                                           Ts=sampled_Ts)
                landmark_embs.append(out)
        else:
            for step in range(self.T):
                landmark_embs.append(self.masc_fn.forward(landmark_embs[-1]))

        landmarks = sum([
            F.sigmoid(gate) * emb
            for gate, emb in zip(self.landmark_write_gate, landmark_embs)
        ])

        landmarks = landmarks.resize(batch_size, landmarks.size(1),
                                     16).transpose(1, 2)

        out = dict()
        logits = torch.bmm(landmarks,
                           tourist_obs_msg.unsqueeze(-1)).squeeze(-1)
        out['prob'] = F.softmax(logits, dim=1)
        y_true = (batch['target'][:, 0] * 4 + batch['target'][:, 1])

        out['sl_loss'] = -torch.log(
            torch.gather(out['prob'], 1, y_true.unsqueeze(-1)) + 1e-8)

        # add RL loss
        if add_rl_loss:
            advantage = -(out['sl_loss'] - out['sl_loss'].mean()).detach()

            log_prob = torch.log(
                torch.gather(T_dist, 1, sampled_Ts.unsqueeze(-1)) + 1e-8)
            out['rl_loss'] = log_prob * advantage

        out['acc'] = sum([
            1.0
            for pred, target in zip(out['prob'].max(1)[1].data.cpu().numpy(),
                                    y_true.data.cpu().numpy())
            if pred == target
        ]) / batch_size
        return out

    def save(self, path):
        state = dict()
        state['hidden_sz'] = self.hidden_sz
        state['embed_sz'] = self.inp_emb_sz
        state['num_tokens'] = self.num_tokens
        state['apply_masc'] = self.apply_masc
        state['T'] = self.T
        state['parameters'] = self.state_dict()
        torch.save(state, path)

    @classmethod
    def load(cls, path):
        state = torch.load(path)
        guide = cls(state['embed_sz'],
                    state['hidden_sz'],
                    state['num_tokens'],
                    T=state['T'],
                    apply_masc=state['apply_masc'])
        guide.load_state_dict(state['parameters'])
        return guide
Example #5
0
class GuideContinuous(nn.Module):

    def __init__(self, in_vocab_sz, num_landmarks, T=2, apply_masc=True):
        super(GuideContinuous, self).__init__()
        self.in_vocab_sz = in_vocab_sz
        self.num_landmarks = num_landmarks
        self.apply_masc = apply_masc
        self.T = T

        self.landmark_write_gate = nn.ParameterList()
        for _ in range(self.T + 1):
            self.landmark_write_gate.append(nn.Parameter(torch.FloatTensor(1, in_vocab_sz, 1, 1).normal_(0.0, 0.1)))

        self.cbow_fn = CBoW(num_landmarks, in_vocab_sz, init_std=0.01)

        if self.apply_masc:
            self.masc_fn = MASC(in_vocab_sz)
            self.extract_fns = nn.ModuleList()
            for _ in range(T):
                self.extract_fns.append(nn.Linear(in_vocab_sz, 9))
        else:
            self.masc_fn = NoMASC(in_vocab_sz)

        self.loss = nn.CrossEntropyLoss(reduce=False)

    def forward(self, msg, batch):
        obs_msg, act_msg = msg['obs'], msg['act']

        l_emb = self.cbow_fn.forward(batch['landmarks']).permute(0, 3, 1, 2)
        l_embs = [l_emb]

        if self.apply_masc:
            for j in range(self.T):
                act_mask = self.extract_fns[j](act_msg)
                out = self.masc_fn.forward(l_embs[-1], act_mask)
                l_embs.append(out)
        else:
            for j in range(self.T):
                out = self.masc_fn.forward(l_emb)
                l_embs.append(out)

        landmarks = sum([F.sigmoid(gate)*emb for gate, emb in zip(self.landmark_write_gate, l_embs)])
        landmarks = landmarks.resize(l_emb.size(0), landmarks.size(1), 16).transpose(1, 2)

        out = dict()
        logits = torch.bmm(landmarks, obs_msg.unsqueeze(-1)).squeeze(-1)
        out['prob'] = F.softmax(logits, dim=1)


        y_true = (batch['target'][:, 0]*4 + batch['target'][:, 1])

        out['loss'] = self.loss(logits, y_true)
        out['acc'] = sum([1.0 for pred, target in zip(out['prob'].max(1)[1].data.cpu().numpy(), y_true.data.cpu().numpy()) if pred == target])/y_true.size(0)
        return out

    def save(self, path):
        state = dict()
        state['in_vocab_sz'] = self.in_vocab_sz
        state['num_landmarks'] = self.num_landmarks
        state['parameters'] = self.state_dict()
        state['T'] = self.T
        state['apply_masc'] = self.apply_masc
        torch.save(state, path)

    @classmethod
    def load(cls, path):
        state = torch.load(path)
        guide = cls(state['in_vocab_sz'], state['num_landmarks'], T=state['T'],
                    apply_masc=state['apply_masc'])
        guide.load_state_dict(state['parameters'])
        return guide
Example #6
0
class GuideDiscrete(nn.Module):
    def __init__(self, in_vocab_sz, num_landmarks, apply_masc=True, T=2):
        super(GuideDiscrete, self).__init__()
        self.in_vocab_sz = in_vocab_sz
        self.num_landmarks = num_landmarks
        self.T = T
        self.apply_masc = apply_masc
        self.emb_map = CBoW(num_landmarks, in_vocab_sz, init_std=0.1)
        self.obs_emb_fn = nn.Linear(in_vocab_sz, in_vocab_sz)
        self.landmark_write_gate = nn.ParameterList()
        for _ in range(T + 1):
            self.landmark_write_gate.append(
                nn.Parameter(
                    torch.FloatTensor(1, in_vocab_sz, 1, 1).normal_(0.0, 0.1)))

        if apply_masc:
            self.masc_fn = MASC(in_vocab_sz)
            self.action_emb = nn.ModuleList()
            for i in range(T):
                self.action_emb.append(nn.Linear(in_vocab_sz, 9))
        else:
            self.masc_fn = NoMASC(in_vocab_sz)

        self.loss = nn.CrossEntropyLoss(reduce=False)

    def forward(self, message, batch):
        msg_obs = self.obs_emb_fn(message[0])
        batch_size = message[0].size(0)

        landmark_emb = self.emb_map.forward(batch['landmarks']).permute(
            0, 3, 1, 2)
        landmark_embs = [landmark_emb]

        if self.apply_masc:
            for j in range(self.T):
                act_msg = message[1]
                action_out = self.action_emb[j](act_msg)
                out = self.masc_fn.forward(landmark_embs[-1],
                                           action_out,
                                           current_step=j)
                landmark_embs.append(out)
        else:
            for j in range(self.T):
                out = self.masc_fn.forward(landmark_embs[-1])
                landmark_embs.append(out)

        landmarks = sum([
            F.sigmoid(gate) * emb
            for gate, emb in zip(self.landmark_write_gate, landmark_embs)
        ])
        landmarks = landmarks.view(batch_size, landmarks.size(1),
                                   16).transpose(1, 2)

        out = dict()
        logits = torch.bmm(landmarks, msg_obs.unsqueeze(-1)).squeeze(-1)
        out['prob'] = F.softmax(logits, 1)
        y_true = (batch['target'][:, 0] * 4 + batch['target'][:, 1])

        out['loss'] = self.loss(logits, y_true)
        out['acc'] = sum([
            1.0
            for pred, target in zip(out['prob'].max(1)[1].data.cpu().numpy(),
                                    y_true.data.cpu().numpy())
            if pred == target
        ]) / y_true.size(0)
        return out

    def save(self, path):
        state = dict()
        state['in_vocab_sz'] = self.in_vocab_sz
        state['num_landmarks'] = self.num_landmarks
        state['parameters'] = self.state_dict()
        state['T'] = self.T
        state['apply_masc'] = self.apply_masc
        torch.save(state, path)

    @classmethod
    def load(cls, path):
        state = torch.load(path)
        guide = cls(state['in_vocab_sz'],
                    state['num_landmarks'],
                    T=state['T'],
                    apply_masc=state['apply_masc'])
        guide.load_state_dict(state['parameters'])
        return guide