def train(args): class uniform_initializer(object): def __init__(self, stdv): self.stdv = stdv def __call__(self, tensor): nn.init.uniform_(tensor, -self.stdv, self.stdv) opt_dict = {"not_improved": 0, "lr": 1., "best_ppl": 1e4} hparams = HParams(**vars(args)) data = DataUtil(hparams=hparams) model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) model = LSTM_LM(model_init, emb_init, hparams) if hparams.eval_from != "": model = torch.load(hparams.eval_from) model.to(hparams.device) with torch.no_grad(): if args.generate: model.generate(args.gen_len, args.num_gen, hparams, data) else: test(model, data, hparams) return model.to(hparams.device) trainable_params = [ p for p in model.parameters() if p.requires_grad] num_params = count_params(trainable_params) print("Model has {0} params".format(num_params)) optim = torch.optim.SGD(model.parameters(), lr=1.0) step = epoch = decay_cnt = 0 report_words = report_loss = report_ppl = report_sents = 0 start_time = time.time() model.train() while True: x_train, x_mask, x_count, x_len, x_pos_emb_idxs, \ y_train, y_mask, y_count, y_len, y_pos_emb_idxs, \ y_sampled, y_sampled_mask, y_sampled_count, y_sampled_len, \ y_pos_emb_idxs, batch_size, eop = data.next_train() report_words += (x_count - batch_size) report_sents += batch_size optim.zero_grad() loss = model.reconstruct_error(x_train, x_len) loss = loss.mean(dim=-1) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad) optim.step() report_loss += loss.item() * batch_size if step % args.log_every == 0: curr_time = time.time() since_start = (curr_time - start_time) / 60.0 log_string = "ep={0:<3d}".format(epoch) log_string += " steps={}".format(step) log_string += " lr={0:<9.7f}".format(opt_dict["lr"]) log_string += " loss={0:<7.2f}".format(report_loss / report_sents) log_string += " |g|={0:<5.2f}".format(grad_norm) log_string += " ppl={0:<8.2f}".format(np.exp(report_loss / report_words)) log_string += " time(min)={0:<5.2f}".format(since_start) print(log_string) if step % args.eval_every == 0: with torch.no_grad(): val_loss, val_ppl = test(model, data, hparams) if val_ppl < opt_dict["best_ppl"]: print("update best ppl") opt_dict["best_ppl"] = val_ppl opt_dict["not_improved"] = 0 torch.save(model, os.path.join(hparams.output_dir, "model.pt")) if val_ppl > opt_dict["best_ppl"]: opt_dict["not_improved"] += 1 if opt_dict["not_improved"] >= decay_step: opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay model = torch.load(os.path.join(hparams.output_dir, "model.pt")) print("new lr: {0:<9.7f}".format(opt_dict["lr"])) decay_cnt += 1 optim = torch.optim.SGD(model.parameters(), lr=opt_dict["lr"]) report_words = report_loss = report_ppl = report_sents = 0 model.train() step += 1 if eop: epoch += 1 if decay_cnt >= max_decay: break
def train(): print(args) if args.load_model and (not args.reset_hparams): print("load hparams..") hparams_file_name = os.path.join(args.output_dir, "hparams.pt") hparams = torch.load(hparams_file_name) hparams.load_model = args.load_model hparams.n_train_steps = args.n_train_steps else: hparams = HParams(**vars(args)) hparams.noise_flag = True # build or load model print("-" * 80) print("Creating model") if args.load_model: data = DataUtil(hparams=hparams) model_file_name = os.path.join(args.output_dir, "model.pt") print("Loading model from '{0}'".format(model_file_name)) model = torch.load(model_file_name) if not hasattr(model, 'data'): model.data = data if not hasattr(model.hparams, 'transformer_wdrop'): model.hparams.transformer_wdrop = False optim_file_name = os.path.join(args.output_dir, "optimizer.pt") print("Loading optimizer from {}".format(optim_file_name)) trainable_params = [p for p in model.parameters() if p.requires_grad] #optim = torch.optim.Adam(trainable_params, lr=hparams.lr, betas=(0.9, 0.98), eps=1e-9, weight_decay=hparams.l2_reg) optim = torch.optim.Adam(trainable_params, lr=hparams.lr, weight_decay=hparams.l2_reg) optimizer_state = torch.load(optim_file_name) optim.load_state_dict(optimizer_state) extra_file_name = os.path.join(args.output_dir, "extra.pt") step, best_val_ppl, best_val_bleu, cur_attempt, cur_decay_attempt, lr = torch.load( extra_file_name) else: if args.pretrained_model: model_name = os.path.join(args.pretrained_model, "model.pt") print("Loading model from '{0}'".format(model_name)) model = torch.load(model_name) print("load hparams..") hparams_file_name = os.path.join(args.pretrained_model, "hparams.pt") reload_hparams = torch.load(hparams_file_name) reload_hparams.train_src_file_list = hparams.train_src_file_list reload_hparams.train_trg_file_list = hparams.train_trg_file_list reload_hparams.dropout = hparams.dropout reload_hparams.lr_dec = hparams.lr_dec hparams = reload_hparams data = DataUtil(hparams=hparams) model.data = data else: data = DataUtil(hparams=hparams) if args.model_type == 'seq2seq': model = Seq2Seq(hparams=hparams, data=data) elif args.model_type == 'transformer': model = Transformer(hparams=hparams, data=data) else: print("Model {} not implemented".format(args.model_type)) exit(0) if args.init_type == "uniform" and not hparams.model_type == "transformer": print("initialize uniform with range {}".format( args.init_range)) for p in model.parameters(): p.data.uniform_(-args.init_range, args.init_range) trainable_params = [p for p in model.parameters() if p.requires_grad] optim = torch.optim.Adam(trainable_params, lr=hparams.lr, weight_decay=hparams.l2_reg) #optim = torch.optim.Adam(trainable_params) step = 0 #best_val_ppl = None best_val_ppl = 1000 best_val_bleu = None cur_attempt = 0 cur_decay_attempt = 0 lr = hparams.lr model.set_lm() model.to(hparams.device) if args.eval_cls: classifier_file_name = os.path.join(args.classifier_dir, "model.pt") print("Loading model from '{0}'".format(classifier_file_name)) classifier = torch.load(classifier_file_name).to(hparams.device) else: classifier = None if args.reset_hparams: lr = args.lr crit = get_criterion(hparams) trainable_params = [p for p in model.parameters() if p.requires_grad] num_params = count_params(trainable_params) print("Model has {0} params".format(num_params)) if args.load_for_test: val_ppl, val_bleu = eval(model, classifier, data, crit, step, hparams, eval_bleu=args.eval_bleu, valid_batch_size=args.valid_batch_size) return print("-" * 80) print("start training...") start_time = log_start_time = time.time() target_words = total_loss = total_sents = total_noise_corrects = total_transfer_corrects = 0 total_bt_loss = total_noise_loss = total_KL_loss = 0. target_rules, target_total, target_eos = 0, 0, 0 total_word_loss, total_rule_loss, total_eos_loss = 0, 0, 0 total_lm_length = total_trans_length = 0 model.train() #i = 0 dev_zero = args.dev_zero tr_loss, update_batch_size = None, 0 if hparams.anneal_epoch == 0: hparams.noise_weight = 0. anneal_rate = 0. else: hparams.noise_weight = 1. if hparams.anneal_epoch == -1: anneal_rate = 0. else: anneal_rate = 1.0 / (data.train_size * args.anneal_epoch // hparams.batch_size) hparams.gs_temp = 1. while True: step += 1 x_train, x_mask, x_count, x_len, x_pos_emb_idxs, y_train, y_mask, y_count, y_len, y_pos_emb_idxs, y_sampled, y_sampled_mask, y_sampled_count, y_sampled_len, y_pos_emb_idxs, batch_size, eop = data.next_train( ) #print(x_train) #print(x_count) target_words += (x_count - batch_size) total_sents += batch_size trans_logits, noise_logits, KL_loss, lm_len, trans_len = model.forward( x_train, x_mask, x_len, x_pos_emb_idxs, y_train, y_mask, y_len, y_pos_emb_idxs, y_sampled, y_sampled_mask, y_sampled_len) total_lm_length += lm_len total_trans_length += trans_len # not predicting the start symbol labels = x_train[:, 1:].contiguous().view(-1) cur_tr_loss, trans_loss, noise_loss, cur_tr_acc, cur_tr_transfer_acc = get_performance( crit, trans_logits, noise_logits, labels, hparams, x_len) assert (cur_tr_loss.item() > 0) if hparams.lm: cur_tr_loss = cur_tr_loss + hparams.klw * KL_loss.sum() total_KL_loss += KL_loss.sum().item() hparams.noise_weight = max(0., hparams.noise_weight - anneal_rate) if hparams.noise_weight == 0: hparams.noise_flag = False # if eop: # hparams.gs_temp = max(0.001, hparams.gs_temp * 0.5) total_loss += cur_tr_loss.item() total_bt_loss += trans_loss.item() total_noise_loss += noise_loss.item() total_noise_corrects += cur_tr_acc total_transfer_corrects += cur_tr_transfer_acc if tr_loss is None: tr_loss = cur_tr_loss else: tr_loss = tr_loss + cur_tr_loss update_batch_size += batch_size if step % args.update_batch == 0: # set learning rate if args.lr_schedule: s = step / args.update_batch + 1 lr = pow(hparams.d_model, -0.5) * min( pow(s, -0.5), s * pow(hparams.n_warm_ups, -1.5)) set_lr(optim, lr) elif step / args.update_batch < hparams.n_warm_ups: base_lr = hparams.lr base_lr = base_lr * (step / args.update_batch + 1) / hparams.n_warm_ups set_lr(optim, base_lr) lr = base_lr elif args.lr_dec_steps > 0: s = (step / args.update_batch) % args.lr_dec_steps lr = args.lr_min + 0.5 * (args.lr_max - args.lr_min) * ( 1 + np.cos(s * np.pi / args.lr_dec_steps)) set_lr(optim, lr) tr_loss = tr_loss / update_batch_size tr_loss.backward() #grad_norm = grad_clip(trainable_params, grad_bound=args.clip_grad) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) optim.step() optim.zero_grad() tr_loss = None update_batch_size = 0 # clean up GPU memory if step % args.clean_mem_every == 0: gc.collect() epoch = step // data.n_train_batches if (step / args.update_batch) % args.log_every == 0: curr_time = time.time() since_start = (curr_time - start_time) / 60.0 elapsed = (curr_time - log_start_time) / 60.0 log_string = "ep={0:<3d}".format(epoch) log_string += " steps={0:<6.2f}".format( (step / args.update_batch) / 1000) log_string += " lr={0:<9.7f}".format(lr) log_string += " total={0:<7.2f}".format(total_loss / total_sents) log_string += " neg ELBO={0:<7.2f}".format( (total_KL_loss + total_bt_loss) / total_sents) log_string += " KL={0:<7.2f}".format(total_KL_loss / total_sents) log_string += " |g|={0:<5.2f}".format(grad_norm) log_string += " bt_ppl={0:<8.2f}".format( np.exp(total_bt_loss / target_words)) log_string += " n_ppl={0:<8.2f}".format( np.exp(total_noise_loss / target_words)) log_string += " n_acc={0:<5.4f}".format(total_noise_corrects / target_words) log_string += " bt_acc={0:<5.4f}".format(total_transfer_corrects / target_words) # noise weight log_string += " nw={:.4f}".format(hparams.noise_weight) log_string += " lmlen={}".format(total_lm_length // total_sents) log_string += " translen={}".format(total_trans_length // total_sents) # log_string += " wpm(k)={0:<5.2f}".format(target_words / (1000 * elapsed)) log_string += " t={0:<5.2f}".format(since_start) print(log_string) if args.eval_end_epoch: if eop: eval_now = True else: eval_now = False elif (step / args.update_batch) % args.eval_every == 0: eval_now = True else: eval_now = False if eval_now: # based_on_bleu = args.eval_bleu and best_val_ppl is not None and best_val_ppl <= args.ppl_thresh based_on_bleu = False if args.dev_zero: based_on_bleu = True print("target words: {}".format(target_words)) with torch.no_grad(): val_ppl, val_bleu = eval( model, classifier, data, crit, step, hparams, eval_bleu=args.eval_bleu, valid_batch_size=args.valid_batch_size) if based_on_bleu: if best_val_bleu is None or best_val_bleu <= val_bleu: save = True best_val_bleu = val_bleu cur_attempt = 0 cur_decay_attempt = 0 else: save = False else: if best_val_ppl is None or best_val_ppl >= val_ppl: save = True best_val_ppl = val_ppl cur_attempt = 0 cur_decay_attempt = 0 else: save = False if save or args.always_save: save_checkpoint([ step, best_val_ppl, best_val_bleu, cur_attempt, cur_decay_attempt, lr ], model, optim, hparams, args.output_dir) elif not args.lr_schedule and step >= hparams.n_warm_ups: if cur_decay_attempt >= args.attempt_before_decay: if val_ppl >= 2 * best_val_ppl: print("reload saved best model !!!") model.load_state_dict( torch.load( os.path.join(args.output_dir, "model.dict"))) hparams = torch.load( os.path.join(args.output_dir, "hparams.pt")) lr = lr * args.lr_dec set_lr(optim, lr) cur_attempt += 1 cur_decay_attempt = 0 else: cur_decay_attempt += 1 # reset counter after eval log_start_time = time.time() target_words = total_sents = total_noise_corrects = total_transfer_corrects = total_loss = 0 total_bt_loss = total_noise_loss = total_KL_loss = 0. target_rules = target_total = target_eos = 0 total_word_loss = total_rule_loss = total_eos_loss = 0 total_lm_length = total_trans_length = 0 if args.patience >= 0: if cur_attempt > args.patience: break elif args.n_train_epochs > 0: if epoch >= args.n_train_epochs: break else: if step > args.n_train_steps: break
#hparams.data_path=args.data_path #hparams.src_vocab_list=args.src_vocab_list #hparams.trg_vocab_list=args.trg_vocab_list hparams.test_src_file = args.test_src_file hparams.test_trg_file = args.test_trg_file hparams.cuda=args.cuda hparams.beam_size=args.beam_size hparams.max_len=args.max_len hparams.batch_size=args.batch_size hparams.merge_bpe=args.merge_bpe hparams.out_file=out_file hparams.nbest=args.nbest hparams.decode=True model.hparams.cuda = hparams.cuda data = DataUtil(hparams=hparams, decode=True) filts = [model.hparams.pad_id, model.hparams.eos_id, model.hparams.bos_id] if not hasattr(model, 'data'): model.data = data if args.debug: hparams.add_param("target_word_vocab_size", data.target_word_vocab_size) hparams.add_param("target_rule_vocab_size", data.target_rule_vocab_size) crit = get_criterion(hparams) out_file = open(hparams.out_file, 'w', encoding='utf-8') end_of_epoch = False num_sentences = 0
def train(): if args.load_model and (not args.reset_hparams): print("load hparams..") hparams_file_name = os.path.join(args.output_dir, "hparams.pt") hparams = torch.load(hparams_file_name) hparams.load_model = args.load_model hparams.n_train_steps = args.n_train_steps else: hparams = HParams( decode=args.decode, data_path=args.data_path, train_src_file_list=args.train_src_file_list, train_trg_file_list=args.train_trg_file_list, dev_src_file=args.dev_src_file, dev_trg_file=args.dev_trg_file, src_vocab_list=args.src_vocab_list, trg_vocab_list=args.trg_vocab_list, src_vocab_size=args.src_vocab_size, trg_vocab_size=args.trg_vocab_size, max_len=args.max_len, n_train_sents=args.n_train_sents, cuda=args.cuda, d_word_vec=args.d_word_vec, d_model=args.d_model, batch_size=args.batch_size, batcher=args.batcher, n_train_steps=args.n_train_steps, dropout=args.dropout, lr=args.lr, lr_dec=args.lr_dec, l2_reg=args.l2_reg, init_type=args.init_type, init_range=args.init_range, share_emb_softmax=args.share_emb_softmax, n_heads=args.n_heads, d_k=args.d_k, d_v=args.d_v, merge_bpe=args.merge_bpe, load_model=args.load_model, char_ngram_n=args.char_ngram_n, max_char_vocab_size=args.max_char_vocab_size, char_input=args.char_input, char_comb=args.char_comb, char_temp=args.char_temp, src_char_vocab_from=args.src_char_vocab_from, src_char_vocab_size=args.src_char_vocab_size, trg_char_vocab_from=args.trg_char_vocab_from, trg_char_vocab_size=args.trg_char_vocab_size, src_char_only=args.src_char_only, trg_char_only=args.trg_char_only, semb=args.semb, dec_semb=args.dec_semb, semb_vsize=args.semb_vsize, lan_code_rl=args.lan_code_rl, sample_rl=args.sample_rl, sep_char_proj=args.sep_char_proj, query_base=args.query_base, residue=args.residue, layer_norm=args.layer_norm, src_no_char=args.src_no_char, trg_no_char=args.trg_no_char, char_gate=args.char_gate, shuffle_train=args.shuffle_train, ordered_char_dict=args.ordered_char_dict, out_c_list=args.out_c_list, k_list=args.k_list, d_char_vec=args.d_char_vec, highway=args.highway, n=args.n, single_n=args.single_n, bpe_ngram=args.bpe_ngram, uni=args.uni, pretrained_src_emb_list=args.pretrained_src_emb_list, pretrained_trg_emb=args.pretrained_trg_emb, ) # build or load model print("-" * 80) print("Creating model") if args.load_model: data = DataUtil(hparams=hparams) model_file_name = os.path.join(args.output_dir, "model.pt") print("Loading model from '{0}'".format(model_file_name)) model = torch.load(model_file_name) if not hasattr(model, 'data'): model.data = data optim_file_name = os.path.join(args.output_dir, "optimizer.pt") print("Loading optimizer from {}".format(optim_file_name)) trainable_params = [ p for p in model.parameters() if p.requires_grad] optim = torch.optim.Adam(trainable_params, lr=hparams.lr, weight_decay=hparams.l2_reg) optimizer_state = torch.load(optim_file_name) optim.load_state_dict(optimizer_state) extra_file_name = os.path.join(args.output_dir, "extra.pt") step, best_val_ppl, best_val_bleu, cur_attempt, lr = torch.load(extra_file_name) else: if args.pretrained_model: model_name = os.path.join(args.pretrained_model, "model.pt") print("Loading model from '{0}'".format(model_name)) model = torch.load(model_name) #if not hasattr(model, 'data'): # model.data = data #if not hasattr(model, 'char_ngram_n'): # model.hparams.char_ngram_n = 0 #if not hasattr(model, 'char_input'): # model.hparams.char_input = None print("load hparams..") hparams_file_name = os.path.join(args.pretrained_model, "hparams.pt") reload_hparams = torch.load(hparams_file_name) reload_hparams.train_src_file_list = hparams.train_src_file_list reload_hparams.train_trg_file_list = hparams.train_trg_file_list reload_hparams.dropout = hparams.dropout reload_hparams.lr_dec = hparams.lr_dec hparams = reload_hparams #hparams.src_vocab_list = reload_hparams.src_vocab_list #hparams.src_vocab_size = reload_hparams.src_vocab_size #hparams.trg_vocab_list = reload_hparams.trg_vocab_list #hparams.trg_vocab_size = reload_hparams.trg_vocab_size #hparams.src_char_vocab_from = reload_hparams.src_char_vocab_from #hparams.src_char_vocab_size = reload_hparams.src_char_vocab_size #hparams.trg_char_vocab_from = reload_hparams.trg_char_vocab_from #hparams.trg_char_vocab_size = reload_hparams.trg_char_vocab_size #print(reload_hparams.src_char_vocab_from) #print(reload_hparams.src_char_vocab_size) data = DataUtil(hparams=hparams) model.data = data else: data = DataUtil(hparams=hparams) model = Seq2Seq(hparams=hparams, data=data) if args.init_type == "uniform": print("initialize uniform with range {}".format(args.init_range)) for p in model.parameters(): p.data.uniform_(-args.init_range, args.init_range) if args.id_init_sep and args.semb and args.sep_char_proj: print("initialize char proj as identity matrix") for s in model.encoder.char_emb.sep_proj_list: d = s.weight.data.size(0) s.weight.data.copy_(torch.eye(d) + args.id_scale*torch.diagflat(torch.ones(d).normal_(0,1))) trainable_params = [ p for p in model.parameters() if p.requires_grad] optim = torch.optim.Adam(trainable_params, lr=hparams.lr, weight_decay=hparams.l2_reg) #optim = torch.optim.Adam(trainable_params) step = 0 best_val_ppl = 100 best_val_bleu = 0 cur_attempt = 0 lr = hparams.lr if args.cuda: model = model.cuda() if args.reset_hparams: lr = args.lr crit = get_criterion(hparams) trainable_params = [ p for p in model.parameters() if p.requires_grad] num_params = count_params(trainable_params) print("Model has {0} params".format(num_params)) print("-" * 80) print("start training...") start_time = log_start_time = time.time() target_words, total_loss, total_corrects = 0, 0, 0 target_rules, target_total, target_eos = 0, 0, 0 total_word_loss, total_rule_loss, total_eos_loss = 0, 0, 0 model.train() #i = 0 dev_zero = args.dev_zero while True: x_train, x_mask, x_count, x_len, y_train, y_mask, y_count, y_len, batch_size, x_train_char_sparse, y_train_char_sparse, eop, file_idx = data.next_train() optim.zero_grad() target_words += (y_count - batch_size) logits = model.forward(x_train, x_mask, x_len, y_train[:,:-1], y_mask[:,:-1], y_len, x_train_char_sparse, y_train_char_sparse, file_idx=file_idx) logits = logits.view(-1, hparams.trg_vocab_size) labels = y_train[:,1:].contiguous().view(-1) #print(labels) tr_loss, tr_acc = get_performance(crit, logits, labels, hparams) total_loss += tr_loss.item() total_corrects += tr_acc.item() step += 1 if dev_zero: dev_zero = False #based_on_bleu = args.eval_bleu and best_val_ppl <= args.ppl_thresh based_on_bleu = args.eval_bleu val_ppl, val_bleu = eval(model, data, crit, step, hparams, eval_bleu=based_on_bleu, valid_batch_size=args.valid_batch_size, tr_logits=logits) if based_on_bleu: if best_val_bleu <= val_bleu: save = True best_val_bleu = val_bleu cur_attempt = 0 else: save = False cur_attempt += 1 else: if best_val_ppl >= val_ppl: save = True best_val_ppl = val_ppl cur_attempt = 0 else: save = False cur_attempt += 1 if save or args.always_save: save_checkpoint([step, best_val_ppl, best_val_bleu, cur_attempt, lr], model, optim, hparams, args.output_dir) else: lr = lr * args.lr_dec set_lr(optim, lr) # reset counter after eval log_start_time = time.time() target_words = total_corrects = total_loss = 0 target_rules = target_total = target_eos = 0 total_word_loss = total_rule_loss = total_eos_loss = 0 tr_loss.div_(batch_size) tr_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) optim.step() # clean up GPU memory if step % args.clean_mem_every == 0: gc.collect() epoch = step // sum(data.n_train_batches) if step % args.log_every == 0: curr_time = time.time() since_start = (curr_time - start_time) / 60.0 elapsed = (curr_time - log_start_time) / 60.0 log_string = "ep={0:<3d}".format(epoch) log_string += " steps={0:<6.2f}".format(step / 1000) log_string += " lr={0:<9.7f}".format(lr) log_string += " loss={0:<7.2f}".format(tr_loss.item()) log_string += " |g|={0:<5.2f}".format(grad_norm) log_string += " ppl={0:<8.2f}".format(np.exp(total_loss / target_words)) log_string += " acc={0:<5.4f}".format(total_corrects / target_words) log_string += " wpm(k)={0:<5.2f}".format(target_words / (1000 * elapsed)) log_string += " time(min)={0:<5.2f}".format(since_start) print(log_string) if args.eval_end_epoch: if eop: eval_now = True else: eval_now = False elif step % args.eval_every == 0: eval_now = True else: eval_now = False if eval_now: based_on_bleu = args.eval_bleu and best_val_ppl <= args.ppl_thresh if args.dev_zero: based_on_bleu = True with torch.no_grad(): val_ppl, val_bleu = eval(model, data, crit, step, hparams, eval_bleu=based_on_bleu, valid_batch_size=args.valid_batch_size, tr_logits=logits) if based_on_bleu: if best_val_bleu <= val_bleu: save = True best_val_bleu = val_bleu cur_attempt = 0 else: save = False cur_attempt += 1 else: if best_val_ppl >= val_ppl: save = True best_val_ppl = val_ppl cur_attempt = 0 else: save = False cur_attempt += 1 if save or args.always_save: save_checkpoint([step, best_val_ppl, best_val_bleu, cur_attempt, lr], model, optim, hparams, args.output_dir) else: lr = lr * args.lr_dec set_lr(optim, lr) # reset counter after eval log_start_time = time.time() target_words = total_corrects = total_loss = 0 target_rules = target_total = target_eos = 0 total_word_loss = total_rule_loss = total_eos_loss = 0 if args.patience >= 0: if cur_attempt > args.patience: break elif args.n_train_epochs > 0: if epoch >= args.n_train_epochs: break else: if step > args.n_train_steps: break
src1_w = [] with open(args.lex_file.split(",")[0], "r") as myfile: for line in myfile: src1_w.append(line.strip()) model.hparams.train_src_file_list = args.lex_file.split(",") model.hparams.train_trg_file_list = args.lex_file.split(",") model.hparams.dev_src_file = args.lex_file.split(",")[0] model.hparams.dev_trg_file = args.lex_file.split(",")[0] model.hparams.shuffle_train = False model.hparams.batcher = "sent" model.hparams.batch_size = len(src1_w) model.hparams.cuda = args.cuda data = DataUtil(model.hparams, shuffle=False) out = args.out_file.split(",") out1 = open(out[0], "wb") out2 = open(out[1], "wb") step = 0 while True: gc.collect() x_train, x_mask, x_count, x_len, x_pos_emb_idxs, \ y_train, y_mask, y_count, y_len, y_pos_emb_idxs, \ batch_size, x_train_char_sparse, y_train_char_sparse, eop, file_idx = data.next_train() if args.options == 'char-no-spe': model.hparams.sep_char_proj = False print(x_train[44]) print(x_train_char_sparse[44])
def train(): if args.load_model and (not args.reset_hparams): print("load hparams..") hparams_file_name = os.path.join(args.output_dir, "hparams.pt") hparams = torch.load(hparams_file_name) hparams.load_model = args.load_model hparams.n_train_steps = args.n_train_steps optim_file_name = os.path.join(args.output_dir, "optimizer.pt") print("Loading optimizer from {}".format(optim_file_name)) trainable_params = [ p for p in model.parameters() if p.requires_grad] #optim = torch.optim.Adam(trainable_params, lr=hparams.lr, betas=(0.9, 0.98), eps=1e-9, weight_decay=hparams.l2_reg) optim = torch.optim.Adam(trainable_params, lr=hparams.lr, weight_decay=hparams.l2_reg) optimizer_state = torch.load(optim_file_name) optim.load_state_dict(optimizer_state) extra_file_name = os.path.join(args.output_dir, "extra.pt") step, best_val_ppl, best_val_bleu, cur_attempt, lr = torch.load(extra_file_name) else: hparams = HParams(**vars(args)) print("building model...") if args.load_model: data = DataUtil(hparams=hparams) model_file_name = os.path.join(args.output_dir, "model.pt") print("Loading model from '{0}'".format(model_file_name)) model = torch.load(model_file_name) trainable_params = [ p for p in model.parameters() if p.requires_grad] num_params = count_params(trainable_params) print("Model has {0} params".format(num_params)) optim_file_name = os.path.join(args.output_dir, "optimizer.pt") print("Loading optimizer from {}".format(optim_file_name)) #optim = torch.optim.Adam(trainable_params, lr=hparams.lr, betas=(0.9, 0.98), eps=1e-9) optim = torch.optim.Adam(trainable_params, lr=hparams.lr) optimizer_state = torch.load(optim_file_name) optim.load_state_dict(optimizer_state) extra_file_name = os.path.join(args.output_dir, "extra.pt") step, best_loss, best_acc, cur_attempt, lr = torch.load(extra_file_name) else: data = DataUtil(hparams=hparams) if hparams.classifer == "cnn": model = CNNClassify(hparams) else: model = BiLSTMClassify(hparams) if args.cuda: model = model.cuda() #if args.init_type == "uniform": # print("initialize uniform with range {}".format(args.init_range)) # for p in model.parameters(): # p.data.uniform_(-args.init_range, args.init_range) trainable_params = [ p for p in model.parameters() if p.requires_grad] num_params = count_params(trainable_params) print("Model has {0} params".format(num_params)) optim = torch.optim.Adam(trainable_params, lr=hparams.lr) step = 0 best_loss = None best_acc = None cur_attempt = 0 lr = hparams.lr #crit = nn.CrossEntropyLoss(reduction='none') crit = nn.CrossEntropyLoss(reduce=False) print("-" * 80) print("start training...") start_time = log_start_time = time.time() total_loss, total_batch, acc = 0, 0, 0 model.train() epoch = 0 while True: x_train, x_mask, x_count, x_len, x_pos_emb_idxs, y_train, y_mask, y_count, y_len, y_pos_emb_idxs, y_sampled, y_sampled_mask, y_sampled_count, y_sampled_len, y_pos_emb_idxs, batch_size, eop = data.next_train() step += 1 #print(x_train) #print(x_mask) logits = model.forward(x_train, x_mask, x_len, step=step) logits = logits.view(-1, hparams.trg_vocab_size) labels = y_train.view(-1) tr_loss = crit(logits, labels) _, preds = torch.max(logits, dim=1) val_acc = torch.eq(preds, labels).int().sum() acc += val_acc.item() tr_loss = tr_loss.sum() total_loss += tr_loss.item() total_batch += batch_size tr_loss.div_(batch_size) tr_loss.backward() grad_norm = grad_clip(trainable_params, grad_bound=args.clip_grad) optim.step() optim.zero_grad() if eop: epoch += 1 if step % args.log_every == 0: curr_time = time.time() since_start = (curr_time - start_time) / 60.0 elapsed = (curr_time - log_start_time) / 60.0 log_string = "ep={0:<3d}".format(epoch) log_string += " steps={0:<6.2f}".format((step) / 1000) log_string += " lr={0:<9.7f}".format(lr) log_string += " loss={0:<7.2f}".format(total_loss) log_string += " acc={0:<5.4f}".format(acc / total_batch) log_string += " |g|={0:<5.2f}".format(grad_norm) log_string += " wpm(k)={0:<5.2f}".format(total_batch / (1000 * elapsed)) log_string += " time(min)={0:<5.2f}".format(since_start) print(log_string) acc, total_loss, total_batch = 0, 0, 0 log_start_time = time.time() if step % args.eval_every == 0: model.eval() cur_acc, cur_loss = eval(model, data, crit, step, hparams) if not best_acc or best_acc < cur_acc: best_loss, best_acc = cur_loss, cur_acc cur_attempt = 0 save_checkpoint([step, best_loss, best_acc, cur_attempt, lr], model, optim, hparams, args.output_dir) else: if args.lr_dec: lr = lr * args.lr_dec set_lr(optim, lr) cur_attempt += 1 if args.patience and cur_attempt > args.patience: break model.train()
os.makedirs(args.output_dir) elif args.reset_output_dir: print("-" * 80) print("Path {} exists. Remove and remake.".format(args.output_dir)) shutil.rmtree(args.output_dir) os.makedirs(args.output_dir) print("-" * 80) log_file = os.path.join(args.output_dir, "stdout") print("Logging to {}".format(log_file)) sys.stdout = Logger(log_file) train() else: hparams_file_name = os.path.join(args.output_dir, "hparams.pt") hparams = torch.load(hparams_file_name) hparams.decode = True hparams.test_src_file = args.test_src_file hparams.test_trg_file = args.test_trg_file data = DataUtil(hparams=hparams) model_file_name = os.path.join(args.output_dir, "model.pt") print("Loading model from '{0}'".format(model_file_name)) model = torch.load(model_file_name) model.eval() hparams.valid_batch_size = args.valid_batch_size with torch.no_grad(): cur_acc, cur_loss = test(model, data, hparams, args.test_src_file, args.test_trg_file, negate=args.negate) print("test_acc={}, test_loss={}".format(cur_acc, cur_loss))
hparams.src_char_only = train_hparams.src_char_only else: hparams.src_char_only = False if not hasattr(train_hparams, 'n') and (train_hparams.char_ngram_n or train_hparams.char_input): hparams.n = 4 if not hasattr(train_hparams, 'single_n') and (train_hparams.char_ngram_n or train_hparams.char_input): hparams.single_n = False if not hasattr(train_hparams, 'bpe_ngram'): hparams.bpe_ngram = False if not hasattr(train_hparams, 'uni'): hparams.uni = False model.hparams.cuda = hparams.cuda data = DataUtil(hparams=hparams, decode=True) filts = [model.hparams.pad_id, model.hparams.eos_id, model.hparams.bos_id] if not hasattr(model, 'data'): model.data = data if not hasattr(hparams, 'model_type'): if type(model) == Seq2Seq: hparams.model_type = "seq2seq" elif type(model) == Transformer: hparams.model_type = 'transformer' if args.debug: hparams.add_param("target_word_vocab_size", data.target_word_vocab_size) hparams.add_param("target_rule_vocab_size", data.target_rule_vocab_size) crit = get_criterion(hparams)