コード例 #1
0
ファイル: validate.py プロジェクト: theluckygod/hcrn-videoqa
        cfg.dataset.motion_feat = os.path.join(
            cfg.dataset.data_dir,
            cfg.dataset.motion_feat.format(cfg.dataset.name))

    test_loader_kwargs = {
        'question_type': cfg.dataset.question_type,
        'question_pt': cfg.dataset.test_question_pt,
        'vocab_json': cfg.dataset.vocab_json,
        'appearance_feat': cfg.dataset.appearance_feat,
        'motion_feat': cfg.dataset.motion_feat,
        'test_num': cfg.test.test_num,
        'batch_size': cfg.train.batch_size,
        'num_workers': cfg.num_workers,
        'shuffle': False
    }
    test_loader = VideoQADataLoader(**test_loader_kwargs)
    model_kwargs.update({'vocab': test_loader.vocab})
    model = HCRN.HCRNNetwork(**model_kwargs).to(device)
    model.load_state_dict(loaded['state_dict'])

    if cfg.test.write_preds:
        acc, preds, gts, v_ids, q_ids = validate(cfg, model, test_loader,
                                                 device, cfg.test.write_preds)

        sys.stdout.write('~~~~~~ Test Accuracy: {test_acc} ~~~~~~~\n'.format(
            test_acc=colored("{:.4f}".format(acc), "red", attrs=['bold'])))
        sys.stdout.flush()

        # write predictions for visualization purposes
        output_dir = os.path.join(cfg.dataset.save_dir, 'preds')
        if not os.path.exists(output_dir):
コード例 #2
0
def train(cfg):
    logging.info("Create train_loader and val_loader.........")
    train_loader_kwargs = {
        'question_type': cfg.dataset.question_type,
        'question_pt': cfg.dataset.train_question_pt,
        'vocab_json': cfg.dataset.vocab_json,
        'appearance_feat': cfg.dataset.appearance_feat,
        'motion_feat': cfg.dataset.motion_feat,
        'train_num': cfg.train.train_num,
        'batch_size': cfg.train.batch_size,
        'num_workers': cfg.num_workers,
        'shuffle': True,
        # 'pin_memory': True
    }
    if cfg.bert.flag:
        if 'precomputed' in cfg.bert.model:
            train_loader_kwargs['question_feat'] = cfg.bert.train_question_feat
        train_loader_kwargs['bert_model'] = cfg.bert.model

    train_loader = VideoQADataLoader(**train_loader_kwargs)
    logging.info("number of train instances: {}".format(
        len(train_loader.dataset)))
    if cfg.val.flag:
        val_loader_kwargs = {
            'question_type': cfg.dataset.question_type,
            'question_pt': cfg.dataset.val_question_pt,
            'vocab_json': cfg.dataset.vocab_json,
            'appearance_feat': cfg.dataset.appearance_feat,
            'motion_feat': cfg.dataset.motion_feat,
            'val_num': cfg.val.val_num,
            'batch_size': cfg.train.batch_size,
            'num_workers': cfg.num_workers,
            'shuffle': False,
            # 'pin_memory': True
        }
        if cfg.bert.flag:
            if 'precomputed' in cfg.bert.model:
                val_loader_kwargs['question_feat'] = cfg.bert.val_question_feat
            val_loader_kwargs['bert_model'] = cfg.bert.model
        val_loader = VideoQADataLoader(**val_loader_kwargs)
        logging.info("number of val instances: {}".format(
            len(val_loader.dataset)))

    logging.info("Create model.........")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model_kwargs = {
        'vision_dim': cfg.train.vision_dim,
        'module_dim': cfg.train.module_dim,
        'word_dim': cfg.train.word_dim,
        'k_max_frame_level': cfg.train.k_max_frame_level,
        'k_max_clip_level': cfg.train.k_max_clip_level,
        'spl_resolution': cfg.train.spl_resolution,
        'vocab': train_loader.vocab,
        'question_type': cfg.dataset.question_type,
        'hcrn_model': cfg.train.hcrn_model,
        'subvids': cfg.train.subvids,
    }
    if cfg.bert.flag:
        model_kwargs['bert_model'] = cfg.bert.model
        model_kwargs['word_dim'] = cfg.bert.word_dim
    model_kwargs_tosave = {
        k: v
        for k, v in model_kwargs.items() if k != 'vocab'
    }
    model = HCRN.HCRNNetwork(**model_kwargs).to(device)
    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    logging.info('num of params: {}'.format(pytorch_total_params))
    logging.info(model)

    if cfg.train.glove:
        logging.info('load glove vectors')
        train_loader.glove_matrix = torch.FloatTensor(
            train_loader.glove_matrix).to(device)
        with torch.no_grad():
            model.linguistic_input_unit.encoder_embed.weight.set_(
                train_loader.glove_matrix)
    if torch.cuda.device_count() > 1 and cfg.multi_gpus:
        model = model.cuda()
        logging.info("Using {} GPUs".format(torch.cuda.device_count()))
        model = nn.DataParallel(model, device_ids=None)

    optimizer = optim.Adam(model.parameters(), cfg.train.lr)

    start_epoch = 0
    if cfg.dataset.question_type == 'count':
        best_val = 100.0
    else:
        best_val = 0
    if cfg.train.restore:
        print("Restore checkpoint and optimizer...")
        ckpt = os.path.join(cfg.dataset.save_dir, 'ckpt', 'model.pt')
        ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage)
        start_epoch = ckpt['epoch'] + 1
        # best_val = ckpt['best_val']
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])
    if cfg.dataset.question_type in ['frameqa', 'none']:
        criterion = nn.CrossEntropyLoss().to(device)
    elif cfg.dataset.question_type == 'count':
        criterion = nn.MSELoss().to(device)
    logging.info("Start training........")
    for epoch in range(start_epoch, cfg.train.max_epochs):
        logging.info('>>>>>> epoch {epoch} <<<<<<'.format(
            epoch=colored("{}".format(epoch), "green", attrs=["bold"])))
        model.train()
        total_acc, count = 0, 0
        batch_mse_sum = 0.0
        total_loss, avg_loss = 0.0, 0.0
        avg_loss = 0
        train_accuracy = 0
        for i, batch in enumerate(iter(train_loader)):
            progress = epoch + i / len(train_loader)
            _, _, answers, *batch_input = [todevice(x, device) for x in batch]
            answers = todevice(answers, device).squeeze()
            loss, metric = train_it(batch_input, answers, model, criterion,
                                    optimizer, cfg)
            total_loss += loss
            avg_loss = total_loss / (i + 1)
            if cfg.dataset.question_type == 'count':
                batch_avg_mse = metric / answers.size(0)
                batch_mse_sum += metric
                count += answers.size(0)
                avg_mse = batch_mse_sum / count
                sys.stdout.write(
                    "\rProgress = {progress}   ce_loss = {ce_loss}   avg_loss = {avg_loss}    train_mse = {train_mse}    avg_mse = {avg_mse}    exp: {exp_name}"
                    .format(progress=colored("{:.3f}".format(progress),
                                             "green",
                                             attrs=['bold']),
                            ce_loss=colored("{:.4f}".format(loss),
                                            "blue",
                                            attrs=['bold']),
                            avg_loss=colored("{:.4f}".format(avg_loss),
                                             "red",
                                             attrs=['bold']),
                            train_mse=colored("{:.4f}".format(batch_avg_mse),
                                              "blue",
                                              attrs=['bold']),
                            avg_mse=colored("{:.4f}".format(avg_mse),
                                            "red",
                                            attrs=['bold']),
                            exp_name=cfg.exp_name))
                sys.stdout.flush()
            else:
                total_acc += metric.sum()
                count += answers.size(0)
                train_accuracy = total_acc / count
                sys.stdout.write(
                    "\rProgress = {progress}   ce_loss = {ce_loss}   avg_loss = {avg_loss}    train_acc = {train_acc}    avg_acc = {avg_acc}    exp: {exp_name}"
                    .format(progress=colored("{:.3f}".format(progress),
                                             "green",
                                             attrs=['bold']),
                            ce_loss=colored("{:.4f}".format(loss),
                                            "blue",
                                            attrs=['bold']),
                            avg_loss=colored("{:.4f}".format(avg_loss),
                                             "red",
                                             attrs=['bold']),
                            train_acc=colored("{:.4f}".format(metric.mean()),
                                              "blue",
                                              attrs=['bold']),
                            avg_acc=colored("{:.4f}".format(train_accuracy),
                                            "red",
                                            attrs=['bold']),
                            exp_name=cfg.exp_name))
                sys.stdout.flush()
        sys.stdout.write("\n")
        if cfg.dataset.question_type == 'count':
            if (epoch + 1) % 5 == 0:
                optimizer = step_decay(cfg, optimizer)
        else:
            if (epoch + 1) % 10 == 0:
                optimizer = step_decay(cfg, optimizer)
        sys.stdout.flush()
        logging.info("Epoch = %s   avg_loss = %.3f    avg_acc = %.3f" %
                     (epoch, avg_loss, train_accuracy))

        if cfg.val.flag:
            output_dir = os.path.join(cfg.dataset.save_dir, 'preds')
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            else:
                assert os.path.isdir(output_dir)
            valid_acc = validate(cfg,
                                 model,
                                 val_loader,
                                 device,
                                 write_preds=False)
            if (valid_acc > best_val and cfg.dataset.question_type != 'count'
                ) or (valid_acc < best_val
                      and cfg.dataset.question_type == 'count'):
                best_val = valid_acc
                # Save best model
                ckpt_dir = os.path.join(cfg.dataset.save_dir, 'ckpt')
                if not os.path.exists(ckpt_dir):
                    os.makedirs(ckpt_dir)
                else:
                    assert os.path.isdir(ckpt_dir)
                save_checkpoint(epoch, model, optimizer, model_kwargs_tosave,
                                os.path.join(ckpt_dir, 'model.pt'), best_val)
                sys.stdout.write('\n >>>>>> save to %s <<<<<< \n' % (ckpt_dir))
                sys.stdout.flush()

            logging.info('~~~~~~ Valid Accuracy: %.4f ~~~~~~~' % valid_acc)
            sys.stdout.write(
                '~~~~~~ Valid Accuracy: {valid_acc} ~~~~~~~\n'.format(
                    valid_acc=colored(
                        "{:.4f}".format(valid_acc), "red", attrs=['bold'])))
            sys.stdout.flush()
コード例 #3
0
def train(cfg):
    logging.info("Create train_loader and val_loader.........")
    train_loader_kwargs = {
        'question_type': cfg.dataset.question_type,
        'question_pt': cfg.dataset.train_question_pt,
        'vocab_json': cfg.dataset.vocab_json,
        'appearance_feat': cfg.dataset.appearance_feat,
        'motion_feat': cfg.dataset.motion_feat,
        'train_num': cfg.train.train_num,
        'batch_size': cfg.train.batch_size,
        'num_workers': cfg.num_workers,
        'shuffle': True
    }
    train_loader = VideoQADataLoader(**train_loader_kwargs)
    logging.info("number of train instances: {}".format(
        len(train_loader.dataset)))
    if cfg.val.flag:
        val_loader_kwargs = {
            'question_type': cfg.dataset.question_type,
            'question_pt': cfg.dataset.val_question_pt,
            'vocab_json': cfg.dataset.vocab_json,
            'appearance_feat': cfg.dataset.appearance_feat,
            'motion_feat': cfg.dataset.motion_feat,
            'val_num': cfg.val.val_num,
            'batch_size': cfg.train.batch_size,
            'num_workers': cfg.num_workers,
            'shuffle': False
        }
        val_loader = VideoQADataLoader(**val_loader_kwargs)
        logging.info("number of val instances: {}".format(
            len(val_loader.dataset)))

    logging.info("Create model.........")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model_kwargs = {
        'vision_dim': cfg.train.vision_dim,
        'module_dim': cfg.train.module_dim,
        'word_dim': cfg.train.word_dim,
        'k_max_frame_level': cfg.train.k_max_frame_level,
        'k_max_clip_level': cfg.train.k_max_clip_level,
        'spl_resolution': cfg.train.spl_resolution,
        'vocab': train_loader.vocab,
        'question_type': cfg.dataset.question_type
    }
    model_kwargs_tosave = {
        k: v
        for k, v in model_kwargs.items() if k != 'vocab'
    }
    model = HCRN.HCRNNetwork(**model_kwargs).to(device)
    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    logging.info('num of params: {}'.format(pytorch_total_params))
    logging.info(model)

    if cfg.train.glove:
        logging.info('load glove vectors')
        train_loader.glove_matrix = torch.FloatTensor(
            train_loader.glove_matrix).to(device)
        with torch.no_grad():
            model.linguistic_input_unit.encoder_embed.weight.set_(
                train_loader.glove_matrix)
    if torch.cuda.device_count() > 1 and cfg.multi_gpus:
        model = model.cuda()
        logging.info("Using {} GPUs".format(torch.cuda.device_count()))
        model = nn.DataParallel(model, device_ids=None)

    optimizer = optim.Adam(model.parameters(), cfg.train.lr)

    start_epoch = 0
    if cfg.dataset.question_type == 'count':
        best_val = 100.0
    else:
        best_val = 0
    if cfg.train.restore:
        print("Restore checkpoint and optimizer...")
        ckpt = os.path.join(cfg.dataset.save_dir, 'ckpt', 'model.pt')
        ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage)
        start_epoch = ckpt['epoch'] + 1
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])
    if cfg.dataset.question_type in ['frameqa', 'none']:
        criterion = nn.CrossEntropyLoss().to(device)
    elif cfg.dataset.question_type == 'count':
        criterion = nn.MSELoss().to(device)
    logging.info("Start training........")
    for epoch in range(start_epoch, cfg.train.max_epochs):
        logging.info('>>>>>> epoch {epoch} <<<<<<'.format(
            epoch=colored("{}".format(epoch), "green", attrs=["bold"])))
        model.train()
        total_acc, count = 0, 0
        batch_mse_sum = 0.0
        total_loss, avg_loss = 0.0, 0.0
        avg_loss = 0
        train_accuracy = 0
        for i, batch in enumerate(iter(train_loader)):
            progress = epoch + i / len(train_loader)
            _, _, answers, *batch_input = [todevice(x, device) for x in batch]
            answers = answers.cuda().squeeze()
            batch_size = answers.size(0)
            optimizer.zero_grad()
            logits = model(*batch_input)
            if cfg.dataset.question_type in ['action', 'transition']:
                batch_agg = np.concatenate(
                    np.tile(
                        np.arange(batch_size).reshape([batch_size, 1]),
                        [1, 5])) * 5  # [0, 0, 0, 0, 0, 5, 5, 5, 5, 1, ...]
                answers_agg = tile(answers, 0, 5)
                loss = torch.max(
                    torch.tensor(0.0).cuda(), 1.0 + logits -
                    logits[answers_agg + torch.from_numpy(batch_agg).cuda()])
                loss = loss.sum()
                loss.backward()
                total_loss += loss.detach()
                avg_loss = total_loss / (i + 1)
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=12)
                optimizer.step()
                preds = torch.argmax(logits.view(batch_size, 5), dim=1)
                aggreeings = (preds == answers)
            elif cfg.dataset.question_type == 'count':
                answers = answers.unsqueeze(-1)
                loss = criterion(logits, answers.float())
                loss.backward()
                total_loss += loss.detach()
                avg_loss = total_loss / (i + 1)
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=12)
                optimizer.step()
                preds = (logits + 0.5).long().clamp(min=1, max=10)
                batch_mse = (preds - answers)**2
            else:
                loss = criterion(logits, answers)
                loss.backward()
                total_loss += loss.detach()
                avg_loss = total_loss / (i + 1)
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=12)
                optimizer.step()
                aggreeings = batch_accuracy(logits, answers)

            if cfg.dataset.question_type == 'count':
                batch_avg_mse = batch_mse.sum().item() / answers.size(0)
                batch_mse_sum += batch_mse.sum().item()
                count += answers.size(0)
                avg_mse = batch_mse_sum / count
                sys.stdout.write(
                    "\rProgress = {progress}   ce_loss = {ce_loss}   avg_loss = {avg_loss}    train_mse = {train_mse}    avg_mse = {avg_mse}    exp: {exp_name}"
                    .format(progress=colored("{:.3f}".format(progress),
                                             "green",
                                             attrs=['bold']),
                            ce_loss=colored("{:.4f}".format(loss.item()),
                                            "blue",
                                            attrs=['bold']),
                            avg_loss=colored("{:.4f}".format(avg_loss),
                                             "red",
                                             attrs=['bold']),
                            train_mse=colored("{:.4f}".format(batch_avg_mse),
                                              "blue",
                                              attrs=['bold']),
                            avg_mse=colored("{:.4f}".format(avg_mse),
                                            "red",
                                            attrs=['bold']),
                            exp_name=cfg.exp_name))
                sys.stdout.flush()
            else:
                total_acc += aggreeings.sum().item()
                count += answers.size(0)
                train_accuracy = total_acc / count
                sys.stdout.write(
                    "\rProgress = {progress}   ce_loss = {ce_loss}   avg_loss = {avg_loss}    train_acc = {train_acc}    avg_acc = {avg_acc}    exp: {exp_name}"
                    .format(progress=colored("{:.3f}".format(progress),
                                             "green",
                                             attrs=['bold']),
                            ce_loss=colored("{:.4f}".format(loss.item()),
                                            "blue",
                                            attrs=['bold']),
                            avg_loss=colored("{:.4f}".format(avg_loss),
                                             "red",
                                             attrs=['bold']),
                            train_acc=colored("{:.4f}".format(
                                aggreeings.float().mean().cpu().numpy()),
                                              "blue",
                                              attrs=['bold']),
                            avg_acc=colored("{:.4f}".format(train_accuracy),
                                            "red",
                                            attrs=['bold']),
                            exp_name=cfg.exp_name))
                sys.stdout.flush()
        sys.stdout.write("\n")
        if cfg.dataset.question_type == 'count':
            if (epoch + 1) % 5 == 0:
                optimizer = step_decay(cfg, optimizer)
        else:
            if (epoch + 1) % 10 == 0:
                optimizer = step_decay(cfg, optimizer)
        sys.stdout.flush()
        logging.info("Epoch = %s   avg_loss = %.3f    avg_acc = %.3f" %
                     (epoch, avg_loss, train_accuracy))

        if cfg.val.flag:
            output_dir = os.path.join(cfg.dataset.save_dir, 'preds')
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            else:
                assert os.path.isdir(output_dir)
            valid_acc = validate(cfg,
                                 model,
                                 val_loader,
                                 device,
                                 write_preds=False)
            if (valid_acc > best_val and cfg.dataset.question_type != 'count'
                ) or (valid_acc < best_val
                      and cfg.dataset.question_type == 'count'):
                best_val = valid_acc
                # Save best model
                ckpt_dir = os.path.join(cfg.dataset.save_dir, 'ckpt')
                if not os.path.exists(ckpt_dir):
                    os.makedirs(ckpt_dir)
                else:
                    assert os.path.isdir(ckpt_dir)
                save_checkpoint(epoch, model, optimizer, model_kwargs_tosave,
                                os.path.join(ckpt_dir, 'model.pt'))
                sys.stdout.write('\n >>>>>> save to %s <<<<<< \n' % (ckpt_dir))
                sys.stdout.flush()

            logging.info('~~~~~~ Valid Accuracy: %.4f ~~~~~~~' % valid_acc)
            sys.stdout.write(
                '~~~~~~ Valid Accuracy: {valid_acc} ~~~~~~~\n'.format(
                    valid_acc=colored(
                        "{:.4f}".format(valid_acc), "red", attrs=['bold'])))
            sys.stdout.flush()
コード例 #4
0
def process_final(cfg):

    assert cfg.dataset.name in [
        'tgif-qa', 'tgif-qa-infer', 'msrvtt-qa', 'msvd-qa'
    ]
    assert cfg.dataset.question_type in [
        'frameqa', 'count', 'transition', 'action', 'none'
    ]
    # check if the data folder exists
    assert os.path.exists(cfg.dataset.data_dir)

    if cfg.dataset.name != 'tgif-qa-infer':
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = 'cpu'
    # cfg.dataset.save_dir = os.path.join(cfg.dataset.save_dir, cfg.exp_name)
    ckpt = os.path.join(cfg.dataset.save_dir, 'ckpt', 'model.pt')
    print('ckpt:: ', ckpt)
    assert os.path.exists(ckpt)
    # load pretrained model
    loaded = torch.load(ckpt, map_location='cpu')
    model_kwargs = loaded['model_kwargs']

    if cfg.dataset.name == 'tgif-qa' or cfg.dataset.name == 'tgif-qa-infer':
        cfg.dataset.test_question_pt = os.path.join(
            cfg.dataset.data_dir,
            cfg.dataset.test_question_pt.format(cfg.dataset.name,
                                                cfg.dataset.question_type))

        cfg.dataset.vocab_json = os.path.join(
            cfg.dataset.data_dir,
            cfg.dataset.vocab_json.format(cfg.dataset.name,
                                          cfg.dataset.question_type))

        cfg.dataset.appearance_feat = os.path.join(
            cfg.dataset.video_dir,
            cfg.dataset.appearance_feat.format(cfg.dataset.name))

        cfg.dataset.motion_feat = os.path.join(
            cfg.dataset.video_dir,
            cfg.dataset.motion_feat.format(cfg.dataset.name))

    else:
        cfg.dataset.question_type = 'none'
        cfg.dataset.appearance_feat = '{}_appearance_feat.h5'
        cfg.dataset.motion_feat = '{}_motion_feat.h5'
        cfg.dataset.vocab_json = '{}_vocab.json'
        cfg.dataset.test_question_pt = '{}_test_questions.pt'

        cfg.dataset.test_question_pt = os.path.join(
            cfg.dataset.data_dir,
            cfg.dataset.test_question_pt.format(cfg.dataset.name))
        cfg.dataset.vocab_json = os.path.join(
            cfg.dataset.data_dir,
            cfg.dataset.vocab_json.format(cfg.dataset.name))

        cfg.dataset.appearance_feat = os.path.join(
            cfg.dataset.data_dir,
            cfg.dataset.appearance_feat.format(cfg.dataset.name))
        cfg.dataset.motion_feat = os.path.join(
            cfg.dataset.data_dir,
            cfg.dataset.motion_feat.format(cfg.dataset.name))

    test_loader_kwargs = {
        'question_type': cfg.dataset.question_type,
        'question_pt': cfg.dataset.test_question_pt,
        'vocab_json': cfg.dataset.vocab_json,
        'appearance_feat': cfg.dataset.appearance_feat,
        'motion_feat': cfg.dataset.motion_feat,
        'test_num': cfg.test.test_num,
        'batch_size': cfg.train.batch_size,
        'num_workers': cfg.num_workers,
        'shuffle': False
    }
    test_loader = VideoQADataLoader(**test_loader_kwargs)
    model_kwargs.update({'vocab': test_loader.vocab})
    model = HCRN.HCRNNetwork(**model_kwargs).to(device)
    model.load_state_dict(loaded['state_dict'])

    if cfg.test.write_preds:
        acc, preds, gts, v_ids, q_ids, logits = validate(
            cfg, model, test_loader, device, cfg.test.write_preds)
        # print('===Question_type', cfg.dataset.question_type)
        # print('====LOGIT ', logits)
        detail = []

        if cfg.dataset.question_type in ['action', 'transition']:
            sm = torch.nn.Softmax(dim=1)
            print('origin_value::', logits.t())
            probs = sm(logits.t())
            print('>>>> Probs: ', type(probs), probs.size(), probs)
            detail = probs.numpy().tolist()

        elif cfg.dataset.question_type in ['frameqa']:
            sm = torch.nn.Softmax()
            probs = sm(logits)
            print('>>>> Probs: ', type(probs), probs.size(), probs)
            answer_vocab = test_loader.vocab['answer_idx_to_token']
            values, idx = torch.topk(probs, 5)
            print('>>>Top5 ', idx)
            top_answer = []
            i = 0
            for predict in idx:
                print('FRAMEQA-pred.item:: ', list(predict.numpy()),
                      list(values[i].numpy()))
                # print('FRAMEQA-answer_vocab:: ', answer_vocab)
                top_answer.append(
                    ([answer_vocab[ix] for ix in list(predict.numpy())],
                     [float(v) for v in list(values[i].numpy())]))
                i += 1
            print('FRAMEQA-topk:: ', top_answer)
            detail = top_answer

        sys.stdout.write('~~~~~~ Test Accuracy: {test_acc} ~~~~~~~\n'.format(
            test_acc=colored("{:.4f}".format(acc), "red", attrs=['bold'])))
        sys.stdout.flush()

        # write predictions for visualization purposes
        output_dir = os.path.join(cfg.dataset.save_dir, 'preds')
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        else:
            assert os.path.isdir(output_dir)
        preds_file = os.path.join(output_dir, "test_preds.json")

        if cfg.dataset.question_type in ['action', 'transition']:             \
                # Find groundtruth questions and corresponding answer candidates

            vocab = test_loader.vocab['question_answer_idx_to_token']
            dict = {}
            with open(cfg.dataset.test_question_pt, 'rb') as f:
                obj = pickle.load(f)
                questions = obj['questions']
                org_v_ids = obj['video_ids']
                org_v_names = obj['video_names']
                org_q_ids = obj['question_id']
                ans_candidates = obj['ans_candidates']

            for idx in range(len(org_q_ids)):
                dict[str(org_q_ids[idx])] = [
                    org_v_names[idx], questions[idx], ans_candidates[idx]
                ]

            instances = [{
                'video_id':
                video_id,
                'question_id':
                q_id,
                'video_name':
                dict[str(q_id)][0],
                'question': [
                    vocab[word.item()] for word in dict[str(q_id)][1]
                    if word != 0
                ],
                'answer':
                answer,
                'prediction':
                pred,
                'detail':
                d
            } for video_id, q_id, answer, pred, d in zip(
                np.hstack(v_ids).tolist(),
                np.hstack(q_ids).tolist(), gts, preds, detail)]
            # write preditions to json file
            # with open(preds_file, 'w') as f:
            #     json.dump(instances, f)
            sys.stdout.write('Display 10 samples...\n')

            # Display 10 samples
            if cfg.dataset.name == 'tgif-qa-infer':
                sample_size = 1
            else:
                sample_size = 10

            for idx in range(sample_size):
                print('Video name: {}'.format(dict[str(q_ids[idx].item())][0]))
                cur_question = [
                    vocab[word.item()]
                    for word in dict[str(q_ids[idx].item())][1] if word != 0
                ]
                print('Question: ' + ' '.join(cur_question) + '?')
                all_answer_cands = dict[str(q_ids[idx].item())][2]
                for cand_id in range(len(all_answer_cands)):
                    cur_answer_cands = [
                        vocab[word.item()]
                        for word in all_answer_cands[cand_id] if word != 0
                    ]
                    print('({}): '.format(cand_id) +
                          ' '.join(cur_answer_cands))
                print('Prediction: {}'.format(preds[idx]))
                print('Groundtruth: {}'.format(gts[idx]))

            return instances
        else:
            vocab = test_loader.vocab['question_idx_to_token']
            dict = {}
            with open(cfg.dataset.test_question_pt, 'rb') as f:
                obj = pickle.load(f)
                questions = obj['questions']
                org_v_ids = obj['video_ids']
                org_v_names = obj['video_names']
                org_q_ids = obj['question_id']

            for idx in range(len(org_q_ids)):
                dict[str(org_q_ids[idx])] = [org_v_names[idx], questions[idx]]

            if cfg.dataset.question_type == 'frameqa':
                instances = [{
                    'video_id':
                    video_id,
                    'question_id':
                    q_id,
                    'video_name':
                    str(dict[str(q_id)][0]),
                    'question': [
                        vocab[word.item()] for word in dict[str(q_id)][1]
                        if word != 0
                    ],
                    'answer':
                    answer,
                    'prediction':
                    pred,
                    'detail':
                    d
                } for video_id, q_id, answer, pred, d in zip(
                    np.hstack(v_ids).tolist(),
                    np.hstack(q_ids).tolist(), gts, preds, detail)]
            else:
                instances = [{
                    'video_id':
                    video_id,
                    'question_id':
                    q_id,
                    'video_name':
                    str(dict[str(q_id)][0]),
                    'question': [
                        vocab[word.item()] for word in dict[str(q_id)][1]
                        if word != 0
                    ],
                    'answer':
                    answer,
                    'prediction':
                    pred
                } for video_id, q_id, answer, pred in zip(
                    np.hstack(v_ids).tolist(),
                    np.hstack(q_ids).tolist(), gts, preds)]
            # write preditions to json file
            # with open(preds_file, 'w') as f:
            #     json.dump(instances, f)
            sys.stdout.write('Display 10 samples...\n')

            # Display 10 samples
            if cfg.dataset.name == 'tgif-qa-infer':
                sample_size = 1
            else:
                sample_size = 10

            for idx in range(sample_size):
                print('Video name: {}'.format(dict[str(q_ids[idx].item())][0]))
                cur_question = [
                    vocab[word.item()]
                    for word in dict[str(q_ids[idx].item())][1] if word != 0
                ]
                print('Question: ' + ' '.join(cur_question) + '?')
                print('Prediction: {}'.format(preds[idx]))
                print('Groundtruth: {}'.format(gts[idx]))

            return instances
    else:
        acc = validate(cfg, model, test_loader, device, cfg.test.write_preds)
        sys.stdout.write('~~~~~~ Test Accuracy: {test_acc} ~~~~~~~\n'.format(
            test_acc=colored("{:.4f}".format(acc), "red", attrs=['bold'])))
        sys.stdout.flush()

        return []