Example #1
0
    def command_generation_teacher_force(self, observation_strings, task_desc_strings, target_strings):
        input_target_strings = [" ".join(["[CLS]"] + item.split()) for item in target_strings]
        output_target_strings = [" ".join(item.split() + ["[SEP]"]) for item in target_strings]

        input_obs = self.get_word_input(observation_strings)
        h_obs, obs_mask = self.encode(observation_strings, use_model="online")
        h_td, td_mask = self.encode(task_desc_strings, use_model="online")

        aggregated_obs_representation = self.online_net.aggretate_information(h_obs, obs_mask, h_td, td_mask)  # batch x obs_length x hid

        input_target = self.get_word_input(input_target_strings)
        ground_truth = self.get_word_input(output_target_strings)  # batch x target_length
        target_mask = compute_mask(input_target)  # mask of ground truth should be the same
        pred = self.online_net.decode(input_target, target_mask, aggregated_obs_representation, obs_mask, None, input_obs)  # batch x target_length x vocab

        batch_loss = NegativeLogLoss(pred * target_mask.unsqueeze(-1), ground_truth, target_mask, smoothing_eps=self.smoothing_eps)
        loss = torch.mean(batch_loss)

        if loss is None:
            return None
        # Backpropagate
        self.online_net.zero_grad()
        self.optimizer.zero_grad()
        loss.backward()
        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), self.clip_grad_norm)
        self.optimizer.step()  # apply gradients
        return to_np(loss)
Example #2
0
    def command_generation_greedy_generation(self, observation_feats,
                                             task_desc_strings,
                                             previous_dynamics):
        with torch.no_grad():
            batch_size = len(observation_feats)

            aggregated_obs_feat = self.aggregate_feats_seq(observation_feats)
            h_obs = self.online_net.vision_fc(aggregated_obs_feat)
            h_td, td_mask = self.encode(task_desc_strings, use_model="online")
            h_td_mean = self.online_net.masked_mean(h_td, td_mask).unsqueeze(1)
            h_obs = h_obs.to(h_td_mean.device)
            vision_td = torch.cat((h_obs, h_td_mean),
                                  dim=1)  # batch x k boxes x hid
            vision_td_mask = torch.ones(
                (batch_size,
                 h_obs.shape[1] + h_td_mean.shape[1])).to(h_td_mean.device)

            if self.recurrent:
                averaged_vision_td_representation = self.online_net.masked_mean(
                    vision_td, vision_td_mask)
                current_dynamics = self.online_net.rnncell(
                    averaged_vision_td_representation, previous_dynamics
                ) if previous_dynamics is not None else self.online_net.rnncell(
                    averaged_vision_td_representation)
            else:
                current_dynamics = None

            # greedy generation
            input_target_list = [[self.word2id["[CLS]"]]
                                 for i in range(batch_size)]
            eos = np.zeros(batch_size)
            for _ in range(self.max_target_length):

                input_target = copy.deepcopy(input_target_list)
                input_target = pad_sequences(
                    input_target, maxlen=max_len(input_target)).astype('int32')
                input_target = to_pt(input_target, self.use_cuda)
                target_mask = compute_mask(
                    input_target)  # mask of ground truth should be the same
                pred = self.online_net.vision_decode(
                    input_target, target_mask, vision_td, vision_td_mask,
                    current_dynamics)  # batch x target_length x vocab
                # pointer softmax
                pred = to_np(pred[:, -1])  # batch x vocab
                pred = np.argmax(pred, -1)  # batch
                for b in range(batch_size):
                    new_stuff = [pred[b]] if eos[b] == 0 else []
                    input_target_list[b] = input_target_list[b] + new_stuff
                    if pred[b] == self.word2id["[SEP]"]:
                        eos[b] = 1
                if np.sum(eos) == batch_size:
                    break
            res = [self.tokenizer.decode(item) for item in input_target_list]
            res = [
                item.replace("[CLS]", "").replace("[SEP]", "").strip()
                for item in res
            ]
            res = [item.replace(" in / on ", " in/on ") for item in res]
            return res, current_dynamics
Example #3
0
 def encode_text(self, input_word_ids):
     # input_word_ids: batch x seq_len
     # text embedding / encoding
     input_word_masks = compute_mask(input_word_ids)
     embeddings = self.embed(input_word_ids, input_word_masks)  # batch x seq_len x emb
     squared_mask = torch.bmm(input_word_masks.unsqueeze(-1), input_word_masks.unsqueeze(1))  # batch x seq_len x seq_len
     encoding_sequence = embeddings
     for i in range(self.encoder_layers):
         encoding_sequence = self.encoder[i](encoding_sequence, input_word_masks, squared_mask, i * (self.encoder_conv_num + 2) + 1, self.encoder_layers)  # batch x time x enc
     return encoding_sequence, input_word_masks
Example #4
0
    def command_generation_teacher_force(self, observation_feats,
                                         task_desc_strings, target_strings):
        input_target_strings = [
            " ".join(["[CLS]"] + item.split()) for item in target_strings
        ]
        output_target_strings = [
            " ".join(item.split() + ["[SEP]"]) for item in target_strings
        ]
        batch_size = len(observation_feats)

        aggregated_obs_feat = self.aggregate_feats_seq(observation_feats)
        h_obs = self.online_net.vision_fc(aggregated_obs_feat)
        h_td, td_mask = self.encode(task_desc_strings, use_model="online")
        h_td_mean = self.online_net.masked_mean(h_td, td_mask).unsqueeze(1)
        h_obs = h_obs.to(h_td_mean.device)
        vision_td = torch.cat((h_obs, h_td_mean),
                              dim=1)  # batch x k boxes x hi
        vision_td_mask = torch.ones(
            (batch_size,
             h_obs.shape[1] + h_td_mean.shape[1])).to(h_td_mean.device)

        input_target = self.get_word_input(input_target_strings)
        ground_truth = self.get_word_input(
            output_target_strings)  # batch x target_length
        target_mask = compute_mask(
            input_target)  # mask of ground truth should be the same
        pred = self.online_net.vision_decode(
            input_target, target_mask, vision_td, vision_td_mask,
            None)  # batch x target_length x vocab

        batch_loss = NegativeLogLoss(pred * target_mask.unsqueeze(-1),
                                     ground_truth,
                                     target_mask,
                                     smoothing_eps=self.smoothing_eps)
        loss = torch.mean(batch_loss)

        if loss is None:
            return None, None
        # Backpropagate
        self.online_net.zero_grad()
        self.optimizer.zero_grad()
        loss.backward()
        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(self.online_net.parameters(),
                                       self.clip_grad_norm)
        self.optimizer.step()  # apply gradients
        return to_np(pred), to_np(loss)
Example #5
0
    def command_generation_recurrent_teacher_force(self, seq_observation_strings, seq_task_desc_strings, seq_target_strings, contains_first_step=False):
        loss_list = []
        previous_dynamics = None
        h_td, td_mask = self.encode(seq_task_desc_strings[0], use_model="online")
        for step_no in range(self.dagger_replay_sample_history_length):
            input_target_strings = [" ".join(["[CLS]"] + item.split()) for item in seq_target_strings[step_no]]
            output_target_strings = [" ".join(item.split() + ["[SEP]"]) for item in seq_target_strings[step_no]]

            input_obs = self.get_word_input(seq_observation_strings[step_no])
            h_obs, obs_mask = self.encode(seq_observation_strings[step_no], use_model="online")
            aggregated_obs_representation = self.online_net.aggretate_information(h_obs, obs_mask, h_td, td_mask)  # batch x obs_length x hid

            averaged_representation = self.online_net.masked_mean(aggregated_obs_representation, obs_mask)  # batch x hid
            current_dynamics = self.online_net.rnncell(averaged_representation, previous_dynamics) if previous_dynamics is not None else self.online_net.rnncell(averaged_representation)

            input_target = self.get_word_input(input_target_strings)
            ground_truth = self.get_word_input(output_target_strings)  # batch x target_length
            target_mask = compute_mask(input_target)  # mask of ground truth should be the same
            pred = self.online_net.decode(input_target, target_mask, aggregated_obs_representation, obs_mask, current_dynamics, input_obs)  # batch x target_length x vocab

            previous_dynamics = current_dynamics
            if (not contains_first_step) and step_no < self.dagger_replay_sample_update_from:
                previous_dynamics = previous_dynamics.detach()
                continue

            batch_loss = NegativeLogLoss(pred * target_mask.unsqueeze(-1), ground_truth, target_mask, smoothing_eps=self.smoothing_eps)
            loss = torch.mean(batch_loss)
            loss_list.append(loss)

        loss = torch.stack(loss_list).mean()
        if loss is None:
            return None
        # Backpropagate
        self.online_net.zero_grad()
        self.optimizer.zero_grad()
        loss.backward()
        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), self.clip_grad_norm)
        self.optimizer.step()  # apply gradients
        return to_np(loss)
Example #6
0
    def command_generation_beam_search_generation(self, observation_strings, task_desc_strings, previous_dynamics):
        with torch.no_grad():

            batch_size = len(observation_strings)
            beam_width = self.beam_width
            if beam_width == 1:
                res, current_dynamics = self.command_generation_greedy_generation(observation_strings, task_desc_strings, previous_dynamics)
                res = [[item] for item in res]
                return res, current_dynamics
            generate_top_k = self.generate_top_k
            res = []

            input_obs = self.get_word_input(observation_strings)
            h_obs, obs_mask = self.encode(observation_strings, use_model="online")
            h_td, td_mask = self.encode(task_desc_strings, use_model="online")
            aggregated_obs_representation = self.online_net.aggretate_information(h_obs, obs_mask, h_td, td_mask)  # batch x obs_length x hid

            if self.recurrent:
                averaged_representation = self.online_net.masked_mean(aggregated_obs_representation, obs_mask)  # batch x hid
                current_dynamics = self.online_net.rnncell(averaged_representation, previous_dynamics) if previous_dynamics is not None else self.online_net.rnncell(averaged_representation)
            else:
                current_dynamics = None

            for b in range(batch_size):

                # starts from CLS tokens
                __input_target_list = [self.word2id["[CLS]"]]
                __input_obs = input_obs[b: b + 1]  # 1 x obs_len
                __obs_mask = obs_mask[b: b + 1]  # 1 x obs_len
                __aggregated_obs_representation = aggregated_obs_representation[b: b + 1]  # 1 x obs_len x hid
                if current_dynamics is not None:
                    __current_dynamics = current_dynamics[b: b + 1]  # 1 x hid
                else:
                    __current_dynamics = None
                ended_nodes = []

                # starting node -  previous node, input target, logp, length
                node = BeamSearchNode(None, __input_target_list, 0, 1)
                nodes_queue = PriorityQueue()
                # start the queue
                nodes_queue.put((node.val, node))
                queue_size = 1

                while(True):
                    # give up when decoding takes too long
                    if queue_size > 2000:
                        break

                    # fetch the best node
                    score, n = nodes_queue.get()
                    __input_target_list = n.input_target

                    if (n.input_target[-1] == self.word2id["[SEP]"] or n.length >= self.max_target_length) and n.previous_node != None:
                        ended_nodes.append((score, n))
                        # if we reached maximum # of sentences required
                        if len(ended_nodes) >= generate_top_k:
                            break
                        else:
                            continue

                    input_target = pad_sequences([__input_target_list], dtype='int32')
                    input_target = to_pt(input_target, self.use_cuda)
                    target_mask = compute_mask(input_target)
                    # decode for one step using decoder
                    pred = self.online_net.decode(input_target, target_mask, __aggregated_obs_representation, __obs_mask, __current_dynamics, __input_obs)  # 1 x target_length x vocab
                    pred = pred[0][-1].cpu()
                    gt_zero = torch.gt(pred, 0.0).float()  # vocab
                    epsilon = torch.le(pred, 0.0).float() * 1e-8  # vocab
                    log_pred = torch.log(pred + epsilon) * gt_zero  # vocab

                    top_beam_width_log_probs, top_beam_width_indicies = torch.topk(log_pred, beam_width)
                    next_nodes = []

                    for new_k in range(beam_width):
                        pos = top_beam_width_indicies[new_k]
                        log_p = top_beam_width_log_probs[new_k].item()
                        node = BeamSearchNode(n, __input_target_list + [pos], n.log_prob + log_p, n.length + 1)
                        next_nodes.append((node.val, node))

                    # put them into queue
                    for i in range(len(next_nodes)):
                        score, nn = next_nodes[i]
                        nodes_queue.put((score, nn))
                    # increase qsize
                    queue_size += len(next_nodes) - 1

                # choose n best paths
                if len(ended_nodes) == 0:
                    ended_nodes = [nodes_queue.get() for _ in range(generate_top_k)]

                utterances = []
                for score, n in sorted(ended_nodes, key=operator.itemgetter(0)):
                    utte = n.input_target
                    utte_string = self.tokenizer.decode(utte)
                    utterances.append(utte_string)
                utterances = [item.replace("[CLS]", "").replace("[SEP]", "").strip() for item in utterances]
                utterances = [item.replace(" in / on ", " in/on " ) for item in utterances]
                res.append(utterances)
            return res, current_dynamics