def main(): cmd_args = add_argument() path_to_file_tr = cmd_args.path_to_file_tr path_to_file_ts = cmd_args.path_to_file_ts min_len_mol = cmd_args.min_len_mol max_len_mol = cmd_args.max_len_mol num_examples_tr = cmd_args.num_examples_tr num_examples_ts = cmd_args.num_examples_ts train_batch_size = json.load(open(cmd_args.ds_conf))['train_batch_size'] gradient_accumulation_steps = json.load(open( cmd_args.ds_conf))['gradient_accumulation_steps'] deepspeed_optimizer = True if json.load(open(cmd_args.ds_conf)).get( 'optimizer', None) is not None else False epochs = cmd_args.epochs emb_dim = cmd_args.emb_dim dim = cmd_args.dim bucket_size = cmd_args.bucket_size depth = cmd_args.depth heads = cmd_args.heads n_hashes = cmd_args.n_hashes ff_chunks = cmd_args.ff_chunks attn_chunks = cmd_args.attn_chunks validate_every = cmd_args.validate_every save_every = cmd_args.save_every output_folder = cmd_args.output_folder use_full_attn = cmd_args.use_full_attn mrpc_test = cmd_args.mrpc_test use_deepspeed = cmd_args.use_deepspeed os.makedirs(output_folder, exist_ok=True) pickle.dump(cmd_args, open(os.sep.join([output_folder, 'training_conf.pkl']), 'wb')) MIN_LENGTH_MOL = min_len_mol MAX_LENGTH_MOL = max_len_mol # 2048 NUM_EXAMPLES_TR = num_examples_tr # 1024 NUM_EXAMPLES_TS = num_examples_ts # 1024 N_EPOCHS = epochs # 10 VALIDATE_EVERY = validate_every SAVE_EVERY = save_every MOL_SEQ_LEN = MAX_LENGTH_MOL # output_lang.max_len if (output_lang.max_len % 2) == 0 else output_lang.max_len + 1 # ?? saved_mol_lang = os.sep.join([output_folder, 'mol_lang.pkl']) MAX_LENGTH_MOL = cmd_args.max_len_mol saved_target_lang = os.sep.join([output_folder, 'mol_lang.pkl']) if mrpc_test: mol_lang, tr_samples, ts_samples = readMRPC( molecule_file_tr=path_to_file_tr, molecule_file_ts=path_to_file_ts, saved_molecule_lang=saved_target_lang, num_examples_tr=NUM_EXAMPLES_TR, num_examples_ts=NUM_EXAMPLES_TS, min_len_molecule=MIN_LENGTH_MOL, max_len_molecule=MAX_LENGTH_MOL, shuffle=True) else: mol_lang, tr_samples, ts_samples = readMolecules( molecule_file_tr=path_to_file_tr, molecule_file_ts=path_to_file_ts, saved_molecule_lang=saved_target_lang, num_examples_tr=NUM_EXAMPLES_TR, num_examples_ts=NUM_EXAMPLES_TS, min_len_molecule=MIN_LENGTH_MOL, max_len_molecule=MAX_LENGTH_MOL, shuffle=True) pickle.dump(mol_lang, open(saved_mol_lang, 'wb')) train_dataset = MolecularSimilarityDataset( tr_samples, mol_lang, train_batch_size if device == 'cuda' else 1) test_dataset = MolecularSimilarityDataset( ts_samples, mol_lang, train_batch_size if device == 'cuda' else 1) MAX_SEQ_LEN = MOL_SEQ_LEN * 2 print('Axial Embedding shape:', compute_axial_position_shape(MAX_SEQ_LEN)) model = ReformerLM( num_tokens=mol_lang.n_words, dim=dim, bucket_size=bucket_size, depth=depth, heads=heads, n_hashes=n_hashes, max_seq_len=MAX_SEQ_LEN, ff_chunks=ff_chunks, attn_chunks=attn_chunks, weight_tie=True, weight_tie_embedding=True, axial_position_emb=True, axial_position_shape=compute_axial_position_shape(MAX_SEQ_LEN), axial_position_dims=(dim // 2, dim // 2), return_embeddings=True, use_full_attn=use_full_attn).to(device) linear_regressor = Linear(512, 2).to(device) model = TrainingWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX).to(device) model_params = filter(lambda p: p.requires_grad, model.parameters()) linear_params = filter(lambda p: p.requires_grad, linear_regressor.parameters()) SAVE_DIR = os.sep.join([output_folder, 'saved_model']) os.makedirs(SAVE_DIR, exist_ok=True) try: model_ckp_max = np.max( [int(ckp) for ckp in os.listdir(os.sep.join([SAVE_DIR, 'model']))]) except: model_ckp_max = 0 gpus_mini_batch = (train_batch_size // gradient_accumulation_steps ) // torch.cuda.device_count() print('gpus_mini_batch:', gpus_mini_batch, 'with gradient_accumulation_steps:', gradient_accumulation_steps) log_file = open(os.sep.join([output_folder, 'training_log.log']), 'a') log_file.write( "\n\n\n{}\tStarting new training from chekpoint: EncoderDecoder-{}\n". format(datetime.datetime.now(), model_ckp_max)) log_file.flush() if use_deepspeed: if deepspeed_optimizer == False: print('No DeepSpeed optimizer found. Using RangerLars.') model_optimizer = RangerLars(model.parameters()) linear_optimizer = RangerLars(linear_regressor.parameters()) model_engine, model_optimizer, trainloader, _ = deepspeed.initialize( args=cmd_args, model=model, optimizer=model_optimizer, model_parameters=model_params, training_data=train_dataset) linear_engine, linear_optimizer, _, _ = deepspeed.initialize( args=cmd_args, model=linear_regressor, optimizer=linear_optimizer, model_parameters=linear_params) else: print('Found optimizer in the DeepSpeed configurations. Using it.') model_engine, model_optimizer, trainloader, _ = deepspeed.initialize( args=cmd_args, model=model, model_parameters=model_params, training_data=train_dataset) linear_engine, linear_optimizer, _, _ = deepspeed.initialize( args=cmd_args, model=linear_regressor, model_parameters=linear_params) _, model_client_sd = model_engine.load_checkpoint( os.sep.join([SAVE_DIR, 'model']), model_ckp_max) testloader = model_engine.deepspeed_io(test_dataset) ######TO DO for eph in range(epochs): print('Starting Epoch: {}'.format(eph)) for i, pair in enumerate(tqdm(trainloader)): tr_step = ((eph * len(trainloader)) + i) + 1 src = pair[0] trg = pair[1] pickle.dump(src, open('src.pkl', 'wb')) pickle.dump(trg, open('trg.pkl', 'wb')) model_engine.train() linear_engine.train() #enc_dec.train() src = src.to(model_engine.local_rank) trg = trg.to(linear_engine.local_rank) print("Sample:", src) print("Target:", trg) print("Target Shape:", trg.shape) print("len Samples:", len(src)) ## Need to learn how to use masks correctly enc_input_mask = torch.tensor( [[1 if idx != PAD_IDX else 0 for idx in smpl] for smpl in src]).bool().to(model_engine.local_rank) # context_mask = torch.tensor([[1 for idx in smpl if idx != PAD_IDX] for smpl in trg]).bool().to(device) ################# enc_keys = model_engine( src, return_loss=False, input_mask=enc_input_mask ) #enc_input_mask)#, context_mask=context_mask) #loss = enc_dec(src, trg, return_loss = True, enc_input_mask = None)#enc_input_mask)#, context_mask=context_mask) print('enc_keys shape', enc_keys.shape) #enc_keys_cls = enc_keys[:,0:1,:].to(linear_engine.local_rank)#torch.tensor([s[0] for s in enc_keys]).to(linear_engine.local_rank) #print('enc_keys_cls shape', enc_keys_cls.shape) preds = torch.softmax(linear_engine(enc_keys), dim=1).to(linear_engine.local_rank) print('preds shape', preds.shape) #preds = np.array([r[0] for r in results]) #print('Pred:', preds.shape) loss = F.cross_entropy(preds, trg).to(linear_engine.local_rank) loss.backward() model_engine.step() linear_engine.step() print('Training Loss:', loss.item()) if tr_step % validate_every == 0: val_loss = [] for pair in tqdm( testloader ): #Can't use the testloader or I will mess up with the model assignment and it won't learn during training, need to use normal validation instead of parallel one model_engine.eval() linear_engine.eval() with torch.no_grad(): ts_src = pair[0] ts_trg = pair[1] pickle.dump(ts_src, open('ts_src.pkl', 'wb')) pickle.dump(ts_trg, open('ts_trg.pkl', 'wb')) ts_src = ts_src.to(model_engine.local_rank) ts_trg = ts_trg.to(linear_engine.local_rank) #ts_src = torch.tensor(np.array([pair[0].numpy()])).to(device) #ts_trg = torch.tensor(np.array([pair[1].numpy()])).to(device) ## Need to learn how to use masks correctly ts_enc_input_mask = torch.tensor([ [1 if idx != PAD_IDX else 0 for idx in smpl] for smpl in ts_src ]).bool().to(model_engine.local_rank) #ts_context_mask = torch.tensor([[1 for idx in smpl if idx != PAD_IDX] for smpl in ts_trg]).bool().to(device) # loss = model_engine( # ts_src, # ts_trg, # return_loss=True, # enc_input_mask=ts_enc_input_mask # ) #ts_enc_input_mask)#, context_mask=ts_context_mask) # #loss = enc_dec(ts_src, ts_trg, return_loss = True, enc_input_mask = None) ts_enc_keys = model_engine( ts_src, return_loss=False, input_mask=ts_enc_input_mask) ts_pred = torch.softmax( linear_engine(ts_enc_keys), dim=1).to(linear_engine.local_rank) loss = F.cross_entropy(ts_pred, ts_trg).to( linear_engine.local_rank) val_loss.append(loss.item()) print( f'\tValidation Loss: AVG: {np.mean(val_loss)}, MEDIAN: {np.median(val_loss)}, STD: {np.std(val_loss)} ' ) log_file.write( 'Step: {}\tTraining Loss:{}\t Validation LOSS: AVG: {}| MEDIAN: {}| STD: {}\n' .format(i, loss.item(), np.mean(val_loss), np.median(val_loss), np.std(val_loss))) else: log_file.write('Step: {}\tTraining Loss:{}\n'.format( i, loss.item())) log_file.flush() if tr_step % save_every == 0: print('\tSaving Checkpoint') model_ckpt_id = str(model_ckp_max + tr_step + 1) model_engine.save_checkpoint( os.sep.join([SAVE_DIR, 'model']), model_ckpt_id) log_file.close() print('\tSaving Final Checkpoint') model_ckpt_id = str(model_ckp_max + tr_step + 1) model_engine.save_checkpoint(os.sep.join([SAVE_DIR, 'model']), model_ckpt_id) else: #model_optimizer = torch.optim.Adam(model.parameters()) # RangerLars(model.parameters()) #linear_optimizer = torch.optim.Adam(linear_regressor.parameters()) # RangerLars(linear_regressor.parameters()) model_optimizer = torch.optim.Adam( list(model.parameters()) + list(linear_regressor.parameters()) ) #RangerLars(list(model.parameters())+list(linear_regressor.parameters())) # PATH = os.sep.join( [SAVE_DIR, 'model', str(model_ckp_max), 'sts_model.pt']) if os.path.exists(PATH): print('********** Found Checkpoint. Loading:', PATH) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) linear_regressor.load_state_dict(checkpoint['linear_state_dict']) model_optimizer.load_state_dict(checkpoint['optimizer_state_dict']) trainloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=False) testloader = DataLoader(test_dataset, batch_size=train_batch_size, shuffle=False) ######TO DO train_loss_list = [] for eph in range(epochs): print('Starting Epoch: {}'.format(eph)) for i, pair in enumerate(tqdm(trainloader)): tr_step = ((eph * len(trainloader)) + i) + 1 src = pair[0] trg = pair[1] pickle.dump(src, open('src.pkl', 'wb')) pickle.dump(trg, open('trg.pkl', 'wb')) model.train() linear_regressor.train() #enc_dec.train() src = src.to(device) trg = trg.to(device) #print("Sample:", src) #print("Target:", trg) #print("Target Shape:", trg.shape) #print("len Samples:", len(src)) ## Need to learn how to use masks correctly enc_input_mask = torch.tensor( [[1 if idx != PAD_IDX else 0 for idx in smpl] for smpl in src]).bool().to(device) # context_mask = torch.tensor([[1 for idx in smpl if idx != PAD_IDX] for smpl in trg]).bool().to(device) ################# enc_keys = model( src, return_loss=False, input_mask=enc_input_mask ) #enc_input_mask)#, context_mask=context_mask) #loss = enc_dec(src, trg, return_loss = True, enc_input_mask = None)#enc_input_mask)#, context_mask=context_mask) #print('enc_keys shape', enc_keys.shape) enc_keys_cls = enc_keys[:, 0, :].to( device ) #torch.tensor([s[0] for s in enc_keys]).to(linear_engine.local_rank) #print('enc_keys_cls shape', enc_keys_cls.shape) preds = torch.softmax(linear_regressor(enc_keys_cls), dim=1).to(device) #print('preds shape', preds.shape) #preds = np.array([r[0] for r in results]) #print('Pred:', preds.shape) loss = F.cross_entropy(preds, trg).to(device) loss.backward() model_optimizer.step() #linear_optimizer.step() train_loss_list.append(loss.item()) #print('Training Loss:', loss.item()) if tr_step % validate_every == 0: val_loss = [] ACC_list = [] MCC_list = [] for pair in tqdm( testloader ): #Can't use the testloader or I will mess up with the model assignment and it won't learn during training, need to use normal validation instead of parallel one model.eval() linear_regressor.eval() with torch.no_grad(): ts_src = pair[0] ts_trg = pair[1] pickle.dump(ts_src, open('ts_src.pkl', 'wb')) pickle.dump(ts_trg, open('ts_trg.pkl', 'wb')) ts_src = ts_src.to(device) ts_trg = ts_trg.to(device) #ts_src = torch.tensor(np.array([pair[0].numpy()])).to(device) #ts_trg = torch.tensor(np.array([pair[1].numpy()])).to(device) ## Need to learn how to use masks correctly ts_enc_input_mask = torch.tensor( [[1 if idx != PAD_IDX else 0 for idx in smpl] for smpl in ts_src]).bool().to(device) #ts_context_mask = torch.tensor([[1 for idx in smpl if idx != PAD_IDX] for smpl in ts_trg]).bool().to(device) # loss = model_engine( # ts_src, # ts_trg, # return_loss=True, # enc_input_mask=ts_enc_input_mask # ) #ts_enc_input_mask)#, context_mask=ts_context_mask) # #loss = enc_dec(ts_src, ts_trg, return_loss = True, enc_input_mask = None) ts_enc_keys = model(ts_src, return_loss=False, input_mask=ts_enc_input_mask) ts_enc_keys_cls = ts_enc_keys[:, 0, :].to(device) ts_pred = torch.softmax( linear_regressor(ts_enc_keys_cls), dim=1).to(device) loss = F.cross_entropy(ts_pred, ts_trg).to(device) ACC, MCC = compute_simple_metrics(ts_pred, ts_trg) ACC_list.append(ACC) MCC_list.append(MCC) val_loss.append(loss.item()) print( f'\Train Loss: LAST: {train_loss_list[-1]}, AVG: {np.mean(train_loss_list)}, MEDIAN: {np.median(train_loss_list)}, STD: {np.std(train_loss_list)} ' ) print( f'\tValidation Loss: AVG: {np.mean(val_loss)}, MEDIAN: {np.median(val_loss)}, STD: {np.std(val_loss)} ' ) print( f'\tValidation ACC: AVG: {np.mean(ACC_list)}, MEDIAN: {np.median(ACC_list)}, STD: {np.std(ACC_list)} ' ) print( f'\tValidation MCC: AVG: {np.mean(MCC_list)}, MEDIAN: {np.median(MCC_list)}, STD: {np.std(MCC_list)} ' ) log_file.write( 'Step: {}\tTraining Loss:{}\t Validation LOSS: AVG: {}| MEDIAN: {}| STD: {}\n' .format(i, loss.item(), np.mean(val_loss), np.median(val_loss), np.std(val_loss))) else: log_file.write('Step: {}\tTraining Loss:{}\n'.format( i, loss.item())) log_file.flush() if tr_step % save_every == 0: print('\tSaving Checkpoint') model_ckpt_id = str(model_ckp_max + tr_step + 1) #model_engine.save_checkpoint(os.sep.join([SAVE_DIR, 'model']), # model_ckpt_id) PATH = os.sep.join([ SAVE_DIR, 'model', str(model_ckpt_id), 'sts_model.pt' ]) os.makedirs(os.sep.join(PATH.split(os.sep)[:-1]), exist_ok=True) torch.save( { 'step': tr_step, 'model_state_dict': model.state_dict(), 'linear_state_dict': linear_regressor.state_dict(), 'optimizer_state_dict': model_optimizer.state_dict(), }, PATH) log_file.close() print('\tSaving Final Checkpoint') model_ckpt_id = str(model_ckp_max + tr_step + 1) #model_engine.save_checkpoint(os.sep.join([SAVE_DIR, 'model']), # model_ckpt_id) PATH = os.sep.join( [SAVE_DIR, 'model', str(model_ckpt_id), 'sts_model.pt']) os.makedirs(os.sep.join(PATH.split(os.sep)[:-1]), exist_ok=True) torch.save( { 'step': tr_step, 'model_state_dict': model.state_dict(), 'linear_state_dict': linear_regressor.state_dict(), 'optimizer_state_dict': model_optimizer.state_dict(), }, PATH)
model=model, model_parameters=model.parameters(), training_data=train_dataset) # training for i, data in enumerate(trainloader): model_engine.train() data = data.to(model_engine.local_rank) loss = model_engine(data, return_loss=True) model_engine.backward(loss) model_engine.step() print(loss.item()) if i % VALIDATE_EVERY == 0: model.eval() with torch.no_grad(): inp = random.choice(val_dataset)[:-1] loss = model(inp[None, :].cuda(), return_loss=True) print(f'validation loss: {loss.item()}') if i % GENERATE_EVERY == 0: model.eval() inp = random.choice(val_dataset)[:-1] prime = decode_tokens(inp) print(f'%s \n\n %s', (prime, '*' * 100)) sample = model.generate(inp.cuda(), GENERATE_LENGTH) output_str = decode_tokens(sample) print(output_str)
def test_encdec_v1(input_lang, target_lang, dim, bucket_size, depth, heads, n_hashes, vir_seq_len, ff_chunks, attn_chunks, mol_seq_len, cmd_args, train_dataset, test_dataset, output_folder, train_batch_size, epochs, validate_every, save_every, checkpoint_id, deepspeed_optimizer, use_full_attn, gradient_accumulation_steps, filter_thres): results = { 'generated_seq': [], 'generated_mol': [], 'target_mol': [], 'input_genome': [] } encoder = ReformerLM( num_tokens=input_lang.n_words, dim=dim, bucket_size=bucket_size, depth=depth, heads=heads, n_hashes=n_hashes, max_seq_len=vir_seq_len, ff_chunks=ff_chunks, attn_chunks=attn_chunks, weight_tie=True, weight_tie_embedding=True, axial_position_emb=True, axial_position_shape=compute_axial_position_shape(vir_seq_len), axial_position_dims=(dim // 2, dim // 2), return_embeddings=True, use_full_attn=use_full_attn).to(device) decoder = ReformerLM( num_tokens=target_lang.n_words, dim=dim, bucket_size=bucket_size, depth=depth, heads=heads, n_hashes=n_hashes, ff_chunks=ff_chunks, attn_chunks=attn_chunks, max_seq_len=mol_seq_len, axial_position_emb=True, axial_position_shape=compute_axial_position_shape(mol_seq_len), axial_position_dims=(dim // 2, dim // 2), weight_tie=True, weight_tie_embedding=True, causal=True, use_full_attn=use_full_attn).to(device) SAVE_DIR = os.sep.join([output_folder, 'saved_model']) if checkpoint_id: enc_ckp_max = checkpoint_id dec_ckp_max = checkpoint_id else: try: enc_ckp_max = np.max([ int(ckp) for ckp in os.listdir(os.sep.join([SAVE_DIR, 'encoder'])) ]) except Exception as e: print('Exception:', e) enc_ckp_max = 0 try: dec_ckp_max = np.max([ int(ckp) for ckp in os.listdir(os.sep.join([SAVE_DIR, 'decoder'])) ]) except: dec_ckp_max = 0 encoder = TrainingWrapper(encoder, ignore_index=PAD_IDX, pad_value=PAD_IDX).to(device) decoder = TrainingWrapper(decoder, ignore_index=PAD_IDX, pad_value=PAD_IDX).to(device) ''' encoder_params = filter(lambda p: p.requires_grad, encoder.parameters()) decoder_params = filter(lambda p: p.requires_grad, decoder.parameters()) if deepspeed_optimizer == False: print('No DeepSpeed optimizer found. Using RangerLars.') encoder_optimizer = RangerLars(encoder.parameters()) decoder_optimizer = RangerLars(decoder.parameters()) encoder_engine, encoder_optimizer, trainloader, _ = deepspeed.initialize( args=cmd_args, model=encoder, optimizer=encoder_optimizer, model_parameters=encoder_params, training_data=train_dataset, dist_init_required=True ) decoder_engine, decoder_optimizer, testloader, _ = deepspeed.initialize( args=cmd_args, model=decoder, optimizer=decoder_optimizer, model_parameters=decoder_params, training_data=test_dataset, dist_init_required=False ) else: print('Found optimizer in the DeepSpeed configurations. Using it.') encoder_engine, encoder_optimizer, trainloader, _ = deepspeed.initialize(args=cmd_args, model=encoder, model_parameters=encoder_params, training_data=train_dataset, dist_init_required=True) decoder_engine, decoder_optimizer, testloader, _ = deepspeed.initialize(args=cmd_args, model=decoder, model_parameters=decoder_params, training_data=test_dataset, dist_init_required=False) _, encoder_client_sd = encoder_engine.load_checkpoint(os.sep.join([SAVE_DIR,'encoder']), enc_ckp_max) _, decoder_client_sd = decoder_engine.load_checkpoint(os.sep.join([SAVE_DIR,'decoder']), dec_ckp_max) gpus_mini_batch = (train_batch_size// gradient_accumulation_steps) // torch.cuda.device_count() print('gpus_mini_batch:', gpus_mini_batch, 'with gradient_accumulation_steps:', gradient_accumulation_steps) for pair in tqdm(testloader): encoder_engine.eval() decoder_engine.eval() encoder.eval() decoder.eval() with torch.no_grad(): ts_src = pair[0] ts_trg = pair[1] input_genome = [[input_lang.index2word[gen_idx.item()] for gen_idx in smpl] for smpl in pair[0]] target_mol = [[target_lang.index2word[mol_idx.item()] for mol_idx in smpl] for smpl in pair[1]] ts_src = ts_src.to(encoder_engine.local_rank) #ts_src.to(device) # ts_trg = ts_trg.to(decoder_engine.local_rank) #ts_trg.to(device) # print('ts_src.shape', ts_src.shape) print('ts_src.shape', ts_trg.shape) enc_keys = encoder(ts_src) #encoder_engine(ts_src) yi = torch.tensor([[SOS_token] for _ in range(gpus_mini_batch)]).long().to(decoder_engine.local_rank) #to(device) # #sample = decoder_engine.generate(yi, mol_seq_len, filter_logits_fn=top_p, filter_thres=0.95, keys=enc_keys, eos_token = EOS_token) sample = decoder.generate(yi, mol_seq_len, filter_logits_fn=top_p, filter_thres=0.95, keys=enc_keys, eos_token = EOS_token) actual_mol = [] for mol_seq in sample.cpu().numpy(): for mol_idx in mol_seq: actual_mol.append(target_lang.index2word[mol_idx]) print('Generated Seq:', sample) print('Generated Mol:', actual_mol) print('Real Mol:', target_mol[:target_mol.index(target_lang.index2word[EOS_token])]) results['generated_seq'].append(sample) results['generated_mol'].append(actual_mol) results['target_mol'].append(target_mol) results['input_genome'].append(input_genome) print('Saving Test Results..') pickle.dump(results, open(os.sep.join([output_folder,'test_results.pkl']), 'wb')) ''' encoder_checkpoint = os.sep.join([ output_folder, 'saved_model', 'encoder', enc_ckp_max, 'mp_rank_00_model_states.pt' ]) decoder_checkpoint = os.sep.join([ output_folder, 'saved_model', 'decoder', dec_ckp_max, 'mp_rank_00_model_states.pt' ]) encoder.load_state_dict( torch.load(encoder_checkpoint, map_location=torch.device(device))['module']) decoder.load_state_dict( torch.load(decoder_checkpoint, map_location=torch.device(device))['module']) real_batch_size = train_batch_size // gradient_accumulation_steps test_loader = DataLoader(dataset=test_dataset, batch_size=real_batch_size, shuffle=True) if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs encoder = nn.DataParallel(encoder) decoder = nn.DataParallel(decoder) encoder.to(device) decoder.to(device) for pair in tqdm(test_loader): encoder.eval() decoder.eval() with torch.no_grad(): ts_src = torch.tensor(np.array([pair[0].numpy()])).to(device) ts_trg = torch.tensor(np.array([pair[1].numpy()])).to(device) input_genome = [ input_lang.index2word[gen_idx.item()] for gen_idx in pair[0] ] target_mol = [ target_lang.index2word[mol_idx.item()] for mol_idx in pair[1] ] enc_keys = encoder(ts_src) yi = torch.tensor([[SOS_token]]).long().to(device) sample = decoder.generate(yi, mol_seq_len, filter_logits_fn=top_p, filter_thres=filter_thres, keys=enc_keys, eos_token=EOS_token) actual_mol = [] for mol_seq in sample.cpu().numpy(): for mol_idx in mol_seq: actual_mol.append(target_lang.index2word[mol_idx]) print('Generated Seq:', sample) print('Generated Mol:', actual_mol) print( 'Real Mol:', target_mol[:target_mol.index(target_lang. index2word[EOS_token])]) results['generated_seq'].append(sample) results['generated_mol'].append(actual_mol) results['target_mol'].append(target_mol) results['input_genome'].append(input_genome) print('Saving Test Results..') pickle.dump(results, open(os.sep.join([output_folder, 'test_results.pkl']), 'wb')) '''