def loop_length_beam(option, opt, model, loader, vocab, device, teacher_model, dict_mapping): b4 = [] m = [] r = [] c = [] ave_len = [] for lbs in range(1, 11): option['length_beam_size'] = lbs metric = run_eval(option, model, None, loader, vocab, device, json_path=opt.json_path, json_name=opt.json_name, print_sent=opt.print_sent, teacher_model=teacher_model, length_crit=torch.nn.SmoothL1Loss(), dict_mapping=dict_mapping, analyze=True) b4.append(metric["Bleu_4"]) m.append(metric["METEOR"]) r.append(metric["ROUGE_L"]) c.append(metric["CIDEr"]) ave_len.append(metric['ave_length']) print(b4) print(m) print(r) print(c) print(ave_len)
def loop_category(option, opt, model, device): loop_logger = CsvLogger( filepath='./category_results', filename='ARVC_%s%s.csv' % (option['dataset'], '' if option.get('method', None) != 'ag' else '_ag'), #fieldsnames=['novel', 'unique', 'usage', 'ave_length', 'gram4'], fieldsnames=[ 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'METEOR', 'ROUGE_L', 'CIDEr', 'Sum', 'loss' ], ) for i in range(20): loader = get_loader(option, mode=opt.em, specific=i) vocab = loader.dataset.get_vocab() #metric = run_eval(option, model, None, loader, vocab, device, print_sent=opt.print_sent, no_score=True, analyze=True) metric = run_eval(option, model, None, loader, vocab, device, print_sent=opt.print_sent, no_score=False, analyze=False) loop_logger.write(metric)
def loop_iterations(option, opt, model, loader, vocab, device, teacher_model, dict_mapping): loop_logger = CsvLogger(filepath='./loop_results', filename=opt.loop + '.csv', fieldsnames=[ 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'METEOR', 'ROUGE_L', 'CIDEr', 'Sum', 'iterations', 'lbs', 'novel', 'unique', 'usage', 'ave_length' ]) for i in range(1, 11): option['iterations'] = i metric = run_eval(option, model, None, loader, vocab, device, json_path=opt.json_path, json_name=opt.json_name, print_sent=opt.print_sent, teacher_model=teacher_model, length_crit=torch.nn.SmoothL1Loss(), dict_mapping=dict_mapping, analyze=True) metric['iterations'] = option['iterations'] metric['lbs'] = option['length_beam_size'] metric.pop('loss') loop_logger.write(metric)
def main(opt): device = torch.device('cuda' if not opt.no_cuda else 'cpu') opt_pth = os.path.join(opt.model_path, 'opt_info.json') option = json.load(open(opt_pth, 'r')) option.update(vars(opt)) # print(option) model = get_model(option) checkpoint = torch.load(os.path.join(opt.model_path, 'best', opt.model_name)) model.load_state_dict(checkpoint['state_dict']) model.to(device) for key in ['info_corpus', 'reference', 'feats_i', 'feats_m', 'feats_a']: if isinstance(option[key], list): for i in range(len(option[key])): option[key][i] = option[key][i].replace('/home/yangbang/VideoCaptioning', '/Users/yangbang/Desktop/VC_data') else: option[key] = option[key].replace('/home/yangbang/VideoCaptioning', '/Users/yangbang/Desktop/VC_data') loader = get_loader(option, mode=opt.em, print_info=False, specific=opt.specific) vocab = loader.dataset.get_vocab() if opt.oe: metric = run_eval(option, model, None, loader, vocab, device, json_path=opt.json_path, json_name=opt.json_name, print_sent=opt.print_sent, no_score=opt.ns, save_videodatainfo=opt.sv) print(metric) else: encoder_outputs = [] category = [] for data in tqdm(loader, ncols=150, leave=False): with torch.no_grad(): results, cate, _, _ = get_forword_results(option, model, data, device=device, only_data=True) encoder_outputs.append(results['enc_output'][0]) category.append(cate) encoder_outputs = torch.cat(encoder_outputs, dim=0).cpu().numpy() category = torch.cat(category, dim=0).view(-1).cpu().numpy() print(encoder_outputs.shape, category.shape) encoder_outputs = encoder_outputs.mean(1) pca = manifold.TSNE(n_components=opt.pca) data = pca.fit_transform(encoder_outputs) plot_several_category(data, category)
def loop_category(option, opt, model, device, teacher_model, dict_mapping): loop_logger = CsvLogger( filepath='./category_results', filename='NAVC_%s_%s%s.csv' % (option['method'], 'AE' if opt.nv_scale == 100 else '', opt.paradigm), fieldsnames=['novel', 'unique', 'usage', 'ave_length', 'gram4']) for i in range(20): loader = get_loader(option, mode=opt.em, specific=i) vocab = loader.dataset.get_vocab() metric = run_eval(option, model, None, loader, vocab, device, print_sent=opt.print_sent, no_score=True, analyze=True, teacher_model=teacher_model, dict_mapping=dict_mapping) loop_logger.write(metric)
def plot(option, opt, model, loader, vocab, device, teacher_model, dict_mapping): colors = [ 'skyblue', 'dodgerblue', 'dodgerblue', 'dodgerblue', 'dodgerblue' ] for i, iteration in enumerate([1, 2, 3, 4, 5]): option['iterations'] = iteration metric, x, y = run_eval(option, model, None, loader, vocab, device, json_path=opt.json_path, json_name=opt.json_name, print_sent=opt.print_sent, teacher_model=teacher_model, length_crit=torch.nn.SmoothL1Loss(), dict_mapping=dict_mapping, length_bias=opt.lb, analyze=True, plot=True, no_score=True, top_n=30) ax = plt.subplot(1, 5, i + 1) ax.barh(x, y, color=colors[i]) ax.set_title('iteration=%d' % iteration) #plt.tick_params(labelsize=13) plt.subplots_adjust(left=0.12, bottom=None, right=0.98, top=None, wspace=0.4, hspace=None) fig = plt.gcf() fig.set_size_inches(15, 10) plt.savefig('./iteration.png') plt.show()
def main(): '''Main Function''' parser = argparse.ArgumentParser(description='translate.py') parser.add_argument('-df', '--default', default=False, action='store_true') parser.add_argument('-method', '--method', default='ARB', type=str) parser.add_argument('-dataset', '--dataset', default='MSRVTT', type=str) parser.add_argument('--default_model_name', default='best.pth.tar', type=str) parser.add_argument('-scope', '--scope', default='', type=str) parser.add_argument('-record', '--record', default=False, action='store_true') parser.add_argument('-field', '--field', nargs='+', type=str, default=['seed']) parser.add_argument('-val_and_test', '--val_and_test', default=False, action='store_true') parser.add_argument('-model_path', '--model_path', type=str) parser.add_argument('-teacher_path', '--teacher_path', type=str) parser.add_argument('-bs', '--beam_size', type=int, default=5, help='Beam size') parser.add_argument('-ba', '--beam_alpha', type=float, default=1.0) parser.add_argument('-topk', '--topk', type=int, default=1) # NA decoding parser.add_argument('-i', '--iterations', type=int, default=5) parser.add_argument('-lbs', '--length_beam_size', type=int, default=6) parser.add_argument('-q', '--q', type=int, default=1) parser.add_argument('-qi', '--q_iterations', type=int, default=1) parser.add_argument('-paradigm', '--paradigm', type=str, default='mp') parser.add_argument('-use_ct', '--use_ct', default=False, action='store_true') parser.add_argument('-md', '--masking_decision', default=False, action='store_true') parser.add_argument('-ncd', '--no_candidate_decision', default=False, action='store_true') parser.add_argument('--algorithm_print_sent', default=False, action='store_true') parser.add_argument('-batch_size', '--batch_size', type=int, default=128) parser.add_argument('-no_cuda', action='store_true') parser.add_argument('-em', '--evaluation_mode', type=str, default='test') parser.add_argument('-print_sent', action='store_true') parser.add_argument('-json_path', type=str, default='') parser.add_argument('-json_name', type=str, default='') parser.add_argument('-ns', '--no_score', default=False, action='store_true') parser.add_argument('-analyze', default=False, action='store_true') parser.add_argument('-latency', default=False, action='store_true') parser.add_argument('-specific', default=-1, type=int) parser.add_argument('-collect_path', type=str, default='./collected_captions') parser.add_argument('-collect', default=False, action='store_true') parser.add_argument('-nobc', '--not_only_best_candidate', default=False, action='store_true') opt = parser.parse_args() device = torch.device('cuda' if not opt.no_cuda else 'cpu') teacher_model = None dict_mapping = {} if opt.default: if opt.dataset.lower() == 'msvd': opt.dataset = 'Youtube2Text' opt.model_path = os.path.join( Constants.base_checkpoint_path, opt.dataset, opt.method, opt.scope, opt.default_model_name ) if opt.method in ['NAB', 'NACF']: opt.teacher_path = os.path.join( Constants.base_checkpoint_path, opt.dataset, 'ARB', opt.scope, opt.default_model_name ) assert os.path.exists(opt.teacher_path) else: assert opt.model_path and os.path.exists(opt.model_path) model, option, other_info = load_model_and_opt(opt.model_path, device, return_other_info=True) if getattr(opt, 'teacher_path', None) is not None: print('Loading teacher model from %s' % opt.teacher_path) teacher_model, teacher_option = load_model_and_opt(opt.teacher_path, device) dict_mapping = get_dict_mapping(option, teacher_option) option['reference'] = option['reference'].replace('msvd_refs.pkl', 'refs.pkl') option['info_corpus'] = option['info_corpus'].replace('info_corpus_0.pkl', 'info_corpus.pkl') if not opt.default: _ = option['dataset'] option.update(vars(opt)) option['dataset'] = _ else: if option['decoding_type'] != 'NARFormer': option['topk'] = opt.topk option['beam_size'] = 5 option['beam_alpha'] = 1.0 else: option['algorithm_print_sent'] = opt.algorithm_print_sent option['paradigm'] = opt.paradigm option['iterations'] = 5 option['length_beam_size'] = 6 option['beam_alpha'] = 1.35 if opt.dataset == 'MSRVTT' else 1.0 option['q'] = 1 option['q_iterations'] = 1 if opt.use_ct else 0 option['use_ct'] = opt.use_ct if opt.collect: prepare_collect_config(option, opt) if opt.latency: opt.batch_size = 1 option['batch_size'] = 1 if opt.val_and_test: modes = ['validate', 'test'] csv_filenames = ['validation_record.csv', 'testing_record.csv'] else: modes = [opt.evaluation_mode] csv_filenames = ['validation_record.csv' if opt.evaluation_mode == 'validate' else 'testing_record.csv'] crit = get_criterion_during_evaluation(option) for mode, csv_filename in zip(modes, csv_filenames): loader = get_loader(option, mode=mode, print_info=True, specific=opt.specific, batch_size=opt.batch_size) vocab = loader.dataset.get_vocab() if opt.record: summarywriter = SummaryWriter(os.path.join(option['checkpoint_path'], mode)) else: summarywriter = None metric = run_eval(option, model, crit, loader, vocab, device, teacher_model=teacher_model, dict_mapping=dict_mapping, json_path=opt.json_path, json_name=opt.json_name, print_sent=opt.print_sent, no_score=opt.no_score, analyze=True if opt.record else opt.analyze, collect_best_candidate_iterative_results=True if opt.collect else False, collect_path=opt.collect_path, summarywriter=summarywriter, global_step=option['seed'] ) print(metric) if opt.record: fieldsnames = ['Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'METEOR', 'ROUGE_L', 'CIDEr', 'Sum', 'ave_length', 'novel', 'unique', 'usage'] if crit is not None: fieldsnames += crit.get_fieldsnames() logger = CsvLogger(filepath=option['checkpoint_path'], filename=csv_filename, fieldsnames=fieldsnames + opt.field) if 'loss' in metric: metric.pop('loss') for key in opt.field: metric[key] = option[key] logger.write(metric)
def main(): '''Main Function''' parser = argparse.ArgumentParser(description='translate.py') parser.add_argument( '-model_path', nargs='+', type=str, default=[ "/home/yangbang/VideoCaptioning/0219save/Youtube2Text/IEL_ARFormer/EBN1_SS0_NDL1_WC0_MI/", #"/home/yangbang/VideoCaptioning/0219save/MSRVTT/IEL_ARFormer/EBN1_SS0_NDL1_WC20_MI/", "/home/yangbang/VideoCaptioning/0219save/MSRVTT/IEL_ARFormer/EBN1_SS0_NDL1_WC20_MI_seed920/", "/home/yangbang/VideoCaptioning/0219save/VATEX/IEL_ARFormer/EBN1_SS1_NDL1_WC0_M/", "/home/yangbang/VideoCaptioning/0219save/MSRVTT/IEL_ARFormer/EBN1_SS0_NDL1_WC20_MI_seed1314_ag/", ]) parser.add_argument( '-model_name', nargs='+', type=str, default=[ '0044_240095_254102_253703_251149_247202.pth.tar', #'0028_177617_180524_183734_183213_182417.pth.tar', "0011_176183_177176_180332_180729_178864.pth.tar", '0099_160093_057474.pth.tar', "0025_179448_180500_184018_184037_182508.pth.tar", ]) parser.add_argument('-i', '--index', default=0, type=int) parser.add_argument('-beam_size', type=int, default=5, help='Beam size') parser.add_argument('-beam_alpha', type=float, default=1.0) parser.add_argument('-batch_size', type=int, default=128, help='Batch size') parser.add_argument('-topk', type=int, default=1, help="""If verbose is set, will output the n_best decoded sentences""") parser.add_argument('-no_cuda', action='store_true') parser.add_argument('-em', type=str, default='test') parser.add_argument('-print_sent', action='store_true') parser.add_argument('-json_path', type=str, default='') parser.add_argument('-json_name', type=str, default='') parser.add_argument('-ns', default=False, action='store_true') parser.add_argument('-sv', default=False, action='store_true') parser.add_argument('-analyze', default=False, action='store_true') parser.add_argument('-write_time', default=False, action='store_true') parser.add_argument('-mid_path', default='best', type=str) parser.add_argument('-specific', default=-1, type=int) parser.add_argument('-category', default=False, action='store_true') parser.add_argument('-sp', '--saved_with_pickle', default=False, action='store_true') parser.add_argument('-pp', '--pickle_path', default='./AR_topk_collect_results') parser.add_argument('-ca', '--collect_analyze', default=False, action='store_true') parser.add_argument('-cv', '--cross_validation', type=int, default=2) opt = parser.parse_args() opt.model_path = opt.model_path[opt.index] opt.model_name = opt.model_name[opt.index] if opt.cross_validation == 1: source_dataset = 'MSRVTT' src_pre = 'msrvtt' src_wct = '2' target_dataset = 'Youtube2Text' tgt_pre = 'msvd' tgt_wct = '0' else: source_dataset = 'Youtube2Text' src_pre = 'msvd' src_wct = '0' target_dataset = 'MSRVTT' tgt_pre = 'msrvtt' tgt_wct = '2' opt_pth = os.path.join(opt.model_path, 'opt_info.json') option = json.load(open(opt_pth, 'r')) option.update(vars(opt)) if opt.saved_with_pickle: if not os.path.exists(opt.pickle_path): os.makedirs(opt.pickle_path) string = '' if 'Youtube2Text' in opt.model_path: dataset_name = 'msvd' elif 'MSRVTT' in opt.model_path: dataset_name = 'msrvtt' if option.get('method', None) == 'ag': string = '_ag' opt.pickle_path = os.path.join( opt.pickle_path, '%s_%d%s.pkl' % (dataset_name, opt.topk, string)) if opt.collect_analyze: collect_analyze(opt) else: device = torch.device('cuda' if not opt.no_cuda else 'cpu') if opt.analyze: opt.batch_size = 1 option['batch_size'] = 1 #print(option) checkpoint = torch.load( os.path.join(opt.model_path, opt.mid_path, opt.model_name)) #option = checkpoint['settings'] #option.update(vars(opt)) model = get_model(option) model.load_state_dict(checkpoint['state_dict']) model.to(device) if opt.category: loop_category(option, opt, model, device) else: loader = get_loader(option, mode=opt.em, print_info=True, specific=opt.specific) vocab = loader.dataset.get_vocab() #print(model) calculate_novel = True if opt.cross_validation != 2: option['dataset'] = target_dataset option['feats_a'] = [] option['feats_m'] = [ item.replace(source_dataset, target_dataset) for item in option['feats_m'] ] option['feats_m'] = [ item.replace(src_pre, tgt_pre) for item in option['feats_m'] ] option['feats_i'] = [ item.replace(source_dataset, target_dataset) for item in option['feats_i'] ] option['feats_i'] = [ item.replace(src_pre, tgt_pre) for item in option['feats_i'] ] option['reference'] = option['reference'].replace( source_dataset, target_dataset) option['reference'] = option['reference'].replace( src_pre, tgt_pre) option['info_corpus'] = option['info_corpus'].replace( source_dataset, target_dataset) option['info_corpus'] = option['info_corpus'].replace( src_wct, tgt_wct) option['info_corpus'] = option['info_corpus'].replace( 'Youtube0Text', 'Youtube2Text') loader = get_loader(option, mode=opt.em, print_info=True, specific=opt.specific) calculate_novel = False print(len(vocab)) metric = run_eval(option, model, None, loader, vocab, device, json_path=opt.json_path, json_name=opt.json_name, print_sent=opt.print_sent, no_score=opt.ns, save_videodatainfo=opt.sv, analyze=opt.analyze, saved_with_pickle=opt.saved_with_pickle, pickle_path=opt.pickle_path, write_time=opt.write_time, calculate_novel=calculate_novel) print(metric)
def main(opt): '''Main Function''' if opt.collect: if not os.path.exists(opt.collect_path): os.makedirs(opt.collect_path) device = torch.device('cuda' if not opt.no_cuda else 'cpu') model, option = load(opt.model_path, opt.model_name, device, mid_path='best') option.update(vars(opt)) set_seed(option['seed']) if not opt.nt: #teacher_path = os.path.join(option["checkpoint_path"].replace('NARFormer', 'ARFormer') + '_SS1_0_70') #teacher_name = 'teacher.pth.tar' #teacher_model, teacher_option = load(teacher_path, teacher_name, device, mid_path='', from_checkpoint=True) checkpoint = torch.load(opt.teacher_path) teacher_option = checkpoint['settings'] teacher_model = get_model(teacher_option) teacher_model.load_state_dict(checkpoint['state_dict']) teacher_model.to(device) assert teacher_option['vocab_size'] == option['vocab_size'] #dict_mapping = get_dict_mapping(option, teacher_option) dict_mapping = {} else: teacher_model = None dict_mapping = {} ''' model = get_model(option) pth = os.path.join(opt.model_path, 'tmp_models') vali_loader = get_loader(option, mode='validate') test_loader = get_loader(option, mode='test') vocab = vali_loader.dataset.get_vocab() logger = CsvLogger( filepath=pth, filename='evaluate.csv', fieldsnames=['epoch', 'split', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'METEOR', 'ROUGE_L', 'CIDEr', 'Sum', 'lbs', 'i', 'ba'] ) for file in os.listdir(pth): if '.pth.tar' not in file: continue epoch = file.split('_')[1] checkpoint = torch.load(os.path.join(pth, file)) model.load_state_dict(checkpoint['state_dict']) model.to(device) best = 0 best_info = () for lbs in range(1, 11): for i in [1, 3, 5, 10]: option['length_beam_size'] = lbs option['iterations'] = i metric = run_eval(option, model, None, test_loader, vocab, device, json_path=opt.json_path, json_name=opt.json_name, print_sent=opt.print_sent) metric.pop('loss') metric['lbs'] = lbs metric['ba'] = opt.beam_alpha metric['i'] = i metric['split'] = 'test' metric['epoch'] = epoch logger.write(metric) if metric['Sum'] > best: best = metric['Sum'] best_info = (lbs, i) print(lbs, i, metric['Sum'], best) ''' ''' # rec length predicted results rec = {} for data in tqdm(loader, ncols=150, leave=False): with torch.no_grad(): results = get_forword_results(option, model, data, device=device, only_data=False) for i in range(results['pred_length'].size(0)): res = results['pred_length'][i].topk(5)[1].tolist() for item in res: rec[item] = rec.get(item, 0) + 1 for i in range(50): if i in rec.keys(): print(i, rec[i]) ''' if opt.plot: plot(option, opt, model, loader, vocab, device, teacher_model, dict_mapping) elif opt.loop: loader = get_loader(option, mode=opt.em, print_info=True) vocab = loader.dataset.get_vocab() #loop_iterations(option, opt, model, loader, vocab, device, teacher_model, dict_mapping) loop_length_beam(option, opt, model, loader, vocab, device, teacher_model, dict_mapping) #loop_iterations(option, opt, model, device, teacher_model, dict_mapping) elif opt.category: loop_category(option, opt, model, device, teacher_model, dict_mapping) else: loader = get_loader(option, mode=opt.em, print_info=True, specific=opt.specific) vocab = loader.dataset.get_vocab() filename = '%s_%s_%s_i%db%da%03d%s.pkl' % ( option['dataset'], option['method'], ('%s' % ('AE' if opt.nv_scale == 100 else '')) + opt.paradigm, opt.iterations, opt.length_beam_size, int( 100 * opt.beam_alpha), '_all' if opt.em == 'all' else '') metric = run_eval(option, model, None, loader, vocab, device, json_path=opt.json_path, json_name=opt.json_name, print_sent=opt.print_sent, teacher_model=teacher_model, length_crit=torch.nn.SmoothL1Loss(), dict_mapping=dict_mapping, analyze=opt.analyze, collect_best_candidate_iterative_results=True if opt.collect else False, collect_path=os.path.join(opt.collect_path, filename), no_score=opt.ns, write_time=opt.write_time) #collect_path=os.path.join(opt.collect_path, opt.collect), print(metric)
def main(opt): device = torch.device('cuda' if not opt.no_cuda else 'cpu') opt_pth = os.path.join(opt.model_path, 'opt_info.json') option = json.load(open(opt_pth, 'r')) option.update(vars(opt)) #print(option) model = get_model(option) checkpoint = torch.load( os.path.join(opt.model_path, 'best', opt.model_name)) model.load_state_dict(checkpoint['state_dict']) model.to(device) loader = get_loader(option, mode=opt.em, print_info=False, specific=opt.specific) vocab = loader.dataset.get_vocab() # rec length predicted results rec = {} length = len(option['modality']) - sum(option['skip_info']) num_gates = 3 gate_data = [[[] for _ in range(num_gates)] for __ in range(length)] opt.pca_name = opt.pca_name + '_%s' % opt.em + ( '_mean' if opt.mean else '') + ('_all' if opt.all else '') opt.pca_path = os.path.join(opt.pca_path, opt.model_path.split('/')[-2]) print(opt.pca_path) print(opt.model_path.split('/')) if not os.path.exists(opt.pca_path): os.makedirs(opt.pca_path) if opt.plot: visualize(opt, option) elif opt.oe: metric = run_eval(option, model, None, loader, vocab, device, json_path=opt.json_path, json_name=opt.json_name, print_sent=opt.print_sent, no_score=opt.ns, save_videodatainfo=opt.sv) print(metric) else: for data in tqdm(loader, ncols=150, leave=False): with torch.no_grad(): results, _, _ = get_forword_results(option, model, data, device=device, only_data=True) gate = results['gate'] assert len(gate) == length assert len(gate[0]) == num_gates - 1 assert isinstance(results['enc_output'], list) assert len(results['enc_output']) == length for i in range(length): gate[i].append(results['enc_output'][i]) for j in range(num_gates): gate_data[i][j].append(gate[i][j]) for i in range(length): for j in range(num_gates): gate_data[i][j] = torch.cat( gate_data[i][j], dim=0) #[len_dataset, n_frames, dim_hidden] print(i, j, gate_data[i][j][0, 0, :10].tolist()) if i == 0: data = torch.cat( [gate_data[0][j], torch.cat(gate_data[1][j], dim=0)], dim=0).cpu().numpy() else: data = gate_data[i][j].cpu().numpy() name = '%s_%d_%d.npy' % (opt.pca_name, i, j) if opt.mean: data = data.mean(1) pca = manifold.TSNE(n_components=opt.pca) #pca = PCA(n_components=opt.pca) #加载PCA算法,设置降维后主成分数目为2 #collect = pca.fit_transform(data) #对样本进行降维 #print(pca.explained_variance_ratio_) collect = pca.fit_transform(data) #对样本进行降维 elif opt.all: bsz, seq_len, dim = data.shape data = data.reshape(bsz * seq_len, dim) pca = manifold.TSNE(n_components=opt.pca) collect = pca.fit_transform(data) #对样本进行降维 else: assert len(data.shape) == 3 seq_len = data.shape[1] collect = [] for nf in range(seq_len): x = data[:, nf, :] pca = manifold.TSNE(n_components=opt.pca) #pca = PCA(n_components=opt.pca) #加载PCA算法,设置降维后主成分数目为2 reduced_x = pca.fit_transform(x) #对样本进行降维 collect.append(reduced_x) collect = np.stack(collect, 1) print(name, collect.shape) np.save(os.path.join(opt.pca_path, name), collect) #print('--------') #print(i, resetgate[i].max(2)[0].max(0)[0].tolist()) #print(i, inputgate[i].max(2)[0].max(0)[0].tolist()) #print('--------') #print(i, resetgate[i].min(2)[0].min(0)[0].tolist()) #print(i, inputgate[i].min(2)[0].min(0)[0].tolist()) '''