Exemplo n.º 1
0
def loop_length_beam(option, opt, model, loader, vocab, device, teacher_model,
                     dict_mapping):
    b4 = []
    m = []
    r = []
    c = []
    ave_len = []
    for lbs in range(1, 11):
        option['length_beam_size'] = lbs
        metric = run_eval(option,
                          model,
                          None,
                          loader,
                          vocab,
                          device,
                          json_path=opt.json_path,
                          json_name=opt.json_name,
                          print_sent=opt.print_sent,
                          teacher_model=teacher_model,
                          length_crit=torch.nn.SmoothL1Loss(),
                          dict_mapping=dict_mapping,
                          analyze=True)
        b4.append(metric["Bleu_4"])
        m.append(metric["METEOR"])
        r.append(metric["ROUGE_L"])
        c.append(metric["CIDEr"])
        ave_len.append(metric['ave_length'])

    print(b4)
    print(m)
    print(r)
    print(c)
    print(ave_len)
Exemplo n.º 2
0
def loop_category(option, opt, model, device):
    loop_logger = CsvLogger(
        filepath='./category_results',
        filename='ARVC_%s%s.csv' %
        (option['dataset'],
         '' if option.get('method', None) != 'ag' else '_ag'),
        #fieldsnames=['novel', 'unique', 'usage', 'ave_length', 'gram4'],
        fieldsnames=[
            'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'METEOR', 'ROUGE_L',
            'CIDEr', 'Sum', 'loss'
        ],
    )

    for i in range(20):
        loader = get_loader(option, mode=opt.em, specific=i)
        vocab = loader.dataset.get_vocab()
        #metric = run_eval(option, model, None, loader, vocab, device, print_sent=opt.print_sent, no_score=True, analyze=True)
        metric = run_eval(option,
                          model,
                          None,
                          loader,
                          vocab,
                          device,
                          print_sent=opt.print_sent,
                          no_score=False,
                          analyze=False)

        loop_logger.write(metric)
Exemplo n.º 3
0
def loop_iterations(option, opt, model, loader, vocab, device, teacher_model,
                    dict_mapping):
    loop_logger = CsvLogger(filepath='./loop_results',
                            filename=opt.loop + '.csv',
                            fieldsnames=[
                                'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4',
                                'METEOR', 'ROUGE_L', 'CIDEr', 'Sum',
                                'iterations', 'lbs', 'novel', 'unique',
                                'usage', 'ave_length'
                            ])
    for i in range(1, 11):
        option['iterations'] = i
        metric = run_eval(option,
                          model,
                          None,
                          loader,
                          vocab,
                          device,
                          json_path=opt.json_path,
                          json_name=opt.json_name,
                          print_sent=opt.print_sent,
                          teacher_model=teacher_model,
                          length_crit=torch.nn.SmoothL1Loss(),
                          dict_mapping=dict_mapping,
                          analyze=True)

        metric['iterations'] = option['iterations']
        metric['lbs'] = option['length_beam_size']
        metric.pop('loss')
        loop_logger.write(metric)
Exemplo n.º 4
0
def main(opt):
    device = torch.device('cuda' if not opt.no_cuda else 'cpu')
    opt_pth = os.path.join(opt.model_path, 'opt_info.json')
    option = json.load(open(opt_pth, 'r'))
    option.update(vars(opt))
    # print(option)

    model = get_model(option)
    checkpoint = torch.load(os.path.join(opt.model_path, 'best', opt.model_name))
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)

    for key in ['info_corpus', 'reference', 'feats_i', 'feats_m', 'feats_a']:
        if isinstance(option[key], list):
            for i in range(len(option[key])):
                option[key][i] = option[key][i].replace('/home/yangbang/VideoCaptioning', '/Users/yangbang/Desktop/VC_data')
        else:
            option[key] = option[key].replace('/home/yangbang/VideoCaptioning', '/Users/yangbang/Desktop/VC_data')

    loader = get_loader(option, mode=opt.em, print_info=False, specific=opt.specific)
    vocab = loader.dataset.get_vocab()

    if opt.oe:
        metric = run_eval(option, model, None, loader, vocab, device,
                          json_path=opt.json_path, json_name=opt.json_name, print_sent=opt.print_sent, no_score=opt.ns,
                          save_videodatainfo=opt.sv)
        print(metric)
    else:
        encoder_outputs = []
        category = []
        for data in tqdm(loader, ncols=150, leave=False):
            with torch.no_grad():
                results, cate, _, _ = get_forword_results(option, model, data, device=device, only_data=True)
                encoder_outputs.append(results['enc_output'][0])
                category.append(cate)
        encoder_outputs = torch.cat(encoder_outputs, dim=0).cpu().numpy()
        category = torch.cat(category, dim=0).view(-1).cpu().numpy()
        print(encoder_outputs.shape, category.shape)

        encoder_outputs = encoder_outputs.mean(1)
        pca = manifold.TSNE(n_components=opt.pca)
        data = pca.fit_transform(encoder_outputs)
        plot_several_category(data, category)
Exemplo n.º 5
0
def loop_category(option, opt, model, device, teacher_model, dict_mapping):
    loop_logger = CsvLogger(
        filepath='./category_results',
        filename='NAVC_%s_%s%s.csv' %
        (option['method'], 'AE' if opt.nv_scale == 100 else '', opt.paradigm),
        fieldsnames=['novel', 'unique', 'usage', 'ave_length', 'gram4'])
    for i in range(20):
        loader = get_loader(option, mode=opt.em, specific=i)
        vocab = loader.dataset.get_vocab()
        metric = run_eval(option,
                          model,
                          None,
                          loader,
                          vocab,
                          device,
                          print_sent=opt.print_sent,
                          no_score=True,
                          analyze=True,
                          teacher_model=teacher_model,
                          dict_mapping=dict_mapping)
        loop_logger.write(metric)
Exemplo n.º 6
0
def plot(option, opt, model, loader, vocab, device, teacher_model,
         dict_mapping):
    colors = [
        'skyblue', 'dodgerblue', 'dodgerblue', 'dodgerblue', 'dodgerblue'
    ]
    for i, iteration in enumerate([1, 2, 3, 4, 5]):
        option['iterations'] = iteration
        metric, x, y = run_eval(option,
                                model,
                                None,
                                loader,
                                vocab,
                                device,
                                json_path=opt.json_path,
                                json_name=opt.json_name,
                                print_sent=opt.print_sent,
                                teacher_model=teacher_model,
                                length_crit=torch.nn.SmoothL1Loss(),
                                dict_mapping=dict_mapping,
                                length_bias=opt.lb,
                                analyze=True,
                                plot=True,
                                no_score=True,
                                top_n=30)

        ax = plt.subplot(1, 5, i + 1)
        ax.barh(x, y, color=colors[i])
        ax.set_title('iteration=%d' % iteration)
        #plt.tick_params(labelsize=13)
    plt.subplots_adjust(left=0.12,
                        bottom=None,
                        right=0.98,
                        top=None,
                        wspace=0.4,
                        hspace=None)
    fig = plt.gcf()
    fig.set_size_inches(15, 10)
    plt.savefig('./iteration.png')
    plt.show()
Exemplo n.º 7
0
def main():
    '''Main Function'''
    parser = argparse.ArgumentParser(description='translate.py')
    parser.add_argument('-df', '--default', default=False, action='store_true')
    parser.add_argument('-method', '--method', default='ARB', type=str)
    parser.add_argument('-dataset', '--dataset', default='MSRVTT', type=str)
    parser.add_argument('--default_model_name', default='best.pth.tar', type=str)
    parser.add_argument('-scope', '--scope', default='', type=str)
    parser.add_argument('-record', '--record', default=False, action='store_true')
    parser.add_argument('-field', '--field', nargs='+', type=str, default=['seed'])
    parser.add_argument('-val_and_test', '--val_and_test', default=False, action='store_true')

    parser.add_argument('-model_path', '--model_path', type=str)
    parser.add_argument('-teacher_path', '--teacher_path', type=str)

    parser.add_argument('-bs', '--beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-ba', '--beam_alpha', type=float, default=1.0)
    parser.add_argument('-topk', '--topk', type=int, default=1)

    # NA decoding
    parser.add_argument('-i', '--iterations', type=int, default=5)
    parser.add_argument('-lbs', '--length_beam_size', type=int, default=6)
    parser.add_argument('-q', '--q', type=int, default=1)
    parser.add_argument('-qi', '--q_iterations', type=int, default=1)
    parser.add_argument('-paradigm', '--paradigm', type=str, default='mp')
    parser.add_argument('-use_ct', '--use_ct', default=False, action='store_true')
    parser.add_argument('-md', '--masking_decision', default=False, action='store_true')
    parser.add_argument('-ncd', '--no_candidate_decision', default=False, action='store_true')
    parser.add_argument('--algorithm_print_sent', default=False, action='store_true')

    parser.add_argument('-batch_size', '--batch_size', type=int, default=128)
    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-em', '--evaluation_mode', type=str, default='test')
    parser.add_argument('-print_sent', action='store_true')
    parser.add_argument('-json_path', type=str, default='')
    parser.add_argument('-json_name', type=str, default='')
    parser.add_argument('-ns', '--no_score', default=False, action='store_true')
    parser.add_argument('-analyze', default=False, action='store_true')

    parser.add_argument('-latency', default=False, action='store_true')
    
    parser.add_argument('-specific', default=-1, type=int)
    parser.add_argument('-collect_path', type=str, default='./collected_captions')
    parser.add_argument('-collect', default=False, action='store_true')
    parser.add_argument('-nobc', '--not_only_best_candidate', default=False, action='store_true')

    opt = parser.parse_args()

    device = torch.device('cuda' if not opt.no_cuda else 'cpu')
    teacher_model = None
    dict_mapping = {}

    if opt.default:
        if opt.dataset.lower() == 'msvd':
            opt.dataset = 'Youtube2Text'
        opt.model_path = os.path.join(
            Constants.base_checkpoint_path,
            opt.dataset,
            opt.method,
            opt.scope,
            opt.default_model_name
        )
        if opt.method in ['NAB', 'NACF']:
            opt.teacher_path = os.path.join(
                Constants.base_checkpoint_path,
                opt.dataset,
                'ARB',
                opt.scope,
                opt.default_model_name
            )
            assert os.path.exists(opt.teacher_path)
    else:
        assert opt.model_path and os.path.exists(opt.model_path)

    model, option, other_info = load_model_and_opt(opt.model_path, device, return_other_info=True)
    if getattr(opt, 'teacher_path', None) is not None:
        print('Loading teacher model from %s' % opt.teacher_path)
        teacher_model, teacher_option = load_model_and_opt(opt.teacher_path, device)
        dict_mapping = get_dict_mapping(option, teacher_option)

    option['reference'] = option['reference'].replace('msvd_refs.pkl', 'refs.pkl')
    option['info_corpus'] = option['info_corpus'].replace('info_corpus_0.pkl', 'info_corpus.pkl')

    if not opt.default:
        _ = option['dataset']
        option.update(vars(opt))
        option['dataset'] = _
    else:
        if option['decoding_type'] != 'NARFormer':
            option['topk'] = opt.topk
            option['beam_size'] = 5
            option['beam_alpha'] = 1.0
        else:
            option['algorithm_print_sent'] = opt.algorithm_print_sent
            option['paradigm'] = opt.paradigm
            option['iterations'] = 5
            option['length_beam_size'] = 6
            option['beam_alpha'] = 1.35 if opt.dataset == 'MSRVTT' else 1.0
            option['q'] = 1
            option['q_iterations'] = 1 if opt.use_ct else 0
        option['use_ct'] = opt.use_ct
    
    if opt.collect:
        prepare_collect_config(option, opt)

    if opt.latency:
        opt.batch_size = 1
        option['batch_size'] = 1

    if opt.val_and_test:
        modes = ['validate', 'test']
        csv_filenames = ['validation_record.csv', 'testing_record.csv']
    else:
        modes = [opt.evaluation_mode]
        csv_filenames = ['validation_record.csv' if opt.evaluation_mode == 'validate' else 'testing_record.csv']
    
    crit = get_criterion_during_evaluation(option)

    for mode, csv_filename in zip(modes, csv_filenames):
        loader = get_loader(option, mode=mode, print_info=True, specific=opt.specific, batch_size=opt.batch_size)
        vocab = loader.dataset.get_vocab()

        if opt.record:
            summarywriter = SummaryWriter(os.path.join(option['checkpoint_path'], mode))
        else:
            summarywriter = None

        metric = run_eval(option, model, crit, loader, vocab, device, 
            teacher_model=teacher_model,
            dict_mapping=dict_mapping,
            json_path=opt.json_path, 
            json_name=opt.json_name, 
            print_sent=opt.print_sent, 
            no_score=opt.no_score,  
            analyze=True if opt.record else opt.analyze, 
            collect_best_candidate_iterative_results=True if opt.collect else False,
            collect_path=opt.collect_path,
            summarywriter=summarywriter,
            global_step=option['seed']
        )
        
        print(metric)
        if opt.record:
            fieldsnames = ['Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4',
                            'METEOR', 'ROUGE_L', 'CIDEr', 'Sum', 
                            'ave_length', 'novel', 'unique', 'usage']
            if crit is not None:
                fieldsnames += crit.get_fieldsnames()
            logger = CsvLogger(filepath=option['checkpoint_path'], filename=csv_filename,
                                    fieldsnames=fieldsnames + opt.field)
            if 'loss' in metric:
                metric.pop('loss')

            for key in opt.field:
                metric[key] = option[key]
            logger.write(metric)
Exemplo n.º 8
0
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')
    parser.add_argument(
        '-model_path',
        nargs='+',
        type=str,
        default=[
            "/home/yangbang/VideoCaptioning/0219save/Youtube2Text/IEL_ARFormer/EBN1_SS0_NDL1_WC0_MI/",
            #"/home/yangbang/VideoCaptioning/0219save/MSRVTT/IEL_ARFormer/EBN1_SS0_NDL1_WC20_MI/",
            "/home/yangbang/VideoCaptioning/0219save/MSRVTT/IEL_ARFormer/EBN1_SS0_NDL1_WC20_MI_seed920/",
            "/home/yangbang/VideoCaptioning/0219save/VATEX/IEL_ARFormer/EBN1_SS1_NDL1_WC0_M/",
            "/home/yangbang/VideoCaptioning/0219save/MSRVTT/IEL_ARFormer/EBN1_SS0_NDL1_WC20_MI_seed1314_ag/",
        ])
    parser.add_argument(
        '-model_name',
        nargs='+',
        type=str,
        default=[
            '0044_240095_254102_253703_251149_247202.pth.tar',
            #'0028_177617_180524_183734_183213_182417.pth.tar',
            "0011_176183_177176_180332_180729_178864.pth.tar",
            '0099_160093_057474.pth.tar',
            "0025_179448_180500_184018_184037_182508.pth.tar",
        ])
    parser.add_argument('-i', '--index', default=0, type=int)

    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-beam_alpha', type=float, default=1.0)
    parser.add_argument('-batch_size',
                        type=int,
                        default=128,
                        help='Batch size')
    parser.add_argument('-topk',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-em', type=str, default='test')
    parser.add_argument('-print_sent', action='store_true')
    parser.add_argument('-json_path', type=str, default='')
    parser.add_argument('-json_name', type=str, default='')
    parser.add_argument('-ns', default=False, action='store_true')
    parser.add_argument('-sv', default=False, action='store_true')
    parser.add_argument('-analyze', default=False, action='store_true')
    parser.add_argument('-write_time', default=False, action='store_true')
    parser.add_argument('-mid_path', default='best', type=str)
    parser.add_argument('-specific', default=-1, type=int)
    parser.add_argument('-category', default=False, action='store_true')
    parser.add_argument('-sp',
                        '--saved_with_pickle',
                        default=False,
                        action='store_true')
    parser.add_argument('-pp',
                        '--pickle_path',
                        default='./AR_topk_collect_results')
    parser.add_argument('-ca',
                        '--collect_analyze',
                        default=False,
                        action='store_true')

    parser.add_argument('-cv', '--cross_validation', type=int, default=2)

    opt = parser.parse_args()
    opt.model_path = opt.model_path[opt.index]
    opt.model_name = opt.model_name[opt.index]
    if opt.cross_validation == 1:
        source_dataset = 'MSRVTT'
        src_pre = 'msrvtt'
        src_wct = '2'
        target_dataset = 'Youtube2Text'
        tgt_pre = 'msvd'
        tgt_wct = '0'
    else:
        source_dataset = 'Youtube2Text'
        src_pre = 'msvd'
        src_wct = '0'
        target_dataset = 'MSRVTT'
        tgt_pre = 'msrvtt'
        tgt_wct = '2'

    opt_pth = os.path.join(opt.model_path, 'opt_info.json')
    option = json.load(open(opt_pth, 'r'))
    option.update(vars(opt))
    if opt.saved_with_pickle:
        if not os.path.exists(opt.pickle_path):
            os.makedirs(opt.pickle_path)
        string = ''
        if 'Youtube2Text' in opt.model_path:
            dataset_name = 'msvd'
        elif 'MSRVTT' in opt.model_path:
            dataset_name = 'msrvtt'
        if option.get('method', None) == 'ag':
            string = '_ag'
        opt.pickle_path = os.path.join(
            opt.pickle_path, '%s_%d%s.pkl' % (dataset_name, opt.topk, string))

    if opt.collect_analyze:
        collect_analyze(opt)
    else:
        device = torch.device('cuda' if not opt.no_cuda else 'cpu')
        if opt.analyze:
            opt.batch_size = 1
            option['batch_size'] = 1

        #print(option)

        checkpoint = torch.load(
            os.path.join(opt.model_path, opt.mid_path, opt.model_name))
        #option = checkpoint['settings']
        #option.update(vars(opt))
        model = get_model(option)

        model.load_state_dict(checkpoint['state_dict'])
        model.to(device)

        if opt.category:
            loop_category(option, opt, model, device)
        else:
            loader = get_loader(option,
                                mode=opt.em,
                                print_info=True,
                                specific=opt.specific)
            vocab = loader.dataset.get_vocab()

            #print(model)
            calculate_novel = True
            if opt.cross_validation != 2:
                option['dataset'] = target_dataset
                option['feats_a'] = []
                option['feats_m'] = [
                    item.replace(source_dataset, target_dataset)
                    for item in option['feats_m']
                ]
                option['feats_m'] = [
                    item.replace(src_pre, tgt_pre)
                    for item in option['feats_m']
                ]
                option['feats_i'] = [
                    item.replace(source_dataset, target_dataset)
                    for item in option['feats_i']
                ]
                option['feats_i'] = [
                    item.replace(src_pre, tgt_pre)
                    for item in option['feats_i']
                ]

                option['reference'] = option['reference'].replace(
                    source_dataset, target_dataset)
                option['reference'] = option['reference'].replace(
                    src_pre, tgt_pre)

                option['info_corpus'] = option['info_corpus'].replace(
                    source_dataset, target_dataset)
                option['info_corpus'] = option['info_corpus'].replace(
                    src_wct, tgt_wct)
                option['info_corpus'] = option['info_corpus'].replace(
                    'Youtube0Text', 'Youtube2Text')

                loader = get_loader(option,
                                    mode=opt.em,
                                    print_info=True,
                                    specific=opt.specific)
                calculate_novel = False

            print(len(vocab))

            metric = run_eval(option,
                              model,
                              None,
                              loader,
                              vocab,
                              device,
                              json_path=opt.json_path,
                              json_name=opt.json_name,
                              print_sent=opt.print_sent,
                              no_score=opt.ns,
                              save_videodatainfo=opt.sv,
                              analyze=opt.analyze,
                              saved_with_pickle=opt.saved_with_pickle,
                              pickle_path=opt.pickle_path,
                              write_time=opt.write_time,
                              calculate_novel=calculate_novel)

            print(metric)
Exemplo n.º 9
0
def main(opt):
    '''Main Function'''
    if opt.collect:
        if not os.path.exists(opt.collect_path):
            os.makedirs(opt.collect_path)

    device = torch.device('cuda' if not opt.no_cuda else 'cpu')

    model, option = load(opt.model_path,
                         opt.model_name,
                         device,
                         mid_path='best')
    option.update(vars(opt))
    set_seed(option['seed'])

    if not opt.nt:
        #teacher_path = os.path.join(option["checkpoint_path"].replace('NARFormer', 'ARFormer') + '_SS1_0_70')
        #teacher_name = 'teacher.pth.tar'
        #teacher_model, teacher_option = load(teacher_path, teacher_name, device, mid_path='', from_checkpoint=True)

        checkpoint = torch.load(opt.teacher_path)
        teacher_option = checkpoint['settings']
        teacher_model = get_model(teacher_option)
        teacher_model.load_state_dict(checkpoint['state_dict'])
        teacher_model.to(device)

        assert teacher_option['vocab_size'] == option['vocab_size']

        #dict_mapping = get_dict_mapping(option, teacher_option)
        dict_mapping = {}
    else:
        teacher_model = None
        dict_mapping = {}
    '''
    model = get_model(option)
    pth = os.path.join(opt.model_path, 'tmp_models')
    vali_loader = get_loader(option, mode='validate')
    test_loader = get_loader(option, mode='test')
    vocab = vali_loader.dataset.get_vocab()
    logger = CsvLogger(
        filepath=pth, 
        filename='evaluate.csv', 
        fieldsnames=['epoch', 'split', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'METEOR', 'ROUGE_L', 'CIDEr', 'Sum', 'lbs', 'i', 'ba']
        )
    for file in os.listdir(pth):
        if '.pth.tar' not in file:
            continue
        epoch = file.split('_')[1]
        checkpoint = torch.load(os.path.join(pth, file))
        model.load_state_dict(checkpoint['state_dict'])
        model.to(device)

        best = 0
        best_info = ()
        for lbs in range(1, 11):
            for i in [1, 3, 5, 10]:
                option['length_beam_size'] = lbs
                option['iterations'] = i
                metric = run_eval(option, model, None, test_loader, vocab, device, json_path=opt.json_path, json_name=opt.json_name, print_sent=opt.print_sent)
                metric.pop('loss')
                metric['lbs'] = lbs
                metric['ba'] = opt.beam_alpha
                metric['i'] = i
                metric['split'] = 'test'
                metric['epoch'] = epoch
                logger.write(metric)
                if metric['Sum'] > best:
                    best = metric['Sum']
                    best_info = (lbs, i)
                print(lbs, i, metric['Sum'], best)


    '''
    '''
    # rec length predicted results
    rec = {}
    for data in tqdm(loader, ncols=150, leave=False):
        with torch.no_grad():
            results = get_forword_results(option, model, data, device=device, only_data=False)
            for i in range(results['pred_length'].size(0)):
                res = results['pred_length'][i].topk(5)[1].tolist()
                for item in res:
                    rec[item] = rec.get(item, 0) + 1
    for i in range(50):
        if i in rec.keys():
            print(i, rec[i])
    '''
    if opt.plot:
        plot(option, opt, model, loader, vocab, device, teacher_model,
             dict_mapping)
    elif opt.loop:
        loader = get_loader(option, mode=opt.em, print_info=True)
        vocab = loader.dataset.get_vocab()
        #loop_iterations(option, opt, model, loader, vocab, device, teacher_model, dict_mapping)
        loop_length_beam(option, opt, model, loader, vocab, device,
                         teacher_model, dict_mapping)
        #loop_iterations(option, opt, model, device, teacher_model, dict_mapping)
    elif opt.category:
        loop_category(option, opt, model, device, teacher_model, dict_mapping)
    else:
        loader = get_loader(option,
                            mode=opt.em,
                            print_info=True,
                            specific=opt.specific)
        vocab = loader.dataset.get_vocab()
        filename = '%s_%s_%s_i%db%da%03d%s.pkl' % (
            option['dataset'], option['method'],
            ('%s' % ('AE' if opt.nv_scale == 100 else '')) + opt.paradigm,
            opt.iterations, opt.length_beam_size, int(
                100 * opt.beam_alpha), '_all' if opt.em == 'all' else '')
        metric = run_eval(option,
                          model,
                          None,
                          loader,
                          vocab,
                          device,
                          json_path=opt.json_path,
                          json_name=opt.json_name,
                          print_sent=opt.print_sent,
                          teacher_model=teacher_model,
                          length_crit=torch.nn.SmoothL1Loss(),
                          dict_mapping=dict_mapping,
                          analyze=opt.analyze,
                          collect_best_candidate_iterative_results=True
                          if opt.collect else False,
                          collect_path=os.path.join(opt.collect_path,
                                                    filename),
                          no_score=opt.ns,
                          write_time=opt.write_time)
        #collect_path=os.path.join(opt.collect_path, opt.collect),
        print(metric)
Exemplo n.º 10
0
def main(opt):
    device = torch.device('cuda' if not opt.no_cuda else 'cpu')
    opt_pth = os.path.join(opt.model_path, 'opt_info.json')
    option = json.load(open(opt_pth, 'r'))
    option.update(vars(opt))
    #print(option)

    model = get_model(option)
    checkpoint = torch.load(
        os.path.join(opt.model_path, 'best', opt.model_name))
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)

    loader = get_loader(option,
                        mode=opt.em,
                        print_info=False,
                        specific=opt.specific)
    vocab = loader.dataset.get_vocab()

    # rec length predicted results
    rec = {}
    length = len(option['modality']) - sum(option['skip_info'])
    num_gates = 3
    gate_data = [[[] for _ in range(num_gates)] for __ in range(length)]

    opt.pca_name = opt.pca_name + '_%s' % opt.em + (
        '_mean' if opt.mean else '') + ('_all' if opt.all else '')
    opt.pca_path = os.path.join(opt.pca_path, opt.model_path.split('/')[-2])
    print(opt.pca_path)
    print(opt.model_path.split('/'))
    if not os.path.exists(opt.pca_path):
        os.makedirs(opt.pca_path)

    if opt.plot:
        visualize(opt, option)
    elif opt.oe:
        metric = run_eval(option,
                          model,
                          None,
                          loader,
                          vocab,
                          device,
                          json_path=opt.json_path,
                          json_name=opt.json_name,
                          print_sent=opt.print_sent,
                          no_score=opt.ns,
                          save_videodatainfo=opt.sv)
        print(metric)
    else:
        for data in tqdm(loader, ncols=150, leave=False):
            with torch.no_grad():
                results, _, _ = get_forword_results(option,
                                                    model,
                                                    data,
                                                    device=device,
                                                    only_data=True)
                gate = results['gate']
                assert len(gate) == length
                assert len(gate[0]) == num_gates - 1
                assert isinstance(results['enc_output'], list)
                assert len(results['enc_output']) == length
                for i in range(length):
                    gate[i].append(results['enc_output'][i])
                    for j in range(num_gates):
                        gate_data[i][j].append(gate[i][j])

        for i in range(length):
            for j in range(num_gates):
                gate_data[i][j] = torch.cat(
                    gate_data[i][j],
                    dim=0)  #[len_dataset, n_frames, dim_hidden]
                print(i, j, gate_data[i][j][0, 0, :10].tolist())

                if i == 0:
                    data = torch.cat(
                        [gate_data[0][j],
                         torch.cat(gate_data[1][j], dim=0)],
                        dim=0).cpu().numpy()
                else:
                    data = gate_data[i][j].cpu().numpy()
                name = '%s_%d_%d.npy' % (opt.pca_name, i, j)
                if opt.mean:
                    data = data.mean(1)
                    pca = manifold.TSNE(n_components=opt.pca)
                    #pca = PCA(n_components=opt.pca)     #加载PCA算法,设置降维后主成分数目为2
                    #collect = pca.fit_transform(data) #对样本进行降维
                    #print(pca.explained_variance_ratio_)
                    collect = pca.fit_transform(data)  #对样本进行降维
                elif opt.all:
                    bsz, seq_len, dim = data.shape
                    data = data.reshape(bsz * seq_len, dim)
                    pca = manifold.TSNE(n_components=opt.pca)
                    collect = pca.fit_transform(data)  #对样本进行降维
                else:
                    assert len(data.shape) == 3
                    seq_len = data.shape[1]
                    collect = []
                    for nf in range(seq_len):
                        x = data[:, nf, :]
                        pca = manifold.TSNE(n_components=opt.pca)
                        #pca = PCA(n_components=opt.pca)     #加载PCA算法,设置降维后主成分数目为2
                        reduced_x = pca.fit_transform(x)  #对样本进行降维
                        collect.append(reduced_x)
                    collect = np.stack(collect, 1)
                print(name, collect.shape)
                np.save(os.path.join(opt.pca_path, name), collect)

            #print('--------')
            #print(i, resetgate[i].max(2)[0].max(0)[0].tolist())
            #print(i, inputgate[i].max(2)[0].max(0)[0].tolist())
            #print('--------')
            #print(i, resetgate[i].min(2)[0].min(0)[0].tolist())
            #print(i, inputgate[i].min(2)[0].min(0)[0].tolist())
        '''