Пример #1
0
def main():
    """Main script."""
    torch.manual_seed(2667)

    ################################################################################
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', default='train', help='train/test')
    parser.add_argument('--save_path', type=str, default=os.path.expanduser('~/kable_management/blp_paper/.results/hme_tvqa/saved_models'),
                        help='path for saving trained models')
    parser.add_argument('--test', type=int, default=0, help='0 | 1')
    parser.add_argument('--jobname', type=str , default='hme_tvqa-default',
                        help='jobname to direct to plotter') 
    parser.add_argument("--pool_type", type=str, default="default", choices=["default", "LinearSum", "ConcatMLP", "MCB", "MFH", "MFB", "MLB", "Block"], help="Which pooling technique to use")
    parser.add_argument("--pool_in_dims", type=int, nargs='+', default=[300,300], help="Input dimensions to pooling layers")
    parser.add_argument("--pool_out_dim", type=int, default=600, help="Output dimension to pooling layers")
    parser.add_argument("--pool_hidden_dim", type=int, default=1500, help="Some pooling types come with a hidden internal dimension")
    
    parser.add_argument("--results_dir_base", type=str, default=os.path.expanduser("~/kable_management/.results/test"))
    parser.add_argument("--log_freq", type=int, default=400, help="print, save training info")
    parser.add_argument("--lr", type=float, default=3e-4, help="learning rate")
    parser.add_argument("--wd", type=float, default=1e-5, help="weight decay")
    parser.add_argument("--n_epoch", type=int, default=100, help="number of epochs to run")
    parser.add_argument("--max_es_cnt", type=int, default=3, help="number of epochs to early stop")
    parser.add_argument("--bsz", type=int, default=32, help="mini-batch size")
    parser.add_argument("--test_bsz", type=int, default=100, help="mini-batch size for testing")
    parser.add_argument("--device", type=int, default=0, help="gpu ordinal, -1 indicates cpu")
    parser.add_argument("--no_core_driver", action="store_true",
                                help="hdf5 driver, default use `core` (load into RAM), if specified, use `None`")
    parser.add_argument("--word_count_threshold", type=int, default=2, help="word vocabulary threshold")
    parser.add_argument("--no_glove", action="store_true", help="not use glove vectors")
    parser.add_argument("--no_ts", action="store_true", help="no timestep annotation, use full length feature")
    parser.add_argument("--input_streams", type=str, nargs="+", choices=["vcpt", "sub", "imagenet", "regional"], #added regional support here
                                help="input streams for the model")
    ################ My stuff
    parser.add_argument("--modelname", type=str, default="tvqa_abc", help="name of the model ot use")
    parser.add_argument("--lrtype", type=str, choices=["adam", "cyclic", "radam", "lrelu"], default="adam", help="Kind of learning rate")
    parser.add_argument("--poolnonlin", type=str, choices=["tanh", "relu", "sigmoid", "None", "lrelu"], default="None", help="add nonlinearities to pooling layer")
    parser.add_argument("--pool_dropout", type=float, default=0.0, help="Dropout value for the projections")
    parser.add_argument("--testrun", type=bool, default=False, help="set True to stop writing and visdom")
    parser.add_argument("--topk", type=int, default=1, help="To use instead of max pooling")
    parser.add_argument("--nosub", type=bool, default=False, help="Ignore the sub stream")
    parser.add_argument("--noimg", type=bool, default=False, help="Ignore the imgnet stream")
    parser.add_argument("--bert", type=str, choices=["default", "mine", "multi_choice", "qa"], default=None, help="What kind of BERT model to use")
    parser.add_argument("--reg_feat_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/regional_features/100p.h5"),
                                help="regional features")
    parser.add_argument("--my_vcpt", type=bool, default=False, help="Use my extracted visual concepts")
    parser.add_argument("--lanecheck", type=bool, default=False, help="Validation lane checks")
    parser.add_argument("--lanecheck_path", type=str, help="Validation lane check path")
    parser.add_argument("--best_path", type=str, help="Path to best model")
    parser.add_argument("--disable_streams", type=str, default=None, nargs="+", choices=["vcpt", "sub", "imagenet", "regional"], #added regional support here
                                help="disable the input stream from voting in the softmax of model outputs")
    parser.add_argument("--dset", choices=["valid", "test", "train"], default="valid", type=str, help="The dataset to use")
    ########################

    parser.add_argument("--n_layers_cls", type=int, default=1, help="number of layers in classifier")
    parser.add_argument("--hsz1", type=int, default=150, help="hidden size for the first lstm")
    parser.add_argument("--hsz2", type=int, default=300, help="hidden size for the second lstm")
    parser.add_argument("--embedding_size", type=int, default=300, help="word embedding dim")
    parser.add_argument("--max_sub_l", type=int, default=300, help="max length for subtitle")
    parser.add_argument("--max_vcpt_l", type=int, default=300, help="max length for visual concepts")
    parser.add_argument("--max_vid_l", type=int, default=480, help="max length for video feature")
    parser.add_argument("--vocab_size", type=int, default=0, help="vocabulary size")
    parser.add_argument("--no_normalize_v", action="store_true", help="do not normalize video featrue")
    # Data paths
    parser.add_argument("--train_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/tvqa_train_processed.json"),
                                help="train set path")
    parser.add_argument("--valid_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/tvqa_val_processed.json"),
                                help="valid set path")
    parser.add_argument("--test_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/tvqa_test_public_processed.json"),
                                help="test set path")
    parser.add_argument("--glove_path", type=str, default=os.path.expanduser("~/kable_management/data/word_embeddings/glove.6B.300d.txt"),
                                help="GloVe pretrained vector path")
    parser.add_argument("--vcpt_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/vcpt_features/py2_det_visual_concepts_hq.pickle"),
                                help="visual concepts feature path")
    parser.add_argument("--vid_feat_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/imagenet_features/tvqa_imagenet_pool5_hq.h5"),
                                help="imagenet feature path")
    parser.add_argument("--vid_feat_size", type=int, default=2048,
                                help="visual feature dimension")
    parser.add_argument("--word2idx_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/cache/py2_word2idx.pickle"),
                                help="word2idx cache path")
    parser.add_argument("--idx2word_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/cache/py2_idx2word.pickle"),
                                help="idx2word cache path")
    parser.add_argument("--vocab_embedding_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/cache/py2_vocab_embedding.pickle"),
                                help="vocab_embedding cache path")
    parser.add_argument("--regional_topk", type=int, default=-1, help="Pick top-k scoring regional features across all frames")
    
    args = parser.parse_args()

    plotter = VisdomLinePlotter(env_name="TEST")#args.jobname)
    args.plotter = plotter
    args.normalize_v = not args.no_normalize_v
    args.device = torch.device("cuda:%d" % args.device if args.device >= 0 else "cpu")
    args.with_ts = not args.no_ts
    args.input_streams = [] if args.input_streams is None else args.input_streams
    args.vid_feat_flag = True if "imagenet" in args.input_streams else False
    args.h5driver = None if args.no_core_driver else "core"

    
    #args.dataset = 'msvd_qa'
    args.dataset = 'tvqa'

    args.word_dim = 300
    args.vocab_num = 4000
    args.pretrained_embedding = os.path.expanduser("~/kable_management/data/msvd_qa/word_embedding.npy")
    args.video_feature_dim = 4096
    args.video_feature_num = 480#20
    args.answer_num = 1000
    args.memory_dim = 256
    args.batch_size = 32
    args.reg_coeff = 1e-5
    args.learning_rate = 0.001
    args.preprocess_dir = os.path.expanduser("~/kable_management/data/hme_tvqa/")
    args.log = os.path.expanduser('~/kable_management/blp_paper/.results/msvd/logs')
    args.hidden_size = 512
    args.image_feature_net = 'concat'
    args.layer = 'fc'
    ################################################################################







    #model.load_embedding(dset.vocab_embedding)
    #opt.vocab_size = len(dset.word2idx)


    ##############################################################################
    ##############################################################################
    ##############################################################################
    ##############################################################################
    ##############################################################################
    ##############################################################################
    ##############################################################################
    ##############################################################################
    ### add arguments ### 
    args.max_sequence_length = 480

    args.memory_type = '_mrm2s'
    args.image_feature_net = 'concat'
    args.layer = 'fc'


    
    #args.save_model_path = args.save_path + '%s_%s_%s%s' % (args.task,args.image_feature_net,args.layer,args.memory_type)
    ## Create model directory
    # if not os.path.exists(args.save_model_path):
    #     os.makedirs(args.save_model_path)
    


    
    #############################
    # Parameters
    #############################
    SEQUENCE_LENGTH = args.max_sequence_length
    VOCABULARY_SIZE = train_dataset.n_words
    assert VOCABULARY_SIZE == voc_len
    FEAT_DIM = train_dataset.get_video_feature_dimension()[1:]
    train_iter = train_dataset.batch_iter(args.num_epochs, args.batch_size)


    # Create model directory
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    from embed_loss import MultipleChoiceLoss
    criterion = MultipleChoiceLoss(num_option=5, margin=1, size_average=True).cuda()

    
    
    if args.memory_type=='_mrm2s':
        rnn = TGIFAttentionTwoStream(args, 'Trans', feat_channel, feat_dim, text_embed_size, args.hidden_size,
                             voc_len, args.num_layers, word_matrix, answer_vocab_size = answer_vocab_size, 
                             max_len=args.max_sequence_length)
    else:
        assert 1==2
        
    rnn = rnn.cuda()
    ##############################################################################
    ############################################################################################################################################################
    ############################################################################################################################################################
    ############################################################################################################################################################
    ##############################################################################





# Heres my code
# rnn = TGIFAttentionTwoStream(args, feat_channel, feat_dim, text_embed_size, args.hidden_size,
#                      voc_len, num_layers, word_matrix, answer_vocab_size = answer_vocab_size,
#                      max_len=max_sequence_length)
# #rnn = rnn.cuda()
# rnn = rnn.to(args.device)

# if args.test == 1:
#     rnn.load_state_dict(torch.load(os.path.join(args.save_model_path, 'rnn-3800-vl_0.316.pkl')))



# # loss function
# criterion = nn.CrossEntropyLoss(size_average=True).cuda()
# optimizer = torch.optim.Adam(rnn.parameters(), lr=args.learning_rate)


# iter = 0
# ######################################################################
# best_val_acc = 0.0
# best_test_acc = 0.0
# train_loss = []
# train_acc = []
# ######################################################################
# for epoch in range(0, 10000):
#     #dataset.reset_train()
#     for batch_idx, batch in enumerate(train_loader):
#         tvqa_data, targets, _ = tvqa_dataset.preprocess_inputs(batch, args.max_sub_l, args.max_vcpt_l, args.max_vid_l,
#                                                  device='cpu')#args.device
#         assert( tvqa_data[18].shape[2] == 4096 )
#         data_dict = {}                                         
#         data_dict['num_mult_choices'] = 5
#         data_dict['answers'] = targets #'int32 torch tensor correct answer ids'
#         data_dict['video_features'] = torch.cat( [tvqa_data[18], tvqa_data[16]], dim=2 ).unsqueeze(2).unsqueeze(2)#'(bsz, 35(max), 1, 1, 6144)'
#         data_dict['video_lengths'] = tvqa_data[17].tolist() #'list, (bsz), number of frames, #max is 35'
#         data_dict['candidates'] = tvqa_data[0]#'torch tensor (bsz, 5, 35) (max q length)'
#         data_dict['candidate_lengths'] = tvqa_data[1]#'list of list, lengths of each question+ans'
#         import ipdb; ipdb.set_trace()


################################################################################



















    #dataset = dt.MSVDQA(args.batch_size, args.preprocess_dir)
    dataset = tvqa_dataset.TVQADataset(args, mode="train")
    train_loader = DataLoader(dataset, batch_size=args.bsz, shuffle=True, collate_fn=tvqa_dataset.pad_collate)

    args.memory_type='_mrm2s'
    args.save_model_path = args.save_path + 'model_%s_%s%s' % (args.image_feature_net,args.layer,args.memory_type)
    if not os.path.exists(args.save_model_path):
        os.makedirs(args.save_model_path)



    #############################
    # get video feature dimension
    #############################
    feat_channel = args.video_feature_dim
    feat_dim = 1
    text_embed_size = args.word_dim
    answer_vocab_size = args.answer_num
    voc_len = args.vocab_num
    num_layers = 2
    max_sequence_length = args.video_feature_num
    word_matrix = np.load(args.pretrained_embedding)
    answerset = pd.read_csv(os.path.expanduser('~/kable_management/data/msvd_qa/answer_set.txt'), header=None)[0]


    rnn = TGIFAttentionTwoStream(args, feat_channel, feat_dim, text_embed_size, args.hidden_size,
                         voc_len, num_layers, word_matrix, answer_vocab_size = answer_vocab_size,
                         max_len=max_sequence_length)
    #rnn = rnn.cuda()
    rnn = rnn.to(args.device)

    if args.test == 1:
        rnn.load_state_dict(torch.load(os.path.join(args.save_model_path, 'rnn-3800-vl_0.316.pkl')))



    # loss function
    criterion = nn.CrossEntropyLoss(size_average=True).cuda()
    optimizer = torch.optim.Adam(rnn.parameters(), lr=args.learning_rate)

    
    iter = 0
    ######################################################################
    best_val_acc = 0.0
    best_test_acc = 0.0
    train_loss = []
    train_acc = []
    ######################################################################
    for epoch in range(0, 10000):
        #dataset.reset_train()
        for batch_idx, batch in enumerate(train_loader):
            tvqa_data, targets, _ = tvqa_dataset.preprocess_inputs(batch, args.max_sub_l, args.max_vcpt_l, args.max_vid_l,
                                                     device='cpu')#args.device
            assert( tvqa_data[18].shape[2] == 4096 )
            data_dict = {}                                         
            data_dict['num_mult_choices'] = 5
            data_dict['answers'] = targets #'int32 torch tensor correct answer ids'
            data_dict['video_features'] = torch.cat( [tvqa_data[18], tvqa_data[16]], dim=2 ).unsqueeze(2).unsqueeze(2)#'(bsz, 35(max), 1, 1, 6144)'
            data_dict['video_lengths'] = tvqa_data[17].tolist() #'list, (bsz), number of frames, #max is 35'
            data_dict['candidates'] = tvqa_data[0]#'torch tensor (bsz, 5, 35) (max q length)'
            data_dict['candidate_lengths'] = tvqa_data[1]#'list of list, lengths of each question+ans'
            import ipdb; ipdb.set_trace()

            outputs, predictions = rnn(data_dict)
            targets = data_dict['answers']

            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            import ipdb; ipdb.set_trace()

            acc = rnn.accuracy(predictions, targets)
            print('Train iter %d, loss %.3f, acc %.2f' % (iter,loss.data,acc.item()))
            train_loss.append(loss.item())
            train_acc.append(acc.item())
Пример #2
0
def main():
    """Main script."""
    torch.manual_seed(1)

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', default='train', help='train/test')
    parser.add_argument('--save_path',
                        type=str,
                        default='./pretrain_models/',
                        help='path for saving trained models')
    parser.add_argument('--test', type=int, default=0, help='0 | 1')
    parser.add_argument('--memory_type', type=str, help='0 | 1')

    args = parser.parse_args()

    args.dataset = 'zh_vqa'
    args.word_dim = 300
    args.vocab_num = 4000
    args.pretrained_embedding = './data_zhqa/word_embedding.npy'
    args.video_feature_dim = 4096
    args.video_feature_num = 20
    args.memory_dim = 256
    args.batch_size = 8
    args.reg_coeff = 1e-5
    args.learning_rate = 0.001
    args.preprocess_dir = 'data_zhqa'
    args.log = './logs'
    args.hidden_size = 256

    dataset = dt.ZHVQA(args.batch_size, './data_zhqa')
    args.image_feature_net = 'concat'
    args.layer = 'fc'

    args.save_model_path = args.save_path + 'model_%s_%s%s' % (
        args.image_feature_net, args.layer, args.memory_type)
    if not os.path.exists(args.save_model_path):
        os.makedirs(args.save_model_path)

    #############################
    # get video feature dimension
    #############################
    feat_channel = args.video_feature_dim
    feat_dim = 1
    text_embed_size = args.word_dim
    voc_len = args.vocab_num
    num_layers = 2
    max_sequence_length = args.video_feature_num
    word_matrix = np.load(args.pretrained_embedding)

    if args.memory_type == '_mrm2s':
        rnn = AttentionTwoStream(feat_channel,
                                 feat_dim,
                                 text_embed_size,
                                 args.hidden_size,
                                 voc_len,
                                 num_layers,
                                 word_matrix,
                                 max_len=max_sequence_length)
        rnn = rnn.to(device)

    elif args.memory_type == '_stvqa':
        feat_channel *= 2
        rnn = TGIFBenchmark(feat_channel,
                            feat_dim,
                            text_embed_size,
                            args.hidden_size,
                            voc_len,
                            num_layers,
                            word_matrix,
                            max_len=max_sequence_length)
        rnn = rnn.to(device)

    elif args.memory_type == '_enc_dec':
        feat_channel *= 2
        rnn = LSTMEncDec(feat_channel,
                         feat_dim,
                         text_embed_size,
                         args.hidden_size,
                         voc_len,
                         num_layers,
                         word_matrix,
                         max_len=max_sequence_length)
        rnn = rnn.to(device)

    elif args.memory_type == '_co_mem':
        rnn = CoMemory(feat_channel,
                       feat_dim,
                       text_embed_size,
                       args.hidden_size,
                       voc_len,
                       num_layers,
                       word_matrix,
                       max_len=max_sequence_length)
        rnn = rnn.to(device)

    else:
        raise Exception('Please specify memory_type')

    # loss function
    criterion = MultipleChoiceLoss(margin=1, size_average=True).to(device)

    optimizer = torch.optim.Adam(rnn.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=0.0005)

    iter = 0
    for epoch in range(0, 20):
        dataset.reset_train()

        while dataset.has_train_batch:
            iter += 1

            vgg, c3d, questions, answers, question_lengths = dataset.get_train_batch(
            )
            data_dict = getInput(vgg, c3d, questions, answers,
                                 question_lengths)
            outputs, predictions = rnn(data_dict)
            targets = data_dict['answers']

            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc = rnn.accuracy(predictions, targets)
            print('Train iter %d, loss %.3f, acc %.2f' %
                  (iter, loss.data, acc.item()))

            if iter % 300 == 0:
                rnn.eval()

                # val iterate over examples
                with torch.no_grad():

                    idx = 0
                    accuracy = AverageMeter()

                    while dataset.has_val_example:
                        if idx % 10 == 0:
                            print 'Val iter %d/%d' % (
                                idx, dataset.val_example_total)
                        idx += 1

                        vgg, c3d, questions, answers, question_lengths = dataset.get_val_example(
                        )
                        data_dict = getInput(vgg, c3d, questions, answers,
                                             question_lengths)
                        outputs, predictions = rnn(data_dict)
                        targets = data_dict['answers']

                        acc = rnn.accuracy(predictions, targets)
                        accuracy.update(acc.item(), len(vgg))

                    val_acc = accuracy.avg
                    print('Val iter %d, acc %.3f' % (iter, val_acc))
                    dataset.reset_val()

                    idx = 0
                    accuracy = AverageMeter()

                    while dataset.has_test_example:
                        if idx % 10 == 0:
                            print 'Test iter %d/%d' % (
                                idx, dataset.test_example_total)
                        idx += 1

                        vgg, c3d, questions, answers, question_lengths, _ = dataset.get_test_example(
                        )
                        data_dict = getInput(vgg, c3d, questions, answers,
                                             question_lengths)
                        outputs, predictions = rnn(data_dict)
                        targets = data_dict['answers']

                        acc = rnn.accuracy(predictions, targets)
                        accuracy.update(acc.item(), len(vgg))

                    test_acc = accuracy.avg
                    print('Test iter %d, acc %.3f' % (iter, accuracy.avg))
                    dataset.reset_test()

                    torch.save(
                        rnn.state_dict(),
                        os.path.join(
                            args.save_model_path,
                            'rnn-%04d-vl_%.3f-t_%.3f.pkl' %
                            (iter, val_acc, test_acc)))
                rnn.train()
Пример #3
0
def main(args):
    torch.manual_seed(1)

    ### add arguments ###
    args.vc_dir = './data/Vocabulary'
    args.df_dir = './data/dataset'
    args.max_sequence_length = 35
    args.model_name = args.task + args.feat_type

    args.memory_type = '_mrm2s'
    args.image_feature_net = 'concat'
    args.layer = 'fc'

    args.save_model_path = args.save_path + '%s_%s_%s%s' % (
        args.task, args.image_feature_net, args.layer, args.memory_type)
    # Create model directory
    if not os.path.exists(args.save_model_path):
        os.makedirs(args.save_model_path)

    ######################################################################################
    ## This part of dataset code is adopted from
    ## https://github.com/YunseokJANG/tgif-qa/blob/master/code/gifqa/data_util/tgif.py
    ######################################################################################
    print 'Start loading TGIF dataset'
    train_dataset = DatasetTGIF(dataset_name='train',
                                image_feature_net=args.image_feature_net,
                                layer=args.layer,
                                max_length=args.max_sequence_length,
                                data_type=args.task,
                                dataframe_dir=args.df_dir,
                                vocab_dir=args.vc_dir)
    train_dataset.load_word_vocabulary()

    val_dataset = train_dataset.split_dataset(ratio=0.1)
    val_dataset.share_word_vocabulary_from(train_dataset)

    test_dataset = DatasetTGIF(dataset_name='test',
                               image_feature_net=args.image_feature_net,
                               layer=args.layer,
                               max_length=args.max_sequence_length,
                               data_type=args.task,
                               dataframe_dir=args.df_dir,
                               vocab_dir=args.vc_dir)

    test_dataset.share_word_vocabulary_from(train_dataset)

    print 'dataset lengths train/val/test %d/%d/%d' % (
        len(train_dataset), len(val_dataset), len(test_dataset))

    #############################
    # get video feature dimension
    #############################
    video_feature_dimension = train_dataset.get_video_feature_dimension()
    feat_channel = video_feature_dimension[3]
    feat_dim = video_feature_dimension[2]
    text_embed_size = train_dataset.GLOVE_EMBEDDING_SIZE
    answer_vocab_size = None

    #############################
    # get word vector dimension
    #############################
    word_matrix = train_dataset.word_matrix
    voc_len = word_matrix.shape[0]
    assert text_embed_size == word_matrix.shape[1]

    #############################
    # Parameters
    #############################
    SEQUENCE_LENGTH = args.max_sequence_length
    VOCABULARY_SIZE = train_dataset.n_words
    assert VOCABULARY_SIZE == voc_len
    FEAT_DIM = train_dataset.get_video_feature_dimension()[1:]

    train_iter = train_dataset.batch_iter(args.num_epochs, args.batch_size)

    # Create model directory
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    if args.task == 'Count':
        # add L2 loss
        criterion = nn.MSELoss(size_average=True).cuda()
    elif args.task in ['Action', 'Trans']:
        from embed_loss import MultipleChoiceLoss
        criterion = MultipleChoiceLoss(num_option=5,
                                       margin=1,
                                       size_average=True).cuda()
    elif args.task == 'FrameQA':
        # add classification loss
        answer_vocab_size = len(train_dataset.ans2idx)
        print('Vocabulary size', answer_vocab_size, VOCABULARY_SIZE)
        criterion = nn.CrossEntropyLoss(size_average=True).cuda()

    if args.memory_type == '_mrm2s':
        rnn = AttentionTwoStream(args.task,
                                 feat_channel,
                                 feat_dim,
                                 text_embed_size,
                                 args.hidden_size,
                                 voc_len,
                                 args.num_layers,
                                 word_matrix,
                                 answer_vocab_size=answer_vocab_size,
                                 max_len=args.max_sequence_length)
    else:
        assert 1 == 2

    rnn = rnn.cuda()

    #  to directly test, load pre-trained model, replace with your model to test your model
    if args.test == 1:
        if args.task == 'Count':
            rnn.load_state_dict(
                torch.load(
                    './saved_models/Count_concat_fc_mrm2s/rnn-1300-l3.257-a27.942.pkl'
                ))
        elif args.task == 'Action':
            rnn.load_state_dict(
                torch.load(
                    './saved_models/Action_concat_fc_mrm2s/rnn-0800-l0.137-a84.663.pkl'
                ))
        elif args.task == 'Trans':
            rnn.load_state_dict(
                torch.load(
                    './saved_models/Trans_concat_fc_mrm2s/rnn-1500-l0.246-a78.068.pkl'
                ))
        elif args.task == 'FrameQA':
            rnn.load_state_dict(
                torch.load(
                    './saved_models/FrameQA_concat_fc_mrm2s/rnn-4200-l1.233-a69.361.pkl'
                ))
        else:
            assert 1 == 2, 'Invalid task'

    optimizer = torch.optim.Adam(rnn.parameters(), lr=args.learning_rate)

    iter = 0

    if args.task == 'Count':

        best_val_loss = 100.0
        best_val_iter = 0.0

        # this is a regression problem, predict a value from 1-10
        for batch_chunk in train_iter:

            if args.test == 0:
                video_features = torch.from_numpy(
                    batch_chunk['video_features'].astype(np.float32)).cuda()
                video_lengths = batch_chunk['video_lengths']
                question_words = torch.from_numpy(
                    batch_chunk['question_words'].astype(np.int64)).cuda()
                question_lengths = batch_chunk['question_lengths']
                answers = torch.from_numpy(batch_chunk['answer'].astype(
                    np.float32)).cuda()

                data_dict = {}
                data_dict['video_features'] = video_features
                data_dict['video_lengths'] = video_lengths
                data_dict['question_words'] = question_words
                data_dict['question_lengths'] = question_lengths
                data_dict['answers'] = answers

                outputs, targets, predictions = rnn(data_dict, 'Count')
                loss = criterion(outputs, targets)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                acc = rnn.accuracy(predictions, targets.int())
                print('Train %s iter %d, loss %.3f, acc %.2f' %
                      (args.task, iter, loss.data, acc.item()))

            if iter % 100 == 0:
                rnn.eval()
                with torch.no_grad():

                    if args.test == 0:
                        ##### Validation ######
                        n_iter = len(val_dataset) / args.batch_size
                        losses = AverageMeter()
                        accuracy = AverageMeter()

                        iter_val = 0
                        for batch_chunk in val_dataset.batch_iter(
                                1, args.batch_size, shuffle=False):
                            if iter_val % 10 == 0:
                                print('%d/%d' % (iter_val, n_iter))

                            iter_val += 1

                            video_features = torch.from_numpy(
                                batch_chunk['video_features'].astype(
                                    np.float32)).cuda()
                            video_lengths = batch_chunk['video_lengths']
                            question_words = torch.from_numpy(
                                batch_chunk['question_words'].astype(
                                    np.int64)).cuda()
                            question_lengths = batch_chunk['question_lengths']
                            answers = torch.from_numpy(
                                batch_chunk['answer'].astype(
                                    np.float32)).cuda()

                            # print(question_words)
                            data_dict = {}
                            data_dict['video_features'] = video_features
                            data_dict['video_lengths'] = video_lengths
                            data_dict['question_words'] = question_words
                            data_dict['question_lengths'] = question_lengths
                            data_dict['answers'] = answers

                            outputs, targets, predictions = rnn(
                                data_dict, 'Count')
                            loss = criterion(outputs, targets)
                            acc = rnn.accuracy(predictions, targets.int())

                            losses.update(loss.item(), video_features.size(0))
                            accuracy.update(acc.item(), video_features.size(0))

                        if best_val_loss > losses.avg:
                            best_val_loss = losses.avg
                            best_val_iter = iter

                        print(
                            '[Val] iter %d, loss %.3f, acc %.2f, best loss %.3f at iter %d'
                            % (iter, losses.avg, accuracy.avg, best_val_loss,
                               best_val_iter))

                        torch.save(
                            rnn.state_dict(),
                            os.path.join(
                                args.save_model_path,
                                'rnn-%04d-l%.3f-a%.3f.pkl' %
                                (iter, losses.avg, accuracy.avg)))

                    if 1 == 1:
                        ###### Test ######
                        n_iter = len(test_dataset) / args.batch_size
                        losses = AverageMeter()
                        accuracy = AverageMeter()

                        iter_test = 0
                        for batch_chunk in test_dataset.batch_iter(
                                1, args.batch_size, shuffle=False):
                            if iter_test % 10 == 0:
                                print('%d/%d' % (iter_test, n_iter))

                            iter_test += 1

                            video_features = torch.from_numpy(
                                batch_chunk['video_features'].astype(
                                    np.float32)).cuda()
                            video_lengths = batch_chunk['video_lengths']
                            question_words = torch.from_numpy(
                                batch_chunk['question_words'].astype(
                                    np.int64)).cuda()
                            question_lengths = batch_chunk['question_lengths']
                            answers = torch.from_numpy(
                                batch_chunk['answer'].astype(
                                    np.float32)).cuda()

                            data_dict = {}
                            data_dict['video_features'] = video_features
                            data_dict['video_lengths'] = video_lengths
                            data_dict['question_words'] = question_words
                            data_dict['question_lengths'] = question_lengths
                            data_dict['answers'] = answers

                            outputs, targets, predictions = rnn(
                                data_dict, 'Count')
                            loss = criterion(outputs, targets)
                            acc = rnn.accuracy(predictions, targets.int())

                            losses.update(loss.item(), video_features.size(0))
                            accuracy.update(acc.item(), video_features.size(0))

                        print('[Test] iter %d, loss %.3f, acc %.2f' %
                              (iter, losses.avg, accuracy.avg))
                        if args.test == 1:
                            exit()

                rnn.train()

            iter += 1

    elif args.task in ['Action', 'Trans']:

        best_val_acc = 0.0
        best_val_iter = 0.0

        # this is a multiple-choice problem, predict probability of each class
        for batch_chunk in train_iter:

            if args.test == 0:
                video_features = torch.from_numpy(
                    batch_chunk['video_features'].astype(np.float32)).cuda()
                video_lengths = batch_chunk['video_lengths']
                candidates = torch.from_numpy(batch_chunk['candidates'].astype(
                    np.int64)).cuda()
                candidate_lengths = batch_chunk['candidate_lengths']
                answers = torch.from_numpy(batch_chunk['answer'].astype(
                    np.int32)).cuda()
                num_mult_choices = batch_chunk['num_mult_choices']

                data_dict = {}
                data_dict['video_features'] = video_features
                data_dict['video_lengths'] = video_lengths
                data_dict['candidates'] = candidates
                data_dict['candidate_lengths'] = candidate_lengths
                data_dict['answers'] = answers
                data_dict['num_mult_choices'] = num_mult_choices

                outputs, targets, predictions = rnn(data_dict, args.task)

                loss = criterion(outputs, targets)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                acc = rnn.accuracy(predictions, targets.long())
                print('Train %s iter %d, loss %.3f, acc %.2f' %
                      (args.task, iter, loss.data, acc.item()))

            if iter % 100 == 0:
                rnn.eval()
                with torch.no_grad():

                    if args.test == 0:
                        n_iter = len(val_dataset) / args.batch_size
                        losses = AverageMeter()
                        accuracy = AverageMeter()
                        iter_val = 0
                        for batch_chunk in val_dataset.batch_iter(
                                1, args.batch_size, shuffle=False):
                            if iter_val % 10 == 0:
                                print('%d/%d' % (iter_val, n_iter))

                            iter_val += 1

                            video_features = torch.from_numpy(
                                batch_chunk['video_features'].astype(
                                    np.float32)).cuda()
                            video_lengths = batch_chunk['video_lengths']
                            candidates = torch.from_numpy(
                                batch_chunk['candidates'].astype(
                                    np.int64)).cuda()
                            candidate_lengths = batch_chunk[
                                'candidate_lengths']
                            answers = torch.from_numpy(
                                batch_chunk['answer'].astype(np.int32)).cuda()
                            num_mult_choices = batch_chunk['num_mult_choices']

                            data_dict = {}
                            data_dict['video_features'] = video_features
                            data_dict['video_lengths'] = video_lengths
                            data_dict['candidates'] = candidates
                            data_dict['candidate_lengths'] = candidate_lengths
                            data_dict['answers'] = answers
                            data_dict['num_mult_choices'] = num_mult_choices

                            outputs, targets, predictions = rnn(
                                data_dict, args.task)

                            loss = criterion(outputs, targets)
                            acc = rnn.accuracy(predictions, targets.long())

                            losses.update(loss.item(), video_features.size(0))
                            accuracy.update(acc.item(), video_features.size(0))

                        if best_val_acc < accuracy.avg:
                            best_val_acc = accuracy.avg
                            best_val_iter = iter

                        print(
                            '[Val] iter %d, loss %.3f, acc %.2f, best acc %.3f at iter %d'
                            % (iter, losses.avg, accuracy.avg, best_val_acc,
                               best_val_iter))
                        torch.save(
                            rnn.state_dict(),
                            os.path.join(
                                args.save_model_path,
                                'rnn-%04d-l%.3f-a%.3f.pkl' %
                                (iter, losses.avg, accuracy.avg)))

                    if 1 == 1:
                        n_iter = len(test_dataset) / args.batch_size
                        losses = AverageMeter()
                        accuracy = AverageMeter()
                        iter_test = 0
                        for batch_chunk in test_dataset.batch_iter(
                                1, args.batch_size, shuffle=False):
                            if iter_test % 10 == 0:
                                print('%d/%d' % (iter_test, n_iter))

                            iter_test += 1

                            video_features = torch.from_numpy(
                                batch_chunk['video_features'].astype(
                                    np.float32)).cuda()
                            video_lengths = batch_chunk['video_lengths']
                            candidates = torch.from_numpy(
                                batch_chunk['candidates'].astype(
                                    np.int64)).cuda()
                            candidate_lengths = batch_chunk[
                                'candidate_lengths']
                            answers = torch.from_numpy(
                                batch_chunk['answer'].astype(np.int32)).cuda()
                            num_mult_choices = batch_chunk['num_mult_choices']
                            #question_word_nums = batch_chunk['question_word_nums']

                            data_dict = {}
                            data_dict['video_features'] = video_features
                            data_dict['video_lengths'] = video_lengths
                            data_dict['candidates'] = candidates
                            data_dict['candidate_lengths'] = candidate_lengths
                            data_dict['answers'] = answers
                            data_dict['num_mult_choices'] = num_mult_choices

                            outputs, targets, predictions = rnn(
                                data_dict, args.task)

                            loss = criterion(outputs, targets)
                            acc = rnn.accuracy(predictions, targets.long())

                            losses.update(loss.item(), video_features.size(0))
                            accuracy.update(acc.item(), video_features.size(0))

                        print('[Test] iter %d, loss %.3f, acc %.2f' %
                              (iter, losses.avg, accuracy.avg))
                        if args.test == 1:
                            exit()

                rnn.train()

            iter += 1

    elif args.task == 'FrameQA':

        best_val_acc = 0.0
        best_val_iter = 0.0

        # this is a multiple-choice problem, predict probability of each class
        for batch_chunk in train_iter:

            if args.test == 0:
                video_features = torch.from_numpy(
                    batch_chunk['video_features'].astype(np.float32)).cuda()
                video_lengths = batch_chunk['video_lengths']
                question_words = torch.from_numpy(
                    batch_chunk['question_words'].astype(np.int64)).cuda()
                question_lengths = batch_chunk['question_lengths']
                answers = torch.from_numpy(batch_chunk['answer'].astype(
                    np.int64)).cuda()

                data_dict = {}
                data_dict['video_features'] = video_features
                data_dict['video_lengths'] = video_lengths
                data_dict['question_words'] = question_words
                data_dict['question_lengths'] = question_lengths
                data_dict['answers'] = answers

                outputs, targets, predictions = rnn(data_dict, args.task)

                loss = criterion(outputs, targets)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                acc = rnn.accuracy(predictions, targets)
                print('Train %s iter %d, loss %.3f, acc %.2f' %
                      (args.task, iter, loss.data, acc.item()))

            if iter % 100 == 0:
                rnn.eval()

                with torch.no_grad():

                    if args.test == 0:
                        losses = AverageMeter()
                        accuracy = AverageMeter()
                        n_iter = len(val_dataset) / args.batch_size

                        iter_val = 0
                        for batch_chunk in val_dataset.batch_iter(
                                1, args.batch_size, shuffle=False):
                            if iter_val % 10 == 0:
                                print('%d/%d' % (iter_val, n_iter))

                            iter_val += 1

                            video_features = torch.from_numpy(
                                batch_chunk['video_features'].astype(
                                    np.float32)).cuda()
                            video_lengths = batch_chunk['video_lengths']
                            question_words = torch.from_numpy(
                                batch_chunk['question_words'].astype(
                                    np.int64)).cuda()
                            question_lengths = batch_chunk['question_lengths']
                            answers = torch.from_numpy(
                                batch_chunk['answer'].astype(np.int64)).cuda()

                            data_dict = {}
                            data_dict['video_features'] = video_features
                            data_dict['video_lengths'] = video_lengths
                            data_dict['question_words'] = question_words
                            data_dict['question_lengths'] = question_lengths
                            data_dict['answers'] = answers

                            outputs, targets, predictions = rnn(
                                data_dict, args.task)

                            loss = criterion(outputs, targets)
                            acc = rnn.accuracy(predictions, targets)

                            losses.update(loss.item(), video_features.size(0))
                            accuracy.update(acc.item(), video_features.size(0))

                        if best_val_acc < accuracy.avg:
                            best_val_acc = accuracy.avg
                            best_val_iter = iter

                        print(
                            '[Val] iter %d, loss %.3f, acc %.2f, best acc %.3f at iter %d'
                            % (iter, losses.avg, accuracy.avg, best_val_acc,
                               best_val_iter))
                        torch.save(
                            rnn.state_dict(),
                            os.path.join(
                                args.save_model_path,
                                'rnn-%04d-l%.3f-a%.3f.pkl' %
                                (iter, losses.avg, accuracy.avg)))

                    if 1 == 1:
                        losses = AverageMeter()
                        accuracy = AverageMeter()
                        n_iter = len(test_dataset) / args.batch_size

                        iter_test = 0
                        for batch_chunk in test_dataset.batch_iter(
                                1, args.batch_size, shuffle=False):
                            if iter_test % 10 == 0:
                                print('%d/%d' % (iter_test, n_iter))

                            iter_test += 1

                            video_features = torch.from_numpy(
                                batch_chunk['video_features'].astype(
                                    np.float32)).cuda()
                            video_lengths = batch_chunk['video_lengths']
                            question_words = torch.from_numpy(
                                batch_chunk['question_words'].astype(
                                    np.int64)).cuda()
                            question_lengths = batch_chunk['question_lengths']
                            answers = torch.from_numpy(
                                batch_chunk['answer'].astype(np.int64)).cuda()

                            data_dict = {}
                            data_dict['video_features'] = video_features
                            data_dict['video_lengths'] = video_lengths
                            data_dict['question_words'] = question_words
                            data_dict['question_lengths'] = question_lengths
                            data_dict['answers'] = answers

                            outputs, targets, predictions = rnn(
                                data_dict, args.task)

                            loss = criterion(outputs, targets)
                            acc = rnn.accuracy(predictions, targets)

                            losses.update(loss.item(), video_features.size(0))
                            accuracy.update(acc.item(), video_features.size(0))

                        print('[Test] iter %d, loss %.3f, acc %.2f' %
                              (iter, losses.avg, accuracy))
                        if args.test == 1:
                            exit()

                rnn.train()

            iter += 1
Пример #4
0
def main():
    """Main script."""

    if args.server == '780':
        args.data_path = '/home/jp/data/tgif-qa/data'

    args.feat_dir = os.path.join(args.data_path, 'feats')
    args.vc_dir = os.path.join(args.data_path, 'Vocabulary')
    args.df_dir = os.path.join(args.data_path, 'dataset')

    args.model_name = args.task
    args.pin_memory = False
    args.dataset = 'tgif_qa'
    args.log = './logs'
    # args.val_epoch_step = 5
    args.val_epoch_step = 1
    if args.two_loss > 0:
        args.two_loss = True
    else:
        args.two_loss = False

    if args.birnn > 0:
        args.birnn = True
    else:
        args.birnn = False

    assert args.ablation in ['none', 'gcn', 'global', 'local', 'only_local']
    assert args.fusion_type in [
        'none', 'coattn', 'single_visual', 'single_semantic', 'coconcat',
        'cosiamese'
    ]

    args.save_model_path = args.save_path + 'MMModel/'
    if not os.path.exists(args.save_model_path):
        os.makedirs(args.save_model_path)

    full_dataset = TGIFQA(dataset_name='train',
                          q_max_length=args.q_max_length,
                          v_max_length=args.v_max_length,
                          max_n_videos=args.max_n_videos,
                          data_type=args.task,
                          csv_dir=args.df_dir,
                          vocab_dir=args.vc_dir,
                          feat_dir=args.feat_dir)
    test_dataset = TGIFQA(dataset_name='test',
                          q_max_length=args.q_max_length,
                          v_max_length=args.v_max_length,
                          max_n_videos=args.max_n_videos,
                          data_type=args.task,
                          csv_dir=args.df_dir,
                          vocab_dir=args.vc_dir,
                          feat_dir=args.feat_dir)

    val_size = int(args.val_ratio * len(full_dataset))
    train_size = len(full_dataset) - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size])
    print('Dataset lengths train/val/test %d/%d/%d' %
          (len(train_dataset), len(val_dataset), len(test_dataset)))
    # train_dataset = full_dataset
    # val_dataloader = None

    train_dataloader = DataLoader(train_dataset,
                                  args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=args.pin_memory,
                                  worker_init_fn=_init_fn)
    val_dataloader = DataLoader(val_dataset,
                                args.batch_size,
                                shuffle=False,
                                num_workers=args.num_workers,
                                pin_memory=args.pin_memory,
                                worker_init_fn=_init_fn)
    test_dataloader = DataLoader(test_dataset,
                                 args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 pin_memory=args.pin_memory,
                                 worker_init_fn=_init_fn)

    print('Load data successful.')

    if args.prefetch == 'nvidia':
        train_dataloader = nvidia_prefetcher(train_dataloader)
    elif args.prefetch == 'background':
        train_dataloader = BackgroundGenerator(train_dataloader)

    args.resnet_input_size = 2048
    args.c3d_input_size = 4096

    args.text_embed_size = train_dataset.dataset.GLOVE_EMBEDDING_SIZE
    args.answer_vocab_size = None

    args.word_matrix = train_dataset.dataset.word_matrix
    args.voc_len = args.word_matrix.shape[0]
    assert args.text_embed_size == args.word_matrix.shape[1]

    VOCABULARY_SIZE = train_dataset.dataset.n_words
    assert VOCABULARY_SIZE == args.voc_len

    ### criterions
    if args.task == 'Count':
        # add L2 loss
        criterion = nn.MSELoss().to(device)
    elif args.task in ['Action', 'Trans']:
        from embed_loss import MultipleChoiceLoss
        criterion = MultipleChoiceLoss(num_option=5,
                                       margin=1,
                                       size_average=True).to(device)
    elif args.task == 'FrameQA':
        # add classification loss
        args.answer_vocab_size = len(train_dataset.dataset.ans2idx)
        print(('Vocabulary size', args.answer_vocab_size, VOCABULARY_SIZE))
        criterion = nn.CrossEntropyLoss().to(device)

    if not args.test:
        train(args, train_dataloader, val_dataloader, test_dataloader,
              criterion)
    else:
        assert args.checkpoint[:5] == args.task[:5]
        model = torch.load(args.save_model_path + args.checkpoint)
        test(args, model, test_dataloader, 0, criterion)
Пример #5
0
def main():
    """Main script."""
    torch.manual_seed(1)

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', default='train', help='train/test')
    parser.add_argument('--save_path',
                        type=str,
                        default='./saved_models/',
                        help='path for saving trained models')
    parser.add_argument(
        '--split',
        type=int,
        help='which of the three splits to train/val/test, option: 0 | 1 | 2')
    parser.add_argument('--test', type=int, default=0, help='0 | 1')
    parser.add_argument('--memory_type',
                        type=str,
                        help='_mrm2s | _stvqa | _enc_dec | _co_mem')

    args = parser.parse_args()

    args.dataset = 'ego_vqa'
    args.word_dim = 300
    args.vocab_num = 4000
    args.pretrained_embedding = 'data/word_embedding.npy'
    args.video_feature_dim = 4096
    args.video_feature_num = 20
    args.memory_dim = 256
    args.batch_size = 8
    args.reg_coeff = 1e-5
    args.learning_rate = 0.001
    args.preprocess_dir = 'data'
    args.log = './logs'
    args.hidden_size = 256

    video_cam = [
        '1_D', '1_M', '2_F', '2_M', '3_F', '3_M', '4_F', '4_M', '5_D', '5_M',
        '6_D', '6_M', '7_D', '7_M', '8_M', '8_X'
    ]

    with open('./data/data_split.json', 'r') as f:
        splits = json.load(f)
    assert args.split < len(splits)
    video_cam_split = splits[args.split]

    dataset = dt.EgoVQA(args.batch_size, './data', './data/feats',
                        video_cam_split)

    #args.memory_type = '_mrm2s'   # HME-QA: see HME-QA paper
    #args.memory_type = '_stvqa'  # st-vqa: see TGIF-QA paper
    #args.memory_type = '_enc_dec' # plain LSTM
    #args.memory_type = '_co_mem' # Co-Mem
    # TODO: if possible, add new algorithm
    # TODO: analyze different types of questions

    args.image_feature_net = 'concat'
    args.layer = 'fc'

    args.save_model_path = args.save_path + 'model_%s_%s%s' % (
        args.image_feature_net, args.layer, args.memory_type)
    if not os.path.exists(args.save_model_path):
        os.makedirs(args.save_model_path)
    if not os.path.exists(
            os.path.join(args.save_model_path, 's%d' % (args.split))):
        os.makedirs(os.path.join(args.save_model_path, 's%d' % (args.split)))

    args.pretrain_model_path = './pretrain_models/' + 'model_%s_%s%s' % (
        args.image_feature_net, args.layer, args.memory_type)

    #############################
    # get video feature dimension
    #############################
    feat_channel = args.video_feature_dim
    feat_dim = 1
    text_embed_size = args.word_dim
    voc_len = args.vocab_num
    num_layers = 2
    max_sequence_length = args.video_feature_num
    word_matrix = np.load(args.pretrained_embedding)

    if args.memory_type == '_mrm2s':
        rnn = AttentionTwoStream(feat_channel,
                                 feat_dim,
                                 text_embed_size,
                                 args.hidden_size,
                                 voc_len,
                                 num_layers,
                                 word_matrix,
                                 max_len=max_sequence_length)
        rnn = rnn.cuda()

        best_pretrain_state = 'rnn-5100-vl_80.451-t_81.002.pkl'

    elif args.memory_type == '_stvqa':
        feat_channel *= 2
        rnn = TGIFBenchmark(feat_channel,
                            feat_dim,
                            text_embed_size,
                            args.hidden_size,
                            voc_len,
                            num_layers,
                            word_matrix,
                            max_len=max_sequence_length)
        rnn = rnn.cuda()

        best_pretrain_state = 'rnn-0400-vl_69.603-t_68.562.pkl'

    elif args.memory_type == '_enc_dec':
        feat_channel *= 2
        rnn = LSTMEncDec(feat_channel,
                         feat_dim,
                         text_embed_size,
                         args.hidden_size,
                         voc_len,
                         num_layers,
                         word_matrix,
                         max_len=max_sequence_length)
        rnn = rnn.cuda()

        best_pretrain_state = 'rnn-1200-vl_66.518-t_65.817.pkl'

    elif args.memory_type == '_co_mem':
        rnn = CoMemory(feat_channel,
                       feat_dim,
                       text_embed_size,
                       args.hidden_size,
                       voc_len,
                       num_layers,
                       word_matrix,
                       max_len=max_sequence_length)
        rnn = rnn.cuda()

        best_pretrain_state = 'rnn-1200-vl_76.516-t_75.708.pkl'

    else:
        assert 1 == 2

    #################################
    # load pretrain model to finetune
    #################################

    pretrain_path = os.path.join(
        './pretrain_models', 'model_%s_%s%s' %
        (args.image_feature_net, args.layer, args.memory_type),
        best_pretrain_state)

    if os.path.isfile(pretrain_path):
        #rnn.load_state_dict(torch.load(pretrain_path))
        print 'Load from ', pretrain_path
    else:
        print 'Cannot load ', pretrain_path

    # loss function
    criterion = MultipleChoiceLoss(margin=1, size_average=True).cuda()

    optimizer = torch.optim.Adam(rnn.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=0.0005)

    best_test_acc = 0.0
    best_test_iter = 0

    iter = 0

    for epoch in range(0, 20):
        dataset.reset_train()

        while dataset.has_train_batch:
            iter += 1

            vgg, c3d, questions, answers, question_lengths = dataset.get_train_batch(
            )
            data_dict = getInput(vgg, c3d, questions, answers,
                                 question_lengths)
            outputs, predictions = rnn(data_dict)
            targets = data_dict['answers']

            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc = rnn.accuracy(predictions, targets)
            print('Train iter %d, loss %.3f, acc %.2f' %
                  (iter, loss.data, acc.item()))

        if epoch >= 0:
            rnn.eval()

            # val iterate over examples
            with torch.no_grad():

                idx = 0
                accuracy = AverageMeter()

                while dataset.has_val_example:
                    if idx % 10 == 0:
                        print 'Val iter %d/%d' % (idx,
                                                  dataset.val_example_total)
                    idx += 1

                    vgg, c3d, questions, answers, question_lengths = dataset.get_val_example(
                    )
                    data_dict = getInput(vgg, c3d, questions, answers,
                                         question_lengths)
                    outputs, predictions = rnn(data_dict)
                    targets = data_dict['answers']

                    acc = rnn.accuracy(predictions, targets)
                    accuracy.update(acc.item(), len(vgg))

                val_acc = accuracy.avg
                print('Val iter %d, acc %.3f' % (iter, val_acc))
                dataset.reset_val()

                idx = 0
                accuracy = AverageMeter()

                while dataset.has_test_example:
                    if idx % 10 == 0:
                        print 'Test iter %d/%d' % (idx,
                                                   dataset.test_example_total)
                    idx += 1

                    vgg, c3d, questions, answers, question_lengths, _, _ = dataset.get_test_example(
                    )
                    data_dict = getInput(vgg, c3d, questions, answers,
                                         question_lengths)
                    outputs, predictions = rnn(data_dict)
                    targets = data_dict['answers']

                    acc = rnn.accuracy(predictions, targets)
                    accuracy.update(acc.item(), len(vgg))

                test_acc = accuracy.avg
                print('Test iter %d, acc %.3f' % (iter, accuracy.avg))
                dataset.reset_test()

                if best_test_acc < accuracy.avg:
                    best_test_acc = accuracy.avg
                    best_test_iter = iter

                print('[Test] iter %d, acc %.3f, best acc %.3f at iter %d' %
                      (iter, test_acc, best_test_acc, best_test_iter))

                torch.save(
                    rnn.state_dict(),
                    os.path.join(
                        args.save_model_path, 's%d' % (args.split, ),
                        'rnn-%04d-vl_%.3f-t_%.3f.pkl' %
                        (iter, val_acc, test_acc)))
                rnn.train()
Пример #6
0
def main():
    """Main script."""
    torch.manual_seed(2667)

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', default='train', help='train/test')
    parser.add_argument(
        '--save_path',
        type=str,
        default=os.path.expanduser(
            '~/kable_management/blp_paper/.results/egovqa/saved_models'),
        help='path for saving trained models')
    parser.add_argument(
        '--split',
        type=int,
        help='which of the three splits to train/val/test, option: 0 | 1 | 2')
    parser.add_argument('--test', type=int, default=0, help='0 | 1')
    parser.add_argument('--memory_type',
                        type=str,
                        help='_mrm2s | _stvqa | _enc_dec | _co_mem')
    ######################################################################
    parser.add_argument('--jobname',
                        type=str,
                        default='egovqa-default',
                        help='jobname to direct to plotter')
    parser.add_argument("--pool_type",
                        type=str,
                        default="default",
                        choices=[
                            "default", "LinearSum", "ConcatMLP", "MCB", "MFH",
                            "MFB", "MLB", "Block", "Tucker", "BlockTucker",
                            "Mutan"
                        ],
                        help="Which pooling technique to use")
    parser.add_argument("--pool_in_dims",
                        type=int,
                        nargs='+',
                        default=[300, 300],
                        help="Input dimensions to pooling layers")
    parser.add_argument("--pool_out_dim",
                        type=int,
                        default=600,
                        help="Output dimension to pooling layers")
    parser.add_argument(
        "--pool_hidden_dim",
        type=int,
        default=1500,
        help="Some pooling types come with a hidden internal dimension")

    ######################################################################

    args = parser.parse_args()

    ######################################################################
    plotter = VisdomLinePlotter(env_name=args.jobname)
    args.plotter = plotter
    ######################################################################

    args.dataset = 'ego_vqa'
    args.word_dim = 300
    args.vocab_num = 4000
    args.pretrained_embedding = os.path.expanduser(
        '~/kable_management/data/EgoVQA/data/word_embedding.npy')
    args.video_feature_dim = 4096
    args.video_feature_num = 20
    args.memory_dim = 256
    args.batch_size = 8
    args.reg_coeff = 1e-5
    args.learning_rate = 0.001
    #args.preprocess_dir = 'data'
    args.preprocess_dir = os.path.expanduser(
        "~/kable_management/data/EgoVQA/data")
    args.log = './logs'
    args.hidden_size = 256

    video_cam = [
        '1_D', '1_M', '2_F', '2_M', '3_F', '3_M', '4_F', '4_M', '5_D', '5_M',
        '6_D', '6_M', '7_D', '7_M', '8_M', '8_X'
    ]

    with open(
            os.path.expanduser(
                "~/kable_management/data/EgoVQA/data/data_split.json"),
            'r') as f:
        splits = json.load(f)

    assert args.split < len(splits)
    video_cam_split = splits[args.split]

    #dataset = dt.EgoVQA(args.batch_size, './data', './data/feats', video_cam_split)
    dataset = dt.EgoVQA(
        args.batch_size,
        os.path.expanduser("~/kable_management/data/EgoVQA/data"),
        os.path.expanduser("~/kable_management/data/EgoVQA/data/feats"),
        video_cam_split)

    #args.memory_type = '_mrm2s'   # HME-QA: see HME-QA paper
    #args.memory_type = '_stvqa'  # st-vqa: see TGIF-QA paper
    #args.memory_type = '_enc_dec' # plain LSTM
    #args.memory_type = '_co_mem' # Co-Mem
    # TODO: if possible, add new algorithm
    # TODO: analyze different types of questions

    args.image_feature_net = 'concat'
    args.layer = 'fc'

    args.save_model_path = args.save_path + 'model_%s_%s%s' % (
        args.image_feature_net, args.layer, args.memory_type)
    if not os.path.exists(args.save_model_path):
        os.makedirs(args.save_model_path)
    if not os.path.exists(
            os.path.join(args.save_model_path, 's%d' % (args.split))):
        os.makedirs(os.path.join(args.save_model_path, 's%d' % (args.split)))

    args.pretrain_model_path = os.path.expanduser(
        "~/kable_management/blp_paper/.results/egovqa/pretrain_models"
    ) + 'model_%s_%s%s' % (args.image_feature_net, args.layer,
                           args.memory_type)

    #############################
    # get video feature dimension
    #############################
    feat_channel = args.video_feature_dim
    feat_dim = 1
    text_embed_size = args.word_dim
    voc_len = args.vocab_num
    num_layers = 2
    max_sequence_length = args.video_feature_num
    word_matrix = np.load(args.pretrained_embedding)

    if args.memory_type == '_mrm2s':
        rnn = AttentionTwoStream(args,
                                 feat_channel,
                                 feat_dim,
                                 text_embed_size,
                                 args.hidden_size,
                                 voc_len,
                                 num_layers,
                                 word_matrix,
                                 max_len=max_sequence_length)
        rnn = rnn.cuda()

        best_pretrain_state = 'rnn-5100-vl_80.451-t_81.002.pkl'

    elif args.memory_type == '_stvqa':
        feat_channel *= 2
        rnn = TGIFBenchmark(args,
                            feat_channel,
                            feat_dim,
                            text_embed_size,
                            args.hidden_size,
                            voc_len,
                            num_layers,
                            word_matrix,
                            max_len=max_sequence_length)
        rnn = rnn.cuda()

        best_pretrain_state = 'rnn-0400-vl_69.603-t_68.562.pkl'

    elif args.memory_type == '_enc_dec':
        feat_channel *= 2
        rnn = LSTMEncDec(args,
                         feat_channel,
                         feat_dim,
                         text_embed_size,
                         args.hidden_size,
                         voc_len,
                         num_layers,
                         word_matrix,
                         max_len=max_sequence_length)
        rnn = rnn.cuda()

        best_pretrain_state = 'rnn-1200-vl_66.518-t_65.817.pkl'

    elif args.memory_type == '_co_mem':
        rnn = CoMemory(args,
                       feat_channel,
                       feat_dim,
                       text_embed_size,
                       args.hidden_size,
                       voc_len,
                       num_layers,
                       word_matrix,
                       max_len=max_sequence_length)
        rnn = rnn.cuda()

        best_pretrain_state = 'rnn-1200-vl_76.516-t_75.708.pkl'

    else:
        assert 1 == 2

    #################################
    # load pretrain model to finetune
    #################################

    pretrain_path = os.path.join(
        os.path.expanduser(
            "~/kable_management/blp_paper/.results/egovqa/pretrain_models"),
        'model_%s_%s%s' %
        (args.image_feature_net, args.layer, args.memory_type),
        best_pretrain_state)

    if os.path.isfile(pretrain_path):
        #rnn.load_state_dict(torch.load(pretrain_path))
        print 'Load from ', pretrain_path
    else:
        print 'Cannot load ', pretrain_path

    # loss function
    criterion = MultipleChoiceLoss(margin=1, size_average=True).cuda()

    optimizer = torch.optim.Adam(rnn.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=0.0005)

    best_test_acc = 0.0
    best_val_acc = 0.0
    best_test_iter = 0
    train_loss = []
    train_acc = []

    iter = 0

    for epoch in range(0, 500):  #100
        dataset.reset_train()

        while dataset.has_train_batch:
            iter += 1

            vgg, c3d, questions, answers, question_lengths = dataset.get_train_batch(
            )
            data_dict = getInput(vgg, c3d, questions, answers,
                                 question_lengths)
            outputs, predictions = rnn(data_dict)
            targets = data_dict['answers']

            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc = rnn.accuracy(predictions, targets)
            ######################################################################
            #args.plotter.plot("accuracy", "train", args.jobname, iter, acc.item())
            args.plotter.plot("loss", "train", args.jobname, iter, loss.item())
            ######################################################################
            print('Train iter %d, loss %.3f, acc %.2f' %
                  (iter, loss.data, acc.item()))

        if epoch >= 0:
            rnn.eval()

            # val iterate over examples
            with torch.no_grad():

                idx = 0
                accuracy = AverageMeter()

                while dataset.has_val_example:
                    if idx % 10 == 0:
                        print 'Val iter %d/%d' % (idx,
                                                  dataset.val_example_total)
                    idx += 1

                    vgg, c3d, questions, answers, question_lengths = dataset.get_val_example(
                    )
                    data_dict = getInput(vgg, c3d, questions, answers,
                                         question_lengths)
                    outputs, predictions = rnn(data_dict)
                    targets = data_dict['answers']

                    acc = rnn.accuracy(predictions, targets)
                    accuracy.update(acc.item(), len(vgg))

                val_acc = accuracy.avg
                print('Val iter %d, acc %.3f' % (iter, val_acc))
                ######################################################################
                if best_val_acc < val_acc:
                    best_val_acc = val_acc
                    args.plotter.text_plot(
                        args.jobname + " val", args.jobname + " val " +
                        str(round(best_val_acc, 4)) + " " + str(iter))
                args.plotter.plot("accuracy", "val", args.jobname, iter,
                                  val_acc)
                ######################################################################
                dataset.reset_val()

                idx = 0
                accuracy = AverageMeter()

                while dataset.has_test_example:
                    if idx % 10 == 0:
                        print 'Test iter %d/%d' % (idx,
                                                   dataset.test_example_total)
                    idx += 1

                    vgg, c3d, questions, answers, question_lengths, _, _ = dataset.get_test_example(
                    )
                    data_dict = getInput(vgg, c3d, questions, answers,
                                         question_lengths)
                    outputs, predictions = rnn(data_dict)
                    targets = data_dict['answers']

                    acc = rnn.accuracy(predictions, targets)
                    accuracy.update(acc.item(), len(vgg))

                test_acc = accuracy.avg
                print('Test iter %d, acc %.3f' % (iter, accuracy.avg))
                dataset.reset_test()

                ######################################################################
                if best_test_acc < test_acc:
                    best_test_iter = iter
                    best_test_acc = test_acc
                    args.plotter.text_plot(
                        args.jobname + " test", args.jobname + " test " +
                        str(round(best_test_acc, 4)) + " " + str(iter))
                args.plotter.plot("accuracy", "test", args.jobname, iter,
                                  test_acc)
                ######################################################################

                print('[Test] iter %d, acc %.3f, best acc %.3f at iter %d' %
                      (iter, test_acc, best_test_acc, best_test_iter))

                torch.save(
                    rnn.state_dict(),
                    os.path.join(
                        args.save_model_path, 's%d' % (args.split, ),
                        'rnn-%04d-vl_%.3f-t_%.3f.pkl' %
                        (iter, val_acc, test_acc)))
                rnn.train()
Пример #7
0
def main(args):
    torch.manual_seed(2667)  #1

    ### add arguments ###
    args.vc_dir = os.path.expanduser(
        "~/kable_management/data/tgif_qa/data/Vocabulary")
    args.df_dir = os.path.expanduser(
        "~/kable_management/data/tgif_qa/data/dataset")

    # Dataset
    dataset = tvqa_dataset.TVQADataset(args, mode="train")

    args.max_sequence_length = 50  #35
    args.model_name = args.task + args.feat_type

    args.memory_type = '_mrm2s'
    args.image_feature_net = 'concat'
    args.layer = 'fc'

    args.save_model_path = args.save_path + '%s_%s_%s%s' % (
        args.task, args.image_feature_net, args.layer, args.memory_type)
    # Create model directory
    if not os.path.exists(args.save_model_path):
        os.makedirs(args.save_model_path)

    #############################
    # get video feature dimension
    #############################
    video_feature_dimension = (480, 1, 1, 6144)
    feat_channel = video_feature_dimension[3]
    feat_dim = video_feature_dimension[2]
    text_embed_size = 300
    answer_vocab_size = None

    #############################
    # get word vector dimension
    #############################
    word_matrix = dataset.vocab_embedding
    voc_len = len(dataset.word2idx)
    assert text_embed_size == word_matrix.shape[1]

    #############################
    # Parameters
    #############################
    SEQUENCE_LENGTH = args.max_sequence_length
    VOCABULARY_SIZE = voc_len
    assert VOCABULARY_SIZE == voc_len
    FEAT_DIM = (1, 1, 6144)

    # Create model directory
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    if args.task == 'Count':
        # add L2 loss
        criterion = nn.MSELoss(size_average=True).to(args.device)
    elif args.task in ['Action', 'Trans']:
        from embed_loss import MultipleChoiceLoss
        criterion = MultipleChoiceLoss(num_option=5,
                                       margin=1,
                                       size_average=True).to(args.device)
    elif args.task == 'FrameQA':
        # add classification loss
        answer_vocab_size = len(train_dataset.ans2idx)
        print('Vocabulary size', answer_vocab_size, VOCABULARY_SIZE)
        criterion = nn.CrossEntropyLoss(size_average=True).to(args.device)

    if args.memory_type == '_mrm2s':
        rnn = TGIFAttentionTwoStream(args,
                                     args.task,
                                     feat_channel,
                                     feat_dim,
                                     text_embed_size,
                                     args.hidden_size,
                                     voc_len,
                                     args.num_layers,
                                     word_matrix,
                                     answer_vocab_size=answer_vocab_size,
                                     max_len=args.max_sequence_length)
    else:
        assert 1 == 2

    rnn = rnn.to(args.device)

    #  to directly test, load pre-trained model, replace with your model to test your model
    if args.test == 1:
        if args.task == 'Count':
            rnn.load_state_dict(
                torch.load(
                    './saved_models/Count_concat_fc_mrm2s/rnn-1300-l3.257-a27.942.pkl'
                ))
        elif args.task == 'Action':
            rnn.load_state_dict(
                torch.load(
                    './saved_models/Action_concat_fc_mrm2s/rnn-0800-l0.137-a84.663.pkl'
                ))
        elif args.task == 'Trans':
            rnn.load_state_dict(
                torch.load(
                    './saved_models/Trans_concat_fc_mrm2s/rnn-1500-l0.246-a78.068.pkl'
                ))
        elif args.task == 'FrameQA':
            rnn.load_state_dict(
                torch.load(
                    './saved_models/FrameQA_concat_fc_mrm2s/rnn-4200-l1.233-a69.361.pkl'
                ))
        else:
            assert 1 == 2, 'Invalid task'

    optimizer = torch.optim.Adam(rnn.parameters(), lr=args.learning_rate)

    iter = 0

    if args.task in ['Action', 'Trans']:
        best_val_acc = 0.0
        best_val_iter = 0.0
        best_test_acc = 0.0
        train_acc = []
        train_loss = []

        for epoch in range(0, args.num_epochs):
            dataset.set_mode("train")
            train_loader = DataLoader(dataset,
                                      batch_size=args.bsz,
                                      shuffle=True,
                                      collate_fn=tvqa_dataset.pad_collate,
                                      num_workers=args.num_workers)
            for batch_idx, batch in enumerate(train_loader):
                # if(batch_idx<51):
                #     iter+=1
                #     continue
                # else:
                #     pass
                tvqa_data, targets, _ = tvqa_dataset.preprocess_inputs(
                    batch,
                    args.max_sub_l,
                    args.max_vcpt_l,
                    args.max_vid_l,
                    device=args.device)
                assert (tvqa_data[18].shape[2] == 4096)
                data_dict = {}
                data_dict['num_mult_choices'] = 5
                data_dict[
                    'answers'] = targets  #'int32 torch tensor correct answer ids'
                data_dict['video_features'] = torch.cat(
                    [tvqa_data[18], tvqa_data[16]],
                    dim=2).unsqueeze(2).unsqueeze(
                        2)  #'(bsz, 35(max), 1, 1, 6144)'
                data_dict['video_lengths'] = tvqa_data[17].tolist(
                )  #'list, (bsz), number of frames, #max is 35'
                data_dict['candidates'] = tvqa_data[
                    0]  #'torch tensor (bsz, 5, 35) (max q length)'
                data_dict['candidate_lengths'] = tvqa_data[
                    1]  #'list of list, lengths of each question+ans'
                # if hard restrict
                if args.hard_restrict:
                    data_dict['video_features'] = data_dict[
                        'video_features'][:, :args.
                                          max_sequence_length]  #'(bsz, 35(max), 1, 1, 6144)'
                    data_dict['video_lengths'] = [
                        i if i <= args.max_sequence_length else
                        args.max_sequence_length
                        for i in data_dict['video_lengths']
                    ]
                    data_dict['candidates'] = data_dict[
                        'candidates'][:, :, :args.max_sequence_length]
                    for xx in range(data_dict['candidate_lengths'].shape[0]):
                        if (data_dict['video_lengths'][xx] == 0):
                            data_dict['video_lengths'][xx] = 1
                        for yy in range(
                                data_dict['candidate_lengths'].shape[1]):
                            if (data_dict['candidate_lengths'][xx][yy] >
                                    args.max_sequence_length):
                                data_dict['candidate_lengths'][xx][
                                    yy] = args.max_sequence_length
                #import ipdb; ipdb.set_trace()
                #print("boom")
                outputs, targets, predictions = rnn(data_dict, args.task)

                loss = criterion(outputs, targets)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                acc = rnn.accuracy(predictions, targets.long())
                #print("boom")
                print('Train %s iter %d, loss %.3f, acc %.2f' %
                      (args.task, iter, loss.data, acc.item()))
                ######################################################################
                #args.plotter.plot("accuracy", "train", args.jobname, iter, acc.item())
                #args.plotter.plot("loss", "train", args.jobname, iter, int(loss.data))
                train_acc.append(acc.item())
                train_loss.append(loss.item())
                ######################################################################

                if (batch_idx % args.log_freq) == 0:  #(args.log_freq-1):
                    dataset.set_mode("valid")
                    valid_loader = DataLoader(
                        dataset,
                        batch_size=args.test_bsz,
                        shuffle=True,
                        collate_fn=tvqa_dataset.pad_collate,
                        num_workers=args.num_workers)
                    rnn.eval()
                    ######################################################################
                    train_loss = sum(train_loss) / float(len(train_loss))
                    train_acc = sum(train_acc) / float(len(train_acc))
                    args.plotter.plot("accuracy", "train",
                                      args.jobname + " " + args.task, iter,
                                      train_acc)
                    #import ipdb; ipdb.set_trace()
                    args.plotter.plot("loss", "train",
                                      args.jobname + " " + args.task, iter,
                                      train_loss)
                    train_loss = []
                    train_acc = []
                    ######################################################################
                    with torch.no_grad():
                        if args.test == 0:
                            n_iter = len(valid_loader)
                            losses = []
                            accuracy = []
                            iter_val = 0
                            for batch_idx, batch in enumerate(valid_loader):
                                tvqa_data, targets, _ = tvqa_dataset.preprocess_inputs(
                                    batch,
                                    args.max_sub_l,
                                    args.max_vcpt_l,
                                    args.max_vid_l,
                                    device=args.device)
                                assert (tvqa_data[18].shape[2] == 4096)
                                data_dict = {}
                                data_dict['num_mult_choices'] = 5
                                data_dict[
                                    'answers'] = targets  #'int32 torch tensor correct answer ids'
                                data_dict['video_features'] = torch.cat(
                                    [tvqa_data[18], tvqa_data[16]],
                                    dim=2).unsqueeze(2).unsqueeze(
                                        2)  #'(bsz, 35(max), 1, 1, 6144)'
                                data_dict['video_lengths'] = tvqa_data[
                                    17].tolist(
                                    )  #'list, (bsz), number of frames, #max is 35'
                                data_dict['candidates'] = tvqa_data[
                                    0]  #'torch tensor (bsz, 5, 35) (max q length)'
                                data_dict['candidate_lengths'] = tvqa_data[
                                    1]  #'list of list, lengths of each question+ans'
                                # if hard restrict
                                if args.hard_restrict:
                                    data_dict['video_features'] = data_dict[
                                        'video_features'][:, :args.
                                                          max_sequence_length]  #'(bsz, 35(max), 1, 1, 6144)'
                                    data_dict['video_lengths'] = [
                                        i if i <= args.max_sequence_length else
                                        args.max_sequence_length
                                        for i in data_dict['video_lengths']
                                    ]
                                    data_dict['candidates'] = data_dict[
                                        'candidates'][:, :, :args.
                                                      max_sequence_length]
                                    for xx in range(
                                            data_dict['candidate_lengths'].
                                            shape[0]):
                                        if (data_dict['video_lengths'][xx] == 0
                                            ):
                                            data_dict['video_lengths'][xx] = 1
                                        for yy in range(
                                                data_dict['candidate_lengths'].
                                                shape[1]):
                                            if (data_dict['candidate_lengths']
                                                [xx][yy] >
                                                    args.max_sequence_length):
                                                data_dict['candidate_lengths'][
                                                    xx][yy] = args.max_sequence_length

                                outputs, targets, predictions = rnn(
                                    data_dict, args.task)

                                loss = criterion(outputs, targets)
                                acc = rnn.accuracy(predictions, targets.long())

                                losses.append(loss.item())
                                accuracy.append(acc.item())
                            losses = sum(losses) / n_iter
                            accuracy = sum(accuracy) / n_iter

                            ######################################################################
                            if best_val_acc < accuracy:
                                best_val_iter = iter
                                best_val_acc = accuracy
                                args.plotter.text_plot(
                                    args.task + args.jobname + " val",
                                    args.task + args.jobname + " val " +
                                    str(round(best_val_acc, 4)) + " " +
                                    str(iter))
                            #args.plotter.plot("loss", "val", args.jobname, iter, losses.avg)
                            args.plotter.plot("accuracy", "val",
                                              args.jobname + " " + args.task,
                                              iter, accuracy)
                            ######################################################################

                            #print('[Val] iter %d, loss %.3f, acc %.2f, best acc %.3f at iter %d' % (
                            #                iter, losses, accuracy, best_val_acc, best_val_iter))
                            # torch.save(rnn.state_dict(), os.path.join(args.save_model_path, 'rnn-%04d-l%.3f-a%.3f.pkl' % (
                            #                 iter, losses.avg, accuracy.avg)))
                    rnn.train()
                    dataset.set_mode("train")

                iter += 1
    elif args.task == 'Count':
        pass