Exemple #1
0
def get_graph_feature(x, k=20, idx=None):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.contiguous()
    x = x.view(batch_size, -1, num_points).contiguous()
    if idx is None:
        idx = knn(x, k=k)  # (batch_size, num_points, k)
    # device = torch.device('cuda')

    idx_base = torch.arange(0, batch_size).view(-1, 1, 1) * num_points
    idx_base = idx_base.cuda(torch.get_device(x))
    idx = idx + idx_base

    idx = idx.view(-1)

    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()
    try:
        feature = x.view(batch_size * num_points, -1)[idx, :]
    except:
        import ipdb
        ipdb.set_trace()
    feature = feature.view(batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)

    feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2)

    return feature
Exemple #2
0
    def forward(self, context, gold_cui_cano_and_def_concatenated, gold_cuidx,
                mention_uniq_id):
        batch_num = context['tokens'].size(0)
        contextualized_mention = self.mention_encoder(context)
        encoded_entites = self.entity_encoder(
            cano_and_def_concatnated_text=gold_cui_cano_and_def_concatenated)

        contextualized_mention_forcossim = normalize(contextualized_mention,
                                                     dim=1)
        encoded_entites_forcossim = normalize(encoded_entites, dim=1)
        # scores = contextualized_mention_forcossim.mm(encoded_entites_forcossim.t())
        scores = contextualized_mention.mm(encoded_entites.t())

        device = torch.get_device(scores) if self.cuda_flag else torch.device(
            'cpu')

        target = torch.LongTensor(torch.arange(batch_num)).to(device)
        loss = F.cross_entropy(scores, target, reduction="mean")

        output = {'loss': loss}
        if self.args.add_mse:
            output['loss'] += 0.01 * self.mesloss(contextualized_mention,
                                                  encoded_entites)

        if self.istrainflag:
            golds = torch.eye(batch_num).to(device)
            self.accuracy(scores, torch.argmax(golds, dim=1))
        else:
            output['gold_cuidx'] = gold_cuidx
            output['encoded_mentions'] = contextualized_mention
        return output
    def nms(self, centers, X, b):
        """
        Non max suprression.
        :param centers: center of clusters
        :param X: points to be clustered
        :param b: band width used to get the centers
        """
        membership = 2.0 - 2.0 * centers @ torch.transpose(X, 1, 0)

        # which cluster center is closer to the points
        membership = torch.min(membership, 0)[1]

        # Find the unique clusters which is closer to at least one point
        uniques, counts_ = np.unique(membership.data.cpu().numpy(),
                                     return_counts=True)

        # count of the number of points belonging to unique cluster ids above
        counts = torch.from_numpy(counts_.astype(np.float32)).cuda(
            torch.get_device(centers))

        num_mem_cluster = torch.zeros(
            (X.shape[0])).cuda(torch.get_device(centers))

        # Contains the count of number of points belonging to a
        # unique cluster
        num_mem_cluster[uniques] = counts

        # distance of clusters from each other
        dist = 2.0 - 2.0 * centers @ torch.transpose(centers, 1, 0)

        # find the nearest neighbors to each cluster based on some threshold
        # TODO this could be b ** 2
        cluster_nbrs = dist < b
        cluster_nbrs = cluster_nbrs.float()

        cluster_center_ids = torch.unique(
            torch.max(cluster_nbrs[uniques] * num_mem_cluster.reshape((1, -1)),
                      1)[1])
        # pruned centers
        centers = centers[cluster_center_ids]

        # assign labels to the input points
        # It is assumed that the embeddings lie on the hypershphere and are normalized
        temp = centers @ torch.transpose(X, 1, 0)
        labels = torch.max(temp, 0)[1]
        return centers, cluster_center_ids, labels
Exemple #4
0
    def forward(self, context, gold_title_and_desc_concatenated, gold_and_negatives_title_and_desc_concatenated,
                gold_duidx, mention_uniq_id, labels_for_BCEWithLogitsLoss):
        batch_num = context['tokens'].size(0)
        contextualized_mention = self.mention_encoder(context)
        encoded_entites = self.entity_encoder(title_and_desc_concatnated_text=gold_title_and_desc_concatenated)

        contextualized_mention_forcossim = normalize(contextualized_mention, dim=1)
        if self.args.search_method == 'cossim':
            encoded_entites_forcossim = normalize(encoded_entites, dim=1)
            scores = contextualized_mention_forcossim.mm(encoded_entites_forcossim.t())
        elif self.args.search_method == 'indexflatip':
            scores = contextualized_mention.mm(encoded_entites.t())
        else:
            assert self.args.search_method == 'indexflatl2'
            scores = - self.calc_L2distance(contextualized_mention.view(batch_num, 1, -1), encoded_entites) # FIXED

        device = torch.get_device(scores) if self.cuda_flag else torch.device('cpu')
        target = torch.LongTensor(torch.arange(batch_num)).to(device)

        if self.args.search_method in ['cossim','indexflatip']:
            loss = F.cross_entropy(scores, target, reduction="mean")
        else:
            loss = self.BCEWloss(scores, torch.eye(batch_num).cuda())

        output = {'loss': loss}

        if self.args.add_mse_for_biencoder:
            output['loss'] += self.mesloss(contextualized_mention, encoded_entites)

        if self.args.add_hard_negatives:
            batch_, gold_plus_negs_num, maxTokenLengthInBatch = gold_and_negatives_title_and_desc_concatenated['tokens'].size()
            docked_tokenlist = {'tokens': gold_and_negatives_title_and_desc_concatenated['tokens'].view(batch_ * gold_plus_negs_num, -1)}
            encoded_entities_from_hard_negatives_idx0isgold = self.entity_encoder(docked_tokenlist).view(batch_, gold_plus_negs_num, -1)

            if self.args.search_method == 'cossim':
                encoded_mention_forcossim_ = contextualized_mention_forcossim.repeat(1, gold_plus_negs_num).view(batch_*gold_plus_negs_num, -1)
                scores_for_hard_negatives = (((encoded_mention_forcossim_)*(normalize(encoded_entities_from_hard_negatives_idx0isgold.view(batch_*gold_plus_negs_num,-1), dim=1))).sum(1, keepdim=True).squeeze(1)).view(batch_,gold_plus_negs_num)
            elif self.args.search_method == 'indexflatip':
                encoded_mention_ = contextualized_mention.repeat(1, gold_plus_negs_num).view(batch_* gold_plus_negs_num, -1)
                scores_for_hard_negatives = (((encoded_mention_)*(encoded_entities_from_hard_negatives_idx0isgold.view(batch_*gold_plus_negs_num,-1))).sum(1, keepdim=True).squeeze(1)).view(batch_,gold_plus_negs_num)
            else:
                raise NotImplementedError

            loss += self.BCEWloss(scores_for_hard_negatives, labels_for_BCEWithLogitsLoss)

        if self.istrainflag:
            golds = torch.eye(batch_num).to(device)
            self.accuracy(scores, torch.argmax(golds, dim=1))
        else:
            output['gold_duidx'] = gold_duidx
            output['encoded_mentions'] = contextualized_mention

        return output
Exemple #5
0
    def forward(self, input, target, type='from_bottom_to_up'):
        assert (input.shape == target.shape)
        cal_func_dict = {
            'from_top_to_down': self.calc_dfs,
            'from_bottom_to_up': self.calc_dfs_from_bottom_to_up
        }
        assert (type in cal_func_dict.keys())
        cal_func = cal_func_dict[type]

        B, C, H, W = input.shape
        HL = int(math.log(H, 2))
        WL = int(math.log(W, 2))
        L = min(self.max_size, int(math.pow(2, min(HL, WL))))
        max_deep = int(math.log(L, 2))

        pad_h = -H % L
        pad_w = -W % L
        input_pad = F.pad(input, [0, pad_h, 0, pad_w], mode='reflect')
        target_pad = F.pad(target, [0, pad_h, 0, pad_w], mode='reflect')

        if C == 3:
            device = torch.get_device(
                input) if 'cuda' in input.type() else 'cpu'
            self.rgb2gray = self.rgb2gray.to(device)
            input_gray = (input_pad * self.rgb2gray).sum(dim=1).squeeze(1)
            target_gray = (target_pad * self.rgb2gray).sum(dim=1).squeeze(1)
        else:
            input_gray = input_pad.squeeze(1)
            target_gray = target_pad.squeeze(1)

        B, H, W = input_gray.shape
        loss_tile = []
        for h in range(0, H, L):
            for w in range(0, W, L):
                loss_tile.append(
                    cal_func(input_gray[:, h:h + L, w:w + L],
                             target_gray[:, h:h + L, w:w + L], max_deep))
        return sum(loss_tile) / len(loss_tile)
Exemple #6
0
    def forward(self, x, concept_matrix):
        device = torch.device(torch.get_device(x))

        x_hidden = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
        x_hidden = x_hidden.permute(0, 2, 1) # [N, T, F]
        x_hidden, _ = self.rnn(x_hidden)
        x_hidden = x_hidden[:, -1, :]

        # Predefined Concept Module
       
        stock_to_concept = concept_matrix 
        
        stock_to_concept_sum = torch.sum(stock_to_concept, 0).reshape(1, -1).repeat(stock_to_concept.shape[0], 1)
        stock_to_concept_sum = stock_to_concept_sum.mul(concept_matrix)

        stock_to_concept_sum = stock_to_concept_sum + (torch.ones(stock_to_concept.shape[0], stock_to_concept.shape[1]).to(device))
        stock_to_concept = stock_to_concept / stock_to_concept_sum #股票到tag的权重
        hidden = torch.t(stock_to_concept).mm(x_hidden) #
        
        hidden = hidden[hidden.sum(1)!=0]
        
        concept_to_stock = self.cal_cos_similarity(x_hidden, hidden) 
        concept_to_stock = self.softmax_t2s(concept_to_stock)

        e_shared_info = concept_to_stock.mm(hidden)
        e_shared_info = self.fc_es(e_shared_info)

        e_shared_back = self.fc_es_back(e_shared_info)
        output_es = self.fc_es_fore(e_shared_info)
        output_es = self.leaky_relu(output_es)

        
        # Hidden Concept Module
        i_shared_info = x_hidden - e_shared_back
        hidden = i_shared_info #每个股票都有一个hidden的tag,所以有280个hidden tags。
        i_stock_to_concept = self.cal_cos_similarity(i_shared_info, hidden) 
        dim = i_stock_to_concept.shape[0]
        diag = i_stock_to_concept.diagonal(0)
        i_stock_to_concept = i_stock_to_concept * (torch.ones(dim, dim) - torch.eye(dim)).to(device)
        row = torch.linspace(0,dim-1,dim).to(device).long()
        column =i_stock_to_concept.max(1)[1].long()
        value = i_stock_to_concept.max(1)[0]
        i_stock_to_concept[row, column] = 10
        i_stock_to_concept[i_stock_to_concept!=10]=0
        i_stock_to_concept[row, column] = value
        i_stock_to_concept = i_stock_to_concept + torch.diag_embed((i_stock_to_concept.sum(0)!=0).float()*diag)
        hidden = torch.t(i_shared_info).mm(i_stock_to_concept).t()
        hidden = hidden[hidden.sum(1)!=0]

        i_concept_to_stock = self.cal_cos_similarity(i_shared_info, hidden)
        i_concept_to_stock = self.softmax_t2s(i_concept_to_stock)
        i_shared_info = i_concept_to_stock.mm(hidden)
        i_shared_info = self.fc_is(i_shared_info)

        i_shared_back = self.fc_is_back(i_shared_info)
        output_is = self.fc_is_fore(i_shared_info)
        output_is = self.leaky_relu(output_is)

        # Individual Information Module
        individual_info  = x_hidden - e_shared_back - i_shared_back
        output_indi = individual_info
        output_indi = self.fc_indi(output_indi)
        output_indi = self.leaky_relu(output_indi)
        pred_indi = self.fc_out_indi(output_indi).squeeze()

        # Stock Trend Prediction
        all_info = output_es + output_is + output_indi
        pred_all = self.fc_out(all_info).squeeze()

        return pred_all
Exemple #7
0
    def forward(self, batch):
        img = batch['img']
        ques_len = batch['question_len']
        ques_emb = self.qembedding(batch['question'])

        #Load Chargrid (chargrid on the fly)
        labels = batch['labels']
        bboxes = batch['bboxes']
        n_label = batch["n_label"]
        #BATCH SIZE
        #chargrid = torch.zeros((labels.shape[0],256,256,39),device=torch.get_device(labels))
        chargrid = torch.zeros((labels.shape[0],39,256,256),device=torch.get_device(labels))
        #create chargrid on the fly

        for batch_id in range(labels.shape[0]):
            for label_id in range(n_label[batch_id].item()):
                x,y,x2,y2 = bboxes[batch_id,label_id,:]
                label_box = labels[batch_id,label_id].repeat((x2-x,y2-y,1)).transpose(2,0)
                chargrid[batch_id,:,y:y2,x:x2] = label_box
        #print(f"chargrid: {time.time()-start:.4f}",)
        #chargrid = chargrid.permute(0,3,1,2).contiguous()

        #Load Chargrid (chargrid created beforehand
        #chargrid = batch['chargrid']


        self.qlstm.flatten_parameters()
        ques = sequences.dynamic_rnn(self.qlstm, ques_emb, ques_len)
        # answer using questions only
        if self.kind == 'lstm':
            scores = self.qclassifier(ques)
            return F.log_softmax(scores, dim=1)
        img = self.img_net(img)
        chargrid = self.chargrid_net(chargrid)

        #Decision Factor
        #curr_dec_weight = self.dec_weight_sigm(self.decision_weight)
        #chargrid = chargrid * curr_dec_weight
        #img = img * (1-curr_dec_weight)

        #entitygrid = chargrid + img

        # answer using questions + images; no relational structure
        if self.kind == 'cnn+lstm':
            ipt = torch.cat([ques, img.view(len(img), -1)], dim=1)
            scores = self.cnn_lstm_classifier(ipt)
            return F.log_softmax(scores, dim=1)
        # RN implementation treating pixels as objects
        # (f and g as in the RN paper)
        assert self.kind == 'rn'

        #Chargrid: Concat img and chargrid and conv
        entitygrid = torch.cat([img,chargrid],dim=1)
        entitygrid = self.entitygrid_net(entitygrid)
        context = 0
        pairs = self.img_to_pairs(entitygrid, ques)
        #pairs = self.img_to_pairs(img, ques)
        N, N_pairs, _ = pairs.size()
        context = self.g(pairs.view(N*N_pairs, -1))
        context = context.view(N, N_pairs, -1).mean(dim=1)
        scores = self.f(context)
        return F.log_softmax(scores, dim=1)
Exemple #8
0
def valid(epoch,tensorboard_client,global_iteration, valid_set, load_image=True, model_name=None, val_split="val_easy"):
    #run_name = val_split
    
    print("Inside validation ", epoch)
    dataset = iter(valid_set)
    model.eval()  # eval_mode
    class_correct = Counter()
    class_total = Counter()
    prediction = []

    with torch.no_grad():

        for i, (image, question, q_len, answer, answer_class, labels, bboxes, n_label, data_index) in enumerate(tqdm(dataset)):
            image, question, q_len, labels, bboxes, n_label = (
                image.to(device),
                question.to(device),
                torch.tensor(q_len),
                labels.to(device),
                bboxes,#bboxes.to(device),
                n_label
            )

            batch_size = labels.shape[0]
            n_channel = labels.shape[-1]
            chargrid = torch.zeros((batch_size,n_channel,224,224),device=torch.get_device(labels))

            #Chargrid Creation 
            for batch_id in range(labels.shape[0]):
                for label_id in range(n_label[batch_id].item()):
                    x,y,x2,y2 = bboxes[batch_id,label_id,:]
                    label_box = labels[batch_id,label_id].repeat((x2-x,y2-y,1)).transpose(2,0)
                    chargrid[batch_id,:,y:y2,x:x2] = label_box

            output = model(image, question, q_len, chargrid)
            argmax_output = output.data.cpu().numpy().argmax(1)
            numpy_answer = answer.numpy()
            correct = argmax_output == numpy_answer
            for c, class_ in zip(correct, answer_class):
                if c:  # if correct
                    class_correct[class_] += 1
                class_total[class_] += 1

            prediction.append([data_index,numpy_answer,argmax_output])

            if (("IMG" in model_name) or ("SAN" in model_name)) and type(epoch) == type(0.1) and (
                    i * batch_size // 2) > (
                    6e4):  # intermediate train, only val on 10% of the validation set
                break  # early break validation loop

    class_correct['total'] = sum(class_correct.values())
    class_total['total'] = sum(class_total.values())

    print("class_correct", class_correct)
    print("class_total", class_total)

    #Debug
    # with open('log/log_' + model_name + '_{}_'.format(round(epoch + 1, 4)) + val_split + '.txt', 'w') as w:
    #     for k, v in class_total.items():
    #         w.write('{}: {:.5f}\n'.format(k, class_correct[k] / v))
    #     # TODO: save the model here!

    total_score = class_correct['total'] / class_total['total']
    print('Avg Acc: {:.5f}'.format(total_score))

    visualize_val(global_iteration,model,tensorboard_client,val_split,class_total,class_correct)

    return prediction,total_score
Exemple #9
0
def train(epoch,tensorboard_client,global_iteration,word_dic,answer_dic,load_image=True, model_name=None):
    run_name = "train"
    model.train(True)  # train mode
    if isinstance(model,nn.DataParallel):
        vqa_model = model.module
    else:
        vqa_model = model

    dataset = iter(train_set)
    #Debug
    pbar = tqdm(dataset)
    #pbar = dataset
    n_batch = len(pbar)
    moving_loss = 0  # it will change when loop over data

    #Chargrid: Visualize
    attention_map = SaveFeatures(vqa_model.attention)
    chargrid_act1 = SaveFeatures(vqa_model.chargrid_net[0])
    chargrid_act3 = SaveFeatures(vqa_model.chargrid_net[3])
    img_act0 = SaveFeatures(vqa_model.resnet[0])

    tensorboard_client.register_hook("chargrid_act1",chargrid_act1)
    tensorboard_client.register_hook("chargrid_act3",chargrid_act3)
    tensorboard_client.register_hook("img_act0",img_act0)

    norm_img = mpl.colors.Normalize(vmin=-1,vmax=1)
    plt.style.use('seaborn-white')

    #Train Epoch
    #start = torch.cuda.Event(enable_timing=True)
    #end = torch.cuda.Event(enable_timing=True)

    

    print(device)
    print(next(model.parameters()).is_cuda, "next(model.parameters()).is_cuda")
    #Chargrid load labels,bboxes
    for i, (image, question, q_len, answer, question_class, labels, bboxes, n_label, data_index) in enumerate(pbar):


        #start.record()
        image, question, q_len, answer,labels,bboxes,n_label = (
            image.to(device),
            question.to(device),
            torch.tensor(q_len),
            answer.to(device),
            labels.to(device),
            bboxes,#bboxes.to(device),
            n_label
        )
        #end.record()
        #torch.cuda.synchronize()
        #print("Loading: ",start.elapsed_time(end))
        #start.record()
        #Chargrid: Creation
        batch_size = labels.shape[0]
        n_channel = labels.shape[-1]
        chargrid = torch.zeros((batch_size,n_channel,224,224),device=torch.get_device(labels))
        #create chargrid on the fly
        #start = time.time()
        for batch_id in range(labels.shape[0]):
            for label_id in range(n_label[batch_id].item()):
                x,y,x2,y2 = bboxes[batch_id,label_id,:]
                label_box = labels[batch_id,label_id].repeat((x2-x,y2-y,1)).transpose(2,0)
                chargrid[batch_id,:,y:y2,x:x2] = label_box

        #end.record()
        #torch.cuda.synchronize()
        #print("Chargrid: ",start.elapsed_time(end))

        model.zero_grad()
        output = model(image, question, q_len, chargrid)
        #end.record()
        #torch.cuda.synchronize()
        #print("Forward: ",start.elapsed_time(end))
        #start.record()
        #SANDY: add the OCR tokens at the beginning of the question
        loss = criterion(output, answer)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
        optimizer.step()
        #end.record()
        #torch.cuda.synchronize()
        #print("Backward: ",start.elapsed_time(end))
        #start.record()
        item_correct = output.data.cpu().numpy().argmax(1) == answer.data.cpu().numpy()
        correct = item_correct.sum() / batch_size

        if moving_loss == 0:
            moving_loss = correct
            # print("moving_loss = correct")

        else:
            moving_loss = moving_loss * 0.99 + correct * 0.01
            # print("moving_loss = moving_loss * 0.99 + correct * 0.01")

        # pbar.set_description(
        #     'Epoch: {}; Loss: {:.5f}; Acc: {:.5f}; Correct:{:.5f}; LR: {:.6f}, Batch_Acc: {:.5f}'.format(
        #         epoch + 1,
        #         loss.detach().item(),  # 0.00 for YES model
        #         moving_loss,
        #         correct,
        #         optimizer.param_groups[0]['lr'],  # 0.00  for YES model
        #         np.mean((output.argmax(1) == answer).cpu().numpy())
        #     )
        # )
        #Debug 1 == 0
            

            
        
        if (("IMG" in model_name) or ("SAN" in model_name)) and i % 10000 == 0 and i != 0 and 1 == 0:
            # valid(epoch + float(i * batch_size / 2325316), model_name=model_name, val_split="val_easy",
            #       load_image=load_image)
            valid(epoch + float(i * batch_size / 2325316),tensorboard_client,global_iteration, valid_set_easy, model_name=model_name,
                load_image=load_image, val_split="val_easy")

            model.train(True)

        #Chargrid: visualize
        visualize_train(
            global_iteration,run_name,tensorboard_client,
            loss,moving_loss)

        #if global_iteration % 50 * batch_size == 0:
        if global_iteration % train_visualization_iteration == 0:


            #Visualize Input divided by correctness
            for is_correct,correct_class in [(True,"correct"),(False,"incorrect")]:
                select_mask = item_correct == is_correct
                n_pictures = min(np.sum(select_mask),8)
                if n_pictures != 0:

                    #Image Layer
                    visu_net = vqa_model.resnet
                    tensorboard_client.add_conv2(
                        global_iteration,
                        visu_net[0],
                        "Image_Conv1",
                        "img_act0",
                        select_mask,
                        n_pictures,
                        f"_{correct_class}"
                    )

                    visu_img = image[select_mask][:n_pictures].cpu().numpy()
                    visu_question = question[select_mask][:n_pictures].cpu().numpy()
                    visu_answer = answer[select_mask][:n_pictures].cpu().numpy()
                    visu_output = output[select_mask][:n_pictures].data.cpu().numpy().argmax(1)

                    tensorboard_client.add_figure_with_question(
                        global_iteration,
                        visu_img,
                        visu_question,
                        visu_answer,
                        visu_output,
                        "Input",
                        f"_{correct_class}")

                    #Attention
                    #attention_features = F.pad(attention_map.features[:16].detach().cpu(),(2,2,2,2))
                    attention_features = attention_map.get_features()[select_mask][:n_pictures].detach().cpu()     
                    glimpses = attention_features.size(1)
                    grid_size = attention_features.size(2)
                    attention_features = attention_features.view(n_pictures, glimpses, -1)
                    #attention_features = F.softmax(attention_features, dim=-1).unsqueeze(2)
                    attention_features = attention_features.view(n_pictures, glimpses, grid_size, grid_size)
                    if glimpses == 2:
                        att1,att2 = torch.split(attention_features,1,1)
                        tensorboard_client.add_images(
                            global_iteration,
                            att2,
                            f"Attention/glimps2_{correct_class}")
                    else:
                        att1 = attention_features

                    tensorboard_client.add_images(
                        global_iteration,
                        att1,
                        f"Attention/glimps1_{correct_class}")
                    


            #Chargrid
            # visu_net = model.chargrid_net
            # tensorboard_client.add_conv2(
            #     global_iteration,
            #     visu_net[0],
            #     "Chargrid_Conv1",
            #     chargrid_act1,
            #     16
            # )
            # tensorboard_client.add_conv2(
            #     global_iteration,
            #     visu_net[3],
            #     "Chargrid_Conv2",
            #     chargrid_act3,
            #     16
            # )

            # #Chargrid Input
            # visu_chargrid = torch.sum(chargrid[:16],dim=1,keepdim=True).cpu().numpy()
            # tensorboard_client.add_figure_with_question(
            #     global_iteration,
            #     visu_chargrid,
            #     visu_question,
            #     visu_answer,
            #     visu_output,
            #     "Chargrid")

        #Replace by batch_size
        global_iteration += 1

    #valid(epoch + float(i * batch_size / 2325316),tensorboard_client,global_iteration, train_set, model_name=model_name,
    #            load_image=load_image, val_split="train")
    
    return global_iteration
Exemple #10
0
def assign_pairs(sbj_bboxes,
                 sbj_labels,
                 sbj_idxs,
                 obj_bboxes,
                 obj_labels,
                 obj_idxs,
                 gt_bboxes,
                 gt_labels,
                 gt_rel,
                 gt_instid,
                 th=0.5):
    """

    :param sbj_bboxes: m x 4 (tensor)
    :param sbj_labels: m x 1 (tensor)
    :param sbj_idxs:   m x 1 (tensor) subject indices in combined_bboxes
    :param obj_bboxes: m x 4 (tensor)
    :param obj_labels: m x 1 (tensor)
    :param obj_idxs:   m x 1 (tensor) object indices in combined_bboxes
    :param gt_bboxes:  n x 4 (tensor)
    :param gt_labels:  n x 1 (tensor)
    :param gt_rel:     k x 3 (tensor) [subject instance id, object instance id, relation id]
    :param gt_instid:  n x 1 (tensor) instance indices
    :param th:         IoU threshold (default: 0.5)
    :return:
        positive and negative quintet,
        each quad: [subject semantic id, object semantic id, relation id,
                    subject index, object index ]
    """

    total_cand = sbj_bboxes.shape[0]
    total_rela = gt_rel.shape[0]

    # return variables
    relations = []
    relation_labels = []

    # compute overlaps
    overlaps_sbj_gt = bbox_overlaps(sbj_bboxes, gt_bboxes)
    overlaps_obj_gt = bbox_overlaps(obj_bboxes, gt_bboxes)

    # convert to numpy
    for i in range(total_cand):
        # labels for detected subject and object
        sbj_label = sbj_labels[i][0]
        obj_label = obj_labels[i][0]
        found = False

        r = [sbj_label, obj_label, sbj_idxs[i], obj_idxs[i]]
        r_l = []
        for j in range(total_rela):
            # relations
            gt_sbj_instid, gt_obj_instid, gt_rel_id = gt_rel[j, :]
            sbj_instid_idx = (gt_instid == gt_sbj_instid.item()).nonzero()[0]
            obj_instid_idx = (gt_instid == gt_obj_instid.item()).nonzero()[0]

            assert len(sbj_instid_idx) == 1
            assert len(obj_instid_idx) == 1

            gt_sbj_label = gt_labels[sbj_instid_idx][0]
            gt_obj_label = gt_labels[obj_instid_idx][0]

            overlap_s = overlaps_sbj_gt[i, sbj_instid_idx]
            overlap_o = overlaps_obj_gt[i, obj_instid_idx]

            positive = (sbj_label.item() == gt_sbj_label.item()
                        and obj_label.item() == gt_obj_label.item()
                        and overlap_s.item() >= th and overlap_o.item() >= th)

            if positive is True:
                r_l.append(gt_rel_id)
                found = True
        if not found:
            relation_labels.append(torch.tensor([0]))
        else:
            relation_labels.append(r_l)
        relations.append(r)

    assert len(relations) == len(relation_labels)

    device = torch.get_device(sbj_bboxes)
    relations = torch.FloatTensor(relations).to(device)

    return relations, relation_labels