def forward(self, input, offsets=None): out1 = F.embedding_bag(self.weight, input, offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, 'mean') out2 = F.embedding_bag(self.weight, input, offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, 'max') #return out1 return torch.cat([out1, out2], 1)
def forward(self, x, offsets=None): # print("Forward: ", self.hashed_weight[self.weight_idx]) if not self.lens: return F.embedding_bag(x, self.hashed_weight[self.weight_idx], offsets=offsets, mode=self.mode) else: # global hashed_weight return F.embedding_bag(x, hashed_weight[self.weight_idx], offsets=offsets, mode=self.mode)
def forward(self, feats_in_l, idx_targets, sizes_subg): if self.type_pool == 'center': if self.type_res == 'none': return feats_in_l[-1][idx_targets] else: # regular JK feats_root_l = [f[idx_targets] for f in feats_in_l] feat_in = self.f_residue(feats_root_l) elif self.type_pool in ['max', 'mean', 'sum']: # first pool subgraph at each layer, then residue offsets = torch.cumsum(sizes_subg, dim=0) offsets = torch.roll(offsets, 1) offsets[0] = 0 idx = torch.arange(feats_in_l[-1].shape[0]).to( feats_in_l[-1].device) if self.type_res == 'none': feat_pool = F.embedding_bag(idx, feats_in_l[-1], offsets, mode=self.type_pool) feat_root = feats_in_l[-1][idx_targets] else: feat_pool_l = [] for feat in feats_in_l: feat_pool = F.embedding_bag(idx, feat, offsets, mode=self.type_pool) feat_pool_l.append(feat_pool) feat_pool = self.f_residue(feat_pool_l) feat_root = self.f_residue( [f[idx_targets] for f in feats_in_l]) feat_in = torch.cat([feat_root, feat_pool], dim=1) elif self.type_pool == 'sort': if self.type_res == 'none': feat_pool_in = feats_in_l[-1] feat_root = feats_in_l[-1][idx_targets] else: feat_pool_in = self.f_residue(feats_in_l) feat_root = self.f_residue( [f[idx_targets] for f in feats_in_l]) arange = torch.arange(sizes_subg.size(0)).to(sizes_subg.device) idx_batch = torch.repeat_interleave(arange, sizes_subg) feat_pool_k = global_sort_pool(feat_pool_in, idx_batch, self.k) # #subg x (k * F) feat_pool = self.nn_pool(feat_pool_k) feat_in = torch.cat([feat_root, feat_pool], dim=1) else: raise NotImplementedError return self.f_norm(self.nn(feat_in))
def get(self, input_: TensorList) -> FloatTensorType: if input_.size(0) == 0: return torch.empty((0, self.weight.size(1))) return F.embedding_bag( input_.data.long(), self.weight, input_.offsets[:-1], max_norm=self.max_norm, sparse=True, )
def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None) -> Tensor: weight_quant_dequant = self.get_weight() return F.embedding_bag(input, weight_quant_dequant, offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights, self.include_last_offset, self.padding_idx)
def forward(self, input, offsets=None, per_sample_weights=None, params=None): if params is None: params = OrderedDict(self.named_parameters()) return F.embedding_bag(input, params['weight'], offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights, self.include_last_offset)
def merge_piece_emb(self, pieces, inputs) -> torch.Tensor: offset = torch.tensor([0], dtype=torch.long) for (s, w), p in pieces.items(): inputs[s, w, :] = embedding_bag( p, self.bert.embeddings.word_embeddings.weight, offset.to(p.device)) return inputs
def aggregate_by_embbag(self, weight, key_index, mode): offsets = torch.tensor( key_index.start, dtype=torch.long, device=self.device) input = torch.arange( weight.shape[0], dtype=torch.long, device=self.device) x = F.embedding_bag( weight=weight, input=input, offsets=offsets, mode=mode) return x
def forward(self, x, offsets=None): # self.weight_idx = self.weight_idx.to(x.device) # self.hashed_weight = self.hashed_weight.to(x.device) # print("Forward: ", self.hashed_weight, self.hashed_weight[self.weight_idx]) res = F.embedding_bag(x, self.hashed_weight[self.weight_idx_list[0], :], offsets=offsets, mode=self.mode, sparse=self.sparse) for idx in range(1, self.update_count): res += F.embedding_bag( x, self.hashed_weight[self.weight_idx_list[idx], :], offsets=offsets, mode=self.mode, sparse=self.sparse) return res
def forward(self, x, offsets=None): # self.weight_idx = self.weight_idx.to(x.device) # self.hashed_weight = self.hashed_weight.to(x.device) # print("Forward: ", self.hashed_weight, self.hashed_weight[self.weight_idx]) return F.embedding_bag(x, self.hashed_weight[self.weight_idx, :], offsets=offsets, mode=self.mode, sparse=self.sparse)
def forward(self): input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]) input2 = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]) embedding_matrix = torch.rand(10, 3) offsets = torch.tensor([0, 4]) return len( F.embedding(input, embedding_matrix), F.embedding_bag(input2, embedding_matrix, offsets), F.one_hot(torch.arange(0, 5) % 3, num_classes=5), )
def forward(self, input, offsets=None, per_sample_weights=None): input_q = (input // self.num_collisions).long() input_r = torch.remainder(input, self.num_collisions).long() embed_q = F.embedding_bag(input_q, self.weight_q, offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights) embed_r = F.embedding_bag(input_r, self.weight_r, offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights) if self.operation == 'concat': embed = torch.cat((embed_q, embed_r), dim=1) elif self.operation == 'add': embed = embed_q + embed_r elif self.operation == 'mult': embed = embed_q * embed_r return embed
def test_embedding_bag(self): pre_embed_dim = 1024 post_embed_dim = 32 inp = torch.randint(0, pre_embed_dim, (128, 16), device='cuda') weight = torch.randn(pre_embed_dim, post_embed_dim, device='cuda', dtype=self.dtype) output = F.embedding_bag(inp, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode='mean', sparse=False)
def __init__(self, name_or_path: str, freeze: str = 'all', layer_num: int = 1, # 从最后一层开始,往前取几层 transform_dim: int = 0, # 每层降维到多少 scalar_mix: Dict[str, Any] = None, word_piece: str = 'first', # 需要保证input ids为第一个 **kwargs): super().__init__() self.bert = BertModel.from_pretrained(name_or_path) self.bert.encoder.output_hidden_states = True self.bert.config.output_hidden_states = True self.index = 2 self.layer_num = layer_num self.output_dim = self.bert.config.hidden_size if freeze == 'all': for param in self.bert.parameters(): param.requires_grad = False if transform_dim > 0: self.word_transform = ModuleList([NonLinear( self.output_dim, transform_dim) for _ in range(self.layer_num)]) self.output_dim = transform_dim else: self.word_transform = None if word_piece == 'first': self.word_piece = None else: # mean of pieces offset = torch.tensor([0], dtype=torch.long) # self.register_buffer("offset", offset) self.word_piece = lambda x: embedding_bag( x, self.bert.embeddings.word_embeddings.weight, offset.to(x.device)) self.scalar_mix = None if scalar_mix is None else ScalarMixWithDropout( layer_num, **scalar_mix) if layer_num == 1: self.scalar_mix = lambda x, *args: x[0]
def forward(self, features, mask): device = features.device if self.mode == "attention" and isinstance(mask, tuple): position = torch.arange(features.shape[-2], device=device).reshape([1] * (features.ndim - 2) + [features.shape[-2]]) mask = (mask[0].unsqueeze(-1) <= position) & (position < mask[1].unsqueeze(-1)) features = features.unsqueeze(-3) if isinstance(mask, tuple): original_dtype = features.dtype if features.dtype == torch.int or features.dtype == torch.long: features = features.float() begins, ends = mask if self.mode == "first": ends = torch.minimum(begins + 1, ends) if self.mode == "last": begins = torch.maximum(ends - 1, begins) begins = begins.expand(*features.shape[:begins.ndim - 1], begins.shape[-1]).clamp_min(0) ends = ends.expand(*features.shape[:begins.ndim - 1], ends.shape[-1]).clamp_min(0) final_shape = (*begins.shape, *features.shape[begins.ndim:]) features = features.view(-1, features.shape[-2], features.shape[-1]) begins = begins.reshape(features.shape[0], begins.numel() // features.shape[0] if len(features) else 0) ends = ends.reshape(features.shape[0], ends.numel() // features.shape[0] if len(features) else 0) max_window_size = max(0, int((ends - begins).max())) if 0 not in ends.shape else 0 flat_indices = torch.arange(max_window_size, device=device)[None, None, :] + begins[..., None] flat_indices_mask = flat_indices < ends[..., None] flat_indices += torch.arange(len(flat_indices), device=device)[:, None, None] * features.shape[1] flat_indices = flat_indices[flat_indices_mask] res = F.embedding_bag( input=flat_indices, weight=self.dropout(features.reshape(-1, features.shape[-1])), offsets=torch.cat([torch.tensor([0], device=device), flat_indices_mask.sum(-1).reshape(-1)]).cumsum(0)[:-1].clamp_max(flat_indices.shape[0]), mode=self.mode if self.mode not in ("first", "last") else "max", ).reshape(final_shape) if res.dtype != original_dtype: res = res.type(original_dtype) return res elif torch.is_tensor(mask): features = features features = self.dropout(features) if self.mode == "first": mask = ~shift(mask.long(), n=1, dim=-1).cumsum(-1).bool() & mask elif self.mode == "last": mask = mask.flip(-1) mask = (~shift(mask.long(), n=1, dim=-1).cumsum(-1).bool() & mask).flip(-1) if mask.ndim <= features.ndim - 1: mask = mask.unsqueeze(-1) if 0 in mask.shape: return features.sum(-2) if self.mode == "attention": weights = self.key_proj(features).masked_fill(~mask, -100000).softmax(-2) # ... tokens heads values = self.value_proj(features) if self.value_proj is not None else features values = values.view(*values.shape[:-1], weights.shape[-1], -1) # ... tokens heads dim res = torch.einsum('...nhd,...nh->...hd', values, weights) return res.view(*res.shape[:-2], -1) elif self.mode == "max": features = features.masked_fill(~mask, -100000).max(-2).values.masked_fill(~(mask.any(-2)), 0) elif self.mode == "abs-max": values, indices = features.abs().masked_fill(~mask, -100000).max(-2) features = features.gather(dim=-2, index=indices.unsqueeze(1)).squeeze(1) elif self.mode in ("sum", "mean", "first", "last"): features = features.masked_fill(~mask, 0).sum(-2) if self.mode == "mean": features = features / mask.float().sum(-2).clamp_min(1.) elif self.mode == "softmax": weights = (features.detach() * self.alpha).masked_fill(~mask, -100000).softmax(-2) features = torch.einsum('...nd,...nd->...d', weights, features.masked_fill(~mask, 0)) elif self.mode == "softmax-abs": weights = (features.detach().abs() * self.alpha).masked_fill(~mask, -100000).softmax(-2) features = torch.einsum('...nd,...nd->...d', weights, features.masked_fill(~mask, 0)) return features
def forward(self, batch_idx, pairs, negs, offsets, lists): centers = pairs[:, 0] embed_centers = self.center_embedding(centers) # N x dim embed_contexts_means = torch.stack([ F.embedding_bag( lists, self.aspect_embedding.weight[k * self.num_nodes:(k + 1) * self.num_nodes], offsets, mode=self.pooling) for k in range(self.num_aspects) ], 1) # N x K x dim if self.isSoftmax: # Apply softmax aspect_softmax = torch.bmm(embed_contexts_means, embed_centers.unsqueeze(-1)).squeeze( -1) # N x K # 1-1. Gumbel Softmax if self.isGumbelSoftmax: # In fact, following the original Gumbel-softmax, the input for F.gumbel_softmax() should be logit (i.e., unnormalized log probabilities.) # However, we found that unnormalized probabilities without log are numerically more stable, and performs on par with logit. aspect_softmax = F.gumbel_softmax(aspect_softmax, tau=self.tau_gumbel, hard=self.isHard) elif self.isNormalSoftmax: # 1-2. Softmax aspect_softmax = F.softmax(aspect_softmax, dim=1) contexts = pairs[:, 1] total_contexts_idxs = torch.cat([ k * self.num_nodes + contexts.unsqueeze(1) for k in range(self.num_aspects) ], 1) aspect_embedding_context = self.aspect_embedding( total_contexts_idxs) # N x K x dim score_pos = torch.bmm( aspect_embedding_context, embed_centers.unsqueeze(-1)).squeeze( -1) # (N x K x dim) x (N x dim x 1) = (N x K) score_pos = -F.logsigmoid(score_pos) embed_contexts_negs = [ self.aspect_embedding(k * self.num_nodes + negs) for k in range(self.num_aspects) ] # [N x num_neg x dim] * K score_negs = [ torch.bmm(embed_contexts_neg, embed_centers.unsqueeze(-1)).squeeze(-1) for k, embed_contexts_neg in enumerate(embed_contexts_negs) ] score_neg = torch.stack([ -torch.sum(F.logsigmoid(-score_neg), dim=1) for score_neg in score_negs ], 1) if self.isSoftmax: sg_loss = aspect_softmax * (score_pos + score_neg) else: sg_loss = (score_pos + score_neg) / self.num_aspects sg_loss = torch.mean(sg_loss) final_loss = sg_loss # Aspect regularization if self.isReg: div_metric = None # N x K x dim aspect_emb_reshaped = self.aspect_embedding.weight.view( self.num_aspects, self.num_nodes, self.dim).permute(1, 0, 2).contiguous() for i in range(self.num_aspects): for j in range(i + 1, self.num_aspects): sim_matrix = F.cosine_similarity( aspect_emb_reshaped[:, i, :], aspect_emb_reshaped[:, j, :]) mask = torch.abs(sim_matrix) > self.threshold if i == 0 and j == 1: div_metric = (torch.abs( torch.masked_select(sim_matrix, mask))).sum() else: div_metric += (torch.abs( torch.masked_select(sim_matrix, mask))).sum() div_reg = self.reg_coef * div_metric final_loss += div_reg return final_loss, (self.reg_coef * div_metric).item() return final_loss
def forward(self, x): return F.embedding_bag(x, self.hashed_weight[self.weight_idx])
def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor: return F.embedding_bag(input, self.weight_fake_quant(self.weight), offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights, self.include_last_offset, self.padding_idx)
def forward(self, input, offsets=None, per_sample_weights=None): return self.activation_quantizer( F.embedding_bag(input, self.weight_quantizer(self.weight), offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights, self.include_last_offset))
def rollout(self): # Reset environment. obs = self.env.reset() self.batch_size = len(obs) # Trajectory history. traj = [{ 'scan': ob['scan'], 'instr_id': ob['instr_id'], 'agent_pose': [(ob['viewpoint'], ob['heading'], ob['elevation'])], 'agent_mode': ['main'], 'agent_ask': [], 'agent_nav': [], 'instruction': [ob['instruction']], 'target_viewpoints': [ob['target_viewpoints']], 'nav_prob': [], 'message': [], 'time_on_task': [0], 'time': [0], 'teacher_nav': [], 'teacher_ask': [], 'teacher_reason': [], 'agent_reason': [], 'agent_reason_prob': [], 'adj_loc_list': [] } for ob in obs] # Initial decoder states. nav_a, ask_a = self.model.reset(self.batch_size) ended = [False] * self.batch_size should_encode_instruction = True info_list = [[] for _ in range(self.batch_size)] nav_logits, ask_logits, ask_reason_logits = [], [], [] nav_pos_targets = [] last_ask = [-1] * self.batch_size device = nav_a.device self.nav_loss = torch.tensor(0., device=device) self.ask_loss = torch.tensor(0., device=device) self.ask_reason_loss = torch.tensor(0., device=device) for time_step in range(self.episode_len): # Encode instruction if should_encode_instruction: ctx_seq, ctx_mask = self._text_context_variable(obs) nav_ctx, ask_ctx = self.model.encode(ctx_seq, ctx_mask) if not self.hparams.no_reset_inter: self.model.reset_text_decoder(self.batch_size) # Masks nav_a_embeds, nav_logit_mask = self._nav_action_variable(obs) ask_logit_mask = self._ask_action_variable(obs) # Visual features curr_view_features, goal_view_features = \ self._visual_feature_variable(obs) # Time feature time = self.from_numpy(np.array([ob['time'] for ob in obs])) time_on_task = self.from_numpy( np.array([ob['time_on_task'] for ob in obs])) # Query nav policy nav_logit = self.model.decode_nav(time, time_on_task, nav_a, nav_a_embeds, nav_ctx, ctx_mask, curr_view_features, goal_view_features, nav_logit_mask) # Query nav teacher nav_target_list = self.teacher.next_nav(obs) nav_a = self._next_action(nav_logit, self.nav_feedback) nav_a_list = nav_a.tolist() # Compute distribution nav_dist = self._compute_nav_dist(obs, nav_logit) nav_dist_list = nav_dist.tolist() if self.hparams.ask_baseline is None: # Query ask policy ask_logit, ask_reason_logit = self.model.decode_ask( time, time_on_task, ask_a, nav_dist, ask_ctx, ctx_mask, curr_view_features, goal_view_features, ask_logit_mask) ask_logits.append(ask_logit) ask_a = self._next_action(ask_logit, self.ask_feedback) ask_a_list = ask_a.tolist() ask_reason_logits.append(ask_reason_logit) ask_reason_prob_list = torch.sigmoid(ask_reason_logit).tolist() else: # Query ask teacher for i, ob in enumerate(obs): ob['last_ask'] = last_ask[i] ask_a_list, ask_reason = self.teacher.next_ask(obs) for i in range(self.batch_size): ask_a_list[i] = max(0, ask_a_list[i]) if ask_a_list[i] == self.ask_actions.index('request_help'): last_ask[i] = time_step should_encode_instruction = False anna_messages = [None] * self.batch_size for i in range(self.batch_size): # Perfect language instruction interpretation if self.hparams.perfect_interpretation and obs[i][ 'mode'] == 'on_route': nav_a_list[i] = max(0, nav_target_list[i]) # If request if ask_a_list[i] == self.ask_actions.index('request_help'): # Query ANNA for route instruction and departure node anna_messages[i] = self.anna(obs[i]) # Agent should not move nav_a_list[i] = 0 # Teacher nav action should be ignored nav_target_list[i] = -1 # Need to re-encode the instruction. should_encode_instruction = True else: # If agent decides to depart route, re-encode instruction if nav_a_list[i] == 0 and obs[i]['mode'] == 'on_route': should_encode_instruction = True nav_pos_targets.append(np.array(nav_target_list, dtype=np.int64)) nav_logits.append(nav_logit) nav_logit_list = nav_logit.tolist() for i in range(self.batch_size): info_list[i].append({ 'ob': obs[i], 'nav_dist': nav_dist_list[i], 'nav_target': nav_target_list[i], 'nav_a': nav_a_list[i], 'nav_argmax': int(np.argmax(nav_logit_list[i])), 'ask_a': ask_a_list[i], 'num_a': int(nav_logit.size(1)) }) # Retrieve embedding of the taken nav action. nav_a = nav_a_embeds[np.arange(self.batch_size), nav_a_list, :].detach() # Update ask action mask ask_a = torch.tensor(ask_a_list, dtype=torch.long, device=device) self.model.ask_module.update_action_mask( ask_a != self.ask_actions.index('request_help')) adj_loc_lists = [ob['adj_loc_list'] for ob in obs] # Take the nav action. obs = self.env.step(nav_a_list, anna_messages) unaligned_nav_dist = F.softmax(nav_logit, dim=1).tolist() # Book-keeping for i, ob in enumerate(obs): if not ended[i]: traj[i]['agent_pose'].append( (ob['viewpoint'], ob['heading'], ob['elevation'], i)) traj[i]['agent_mode'].append(ob['mode']) traj[i]['instruction'].append(ob['instruction']) traj[i]['target_viewpoints'].append( ob['target_viewpoints']) traj[i]['message'].append(anna_messages[i]) traj[i]['time_on_task'].append(ob['time_on_task']) traj[i]['time'].append(ob['time']) traj[i]['adj_loc_list'].append(adj_loc_lists[i]) traj[i]['teacher_nav'].append(nav_target_list[i]) if self.hparams.ask_baseline is None: agent_reasons = [] out_str = [] for k, prob in enumerate(ask_reason_prob_list[i]): label = AskTeacher.reason_labels[k] out_str.append('%s %.1f' % (label[0], prob * 100)) if prob >= 0.5: agent_reasons.append(label) traj[i]['agent_reason'].append(agent_reasons) out_str = ' '.join(out_str) traj[i]['agent_reason_prob'].append(out_str) else: traj[i]['teacher_ask'].append(ask_a_list[i]) traj[i]['teacher_reason'].append(ask_reason[i]) traj[i]['agent_ask'].append(ask_a_list[i]) prob_str = ' '.join([ ('%d-%.2f' % (loc['absViewIndex'], x)) for loc, x in zip(adj_loc_lists[i], unaligned_nav_dist[i]) ]) if ask_a_list[i] == self.ask_actions.index('request_help'): traj[i]['agent_nav'].append(-1) traj[i]['nav_prob'].append(prob_str) else: traj[i]['agent_nav'].append(nav_a_list[i]) traj[i]['nav_prob'].append( '%d %.2f %s' % (nav_a_list[i], unaligned_nav_dist[i][nav_a_list[i]], prob_str)) ended[i] |= ob['ended'] if all(ended): break for i in range(self.batch_size): info_list[i].append({'ob': obs[i]}) # RETROSPECTIVE navigation teacher # Look back at the trajectory and decide when the agent should have requested if self.hparams.ask_baseline is None: ask_targets, ask_reason_targets, ask_reasons = \ self.teacher.all_ask(info_list) for t, target, reason in zip(traj, ask_targets, ask_reasons): l = len(t['agent_ask']) t['teacher_ask'] = target[:l].tolist() t['teacher_reason'] = reason[:l] nav_neg_targets, neg_offsets = \ self.teacher.all_neg_nav(info_list) if not self.is_eval: # Help-request loss if self.hparams.ask_baseline is None: # seq_len x batch ask_targets = self.from_numpy(ask_targets.transpose()) ask_reason_targets = self.from_numpy( ask_reason_targets.swapaxes(0, 1)).float() for ask_logit, ask_target, ask_reason_logit, ask_reason_target \ in zip(ask_logits, ask_targets, ask_reason_logits, ask_reason_targets): # Ask loss ask_loss = self.ask_criterion(ask_logit, ask_target) # Ask reason loss ask_reason_loss = self.ask_reason_criterion( ask_reason_logit, ask_reason_target) ask_reason_loss = ask_reason_loss.mean(dim=-1) mask = (ask_target != -1) normalizer = mask.sum().item() if normalizer > 0: ask_reason_loss = (ask_reason_loss * mask.float()).sum() / \ normalizer else: ask_reason_loss = 0. if self.hparams.no_reason: self.ask_loss += ask_loss else: self.ask_loss += ask_loss + ask_reason_loss # Navigation loss nav_pos_targets = self.from_numpy(np.stack(nav_pos_targets)) neg_offsets = self.from_numpy(neg_offsets) for nav_logit, nav_pos_target, nav_neg_target, neg_offset in \ zip(nav_logits, nav_pos_targets, nav_neg_targets, neg_offsets): # -log P(a+) nav_pos_loss = self.nav_criterion(nav_logit, nav_pos_target) # K = number of negative actions # 1/K sum -log P(a-) nav_log_softmax = F.log_softmax(nav_logit, dim=1).view(-1, 1) nav_neg_target = self.from_numpy(nav_neg_target) nav_neg_loss = -F.embedding_bag( nav_neg_target, nav_log_softmax, neg_offset).squeeze(1) mask = (nav_pos_target != -1) normalizer = mask.sum().item() if normalizer > 0: nav_neg_loss = (nav_neg_loss * mask.float()).sum() / normalizer else: nav_neg_loss = 0. # nav_loss = -log P(a+) + alpha/K sum log P(a-) self.nav_loss += nav_pos_loss - self.hparams.alpha * nav_neg_loss self._compute_loss() return traj
def forward(self, input): return F.embedding_bag(input, self.weight_fake_quant(self.weight))