示例#1
0
 def forward(self, input, offsets=None):
     out1 = F.embedding_bag(self.weight, input, offsets,
                            self.max_norm, self.norm_type,
                            self.scale_grad_by_freq, 'mean')
     out2 = F.embedding_bag(self.weight, input, offsets,
                            self.max_norm, self.norm_type,
                            self.scale_grad_by_freq, 'max')
     #return out1
     return torch.cat([out1, out2], 1)
示例#2
0
 def forward(self, x, offsets=None):
     # print("Forward: ", self.hashed_weight[self.weight_idx])
     if not self.lens:
         return F.embedding_bag(x,
                                self.hashed_weight[self.weight_idx],
                                offsets=offsets,
                                mode=self.mode)
     else:
         # global hashed_weight
         return F.embedding_bag(x,
                                hashed_weight[self.weight_idx],
                                offsets=offsets,
                                mode=self.mode)
示例#3
0
 def forward(self, feats_in_l, idx_targets, sizes_subg):
     if self.type_pool == 'center':
         if self.type_res == 'none':
             return feats_in_l[-1][idx_targets]
         else:  # regular JK
             feats_root_l = [f[idx_targets] for f in feats_in_l]
             feat_in = self.f_residue(feats_root_l)
     elif self.type_pool in ['max', 'mean', 'sum']:
         # first pool subgraph at each layer, then residue
         offsets = torch.cumsum(sizes_subg, dim=0)
         offsets = torch.roll(offsets, 1)
         offsets[0] = 0
         idx = torch.arange(feats_in_l[-1].shape[0]).to(
             feats_in_l[-1].device)
         if self.type_res == 'none':
             feat_pool = F.embedding_bag(idx,
                                         feats_in_l[-1],
                                         offsets,
                                         mode=self.type_pool)
             feat_root = feats_in_l[-1][idx_targets]
         else:
             feat_pool_l = []
             for feat in feats_in_l:
                 feat_pool = F.embedding_bag(idx,
                                             feat,
                                             offsets,
                                             mode=self.type_pool)
                 feat_pool_l.append(feat_pool)
             feat_pool = self.f_residue(feat_pool_l)
             feat_root = self.f_residue(
                 [f[idx_targets] for f in feats_in_l])
         feat_in = torch.cat([feat_root, feat_pool], dim=1)
     elif self.type_pool == 'sort':
         if self.type_res == 'none':
             feat_pool_in = feats_in_l[-1]
             feat_root = feats_in_l[-1][idx_targets]
         else:
             feat_pool_in = self.f_residue(feats_in_l)
             feat_root = self.f_residue(
                 [f[idx_targets] for f in feats_in_l])
         arange = torch.arange(sizes_subg.size(0)).to(sizes_subg.device)
         idx_batch = torch.repeat_interleave(arange, sizes_subg)
         feat_pool_k = global_sort_pool(feat_pool_in, idx_batch,
                                        self.k)  # #subg x (k * F)
         feat_pool = self.nn_pool(feat_pool_k)
         feat_in = torch.cat([feat_root, feat_pool], dim=1)
     else:
         raise NotImplementedError
     return self.f_norm(self.nn(feat_in))
示例#4
0
 def get(self, input_: TensorList) -> FloatTensorType:
     if input_.size(0) == 0:
         return torch.empty((0, self.weight.size(1)))
     return F.embedding_bag(
         input_.data.long(), self.weight, input_.offsets[:-1],
         max_norm=self.max_norm, sparse=True,
     )
示例#5
0
 def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None) -> Tensor:
     weight_quant_dequant = self.get_weight()
     return F.embedding_bag(input, weight_quant_dequant, offsets,
                            self.max_norm, self.norm_type,
                            self.scale_grad_by_freq, self.mode, self.sparse,
                            per_sample_weights, self.include_last_offset,
                            self.padding_idx)
示例#6
0
 def forward(self, input, offsets=None, per_sample_weights=None, params=None):
     if params is None:
         params = OrderedDict(self.named_parameters())
     return F.embedding_bag(input, params['weight'], offsets,
                            self.max_norm, self.norm_type,
                            self.scale_grad_by_freq, self.mode, self.sparse,
                            per_sample_weights, self.include_last_offset)
示例#7
0
文件: adapter.py 项目: izhx/nmnlp
 def merge_piece_emb(self, pieces, inputs) -> torch.Tensor:
     offset = torch.tensor([0], dtype=torch.long)
     for (s, w), p in pieces.items():
         inputs[s, w, :] = embedding_bag(
             p, self.bert.embeddings.word_embeddings.weight,
             offset.to(p.device))
     return inputs
示例#8
0
文件: atj.py 项目: ATJNet2020/ATJ-Net
 def aggregate_by_embbag(self, weight, key_index, mode):
   offsets = torch.tensor(
       key_index.start, dtype=torch.long, device=self.device)
   input = torch.arange(
       weight.shape[0], dtype=torch.long, device=self.device)
   x = F.embedding_bag(
       weight=weight, input=input, offsets=offsets, mode=mode)
   return x
 def forward(self, x, offsets=None):
     # self.weight_idx = self.weight_idx.to(x.device)
     # self.hashed_weight = self.hashed_weight.to(x.device)
     # print("Forward: ", self.hashed_weight, self.hashed_weight[self.weight_idx])
     res = F.embedding_bag(x,
                           self.hashed_weight[self.weight_idx_list[0], :],
                           offsets=offsets,
                           mode=self.mode,
                           sparse=self.sparse)
     for idx in range(1, self.update_count):
         res += F.embedding_bag(
             x,
             self.hashed_weight[self.weight_idx_list[idx], :],
             offsets=offsets,
             mode=self.mode,
             sparse=self.sparse)
     return res
 def forward(self, x, offsets=None):
     # self.weight_idx = self.weight_idx.to(x.device)
     # self.hashed_weight = self.hashed_weight.to(x.device)
     # print("Forward: ", self.hashed_weight, self.hashed_weight[self.weight_idx])
     return F.embedding_bag(x,
                            self.hashed_weight[self.weight_idx, :],
                            offsets=offsets,
                            mode=self.mode,
                            sparse=self.sparse)
示例#11
0
文件: nn_ops.py 项目: yuguo68/pytorch
 def forward(self):
     input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
     input2 = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
     embedding_matrix = torch.rand(10, 3)
     offsets = torch.tensor([0, 4])
     return len(
         F.embedding(input, embedding_matrix),
         F.embedding_bag(input2, embedding_matrix, offsets),
         F.one_hot(torch.arange(0, 5) % 3, num_classes=5),
     )
    def forward(self, input, offsets=None, per_sample_weights=None):
        input_q = (input // self.num_collisions).long()
        input_r = torch.remainder(input, self.num_collisions).long()

        embed_q = F.embedding_bag(input_q, self.weight_q, offsets,
                                  self.max_norm, self.norm_type,
                                  self.scale_grad_by_freq, self.mode,
                                  self.sparse, per_sample_weights)
        embed_r = F.embedding_bag(input_r, self.weight_r, offsets,
                                  self.max_norm, self.norm_type,
                                  self.scale_grad_by_freq, self.mode,
                                  self.sparse, per_sample_weights)

        if self.operation == 'concat':
            embed = torch.cat((embed_q, embed_r), dim=1)
        elif self.operation == 'add':
            embed = embed_q + embed_r
        elif self.operation == 'mult':
            embed = embed_q * embed_r

        return embed
示例#13
0
 def test_embedding_bag(self):
     pre_embed_dim = 1024
     post_embed_dim = 32
     inp = torch.randint(0, pre_embed_dim, (128, 16), device='cuda')
     weight = torch.randn(pre_embed_dim,
                          post_embed_dim,
                          device='cuda',
                          dtype=self.dtype)
     output = F.embedding_bag(inp,
                              weight,
                              offsets=None,
                              max_norm=None,
                              norm_type=2,
                              scale_grad_by_freq=False,
                              mode='mean',
                              sparse=False)
示例#14
0
    def __init__(self,
                 name_or_path: str,
                 freeze: str = 'all',
                 layer_num: int = 1,  # 从最后一层开始,往前取几层
                 transform_dim: int = 0,  # 每层降维到多少
                 scalar_mix: Dict[str, Any] = None,
                 word_piece: str = 'first',  # 需要保证input ids为第一个
                 **kwargs):
        super().__init__()
        self.bert = BertModel.from_pretrained(name_or_path)
        self.bert.encoder.output_hidden_states = True
        self.bert.config.output_hidden_states = True

        self.index = 2
        self.layer_num = layer_num
        self.output_dim = self.bert.config.hidden_size

        if freeze == 'all':
            for param in self.bert.parameters():
                param.requires_grad = False

        if transform_dim > 0:
            self.word_transform = ModuleList([NonLinear(
                self.output_dim, transform_dim) for _ in range(self.layer_num)])
            self.output_dim = transform_dim
        else:
            self.word_transform = None

        if word_piece == 'first':
            self.word_piece = None
        else:  # mean of pieces
            offset = torch.tensor([0], dtype=torch.long)
            # self.register_buffer("offset", offset)
            self.word_piece = lambda x: embedding_bag(
                x, self.bert.embeddings.word_embeddings.weight, offset.to(x.device))

        self.scalar_mix = None if scalar_mix is None else ScalarMixWithDropout(
            layer_num, **scalar_mix)
        if layer_num == 1:
            self.scalar_mix = lambda x, *args: x[0]
示例#15
0
    def forward(self, features, mask):
        device = features.device
        if self.mode == "attention" and isinstance(mask, tuple):
            position = torch.arange(features.shape[-2], device=device).reshape([1] * (features.ndim - 2) + [features.shape[-2]])
            mask = (mask[0].unsqueeze(-1) <= position) & (position < mask[1].unsqueeze(-1))
            features = features.unsqueeze(-3)
        if isinstance(mask, tuple):
            original_dtype = features.dtype
            if features.dtype == torch.int or features.dtype == torch.long:
                features = features.float()
            begins, ends = mask
            if self.mode == "first":
                ends = torch.minimum(begins + 1, ends)
            if self.mode == "last":
                begins = torch.maximum(ends - 1, begins)
            begins = begins.expand(*features.shape[:begins.ndim - 1], begins.shape[-1]).clamp_min(0)
            ends = ends.expand(*features.shape[:begins.ndim - 1], ends.shape[-1]).clamp_min(0)
            final_shape = (*begins.shape, *features.shape[begins.ndim:])
            features = features.view(-1, features.shape[-2], features.shape[-1])
            begins = begins.reshape(features.shape[0], begins.numel() // features.shape[0] if len(features) else 0)
            ends = ends.reshape(features.shape[0], ends.numel() // features.shape[0] if len(features) else 0)

            max_window_size = max(0, int((ends - begins).max())) if 0 not in ends.shape else 0
            flat_indices = torch.arange(max_window_size, device=device)[None, None, :] + begins[..., None]
            flat_indices_mask = flat_indices < ends[..., None]
            flat_indices += torch.arange(len(flat_indices), device=device)[:, None, None] * features.shape[1]

            flat_indices = flat_indices[flat_indices_mask]
            res = F.embedding_bag(
                input=flat_indices,
                weight=self.dropout(features.reshape(-1, features.shape[-1])),
                offsets=torch.cat([torch.tensor([0], device=device), flat_indices_mask.sum(-1).reshape(-1)]).cumsum(0)[:-1].clamp_max(flat_indices.shape[0]),
                mode=self.mode if self.mode not in ("first", "last") else "max",
            ).reshape(final_shape)
            if res.dtype != original_dtype:
                res = res.type(original_dtype)
            return res
        elif torch.is_tensor(mask):
            features = features
            features = self.dropout(features)
            if self.mode == "first":
                mask = ~shift(mask.long(), n=1, dim=-1).cumsum(-1).bool() & mask
            elif self.mode == "last":
                mask = mask.flip(-1)
                mask = (~shift(mask.long(), n=1, dim=-1).cumsum(-1).bool() & mask).flip(-1)

            if mask.ndim <= features.ndim - 1:
                mask = mask.unsqueeze(-1)
            if 0 in mask.shape:
                return features.sum(-2)
            if self.mode == "attention":
                weights = self.key_proj(features).masked_fill(~mask, -100000).softmax(-2)  # ... tokens heads
                values = self.value_proj(features) if self.value_proj is not None else features
                values = values.view(*values.shape[:-1], weights.shape[-1], -1)  # ... tokens heads dim
                res = torch.einsum('...nhd,...nh->...hd', values, weights)
                return res.view(*res.shape[:-2], -1)
            elif self.mode == "max":
                features = features.masked_fill(~mask, -100000).max(-2).values.masked_fill(~(mask.any(-2)), 0)
            elif self.mode == "abs-max":
                values, indices = features.abs().masked_fill(~mask, -100000).max(-2)
                features = features.gather(dim=-2, index=indices.unsqueeze(1)).squeeze(1)
            elif self.mode in ("sum", "mean", "first", "last"):
                features = features.masked_fill(~mask, 0).sum(-2)
                if self.mode == "mean":
                    features = features / mask.float().sum(-2).clamp_min(1.)
            elif self.mode == "softmax":
                weights = (features.detach() * self.alpha).masked_fill(~mask, -100000).softmax(-2)
                features = torch.einsum('...nd,...nd->...d', weights, features.masked_fill(~mask, 0))
            elif self.mode == "softmax-abs":
                weights = (features.detach().abs() * self.alpha).masked_fill(~mask, -100000).softmax(-2)
                features = torch.einsum('...nd,...nd->...d', weights, features.masked_fill(~mask, 0))
            return features
示例#16
0
    def forward(self, batch_idx, pairs, negs, offsets, lists):
        centers = pairs[:, 0]
        embed_centers = self.center_embedding(centers)  # N x dim
        embed_contexts_means = torch.stack([
            F.embedding_bag(
                lists,
                self.aspect_embedding.weight[k * self.num_nodes:(k + 1) *
                                             self.num_nodes],
                offsets,
                mode=self.pooling) for k in range(self.num_aspects)
        ], 1)  # N x K x dim

        if self.isSoftmax:
            # Apply softmax
            aspect_softmax = torch.bmm(embed_contexts_means,
                                       embed_centers.unsqueeze(-1)).squeeze(
                                           -1)  # N x K
            # 1-1. Gumbel Softmax
            if self.isGumbelSoftmax:
                # In fact, following the original Gumbel-softmax, the input for F.gumbel_softmax() should be logit (i.e., unnormalized log probabilities.)
                # However, we found that unnormalized probabilities without log are numerically more stable, and performs on par with logit.
                aspect_softmax = F.gumbel_softmax(aspect_softmax,
                                                  tau=self.tau_gumbel,
                                                  hard=self.isHard)
            elif self.isNormalSoftmax:
                # 1-2. Softmax
                aspect_softmax = F.softmax(aspect_softmax, dim=1)

        contexts = pairs[:, 1]
        total_contexts_idxs = torch.cat([
            k * self.num_nodes + contexts.unsqueeze(1)
            for k in range(self.num_aspects)
        ], 1)
        aspect_embedding_context = self.aspect_embedding(
            total_contexts_idxs)  # N x K x dim

        score_pos = torch.bmm(
            aspect_embedding_context, embed_centers.unsqueeze(-1)).squeeze(
                -1)  # (N x K x dim) x (N x dim x 1) = (N x K)
        score_pos = -F.logsigmoid(score_pos)

        embed_contexts_negs = [
            self.aspect_embedding(k * self.num_nodes + negs)
            for k in range(self.num_aspects)
        ]  # [N x num_neg x dim] * K
        score_negs = [
            torch.bmm(embed_contexts_neg,
                      embed_centers.unsqueeze(-1)).squeeze(-1)
            for k, embed_contexts_neg in enumerate(embed_contexts_negs)
        ]
        score_neg = torch.stack([
            -torch.sum(F.logsigmoid(-score_neg), dim=1)
            for score_neg in score_negs
        ], 1)

        if self.isSoftmax:
            sg_loss = aspect_softmax * (score_pos + score_neg)
        else:
            sg_loss = (score_pos + score_neg) / self.num_aspects

        sg_loss = torch.mean(sg_loss)

        final_loss = sg_loss

        # Aspect regularization
        if self.isReg:
            div_metric = None
            # N x K x dim
            aspect_emb_reshaped = self.aspect_embedding.weight.view(
                self.num_aspects, self.num_nodes,
                self.dim).permute(1, 0, 2).contiguous()
            for i in range(self.num_aspects):
                for j in range(i + 1, self.num_aspects):
                    sim_matrix = F.cosine_similarity(
                        aspect_emb_reshaped[:, i, :],
                        aspect_emb_reshaped[:, j, :])
                    mask = torch.abs(sim_matrix) > self.threshold
                    if i == 0 and j == 1:
                        div_metric = (torch.abs(
                            torch.masked_select(sim_matrix, mask))).sum()
                    else:
                        div_metric += (torch.abs(
                            torch.masked_select(sim_matrix, mask))).sum()

            div_reg = self.reg_coef * div_metric

            final_loss += div_reg
            return final_loss, (self.reg_coef * div_metric).item()

        return final_loss
示例#17
0
 def forward(self, x):
     return F.embedding_bag(x, self.hashed_weight[self.weight_idx])
示例#18
0
 def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor:
     return F.embedding_bag(input, self.weight_fake_quant(self.weight),
                            offsets, self.max_norm, self.norm_type,
                            self.scale_grad_by_freq, self.mode, self.sparse,
                            per_sample_weights, self.include_last_offset,
                            self.padding_idx)
示例#19
0
 def forward(self, input, offsets=None, per_sample_weights=None):
     return self.activation_quantizer(
         F.embedding_bag(input, self.weight_quantizer(self.weight), offsets,
                         self.max_norm, self.norm_type,
                         self.scale_grad_by_freq, self.mode, self.sparse,
                         per_sample_weights, self.include_last_offset))
    def rollout(self):
        # Reset environment.
        obs = self.env.reset()
        self.batch_size = len(obs)

        # Trajectory history.
        traj = [{
            'scan':
            ob['scan'],
            'instr_id':
            ob['instr_id'],
            'agent_pose': [(ob['viewpoint'], ob['heading'], ob['elevation'])],
            'agent_mode': ['main'],
            'agent_ask': [],
            'agent_nav': [],
            'instruction': [ob['instruction']],
            'target_viewpoints': [ob['target_viewpoints']],
            'nav_prob': [],
            'message': [],
            'time_on_task': [0],
            'time': [0],
            'teacher_nav': [],
            'teacher_ask': [],
            'teacher_reason': [],
            'agent_reason': [],
            'agent_reason_prob': [],
            'adj_loc_list': []
        } for ob in obs]

        # Initial decoder states.
        nav_a, ask_a = self.model.reset(self.batch_size)

        ended = [False] * self.batch_size

        should_encode_instruction = True

        info_list = [[] for _ in range(self.batch_size)]
        nav_logits, ask_logits, ask_reason_logits = [], [], []
        nav_pos_targets = []
        last_ask = [-1] * self.batch_size

        device = nav_a.device
        self.nav_loss = torch.tensor(0., device=device)
        self.ask_loss = torch.tensor(0., device=device)
        self.ask_reason_loss = torch.tensor(0., device=device)

        for time_step in range(self.episode_len):

            # Encode instruction
            if should_encode_instruction:
                ctx_seq, ctx_mask = self._text_context_variable(obs)
                nav_ctx, ask_ctx = self.model.encode(ctx_seq, ctx_mask)
                if not self.hparams.no_reset_inter:
                    self.model.reset_text_decoder(self.batch_size)

            # Masks
            nav_a_embeds, nav_logit_mask = self._nav_action_variable(obs)
            ask_logit_mask = self._ask_action_variable(obs)

            # Visual features
            curr_view_features, goal_view_features = \
                self._visual_feature_variable(obs)

            # Time feature
            time = self.from_numpy(np.array([ob['time'] for ob in obs]))
            time_on_task = self.from_numpy(
                np.array([ob['time_on_task'] for ob in obs]))

            # Query nav policy
            nav_logit = self.model.decode_nav(time, time_on_task, nav_a,
                                              nav_a_embeds, nav_ctx, ctx_mask,
                                              curr_view_features,
                                              goal_view_features,
                                              nav_logit_mask)

            # Query nav teacher
            nav_target_list = self.teacher.next_nav(obs)
            nav_a = self._next_action(nav_logit, self.nav_feedback)
            nav_a_list = nav_a.tolist()

            # Compute distribution
            nav_dist = self._compute_nav_dist(obs, nav_logit)
            nav_dist_list = nav_dist.tolist()

            if self.hparams.ask_baseline is None:
                # Query ask policy
                ask_logit, ask_reason_logit = self.model.decode_ask(
                    time, time_on_task, ask_a, nav_dist, ask_ctx, ctx_mask,
                    curr_view_features, goal_view_features, ask_logit_mask)

                ask_logits.append(ask_logit)
                ask_a = self._next_action(ask_logit, self.ask_feedback)
                ask_a_list = ask_a.tolist()

                ask_reason_logits.append(ask_reason_logit)
                ask_reason_prob_list = torch.sigmoid(ask_reason_logit).tolist()
            else:
                # Query ask teacher
                for i, ob in enumerate(obs):
                    ob['last_ask'] = last_ask[i]
                ask_a_list, ask_reason = self.teacher.next_ask(obs)
                for i in range(self.batch_size):
                    ask_a_list[i] = max(0, ask_a_list[i])
                    if ask_a_list[i] == self.ask_actions.index('request_help'):
                        last_ask[i] = time_step

            should_encode_instruction = False
            anna_messages = [None] * self.batch_size

            for i in range(self.batch_size):

                # Perfect language instruction interpretation
                if self.hparams.perfect_interpretation and obs[i][
                        'mode'] == 'on_route':
                    nav_a_list[i] = max(0, nav_target_list[i])

                # If request
                if ask_a_list[i] == self.ask_actions.index('request_help'):
                    # Query ANNA for route instruction and departure node
                    anna_messages[i] = self.anna(obs[i])
                    # Agent should not move
                    nav_a_list[i] = 0
                    # Teacher nav action should be ignored
                    nav_target_list[i] = -1
                    # Need to re-encode the instruction.
                    should_encode_instruction = True
                else:
                    # If agent decides to depart route, re-encode instruction
                    if nav_a_list[i] == 0 and obs[i]['mode'] == 'on_route':
                        should_encode_instruction = True

            nav_pos_targets.append(np.array(nav_target_list, dtype=np.int64))
            nav_logits.append(nav_logit)

            nav_logit_list = nav_logit.tolist()
            for i in range(self.batch_size):
                info_list[i].append({
                    'ob':
                    obs[i],
                    'nav_dist':
                    nav_dist_list[i],
                    'nav_target':
                    nav_target_list[i],
                    'nav_a':
                    nav_a_list[i],
                    'nav_argmax':
                    int(np.argmax(nav_logit_list[i])),
                    'ask_a':
                    ask_a_list[i],
                    'num_a':
                    int(nav_logit.size(1))
                })

            # Retrieve embedding of the taken nav action.
            nav_a = nav_a_embeds[np.arange(self.batch_size),
                                 nav_a_list, :].detach()

            # Update ask action mask
            ask_a = torch.tensor(ask_a_list, dtype=torch.long, device=device)
            self.model.ask_module.update_action_mask(
                ask_a != self.ask_actions.index('request_help'))

            adj_loc_lists = [ob['adj_loc_list'] for ob in obs]

            # Take the nav action.
            obs = self.env.step(nav_a_list, anna_messages)

            unaligned_nav_dist = F.softmax(nav_logit, dim=1).tolist()

            # Book-keeping
            for i, ob in enumerate(obs):
                if not ended[i]:
                    traj[i]['agent_pose'].append(
                        (ob['viewpoint'], ob['heading'], ob['elevation'], i))
                    traj[i]['agent_mode'].append(ob['mode'])
                    traj[i]['instruction'].append(ob['instruction'])
                    traj[i]['target_viewpoints'].append(
                        ob['target_viewpoints'])
                    traj[i]['message'].append(anna_messages[i])
                    traj[i]['time_on_task'].append(ob['time_on_task'])
                    traj[i]['time'].append(ob['time'])
                    traj[i]['adj_loc_list'].append(adj_loc_lists[i])

                    traj[i]['teacher_nav'].append(nav_target_list[i])

                    if self.hparams.ask_baseline is None:
                        agent_reasons = []
                        out_str = []
                        for k, prob in enumerate(ask_reason_prob_list[i]):
                            label = AskTeacher.reason_labels[k]
                            out_str.append('%s %.1f' % (label[0], prob * 100))
                            if prob >= 0.5:
                                agent_reasons.append(label)
                        traj[i]['agent_reason'].append(agent_reasons)
                        out_str = ' '.join(out_str)
                        traj[i]['agent_reason_prob'].append(out_str)
                    else:
                        traj[i]['teacher_ask'].append(ask_a_list[i])
                        traj[i]['teacher_reason'].append(ask_reason[i])

                    traj[i]['agent_ask'].append(ask_a_list[i])

                    prob_str = ' '.join([
                        ('%d-%.2f' % (loc['absViewIndex'], x))
                        for loc, x in zip(adj_loc_lists[i],
                                          unaligned_nav_dist[i])
                    ])
                    if ask_a_list[i] == self.ask_actions.index('request_help'):
                        traj[i]['agent_nav'].append(-1)
                        traj[i]['nav_prob'].append(prob_str)
                    else:
                        traj[i]['agent_nav'].append(nav_a_list[i])
                        traj[i]['nav_prob'].append(
                            '%d %.2f %s' %
                            (nav_a_list[i],
                             unaligned_nav_dist[i][nav_a_list[i]], prob_str))

                    ended[i] |= ob['ended']

            if all(ended):
                break

        for i in range(self.batch_size):
            info_list[i].append({'ob': obs[i]})

        # RETROSPECTIVE navigation teacher
        # Look back at the trajectory and decide when the agent should have requested
        if self.hparams.ask_baseline is None:
            ask_targets, ask_reason_targets, ask_reasons = \
                self.teacher.all_ask(info_list)
            for t, target, reason in zip(traj, ask_targets, ask_reasons):
                l = len(t['agent_ask'])
                t['teacher_ask'] = target[:l].tolist()
                t['teacher_reason'] = reason[:l]

        nav_neg_targets, neg_offsets = \
            self.teacher.all_neg_nav(info_list)

        if not self.is_eval:
            # Help-request loss
            if self.hparams.ask_baseline is None:
                # seq_len x batch
                ask_targets = self.from_numpy(ask_targets.transpose())
                ask_reason_targets = self.from_numpy(
                    ask_reason_targets.swapaxes(0, 1)).float()

                for ask_logit, ask_target, ask_reason_logit, ask_reason_target \
                    in zip(ask_logits, ask_targets, ask_reason_logits,
                    ask_reason_targets):

                    # Ask loss
                    ask_loss = self.ask_criterion(ask_logit, ask_target)

                    # Ask reason loss
                    ask_reason_loss = self.ask_reason_criterion(
                        ask_reason_logit, ask_reason_target)
                    ask_reason_loss = ask_reason_loss.mean(dim=-1)
                    mask = (ask_target != -1)
                    normalizer = mask.sum().item()
                    if normalizer > 0:
                        ask_reason_loss = (ask_reason_loss * mask.float()).sum() / \
                            normalizer
                    else:
                        ask_reason_loss = 0.

                    if self.hparams.no_reason:
                        self.ask_loss += ask_loss
                    else:
                        self.ask_loss += ask_loss + ask_reason_loss

            # Navigation loss
            nav_pos_targets = self.from_numpy(np.stack(nav_pos_targets))
            neg_offsets = self.from_numpy(neg_offsets)

            for nav_logit, nav_pos_target, nav_neg_target, neg_offset in \
                zip(nav_logits, nav_pos_targets, nav_neg_targets, neg_offsets):

                # -log P(a+)
                nav_pos_loss = self.nav_criterion(nav_logit, nav_pos_target)

                # K = number of negative actions
                # 1/K sum -log P(a-)
                nav_log_softmax = F.log_softmax(nav_logit, dim=1).view(-1, 1)
                nav_neg_target = self.from_numpy(nav_neg_target)
                nav_neg_loss = -F.embedding_bag(
                    nav_neg_target, nav_log_softmax, neg_offset).squeeze(1)

                mask = (nav_pos_target != -1)
                normalizer = mask.sum().item()
                if normalizer > 0:
                    nav_neg_loss = (nav_neg_loss *
                                    mask.float()).sum() / normalizer
                else:
                    nav_neg_loss = 0.

                # nav_loss = -log P(a+) + alpha/K sum log P(a-)
                self.nav_loss += nav_pos_loss - self.hparams.alpha * nav_neg_loss

            self._compute_loss()

        return traj
示例#21
0
 def forward(self, input):
     return F.embedding_bag(input, self.weight_fake_quant(self.weight))