예제 #1
0
def main(opt):
    device = torch.device('cuda' if not opt.no_cuda else 'cpu')
    opt = vars(opt)
    checklist = [
            'info_corpus', 'vocab_size',
            'with_category', 'num_category', 
            'decoder_type', 'dataset', 'n_frames', 'max_len', 'top_down'
            ]

    def check(d, other_d, key):
        value = other_d[key]
        if d.get(key, False):
            assert d[key] == value, "%s %s %s" %(key, str(d[key]), str(value))
        return value

    opt_list = []
    model_list = []
    checkpoint_paths = [item.split('/best/')[0] for item in opt['ensemble_checkpoint_paths']]
    pretrained_paths = opt['ensemble_checkpoint_paths']
    
    for i, pth in enumerate(checkpoint_paths):
        info = json.load(open(os.path.join(pth, 'opt_info.json')))
        opt_list.append(info)
        checkpoint = torch.load(pretrained_paths[i])
        model = get_model(info)
        model.load_state_dict(checkpoint['state_dict'])
        model.to(device)
        model_list.append(model)
        print(checkpoint['test_result'])

    for info in opt_list:
        for k in checklist:
            opt[k] = check(opt, info, k)

    names = opt['names']
    names.insert(0, opt['dataset'])     
    pth = os.path.join(opt['results_path'], opt['em'], '_'.join(names))
    if not os.path.exists(pth): os.makedirs(pth)

    with open(os.path.join(pth, 'ensemble_opt.json'), 'w') as f:
        json.dump(opt, f)

    loader_list = []
    for item in opt_list:
        loader = get_loader(item, mode=opt['em'], print_info=False, specific=opt['specific'])
        loader_list.append(loader)

    vocab = loader_list[0].dataset.get_vocab()

    metric = run_eval_ensemble(opt, opt_list, model_list, None, loader_list, vocab, device, 
        json_path=opt['json_path'], json_name=opt['json_name'], 
        print_sent=opt['print_sent'], no_score=opt['ns'], analyze=True)
    print(metric)
예제 #2
0
def load(checkpoint_path,
         checkpoint_name,
         device,
         mid_path='',
         opt_name='opt_info.json',
         from_checkpoint=False):
    checkpoint = torch.load(
        os.path.join(checkpoint_path, mid_path, checkpoint_name))
    if from_checkpoint:
        opt = checkpoint['settings']
    else:
        opt = json.load(open(os.path.join(checkpoint_path, opt_name)))
    model = get_model(opt)

    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    return model, opt
예제 #3
0
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')
    parser.add_argument(
        '-model_path',
        nargs='+',
        type=str,
        default=[
            "/home/yangbang/VideoCaptioning/0219save/Youtube2Text/IEL_ARFormer/EBN1_SS0_NDL1_WC0_MI/",
            #"/home/yangbang/VideoCaptioning/0219save/MSRVTT/IEL_ARFormer/EBN1_SS0_NDL1_WC20_MI/",
            "/home/yangbang/VideoCaptioning/0219save/MSRVTT/IEL_ARFormer/EBN1_SS0_NDL1_WC20_MI_seed920/",
            "/home/yangbang/VideoCaptioning/0219save/VATEX/IEL_ARFormer/EBN1_SS1_NDL1_WC0_M/",
            "/home/yangbang/VideoCaptioning/0219save/MSRVTT/IEL_ARFormer/EBN1_SS0_NDL1_WC20_MI_seed1314_ag/",
        ])
    parser.add_argument(
        '-model_name',
        nargs='+',
        type=str,
        default=[
            '0044_240095_254102_253703_251149_247202.pth.tar',
            #'0028_177617_180524_183734_183213_182417.pth.tar',
            "0011_176183_177176_180332_180729_178864.pth.tar",
            '0099_160093_057474.pth.tar',
            "0025_179448_180500_184018_184037_182508.pth.tar",
        ])
    parser.add_argument('-i', '--index', default=0, type=int)

    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-beam_alpha', type=float, default=1.0)
    parser.add_argument('-batch_size',
                        type=int,
                        default=128,
                        help='Batch size')
    parser.add_argument('-topk',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-em', type=str, default='test')
    parser.add_argument('-print_sent', action='store_true')
    parser.add_argument('-json_path', type=str, default='')
    parser.add_argument('-json_name', type=str, default='')
    parser.add_argument('-ns', default=False, action='store_true')
    parser.add_argument('-sv', default=False, action='store_true')
    parser.add_argument('-analyze', default=False, action='store_true')
    parser.add_argument('-write_time', default=False, action='store_true')
    parser.add_argument('-mid_path', default='best', type=str)
    parser.add_argument('-specific', default=-1, type=int)
    parser.add_argument('-category', default=False, action='store_true')
    parser.add_argument('-sp',
                        '--saved_with_pickle',
                        default=False,
                        action='store_true')
    parser.add_argument('-pp',
                        '--pickle_path',
                        default='./AR_topk_collect_results')
    parser.add_argument('-ca',
                        '--collect_analyze',
                        default=False,
                        action='store_true')

    parser.add_argument('-cv', '--cross_validation', type=int, default=2)

    opt = parser.parse_args()
    opt.model_path = opt.model_path[opt.index]
    opt.model_name = opt.model_name[opt.index]
    if opt.cross_validation == 1:
        source_dataset = 'MSRVTT'
        src_pre = 'msrvtt'
        src_wct = '2'
        target_dataset = 'Youtube2Text'
        tgt_pre = 'msvd'
        tgt_wct = '0'
    else:
        source_dataset = 'Youtube2Text'
        src_pre = 'msvd'
        src_wct = '0'
        target_dataset = 'MSRVTT'
        tgt_pre = 'msrvtt'
        tgt_wct = '2'

    opt_pth = os.path.join(opt.model_path, 'opt_info.json')
    option = json.load(open(opt_pth, 'r'))
    option.update(vars(opt))
    if opt.saved_with_pickle:
        if not os.path.exists(opt.pickle_path):
            os.makedirs(opt.pickle_path)
        string = ''
        if 'Youtube2Text' in opt.model_path:
            dataset_name = 'msvd'
        elif 'MSRVTT' in opt.model_path:
            dataset_name = 'msrvtt'
        if option.get('method', None) == 'ag':
            string = '_ag'
        opt.pickle_path = os.path.join(
            opt.pickle_path, '%s_%d%s.pkl' % (dataset_name, opt.topk, string))

    if opt.collect_analyze:
        collect_analyze(opt)
    else:
        device = torch.device('cuda' if not opt.no_cuda else 'cpu')
        if opt.analyze:
            opt.batch_size = 1
            option['batch_size'] = 1

        #print(option)

        checkpoint = torch.load(
            os.path.join(opt.model_path, opt.mid_path, opt.model_name))
        #option = checkpoint['settings']
        #option.update(vars(opt))
        model = get_model(option)

        model.load_state_dict(checkpoint['state_dict'])
        model.to(device)

        if opt.category:
            loop_category(option, opt, model, device)
        else:
            loader = get_loader(option,
                                mode=opt.em,
                                print_info=True,
                                specific=opt.specific)
            vocab = loader.dataset.get_vocab()

            #print(model)
            calculate_novel = True
            if opt.cross_validation != 2:
                option['dataset'] = target_dataset
                option['feats_a'] = []
                option['feats_m'] = [
                    item.replace(source_dataset, target_dataset)
                    for item in option['feats_m']
                ]
                option['feats_m'] = [
                    item.replace(src_pre, tgt_pre)
                    for item in option['feats_m']
                ]
                option['feats_i'] = [
                    item.replace(source_dataset, target_dataset)
                    for item in option['feats_i']
                ]
                option['feats_i'] = [
                    item.replace(src_pre, tgt_pre)
                    for item in option['feats_i']
                ]

                option['reference'] = option['reference'].replace(
                    source_dataset, target_dataset)
                option['reference'] = option['reference'].replace(
                    src_pre, tgt_pre)

                option['info_corpus'] = option['info_corpus'].replace(
                    source_dataset, target_dataset)
                option['info_corpus'] = option['info_corpus'].replace(
                    src_wct, tgt_wct)
                option['info_corpus'] = option['info_corpus'].replace(
                    'Youtube0Text', 'Youtube2Text')

                loader = get_loader(option,
                                    mode=opt.em,
                                    print_info=True,
                                    specific=opt.specific)
                calculate_novel = False

            print(len(vocab))

            metric = run_eval(option,
                              model,
                              None,
                              loader,
                              vocab,
                              device,
                              json_path=opt.json_path,
                              json_name=opt.json_name,
                              print_sent=opt.print_sent,
                              no_score=opt.ns,
                              save_videodatainfo=opt.sv,
                              analyze=opt.analyze,
                              saved_with_pickle=opt.saved_with_pickle,
                              pickle_path=opt.pickle_path,
                              write_time=opt.write_time,
                              calculate_novel=calculate_novel)

            print(metric)
예제 #4
0
def main(opt):
    '''Main Function'''
    if opt.collect:
        if not os.path.exists(opt.collect_path):
            os.makedirs(opt.collect_path)

    device = torch.device('cuda' if not opt.no_cuda else 'cpu')

    model, option = load(opt.model_path,
                         opt.model_name,
                         device,
                         mid_path='best')
    option.update(vars(opt))
    set_seed(option['seed'])

    if not opt.nt:
        #teacher_path = os.path.join(option["checkpoint_path"].replace('NARFormer', 'ARFormer') + '_SS1_0_70')
        #teacher_name = 'teacher.pth.tar'
        #teacher_model, teacher_option = load(teacher_path, teacher_name, device, mid_path='', from_checkpoint=True)

        checkpoint = torch.load(opt.teacher_path)
        teacher_option = checkpoint['settings']
        teacher_model = get_model(teacher_option)
        teacher_model.load_state_dict(checkpoint['state_dict'])
        teacher_model.to(device)

        assert teacher_option['vocab_size'] == option['vocab_size']

        #dict_mapping = get_dict_mapping(option, teacher_option)
        dict_mapping = {}
    else:
        teacher_model = None
        dict_mapping = {}
    '''
    model = get_model(option)
    pth = os.path.join(opt.model_path, 'tmp_models')
    vali_loader = get_loader(option, mode='validate')
    test_loader = get_loader(option, mode='test')
    vocab = vali_loader.dataset.get_vocab()
    logger = CsvLogger(
        filepath=pth, 
        filename='evaluate.csv', 
        fieldsnames=['epoch', 'split', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'METEOR', 'ROUGE_L', 'CIDEr', 'Sum', 'lbs', 'i', 'ba']
        )
    for file in os.listdir(pth):
        if '.pth.tar' not in file:
            continue
        epoch = file.split('_')[1]
        checkpoint = torch.load(os.path.join(pth, file))
        model.load_state_dict(checkpoint['state_dict'])
        model.to(device)

        best = 0
        best_info = ()
        for lbs in range(1, 11):
            for i in [1, 3, 5, 10]:
                option['length_beam_size'] = lbs
                option['iterations'] = i
                metric = run_eval(option, model, None, test_loader, vocab, device, json_path=opt.json_path, json_name=opt.json_name, print_sent=opt.print_sent)
                metric.pop('loss')
                metric['lbs'] = lbs
                metric['ba'] = opt.beam_alpha
                metric['i'] = i
                metric['split'] = 'test'
                metric['epoch'] = epoch
                logger.write(metric)
                if metric['Sum'] > best:
                    best = metric['Sum']
                    best_info = (lbs, i)
                print(lbs, i, metric['Sum'], best)


    '''
    '''
    # rec length predicted results
    rec = {}
    for data in tqdm(loader, ncols=150, leave=False):
        with torch.no_grad():
            results = get_forword_results(option, model, data, device=device, only_data=False)
            for i in range(results['pred_length'].size(0)):
                res = results['pred_length'][i].topk(5)[1].tolist()
                for item in res:
                    rec[item] = rec.get(item, 0) + 1
    for i in range(50):
        if i in rec.keys():
            print(i, rec[i])
    '''
    if opt.plot:
        plot(option, opt, model, loader, vocab, device, teacher_model,
             dict_mapping)
    elif opt.loop:
        loader = get_loader(option, mode=opt.em, print_info=True)
        vocab = loader.dataset.get_vocab()
        #loop_iterations(option, opt, model, loader, vocab, device, teacher_model, dict_mapping)
        loop_length_beam(option, opt, model, loader, vocab, device,
                         teacher_model, dict_mapping)
        #loop_iterations(option, opt, model, device, teacher_model, dict_mapping)
    elif opt.category:
        loop_category(option, opt, model, device, teacher_model, dict_mapping)
    else:
        loader = get_loader(option,
                            mode=opt.em,
                            print_info=True,
                            specific=opt.specific)
        vocab = loader.dataset.get_vocab()
        filename = '%s_%s_%s_i%db%da%03d%s.pkl' % (
            option['dataset'], option['method'],
            ('%s' % ('AE' if opt.nv_scale == 100 else '')) + opt.paradigm,
            opt.iterations, opt.length_beam_size, int(
                100 * opt.beam_alpha), '_all' if opt.em == 'all' else '')
        metric = run_eval(option,
                          model,
                          None,
                          loader,
                          vocab,
                          device,
                          json_path=opt.json_path,
                          json_name=opt.json_name,
                          print_sent=opt.print_sent,
                          teacher_model=teacher_model,
                          length_crit=torch.nn.SmoothL1Loss(),
                          dict_mapping=dict_mapping,
                          analyze=opt.analyze,
                          collect_best_candidate_iterative_results=True
                          if opt.collect else False,
                          collect_path=os.path.join(opt.collect_path,
                                                    filename),
                          no_score=opt.ns,
                          write_time=opt.write_time)
        #collect_path=os.path.join(opt.collect_path, opt.collect),
        print(metric)