Example #1
0
    def _compute_relation_log_likelihood(self, device, relation_list, pair_object_features, meta_data, object_num, relation_image_map, \
        object_image_map, default_log_likelihood=-30, normalized_probability=True):
        pair_num = pair_object_features['features'].size()[0]
        relation_num = len(relation_list)

        temp = itemgetter(*relation_list)(self._ontology._vocabulary['arg_to_idx'])
        if self._cached:
            if not isinstance(temp, (list, tuple)):
                temp = [temp]

            temp = (np.array(temp, dtype=np.int64) - 1).tolist()
            temp = itemgetter(*temp)(self._ontology._relation_reveresed_index)
            ind = torch.tensor(temp, dtype=torch.int64, device=device)
        else:
            ind = torch.tensor(temp, dtype=torch.int64, device=device)
            ind = ind - 1

        if ind.dim() == 0:
            ind = ind.unsqueeze(0)

        if 'relation_pairobject_map' in meta_data: # Only for direct-supervision, object-level training/testing
            relation_seq = list(range(relation_num))
            if self._cached:
                res = pair_object_features['features'][meta_data['relation_pairobject_map'], ind]
            else:
                res = self._embedding_network(self._relation_network(pair_object_features['features']))[meta_data['relation_pairobject_map'], ind]

            result = default_log_likelihood * torch.ones(relation_num, object_num, object_num, 1, dtype=pair_object_features['features'].dtype, device=device)
            result[relation_seq, pair_object_features['index'][1], pair_object_features['index'][2], 0] = res
        else:
            _, ind1, ind2 = util.find_sparse_pair_indices(relation_image_map, pair_object_features['index'][0], device, exclude_self_relations=False)
            if self._cached:
                result = pair_object_features['features'][ind2, ind[ind1]]
            else:
                result = self._embedding_network(self._relation_network(pair_object_features['features']))[ind2, ind[ind1]]

            if self._normalize and normalized_probability:
                temp = default_log_likelihood * torch.ones(relation_num, pair_num, dtype=pair_object_features['features'].dtype, device=device)
                temp[ind1, ind2] = result

                cluster_map = self._build_map(relation_image_map)
                if cluster_map is not None:
                    denom = util.mm(cluster_map.transpose(0, 1), util.safe_log(util.mm(cluster_map, temp.exp())))
                    temp = temp - denom

                temp = temp.unsqueeze(2)
            else:
                temp = default_log_likelihood * torch.ones(relation_num, pair_num, 1, dtype=pair_object_features['features'].dtype, device=device)
                temp[ind1, ind2, 0] = result

            result = default_log_likelihood * torch.ones(relation_num, object_num, object_num, 1, dtype=pair_object_features['features'].dtype, device=device)
            result[:, pair_object_features['index'][1], pair_object_features['index'][2], :] = temp

        return result
    def apply_modulations(self,
                          modulations,
                          input_variable_set,
                          predicate_question_map=None):
        if modulations is not None:
            max_activation = 10
            alpha = modulations[:, 0].unsqueeze(1) * max_activation
            beta = modulations[:, 1].unsqueeze(1) * max_activation
            c = modulations[:, 2].unsqueeze(
                1) * max_activation if modulations.size(
                )[1] > 2 else torch.ones(
                    1, device=self._device, dtype=modulations.dtype)
            d = modulations[:, 3].unsqueeze(
                1) if modulations.size()[1] > 3 else 0.5 * torch.ones(
                    1, device=self._device, dtype=modulations.dtype)
            temp = alpha * self._log_attention + util.safe_log(
                c) + util.safe_log(d)
            self._log_attention = temp - util.safe_log(
                (beta * util.log_not(self._log_attention) +
                 util.safe_log(1.0 - d)).exp() + temp.exp())

            if modulations.size()[1] > 4:
                g = modulations[:, 4].unsqueeze(1)
                if predicate_question_map is None:
                    self._log_attention = util.safe_log(
                        g * self._log_attention.exp() +
                        (1.0 - g) * input_variable_set._log_attention.exp())
                else:
                    self._log_attention = util.safe_log(
                        g * self._log_attention.exp() + (1.0 - g) *
                        util.mm(predicate_question_map,
                                input_variable_set._log_attention.exp()))

        return self
Example #3
0
 def _compute_attribute_log_likelihood(self,
                                       device,
                                       attribute_list,
                                       object_features,
                                       meta_data,
                                       object_num,
                                       attribute_image_map,
                                       object_image_map,
                                       default_log_likelihood=-30):
     res = torch.rand(len(attribute_list) * self._object_num,
                      device=self._device)
     # print("\nAttribute likelihood:")
     # print(res.view(len(attribute_list), -1))
     return util.safe_log(res)
Example #4
0
 def _compute_relation_log_likelihood(self,
                                      device,
                                      relation_list,
                                      pair_object_features,
                                      meta_data,
                                      object_num,
                                      relation_image_map,
                                      object_image_map,
                                      default_log_likelihood=-30):
     res = torch.rand(len(relation_list) * (self._object_num**2),
                      device=self._device)
     # print("\nRelation likelihood:")
     # print(res.view(len(relation_list), self._object_num, self._object_num))
     return util.safe_log(res)
Example #5
0
    def _compute_attribute_log_likelihood(self, device, attribute_list, object_features, meta_data, object_num, attribute_image_map, \
        object_image_map, default_log_likelihood=-30, normalized_probability=True):
        object_num = object_features.size()[0]
        attribute_num = len(attribute_list)
        
        temp = itemgetter(*attribute_list)(self._ontology._vocabulary['arg_to_idx'])
        # if self._cached:
        #     temp = np.array(temp, dtype=np.int64) - 1
        #     temp = np.array(itemgetter(*temp.tolist())(self._ontology._attribute_reveresed_index), dtype=np.int64)
        #     ind = torch.tensor(temp, dtype=torch.int64, device=device)
        # else:
        ind = torch.tensor(temp, dtype=torch.int64, device=device)
        ind = ind - 1
        
        if ind.dim() == 0:
            ind = ind.unsqueeze(0)

        _, ind1, ind2 = util.find_sparse_pair_indices(attribute_image_map, object_image_map, device, exclude_self_relations=False)
        if self._cached:
            output_features = object_features[ind2, ind[ind1]]
        else:
            output_features = self._embedding_network(self._attribute_network(object_features))[ind2, ind[ind1]]
        
        # Reshape into the dense version
        if self._normalize and normalized_probability:
            result = default_log_likelihood * torch.ones(attribute_num, object_num, dtype=object_features.dtype, device=device)
            result[ind1, ind2] = output_features

            cluster_map = self._build_map(attribute_image_map)
            if cluster_map is not None:
                denom = util.mm(cluster_map.transpose(0, 1), util.safe_log(util.mm(cluster_map, result.exp())))
                result = result - denom

            result = result.unsqueeze(2)
        else:
            result = default_log_likelihood * torch.ones(attribute_num, object_num, 1, dtype=object_features.dtype, device=device)
            result[ind1, ind2, 0] = output_features

        return result
Example #6
0
    def _forward_core(self, log_prior, log_likelihood, quantifiers, dim_order, batch_object_map, predicate_question_map):
        
        # log_prior is (question_num x arity x object_num)
        # log_likelihood is (predicate_num x object_num x ... x object_num)
        # quantifiers is (predicate_num x arity)
        # In most cases question_num == predicate_num otherwise predicate_question_map must be provided
        
        object_num = log_prior.size()[2]
        # batch_size = batch_object_map.size()[0]
        predicate_num = log_likelihood.size()[0]
        question_num = log_prior.size()[0]

        if predicate_question_map is not None and predicate_num != question_num:
            log_p = util.mm(predicate_question_map, log_prior.view(question_num, -1)).view(predicate_num, self._arity, -1)
        else:
            log_p = log_prior

        result = torch.zeros(predicate_num, self._arity, object_num, device=self._device, dtype=log_prior.dtype)

        if question_num > 1 and self._arity > 1:
            # Compute the indices of the diagonals of (question_num x ... x question_num)
            coeff = (question_num ** (self._arity - 1) - question_num) / (question_num - 1)
            ind = torch.arange(question_num, dtype=torch.int64, device=self._device)
            ind += (ind * coeff).long()

        if object_num > 1 and self._arity > 1:
            obj_diag_ind = np.arange(object_num, dtype=np.int64).tolist()

        for a in range(self._arity):
            i = dim_order[a] + 1
            log_posterior = log_likelihood

            for b in range(self._arity):
                j = dim_order[b] + 1
                
                if i != j:
                    # Multiply the prior
                    if self._trainable_gate:
                        log_posterior = self._nlg[j - 1](log_posterior, log_p[:, j - 1, :].view([predicate_num] + self._reshape_dim[j - 1]))
                    else:
                        log_posterior = log_posterior + log_p[:, j - 1, :].view([predicate_num] + self._reshape_dim[j - 1])
                    
                    if quantifiers[:, j - 1].numel() == 1:
                        if quantifiers[:, j - 1] == Quantifier.EXISTS:
                            log_posterior = util.safe_log(1.0 - util.safe_exp(log_posterior))
                    else:
                        log_posterior = util.log_parametric_not(log_posterior, quantifiers[:, j - 1].view([-1] + self._arity*[1]), 1)

                    # Discounting the diagonal part (self-relations)
                    # REVIEW: Only works for arity <= 2
                    log_posterior[:, obj_diag_ind, obj_diag_ind] = 0

                    s1 = log_posterior.size()
                    log_posterior = log_posterior.transpose(0, j)
                    s2 = list(log_posterior.size())

                    # if quantifiers[:, j - 1].numel() == 1:
                    #     if quantifiers[:, j - 1] == Quantifier.EXISTS:
                    #         log_posterior = util.safe_log(1.0 - util.safe_exp(log_posterior))
                    # else:
                    #     log_posterior = util.log_parametric_not(log_posterior, quantifiers[:, j - 1].view([1] + self._reshape_dim[j - 1]), 1)
                    
                    log_posterior = log_posterior.contiguous().view(s1[j], -1)
                    log_posterior = util.mm(batch_object_map, log_posterior)
                    s2[0] = question_num
                    log_posterior = log_posterior.view(s2).transpose(0, j)
                    
                    if quantifiers[:, j - 1].numel() == 1:
                        if quantifiers[:, j - 1] == Quantifier.EXISTS:
                            log_posterior = util.safe_log(1.0 - util.safe_exp(log_posterior))
                    else:
                        log_posterior = util.log_parametric_not(log_posterior, quantifiers[:, j - 1].view([-1] + self._arity*[1]), 1)

            if self._trainable_gate:
                log_posterior = self._nlg[i - 1](log_posterior, log_p[:, i - 1, :].view([predicate_num] + self._reshape_dim[i - 1]))
            else:
                log_posterior = log_posterior + log_p[:, i - 1, :].view([predicate_num] + self._reshape_dim[i - 1])
            
            log_posterior = log_posterior.transpose(1, i).contiguous().view(predicate_num, object_num, -1)

            if question_num > 1 and self._arity > 1:
                log_posterior = log_posterior[:, :, ind]
                mask = batch_object_map.transpose(0, 1).to_dense().unsqueeze(0)
                log_posterior = (log_posterior * mask).sum(dim=2)
            else:
                log_posterior = log_posterior.squeeze(2)

            result[:, i - 1, :] = log_posterior

        return result # predicate_num x arity x object_num
Example #7
0
    def _compute_loss(self, program_batch_list, prediction):
        if prediction['type'] == QuestionType.STATEMENT:
            return -prediction['log_probability'].sum()

        if prediction['type'] == QuestionType.BINARY:
            target = []
            for program_batch in program_batch_list:
                target.append([
                    a in ['yes', 'yeah', 'yep', 'yup', 'aye', 'yea']
                    for a in program_batch._answers
                ])

            target = list(sum(target, []))
            norm = torch.ones(len(target), device=self._device)

            target = torch.tensor(target,
                                  dtype=torch.float32,
                                  device=self._device)
            loss = nn.functional.binary_cross_entropy(
                prediction['log_probability'].exp(),
                target,
                weight=1.0 / norm,
                reduction='sum')

        elif prediction['type'] == QuestionType.OBJECT_STATEMENT:
            target = []
            for program_batch in program_batch_list:
                target.append([
                    a in ['yes', 'yeah', 'yep', 'yup', 'aye', 'yea']
                    for sublist in program_batch._answers for a in sublist
                ])

            target = list(sum(target, []))
            weights = torch.tensor(list(
                sum([
                    program_batch._meta_data['weights']
                    for program_batch in program_batch_list
                ], [])),
                                   device=self._device)

            target = torch.tensor(target,
                                  dtype=torch.float32,
                                  device=self._device)
            loss = nn.functional.binary_cross_entropy(
                prediction['log_probability'].exp(),
                target,
                weight=weights,
                reduction='sum')

        elif prediction['type'] == QuestionType.QUERY:
            predicate_num = prediction['log_probability'].size()[0]
            all_answers = [
                program_batch._answers for program_batch in program_batch_list
            ]
            all_answers = list(sum(all_answers, []))

            target = [[a == o for o in op]
                      for a, op in zip(all_answers, prediction['options'])]
            question_num = len(target)
            x_ind = torch.tensor(list(
                chain.from_iterable(
                    repeat(a, len(b)) for a, b in enumerate(target))),
                                 dtype=torch.int64,
                                 device=self._device)
            y_ind = torch.arange(predicate_num,
                                 dtype=torch.int64,
                                 device=self._device)
            all_ones = torch.ones(predicate_num, device=self._device)
            index = torch.stack([x_ind, y_ind])

            if self._device.type == 'cuda':
                question_predicate_map = torch.cuda.sparse.FloatTensor(
                    index,
                    all_ones,
                    torch.Size([question_num, predicate_num]),
                    device=self._device)
            else:
                question_predicate_map = torch.sparse.FloatTensor(
                    index,
                    all_ones,
                    torch.Size([question_num, predicate_num]),
                    device=self._device)

            target = list(sum(target, []))
            target = torch.tensor(target,
                                  dtype=torch.float32,
                                  device=self._device)

            score = prediction['log_probability']  #.exp()
            loss = util.safe_log(
                util.mm(
                    question_predicate_map,
                    score.unsqueeze(1).exp())).sum() - (target * score).sum()

            # norm = torch.ones(len(target), device=self._device)
            # loss = nn.functional.binary_cross_entropy(prediction['log_probability'].exp(), target, weight=1.0 / norm, reduction='sum')

        elif prediction['type'] == QuestionType.SCENE_GRAPH:
            attr_target = torch.tensor(np.concatenate([
                pb._meta_data['attribute_answer'] for pb in program_batch_list
            ],
                                                      axis=0),
                                       dtype=torch.float32,
                                       device=self._device)
            attr_weight = torch.tensor(np.concatenate([
                pb._meta_data['attribute_weight'] for pb in program_batch_list
            ],
                                                      axis=0),
                                       dtype=torch.float32,
                                       device=self._device)

            rel_target = torch.tensor(np.concatenate([
                pb._meta_data['relation_answer'] for pb in program_batch_list
            ],
                                                     axis=0),
                                      dtype=torch.float32,
                                      device=self._device)
            rel_weight = torch.tensor(np.concatenate([
                pb._meta_data['relation_weight'] for pb in program_batch_list
            ],
                                                     axis=0),
                                      dtype=torch.float32,
                                      device=self._device)

            attr_loss = nn.functional.binary_cross_entropy(
                prediction['log_probability'][0].exp(),
                attr_target,
                weight=attr_weight,
                reduction='sum')
            rel_loss = nn.functional.binary_cross_entropy(
                prediction['log_probability'][1].exp(),
                rel_target,
                weight=rel_weight,
                reduction='sum')

            # attr_loss = nn.functional.binary_cross_entropy(prediction['log_probability'][0][:, self._ontology._non_noun_subindex].exp(), \
            #     attr_target[:, self._ontology._non_noun_subindex], weight=attr_weight[:, self._ontology._non_noun_subindex], reduction='sum')

            # w = (attr_weight[:, self._ontology._noun_subindex] * attr_target[:, self._ontology._noun_subindex]).sum(1)

            # score = prediction['log_probability'][0][:, self._ontology._noun_subindex].exp()
            # attr_loss += ((w * util.safe_log(score.exp().sum(1))).sum() - (attr_weight[:, self._ontology._noun_subindex] * attr_target[:, self._ontology._noun_subindex] * score).sum())

            # r_score = prediction['log_probability'][1].exp()
            # rel_loss = (rel_weight.sum(1) * util.safe_log(r_score.exp().sum(1))).sum() - (rel_weight * rel_target * r_score).sum()

            loss = attr_loss + rel_loss

        if 'l1_lambda' in self._config and self._config['l1_lambda'] > 0:
            all_params = torch.cat([
                x.view(-1) for x in filter(lambda p: p.requires_grad,
                                           self.model.parameters())
            ])
            loss += (self._config['l1_lambda'] * torch.norm(all_params, 1) /
                     max(1, all_params.numel()))

        return loss