def log_probability(self, hard_mode=False): if hard_mode: log_posterior = util.log_parametric_not( self._log_attention, self._quantifier.unsqueeze(1), 1) if self._predicate_question_map is not None: log_posterior = (util.mm(self._predicate_question_map, self._batch_object_map.to_dense()) * log_posterior).min(1)[0] else: log_posterior = (self._batch_object_map.to_dense() * log_posterior).min(1)[0] log_posterior = util.log_parametric_not(log_posterior, self._quantifier, 1) else: log_posterior = self._log_attention.transpose(0, 1).contiguous() log_posterior = util.log_parametric_not( log_posterior, self._quantifier.unsqueeze(0), 1) if self._predicate_question_map is not None: log_posterior = util.mm( self._predicate_question_map, util.mm(self._batch_object_map, log_posterior)) else: log_posterior = util.mm(self._batch_object_map, log_posterior) log_posterior = util.log_parametric_not(log_posterior.diag(), self._quantifier, 1) return log_posterior
def squeeze(self, predicate_question_map): new_state0 = util.mm(predicate_question_map.transpose(0, 1), self._state[0]) new_state1 = util.mm(predicate_question_map.transpose(0, 1), self._state[1]) return BatchAttentionState(self._name, self._device, (new_state0, new_state1)).to(self.dtype)
def _compute_relation_log_likelihood(self, device, relation_list, pair_object_features, meta_data, object_num, relation_image_map, \ object_image_map, default_log_likelihood=-30, normalized_probability=True): pair_num = pair_object_features['features'].size()[0] relation_num = len(relation_list) temp = itemgetter(*relation_list)(self._ontology._vocabulary['arg_to_idx']) if self._cached: if not isinstance(temp, (list, tuple)): temp = [temp] temp = (np.array(temp, dtype=np.int64) - 1).tolist() temp = itemgetter(*temp)(self._ontology._relation_reveresed_index) ind = torch.tensor(temp, dtype=torch.int64, device=device) else: ind = torch.tensor(temp, dtype=torch.int64, device=device) ind = ind - 1 if ind.dim() == 0: ind = ind.unsqueeze(0) if 'relation_pairobject_map' in meta_data: # Only for direct-supervision, object-level training/testing relation_seq = list(range(relation_num)) if self._cached: res = pair_object_features['features'][meta_data['relation_pairobject_map'], ind] else: res = self._embedding_network(self._relation_network(pair_object_features['features']))[meta_data['relation_pairobject_map'], ind] result = default_log_likelihood * torch.ones(relation_num, object_num, object_num, 1, dtype=pair_object_features['features'].dtype, device=device) result[relation_seq, pair_object_features['index'][1], pair_object_features['index'][2], 0] = res else: _, ind1, ind2 = util.find_sparse_pair_indices(relation_image_map, pair_object_features['index'][0], device, exclude_self_relations=False) if self._cached: result = pair_object_features['features'][ind2, ind[ind1]] else: result = self._embedding_network(self._relation_network(pair_object_features['features']))[ind2, ind[ind1]] if self._normalize and normalized_probability: temp = default_log_likelihood * torch.ones(relation_num, pair_num, dtype=pair_object_features['features'].dtype, device=device) temp[ind1, ind2] = result cluster_map = self._build_map(relation_image_map) if cluster_map is not None: denom = util.mm(cluster_map.transpose(0, 1), util.safe_log(util.mm(cluster_map, temp.exp()))) temp = temp - denom temp = temp.unsqueeze(2) else: temp = default_log_likelihood * torch.ones(relation_num, pair_num, 1, dtype=pair_object_features['features'].dtype, device=device) temp[ind1, ind2, 0] = result result = default_log_likelihood * torch.ones(relation_num, object_num, object_num, 1, dtype=pair_object_features['features'].dtype, device=device) result[:, pair_object_features['index'][1], pair_object_features['index'][2], :] = temp return result
def apply_modulations(self, modulations, input_variable_set, predicate_question_map=None): if modulations is not None: max_activation = 10 alpha = modulations[:, 0].unsqueeze(1) * max_activation beta = modulations[:, 1].unsqueeze(1) * max_activation c = modulations[:, 2].unsqueeze( 1) * max_activation if modulations.size( )[1] > 2 else torch.ones( 1, device=self._device, dtype=modulations.dtype) d = modulations[:, 3].unsqueeze( 1) if modulations.size()[1] > 3 else 0.5 * torch.ones( 1, device=self._device, dtype=modulations.dtype) temp = alpha * self._log_attention + util.safe_log( c) + util.safe_log(d) self._log_attention = temp - util.safe_log( (beta * util.log_not(self._log_attention) + util.safe_log(1.0 - d)).exp() + temp.exp()) if modulations.size()[1] > 4: g = modulations[:, 4].unsqueeze(1) if predicate_question_map is None: self._log_attention = util.safe_log( g * self._log_attention.exp() + (1.0 - g) * input_variable_set._log_attention.exp()) else: self._log_attention = util.safe_log( g * self._log_attention.exp() + (1.0 - g) * util.mm(predicate_question_map, input_variable_set._log_attention.exp())) return self
def compute_total_log_probability(self, log_posterior, quantifier, batch_object_map, predicate_question_map): if log_posterior.dim() < 3: log_posterior = log_posterior.unsqueeze(0) if quantifier.dim() < 2: quantifier = quantifier.unsqueeze(0) predicate_num, arity, object_num = log_posterior.size() log_posterior = util.log_parametric_not(log_posterior, quantifier.unsqueeze(2), 1) if predicate_question_map is not None: log_posterior = util.mm(predicate_question_map, util.mm(batch_object_map, log_posterior.transpose(0, 2).contiguous().view(object_num, -1))) else: log_posterior = util.mm(batch_object_map, log_posterior.transpose(0, 2).contiguous().view(object_num, -1)) ind = torch.arange(predicate_num, dtype=torch.int64) log_posterior = log_posterior.view(predicate_num, arity, predicate_num)[ind, :, ind] # predicate_num x arity log_posterior = util.log_parametric_not(log_posterior, quantifier, 1) return # predicate_num x arity
def _compute_attribute_log_likelihood(self, device, attribute_list, object_features, meta_data, object_num, attribute_image_map, \ object_image_map, default_log_likelihood=-30, normalized_probability=True): object_num = object_features.size()[0] attribute_num = len(attribute_list) temp = itemgetter(*attribute_list)(self._ontology._vocabulary['arg_to_idx']) # if self._cached: # temp = np.array(temp, dtype=np.int64) - 1 # temp = np.array(itemgetter(*temp.tolist())(self._ontology._attribute_reveresed_index), dtype=np.int64) # ind = torch.tensor(temp, dtype=torch.int64, device=device) # else: ind = torch.tensor(temp, dtype=torch.int64, device=device) ind = ind - 1 if ind.dim() == 0: ind = ind.unsqueeze(0) _, ind1, ind2 = util.find_sparse_pair_indices(attribute_image_map, object_image_map, device, exclude_self_relations=False) if self._cached: output_features = object_features[ind2, ind[ind1]] else: output_features = self._embedding_network(self._attribute_network(object_features))[ind2, ind[ind1]] # Reshape into the dense version if self._normalize and normalized_probability: result = default_log_likelihood * torch.ones(attribute_num, object_num, dtype=object_features.dtype, device=device) result[ind1, ind2] = output_features cluster_map = self._build_map(attribute_image_map) if cluster_map is not None: denom = util.mm(cluster_map.transpose(0, 1), util.safe_log(util.mm(cluster_map, result.exp()))) result = result - denom result = result.unsqueeze(2) else: result = default_log_likelihood * torch.ones(attribute_num, object_num, 1, dtype=object_features.dtype, device=device) result[ind1, ind2, 0] = output_features return result
def _forward_core(self, log_prior, log_likelihood, quantifiers, dim_order, batch_object_map, predicate_question_map): # log_prior is (question_num x arity x object_num) # log_likelihood is (predicate_num x object_num x ... x object_num) # quantifiers is (predicate_num x arity) # In most cases question_num == predicate_num otherwise predicate_question_map must be provided object_num = log_prior.size()[2] # batch_size = batch_object_map.size()[0] predicate_num = log_likelihood.size()[0] question_num = log_prior.size()[0] if predicate_question_map is not None and predicate_num != question_num: log_p = util.mm(predicate_question_map, log_prior.view(question_num, -1)).view(predicate_num, self._arity, -1) else: log_p = log_prior result = torch.zeros(predicate_num, self._arity, object_num, device=self._device, dtype=log_prior.dtype) if question_num > 1 and self._arity > 1: # Compute the indices of the diagonals of (question_num x ... x question_num) coeff = (question_num ** (self._arity - 1) - question_num) / (question_num - 1) ind = torch.arange(question_num, dtype=torch.int64, device=self._device) ind += (ind * coeff).long() if object_num > 1 and self._arity > 1: obj_diag_ind = np.arange(object_num, dtype=np.int64).tolist() for a in range(self._arity): i = dim_order[a] + 1 log_posterior = log_likelihood for b in range(self._arity): j = dim_order[b] + 1 if i != j: # Multiply the prior if self._trainable_gate: log_posterior = self._nlg[j - 1](log_posterior, log_p[:, j - 1, :].view([predicate_num] + self._reshape_dim[j - 1])) else: log_posterior = log_posterior + log_p[:, j - 1, :].view([predicate_num] + self._reshape_dim[j - 1]) if quantifiers[:, j - 1].numel() == 1: if quantifiers[:, j - 1] == Quantifier.EXISTS: log_posterior = util.safe_log(1.0 - util.safe_exp(log_posterior)) else: log_posterior = util.log_parametric_not(log_posterior, quantifiers[:, j - 1].view([-1] + self._arity*[1]), 1) # Discounting the diagonal part (self-relations) # REVIEW: Only works for arity <= 2 log_posterior[:, obj_diag_ind, obj_diag_ind] = 0 s1 = log_posterior.size() log_posterior = log_posterior.transpose(0, j) s2 = list(log_posterior.size()) # if quantifiers[:, j - 1].numel() == 1: # if quantifiers[:, j - 1] == Quantifier.EXISTS: # log_posterior = util.safe_log(1.0 - util.safe_exp(log_posterior)) # else: # log_posterior = util.log_parametric_not(log_posterior, quantifiers[:, j - 1].view([1] + self._reshape_dim[j - 1]), 1) log_posterior = log_posterior.contiguous().view(s1[j], -1) log_posterior = util.mm(batch_object_map, log_posterior) s2[0] = question_num log_posterior = log_posterior.view(s2).transpose(0, j) if quantifiers[:, j - 1].numel() == 1: if quantifiers[:, j - 1] == Quantifier.EXISTS: log_posterior = util.safe_log(1.0 - util.safe_exp(log_posterior)) else: log_posterior = util.log_parametric_not(log_posterior, quantifiers[:, j - 1].view([-1] + self._arity*[1]), 1) if self._trainable_gate: log_posterior = self._nlg[i - 1](log_posterior, log_p[:, i - 1, :].view([predicate_num] + self._reshape_dim[i - 1])) else: log_posterior = log_posterior + log_p[:, i - 1, :].view([predicate_num] + self._reshape_dim[i - 1]) log_posterior = log_posterior.transpose(1, i).contiguous().view(predicate_num, object_num, -1) if question_num > 1 and self._arity > 1: log_posterior = log_posterior[:, :, ind] mask = batch_object_map.transpose(0, 1).to_dense().unsqueeze(0) log_posterior = (log_posterior * mask).sum(dim=2) else: log_posterior = log_posterior.squeeze(2) result[:, i - 1, :] = log_posterior return result # predicate_num x arity x object_num
def forward(self, op_id, world, subject_variable_set, object_variable_set, relation_list, predicate_question_map=None, default_log_likelihood=-30, normalized_probability=True): assert subject_variable_set.batch_size() == object_variable_set.batch_size(), "The subject and object variable sets must have the same batch size." assert subject_variable_set.object_num() == object_variable_set.object_num(), "The subject and object variable sets must have the same number of objects." if not isinstance(relation_list, list): relation_list = [relation_list] ind = [val is not None and val.strip() not in ['', '_'] for val in relation_list] if not any(ind): return subject_variable_set, object_variable_set device = subject_variable_set.device question_num = world.batch_size() predicate_num = len(relation_list) if predicate_question_map is not None and isinstance(predicate_question_map, (list, tuple)): all_ones = torch.ones(predicate_num, device=device, dtype=subject_variable_set.dtype) x_ind = torch.arange(predicate_num, dtype=torch.int64, device=device) y_ind = torch.tensor(predicate_question_map, dtype=torch.int64, device=device) index = torch.stack([x_ind, y_ind]) if util.is_cuda(device): predicate_question_map = torch.cuda.sparse.FloatTensor(index, all_ones, torch.Size([predicate_num, question_num]), device=device) else: predicate_question_map = torch.sparse.FloatTensor(index, all_ones, torch.Size([predicate_num, question_num]), device=device) assert (question_num == predicate_num) or \ (predicate_question_map is not None and predicate_question_map.size() == torch.Size([predicate_num, question_num])),\ "Batch size mismatch." log_attention = torch.stack([subject_variable_set._log_attention, object_variable_set._log_attention]).transpose(0, 1).contiguous() dim_order = [0, 1] quantifier = torch.stack([subject_variable_set._quantifier, object_variable_set._quantifier]).transpose(0, 1).contiguous() if predicate_question_map is not None: quantifier = util.mm(predicate_question_map, quantifier) if not all(ind): relation_list = list(compress(relation_list, ind)) is_any_negated, is_neg, r_list = util.detect_negations(relation_list, device) is_neg = torch.tensor(is_neg, dtype=subject_variable_set.dtype, device=device) if predicate_question_map is not None: relation_image_map = predicate_question_map._indices()[1, ind] else: relation_image_map = torch.arange(predicate_num, dtype=torch.int64, device=device)[ind] log_likelihood = self._oracle(TokenType.RELATION, r_list, relation_image_map, world, default_log_likelihood=default_log_likelihood, normalized_probability=normalized_probability) if isinstance(log_likelihood, torch.Tensor): if log_likelihood.dim() < 4: log_likelihood = log_likelihood.unsqueeze(0) if not all(ind): if isinstance(log_likelihood, torch.Tensor): ll = default_log_likelihood * torch.ones(predicate_num, log_likelihood.size()[1], log_likelihood.size()[2], log_likelihood.size()[3], device=device, dtype=subject_variable_set.dtype) if any(ind): ind = torch.tensor(ind, dtype=torch.bool, device=device) ll[ind, :, :, :] = log_likelihood else: ll = log_likelihood if any(ind): seq = torch.arange(predicate_num, dtype=torch.int64, device=device) ll['index'][0] = seq[ind][ll['index'][0]] ll['size'][0] = predicate_num if is_any_negated: is_negated = torch.zeros(predicate_num, device=device, dtype=subject_variable_set.dtype) is_negated[ind] = is_neg else: is_negated = None log_attention_posterior = self._blc(log_attention, ll, quantifier, dim_order, subject_variable_set._batch_object_map, predicate_question_map, is_negated, default_log_likelihood=default_log_likelihood) log_attention_posterior[ind == False, 0, :] = subject_variable_set._log_attention[ind == False, :] log_attention_posterior[ind == False, 1, :] = object_variable_set._log_attention[ind == False, :] else: is_negated = is_neg if is_any_negated else None log_attention_posterior = self._blc(log_attention, log_likelihood, quantifier, dim_order, subject_variable_set._batch_object_map, predicate_question_map, is_negated, default_log_likelihood=default_log_likelihood) if predicate_question_map is None: quantifier = subject_variable_set._quantifier else: quantifier = util.mm(predicate_question_map, subject_variable_set._quantifier.unsqueeze(1)).squeeze(1) new_subject_set = BatchVariableSet(subject_variable_set._name, subject_variable_set._device, subject_variable_set.object_num(), predicate_num, quantifiers=quantifier, log_attention=log_attention_posterior[:, 0, :], batch_object_map=subject_variable_set._batch_object_map, predicate_question_map=predicate_question_map, base_cumulative_loss=subject_variable_set.cumulative_loss() + object_variable_set.cumulative_loss(), prev_variable_sets_num=subject_variable_set._prev_variable_sets_num + object_variable_set._prev_variable_sets_num + 1).to(subject_variable_set.dtype) new_object_set = BatchVariableSet(object_variable_set._name, object_variable_set._device, object_variable_set.object_num(), predicate_num, quantifiers=quantifier, log_attention=log_attention_posterior[:, 1, :], batch_object_map=object_variable_set._batch_object_map, predicate_question_map=predicate_question_map, base_cumulative_loss=subject_variable_set.cumulative_loss() + object_variable_set.cumulative_loss(), prev_variable_sets_num=subject_variable_set._prev_variable_sets_num + object_variable_set._prev_variable_sets_num + 1).to(object_variable_set.dtype) if op_id in self._subject_modulations: new_subject_set = new_subject_set.apply_modulations(self._subject_modulations[op_id], subject_variable_set, predicate_question_map).to(subject_variable_set.dtype) self._subject_modulations.pop(op_id, 'No Key found') if op_id in self._object_modulations: new_object_set = new_object_set.apply_modulations(self._object_modulations[op_id], object_variable_set, predicate_question_map).to(object_variable_set.dtype) self._object_modulations.pop(op_id, 'No Key found') return new_subject_set, new_object_set
def forward(self, op_id, world, variable_set, attribute_list, predicate_question_map=None, default_log_likelihood=-30, normalized_probability=True): if not isinstance(attribute_list, list): attribute_list = [attribute_list] ind = [val is not None and val.strip() not in ['', '_'] for val in attribute_list] if not any(ind): return variable_set device = variable_set.device question_num = variable_set.batch_size() predicate_num = len(attribute_list) if predicate_question_map is not None and isinstance(predicate_question_map, (list, tuple)): all_ones = torch.ones(predicate_num, device=device, dtype=variable_set.dtype) x_ind = torch.arange(predicate_num, dtype=torch.int64, device=device) y_ind = torch.tensor(predicate_question_map, dtype=torch.int64, device=device) index = torch.stack([x_ind, y_ind]) if util.is_cuda(device): predicate_question_map = torch.cuda.sparse.FloatTensor(index, all_ones, torch.Size([predicate_num, question_num]), device=device) else: predicate_question_map = torch.sparse.FloatTensor(index, all_ones, torch.Size([predicate_num, question_num]), device=device) assert (question_num == predicate_num) or \ (predicate_question_map is not None and predicate_question_map.size() == torch.Size([predicate_num, question_num])),\ "Batch size mismatch." quantifier = variable_set._quantifier.unsqueeze(1) if predicate_question_map is not None: quantifier = util.mm(predicate_question_map, quantifier) if not all(ind): attribute_list = list(compress(attribute_list, ind)) is_any_negated, is_neg, a_list = util.detect_negations(attribute_list, device) is_neg = torch.tensor(is_neg, dtype=variable_set.dtype, device=device) if predicate_question_map is not None: attribute_image_map = predicate_question_map._indices()[1, ind] else: attribute_image_map = torch.arange(predicate_num, dtype=torch.int64, device=device)[ind] log_likelihood = self._oracle(TokenType.ATTRIBUTE, a_list, attribute_image_map, world, default_log_likelihood=default_log_likelihood, normalized_probability=normalized_probability) if isinstance(log_likelihood, torch.Tensor): if log_likelihood.dim() < 3: log_likelihood = log_likelihood.unsqueeze(0) if not all(ind): if isinstance(log_likelihood, torch.Tensor): ll = default_log_likelihood * torch.ones(predicate_num, log_likelihood.size()[1], log_likelihood.size()[2], device=device, dtype=variable_set.dtype) if any(ind): ind = torch.tensor(ind, dtype=torch.bool, device=device) ll[ind, :, :] = log_likelihood else: ll = log_likelihood if any(ind): seq = torch.arange(predicate_num, dtype=torch.int64, device=device) ll['index'][0] = seq[ind][ll['index'][0]] ll['size'][0] = predicate_num if is_any_negated: is_negated = torch.zeros(predicate_num, device=device, dtype=variable_set.dtype) is_negated[ind] = is_neg else: is_negated = None log_attention = self._blc(variable_set._log_attention.unsqueeze(1), ll, quantifier, [0], variable_set._batch_object_map, predicate_question_map, is_negated, default_log_likelihood=default_log_likelihood) log_attention[ind == False, 0, :] = variable_set._log_attention[ind == False, :] else: is_negated = is_neg if is_any_negated else None log_attention = self._blc(variable_set._log_attention.unsqueeze(1), log_likelihood, quantifier, [0], variable_set._batch_object_map, predicate_question_map, is_negated, default_log_likelihood=default_log_likelihood) if predicate_question_map is None: quantifier = variable_set._quantifier else: quantifier = util.mm(predicate_question_map, variable_set._quantifier.unsqueeze(1)).squeeze(1) res = BatchVariableSet(variable_set._name, variable_set._device, variable_set.object_num(), predicate_num, quantifiers=quantifier, \ log_attention=log_attention[:, 0, :], batch_object_map=variable_set._batch_object_map, predicate_question_map=predicate_question_map, \ base_cumulative_loss=variable_set.cumulative_loss(), prev_variable_sets_num=variable_set._prev_variable_sets_num + 1).to(variable_set.dtype) if op_id in self._modulations: res = res.apply_modulations(self._modulations[op_id], variable_set, predicate_question_map).to(variable_set.dtype) self._modulations.pop(op_id, 'No Key found') return res
def expand(self, predicate_question_map): new_state0 = util.mm(predicate_question_map, self._state[0]) new_state1 = util.mm(predicate_question_map, self._state[1]) return BatchAttentionState(self._name, self._device, (new_state0, new_state1)).to(self.dtype)
def _compute_loss(self, program_batch_list, prediction): if prediction['type'] == QuestionType.STATEMENT: return -prediction['log_probability'].sum() if prediction['type'] == QuestionType.BINARY: target = [] for program_batch in program_batch_list: target.append([ a in ['yes', 'yeah', 'yep', 'yup', 'aye', 'yea'] for a in program_batch._answers ]) target = list(sum(target, [])) norm = torch.ones(len(target), device=self._device) target = torch.tensor(target, dtype=torch.float32, device=self._device) loss = nn.functional.binary_cross_entropy( prediction['log_probability'].exp(), target, weight=1.0 / norm, reduction='sum') elif prediction['type'] == QuestionType.OBJECT_STATEMENT: target = [] for program_batch in program_batch_list: target.append([ a in ['yes', 'yeah', 'yep', 'yup', 'aye', 'yea'] for sublist in program_batch._answers for a in sublist ]) target = list(sum(target, [])) weights = torch.tensor(list( sum([ program_batch._meta_data['weights'] for program_batch in program_batch_list ], [])), device=self._device) target = torch.tensor(target, dtype=torch.float32, device=self._device) loss = nn.functional.binary_cross_entropy( prediction['log_probability'].exp(), target, weight=weights, reduction='sum') elif prediction['type'] == QuestionType.QUERY: predicate_num = prediction['log_probability'].size()[0] all_answers = [ program_batch._answers for program_batch in program_batch_list ] all_answers = list(sum(all_answers, [])) target = [[a == o for o in op] for a, op in zip(all_answers, prediction['options'])] question_num = len(target) x_ind = torch.tensor(list( chain.from_iterable( repeat(a, len(b)) for a, b in enumerate(target))), dtype=torch.int64, device=self._device) y_ind = torch.arange(predicate_num, dtype=torch.int64, device=self._device) all_ones = torch.ones(predicate_num, device=self._device) index = torch.stack([x_ind, y_ind]) if self._device.type == 'cuda': question_predicate_map = torch.cuda.sparse.FloatTensor( index, all_ones, torch.Size([question_num, predicate_num]), device=self._device) else: question_predicate_map = torch.sparse.FloatTensor( index, all_ones, torch.Size([question_num, predicate_num]), device=self._device) target = list(sum(target, [])) target = torch.tensor(target, dtype=torch.float32, device=self._device) score = prediction['log_probability'] #.exp() loss = util.safe_log( util.mm( question_predicate_map, score.unsqueeze(1).exp())).sum() - (target * score).sum() # norm = torch.ones(len(target), device=self._device) # loss = nn.functional.binary_cross_entropy(prediction['log_probability'].exp(), target, weight=1.0 / norm, reduction='sum') elif prediction['type'] == QuestionType.SCENE_GRAPH: attr_target = torch.tensor(np.concatenate([ pb._meta_data['attribute_answer'] for pb in program_batch_list ], axis=0), dtype=torch.float32, device=self._device) attr_weight = torch.tensor(np.concatenate([ pb._meta_data['attribute_weight'] for pb in program_batch_list ], axis=0), dtype=torch.float32, device=self._device) rel_target = torch.tensor(np.concatenate([ pb._meta_data['relation_answer'] for pb in program_batch_list ], axis=0), dtype=torch.float32, device=self._device) rel_weight = torch.tensor(np.concatenate([ pb._meta_data['relation_weight'] for pb in program_batch_list ], axis=0), dtype=torch.float32, device=self._device) attr_loss = nn.functional.binary_cross_entropy( prediction['log_probability'][0].exp(), attr_target, weight=attr_weight, reduction='sum') rel_loss = nn.functional.binary_cross_entropy( prediction['log_probability'][1].exp(), rel_target, weight=rel_weight, reduction='sum') # attr_loss = nn.functional.binary_cross_entropy(prediction['log_probability'][0][:, self._ontology._non_noun_subindex].exp(), \ # attr_target[:, self._ontology._non_noun_subindex], weight=attr_weight[:, self._ontology._non_noun_subindex], reduction='sum') # w = (attr_weight[:, self._ontology._noun_subindex] * attr_target[:, self._ontology._noun_subindex]).sum(1) # score = prediction['log_probability'][0][:, self._ontology._noun_subindex].exp() # attr_loss += ((w * util.safe_log(score.exp().sum(1))).sum() - (attr_weight[:, self._ontology._noun_subindex] * attr_target[:, self._ontology._noun_subindex] * score).sum()) # r_score = prediction['log_probability'][1].exp() # rel_loss = (rel_weight.sum(1) * util.safe_log(r_score.exp().sum(1))).sum() - (rel_weight * rel_target * r_score).sum() loss = attr_loss + rel_loss if 'l1_lambda' in self._config and self._config['l1_lambda'] > 0: all_params = torch.cat([ x.view(-1) for x in filter(lambda p: p.requires_grad, self.model.parameters()) ]) loss += (self._config['l1_lambda'] * torch.norm(all_params, 1) / max(1, all_params.numel())) return loss