def main():
    opt = BaseOptions().parse()
    torch.manual_seed(opt.seed)
    cudnn.benchmark = False
    cudnn.deterministic = True
    np.random.seed(opt.seed)

    dset = TVQADataset(opt)
    opt.vocab_size = len(dset.word2idx)

    model = TVQANet(opt)


    if opt.device.type == "cuda":
        print("CUDA enabled.")
        if len(opt.device_ids) > 1:
            print("Use multi GPU", opt.device_ids)
            model = torch.nn.DataParallel(model, device_ids=opt.device_ids, output_device=0)  # use multi GPU
        model.to(opt.device)


    # model.load_state_dict(torch.load("./path/best_release_7420.pth"))


    criterion = nn.CrossEntropyLoss(reduction="sum").to(opt.device)

    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=opt.lr,
        weight_decay=opt.wd)

    best_acc = 0.
    start_epoch = 0
    early_stopping_cnt = 0
    early_stopping_flag = False

    for epoch in range(start_epoch, opt.n_epoch):
        if not early_stopping_flag:

            niter = epoch * np.ceil(len(dset) / float(opt.bsz))

            cur_acc = train(opt, dset, model, criterion, optimizer, epoch, best_acc)

            is_best = cur_acc > best_acc
            best_acc = max(cur_acc, best_acc)
            if not is_best:
                early_stopping_cnt += 1
                if early_stopping_cnt >= opt.max_es_cnt:
                    early_stopping_flag = True
            else:
                early_stopping_cnt = 0
        else:
            print("=> early stop with valid acc %.4f" % best_acc)
            break  

        if epoch == 10:
            for g in optimizer.param_groups:
                g['lr'] = 0.0002

    return opt.results_dir.split("/")[1]
Example #2
0
def main_inference():
    print("Loading config...")
    opt = BaseOptions().parse()
    print("Loading dataset...")
    dset = TVQADataset(opt, paths)
    print("Loading model...")
    model = TVQANet(opt)

    device = torch.device("cuda:0" if opt.device != '-2'
                          and torch.cuda.is_available() else "cpu")

    # if specified, use opt.device else use the better of whats available (gpu > cpu)
    #model.to(opt.device if opt.device != '-2' else device)

    cudnn.benchmark = True

    # load pre-trained model if it exists
    loadPreTrainedModel(model=model, modelPath=paths["pretrained_model"])

    model.eval()
    model.inference_mode = True
    torch.set_grad_enabled(False)
    print("Evaluation Starts:\n")
    predictions = inference(opt, dset, model)
    print("predictions {}".format(predictions.keys()))
    pred_path = paths["pretrained_model"].replace(
        "best_valid.pth", "{}_inference_predictions.json".format(opt.mode))
    save_json(predictions, pred_path)
Example #3
0
    @staticmethod
    def get_fake_inputs(device="cuda:0"):
        bsz = 16
        q = torch.ones(bsz, 25).long().to(device)
        q_l = torch.ones(bsz).fill_(25).long().to(device)
        a = torch.ones(bsz, 5, 20).long().to(device)
        a_l = torch.ones(bsz, 5).fill_(20).long().to(device)
        a0, a1, a2, a3, a4 = [a[:, i, :] for i in range(5)]
        a0_l, a1_l, a2_l, a3_l, a4_l = [a_l[:, i] for i in range(5)]
        sub = torch.ones(bsz, 300).long().to(device)
        sub_l = torch.ones(bsz).fill_(300).long().to(device)
        vcpt = torch.ones(bsz, 300).long().to(device)
        vcpt_l = torch.ones(bsz).fill_(300).long().to(device)
        vid = torch.ones(bsz, 100, 2048).to(device)
        vid_l = torch.ones(bsz).fill_(100).long().to(device)
        return q, q_l, a0, a0_l, a1, a1_l, a2, a2_l, a3, a3_l, a4, a4_l, sub, sub_l, vcpt, vcpt_l, vid, vid_l


if __name__ == '__main__':
    from config import BaseOptions
    import sys
    sys.argv[1:] = ["--input_streams" "sub"]
    opt = BaseOptions().parse()

    model = ABC(opt)
    model.to(opt.device)
    test_in = model.get_fake_inputs(device=opt.device)
    test_out = model(*test_in)
    print((test_out.size()))
Example #4
0
File: main.py Project: sunutf/TVQA
        valid_qids += [int(x) for x in qids]
        valid_loss.append(loss.item())
        pred_ids = outputs.data.max(1)[1]
        valid_corrects += pred_ids.eq(targets.data).cpu().numpy().tolist()

        if opt.debug:
            break

    valid_acc = sum(valid_corrects) / float(len(valid_corrects))
    valid_loss = sum(valid_loss) / float(len(valid_corrects))
    return valid_acc, valid_loss


if __name__ == "__main__":
    torch.manual_seed(2018)
    opt = BaseOptions().parse()
    writer = SummaryWriter(opt.results_dir)
    opt.writer = writer

    dset = TVQADataset(opt)
    opt.vocab_size = len(dset.word2idx)
    model = ABC(opt)
    if not opt.no_glove:
        model.load_embedding(dset.vocab_embedding)

    model.to(opt.device)
    cudnn.benchmark = True
    criterion = nn.CrossEntropyLoss(size_average=False).to(opt.device)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=opt.lr,
Example #5
0
def main():
    opt = BaseOptions().parse()
    torch.manual_seed(opt.seed)
    cudnn.benchmark = False
    cudnn.deterministic = True
    np.random.seed(opt.seed)

    writer = SummaryWriter(opt.results_dir)
    opt.writer = writer
    dset = TVQADataset(opt)
    opt.vocab_size = len(dset.word2idx)
    model = STAGE(opt)

    count_parameters(model)

    if opt.device.type == "cuda":
        print("CUDA enabled.")
        model.to(opt.device)
        if len(opt.device_ids) > 1:
            print("Use multi GPU", opt.device_ids)
            model = torch.nn.DataParallel(
                model, device_ids=opt.device_ids)  # use multi GPU

    criterion = nn.CrossEntropyLoss(reduction="sum").to(opt.device)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=opt.lr,
                                 weight_decay=opt.wd)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode="max",
                                                           factor=0.5,
                                                           patience=10,
                                                           verbose=True)

    best_acc = 0.
    start_epoch = 0
    early_stopping_cnt = 0
    early_stopping_flag = False
    for epoch in range(start_epoch, opt.n_epoch):
        if not early_stopping_flag:
            use_hard_negatives = epoch + 1 > opt.hard_negative_start  # whether to use hard negative sampling
            niter = epoch * np.ceil(len(dset) / float(opt.bsz))
            opt.writer.add_scalar("learning_rate",
                                  float(optimizer.param_groups[0]["lr"]),
                                  niter)
            cur_acc = train(opt,
                            dset,
                            model,
                            criterion,
                            optimizer,
                            epoch,
                            best_acc,
                            use_hard_negatives=use_hard_negatives)
            scheduler.step(cur_acc)  # decrease lr when acc is not improving
            # remember best acc
            is_best = cur_acc > best_acc
            best_acc = max(cur_acc, best_acc)
            if not is_best:
                early_stopping_cnt += 1
                if early_stopping_cnt >= opt.max_es_cnt:
                    early_stopping_flag = True
            else:
                early_stopping_cnt = 0
        else:
            print("=> early stop with valid acc %.4f" % best_acc)
            opt.writer.export_scalars_to_json(
                os.path.join(opt.results_dir, "all_scalars.json"))
            opt.writer.close()
            break  # early stop break

        if opt.debug:
            break

    return opt.results_dir.split("/")[1], opt.debug
Example #6
0
    @staticmethod
    def get_fake_inputs(device="cuda:0"):
        bsz = 16
        q = torch.ones(bsz, 25).long().to(device)
        q_l = torch.ones(bsz).fill_(25).long().to(device)
        a = torch.ones(bsz, 5, 20).long().to(device)
        a_l = torch.ones(bsz, 5).fill_(20).long().to(device)
        a0, a1, a2, a3, a4 = [a[:, i, :] for i in range(5)]
        a0_l, a1_l, a2_l, a3_l, a4_l = [a_l[:, i] for i in range(5)]
        sub = torch.ones(bsz, 300).long().to(device)
        sub_l = torch.ones(bsz).fill_(300).long().to(device)
        vcpt = torch.ones(bsz, 300).long().to(device)
        vcpt_l = torch.ones(bsz).fill_(300).long().to(device)
        vid = torch.ones(bsz, 100, 2048).to(device)
        vid_l = torch.ones(bsz).fill_(100).long().to(device)
        return q, q_l, a0, a0_l, a1, a1_l, a2, a2_l, a3, a3_l, a4, a4_l, sub, sub_l, vcpt, vcpt_l, vid, vid_l


if __name__ == '__main__':
    from config import BaseOptions
    import sys
    #sys.argv[1:] = ["--input_streams " "sub"]
    opt = BaseOptions().parse()
    opt.vocab_size = 5
    model = ABC(opt)
    model.to(opt.device)
    test_in = model.get_fake_inputs(device=opt.device)
    test_out = model(*test_in)
    print(test_out.size())
Example #7
0
        valid_loss.append(loss.item())
        pred_ids = outputs.data.max(1)[1]
        valid_corrects += pred_ids.eq(targets.data).cpu().numpy().tolist()

        if opt.debug:
            break

    valid_acc = sum(valid_corrects) / float(len(valid_corrects))
    valid_loss = sum(valid_loss) / float(len(valid_loss))

    return valid_acc, valid_loss


if __name__ == "__main__":
    torch.manual_seed(2018)
    opt = BaseOptions().parse()
    writer = SummaryWriter(opt.results_dir)
    opt.writer = writer
    plotter = VisdomLinePlotter(env_name=opt.jobname)
    opt.plotter = plotter
    dset = TVQADataset(opt)
    if opt.bert is None:
        opt.vocab_size = len(dset.word2idx)
    if opt.disable_streams is None:
        opt.disable_streams = []

    # My dynamic imports
    ####
    import importlib
    print((opt.jobname))
    print((opt.modelname))
Example #8
0
class ABC(nn.Module):
    def __init__(self, opt):
        super(ABC, self).__init__()
        self.vid_flag = "imagenet" in opt.input_streams
        self.sub_flag = "sub" in opt.input_streams
        self.vcpt_flag = "vcpt" in opt.input_streams
        self.reg_flag = "regional" in opt.input_streams
        self.topk = opt.topk
        self.opt = opt
        hidden_size_1 = opt.hsz1
        hidden_size_2 = opt.hsz2
        n_layers_cls = opt.n_layers_cls
        vid_feat_size = opt.vid_feat_size
        embedding_size = opt.embedding_size

        # For BERT
        if opt.bert is None:
            vocab_size = opt.vocab_size
            self.embedding = nn.Embedding(vocab_size, embedding_size)
        else:
            self.bert_fc = nn.Linear(768, 300)
            if opt.bert in ["default"]:
                self.bert = BertModel.from_pretrained('bert-base-uncased')
            elif opt.bert == "multi_choice":
                self.bert = BertForMultipleChoice.from_pretrained('bert-base-uncased')
            elif opt.bert == "qa":
                self.bert = BertForQuestionAnswering.from_pretrained('bert-base-uncased')

        self.bidaf = BidafAttn(hidden_size_1 * 3, method="dot")  # no parameter for dot
        self.lstm_raw = RNNEncoder(300, hidden_size_1, bidirectional=True, dropout_p=0, n_layers=1, rnn_type="lstm")


        if self.vid_flag:
            print("activate video stream")
            self.video_fc = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(vid_feat_size, embedding_size),
                nn.Tanh(),
            )
            self.lstm_mature_vid = RNNEncoder(hidden_size_1 * 2 * 5, hidden_size_2, bidirectional=True,
                                              dropout_p=0, n_layers=1, rnn_type="lstm")
            self.classifier_vid = MLP(hidden_size_2*2, 1, 500, n_layers_cls)

        if self.sub_flag:
            print("activate sub stream")
            self.lstm_mature_sub = RNNEncoder(hidden_size_1 * 2 * 5, hidden_size_2, bidirectional=True,
                                              dropout_p=0, n_layers=1, rnn_type="lstm")
            self.classifier_sub = MLP(hidden_size_2*2, 1, 500, n_layers_cls)

        if self.vcpt_flag:
            print("activate vcpt stream")
            self.lstm_mature_vcpt = RNNEncoder(hidden_size_1 * 2 * 5, hidden_size_2, bidirectional=True,
                                               dropout_p=0, n_layers=1, rnn_type="lstm")
            self.classifier_vcpt = MLP(hidden_size_2*2, 1, 500, n_layers_cls)

        if self.reg_flag:
            print("activate regional stream")
            self.regional_fc = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(2048, embedding_size),
                nn.Tanh(),
            )
            self.lstm_mature_reg= RNNEncoder(hidden_size_1 * 2 * 5, hidden_size_2, bidirectional=True,
                                              dropout_p=0, n_layers=1, rnn_type="lstm")
            self.classifier_reg = MLP(hidden_size_2*2, 1, 500, n_layers_cls)

    def load_embedding(self, pretrained_embedding):
        self.embedding.weight.data.copy_(torch.from_numpy(pretrained_embedding))

    def forward(self, q, q_l, a0, a0_l, a1, a1_l, a2, a2_l, a3, a3_l, a4, a4_l,
                sub, sub_l, vcpt, vcpt_l, vid, vid_l, reg, reg_l, regtopk, regtopk_l):
        #import ipdb; ipdb.set_trace()
        if self.opt.bert is None:# For BERT
            e_q = self.embedding(q)
            e_a0 = self.embedding(a0)
            e_a1 = self.embedding(a1)
            e_a2 = self.embedding(a2)
            e_a3 = self.embedding(a3)
            e_a4 = self.embedding(a4)

            raw_out_q, _ = self.lstm_raw(e_q, q_l)
            raw_out_a0, _ = self.lstm_raw(e_a0, a0_l)
            raw_out_a1, _ = self.lstm_raw(e_a1, a1_l)
            raw_out_a2, _ = self.lstm_raw(e_a2, a2_l)
            raw_out_a3, _ = self.lstm_raw(e_a3, a3_l)
            raw_out_a4, _ = self.lstm_raw(e_a4, a4_l)
        else:
            with torch.no_grad():
                e_q  = self.bert(q)[0]
                e_a0 = self.bert(a0)[0]
                e_a1 = self.bert(a1)[0]
                e_a2 = self.bert(a2)[0]
                e_a3 = self.bert(a3)[0]
                e_a4 = self.bert(a4)[0]

            e_q  = F.leaky_relu(self.bert_fc(e_q))
            e_a0 = F.leaky_relu(self.bert_fc(e_a0))
            e_a1 = F.leaky_relu(self.bert_fc(e_a1))
            e_a2 = F.leaky_relu(self.bert_fc(e_a2))
            e_a3 = F.leaky_relu(self.bert_fc(e_a3))
            e_a4 = F.leaky_relu(self.bert_fc(e_a4))

            raw_out_q, _ = self.lstm_raw(e_q, q_l)
            raw_out_a0, _ = self.lstm_raw(e_a0, a0_l)
            raw_out_a1, _ = self.lstm_raw(e_a1, a1_l)
            raw_out_a2, _ = self.lstm_raw(e_a2, a2_l)
            raw_out_a3, _ = self.lstm_raw(e_a3, a3_l)
            raw_out_a4, _ = self.lstm_raw(e_a4, a4_l)

        #### Subs ####
        if self.sub_flag:
            if(self.opt.bert is None):# For BERT
                e_sub  = self.embedding(sub) #Subtitles embedded
            else:
                e_sub  = self.bert(sub)[0]
                e_sub  = F.relu(self.bert_fc(e_sub))
            raw_out_sub, _ = self.lstm_raw(e_sub, sub_l) #through lstm
            sub_out = self.stream_processor(self.lstm_mature_sub, self.classifier_sub, raw_out_sub, sub_l,
                                            raw_out_q, q_l, raw_out_a0, a0_l, raw_out_a1, a1_l,
                                            raw_out_a2, a2_l, raw_out_a3, a3_l, raw_out_a4, a4_l) #Fusion happens in here for subtitles
        else:
            sub_out = 0

        #### Vcpt ####
        if self.vcpt_flag:
            if(self.opt.bert is None):# For BERT
                e_vcpt = self.embedding(vcpt)
            else:
                e_vcpt = self.bert(vcpt)[0]
                e_vcpt = F.leaky_relu(self.bert_fc(e_vcpt))
            raw_out_vcpt, _ = self.lstm_raw(e_vcpt, vcpt_l)
            vcpt_out = self.stream_processor(self.lstm_mature_vcpt, self.classifier_vcpt, raw_out_vcpt, vcpt_l,
                                             raw_out_q, q_l, raw_out_a0, a0_l, raw_out_a1, a1_l,
                                             raw_out_a2, a2_l, raw_out_a3, a3_l, raw_out_a4, a4_l)
        else:
            vcpt_out = 0

        #### Imgnet ####
        if self.vid_flag:
            e_vid = self.video_fc(vid)
            raw_out_vid, _ = self.lstm_raw(e_vid, vid_l)
            vid_out = self.stream_processor(self.lstm_mature_vid, self.classifier_vid, raw_out_vid, vid_l,
                                            raw_out_q, q_l, raw_out_a0, a0_l, raw_out_a1, a1_l,
                                            raw_out_a2, a2_l, raw_out_a3, a3_l, raw_out_a4, a4_l)
        else:
            vid_out = 0

        #### Reg ####
        if self.reg_flag:
            e_reg = self.regional_fc(reg)
            raw_out_reg, _ = self.lstm_raw(e_reg, reg_l)
            reg_out = self.stream_processor(self.lstm_mature_reg, self.classifier_reg, raw_out_reg, reg_l,
                                            raw_out_q, q_l, raw_out_a0, a0_l, raw_out_a1, a1_l,
                                            raw_out_a2, a2_l, raw_out_a3, a3_l, raw_out_a4, a4_l)
        else:
            reg_out = 0


        # Total
        out = sub_out + vcpt_out + vid_out + reg_out # adding zeros has no effect on backward
        return out.squeeze()


    #Regular stream processor for imgnet, vcpt, subtitles
    def stream_processor(self, lstm_mature, classifier, ctx_embed, ctx_l,
                         q_embed, q_l, a0_embed, a0_l, a1_embed, a1_l, a2_embed, a2_l, a3_embed, a3_l, a4_embed, a4_l):
        u_q, _ = self.bidaf(ctx_embed, ctx_l, q_embed, q_l)
        u_a0, _ = self.bidaf(ctx_embed, ctx_l, a0_embed, a0_l)
        u_a1, _ = self.bidaf(ctx_embed, ctx_l, a1_embed, a1_l)
        u_a2, _ = self.bidaf(ctx_embed, ctx_l, a2_embed, a2_l)
        u_a3, _ = self.bidaf(ctx_embed, ctx_l, a3_embed, a3_l)
        u_a4, _ = self.bidaf(ctx_embed, ctx_l, a4_embed, a4_l)

        concat_a0 = torch.cat([ctx_embed, u_a0, u_q, u_a0 * ctx_embed, u_q * ctx_embed], dim=-1)
        concat_a1 = torch.cat([ctx_embed, u_a1, u_q, u_a1 * ctx_embed, u_q * ctx_embed], dim=-1)
        concat_a2 = torch.cat([ctx_embed, u_a2, u_q, u_a2 * ctx_embed, u_q * ctx_embed], dim=-1)
        concat_a3 = torch.cat([ctx_embed, u_a3, u_q, u_a3 * ctx_embed, u_q * ctx_embed], dim=-1)
        concat_a4 = torch.cat([ctx_embed, u_a4, u_q, u_a4 * ctx_embed, u_q * ctx_embed], dim=-1)

        mature_maxout_a0, _ = lstm_mature(concat_a0, ctx_l)
        mature_maxout_a1, _ = lstm_mature(concat_a1, ctx_l)
        mature_maxout_a2, _ = lstm_mature(concat_a2, ctx_l)
        mature_maxout_a3, _ = lstm_mature(concat_a3, ctx_l)
        mature_maxout_a4, _ = lstm_mature(concat_a4, ctx_l)

        if self.topk == 1:
            mature_maxout_a0 = max_along_time(mature_maxout_a0, ctx_l).unsqueeze(1)
            mature_maxout_a1 = max_along_time(mature_maxout_a1, ctx_l).unsqueeze(1)
            mature_maxout_a2 = max_along_time(mature_maxout_a2, ctx_l).unsqueeze(1)
            mature_maxout_a3 = max_along_time(mature_maxout_a3, ctx_l).unsqueeze(1)
            mature_maxout_a4 = max_along_time(mature_maxout_a4, ctx_l).unsqueeze(1)
        else:
            mature_maxout_a0 = max_avg_along_time(mature_maxout_a0, ctx_l, self.topk).unsqueeze(1)
            mature_maxout_a1 = max_avg_along_time(mature_maxout_a1, ctx_l, self.topk).unsqueeze(1)
            mature_maxout_a2 = max_avg_along_time(mature_maxout_a2, ctx_l, self.topk).unsqueeze(1)
            mature_maxout_a3 = max_avg_along_time(mature_maxout_a3, ctx_l, self.topk).unsqueeze(1)
            mature_maxout_a4 = max_avg_along_time(mature_maxout_a4, ctx_l, self.topk).unsqueeze(1)

        mature_answers = torch.cat([
            mature_maxout_a0, mature_maxout_a1, mature_maxout_a2, mature_maxout_a3, mature_maxout_a4
        ], dim=1)
        out = classifier(mature_answers)  # (B, 5)
        return out


        #Regional feature stream processor to deal with attention across 20 objects
        def stream_processor(self, lstm_mature, classifier, ctx_embed, ctx_l,
                         q_embed, q_l, a0_embed, a0_l, a1_embed, a1_l, a2_embed, a2_l, a3_embed, a3_l, a4_embed, a4_l):
        u_q, _ = self.bidaf(ctx_embed, ctx_l, q_embed, q_l)
        u_a0, _ = self.bidaf(ctx_embed, ctx_l, a0_embed, a0_l)
        u_a1, _ = self.bidaf(ctx_embed, ctx_l, a1_embed, a1_l)
        u_a2, _ = self.bidaf(ctx_embed, ctx_l, a2_embed, a2_l)
        u_a3, _ = self.bidaf(ctx_embed, ctx_l, a3_embed, a3_l)
        u_a4, _ = self.bidaf(ctx_embed, ctx_l, a4_embed, a4_l)

        concat_a0 = torch.cat([ctx_embed, u_a0, u_q, u_a0 * ctx_embed, u_q * ctx_embed], dim=-1)
        concat_a1 = torch.cat([ctx_embed, u_a1, u_q, u_a1 * ctx_embed, u_q * ctx_embed], dim=-1)
        concat_a2 = torch.cat([ctx_embed, u_a2, u_q, u_a2 * ctx_embed, u_q * ctx_embed], dim=-1)
        concat_a3 = torch.cat([ctx_embed, u_a3, u_q, u_a3 * ctx_embed, u_q * ctx_embed], dim=-1)
        concat_a4 = torch.cat([ctx_embed, u_a4, u_q, u_a4 * ctx_embed, u_q * ctx_embed], dim=-1)

        mature_maxout_a0, _ = lstm_mature(concat_a0, ctx_l)
        mature_maxout_a1, _ = lstm_mature(concat_a1, ctx_l)
        mature_maxout_a2, _ = lstm_mature(concat_a2, ctx_l)
        mature_maxout_a3, _ = lstm_mature(concat_a3, ctx_l)
        mature_maxout_a4, _ = lstm_mature(concat_a4, ctx_l)

        if self.topk == 1:
            mature_maxout_a0 = max_along_time(mature_maxout_a0, ctx_l).unsqueeze(1)
            mature_maxout_a1 = max_along_time(mature_maxout_a1, ctx_l).unsqueeze(1)
            mature_maxout_a2 = max_along_time(mature_maxout_a2, ctx_l).unsqueeze(1)
            mature_maxout_a3 = max_along_time(mature_maxout_a3, ctx_l).unsqueeze(1)
            mature_maxout_a4 = max_along_time(mature_maxout_a4, ctx_l).unsqueeze(1)
        else:
            mature_maxout_a0 = max_avg_along_time(mature_maxout_a0, ctx_l, self.topk).unsqueeze(1)
            mature_maxout_a1 = max_avg_along_time(mature_maxout_a1, ctx_l, self.topk).unsqueeze(1)
            mature_maxout_a2 = max_avg_along_time(mature_maxout_a2, ctx_l, self.topk).unsqueeze(1)
            mature_maxout_a3 = max_avg_along_time(mature_maxout_a3, ctx_l, self.topk).unsqueeze(1)
            mature_maxout_a4 = max_avg_along_time(mature_maxout_a4, ctx_l, self.topk).unsqueeze(1)

        mature_answers = torch.cat([
            mature_maxout_a0, mature_maxout_a1, mature_maxout_a2, mature_maxout_a3, mature_maxout_a4
        ], dim=1)
        out = classifier(mature_answers)  # (B, 5)
        return out

    def pointer_network():
        pass
    @staticmethod
    def get_fake_inputs(device="cuda:0"):
        bsz = 16
        q = torch.ones(bsz, 25).long().to(device)
        q_l = torch.ones(bsz).fill_(25).long().to(device)
        a = torch.ones(bsz, 5, 20).long().to(device)
        a_l = torch.ones(bsz, 5).fill_(20).long().to(device)
        a0, a1, a2, a3, a4 = [a[:, i, :] for i in range(5)]
        a0_l, a1_l, a2_l, a3_l, a4_l = [a_l[:, i] for i in range(5)]
        sub = torch.ones(bsz, 300).long().to(device)
        sub_l = torch.ones(bsz).fill_(300).long().to(device)
        vcpt = torch.ones(bsz, 300).long().to(device)
        vcpt_l = torch.ones(bsz).fill_(300).long().to(device)
        vid = torch.ones(bsz, 100, 2048).to(device)
        vid_l = torch.ones(bsz).fill_(100).long().to(device)
        return q, q_l, a0, a0_l, a1, a1_l, a2, a2_l, a3, a3_l, a4, a4_l, sub, sub_l, vcpt, vcpt_l, vid, vid_l


if __name__ == '__main__':
    from config import BaseOptions
    import sys
    sys.argv[1:] = ["--input_streams" "sub"]
    opt = BaseOptions().parse()

    model = ABC(opt)
    model.to(opt.device)
    test_in = model.get_fake_inputs(device=opt.device)
    test_out = model(*test_in)
    print((test_out.size()))
Example #9
0
import torch
from torch.autograd import Variable
from PIL import Image
from reconstruction import *
from relighting import *
from direct_intrinsics_sn import DirectIntrinsicsSN
from infer import main, set_experiment
from utils import Cuda, create_image
from normal_weights.models import net as normal_net
from config import BaseOptions
from smoothing import average_frames, average_frames_warp
from scipy import ndimage

print("PyTorch can see",torch.cuda.device_count(),"GPU(s). Current device:",torch.cuda.current_device())

inst =  BaseOptions()
parser = inst.parser
opt = parser.parse_args()

IMG_SCALE  = 1./255
IMG_MEAN = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
IMG_STD = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))

def initNetworks():
    #Decomposition network
    net = DirectIntrinsicsSN(3,['color','color','class'])
    net.load_state_dict(torch.load(opt.intrinseg_weights_loc))
    cuda = Cuda(0)
    net = net.cuda(device=cuda.device)
    net.eval()