def _compute_relation_log_likelihood(self, device, relation_list, pair_object_features, meta_data, object_num, relation_image_map, \ object_image_map, default_log_likelihood=-30, normalized_probability=True): pair_num = pair_object_features['features'].size()[0] relation_num = len(relation_list) temp = itemgetter(*relation_list)(self._ontology._vocabulary['arg_to_idx']) if self._cached: if not isinstance(temp, (list, tuple)): temp = [temp] temp = (np.array(temp, dtype=np.int64) - 1).tolist() temp = itemgetter(*temp)(self._ontology._relation_reveresed_index) ind = torch.tensor(temp, dtype=torch.int64, device=device) else: ind = torch.tensor(temp, dtype=torch.int64, device=device) ind = ind - 1 if ind.dim() == 0: ind = ind.unsqueeze(0) if 'relation_pairobject_map' in meta_data: # Only for direct-supervision, object-level training/testing relation_seq = list(range(relation_num)) if self._cached: res = pair_object_features['features'][meta_data['relation_pairobject_map'], ind] else: res = self._embedding_network(self._relation_network(pair_object_features['features']))[meta_data['relation_pairobject_map'], ind] result = default_log_likelihood * torch.ones(relation_num, object_num, object_num, 1, dtype=pair_object_features['features'].dtype, device=device) result[relation_seq, pair_object_features['index'][1], pair_object_features['index'][2], 0] = res else: _, ind1, ind2 = util.find_sparse_pair_indices(relation_image_map, pair_object_features['index'][0], device, exclude_self_relations=False) if self._cached: result = pair_object_features['features'][ind2, ind[ind1]] else: result = self._embedding_network(self._relation_network(pair_object_features['features']))[ind2, ind[ind1]] if self._normalize and normalized_probability: temp = default_log_likelihood * torch.ones(relation_num, pair_num, dtype=pair_object_features['features'].dtype, device=device) temp[ind1, ind2] = result cluster_map = self._build_map(relation_image_map) if cluster_map is not None: denom = util.mm(cluster_map.transpose(0, 1), util.safe_log(util.mm(cluster_map, temp.exp()))) temp = temp - denom temp = temp.unsqueeze(2) else: temp = default_log_likelihood * torch.ones(relation_num, pair_num, 1, dtype=pair_object_features['features'].dtype, device=device) temp[ind1, ind2, 0] = result result = default_log_likelihood * torch.ones(relation_num, object_num, object_num, 1, dtype=pair_object_features['features'].dtype, device=device) result[:, pair_object_features['index'][1], pair_object_features['index'][2], :] = temp return result
def apply_modulations(self, modulations, input_variable_set, predicate_question_map=None): if modulations is not None: max_activation = 10 alpha = modulations[:, 0].unsqueeze(1) * max_activation beta = modulations[:, 1].unsqueeze(1) * max_activation c = modulations[:, 2].unsqueeze( 1) * max_activation if modulations.size( )[1] > 2 else torch.ones( 1, device=self._device, dtype=modulations.dtype) d = modulations[:, 3].unsqueeze( 1) if modulations.size()[1] > 3 else 0.5 * torch.ones( 1, device=self._device, dtype=modulations.dtype) temp = alpha * self._log_attention + util.safe_log( c) + util.safe_log(d) self._log_attention = temp - util.safe_log( (beta * util.log_not(self._log_attention) + util.safe_log(1.0 - d)).exp() + temp.exp()) if modulations.size()[1] > 4: g = modulations[:, 4].unsqueeze(1) if predicate_question_map is None: self._log_attention = util.safe_log( g * self._log_attention.exp() + (1.0 - g) * input_variable_set._log_attention.exp()) else: self._log_attention = util.safe_log( g * self._log_attention.exp() + (1.0 - g) * util.mm(predicate_question_map, input_variable_set._log_attention.exp())) return self
def _compute_attribute_log_likelihood(self, device, attribute_list, object_features, meta_data, object_num, attribute_image_map, object_image_map, default_log_likelihood=-30): res = torch.rand(len(attribute_list) * self._object_num, device=self._device) # print("\nAttribute likelihood:") # print(res.view(len(attribute_list), -1)) return util.safe_log(res)
def _compute_relation_log_likelihood(self, device, relation_list, pair_object_features, meta_data, object_num, relation_image_map, object_image_map, default_log_likelihood=-30): res = torch.rand(len(relation_list) * (self._object_num**2), device=self._device) # print("\nRelation likelihood:") # print(res.view(len(relation_list), self._object_num, self._object_num)) return util.safe_log(res)
def _compute_attribute_log_likelihood(self, device, attribute_list, object_features, meta_data, object_num, attribute_image_map, \ object_image_map, default_log_likelihood=-30, normalized_probability=True): object_num = object_features.size()[0] attribute_num = len(attribute_list) temp = itemgetter(*attribute_list)(self._ontology._vocabulary['arg_to_idx']) # if self._cached: # temp = np.array(temp, dtype=np.int64) - 1 # temp = np.array(itemgetter(*temp.tolist())(self._ontology._attribute_reveresed_index), dtype=np.int64) # ind = torch.tensor(temp, dtype=torch.int64, device=device) # else: ind = torch.tensor(temp, dtype=torch.int64, device=device) ind = ind - 1 if ind.dim() == 0: ind = ind.unsqueeze(0) _, ind1, ind2 = util.find_sparse_pair_indices(attribute_image_map, object_image_map, device, exclude_self_relations=False) if self._cached: output_features = object_features[ind2, ind[ind1]] else: output_features = self._embedding_network(self._attribute_network(object_features))[ind2, ind[ind1]] # Reshape into the dense version if self._normalize and normalized_probability: result = default_log_likelihood * torch.ones(attribute_num, object_num, dtype=object_features.dtype, device=device) result[ind1, ind2] = output_features cluster_map = self._build_map(attribute_image_map) if cluster_map is not None: denom = util.mm(cluster_map.transpose(0, 1), util.safe_log(util.mm(cluster_map, result.exp()))) result = result - denom result = result.unsqueeze(2) else: result = default_log_likelihood * torch.ones(attribute_num, object_num, 1, dtype=object_features.dtype, device=device) result[ind1, ind2, 0] = output_features return result
def _forward_core(self, log_prior, log_likelihood, quantifiers, dim_order, batch_object_map, predicate_question_map): # log_prior is (question_num x arity x object_num) # log_likelihood is (predicate_num x object_num x ... x object_num) # quantifiers is (predicate_num x arity) # In most cases question_num == predicate_num otherwise predicate_question_map must be provided object_num = log_prior.size()[2] # batch_size = batch_object_map.size()[0] predicate_num = log_likelihood.size()[0] question_num = log_prior.size()[0] if predicate_question_map is not None and predicate_num != question_num: log_p = util.mm(predicate_question_map, log_prior.view(question_num, -1)).view(predicate_num, self._arity, -1) else: log_p = log_prior result = torch.zeros(predicate_num, self._arity, object_num, device=self._device, dtype=log_prior.dtype) if question_num > 1 and self._arity > 1: # Compute the indices of the diagonals of (question_num x ... x question_num) coeff = (question_num ** (self._arity - 1) - question_num) / (question_num - 1) ind = torch.arange(question_num, dtype=torch.int64, device=self._device) ind += (ind * coeff).long() if object_num > 1 and self._arity > 1: obj_diag_ind = np.arange(object_num, dtype=np.int64).tolist() for a in range(self._arity): i = dim_order[a] + 1 log_posterior = log_likelihood for b in range(self._arity): j = dim_order[b] + 1 if i != j: # Multiply the prior if self._trainable_gate: log_posterior = self._nlg[j - 1](log_posterior, log_p[:, j - 1, :].view([predicate_num] + self._reshape_dim[j - 1])) else: log_posterior = log_posterior + log_p[:, j - 1, :].view([predicate_num] + self._reshape_dim[j - 1]) if quantifiers[:, j - 1].numel() == 1: if quantifiers[:, j - 1] == Quantifier.EXISTS: log_posterior = util.safe_log(1.0 - util.safe_exp(log_posterior)) else: log_posterior = util.log_parametric_not(log_posterior, quantifiers[:, j - 1].view([-1] + self._arity*[1]), 1) # Discounting the diagonal part (self-relations) # REVIEW: Only works for arity <= 2 log_posterior[:, obj_diag_ind, obj_diag_ind] = 0 s1 = log_posterior.size() log_posterior = log_posterior.transpose(0, j) s2 = list(log_posterior.size()) # if quantifiers[:, j - 1].numel() == 1: # if quantifiers[:, j - 1] == Quantifier.EXISTS: # log_posterior = util.safe_log(1.0 - util.safe_exp(log_posterior)) # else: # log_posterior = util.log_parametric_not(log_posterior, quantifiers[:, j - 1].view([1] + self._reshape_dim[j - 1]), 1) log_posterior = log_posterior.contiguous().view(s1[j], -1) log_posterior = util.mm(batch_object_map, log_posterior) s2[0] = question_num log_posterior = log_posterior.view(s2).transpose(0, j) if quantifiers[:, j - 1].numel() == 1: if quantifiers[:, j - 1] == Quantifier.EXISTS: log_posterior = util.safe_log(1.0 - util.safe_exp(log_posterior)) else: log_posterior = util.log_parametric_not(log_posterior, quantifiers[:, j - 1].view([-1] + self._arity*[1]), 1) if self._trainable_gate: log_posterior = self._nlg[i - 1](log_posterior, log_p[:, i - 1, :].view([predicate_num] + self._reshape_dim[i - 1])) else: log_posterior = log_posterior + log_p[:, i - 1, :].view([predicate_num] + self._reshape_dim[i - 1]) log_posterior = log_posterior.transpose(1, i).contiguous().view(predicate_num, object_num, -1) if question_num > 1 and self._arity > 1: log_posterior = log_posterior[:, :, ind] mask = batch_object_map.transpose(0, 1).to_dense().unsqueeze(0) log_posterior = (log_posterior * mask).sum(dim=2) else: log_posterior = log_posterior.squeeze(2) result[:, i - 1, :] = log_posterior return result # predicate_num x arity x object_num
def _compute_loss(self, program_batch_list, prediction): if prediction['type'] == QuestionType.STATEMENT: return -prediction['log_probability'].sum() if prediction['type'] == QuestionType.BINARY: target = [] for program_batch in program_batch_list: target.append([ a in ['yes', 'yeah', 'yep', 'yup', 'aye', 'yea'] for a in program_batch._answers ]) target = list(sum(target, [])) norm = torch.ones(len(target), device=self._device) target = torch.tensor(target, dtype=torch.float32, device=self._device) loss = nn.functional.binary_cross_entropy( prediction['log_probability'].exp(), target, weight=1.0 / norm, reduction='sum') elif prediction['type'] == QuestionType.OBJECT_STATEMENT: target = [] for program_batch in program_batch_list: target.append([ a in ['yes', 'yeah', 'yep', 'yup', 'aye', 'yea'] for sublist in program_batch._answers for a in sublist ]) target = list(sum(target, [])) weights = torch.tensor(list( sum([ program_batch._meta_data['weights'] for program_batch in program_batch_list ], [])), device=self._device) target = torch.tensor(target, dtype=torch.float32, device=self._device) loss = nn.functional.binary_cross_entropy( prediction['log_probability'].exp(), target, weight=weights, reduction='sum') elif prediction['type'] == QuestionType.QUERY: predicate_num = prediction['log_probability'].size()[0] all_answers = [ program_batch._answers for program_batch in program_batch_list ] all_answers = list(sum(all_answers, [])) target = [[a == o for o in op] for a, op in zip(all_answers, prediction['options'])] question_num = len(target) x_ind = torch.tensor(list( chain.from_iterable( repeat(a, len(b)) for a, b in enumerate(target))), dtype=torch.int64, device=self._device) y_ind = torch.arange(predicate_num, dtype=torch.int64, device=self._device) all_ones = torch.ones(predicate_num, device=self._device) index = torch.stack([x_ind, y_ind]) if self._device.type == 'cuda': question_predicate_map = torch.cuda.sparse.FloatTensor( index, all_ones, torch.Size([question_num, predicate_num]), device=self._device) else: question_predicate_map = torch.sparse.FloatTensor( index, all_ones, torch.Size([question_num, predicate_num]), device=self._device) target = list(sum(target, [])) target = torch.tensor(target, dtype=torch.float32, device=self._device) score = prediction['log_probability'] #.exp() loss = util.safe_log( util.mm( question_predicate_map, score.unsqueeze(1).exp())).sum() - (target * score).sum() # norm = torch.ones(len(target), device=self._device) # loss = nn.functional.binary_cross_entropy(prediction['log_probability'].exp(), target, weight=1.0 / norm, reduction='sum') elif prediction['type'] == QuestionType.SCENE_GRAPH: attr_target = torch.tensor(np.concatenate([ pb._meta_data['attribute_answer'] for pb in program_batch_list ], axis=0), dtype=torch.float32, device=self._device) attr_weight = torch.tensor(np.concatenate([ pb._meta_data['attribute_weight'] for pb in program_batch_list ], axis=0), dtype=torch.float32, device=self._device) rel_target = torch.tensor(np.concatenate([ pb._meta_data['relation_answer'] for pb in program_batch_list ], axis=0), dtype=torch.float32, device=self._device) rel_weight = torch.tensor(np.concatenate([ pb._meta_data['relation_weight'] for pb in program_batch_list ], axis=0), dtype=torch.float32, device=self._device) attr_loss = nn.functional.binary_cross_entropy( prediction['log_probability'][0].exp(), attr_target, weight=attr_weight, reduction='sum') rel_loss = nn.functional.binary_cross_entropy( prediction['log_probability'][1].exp(), rel_target, weight=rel_weight, reduction='sum') # attr_loss = nn.functional.binary_cross_entropy(prediction['log_probability'][0][:, self._ontology._non_noun_subindex].exp(), \ # attr_target[:, self._ontology._non_noun_subindex], weight=attr_weight[:, self._ontology._non_noun_subindex], reduction='sum') # w = (attr_weight[:, self._ontology._noun_subindex] * attr_target[:, self._ontology._noun_subindex]).sum(1) # score = prediction['log_probability'][0][:, self._ontology._noun_subindex].exp() # attr_loss += ((w * util.safe_log(score.exp().sum(1))).sum() - (attr_weight[:, self._ontology._noun_subindex] * attr_target[:, self._ontology._noun_subindex] * score).sum()) # r_score = prediction['log_probability'][1].exp() # rel_loss = (rel_weight.sum(1) * util.safe_log(r_score.exp().sum(1))).sum() - (rel_weight * rel_target * r_score).sum() loss = attr_loss + rel_loss if 'l1_lambda' in self._config and self._config['l1_lambda'] > 0: all_params = torch.cat([ x.view(-1) for x in filter(lambda p: p.requires_grad, self.model.parameters()) ]) loss += (self._config['l1_lambda'] * torch.norm(all_params, 1) / max(1, all_params.numel())) return loss