def main(): '''read arguments''' parser = build_parser() args = parser.parse_args() config = args print("Loading Data!") train_corpus, val_corpus_bins = load_data(config, num_bins=config.bins) data_dir = os.path.join('data', config.dataset) if os.path.exists(data_dir) == False: os.mkdir(data_dir) print("Writing Train corpus") with open(os.path.join(data_dir, 'train_corpus.pk'), 'wb') as f: pickle.dump(file=f, obj=train_corpus) print("Done") print("Writing Val corpus bins") with open(os.path.join(data_dir, 'val_corpus_bins.pk'), 'wb') as f: pickle.dump(file=f, obj=val_corpus_bins) print("Done") print("Writing Train text files") with open(os.path.join(data_dir, 'train_src.txt'), 'w') as f: f.write('\n'.join(train_corpus.source)) with open(os.path.join(data_dir, 'train_tgt.txt'), 'w') as f: f.write('\n'.join(train_corpus.target)) print("Done") print("Writing Val text files") for i, val_corpus_bin in enumerate(val_corpus_bins): with open(os.path.join(data_dir, 'val_src_bin{}.txt'.format(i)), 'w') as f: f.write('\n'.join(val_corpus_bin.source)) with open(os.path.join(data_dir, 'val_tgt_bin{}.txt'.format(i)), 'w') as f: f.write('\n'.join(val_corpus_bin.target)) print("Done") print("Gathering Length and Depth info of the dataset") train_depths = list( set([ train_corpus.Lang.depth_counter(line).sum(1).max() for line in train_corpus.source ])) train_lens = list(set([len(line) for line in train_corpus.source])) val_lens_bins, val_depths_bins = [], [] for i, val_corpus in enumerate(val_corpus_bins): val_depths = list( set([ val_corpus.Lang.depth_counter(line).sum(1).max() for line in val_corpus.source ])) val_depths_bins.append(val_depths) val_lens = list(set([len(line) for line in val_corpus.source])) val_lens_bins.append(val_lens) info_dict = {} info_dict['Lang'] = '{}-{}'.format(config.lang, config.num_par) info_dict['Train Lengths'] = (min(train_lens), max(train_lens)) info_dict['Train Depths'] = (int(min(train_depths)), int(max(train_depths))) info_dict['Train Size'] = len(train_corpus.source) for i, (val_lens, val_depths) in enumerate(zip(val_lens_bins, val_depths_bins)): info_dict['Val Bin-{} Lengths'.format(i)] = (min(val_lens), max(val_lens)) info_dict['Val Bin-{} Depths'.format(i)] = (int(min(val_depths)), int(max(val_depths))) info_dict['Val Bin-{} Size'.format(i)] = len(val_corpus_bins[i].source) with open(os.path.join('data', config.dataset, 'data_info.json'), 'w') as f: json.dump(obj=info_dict, fp=f) print("Done")
def main(): # Parse arguments parser = build_parser() args = parser.parse_args() args.mode = args.mode.lower() # div_gps = args.div_gps # div_lam = args.div_lam if args.mode == 'train': if len(args.run_name.split()) == 0: args.run_name = datetime.fromtimestamp(time.time()).strftime( args.date_fmt) else: args.run_name = args.run_name np.random.seed(args.seed) torch.manual_seed(args.seed) random.seed(args.seed) smethod = str(args.selec) data_sub = str(os.path.join('data', args.dataset, 'test', 'src.txt')).split('/')[-1].split('.')[0] slam = args.slam a1 = args.a1 a2 = args.a2 b1 = args.b1 b2 = args.b2 sparam = [a1, a2, b1, b2] outdir = str(args.out_dir) # GPU initialization device = gpu_init_pytorch(args.gpu) log_folder_name = os.path.join('Logs', args.run_name) create_save_directories('Logs', 'Model', args.run_name) logger = get_logger(__name__, args.run_name, args.log_fmt, logging.INFO, os.path.join(log_folder_name, 's2s.log')) if args.mode == 'train': train_dataloader, val_dataloader = read_files(args, logger) logger.info('Creating vocab ...') voc = Voc(args.dataset) voc = create_vocab_dict(args, voc, train_dataloader) logger.info('Vocab created with number of words = {}'.format( voc.nwords)) logger.info('Saving Vocabulary file') with open(os.path.join('Model', args.run_name, 'vocab.p'), 'wb') as f: pickle.dump(voc, f, protocol=pickle.HIGHEST_PROTOCOL) logger.info('Vocabulary file saved in {}'.format( os.path.join('Model', args.run_name, 'vocab.p'))) else: test_dataloader = read_files(args, logger) logger.info('Loading Vocabulary file') with open(os.path.join('Model', args.run_name, 'vocab.p'), 'rb') as f: voc = pickle.load(f) logger.info('Vocabulary file Loaded from {}'.format( os.path.join('Model', args.run_name, 'vocab.p'))) # Get Checkpoint, return None if no checkpoint present checkpoint = get_latest_checkpoint('Model', args.run_name, logger) if args.mode == 'train': if checkpoint == None: logger.info('Starting a fresh training procedure') ep_offset = 0 min_val_loss = 1e8 max_val_bleu = 0.0 config_file_name = os.path.join('Model', args.run_name, 'config.p') if args.use_word2vec: args.emb_size = 300 model = s2s(args, voc, device, logger) with open(config_file_name, 'wb') as f: pickle.dump(vars(args), f, protocol=pickle.HIGHEST_PROTOCOL) else: config_file_name = os.path.join('Model', args.run_name, 'config.p') with open(config_file_name, 'rb') as f: args = AttrDict(pickle.load(f)) if args.use_word2vec: args.emb_size = 300 model = s2s(args, voc, device, logger) ep_offset, train_loss, min_val_loss, max_val_bleu, voc = load_checkpoint( model, args.mode, checkpoint, logger, device) logger.info('Resuming Training From ') od = OrderedDict() od['Epoch'] = ep_offset od['Train_loss'] = train_loss od['Validation_loss'] = min_val_loss od['Validation_Bleu'] = max_val_bleu print_log(logger, od) ep_offset += 1 # Call Training function train(model, train_dataloader, val_dataloader, voc, device, args, logger, ep_offset, min_val_loss, max_val_bleu) else: if checkpoint == None: logger.info('Cannot decode because of absence of checkpoints') sys.exit() else: config_file_name = os.path.join('Model', args.run_name, 'config.p') beam_width = args.beam_width gpu = args.gpu with open(config_file_name, 'rb') as f: args = AttrDict(pickle.load(f)) args.beam_width = beam_width args.gpu = gpu # args.div_beam = div_beam # args.div_gps = div_gps # args.div_lam = div_lam if args.use_word2vec: args.emb_size = 300 args.slam = slam args.sparam = sparam args.out_dir = outdir model = s2s(args, voc, device, logger) ep_offset, train_loss, min_val_loss, max_val_bleu, voc = load_checkpoint( model, args.mode, checkpoint, logger, device) logger.info('Decoding from') od = OrderedDict() od['Epoch'] = ep_offset od['Train_Loss'] = train_loss od['Validation_Loss'] = min_val_loss od['Validation_Bleu'] = max_val_bleu print_log(logger, od) if args.beam_width == 1: decode_greedy(model, test_dataloader, voc, device, args, logger) else: decode_beam(model, test_dataloader, voc, device, args, logger, smethod, data_sub)
def main(): '''read arguments''' parser = build_parser() args = parser.parse_args() config =args mode = config.mode if mode == 'train': is_train = True else: is_train = False ''' Set seed for reproducibility''' np.random.seed(config.seed) torch.manual_seed(config.seed) random.seed(config.seed) '''GPU initialization''' device = gpu_init_pytorch(config.gpu) #device = 'cpu' '''Run Config files/paths''' run_name = config.run_name config.log_path = os.path.join(log_folder, run_name) config.model_path = os.path.join(model_folder, run_name) config.board_path = os.path.join(board_path, run_name) vocab_path = os.path.join(config.model_path, 'vocab.p') config_file = os.path.join(config.model_path, 'config.p') log_file = os.path.join(config.log_path, 'log.txt') if config.results: config.result_path = os.path.join(result_folder, 'val_results_{}.json'.format(config.dataset)) if is_train: create_save_directories(config.log_path, config.model_path) else: create_save_directories(config.log_path, config.result_path) logger = get_logger(run_name, log_file, logging.DEBUG) writer = SummaryWriter(config.board_path) logger.debug('Created Relevant Directories') logger.info('Experiment Name: {}'.format(config.run_name)) '''Read Files and create/load Vocab''' if is_train: logger.debug('Creating Vocab and loading Data ...') train_loader, val_loader_bins, voc = load_data(config, logger) logger.info( 'Vocab Created with number of words : {}'.format(voc.nwords)) with open(vocab_path, 'wb') as f: pickle.dump(voc, f, protocol=pickle.HIGHEST_PROTOCOL) logger.info('Vocab saved at {}'.format(vocab_path)) else: logger.info('Loading Vocab File...') with open(vocab_path, 'rb') as f: voc = pickle.load(f) logger.info('Vocab Files loaded from {}'.format(vocab_path)) logger.info("Loading Test Dataloaders...") config.batch_size = 1 test_loader_bins = load_data(config, logger, voc) logger.info("Done loading test dataloaders") # print('Done') # TO DO : Load Existing Checkpoints here if is_train: max_val_acc = 0.0 epoch_offset= 0 if config.load_model: checkpoint = get_latest_checkpoint(config.model_path, logger) if checkpoint: ckpt = torch.load(checkpoint, map_location=lambda storage, loc: storage) #config.lr = checkpoint['lr'] model = build_model(config=config, voc=voc, device=device, logger=logger) model.load_state_dict(ckpt['model_state_dict']) model.optimizer.load_state_dict(ckpt['optimizer_state_dict']) else: model = build_model(config=config, voc=voc, device=device, logger=logger) # pdb.set_trace() logger.info('Initialized Model') with open(config_file, 'wb') as f: pickle.dump(vars(config), f, protocol=pickle.HIGHEST_PROTOCOL) logger.debug('Config File Saved') logger.info('Starting Training Procedure') train_model(model, train_loader, val_loader_bins, voc, device, config, logger, epoch_offset, max_val_acc, writer) else: gpu = config.gpu with open(config_file, 'rb') as f: bias = config.bias extraffn = config.extraffn config = AttrDict(pickle.load(f)) config.gpu = gpu config.bins = len(test_loader_bins) config.batch_size = 1 config.bias = bias config.extraffn = extraffn # To do: remove it later #config.num_labels =2 model = build_model(config=config, voc=voc, device=device, logger=logger) checkpoint = get_latest_checkpoint(config.model_path, logger) ep_offset, train_loss, score, voc = load_checkpoint( model, config.mode, checkpoint, logger, device, bins = config.bins) logger.info('Prediction from') od = OrderedDict() od['epoch'] = ep_offset od['train_loss'] = train_loss if config.bins != -1: for i in range(config.bins): od['max_val_acc_bin{}'.format(i)] = score[i] else: od['max_val_acc'] = score print_log(logger, od) pdb.set_trace() #test_acc_epoch, test_loss_epoch = run_validation(config, model, test_loader, voc, device, logger) #test_analysis_dfs = [] for i in range(config.bins): test_acc_epoch, test_analysis_df = run_test(config, model, test_loader_bins[i], voc, device, logger) logger.info('Bin {} Accuracy: {}'.format(i, test_acc_epoch)) #test_analysis_dfs.append(test_analysis_df) test_analysis_df.to_csv(os.path.join(result_folder, '{}_{}_test_analysis_bin{}.csv'.format(config.dataset, config.model_type, i))) logger.info("Analysis results written to {}...".format(result_folder))
def main(): '''Parse Arguments''' parser = build_parser() args = parser.parse_args() '''Specify Seeds for reproducibility''' np.random.seed(args.seed) torch.manual_seed(args.seed) random.seed(args.seed) '''Configs''' device = gpu_init_pytorch(args.gpu) mode = args.mode if mode == 'train': is_train = True else: is_train = False # ckpt= args.ckpt run_name = args.run_name args.log_path = os.path.join(log_folder, run_name) args.model_path = os.path.join(model_folder, run_name) args.board_path = os.path.join(board_path, run_name) args.outputs_path = os.path.join(outputs_folder, run_name) args_file = os.path.join(args.model_path, 'args.p') log_file = os.path.join(args.log_path, 'log.txt') if args.results: args.result_path = os.path.join( result_folder, 'val_results_{}.json'.format(args.dataset)) logging_var = bool(args.logging) if is_train: create_save_directories(args.log_path) create_save_directories(args.model_path) create_save_directories(args.outputs_path) else: create_save_directories(args.log_path) create_save_directories(args.result_path) logger = get_logger(run_name, log_file, logging.DEBUG) logger.debug('Created Relevant Directories') logger.info('Experiment Name: {}'.format(args.run_name)) if args.mt: vocab1_path = os.path.join(args.model_path, 'vocab1.p') vocab2_path = os.path.join(args.model_path, 'vocab2.p') if is_train: #pdb.set_trace() train_dataloader, val_dataloader = load_data(args, logger) logger.debug('Creating Vocab...') voc1 = Voc() voc1.create_vocab_dict(args, 'src', train_dataloader) # To Do : Remove Later voc1.add_to_vocab_dict(args, 'src', val_dataloader) voc2 = Voc() voc2.create_vocab_dict(args, 'trg', train_dataloader) # To Do : Remove Later voc2.add_to_vocab_dict(args, 'trg', val_dataloader) logger.info('Vocab Created with number of words : {}'.format( voc1.nwords)) with open(vocab1_path, 'wb') as f: pickle.dump(voc1, f, protocol=pickle.HIGHEST_PROTOCOL) with open(vocab2_path, 'wb') as f: pickle.dump(voc2, f, protocol=pickle.HIGHEST_PROTOCOL) logger.info('Vocab saved at {}'.format(vocab1_path)) else: test_dataloader = load_data(args, logger) logger.info('Loading Vocab File...') with open(vocab1_path, 'rb') as f: voc1 = pickle.load(f) with open(vocab2_path, 'rb') as f: voc2 = pickle.load(f) logger.info( 'Vocab Files loaded from {}\nNumber of Words: {}'.format( vocab1_path, voc1.nwords)) # print('Done') # TO DO : Load Existing Checkpoints here checkpoint = get_latest_checkpoint(args.model_path, logger) '''Param Specs''' layers = args.layers heads = args.heads d_model = args.d_model d_ff = args.d_ff max_len = args.max_length dropout = args.dropout BATCH_SIZE = args.batch_size epochs = args.epochs if logging_var: meta_fname = os.path.join(args.log_path, 'meta.txt') loss_fname = os.path.join(args.log_path, 'loss.txt') meta_fh = open(meta_fname, 'w') loss_fh = open(loss_fname, 'w') print('Log Files created at: {}'.format(args.log_path)) write_meta(args, meta_fh) """stime= time.time() print('Loading Data...') train, val, test, SRC, TGT = build_data() etime= (time.time()-stime)/60 print('Data Loaded\nTime Taken:{}'.format(etime ))""" pad_idx = voc1.w2id['PAD'] model = make_model(voc1.nwords, voc2.nwords, N=layers, h=heads, d_model=d_model, d_ff=d_ff, dropout=dropout) model.to(device) criterion = LabelSmoothing(size=voc2.nwords, padding_idx=pad_idx, smoothing=0.1) criterion.to(device) # train_iter = MyIterator(train, batch_size=BATCH_SIZE, device=device, # repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)), # batch_size_fn=batch_size_fn, train=True) # valid_iter = MyIterator(val, batch_size=BATCH_SIZE, device=device, # repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)), # batch_size_fn=batch_size_fn, train=False) if mode == 'train': model_opt = NoamOpt( model.src_embed[0].d_model, 1, 2000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) max_val_score = 0.0 min_error_score = 100.0 epoch_offset = 0 for epoch in range(epochs): # pdb.set_trace() #if epoch%3==0: print('Training Epoch: ', epoch) model.train() run_epoch((rebatch(args, device, voc1, voc2, pad_idx, b) for b in train_dataloader), model, LossCompute(model.generator, criterion, device=device, opt=model_opt)) model.eval() # loss = run_epoch((rebatch(args, device, voc1, voc2, pad_idx, b) for b in val_dataloader), # model, # LossCompute(model.generator, criterion, device=device, opt=None)) # loss_str= "Epoch: {} \t Val Loss: {}\n".format(epoch,loss) # print(loss_str) refs = [] hyps = [] error_score = 0 for i, batch in enumerate(val_dataloader): sent1s = sents_to_idx(voc1, batch['src'], args.max_length) sent2s = sents_to_idx(voc2, batch['trg'], args.max_length) sent1_var, sent2_var, input_len1, input_len2 = process_batch( sent1s, sent2s, voc1, voc2, device, voc1.id2w[pad_idx]) sent1s = idx_to_sents(voc1, sent1_var, no_eos=True) sent2s = idx_to_sents(voc2, sent2_var, no_eos=True) #pdb.set_trace() # for l in range(len(batch['src'])): # if len(batch['src'][l].split())!=9: # print(l) #for eg in range(sent1_var.size(0)): src = sent1_var.transpose(0, 1) src_mask = (src != voc1.w2id['PAD']).unsqueeze(-2) #refs.append([' '.join(sent2s[eg])]) refs += [[' '.join(sent2s[i])] for i in range(sent2_var.size(1))] # pdb.set_trace() out = greedy_decode(model, src, src_mask, max_len=60, start_symbol=voc2.w2id['<s>'], pad=pad_idx) words = [] decoded_words = [[] for i in range(out.size(0))] ends = [] #pdb.set_trace() #print("Translation:", end="\t") for z in range(1, out.size(1)): for b in range(len(decoded_words)): sym = voc2.id2w[out[b, z].item()] if b not in ends: if sym == "</s>": ends.append(b) continue #print(sym, end =" ") decoded_words[b].append(sym) with open(args.outputs_path + '/outputs.txt', 'a') as f_out: f_out.write('Batch: ' + str(i) + '\n') f_out.write( '---------------------------------------\n') for z in range(len(decoded_words)): try: f_out.write('Example: ' + str(z) + '\n') f_out.write('Source: ' + batch['src'][z] + '\n') f_out.write('Target: ' + batch['trg'][z] + '\n') f_out.write('Generated: ' + stack_to_string(decoded_words[z]) + '\n' + '\n') except: logger.warning('Exception: Failed to generate') pdb.set_trace() break f_out.write( '---------------------------------------\n') f_out.close() hyps += [ ' '.join(decoded_words[z]) for z in range(len(decoded_words)) ] #hyps.append(stack_to_string(words)) error_score += cal_score(decoded_words, batch['trg']) #print() #print("Target:", end="\t") for z in range(1, sent2_var.size(0)): sym = voc2.id2w[sent2_var[z, 0].item()] if sym == "</s>": break #print(sym, end =" ") #print() #break val_bleu_epoch = bleu_scorer(refs, hyps) print('Epoch: {} Val bleu: {}'.format(epoch, val_bleu_epoch[0])) print('Epoch: {} Val Error: {}'.format( epoch, error_score / len(val_dataloader))) # if logging_var: # loss_fh.write(loss_str) if epoch % 10 == 0: ckpt_path = os.path.join(args.model_path, 'model.pt') logger.info('Saving Checkpoint at : {}'.format(ckpt_path)) torch.save(model.state_dict(), ckpt_path) print('Model saved at: {}'.format(ckpt_path)) else: model.load_state_dict(torch.load(args.model_path)) model.eval() # pdb.set_trace() # for i, batch in enumerate(val_dataloader): # sent1s = sents_to_idx(voc1, batch['src'], args.max_length) # sent2s = sents_to_idx(voc2, batch['trg'], args.max_length) # sent1_var, sent2_var, input_len1, input_len2 = process_batch(sent1s, sent2s, voc1, voc2, device) # src = sent1_var.transpose(0, 1)[:1] # src_mask = (src != voc1.w2id['PAD']).unsqueeze(-2) # out = greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=voc2.w2id['<s>']) # print("Translation:", end="\t") # for i in range(1, out.size(1)): # sym = voc2.id2w[out[0, i].item()] # if sym == "</s>": break # print(sym, end =" ") # print() # print("Target:", end="\t") # for i in range(1, sent2_var.size(0)): # sym = voc2.id2w[sent2_var[i, 0].item()] # if sym == "</s>": break # print(sym, end =" ") # print() # break else: ''' Code for Synthetic Data ''' vocab_path = os.path.join(args.model_path, 'vocab.p') if is_train: #pdb.set_trace() train_dataloader, val_dataloader = load_data(args, logger) logger.debug('Creating Vocab...') voc = Syn_Voc() voc.create_vocab_dict(args, train_dataloader) # To Do : Remove Later voc.add_to_vocab_dict(args, val_dataloader) logger.info('Vocab Created with number of words : {}'.format( voc.nwords)) with open(vocab_path, 'wb') as f: pickle.dump(voc, f, protocol=pickle.HIGHEST_PROTOCOL) logger.info('Vocab saved at {}'.format(vocab_path)) else: test_dataloader = load_data(args, logger) logger.info('Loading Vocab File...') with open(vocab_path, 'rb') as f: voc = pickle.load(f) logger.info( 'Vocab Files loaded from {}\nNumber of Words: {}'.format( vocab_path, voc.nwords)) # print('Done') # TO DO : Load Existing Checkpoints here # checkpoint = get_latest_checkpoint(args.model_path, logger) '''Param Specs''' layers = args.layers heads = args.heads d_model = args.d_model d_ff = args.d_ff max_len = args.max_length dropout = args.dropout BATCH_SIZE = args.batch_size epochs = args.epochs if logging_var: meta_fname = os.path.join(args.log_path, 'meta.txt') loss_fname = os.path.join(args.log_path, 'loss.txt') meta_fh = open(meta_fname, 'w') loss_fh = open(loss_fname, 'w') print('Log Files created at: {}'.format(args.log_path)) write_meta(args, meta_fh) """stime= time.time() print('Loading Data...') train, val, test, SRC, TGT = build_data() etime= (time.time()-stime)/60 print('Data Loaded\nTime Taken:{}'.format(etime ))""" pad_idx = voc.w2id['PAD'] model = make_model(voc.nwords, voc.nwords, N=layers, h=heads, d_model=d_model, d_ff=d_ff, dropout=dropout) model.to(device) logger.info('Initialized Model') criterion = LabelSmoothing(size=voc.nwords, padding_idx=pad_idx, smoothing=0.1) criterion.to(device) # train_iter = MyIterator(train, batch_size=BATCH_SIZE, device=device, # repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)), # batch_size_fn=batch_size_fn, train=True) # valid_iter = MyIterator(val, batch_size=BATCH_SIZE, device=device, # repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)), # batch_size_fn=batch_size_fn, train=False) if mode == 'train': model_opt = NoamOpt( model.src_embed[0].d_model, 1, 3000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) max_bleu_score = 0.0 min_error_score = 100.0 epoch_offset = 0 logger.info('Starting Training Procedure') for epoch in range(epochs): # pdb.set_trace() #if epoch%3==0: print('Training Epoch: ', epoch) model.train() start_time = time.time() run_epoch((rebatch(args, device, voc, voc, pad_idx, b) for b in train_dataloader), model, LossCompute(model.generator, criterion, device=device, opt=model_opt)) time_taken = (time.time() - start_time) / 60.0 logger.debug( 'Training for epoch {} completed...\nTime Taken: {}'. format(epoch, time_taken)) logger.debug('Starting Validation') model.eval() # loss = run_epoch((rebatch(args, device, voc1, voc2, pad_idx, b) for b in val_dataloader), # model, # LossCompute(model.generator, criterion, device=device, opt=None)) # loss_str= "Epoch: {} \t Val Loss: {}\n".format(epoch,loss) # print(loss_str) refs = [] hyps = [] error_score = 0 for i, batch in enumerate(val_dataloader): sent1s = sents_to_idx(voc, batch['src'], args.max_length) sent2s = sents_to_idx(voc, batch['trg'], args.max_length) sent1_var, sent2_var, input_len1, input_len2 = process_batch( sent1s, sent2s, voc, voc, device, voc.id2w[pad_idx]) sent1s = idx_to_sents(voc, sent1_var, no_eos=True) sent2s = idx_to_sents(voc, sent2_var, no_eos=True) # pdb.set_trace() # for l in range(len(batch['src'])): # if len(batch['src'][l].split())!=9: # print(l) #for eg in range(sent1_var.size(0)): src = sent1_var.transpose(0, 1) ### FOR NON-DIRECTIONAL ### # src_mask = (src != voc.w2id['PAD']).unsqueeze(-2) ### FOR DIRECTIONAL ### src_mask = make_std_mask(src, pad_idx) src_mask_bi = make_bi_std_mask(src, pad_idx) src_mask_dec = (src != voc.w2id['PAD']).unsqueeze(-2) #refs.append([' '.join(sent2s[eg])]) # refs += [[' '.join(sent2s[i])] for i in range(sent2_var.size(1))] refs += [[x] for x in batch['trg']] out = greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=voc.w2id['<s>'], pad=pad_idx, src_mask_dec=src_mask_dec, src_mask_bi=src_mask_bi) words = [] decoded_words = [[] for i in range(out.size(0))] ends = [] # pdb.set_trace() #print("Translation:", end="\t") for z in range(1, out.size(1)): for b in range(len(decoded_words)): sym = voc.id2w[out[b, z].item()] if b not in ends: if sym == "</s>": ends.append(b) continue #print(sym, end =" ") decoded_words[b].append(sym) with open(args.outputs_path + '/outputs.txt', 'a') as f_out: f_out.write('Batch: ' + str(i) + '\n') f_out.write( '---------------------------------------\n') for z in range(len(decoded_words)): try: f_out.write('Example: ' + str(z) + '\n') f_out.write('Source: ' + batch['src'][z] + '\n') f_out.write('Target: ' + batch['trg'][z] + '\n') f_out.write('Generated: ' + stack_to_string(decoded_words[z]) + '\n' + '\n') except: logger.warning('Exception: Failed to generate') pdb.set_trace() break f_out.write( '---------------------------------------\n') f_out.close() hyps += [ ' '.join(decoded_words[z]) for z in range(len(decoded_words)) ] #hyps.append(stack_to_string(words)) if args.ap: error_score += cal_score_AP(decoded_words, batch['trg']) else: error_score += cal_score(decoded_words, batch['trg']) #print() #print("Target:", end="\t") for z in range(1, sent2_var.size(0)): sym = voc.id2w[sent2_var[z, 0].item()] if sym == "</s>": break #print(sym, end =" ") #print() #break if (error_score / len(val_dataloader)) < min_error_score: min_error_score = error_score / len(val_dataloader) val_bleu_epoch = bleu_scorer(refs, hyps) if max_bleu_score < val_bleu_epoch[0]: max_bleu_score = val_bleu_epoch[0] logger.info('Epoch: {} Val bleu: {}'.format( epoch, val_bleu_epoch[0])) logger.info('Maximum Bleu: {}'.format(max_bleu_score)) logger.info('Epoch: {} Val Error: {}'.format( epoch, error_score / len(val_dataloader))) logger.info('Minimum Error: {}'.format(min_error_score)) # if logging_var: # loss_fh.write(loss_str) if epoch % 5 == 0: ckpt_path = os.path.join(args.model_path, 'model.pt') logger.info('Saving Checkpoint at : {}'.format(ckpt_path)) torch.save(model.state_dict(), ckpt_path) print('Model saved at: {}'.format(ckpt_path)) store_results(args, max_bleu_score, min_error_score) logger.info('Scores saved at {}'.format(args.result_path)) else: model.load_state_dict(torch.load(args.model_path)) model.eval()
def main(): parser = build_parser() args = parser.parse_args() args.tree_height = [int(s) for s in args.tree_height.split(",")] use_ptr = args.use_ptr cov_after_ep = args.cov_after_ep height_dec = args.height_dec if args.mode == 'train': if len(args.run_name.split()) == 0: args.run_name = datetime.fromtimestamp(time.time()).strftime( args.date_fmt) if args.pretrained_encoder == None or args.pretrained_encoder == 'bert_all': args.use_attn = True # SET SEEDS np.random.seed(args.seed) torch.manual_seed(args.seed) random.seed(args.seed) # ASSIGN GPU device = gpu_init_pytorch(args.gpu) # CREATE LOGGING FOLDER log_folder_name = os.path.join('Logs', args.run_name) create_save_directories('Logs', 'Models', args.run_name) # Comet ML - Log all params experiment = None if not args.debug: experiment = Experiment(api_key=_API_KEY, project_name=args.project_name, workspace="NAN") experiment.set_name(args.run_name) experiment.log_parameters(vars(args)) logger = get_logger(__name__, 'temp_run', args.log_fmt, logging.INFO, os.path.join(log_folder_name, 'SYN-Par.log')) logger.info('Run name: {}'.format(args.run_name)) if args.mode == 'train': train_dataloader, val_dataloader, test_dataloader = read_files( args, logger) logger.info('Creating vocab ...') voc1 = Voc(args.dataset + 'sents') voc2 = Voc(args.dataset + 'trees') voc_file = os.path.join('Models', args.run_name, 'vocab.p') if (os.path.exists(voc_file)): logger.info('Loading vocabulary from {}'.format( os.path.join('Models', args.run_name, 'vocab.p'))) voc = pickle.load(open(voc_file, 'rb')) else: voc = create_vocab_dict(args, voc1, voc2, train_dataloader) logger.info('Vocab created with number of words = {}'.format( voc.nwords)) logger.info('Saving Vocabulary file') with open(voc_file, 'wb') as f: pickle.dump(voc, f, protocol=pickle.HIGHEST_PROTOCOL) logger.info('Vocabulary file saved in {}'.format( os.path.join('Models', args.run_name, 'vocab.p'))) else: config_file_name = os.path.join('Models', args.run_name, 'config.p') mode = args.mode batch_size = args.batch_size beam_width = args.beam_width gpu = args.gpu tree_height2 = 40 use_glove = args.use_glove max_length = args.max_length datatype = args.datatype res_file = args.res_file dataset = args.dataset load_from_ep = args.load_from_ep run_name = args.run_name max_epochs = args.max_epochs with open(config_file_name, 'rb') as f: args = AttrDict(pickle.load(f)) args.mode = mode args.gpu = gpu args.load_from_ep = load_from_ep args.dataset = dataset args.beam_width = beam_width args.gpu = gpu args.height_dec = height_dec args.tree_height2 = tree_height2 args.use_glove = use_glove args.max_length = max_length args.datatype = datatype args.res_file = res_file args.run_name = run_name args.max_epochs = max_epochs test_dataloader = read_files(args, logger) logger.info('Loading Vocabulary file') with open(os.path.join('Models', args.run_name, 'vocab.p'), 'rb') as f: voc = pickle.load(f) logger.info('Vocabulary file Loaded from {}'.format( os.path.join('Models', args.run_name, 'vocab.p'))) checkpoint = get_latest_checkpoint('Models', args.run_name, logger, args.load_from_ep) if args.mode == 'train': if checkpoint == None: logger.info('Starting a fresh training procedure') ep_offset = 0 min_val_loss = 1e8 max_val_bleu = 0.0 config_file_name = os.path.join('Models', args.run_name, 'config.p') if args.use_word2vec: logger.info( 'Over-writing emb_size to 300 because argument use_word2vec has been set to True' ) args.emb_size = 300 model = SYN_Par(args, voc, device, logger, experiment) with open(config_file_name, 'wb') as f: pickle.dump(vars(args), f, protocol=pickle.HIGHEST_PROTOCOL) else: config_file_name = os.path.join('Models', args.run_name, 'config.p') debug = args.debug max_epochs = args.max_epochs with open(config_file_name, 'rb') as f: args = AttrDict(pickle.load(f)) if args.use_word2vec: logger.info( 'Over-writing emb_size to 300 because argument use_word2vec has been set to True' ) args.emb_size = 300 args.use_ptr = use_ptr args.debug = debug args.cov_after_ep = cov_after_ep args.max_epochs = max_epochs model = SYN_Par(args, voc, device, logger, experiment) ep_offset, train_loss, min_val_loss, max_val_bleu, voc = load_checkpoint( model, args.mode, checkpoint, logger, device, args.pretrained_encoder) logger.info('Resuming Training From ') od = OrderedDict() if ep_offset is None and train_loss is None and min_val_loss is None and max_val_bleu is None: od['Epoch'] = 0 od['Train_loss'] = 0.0 od['Validation_loss'] = 0.0 od['Validation_Bleu'] = 0.0 else: od['Epoch'] = ep_offset od['Train_loss'] = train_loss od['Validation_loss'] = min_val_loss od['Validation_Bleu'] = float(max_val_bleu) print_log(logger, od) ep_offset += 1 max_val_bleu = float(max_val_bleu) train(model, train_dataloader, val_dataloader, test_dataloader, voc, device, args, logger, ep_offset, min_val_loss, max_val_bleu, experiment=experiment) else: if checkpoint == None: logger.info('Cannot decode because of absence of checkpoints') sys.exit() else: config_file_name = os.path.join('Models', args.run_name, 'config.p') beam_width = args.beam_width gpu = args.gpu tree_height2 = 40 use_glove = args.use_glove max_length = args.max_length datatype = args.datatype res_file = args.res_file dataset = args.dataset load_from_ep = args.load_from_ep run_name = args.run_name max_epochs = args.max_epochs with open(config_file_name, 'rb') as f: args = AttrDict(pickle.load(f)) args.load_from_ep = load_from_ep args.dataset = dataset args.beam_width = beam_width args.gpu = gpu args.height_dec = height_dec args.tree_height2 = tree_height2 args.use_glove = use_glove args.max_length = max_length args.datatype = datatype args.res_file = res_file args.run_name = run_name args.max_epochs = max_epochs if args.use_word2vec: logger.info( 'Over-writing emb_size to 300 because argument use_word2vec has been set to True' ) args.emb_size = 300 model = SYN_Par(args, voc, device, logger, experiment) ep_offset, train_loss, min_val_loss, max_val_bleu, voc = load_checkpoint( model, args.mode, checkpoint, logger, device, args.pretrained_encoder) logger.info('Decoding from') od = OrderedDict() if ep_offset is None and train_loss is None and min_val_loss is None and max_val_bleu is None: od['Epoch'] = 0 od['Train_loss'] = 0.0 od['Validation_loss'] = 0.0 od['Validation_Bleu'] = 0.0 else: od['Epoch'] = ep_offset od['Train_loss'] = train_loss od['Validation_loss'] = min_val_loss od['Validation_Bleu'] = float(max_val_bleu) print_log(logger, od) ''' refs, hyps, test_loss_epoch = validation(args, model, test_dataloader, voc, device, logger) refs = open('data/controlledgen/test_ref.txt').read().split('\n')[:-1] with open('temp_refs.txt', 'w') as f: f.write('\n'.join(refs)) with open('temp_hyps.txt', 'w') as f: f.write('\n'.join(hyps)) bleu_score_test = run_multi_bleu('temp_hyps.txt', 'temp_refs.txt') #bleu_score_test = bleu_scorer(refs, hyps) #os.remove('temp_hyps.txt') #os.remove('temp_refs.txt') logger.info('Test BLEU score: {}'.format(bleu_score_test)) ''' decode_greedy(model, test_dataloader, voc, device, args, logger)