Пример #1
0
    def log_probability(self, hard_mode=False):
        if hard_mode:
            log_posterior = util.log_parametric_not(
                self._log_attention, self._quantifier.unsqueeze(1), 1)

            if self._predicate_question_map is not None:
                log_posterior = (util.mm(self._predicate_question_map,
                                         self._batch_object_map.to_dense()) *
                                 log_posterior).min(1)[0]
            else:
                log_posterior = (self._batch_object_map.to_dense() *
                                 log_posterior).min(1)[0]

            log_posterior = util.log_parametric_not(log_posterior,
                                                    self._quantifier, 1)

        else:
            log_posterior = self._log_attention.transpose(0, 1).contiguous()
            log_posterior = util.log_parametric_not(
                log_posterior, self._quantifier.unsqueeze(0), 1)

            if self._predicate_question_map is not None:
                log_posterior = util.mm(
                    self._predicate_question_map,
                    util.mm(self._batch_object_map, log_posterior))
            else:
                log_posterior = util.mm(self._batch_object_map, log_posterior)

            log_posterior = util.log_parametric_not(log_posterior.diag(),
                                                    self._quantifier, 1)

        return log_posterior
Пример #2
0
    def squeeze(self, predicate_question_map):
        new_state0 = util.mm(predicate_question_map.transpose(0, 1),
                             self._state[0])
        new_state1 = util.mm(predicate_question_map.transpose(0, 1),
                             self._state[1])

        return BatchAttentionState(self._name, self._device,
                                   (new_state0, new_state1)).to(self.dtype)
Пример #3
0
    def _compute_relation_log_likelihood(self, device, relation_list, pair_object_features, meta_data, object_num, relation_image_map, \
        object_image_map, default_log_likelihood=-30, normalized_probability=True):
        pair_num = pair_object_features['features'].size()[0]
        relation_num = len(relation_list)

        temp = itemgetter(*relation_list)(self._ontology._vocabulary['arg_to_idx'])
        if self._cached:
            if not isinstance(temp, (list, tuple)):
                temp = [temp]

            temp = (np.array(temp, dtype=np.int64) - 1).tolist()
            temp = itemgetter(*temp)(self._ontology._relation_reveresed_index)
            ind = torch.tensor(temp, dtype=torch.int64, device=device)
        else:
            ind = torch.tensor(temp, dtype=torch.int64, device=device)
            ind = ind - 1

        if ind.dim() == 0:
            ind = ind.unsqueeze(0)

        if 'relation_pairobject_map' in meta_data: # Only for direct-supervision, object-level training/testing
            relation_seq = list(range(relation_num))
            if self._cached:
                res = pair_object_features['features'][meta_data['relation_pairobject_map'], ind]
            else:
                res = self._embedding_network(self._relation_network(pair_object_features['features']))[meta_data['relation_pairobject_map'], ind]

            result = default_log_likelihood * torch.ones(relation_num, object_num, object_num, 1, dtype=pair_object_features['features'].dtype, device=device)
            result[relation_seq, pair_object_features['index'][1], pair_object_features['index'][2], 0] = res
        else:
            _, ind1, ind2 = util.find_sparse_pair_indices(relation_image_map, pair_object_features['index'][0], device, exclude_self_relations=False)
            if self._cached:
                result = pair_object_features['features'][ind2, ind[ind1]]
            else:
                result = self._embedding_network(self._relation_network(pair_object_features['features']))[ind2, ind[ind1]]

            if self._normalize and normalized_probability:
                temp = default_log_likelihood * torch.ones(relation_num, pair_num, dtype=pair_object_features['features'].dtype, device=device)
                temp[ind1, ind2] = result

                cluster_map = self._build_map(relation_image_map)
                if cluster_map is not None:
                    denom = util.mm(cluster_map.transpose(0, 1), util.safe_log(util.mm(cluster_map, temp.exp())))
                    temp = temp - denom

                temp = temp.unsqueeze(2)
            else:
                temp = default_log_likelihood * torch.ones(relation_num, pair_num, 1, dtype=pair_object_features['features'].dtype, device=device)
                temp[ind1, ind2, 0] = result

            result = default_log_likelihood * torch.ones(relation_num, object_num, object_num, 1, dtype=pair_object_features['features'].dtype, device=device)
            result[:, pair_object_features['index'][1], pair_object_features['index'][2], :] = temp

        return result
Пример #4
0
    def apply_modulations(self,
                          modulations,
                          input_variable_set,
                          predicate_question_map=None):
        if modulations is not None:
            max_activation = 10
            alpha = modulations[:, 0].unsqueeze(1) * max_activation
            beta = modulations[:, 1].unsqueeze(1) * max_activation
            c = modulations[:, 2].unsqueeze(
                1) * max_activation if modulations.size(
                )[1] > 2 else torch.ones(
                    1, device=self._device, dtype=modulations.dtype)
            d = modulations[:, 3].unsqueeze(
                1) if modulations.size()[1] > 3 else 0.5 * torch.ones(
                    1, device=self._device, dtype=modulations.dtype)
            temp = alpha * self._log_attention + util.safe_log(
                c) + util.safe_log(d)
            self._log_attention = temp - util.safe_log(
                (beta * util.log_not(self._log_attention) +
                 util.safe_log(1.0 - d)).exp() + temp.exp())

            if modulations.size()[1] > 4:
                g = modulations[:, 4].unsqueeze(1)
                if predicate_question_map is None:
                    self._log_attention = util.safe_log(
                        g * self._log_attention.exp() +
                        (1.0 - g) * input_variable_set._log_attention.exp())
                else:
                    self._log_attention = util.safe_log(
                        g * self._log_attention.exp() + (1.0 - g) *
                        util.mm(predicate_question_map,
                                input_variable_set._log_attention.exp()))

        return self
Пример #5
0
    def compute_total_log_probability(self, log_posterior, quantifier, batch_object_map, predicate_question_map):
        if log_posterior.dim() < 3:
            log_posterior = log_posterior.unsqueeze(0)

        if quantifier.dim() < 2:
            quantifier = quantifier.unsqueeze(0)

        predicate_num, arity, object_num = log_posterior.size()
        log_posterior = util.log_parametric_not(log_posterior, quantifier.unsqueeze(2), 1)

        if predicate_question_map is not None:
            log_posterior = util.mm(predicate_question_map, util.mm(batch_object_map, log_posterior.transpose(0, 2).contiguous().view(object_num, -1)))
        else:
            log_posterior = util.mm(batch_object_map, log_posterior.transpose(0, 2).contiguous().view(object_num, -1))
        
        ind = torch.arange(predicate_num, dtype=torch.int64)
        log_posterior = log_posterior.view(predicate_num, arity, predicate_num)[ind, :, ind] # predicate_num x arity

        log_posterior = util.log_parametric_not(log_posterior, quantifier, 1)

        return  # predicate_num x arity
Пример #6
0
    def _compute_attribute_log_likelihood(self, device, attribute_list, object_features, meta_data, object_num, attribute_image_map, \
        object_image_map, default_log_likelihood=-30, normalized_probability=True):
        object_num = object_features.size()[0]
        attribute_num = len(attribute_list)
        
        temp = itemgetter(*attribute_list)(self._ontology._vocabulary['arg_to_idx'])
        # if self._cached:
        #     temp = np.array(temp, dtype=np.int64) - 1
        #     temp = np.array(itemgetter(*temp.tolist())(self._ontology._attribute_reveresed_index), dtype=np.int64)
        #     ind = torch.tensor(temp, dtype=torch.int64, device=device)
        # else:
        ind = torch.tensor(temp, dtype=torch.int64, device=device)
        ind = ind - 1
        
        if ind.dim() == 0:
            ind = ind.unsqueeze(0)

        _, ind1, ind2 = util.find_sparse_pair_indices(attribute_image_map, object_image_map, device, exclude_self_relations=False)
        if self._cached:
            output_features = object_features[ind2, ind[ind1]]
        else:
            output_features = self._embedding_network(self._attribute_network(object_features))[ind2, ind[ind1]]
        
        # Reshape into the dense version
        if self._normalize and normalized_probability:
            result = default_log_likelihood * torch.ones(attribute_num, object_num, dtype=object_features.dtype, device=device)
            result[ind1, ind2] = output_features

            cluster_map = self._build_map(attribute_image_map)
            if cluster_map is not None:
                denom = util.mm(cluster_map.transpose(0, 1), util.safe_log(util.mm(cluster_map, result.exp())))
                result = result - denom

            result = result.unsqueeze(2)
        else:
            result = default_log_likelihood * torch.ones(attribute_num, object_num, 1, dtype=object_features.dtype, device=device)
            result[ind1, ind2, 0] = output_features

        return result
Пример #7
0
    def _forward_core(self, log_prior, log_likelihood, quantifiers, dim_order, batch_object_map, predicate_question_map):
        
        # log_prior is (question_num x arity x object_num)
        # log_likelihood is (predicate_num x object_num x ... x object_num)
        # quantifiers is (predicate_num x arity)
        # In most cases question_num == predicate_num otherwise predicate_question_map must be provided
        
        object_num = log_prior.size()[2]
        # batch_size = batch_object_map.size()[0]
        predicate_num = log_likelihood.size()[0]
        question_num = log_prior.size()[0]

        if predicate_question_map is not None and predicate_num != question_num:
            log_p = util.mm(predicate_question_map, log_prior.view(question_num, -1)).view(predicate_num, self._arity, -1)
        else:
            log_p = log_prior

        result = torch.zeros(predicate_num, self._arity, object_num, device=self._device, dtype=log_prior.dtype)

        if question_num > 1 and self._arity > 1:
            # Compute the indices of the diagonals of (question_num x ... x question_num)
            coeff = (question_num ** (self._arity - 1) - question_num) / (question_num - 1)
            ind = torch.arange(question_num, dtype=torch.int64, device=self._device)
            ind += (ind * coeff).long()

        if object_num > 1 and self._arity > 1:
            obj_diag_ind = np.arange(object_num, dtype=np.int64).tolist()

        for a in range(self._arity):
            i = dim_order[a] + 1
            log_posterior = log_likelihood

            for b in range(self._arity):
                j = dim_order[b] + 1
                
                if i != j:
                    # Multiply the prior
                    if self._trainable_gate:
                        log_posterior = self._nlg[j - 1](log_posterior, log_p[:, j - 1, :].view([predicate_num] + self._reshape_dim[j - 1]))
                    else:
                        log_posterior = log_posterior + log_p[:, j - 1, :].view([predicate_num] + self._reshape_dim[j - 1])
                    
                    if quantifiers[:, j - 1].numel() == 1:
                        if quantifiers[:, j - 1] == Quantifier.EXISTS:
                            log_posterior = util.safe_log(1.0 - util.safe_exp(log_posterior))
                    else:
                        log_posterior = util.log_parametric_not(log_posterior, quantifiers[:, j - 1].view([-1] + self._arity*[1]), 1)

                    # Discounting the diagonal part (self-relations)
                    # REVIEW: Only works for arity <= 2
                    log_posterior[:, obj_diag_ind, obj_diag_ind] = 0

                    s1 = log_posterior.size()
                    log_posterior = log_posterior.transpose(0, j)
                    s2 = list(log_posterior.size())

                    # if quantifiers[:, j - 1].numel() == 1:
                    #     if quantifiers[:, j - 1] == Quantifier.EXISTS:
                    #         log_posterior = util.safe_log(1.0 - util.safe_exp(log_posterior))
                    # else:
                    #     log_posterior = util.log_parametric_not(log_posterior, quantifiers[:, j - 1].view([1] + self._reshape_dim[j - 1]), 1)
                    
                    log_posterior = log_posterior.contiguous().view(s1[j], -1)
                    log_posterior = util.mm(batch_object_map, log_posterior)
                    s2[0] = question_num
                    log_posterior = log_posterior.view(s2).transpose(0, j)
                    
                    if quantifiers[:, j - 1].numel() == 1:
                        if quantifiers[:, j - 1] == Quantifier.EXISTS:
                            log_posterior = util.safe_log(1.0 - util.safe_exp(log_posterior))
                    else:
                        log_posterior = util.log_parametric_not(log_posterior, quantifiers[:, j - 1].view([-1] + self._arity*[1]), 1)

            if self._trainable_gate:
                log_posterior = self._nlg[i - 1](log_posterior, log_p[:, i - 1, :].view([predicate_num] + self._reshape_dim[i - 1]))
            else:
                log_posterior = log_posterior + log_p[:, i - 1, :].view([predicate_num] + self._reshape_dim[i - 1])
            
            log_posterior = log_posterior.transpose(1, i).contiguous().view(predicate_num, object_num, -1)

            if question_num > 1 and self._arity > 1:
                log_posterior = log_posterior[:, :, ind]
                mask = batch_object_map.transpose(0, 1).to_dense().unsqueeze(0)
                log_posterior = (log_posterior * mask).sum(dim=2)
            else:
                log_posterior = log_posterior.squeeze(2)

            result[:, i - 1, :] = log_posterior

        return result # predicate_num x arity x object_num
Пример #8
0
    def forward(self, op_id, world, subject_variable_set, object_variable_set, relation_list, predicate_question_map=None, default_log_likelihood=-30, normalized_probability=True):
        assert subject_variable_set.batch_size() == object_variable_set.batch_size(), "The subject and object variable sets must have the same batch size."
        assert subject_variable_set.object_num() == object_variable_set.object_num(), "The subject and object variable sets must have the same number of objects."

        if not isinstance(relation_list, list):
            relation_list = [relation_list]
        
        ind = [val is not None and val.strip() not in ['', '_'] for val in relation_list]
        if not any(ind):
            return subject_variable_set, object_variable_set

        device = subject_variable_set.device

        question_num = world.batch_size()
        predicate_num = len(relation_list)

        if predicate_question_map is not None and isinstance(predicate_question_map, (list, tuple)):
            all_ones = torch.ones(predicate_num, device=device, dtype=subject_variable_set.dtype)
            x_ind = torch.arange(predicate_num, dtype=torch.int64, device=device)
            y_ind = torch.tensor(predicate_question_map, dtype=torch.int64, device=device)
            index = torch.stack([x_ind, y_ind])
            
            if util.is_cuda(device):
                predicate_question_map = torch.cuda.sparse.FloatTensor(index, all_ones, 
                                torch.Size([predicate_num, question_num]), device=device)
            else:
                predicate_question_map = torch.sparse.FloatTensor(index, all_ones, 
                                torch.Size([predicate_num, question_num]), device=device)

        assert (question_num == predicate_num) or \
            (predicate_question_map is not None and predicate_question_map.size() == torch.Size([predicate_num, question_num])),\
            "Batch size mismatch."

        log_attention = torch.stack([subject_variable_set._log_attention, object_variable_set._log_attention]).transpose(0, 1).contiguous()
        dim_order = [0, 1]
        quantifier = torch.stack([subject_variable_set._quantifier, object_variable_set._quantifier]).transpose(0, 1).contiguous()

        if predicate_question_map is not None:
            quantifier = util.mm(predicate_question_map, quantifier)

        if not all(ind):
            relation_list = list(compress(relation_list, ind))

        is_any_negated, is_neg, r_list = util.detect_negations(relation_list, device)
        is_neg = torch.tensor(is_neg, dtype=subject_variable_set.dtype, device=device)
        
        if predicate_question_map is not None:
            relation_image_map = predicate_question_map._indices()[1, ind]
        else:
            relation_image_map = torch.arange(predicate_num, dtype=torch.int64, device=device)[ind]

        log_likelihood = self._oracle(TokenType.RELATION, r_list, relation_image_map, world, default_log_likelihood=default_log_likelihood, normalized_probability=normalized_probability)

        if isinstance(log_likelihood, torch.Tensor):
            if log_likelihood.dim() < 4:
                log_likelihood = log_likelihood.unsqueeze(0)

        if not all(ind):
            if isinstance(log_likelihood, torch.Tensor):
                ll = default_log_likelihood * torch.ones(predicate_num, log_likelihood.size()[1], log_likelihood.size()[2], log_likelihood.size()[3], device=device, dtype=subject_variable_set.dtype)
                
                if any(ind):
                    ind = torch.tensor(ind, dtype=torch.bool, device=device)
                    ll[ind, :, :, :] = log_likelihood
            else:
                ll = log_likelihood
                if any(ind):
                    seq = torch.arange(predicate_num, dtype=torch.int64, device=device)
                    ll['index'][0] = seq[ind][ll['index'][0]]
                    ll['size'][0] = predicate_num

            if is_any_negated:
                is_negated = torch.zeros(predicate_num, device=device, dtype=subject_variable_set.dtype)
                is_negated[ind] = is_neg
            else:
                is_negated = None

            log_attention_posterior = self._blc(log_attention, ll, quantifier, dim_order, 
                        subject_variable_set._batch_object_map, predicate_question_map, is_negated, default_log_likelihood=default_log_likelihood)

            log_attention_posterior[ind == False, 0, :] = subject_variable_set._log_attention[ind == False, :]
            log_attention_posterior[ind == False, 1, :] = object_variable_set._log_attention[ind == False, :]
        else:        
            is_negated = is_neg if is_any_negated else None

            log_attention_posterior = self._blc(log_attention, log_likelihood, quantifier, dim_order, 
                        subject_variable_set._batch_object_map, predicate_question_map, is_negated, default_log_likelihood=default_log_likelihood)

        if predicate_question_map is None:
            quantifier = subject_variable_set._quantifier
        else:
            quantifier = util.mm(predicate_question_map, subject_variable_set._quantifier.unsqueeze(1)).squeeze(1)

        new_subject_set = BatchVariableSet(subject_variable_set._name, subject_variable_set._device, 
            subject_variable_set.object_num(), predicate_num, quantifiers=quantifier, 
            log_attention=log_attention_posterior[:, 0, :], batch_object_map=subject_variable_set._batch_object_map,
            predicate_question_map=predicate_question_map, base_cumulative_loss=subject_variable_set.cumulative_loss() + object_variable_set.cumulative_loss(),
            prev_variable_sets_num=subject_variable_set._prev_variable_sets_num + object_variable_set._prev_variable_sets_num + 1).to(subject_variable_set.dtype)

        new_object_set = BatchVariableSet(object_variable_set._name, object_variable_set._device, 
            object_variable_set.object_num(), predicate_num, quantifiers=quantifier, 
            log_attention=log_attention_posterior[:, 1, :], batch_object_map=object_variable_set._batch_object_map,
            predicate_question_map=predicate_question_map, base_cumulative_loss=subject_variable_set.cumulative_loss() + object_variable_set.cumulative_loss(),
            prev_variable_sets_num=subject_variable_set._prev_variable_sets_num + object_variable_set._prev_variable_sets_num + 1).to(object_variable_set.dtype)
        
        if op_id in self._subject_modulations:
            new_subject_set = new_subject_set.apply_modulations(self._subject_modulations[op_id], subject_variable_set, predicate_question_map).to(subject_variable_set.dtype)
            self._subject_modulations.pop(op_id, 'No Key found')

        if op_id in self._object_modulations:
            new_object_set = new_object_set.apply_modulations(self._object_modulations[op_id], object_variable_set, predicate_question_map).to(object_variable_set.dtype)
            self._object_modulations.pop(op_id, 'No Key found')

        return new_subject_set, new_object_set
Пример #9
0
    def forward(self, op_id, world, variable_set, attribute_list, predicate_question_map=None, default_log_likelihood=-30, normalized_probability=True):        
        if not isinstance(attribute_list, list):
            attribute_list = [attribute_list]
        
        ind = [val is not None and val.strip() not in ['', '_'] for val in attribute_list]
        if not any(ind):
            return variable_set

        device = variable_set.device

        question_num = variable_set.batch_size()
        predicate_num = len(attribute_list)
        
        if predicate_question_map is not None and isinstance(predicate_question_map, (list, tuple)):
            all_ones = torch.ones(predicate_num, device=device, dtype=variable_set.dtype)
            x_ind = torch.arange(predicate_num, dtype=torch.int64, device=device)
            y_ind = torch.tensor(predicate_question_map, dtype=torch.int64, device=device)
            index = torch.stack([x_ind, y_ind])
            
            if util.is_cuda(device):
                predicate_question_map = torch.cuda.sparse.FloatTensor(index, all_ones, 
                                torch.Size([predicate_num, question_num]), device=device)
            else:
                predicate_question_map = torch.sparse.FloatTensor(index, all_ones, 
                                torch.Size([predicate_num, question_num]), device=device)

        assert (question_num == predicate_num) or \
            (predicate_question_map is not None and predicate_question_map.size() == torch.Size([predicate_num, question_num])),\
            "Batch size mismatch."

        quantifier = variable_set._quantifier.unsqueeze(1)
        if predicate_question_map is not None:
            quantifier = util.mm(predicate_question_map, quantifier)

        if not all(ind):
            attribute_list = list(compress(attribute_list, ind))

        is_any_negated, is_neg, a_list = util.detect_negations(attribute_list, device)
        is_neg = torch.tensor(is_neg, dtype=variable_set.dtype, device=device)

        if predicate_question_map is not None:
            attribute_image_map = predicate_question_map._indices()[1, ind]
        else:
            attribute_image_map = torch.arange(predicate_num, dtype=torch.int64, device=device)[ind]

        log_likelihood = self._oracle(TokenType.ATTRIBUTE, a_list, attribute_image_map, world, default_log_likelihood=default_log_likelihood, normalized_probability=normalized_probability)

        if isinstance(log_likelihood, torch.Tensor):
            if log_likelihood.dim() < 3:
                log_likelihood = log_likelihood.unsqueeze(0)

        if not all(ind):
            if isinstance(log_likelihood, torch.Tensor):
                ll = default_log_likelihood * torch.ones(predicate_num, log_likelihood.size()[1], log_likelihood.size()[2], device=device, dtype=variable_set.dtype)

                if any(ind):
                    ind = torch.tensor(ind, dtype=torch.bool, device=device)
                    ll[ind, :, :] = log_likelihood
            else:
                ll = log_likelihood
                if any(ind):
                    seq = torch.arange(predicate_num, dtype=torch.int64, device=device)
                    ll['index'][0] = seq[ind][ll['index'][0]]
                    ll['size'][0] = predicate_num

            if is_any_negated:
                is_negated = torch.zeros(predicate_num, device=device, dtype=variable_set.dtype)
                is_negated[ind] = is_neg
            else:
                is_negated = None

            log_attention = self._blc(variable_set._log_attention.unsqueeze(1), ll, quantifier, [0], 
                variable_set._batch_object_map, predicate_question_map, is_negated, default_log_likelihood=default_log_likelihood)

            log_attention[ind == False, 0, :] = variable_set._log_attention[ind == False, :]
        else:
            is_negated = is_neg if is_any_negated else None

            log_attention = self._blc(variable_set._log_attention.unsqueeze(1), log_likelihood, quantifier, [0], 
                variable_set._batch_object_map, predicate_question_map, is_negated, default_log_likelihood=default_log_likelihood)

        if predicate_question_map is None:
            quantifier = variable_set._quantifier
        else:
            quantifier = util.mm(predicate_question_map, variable_set._quantifier.unsqueeze(1)).squeeze(1)

        res = BatchVariableSet(variable_set._name, variable_set._device, variable_set.object_num(), predicate_num, quantifiers=quantifier, \
            log_attention=log_attention[:, 0, :], batch_object_map=variable_set._batch_object_map, predicate_question_map=predicate_question_map, \
            base_cumulative_loss=variable_set.cumulative_loss(), prev_variable_sets_num=variable_set._prev_variable_sets_num + 1).to(variable_set.dtype)

        if op_id in self._modulations:
            res = res.apply_modulations(self._modulations[op_id], variable_set, predicate_question_map).to(variable_set.dtype)
            self._modulations.pop(op_id, 'No Key found')
        
        return res
Пример #10
0
    def expand(self, predicate_question_map):
        new_state0 = util.mm(predicate_question_map, self._state[0])
        new_state1 = util.mm(predicate_question_map, self._state[1])

        return BatchAttentionState(self._name, self._device,
                                   (new_state0, new_state1)).to(self.dtype)
Пример #11
0
    def _compute_loss(self, program_batch_list, prediction):
        if prediction['type'] == QuestionType.STATEMENT:
            return -prediction['log_probability'].sum()

        if prediction['type'] == QuestionType.BINARY:
            target = []
            for program_batch in program_batch_list:
                target.append([
                    a in ['yes', 'yeah', 'yep', 'yup', 'aye', 'yea']
                    for a in program_batch._answers
                ])

            target = list(sum(target, []))
            norm = torch.ones(len(target), device=self._device)

            target = torch.tensor(target,
                                  dtype=torch.float32,
                                  device=self._device)
            loss = nn.functional.binary_cross_entropy(
                prediction['log_probability'].exp(),
                target,
                weight=1.0 / norm,
                reduction='sum')

        elif prediction['type'] == QuestionType.OBJECT_STATEMENT:
            target = []
            for program_batch in program_batch_list:
                target.append([
                    a in ['yes', 'yeah', 'yep', 'yup', 'aye', 'yea']
                    for sublist in program_batch._answers for a in sublist
                ])

            target = list(sum(target, []))
            weights = torch.tensor(list(
                sum([
                    program_batch._meta_data['weights']
                    for program_batch in program_batch_list
                ], [])),
                                   device=self._device)

            target = torch.tensor(target,
                                  dtype=torch.float32,
                                  device=self._device)
            loss = nn.functional.binary_cross_entropy(
                prediction['log_probability'].exp(),
                target,
                weight=weights,
                reduction='sum')

        elif prediction['type'] == QuestionType.QUERY:
            predicate_num = prediction['log_probability'].size()[0]
            all_answers = [
                program_batch._answers for program_batch in program_batch_list
            ]
            all_answers = list(sum(all_answers, []))

            target = [[a == o for o in op]
                      for a, op in zip(all_answers, prediction['options'])]
            question_num = len(target)
            x_ind = torch.tensor(list(
                chain.from_iterable(
                    repeat(a, len(b)) for a, b in enumerate(target))),
                                 dtype=torch.int64,
                                 device=self._device)
            y_ind = torch.arange(predicate_num,
                                 dtype=torch.int64,
                                 device=self._device)
            all_ones = torch.ones(predicate_num, device=self._device)
            index = torch.stack([x_ind, y_ind])

            if self._device.type == 'cuda':
                question_predicate_map = torch.cuda.sparse.FloatTensor(
                    index,
                    all_ones,
                    torch.Size([question_num, predicate_num]),
                    device=self._device)
            else:
                question_predicate_map = torch.sparse.FloatTensor(
                    index,
                    all_ones,
                    torch.Size([question_num, predicate_num]),
                    device=self._device)

            target = list(sum(target, []))
            target = torch.tensor(target,
                                  dtype=torch.float32,
                                  device=self._device)

            score = prediction['log_probability']  #.exp()
            loss = util.safe_log(
                util.mm(
                    question_predicate_map,
                    score.unsqueeze(1).exp())).sum() - (target * score).sum()

            # norm = torch.ones(len(target), device=self._device)
            # loss = nn.functional.binary_cross_entropy(prediction['log_probability'].exp(), target, weight=1.0 / norm, reduction='sum')

        elif prediction['type'] == QuestionType.SCENE_GRAPH:
            attr_target = torch.tensor(np.concatenate([
                pb._meta_data['attribute_answer'] for pb in program_batch_list
            ],
                                                      axis=0),
                                       dtype=torch.float32,
                                       device=self._device)
            attr_weight = torch.tensor(np.concatenate([
                pb._meta_data['attribute_weight'] for pb in program_batch_list
            ],
                                                      axis=0),
                                       dtype=torch.float32,
                                       device=self._device)

            rel_target = torch.tensor(np.concatenate([
                pb._meta_data['relation_answer'] for pb in program_batch_list
            ],
                                                     axis=0),
                                      dtype=torch.float32,
                                      device=self._device)
            rel_weight = torch.tensor(np.concatenate([
                pb._meta_data['relation_weight'] for pb in program_batch_list
            ],
                                                     axis=0),
                                      dtype=torch.float32,
                                      device=self._device)

            attr_loss = nn.functional.binary_cross_entropy(
                prediction['log_probability'][0].exp(),
                attr_target,
                weight=attr_weight,
                reduction='sum')
            rel_loss = nn.functional.binary_cross_entropy(
                prediction['log_probability'][1].exp(),
                rel_target,
                weight=rel_weight,
                reduction='sum')

            # attr_loss = nn.functional.binary_cross_entropy(prediction['log_probability'][0][:, self._ontology._non_noun_subindex].exp(), \
            #     attr_target[:, self._ontology._non_noun_subindex], weight=attr_weight[:, self._ontology._non_noun_subindex], reduction='sum')

            # w = (attr_weight[:, self._ontology._noun_subindex] * attr_target[:, self._ontology._noun_subindex]).sum(1)

            # score = prediction['log_probability'][0][:, self._ontology._noun_subindex].exp()
            # attr_loss += ((w * util.safe_log(score.exp().sum(1))).sum() - (attr_weight[:, self._ontology._noun_subindex] * attr_target[:, self._ontology._noun_subindex] * score).sum())

            # r_score = prediction['log_probability'][1].exp()
            # rel_loss = (rel_weight.sum(1) * util.safe_log(r_score.exp().sum(1))).sum() - (rel_weight * rel_target * r_score).sum()

            loss = attr_loss + rel_loss

        if 'l1_lambda' in self._config and self._config['l1_lambda'] > 0:
            all_params = torch.cat([
                x.view(-1) for x in filter(lambda p: p.requires_grad,
                                           self.model.parameters())
            ])
            loss += (self._config['l1_lambda'] * torch.norm(all_params, 1) /
                     max(1, all_params.numel()))

        return loss