Ejemplo n.º 1
0
Archivo: main.py Proyecto: sunutf/TVQA
def validate(opt, dset, model, mode="valid"):
    dset.set_mode(mode)
    torch.set_grad_enabled(False)
    model.eval()
    valid_loader = DataLoader(dset,
                              batch_size=opt.test_bsz,
                              shuffle=False,
                              collate_fn=pad_collate)

    valid_qids = []
    valid_loss = []
    valid_corrects = []
    for _, batch in enumerate(valid_loader):
        model_inputs, targets, qids = preprocess_inputs(batch,
                                                        opt.max_sub_l,
                                                        opt.max_vcpt_l,
                                                        opt.max_vid_l,
                                                        device=opt.device)
        outputs = model(*model_inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        valid_qids += [int(x) for x in qids]
        valid_loss.append(loss.item())
        pred_ids = outputs.data.max(1)[1]
        valid_corrects += pred_ids.eq(targets.data).cpu().numpy().tolist()

        if opt.debug:
            break

    valid_acc = sum(valid_corrects) / float(len(valid_corrects))
    valid_loss = sum(valid_loss) / float(len(valid_corrects))
    return valid_acc, valid_loss
Ejemplo n.º 2
0
def train(opt, dset, model, criterion, optimizer, epoch, previous_best_acc):
    dset.set_mode("train")
    model.train()
    train_loader = DataLoader(dset, batch_size=opt.bsz, shuffle=True, collate_fn=pad_collate)

    train_loss = []
    valid_acc_log = ["batch_idx\tacc"]
    train_corrects = []
    torch.set_grad_enabled(True)
    for batch_idx, batch in tqdm(enumerate(train_loader)):
        model_inputs, targets, _ = preprocess_inputs(batch, opt.max_sub_l, opt.max_vcpt_l, opt.max_vid_l,
                                                     device=opt.device)
        outputs = model(*model_inputs)
        loss = criterion(outputs, targets)
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

        # measure accuracy and record loss
        train_loss.append(loss.item())
        pred_ids = tf.max(outputs, 1)[1]
        train_corrects = pred_ids[tf.where(pred_ids=targets)]
#         pred_ids = outputs.data.max(1)[1]
#         train_corrects += pred_ids.eq(targets.data).cpu().numpy().tolist()
        if batch_idx % opt.log_freq == 0:
            niter = epoch * len(train_loader) + batch_idx

            train_acc = sum(train_corrects) / float(len(train_corrects))
            train_loss = sum(train_loss) / float(len(train_corrects))
            opt.writer.add_scalar("Train/Acc", train_acc, niter)
            opt.writer.add_scalar("Train/Loss", train_loss, niter)

            # Test
            valid_acc, valid_loss = validate(opt, dset, model, mode="valid")
            opt.writer.add_scalar("Valid/Loss", valid_loss, niter)

            valid_log_str = "%02d\t%.4f" % (batch_idx, valid_acc)
            valid_acc_log.append(valid_log_str)
            if valid_acc > previous_best_acc:
                previous_best_acc = valid_acc
                torch.save(model.state_dict(), os.path.join(opt.results_dir, "best_valid.pth"))
            print(" Train Epoch %d loss %.4f acc %.4f Val loss %.4f acc %.4f"
                  % (epoch, train_loss, train_acc, valid_loss, valid_acc))

            # reset to train
            torch.set_grad_enabled(True)
            model.train()
            dset.set_mode("train")
            train_corrects = []
            train_loss = []

        if opt.debug:
            break

    # additional log
    with open(os.path.join(opt.results_dir, "valid_acc.log"), "a") as f:
        f.write("\n".join(valid_acc_log) + "\n")

    return previous_best_acc
Ejemplo n.º 3
0
def test(opt, dset, model):
    dset.set_mode(opt.mode)
    torch.set_grad_enabled(False)
    model.eval()
    valid_loader = DataLoader(dset, batch_size=opt.test_bsz, shuffle=False, collate_fn=pad_collate)

    qid2preds = {}
    qid2targets = {}
    for valid_idx, batch in tqdm(enumerate(valid_loader)):
        model_inputs, targets, qids = preprocess_inputs(batch, opt.max_sub_l, opt.max_vcpt_l, opt.max_vid_l,
                                                        device=opt.device)
        outputs = model(*model_inputs)
        pred_ids = outputs.data.max(1)[1].cpu().numpy().tolist()
        cur_qid2preds = {qid: pred for qid, pred in zip(qids, pred_ids)}
        qid2preds = merge_two_dicts(qid2preds, cur_qid2preds)
        cur_qid2targets = {qid:  target for qid, target in zip(qids, targets)}
        qid2targets = merge_two_dicts(qid2targets, cur_qid2targets)
    return qid2preds, qid2targets
Ejemplo n.º 4
0
def main():
    """Main script."""
    torch.manual_seed(2667)

    ################################################################################
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', default='train', help='train/test')
    parser.add_argument('--save_path', type=str, default=os.path.expanduser('~/kable_management/blp_paper/.results/hme_tvqa/saved_models'),
                        help='path for saving trained models')
    parser.add_argument('--test', type=int, default=0, help='0 | 1')
    parser.add_argument('--jobname', type=str , default='hme_tvqa-default',
                        help='jobname to direct to plotter') 
    parser.add_argument("--pool_type", type=str, default="default", choices=["default", "LinearSum", "ConcatMLP", "MCB", "MFH", "MFB", "MLB", "Block"], help="Which pooling technique to use")
    parser.add_argument("--pool_in_dims", type=int, nargs='+', default=[300,300], help="Input dimensions to pooling layers")
    parser.add_argument("--pool_out_dim", type=int, default=600, help="Output dimension to pooling layers")
    parser.add_argument("--pool_hidden_dim", type=int, default=1500, help="Some pooling types come with a hidden internal dimension")
    
    parser.add_argument("--results_dir_base", type=str, default=os.path.expanduser("~/kable_management/.results/test"))
    parser.add_argument("--log_freq", type=int, default=400, help="print, save training info")
    parser.add_argument("--lr", type=float, default=3e-4, help="learning rate")
    parser.add_argument("--wd", type=float, default=1e-5, help="weight decay")
    parser.add_argument("--n_epoch", type=int, default=100, help="number of epochs to run")
    parser.add_argument("--max_es_cnt", type=int, default=3, help="number of epochs to early stop")
    parser.add_argument("--bsz", type=int, default=32, help="mini-batch size")
    parser.add_argument("--test_bsz", type=int, default=100, help="mini-batch size for testing")
    parser.add_argument("--device", type=int, default=0, help="gpu ordinal, -1 indicates cpu")
    parser.add_argument("--no_core_driver", action="store_true",
                                help="hdf5 driver, default use `core` (load into RAM), if specified, use `None`")
    parser.add_argument("--word_count_threshold", type=int, default=2, help="word vocabulary threshold")
    parser.add_argument("--no_glove", action="store_true", help="not use glove vectors")
    parser.add_argument("--no_ts", action="store_true", help="no timestep annotation, use full length feature")
    parser.add_argument("--input_streams", type=str, nargs="+", choices=["vcpt", "sub", "imagenet", "regional"], #added regional support here
                                help="input streams for the model")
    ################ My stuff
    parser.add_argument("--modelname", type=str, default="tvqa_abc", help="name of the model ot use")
    parser.add_argument("--lrtype", type=str, choices=["adam", "cyclic", "radam", "lrelu"], default="adam", help="Kind of learning rate")
    parser.add_argument("--poolnonlin", type=str, choices=["tanh", "relu", "sigmoid", "None", "lrelu"], default="None", help="add nonlinearities to pooling layer")
    parser.add_argument("--pool_dropout", type=float, default=0.0, help="Dropout value for the projections")
    parser.add_argument("--testrun", type=bool, default=False, help="set True to stop writing and visdom")
    parser.add_argument("--topk", type=int, default=1, help="To use instead of max pooling")
    parser.add_argument("--nosub", type=bool, default=False, help="Ignore the sub stream")
    parser.add_argument("--noimg", type=bool, default=False, help="Ignore the imgnet stream")
    parser.add_argument("--bert", type=str, choices=["default", "mine", "multi_choice", "qa"], default=None, help="What kind of BERT model to use")
    parser.add_argument("--reg_feat_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/regional_features/100p.h5"),
                                help="regional features")
    parser.add_argument("--my_vcpt", type=bool, default=False, help="Use my extracted visual concepts")
    parser.add_argument("--lanecheck", type=bool, default=False, help="Validation lane checks")
    parser.add_argument("--lanecheck_path", type=str, help="Validation lane check path")
    parser.add_argument("--best_path", type=str, help="Path to best model")
    parser.add_argument("--disable_streams", type=str, default=None, nargs="+", choices=["vcpt", "sub", "imagenet", "regional"], #added regional support here
                                help="disable the input stream from voting in the softmax of model outputs")
    parser.add_argument("--dset", choices=["valid", "test", "train"], default="valid", type=str, help="The dataset to use")
    ########################

    parser.add_argument("--n_layers_cls", type=int, default=1, help="number of layers in classifier")
    parser.add_argument("--hsz1", type=int, default=150, help="hidden size for the first lstm")
    parser.add_argument("--hsz2", type=int, default=300, help="hidden size for the second lstm")
    parser.add_argument("--embedding_size", type=int, default=300, help="word embedding dim")
    parser.add_argument("--max_sub_l", type=int, default=300, help="max length for subtitle")
    parser.add_argument("--max_vcpt_l", type=int, default=300, help="max length for visual concepts")
    parser.add_argument("--max_vid_l", type=int, default=480, help="max length for video feature")
    parser.add_argument("--vocab_size", type=int, default=0, help="vocabulary size")
    parser.add_argument("--no_normalize_v", action="store_true", help="do not normalize video featrue")
    # Data paths
    parser.add_argument("--train_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/tvqa_train_processed.json"),
                                help="train set path")
    parser.add_argument("--valid_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/tvqa_val_processed.json"),
                                help="valid set path")
    parser.add_argument("--test_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/tvqa_test_public_processed.json"),
                                help="test set path")
    parser.add_argument("--glove_path", type=str, default=os.path.expanduser("~/kable_management/data/word_embeddings/glove.6B.300d.txt"),
                                help="GloVe pretrained vector path")
    parser.add_argument("--vcpt_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/vcpt_features/py2_det_visual_concepts_hq.pickle"),
                                help="visual concepts feature path")
    parser.add_argument("--vid_feat_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/imagenet_features/tvqa_imagenet_pool5_hq.h5"),
                                help="imagenet feature path")
    parser.add_argument("--vid_feat_size", type=int, default=2048,
                                help="visual feature dimension")
    parser.add_argument("--word2idx_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/cache/py2_word2idx.pickle"),
                                help="word2idx cache path")
    parser.add_argument("--idx2word_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/cache/py2_idx2word.pickle"),
                                help="idx2word cache path")
    parser.add_argument("--vocab_embedding_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/cache/py2_vocab_embedding.pickle"),
                                help="vocab_embedding cache path")
    parser.add_argument("--regional_topk", type=int, default=-1, help="Pick top-k scoring regional features across all frames")
    
    args = parser.parse_args()

    plotter = VisdomLinePlotter(env_name="TEST")#args.jobname)
    args.plotter = plotter
    args.normalize_v = not args.no_normalize_v
    args.device = torch.device("cuda:%d" % args.device if args.device >= 0 else "cpu")
    args.with_ts = not args.no_ts
    args.input_streams = [] if args.input_streams is None else args.input_streams
    args.vid_feat_flag = True if "imagenet" in args.input_streams else False
    args.h5driver = None if args.no_core_driver else "core"

    
    #args.dataset = 'msvd_qa'
    args.dataset = 'tvqa'

    args.word_dim = 300
    args.vocab_num = 4000
    args.pretrained_embedding = os.path.expanduser("~/kable_management/data/msvd_qa/word_embedding.npy")
    args.video_feature_dim = 4096
    args.video_feature_num = 480#20
    args.answer_num = 1000
    args.memory_dim = 256
    args.batch_size = 32
    args.reg_coeff = 1e-5
    args.learning_rate = 0.001
    args.preprocess_dir = os.path.expanduser("~/kable_management/data/hme_tvqa/")
    args.log = os.path.expanduser('~/kable_management/blp_paper/.results/msvd/logs')
    args.hidden_size = 512
    args.image_feature_net = 'concat'
    args.layer = 'fc'
    ################################################################################







    #model.load_embedding(dset.vocab_embedding)
    #opt.vocab_size = len(dset.word2idx)


    ##############################################################################
    ##############################################################################
    ##############################################################################
    ##############################################################################
    ##############################################################################
    ##############################################################################
    ##############################################################################
    ##############################################################################
    ### add arguments ### 
    args.max_sequence_length = 480

    args.memory_type = '_mrm2s'
    args.image_feature_net = 'concat'
    args.layer = 'fc'


    
    #args.save_model_path = args.save_path + '%s_%s_%s%s' % (args.task,args.image_feature_net,args.layer,args.memory_type)
    ## Create model directory
    # if not os.path.exists(args.save_model_path):
    #     os.makedirs(args.save_model_path)
    


    
    #############################
    # Parameters
    #############################
    SEQUENCE_LENGTH = args.max_sequence_length
    VOCABULARY_SIZE = train_dataset.n_words
    assert VOCABULARY_SIZE == voc_len
    FEAT_DIM = train_dataset.get_video_feature_dimension()[1:]
    train_iter = train_dataset.batch_iter(args.num_epochs, args.batch_size)


    # Create model directory
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    from embed_loss import MultipleChoiceLoss
    criterion = MultipleChoiceLoss(num_option=5, margin=1, size_average=True).cuda()

    
    
    if args.memory_type=='_mrm2s':
        rnn = TGIFAttentionTwoStream(args, 'Trans', feat_channel, feat_dim, text_embed_size, args.hidden_size,
                             voc_len, args.num_layers, word_matrix, answer_vocab_size = answer_vocab_size, 
                             max_len=args.max_sequence_length)
    else:
        assert 1==2
        
    rnn = rnn.cuda()
    ##############################################################################
    ############################################################################################################################################################
    ############################################################################################################################################################
    ############################################################################################################################################################
    ##############################################################################





# Heres my code
# rnn = TGIFAttentionTwoStream(args, feat_channel, feat_dim, text_embed_size, args.hidden_size,
#                      voc_len, num_layers, word_matrix, answer_vocab_size = answer_vocab_size,
#                      max_len=max_sequence_length)
# #rnn = rnn.cuda()
# rnn = rnn.to(args.device)

# if args.test == 1:
#     rnn.load_state_dict(torch.load(os.path.join(args.save_model_path, 'rnn-3800-vl_0.316.pkl')))



# # loss function
# criterion = nn.CrossEntropyLoss(size_average=True).cuda()
# optimizer = torch.optim.Adam(rnn.parameters(), lr=args.learning_rate)


# iter = 0
# ######################################################################
# best_val_acc = 0.0
# best_test_acc = 0.0
# train_loss = []
# train_acc = []
# ######################################################################
# for epoch in range(0, 10000):
#     #dataset.reset_train()
#     for batch_idx, batch in enumerate(train_loader):
#         tvqa_data, targets, _ = tvqa_dataset.preprocess_inputs(batch, args.max_sub_l, args.max_vcpt_l, args.max_vid_l,
#                                                  device='cpu')#args.device
#         assert( tvqa_data[18].shape[2] == 4096 )
#         data_dict = {}                                         
#         data_dict['num_mult_choices'] = 5
#         data_dict['answers'] = targets #'int32 torch tensor correct answer ids'
#         data_dict['video_features'] = torch.cat( [tvqa_data[18], tvqa_data[16]], dim=2 ).unsqueeze(2).unsqueeze(2)#'(bsz, 35(max), 1, 1, 6144)'
#         data_dict['video_lengths'] = tvqa_data[17].tolist() #'list, (bsz), number of frames, #max is 35'
#         data_dict['candidates'] = tvqa_data[0]#'torch tensor (bsz, 5, 35) (max q length)'
#         data_dict['candidate_lengths'] = tvqa_data[1]#'list of list, lengths of each question+ans'
#         import ipdb; ipdb.set_trace()


################################################################################



















    #dataset = dt.MSVDQA(args.batch_size, args.preprocess_dir)
    dataset = tvqa_dataset.TVQADataset(args, mode="train")
    train_loader = DataLoader(dataset, batch_size=args.bsz, shuffle=True, collate_fn=tvqa_dataset.pad_collate)

    args.memory_type='_mrm2s'
    args.save_model_path = args.save_path + 'model_%s_%s%s' % (args.image_feature_net,args.layer,args.memory_type)
    if not os.path.exists(args.save_model_path):
        os.makedirs(args.save_model_path)



    #############################
    # get video feature dimension
    #############################
    feat_channel = args.video_feature_dim
    feat_dim = 1
    text_embed_size = args.word_dim
    answer_vocab_size = args.answer_num
    voc_len = args.vocab_num
    num_layers = 2
    max_sequence_length = args.video_feature_num
    word_matrix = np.load(args.pretrained_embedding)
    answerset = pd.read_csv(os.path.expanduser('~/kable_management/data/msvd_qa/answer_set.txt'), header=None)[0]


    rnn = TGIFAttentionTwoStream(args, feat_channel, feat_dim, text_embed_size, args.hidden_size,
                         voc_len, num_layers, word_matrix, answer_vocab_size = answer_vocab_size,
                         max_len=max_sequence_length)
    #rnn = rnn.cuda()
    rnn = rnn.to(args.device)

    if args.test == 1:
        rnn.load_state_dict(torch.load(os.path.join(args.save_model_path, 'rnn-3800-vl_0.316.pkl')))



    # loss function
    criterion = nn.CrossEntropyLoss(size_average=True).cuda()
    optimizer = torch.optim.Adam(rnn.parameters(), lr=args.learning_rate)

    
    iter = 0
    ######################################################################
    best_val_acc = 0.0
    best_test_acc = 0.0
    train_loss = []
    train_acc = []
    ######################################################################
    for epoch in range(0, 10000):
        #dataset.reset_train()
        for batch_idx, batch in enumerate(train_loader):
            tvqa_data, targets, _ = tvqa_dataset.preprocess_inputs(batch, args.max_sub_l, args.max_vcpt_l, args.max_vid_l,
                                                     device='cpu')#args.device
            assert( tvqa_data[18].shape[2] == 4096 )
            data_dict = {}                                         
            data_dict['num_mult_choices'] = 5
            data_dict['answers'] = targets #'int32 torch tensor correct answer ids'
            data_dict['video_features'] = torch.cat( [tvqa_data[18], tvqa_data[16]], dim=2 ).unsqueeze(2).unsqueeze(2)#'(bsz, 35(max), 1, 1, 6144)'
            data_dict['video_lengths'] = tvqa_data[17].tolist() #'list, (bsz), number of frames, #max is 35'
            data_dict['candidates'] = tvqa_data[0]#'torch tensor (bsz, 5, 35) (max q length)'
            data_dict['candidate_lengths'] = tvqa_data[1]#'list of list, lengths of each question+ans'
            import ipdb; ipdb.set_trace()

            outputs, predictions = rnn(data_dict)
            targets = data_dict['answers']

            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            import ipdb; ipdb.set_trace()

            acc = rnn.accuracy(predictions, targets)
            print('Train iter %d, loss %.3f, acc %.2f' % (iter,loss.data,acc.item()))
            train_loss.append(loss.item())
            train_acc.append(acc.item())
Ejemplo n.º 5
0
def train(opt, dset, model, criterion, optimizer, epoch, previous_best_acc,
          scheduler):
    dset.set_mode("train")
    if opt.rubi:  # Model and the question/subtitle only rubi style model ill be packed together, unpack here
        model, rubi_model = model
        rubi_model.train()
    model.train()
    train_loader = DataLoader(dset,
                              batch_size=opt.bsz,
                              shuffle=True,
                              collate_fn=pad_collate)
    train_loss = []
    valid_acc_log = ["batch_idx\tacc"]
    train_corrects = []
    torch.set_grad_enabled(True)
    for batch_idx, batch in tqdm(enumerate(train_loader)):
        # Process inputs
        if (opt.lrtype == "cyclic"):
            scheduler.batch_step()
        model_inputs, targets, _ = preprocess_inputs(batch,
                                                     opt.max_sub_l,
                                                     opt.max_vcpt_l,
                                                     opt.max_vid_l,
                                                     device=opt.device)
        # model output
        if opt.dual_stream:
            outputs = model(*model_inputs)
            if opt.lanecheck:
                raise Exception(
                    "Not implemeneted lanechecking with dual stream")
        else:
            outputs = model(*model_inputs)
            if opt.rubi:
                rubi_outputs = rubi_model(*model_inputs)

        # Loss
        if not opt.rubi:
            if opt.lanecheck:
                loss = criterion(outputs[-1], targets)
            else:
                loss = criterion(outputs, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            if opt.lanecheck:
                rubi_in = {  # This may be confusing, but this is because of my naming scheme conflicting with Remi's
                    'logits': None,
                    'logits_q': rubi_outputs[-1],
                    'logits_rubi': outputs[-1]
                }
            else:
                rubi_in = {  # This may be confusing, but this is because of my naming scheme conflicting with Remi's
                    'logits': None,
                    'logits_q': rubi_outputs,
                    'logits_rubi': outputs
                }

            rubi_targets = {'class_id': targets}
            losses = criterion(rubi_in, rubi_targets)
            loss, q_loss = losses['loss_mm_q'], losses[
                'loss_q']  # loss is the fused rubi loss
            optimizer.zero_grad()
            loss.backward()
            q_loss.backward()  # push the question loss backwards too
            optimizer.step()
        opt.plotter.text_plot(opt.jobname + " epoch",
                              opt.jobname + " " + str(epoch))

        # measure accuracy and record loss
        train_loss.append(loss.item())
        if opt.lanecheck:
            pred_ids = outputs[-1].data.max(1)[1]
        else:
            pred_ids = outputs.data.max(1)[1]
        train_corrects += pred_ids.eq(targets.data).cpu().numpy().tolist()
        if (batch_idx % opt.log_freq == 2):  #(opt.log_freq-1)):
            niter = epoch * len(train_loader) + batch_idx
            train_acc = sum(train_corrects) / float(len(train_corrects))
            train_loss = sum(train_loss) / float(
                len(train_loss))  #from train_corrects

            # Plotting
            if opt.testrun == False:
                opt.writer.add_scalar("Train/Acc", train_acc, niter)
                opt.writer.add_scalar("Train/Loss", train_loss, niter)
                opt.plotter.plot("accuracy", "train", opt.jobname, niter,
                                 train_acc)
                opt.plotter.plot("loss", "train", opt.jobname, niter,
                                 train_loss)

            # Validation
            if opt.dual_stream or opt.deep_cca:
                valid_acc, _ = validate(opt, dset, model, mode="valid")
            else:
                valid_acc, val_lanecheck_dict = validate_lanecheck(
                    opt, dset, model, mode="valid")
            if opt.testrun == False:
                #opt.writer.add_scalar("Valid/Loss", valid_loss, niter)
                opt.plotter.plot("accuracy", "val", opt.jobname, niter,
                                 valid_acc)
                #opt.plotter.plot("loss", "val", opt.jobname, niter, valid_loss)
            valid_log_str = "%02d\t%.4f" % (batch_idx, valid_acc)
            valid_acc_log.append(valid_log_str)

            # If this is the best run yet
            if valid_acc > previous_best_acc:
                previous_best_acc = valid_acc

                # Plot best accuracy so far in text box
                if opt.testrun == False:
                    opt.plotter.text_plot(
                        opt.jobname + " val", opt.jobname + " val " +
                        str(round(previous_best_acc, 4)))

                # Save the predictions for validation and training datasets, and the state dictionary of the model
                #_, train_lanecheck_dict = validate_lanecheck(opt, dset, model, mode="train")
                if (not opt.dual_stream) and (not opt.deep_cca):
                    save_pickle(val_lanecheck_dict,
                                opt.lanecheck_path + '_valid')
                torch.save(model.state_dict(),
                           os.path.join(opt.results_dir, "best_valid.pth"))

            # reset to train
            torch.set_grad_enabled(True)
            model.train()
            dset.set_mode("train")
            train_corrects = []
            train_loss = []

        if opt.debug:
            break

    # additional log
    with open(os.path.join(opt.results_dir, "valid_acc.log"), "a") as f:
        f.write("\n".join(valid_acc_log) + "\n")

    return previous_best_acc
Ejemplo n.º 6
0
def validate_lanecheck(opt, dset, model, mode="valid"):
    dset.set_mode(mode)  # Change mode to training here
    torch.set_grad_enabled(False)
    model.eval()
    valid_loader = DataLoader(dset,
                              batch_size=opt.test_bsz,
                              shuffle=False,
                              collate_fn=pad_collate)
    lanecheck_dict = {}
    valid_corrects = []
    #opt.lanecheck = True
    if opt.disable_streams is not None:
        for d_stream in opt.disable_streams:
            if d_stream in opt.input_streams:
                opt.input_streams.remove(d_stream)
    else:
        opt.disable_streams = []

    for batch_idx, batch in enumerate(valid_loader):
        # Accuracy
        model_inputs, targets, qids = preprocess_inputs(batch,
                                                        opt.max_sub_l,
                                                        opt.max_vcpt_l,
                                                        opt.max_vid_l,
                                                        device=opt.device)
        sub_out, vcpt_out, vid_out, reg_out, regtopk_out, outputs = model(
            *model_inputs)
        pred_ids = outputs.data.max(1)[1]
        valid_corrects += pred_ids.eq(targets.data).cpu().numpy().tolist()

        # Add the ground truth to the end of each output response, and then the predicted ID for the question after that
        if 'sub' in opt.input_streams and not opt.dual_stream:
            sub_out = torch.cat(
                (sub_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1),
                 pred_ids.cpu().float().unsqueeze(1)),
                dim=1)
        if 'vcpt' in opt.input_streams and not opt.dual_stream:
            vcpt_out = torch.cat(
                (vcpt_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1),
                 pred_ids.cpu().float().unsqueeze(1)),
                dim=1)
        if 'imagenet' in opt.input_streams:
            vid_out = torch.cat(
                (vid_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1),
                 pred_ids.cpu().float().unsqueeze(1)),
                dim=1)
        if ('regional' in opt.input_streams) and opt.regional_topk == -1:
            reg_out = torch.cat(
                (reg_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1),
                 pred_ids.cpu().float().unsqueeze(1)),
                dim=1)
        if opt.regional_topk != -1:
            regtopk_out = torch.cat((regtopk_out.cpu().squeeze(),
                                     targets.cpu().float().unsqueeze(1),
                                     pred_ids.cpu().float().unsqueeze(1)),
                                    dim=1)

        # Add them to the lanecheck dictionary
        for id_idx in range(len(qids)):
            lanecheck_dict[qids[id_idx]] = {}
            if 'sub' in opt.input_streams:
                lanecheck_dict[qids[id_idx]]['sub_out'] = sub_out[id_idx]
            if 'vcpt' in opt.input_streams:
                lanecheck_dict[qids[id_idx]]['vcpt_out'] = vcpt_out[id_idx]
            if 'imagenet' in opt.input_streams:
                lanecheck_dict[qids[id_idx]]['vid_out'] = vid_out[id_idx]
            if ('regional' in opt.input_streams) and opt.regional_topk == -1:
                lanecheck_dict[qids[id_idx]]['reg_out'] = reg_out[id_idx]
            if opt.regional_topk != -1:
                lanecheck_dict[
                    qids[id_idx]]['regtopk_out'] = regtopk_out[id_idx]

    valid_acc = sum(valid_corrects) / float(len(valid_corrects))
    lanecheck_dict['acc'] = valid_acc
    #opt.lanecheck = False
    return valid_acc, lanecheck_dict
Ejemplo n.º 7
0
def validate(opt, dset, model, mode="valid"):
    dset.set_mode(opt.dset) # Change mode to training here
    torch.set_grad_enabled(False)
    model.eval()
    valid_loader = DataLoader(dset, batch_size=opt.test_bsz, shuffle=True, collate_fn=pad_collate)
    if opt.bert != None:
        bert_tok = BertTokenizer.from_pretrained('bert-base-uncased')
    else:
        # Word embedding lookup GloVE
        from utils import load_pickle
        idx2word = load_pickle(opt.idx2word_path)
    #valid_qids = []
    lanecheck_dict = {}
    valid_corrects = []
    if opt.disable_streams is not None:
        for d_stream in opt.disable_streams:
            if d_stream in opt.input_streams:
                opt.input_streams.remove(d_stream)
    else:
        opt.disable_streams = []
    for batch_idx, batch in enumerate(valid_loader):
        print(round(batch_idx/len(valid_loader)*100, 2), "percent complete")
        model_inputs, targets, qids = preprocess_inputs(batch, opt.max_sub_l, opt.max_vcpt_l, opt.max_vid_l,
                                                        device=opt.device)
        if opt.lanecheck:
            sub_out, vcpt_out, vid_out, reg_out, regtopk_out, outputs = model(*model_inputs)
        pred_ids = outputs.data.max(1)[1]
        valid_corrects += pred_ids.eq(targets.data).cpu().numpy().tolist()

        # measure accuracy and record loss
        #valid_qids += [int(x) for x in qids]

        batch_q , _    = getattr(batch, 'q')
        batch_a0, _    = getattr(batch, 'a0')
        batch_a1, _    = getattr(batch, 'a1')
        batch_a2, _    = getattr(batch, 'a2')
        batch_a3, _    = getattr(batch, 'a3')
        batch_a4, _    = getattr(batch, 'a4')
        if 'sub' in opt.input_streams:
            batch_sub, _   = getattr(batch, 'sub')
        if 'vcpt' in opt.input_streams:
            batch_vcpt, _  = getattr(batch, 'vcpt')

        # Add the ground truth to the end of each output response, and then the predicted ID for the question after that
        if 'sub' in opt.input_streams:
            sub_out = torch.cat((sub_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1), pred_ids.cpu().float().unsqueeze(1)), dim=1)
        if 'vcpt' in opt.input_streams:
            vcpt_out = torch.cat((vcpt_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1), pred_ids.cpu().float().unsqueeze(1)), dim=1)
        if 'imagenet' in opt.input_streams:
            vid_out = torch.cat((vid_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1), pred_ids.cpu().float().unsqueeze(1)), dim=1)
        if ('regional' in opt.input_streams) and opt.regional_topk == -1:
            reg_out = torch.cat((reg_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1), pred_ids.cpu().float().unsqueeze(1)), dim=1)
        if opt.regional_topk != -1:
            regtopk_out = torch.cat((regtopk_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1), pred_ids.cpu().float().unsqueeze(1)), dim=1)

        # Add them to the lanecheck dictionary
        for id_idx in range(len(qids)):
            lanecheck_dict[qids[id_idx]] = {}
            if 'sub' in opt.input_streams:
                lanecheck_dict[qids[id_idx]]['sub_out']       = sub_out[id_idx]
            if 'vcpt' in opt.input_streams:
                lanecheck_dict[qids[id_idx]]['vcpt_out']      = vcpt_out[id_idx]
            if 'imagenet' in opt.input_streams:
                lanecheck_dict[qids[id_idx]]['vid_out']       = vid_out[id_idx]
            if ('regional' in opt.input_streams) and opt.regional_topk == -1:
                lanecheck_dict[qids[id_idx]]['reg_out']       = reg_out[id_idx]
            if opt.regional_topk != -1:
                lanecheck_dict[qids[id_idx]]['regtopk_out']   = regtopk_out[id_idx]

            #Vcpt
            #decode from bert
            if opt.bert != None:
                lanecheck_dict[qids[id_idx]]['q']                = bert_tok.decode(batch_q[id_idx]) 
                lanecheck_dict[qids[id_idx]]['a0']               = bert_tok.decode(batch_a0[id_idx])  
                lanecheck_dict[qids[id_idx]]['a1']               = bert_tok.decode(batch_a1[id_idx]) 
                lanecheck_dict[qids[id_idx]]['a2']               = bert_tok.decode(batch_a2[id_idx])   
                lanecheck_dict[qids[id_idx]]['a3']               = bert_tok.decode(batch_a3[id_idx])  
                lanecheck_dict[qids[id_idx]]['a4']               = bert_tok.decode(batch_a4[id_idx]) 
                if 'sub' in opt.input_streams:
                    lanecheck_dict[qids[id_idx]]['sub']          = bert_tok.decode(batch_sub[id_idx])  
                if 'vcpt' in opt.input_streams:
                    lanecheck_dict[qids[id_idx]]['vcpt']         = bert_tok.decode(batch_vcpt[id_idx])
            else:
                # Decode from GloVE
                #idx2word
                lanecheck_dict[qids[id_idx]]['q']   = [ idx2word[int(word)] for word in batch_q[id_idx] ]
                lanecheck_dict[qids[id_idx]]['a0']  = [ idx2word[int(word)] for word in batch_a0[id_idx] ]
                lanecheck_dict[qids[id_idx]]['a1']  = [ idx2word[int(word)] for word in batch_a1[id_idx] ]
                lanecheck_dict[qids[id_idx]]['a2']  = [ idx2word[int(word)] for word in batch_a2[id_idx] ]
                lanecheck_dict[qids[id_idx]]['a3']  = [ idx2word[int(word)] for word in batch_a3[id_idx] ]
                lanecheck_dict[qids[id_idx]]['a4']  = [ idx2word[int(word)] for word in batch_a4[id_idx] ]
                if 'sub' in opt.input_streams:
                    lanecheck_dict[qids[id_idx]]['sub'] = [ idx2word[int(word)] for word in batch_sub[id_idx] ]
                if 'vcpt' in opt.input_streams:
                    lanecheck_dict[qids[id_idx]]['vcpt'] = [ idx2word[int(word)] for word in batch_vcpt[id_idx] ]
    
    valid_acc = sum(valid_corrects) / float(len(valid_corrects))
    lanecheck_dict['valid_acc'] = valid_acc
    print('valid acc', valid_acc)
    from utils import save_pickle
    save_pickle(lanecheck_dict, opt.lanecheck_path+'_'+opt.dset)
    return valid_acc
Ejemplo n.º 8
0
def main(args):
    torch.manual_seed(2667)  #1

    ### add arguments ###
    args.vc_dir = os.path.expanduser(
        "~/kable_management/data/tgif_qa/data/Vocabulary")
    args.df_dir = os.path.expanduser(
        "~/kable_management/data/tgif_qa/data/dataset")

    # Dataset
    dataset = tvqa_dataset.TVQADataset(args, mode="train")

    args.max_sequence_length = 50  #35
    args.model_name = args.task + args.feat_type

    args.memory_type = '_mrm2s'
    args.image_feature_net = 'concat'
    args.layer = 'fc'

    args.save_model_path = args.save_path + '%s_%s_%s%s' % (
        args.task, args.image_feature_net, args.layer, args.memory_type)
    # Create model directory
    if not os.path.exists(args.save_model_path):
        os.makedirs(args.save_model_path)

    #############################
    # get video feature dimension
    #############################
    video_feature_dimension = (480, 1, 1, 6144)
    feat_channel = video_feature_dimension[3]
    feat_dim = video_feature_dimension[2]
    text_embed_size = 300
    answer_vocab_size = None

    #############################
    # get word vector dimension
    #############################
    word_matrix = dataset.vocab_embedding
    voc_len = len(dataset.word2idx)
    assert text_embed_size == word_matrix.shape[1]

    #############################
    # Parameters
    #############################
    SEQUENCE_LENGTH = args.max_sequence_length
    VOCABULARY_SIZE = voc_len
    assert VOCABULARY_SIZE == voc_len
    FEAT_DIM = (1, 1, 6144)

    # Create model directory
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    if args.task == 'Count':
        # add L2 loss
        criterion = nn.MSELoss(size_average=True).to(args.device)
    elif args.task in ['Action', 'Trans']:
        from embed_loss import MultipleChoiceLoss
        criterion = MultipleChoiceLoss(num_option=5,
                                       margin=1,
                                       size_average=True).to(args.device)
    elif args.task == 'FrameQA':
        # add classification loss
        answer_vocab_size = len(train_dataset.ans2idx)
        print('Vocabulary size', answer_vocab_size, VOCABULARY_SIZE)
        criterion = nn.CrossEntropyLoss(size_average=True).to(args.device)

    if args.memory_type == '_mrm2s':
        rnn = TGIFAttentionTwoStream(args,
                                     args.task,
                                     feat_channel,
                                     feat_dim,
                                     text_embed_size,
                                     args.hidden_size,
                                     voc_len,
                                     args.num_layers,
                                     word_matrix,
                                     answer_vocab_size=answer_vocab_size,
                                     max_len=args.max_sequence_length)
    else:
        assert 1 == 2

    rnn = rnn.to(args.device)

    #  to directly test, load pre-trained model, replace with your model to test your model
    if args.test == 1:
        if args.task == 'Count':
            rnn.load_state_dict(
                torch.load(
                    './saved_models/Count_concat_fc_mrm2s/rnn-1300-l3.257-a27.942.pkl'
                ))
        elif args.task == 'Action':
            rnn.load_state_dict(
                torch.load(
                    './saved_models/Action_concat_fc_mrm2s/rnn-0800-l0.137-a84.663.pkl'
                ))
        elif args.task == 'Trans':
            rnn.load_state_dict(
                torch.load(
                    './saved_models/Trans_concat_fc_mrm2s/rnn-1500-l0.246-a78.068.pkl'
                ))
        elif args.task == 'FrameQA':
            rnn.load_state_dict(
                torch.load(
                    './saved_models/FrameQA_concat_fc_mrm2s/rnn-4200-l1.233-a69.361.pkl'
                ))
        else:
            assert 1 == 2, 'Invalid task'

    optimizer = torch.optim.Adam(rnn.parameters(), lr=args.learning_rate)

    iter = 0

    if args.task in ['Action', 'Trans']:
        best_val_acc = 0.0
        best_val_iter = 0.0
        best_test_acc = 0.0
        train_acc = []
        train_loss = []

        for epoch in range(0, args.num_epochs):
            dataset.set_mode("train")
            train_loader = DataLoader(dataset,
                                      batch_size=args.bsz,
                                      shuffle=True,
                                      collate_fn=tvqa_dataset.pad_collate,
                                      num_workers=args.num_workers)
            for batch_idx, batch in enumerate(train_loader):
                # if(batch_idx<51):
                #     iter+=1
                #     continue
                # else:
                #     pass
                tvqa_data, targets, _ = tvqa_dataset.preprocess_inputs(
                    batch,
                    args.max_sub_l,
                    args.max_vcpt_l,
                    args.max_vid_l,
                    device=args.device)
                assert (tvqa_data[18].shape[2] == 4096)
                data_dict = {}
                data_dict['num_mult_choices'] = 5
                data_dict[
                    'answers'] = targets  #'int32 torch tensor correct answer ids'
                data_dict['video_features'] = torch.cat(
                    [tvqa_data[18], tvqa_data[16]],
                    dim=2).unsqueeze(2).unsqueeze(
                        2)  #'(bsz, 35(max), 1, 1, 6144)'
                data_dict['video_lengths'] = tvqa_data[17].tolist(
                )  #'list, (bsz), number of frames, #max is 35'
                data_dict['candidates'] = tvqa_data[
                    0]  #'torch tensor (bsz, 5, 35) (max q length)'
                data_dict['candidate_lengths'] = tvqa_data[
                    1]  #'list of list, lengths of each question+ans'
                # if hard restrict
                if args.hard_restrict:
                    data_dict['video_features'] = data_dict[
                        'video_features'][:, :args.
                                          max_sequence_length]  #'(bsz, 35(max), 1, 1, 6144)'
                    data_dict['video_lengths'] = [
                        i if i <= args.max_sequence_length else
                        args.max_sequence_length
                        for i in data_dict['video_lengths']
                    ]
                    data_dict['candidates'] = data_dict[
                        'candidates'][:, :, :args.max_sequence_length]
                    for xx in range(data_dict['candidate_lengths'].shape[0]):
                        if (data_dict['video_lengths'][xx] == 0):
                            data_dict['video_lengths'][xx] = 1
                        for yy in range(
                                data_dict['candidate_lengths'].shape[1]):
                            if (data_dict['candidate_lengths'][xx][yy] >
                                    args.max_sequence_length):
                                data_dict['candidate_lengths'][xx][
                                    yy] = args.max_sequence_length
                #import ipdb; ipdb.set_trace()
                #print("boom")
                outputs, targets, predictions = rnn(data_dict, args.task)

                loss = criterion(outputs, targets)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                acc = rnn.accuracy(predictions, targets.long())
                #print("boom")
                print('Train %s iter %d, loss %.3f, acc %.2f' %
                      (args.task, iter, loss.data, acc.item()))
                ######################################################################
                #args.plotter.plot("accuracy", "train", args.jobname, iter, acc.item())
                #args.plotter.plot("loss", "train", args.jobname, iter, int(loss.data))
                train_acc.append(acc.item())
                train_loss.append(loss.item())
                ######################################################################

                if (batch_idx % args.log_freq) == 0:  #(args.log_freq-1):
                    dataset.set_mode("valid")
                    valid_loader = DataLoader(
                        dataset,
                        batch_size=args.test_bsz,
                        shuffle=True,
                        collate_fn=tvqa_dataset.pad_collate,
                        num_workers=args.num_workers)
                    rnn.eval()
                    ######################################################################
                    train_loss = sum(train_loss) / float(len(train_loss))
                    train_acc = sum(train_acc) / float(len(train_acc))
                    args.plotter.plot("accuracy", "train",
                                      args.jobname + " " + args.task, iter,
                                      train_acc)
                    #import ipdb; ipdb.set_trace()
                    args.plotter.plot("loss", "train",
                                      args.jobname + " " + args.task, iter,
                                      train_loss)
                    train_loss = []
                    train_acc = []
                    ######################################################################
                    with torch.no_grad():
                        if args.test == 0:
                            n_iter = len(valid_loader)
                            losses = []
                            accuracy = []
                            iter_val = 0
                            for batch_idx, batch in enumerate(valid_loader):
                                tvqa_data, targets, _ = tvqa_dataset.preprocess_inputs(
                                    batch,
                                    args.max_sub_l,
                                    args.max_vcpt_l,
                                    args.max_vid_l,
                                    device=args.device)
                                assert (tvqa_data[18].shape[2] == 4096)
                                data_dict = {}
                                data_dict['num_mult_choices'] = 5
                                data_dict[
                                    'answers'] = targets  #'int32 torch tensor correct answer ids'
                                data_dict['video_features'] = torch.cat(
                                    [tvqa_data[18], tvqa_data[16]],
                                    dim=2).unsqueeze(2).unsqueeze(
                                        2)  #'(bsz, 35(max), 1, 1, 6144)'
                                data_dict['video_lengths'] = tvqa_data[
                                    17].tolist(
                                    )  #'list, (bsz), number of frames, #max is 35'
                                data_dict['candidates'] = tvqa_data[
                                    0]  #'torch tensor (bsz, 5, 35) (max q length)'
                                data_dict['candidate_lengths'] = tvqa_data[
                                    1]  #'list of list, lengths of each question+ans'
                                # if hard restrict
                                if args.hard_restrict:
                                    data_dict['video_features'] = data_dict[
                                        'video_features'][:, :args.
                                                          max_sequence_length]  #'(bsz, 35(max), 1, 1, 6144)'
                                    data_dict['video_lengths'] = [
                                        i if i <= args.max_sequence_length else
                                        args.max_sequence_length
                                        for i in data_dict['video_lengths']
                                    ]
                                    data_dict['candidates'] = data_dict[
                                        'candidates'][:, :, :args.
                                                      max_sequence_length]
                                    for xx in range(
                                            data_dict['candidate_lengths'].
                                            shape[0]):
                                        if (data_dict['video_lengths'][xx] == 0
                                            ):
                                            data_dict['video_lengths'][xx] = 1
                                        for yy in range(
                                                data_dict['candidate_lengths'].
                                                shape[1]):
                                            if (data_dict['candidate_lengths']
                                                [xx][yy] >
                                                    args.max_sequence_length):
                                                data_dict['candidate_lengths'][
                                                    xx][yy] = args.max_sequence_length

                                outputs, targets, predictions = rnn(
                                    data_dict, args.task)

                                loss = criterion(outputs, targets)
                                acc = rnn.accuracy(predictions, targets.long())

                                losses.append(loss.item())
                                accuracy.append(acc.item())
                            losses = sum(losses) / n_iter
                            accuracy = sum(accuracy) / n_iter

                            ######################################################################
                            if best_val_acc < accuracy:
                                best_val_iter = iter
                                best_val_acc = accuracy
                                args.plotter.text_plot(
                                    args.task + args.jobname + " val",
                                    args.task + args.jobname + " val " +
                                    str(round(best_val_acc, 4)) + " " +
                                    str(iter))
                            #args.plotter.plot("loss", "val", args.jobname, iter, losses.avg)
                            args.plotter.plot("accuracy", "val",
                                              args.jobname + " " + args.task,
                                              iter, accuracy)
                            ######################################################################

                            #print('[Val] iter %d, loss %.3f, acc %.2f, best acc %.3f at iter %d' % (
                            #                iter, losses, accuracy, best_val_acc, best_val_iter))
                            # torch.save(rnn.state_dict(), os.path.join(args.save_model_path, 'rnn-%04d-l%.3f-a%.3f.pkl' % (
                            #                 iter, losses.avg, accuracy.avg)))
                    rnn.train()
                    dataset.set_mode("train")

                iter += 1
    elif args.task == 'Count':
        pass