def train(args):

  class uniform_initializer(object):
      def __init__(self, stdv):
          self.stdv = stdv
      def __call__(self, tensor):
          nn.init.uniform_(tensor, -self.stdv, self.stdv)

  opt_dict = {"not_improved": 0, "lr": 1., "best_ppl": 1e4}

  hparams = HParams(**vars(args))
  data = DataUtil(hparams=hparams)

  model_init = uniform_initializer(0.01)
  emb_init = uniform_initializer(0.1)

  model = LSTM_LM(model_init, emb_init, hparams)

  if hparams.eval_from != "":
    model = torch.load(hparams.eval_from)
    model.to(hparams.device)
    with torch.no_grad():
        if args.generate:
            model.generate(args.gen_len, args.num_gen, hparams, data)
        else:
            test(model, data, hparams)

    return 

  model.to(hparams.device)

  trainable_params = [
    p for p in model.parameters() if p.requires_grad]
  num_params = count_params(trainable_params)
  print("Model has {0} params".format(num_params))

  optim = torch.optim.SGD(model.parameters(), lr=1.0)

  step = epoch = decay_cnt = 0
  report_words = report_loss = report_ppl = report_sents = 0
  start_time = time.time()

  model.train()
  while True:
    x_train, x_mask, x_count, x_len, x_pos_emb_idxs, \
    y_train, y_mask, y_count, y_len, y_pos_emb_idxs, \
    y_sampled, y_sampled_mask, y_sampled_count, y_sampled_len, \
    y_pos_emb_idxs, batch_size,  eop = data.next_train()


    report_words += (x_count - batch_size)
    report_sents += batch_size

    optim.zero_grad()

    loss = model.reconstruct_error(x_train, x_len)


    loss = loss.mean(dim=-1)
    loss.backward()
    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)

    optim.step()

    report_loss += loss.item() * batch_size

    if step % args.log_every == 0:
      curr_time = time.time()
      since_start = (curr_time - start_time) / 60.0

      log_string = "ep={0:<3d}".format(epoch)
      log_string += " steps={}".format(step)
      log_string += " lr={0:<9.7f}".format(opt_dict["lr"])
      log_string += " loss={0:<7.2f}".format(report_loss / report_sents)
      log_string += " |g|={0:<5.2f}".format(grad_norm)

      log_string += " ppl={0:<8.2f}".format(np.exp(report_loss / report_words))

      log_string += " time(min)={0:<5.2f}".format(since_start)
      print(log_string)

    if step % args.eval_every == 0:
      with torch.no_grad():
        val_loss, val_ppl = test(model, data, hparams)
        if val_ppl < opt_dict["best_ppl"]:
          print("update best ppl")
          opt_dict["best_ppl"] = val_ppl
          opt_dict["not_improved"] = 0
          torch.save(model, os.path.join(hparams.output_dir, "model.pt"))

        if val_ppl > opt_dict["best_ppl"]:
          opt_dict["not_improved"] += 1
          if opt_dict["not_improved"] >= decay_step:
            opt_dict["not_improved"] = 0
            opt_dict["lr"] = opt_dict["lr"] * lr_decay
            model = torch.load(os.path.join(hparams.output_dir, "model.pt"))
            print("new lr: {0:<9.7f}".format(opt_dict["lr"]))
            decay_cnt += 1
            optim = torch.optim.SGD(model.parameters(), lr=opt_dict["lr"])

      report_words = report_loss = report_ppl = report_sents = 0
      model.train()

    step += 1
    if eop:
      epoch += 1

    if decay_cnt >= max_decay:
      break
def train():
    print(args)
    if args.load_model and (not args.reset_hparams):
        print("load hparams..")
        hparams_file_name = os.path.join(args.output_dir, "hparams.pt")
        hparams = torch.load(hparams_file_name)
        hparams.load_model = args.load_model
        hparams.n_train_steps = args.n_train_steps
    else:
        hparams = HParams(**vars(args))
        hparams.noise_flag = True

    # build or load model
    print("-" * 80)
    print("Creating model")
    if args.load_model:
        data = DataUtil(hparams=hparams)
        model_file_name = os.path.join(args.output_dir, "model.pt")
        print("Loading model from '{0}'".format(model_file_name))
        model = torch.load(model_file_name)
        if not hasattr(model, 'data'):
            model.data = data
        if not hasattr(model.hparams, 'transformer_wdrop'):
            model.hparams.transformer_wdrop = False

        optim_file_name = os.path.join(args.output_dir, "optimizer.pt")
        print("Loading optimizer from {}".format(optim_file_name))
        trainable_params = [p for p in model.parameters() if p.requires_grad]
        #optim = torch.optim.Adam(trainable_params, lr=hparams.lr, betas=(0.9, 0.98), eps=1e-9, weight_decay=hparams.l2_reg)
        optim = torch.optim.Adam(trainable_params,
                                 lr=hparams.lr,
                                 weight_decay=hparams.l2_reg)
        optimizer_state = torch.load(optim_file_name)
        optim.load_state_dict(optimizer_state)

        extra_file_name = os.path.join(args.output_dir, "extra.pt")
        step, best_val_ppl, best_val_bleu, cur_attempt, cur_decay_attempt, lr = torch.load(
            extra_file_name)
    else:
        if args.pretrained_model:
            model_name = os.path.join(args.pretrained_model, "model.pt")
            print("Loading model from '{0}'".format(model_name))
            model = torch.load(model_name)

            print("load hparams..")
            hparams_file_name = os.path.join(args.pretrained_model,
                                             "hparams.pt")
            reload_hparams = torch.load(hparams_file_name)
            reload_hparams.train_src_file_list = hparams.train_src_file_list
            reload_hparams.train_trg_file_list = hparams.train_trg_file_list
            reload_hparams.dropout = hparams.dropout
            reload_hparams.lr_dec = hparams.lr_dec
            hparams = reload_hparams
            data = DataUtil(hparams=hparams)
            model.data = data
        else:
            data = DataUtil(hparams=hparams)
            if args.model_type == 'seq2seq':
                model = Seq2Seq(hparams=hparams, data=data)
            elif args.model_type == 'transformer':
                model = Transformer(hparams=hparams, data=data)
            else:
                print("Model {} not implemented".format(args.model_type))
                exit(0)
            if args.init_type == "uniform" and not hparams.model_type == "transformer":
                print("initialize uniform with range {}".format(
                    args.init_range))
                for p in model.parameters():
                    p.data.uniform_(-args.init_range, args.init_range)
        trainable_params = [p for p in model.parameters() if p.requires_grad]
        optim = torch.optim.Adam(trainable_params,
                                 lr=hparams.lr,
                                 weight_decay=hparams.l2_reg)
        #optim = torch.optim.Adam(trainable_params)
        step = 0
        #best_val_ppl = None
        best_val_ppl = 1000
        best_val_bleu = None
        cur_attempt = 0
        cur_decay_attempt = 0
        lr = hparams.lr

    model.set_lm()
    model.to(hparams.device)

    if args.eval_cls:
        classifier_file_name = os.path.join(args.classifier_dir, "model.pt")
        print("Loading model from '{0}'".format(classifier_file_name))
        classifier = torch.load(classifier_file_name).to(hparams.device)
    else:
        classifier = None

    if args.reset_hparams:
        lr = args.lr
    crit = get_criterion(hparams)
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    num_params = count_params(trainable_params)
    print("Model has {0} params".format(num_params))

    if args.load_for_test:
        val_ppl, val_bleu = eval(model,
                                 classifier,
                                 data,
                                 crit,
                                 step,
                                 hparams,
                                 eval_bleu=args.eval_bleu,
                                 valid_batch_size=args.valid_batch_size)
        return

    print("-" * 80)
    print("start training...")
    start_time = log_start_time = time.time()
    target_words = total_loss = total_sents = total_noise_corrects = total_transfer_corrects = 0
    total_bt_loss = total_noise_loss = total_KL_loss = 0.
    target_rules, target_total, target_eos = 0, 0, 0
    total_word_loss, total_rule_loss, total_eos_loss = 0, 0, 0
    total_lm_length = total_trans_length = 0
    model.train()
    #i = 0
    dev_zero = args.dev_zero
    tr_loss, update_batch_size = None, 0
    if hparams.anneal_epoch == 0:
        hparams.noise_weight = 0.
        anneal_rate = 0.
    else:
        hparams.noise_weight = 1.
        if hparams.anneal_epoch == -1:
            anneal_rate = 0.
        else:
            anneal_rate = 1.0 / (data.train_size * args.anneal_epoch //
                                 hparams.batch_size)

    hparams.gs_temp = 1.
    while True:
        step += 1
        x_train, x_mask, x_count, x_len, x_pos_emb_idxs, y_train, y_mask, y_count, y_len, y_pos_emb_idxs, y_sampled, y_sampled_mask, y_sampled_count, y_sampled_len, y_pos_emb_idxs, batch_size, eop = data.next_train(
        )
        #print(x_train)
        #print(x_count)
        target_words += (x_count - batch_size)
        total_sents += batch_size
        trans_logits, noise_logits, KL_loss, lm_len, trans_len = model.forward(
            x_train, x_mask, x_len, x_pos_emb_idxs, y_train, y_mask, y_len,
            y_pos_emb_idxs, y_sampled, y_sampled_mask, y_sampled_len)

        total_lm_length += lm_len
        total_trans_length += trans_len

        # not predicting the start symbol
        labels = x_train[:, 1:].contiguous().view(-1)

        cur_tr_loss, trans_loss, noise_loss, cur_tr_acc, cur_tr_transfer_acc = get_performance(
            crit, trans_logits, noise_logits, labels, hparams, x_len)

        assert (cur_tr_loss.item() > 0)

        if hparams.lm:
            cur_tr_loss = cur_tr_loss + hparams.klw * KL_loss.sum()
            total_KL_loss += KL_loss.sum().item()

        hparams.noise_weight = max(0., hparams.noise_weight - anneal_rate)
        if hparams.noise_weight == 0:
            hparams.noise_flag = False

        # if eop:
        #     hparams.gs_temp = max(0.001, hparams.gs_temp * 0.5)

        total_loss += cur_tr_loss.item()
        total_bt_loss += trans_loss.item()
        total_noise_loss += noise_loss.item()

        total_noise_corrects += cur_tr_acc
        total_transfer_corrects += cur_tr_transfer_acc
        if tr_loss is None:
            tr_loss = cur_tr_loss
        else:
            tr_loss = tr_loss + cur_tr_loss
        update_batch_size += batch_size

        if step % args.update_batch == 0:
            # set learning rate
            if args.lr_schedule:
                s = step / args.update_batch + 1
                lr = pow(hparams.d_model, -0.5) * min(
                    pow(s, -0.5), s * pow(hparams.n_warm_ups, -1.5))
                set_lr(optim, lr)
            elif step / args.update_batch < hparams.n_warm_ups:
                base_lr = hparams.lr
                base_lr = base_lr * (step / args.update_batch +
                                     1) / hparams.n_warm_ups
                set_lr(optim, base_lr)
                lr = base_lr
            elif args.lr_dec_steps > 0:
                s = (step / args.update_batch) % args.lr_dec_steps
                lr = args.lr_min + 0.5 * (args.lr_max - args.lr_min) * (
                    1 + np.cos(s * np.pi / args.lr_dec_steps))
                set_lr(optim, lr)
            tr_loss = tr_loss / update_batch_size
            tr_loss.backward()
            #grad_norm = grad_clip(trainable_params, grad_bound=args.clip_grad)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.clip_grad)
            optim.step()
            optim.zero_grad()
            tr_loss = None
            update_batch_size = 0
        # clean up GPU memory
        if step % args.clean_mem_every == 0:
            gc.collect()
        epoch = step // data.n_train_batches
        if (step / args.update_batch) % args.log_every == 0:
            curr_time = time.time()
            since_start = (curr_time - start_time) / 60.0
            elapsed = (curr_time - log_start_time) / 60.0
            log_string = "ep={0:<3d}".format(epoch)
            log_string += " steps={0:<6.2f}".format(
                (step / args.update_batch) / 1000)
            log_string += " lr={0:<9.7f}".format(lr)
            log_string += " total={0:<7.2f}".format(total_loss / total_sents)
            log_string += " neg ELBO={0:<7.2f}".format(
                (total_KL_loss + total_bt_loss) / total_sents)
            log_string += " KL={0:<7.2f}".format(total_KL_loss / total_sents)
            log_string += " |g|={0:<5.2f}".format(grad_norm)

            log_string += " bt_ppl={0:<8.2f}".format(
                np.exp(total_bt_loss / target_words))
            log_string += " n_ppl={0:<8.2f}".format(
                np.exp(total_noise_loss / target_words))
            log_string += " n_acc={0:<5.4f}".format(total_noise_corrects /
                                                    target_words)
            log_string += " bt_acc={0:<5.4f}".format(total_transfer_corrects /
                                                     target_words)

            # noise weight
            log_string += " nw={:.4f}".format(hparams.noise_weight)

            log_string += " lmlen={}".format(total_lm_length // total_sents)
            log_string += " translen={}".format(total_trans_length //
                                                total_sents)
            # log_string += " wpm(k)={0:<5.2f}".format(target_words / (1000 * elapsed))
            log_string += " t={0:<5.2f}".format(since_start)
            print(log_string)
        if args.eval_end_epoch:
            if eop:
                eval_now = True
            else:
                eval_now = False
        elif (step / args.update_batch) % args.eval_every == 0:
            eval_now = True
        else:
            eval_now = False
        if eval_now:
            # based_on_bleu = args.eval_bleu and best_val_ppl is not None and best_val_ppl <= args.ppl_thresh
            based_on_bleu = False
            if args.dev_zero: based_on_bleu = True
            print("target words: {}".format(target_words))
            with torch.no_grad():
                val_ppl, val_bleu = eval(
                    model,
                    classifier,
                    data,
                    crit,
                    step,
                    hparams,
                    eval_bleu=args.eval_bleu,
                    valid_batch_size=args.valid_batch_size)
            if based_on_bleu:
                if best_val_bleu is None or best_val_bleu <= val_bleu:
                    save = True
                    best_val_bleu = val_bleu
                    cur_attempt = 0
                    cur_decay_attempt = 0
                else:
                    save = False
            else:
                if best_val_ppl is None or best_val_ppl >= val_ppl:
                    save = True
                    best_val_ppl = val_ppl
                    cur_attempt = 0
                    cur_decay_attempt = 0
                else:
                    save = False
            if save or args.always_save:
                save_checkpoint([
                    step, best_val_ppl, best_val_bleu, cur_attempt,
                    cur_decay_attempt, lr
                ], model, optim, hparams, args.output_dir)
            elif not args.lr_schedule and step >= hparams.n_warm_ups:
                if cur_decay_attempt >= args.attempt_before_decay:
                    if val_ppl >= 2 * best_val_ppl:
                        print("reload saved best model !!!")
                        model.load_state_dict(
                            torch.load(
                                os.path.join(args.output_dir, "model.dict")))
                        hparams = torch.load(
                            os.path.join(args.output_dir, "hparams.pt"))
                    lr = lr * args.lr_dec
                    set_lr(optim, lr)
                    cur_attempt += 1
                    cur_decay_attempt = 0
                else:
                    cur_decay_attempt += 1
            # reset counter after eval
            log_start_time = time.time()
            target_words = total_sents = total_noise_corrects = total_transfer_corrects = total_loss = 0
            total_bt_loss = total_noise_loss = total_KL_loss = 0.
            target_rules = target_total = target_eos = 0
            total_word_loss = total_rule_loss = total_eos_loss = 0
            total_lm_length = total_trans_length = 0
        if args.patience >= 0:
            if cur_attempt > args.patience: break
        elif args.n_train_epochs > 0:
            if epoch >= args.n_train_epochs: break
        else:
            if step > args.n_train_steps: break
Beispiel #3
0
#hparams.data_path=args.data_path
#hparams.src_vocab_list=args.src_vocab_list
#hparams.trg_vocab_list=args.trg_vocab_list
hparams.test_src_file = args.test_src_file
hparams.test_trg_file = args.test_trg_file
hparams.cuda=args.cuda
hparams.beam_size=args.beam_size
hparams.max_len=args.max_len
hparams.batch_size=args.batch_size
hparams.merge_bpe=args.merge_bpe
hparams.out_file=out_file
hparams.nbest=args.nbest
hparams.decode=True

model.hparams.cuda = hparams.cuda
data = DataUtil(hparams=hparams, decode=True)
filts = [model.hparams.pad_id, model.hparams.eos_id, model.hparams.bos_id]

if not hasattr(model, 'data'):
  model.data = data

if args.debug:
  hparams.add_param("target_word_vocab_size", data.target_word_vocab_size)
  hparams.add_param("target_rule_vocab_size", data.target_rule_vocab_size)
  crit = get_criterion(hparams)

out_file = open(hparams.out_file, 'w', encoding='utf-8')

end_of_epoch = False
num_sentences = 0
Beispiel #4
0
def train():
  if args.load_model and (not args.reset_hparams):
    print("load hparams..")
    hparams_file_name = os.path.join(args.output_dir, "hparams.pt")
    hparams = torch.load(hparams_file_name)
    hparams.load_model = args.load_model
    hparams.n_train_steps = args.n_train_steps
  else:
    hparams = HParams(
      decode=args.decode,
      data_path=args.data_path,
      train_src_file_list=args.train_src_file_list,
      train_trg_file_list=args.train_trg_file_list,
      dev_src_file=args.dev_src_file,
      dev_trg_file=args.dev_trg_file,
      src_vocab_list=args.src_vocab_list,
      trg_vocab_list=args.trg_vocab_list,
      src_vocab_size=args.src_vocab_size,
      trg_vocab_size=args.trg_vocab_size,
      max_len=args.max_len,
      n_train_sents=args.n_train_sents,
      cuda=args.cuda,
      d_word_vec=args.d_word_vec,
      d_model=args.d_model,
      batch_size=args.batch_size,
      batcher=args.batcher,
      n_train_steps=args.n_train_steps,
      dropout=args.dropout,
      lr=args.lr,
      lr_dec=args.lr_dec,
      l2_reg=args.l2_reg,
      init_type=args.init_type,
      init_range=args.init_range,
      share_emb_softmax=args.share_emb_softmax,
      n_heads=args.n_heads,
      d_k=args.d_k,
      d_v=args.d_v,
      merge_bpe=args.merge_bpe,
      load_model=args.load_model,
      char_ngram_n=args.char_ngram_n,
      max_char_vocab_size=args.max_char_vocab_size,
      char_input=args.char_input,
      char_comb=args.char_comb,
      char_temp=args.char_temp,
      src_char_vocab_from=args.src_char_vocab_from,
      src_char_vocab_size=args.src_char_vocab_size,
      trg_char_vocab_from=args.trg_char_vocab_from,
      trg_char_vocab_size=args.trg_char_vocab_size,
      src_char_only=args.src_char_only,
      trg_char_only=args.trg_char_only,
      semb=args.semb,
      dec_semb=args.dec_semb,
      semb_vsize=args.semb_vsize,
      lan_code_rl=args.lan_code_rl,
      sample_rl=args.sample_rl,
      sep_char_proj=args.sep_char_proj,
      query_base=args.query_base,
      residue=args.residue,
      layer_norm=args.layer_norm,
      src_no_char=args.src_no_char,
      trg_no_char=args.trg_no_char,
      char_gate=args.char_gate,
      shuffle_train=args.shuffle_train,
      ordered_char_dict=args.ordered_char_dict,
      out_c_list=args.out_c_list,
      k_list=args.k_list,
      d_char_vec=args.d_char_vec,
      highway=args.highway,
      n=args.n,
      single_n=args.single_n,
      bpe_ngram=args.bpe_ngram,
      uni=args.uni,
      pretrained_src_emb_list=args.pretrained_src_emb_list,
      pretrained_trg_emb=args.pretrained_trg_emb,
    )
  # build or load model
  print("-" * 80)
  print("Creating model")
  if args.load_model:
    data = DataUtil(hparams=hparams)
    model_file_name = os.path.join(args.output_dir, "model.pt")
    print("Loading model from '{0}'".format(model_file_name))
    model = torch.load(model_file_name)
    if not hasattr(model, 'data'):
      model.data = data

    optim_file_name = os.path.join(args.output_dir, "optimizer.pt")
    print("Loading optimizer from {}".format(optim_file_name))
    trainable_params = [
      p for p in model.parameters() if p.requires_grad]
    optim = torch.optim.Adam(trainable_params, lr=hparams.lr, weight_decay=hparams.l2_reg)
    optimizer_state = torch.load(optim_file_name)
    optim.load_state_dict(optimizer_state)

    extra_file_name = os.path.join(args.output_dir, "extra.pt")
    step, best_val_ppl, best_val_bleu, cur_attempt, lr = torch.load(extra_file_name)
  else:
    if args.pretrained_model:
      model_name = os.path.join(args.pretrained_model, "model.pt")
      print("Loading model from '{0}'".format(model_name))
      model = torch.load(model_name)
      #if not hasattr(model, 'data'):
      #  model.data = data
      #if not hasattr(model, 'char_ngram_n'):
      #  model.hparams.char_ngram_n = 0
      #if not hasattr(model, 'char_input'):
      #  model.hparams.char_input = None
      print("load hparams..")
      hparams_file_name = os.path.join(args.pretrained_model, "hparams.pt")
      reload_hparams = torch.load(hparams_file_name)
      reload_hparams.train_src_file_list = hparams.train_src_file_list
      reload_hparams.train_trg_file_list = hparams.train_trg_file_list
      reload_hparams.dropout = hparams.dropout
      reload_hparams.lr_dec = hparams.lr_dec
      hparams = reload_hparams
      #hparams.src_vocab_list = reload_hparams.src_vocab_list 
      #hparams.src_vocab_size = reload_hparams.src_vocab_size 
      #hparams.trg_vocab_list = reload_hparams.trg_vocab_list 
      #hparams.trg_vocab_size = reload_hparams.trg_vocab_size 
      #hparams.src_char_vocab_from = reload_hparams.src_char_vocab_from 
      #hparams.src_char_vocab_size = reload_hparams.src_char_vocab_size 
      #hparams.trg_char_vocab_from = reload_hparams.trg_char_vocab_from 
      #hparams.trg_char_vocab_size = reload_hparams.trg_char_vocab_size
      #print(reload_hparams.src_char_vocab_from)
      #print(reload_hparams.src_char_vocab_size)
      data = DataUtil(hparams=hparams)
      model.data = data
    else:
      data = DataUtil(hparams=hparams)
      model = Seq2Seq(hparams=hparams, data=data)
      if args.init_type == "uniform":
        print("initialize uniform with range {}".format(args.init_range))
        for p in model.parameters():
          p.data.uniform_(-args.init_range, args.init_range)
      if args.id_init_sep and args.semb and args.sep_char_proj:
        print("initialize char proj as identity matrix")
        for s in model.encoder.char_emb.sep_proj_list:
          d = s.weight.data.size(0)
          s.weight.data.copy_(torch.eye(d) + args.id_scale*torch.diagflat(torch.ones(d).normal_(0,1)))
    trainable_params = [
      p for p in model.parameters() if p.requires_grad]
    optim = torch.optim.Adam(trainable_params, lr=hparams.lr, weight_decay=hparams.l2_reg)
    #optim = torch.optim.Adam(trainable_params)
    step = 0
    best_val_ppl = 100
    best_val_bleu = 0
    cur_attempt = 0
    lr = hparams.lr

  if args.cuda:
    model = model.cuda()

  if args.reset_hparams:
    lr = args.lr
  crit = get_criterion(hparams)
  trainable_params = [
    p for p in model.parameters() if p.requires_grad]
  num_params = count_params(trainable_params)
  print("Model has {0} params".format(num_params))

  print("-" * 80)
  print("start training...")
  start_time = log_start_time = time.time()
  target_words, total_loss, total_corrects = 0, 0, 0
  target_rules, target_total, target_eos = 0, 0, 0
  total_word_loss, total_rule_loss, total_eos_loss = 0, 0, 0
  model.train()
  #i = 0
  dev_zero = args.dev_zero
  while True:
    x_train, x_mask, x_count, x_len, y_train, y_mask, y_count, y_len, batch_size, x_train_char_sparse, y_train_char_sparse, eop, file_idx = data.next_train()
    optim.zero_grad()
    target_words += (y_count - batch_size)
    logits = model.forward(x_train, x_mask, x_len, y_train[:,:-1], y_mask[:,:-1], y_len, x_train_char_sparse, y_train_char_sparse, file_idx=file_idx)
    logits = logits.view(-1, hparams.trg_vocab_size)
    labels = y_train[:,1:].contiguous().view(-1)
    #print(labels)
    tr_loss, tr_acc = get_performance(crit, logits, labels, hparams)
    total_loss += tr_loss.item()
    total_corrects += tr_acc.item()
    step += 1

    if dev_zero:
      dev_zero = False
      #based_on_bleu = args.eval_bleu and best_val_ppl <= args.ppl_thresh
      based_on_bleu = args.eval_bleu
      val_ppl, val_bleu = eval(model, data, crit, step, hparams, eval_bleu=based_on_bleu, valid_batch_size=args.valid_batch_size, tr_logits=logits)	
      if based_on_bleu:
        if best_val_bleu <= val_bleu:
          save = True 
          best_val_bleu = val_bleu
          cur_attempt = 0
        else:
          save = False
          cur_attempt += 1
      else:
      	if best_val_ppl >= val_ppl:
          save = True
          best_val_ppl = val_ppl
          cur_attempt = 0 
      	else:
          save = False
          cur_attempt += 1
      if save or args.always_save:
      	save_checkpoint([step, best_val_ppl, best_val_bleu, cur_attempt, lr], 
      		             model, optim, hparams, args.output_dir)
      else:
        lr = lr * args.lr_dec
        set_lr(optim, lr)
      # reset counter after eval
      log_start_time = time.time()
      target_words = total_corrects = total_loss = 0
      target_rules = target_total = target_eos = 0
      total_word_loss = total_rule_loss = total_eos_loss = 0

    tr_loss.div_(batch_size)
    tr_loss.backward()
    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
    optim.step()
    # clean up GPU memory
    if step % args.clean_mem_every == 0:
      gc.collect()
    epoch = step // sum(data.n_train_batches)
    if step % args.log_every == 0:
      curr_time = time.time()
      since_start = (curr_time - start_time) / 60.0
      elapsed = (curr_time - log_start_time) / 60.0
      log_string = "ep={0:<3d}".format(epoch)
      log_string += " steps={0:<6.2f}".format(step / 1000)
      log_string += " lr={0:<9.7f}".format(lr)
      log_string += " loss={0:<7.2f}".format(tr_loss.item())
      log_string += " |g|={0:<5.2f}".format(grad_norm)

      log_string += " ppl={0:<8.2f}".format(np.exp(total_loss / target_words))
      log_string += " acc={0:<5.4f}".format(total_corrects / target_words)

      log_string += " wpm(k)={0:<5.2f}".format(target_words / (1000 * elapsed))
      log_string += " time(min)={0:<5.2f}".format(since_start)
      print(log_string)
    if args.eval_end_epoch:
      if eop:
        eval_now = True
      else:
        eval_now = False
    elif step % args.eval_every == 0:
      eval_now = True
    else:
      eval_now = False 
    if eval_now:
      based_on_bleu = args.eval_bleu and best_val_ppl <= args.ppl_thresh
      if args.dev_zero: based_on_bleu = True
      with torch.no_grad():
        val_ppl, val_bleu = eval(model, data, crit, step, hparams, eval_bleu=based_on_bleu, valid_batch_size=args.valid_batch_size, tr_logits=logits)	
      if based_on_bleu:
        if best_val_bleu <= val_bleu:
          save = True 
          best_val_bleu = val_bleu
          cur_attempt = 0
        else:
          save = False
          cur_attempt += 1
      else:
      	if best_val_ppl >= val_ppl:
          save = True
          best_val_ppl = val_ppl
          cur_attempt = 0 
      	else:
          save = False
          cur_attempt += 1
      if save or args.always_save:
      	save_checkpoint([step, best_val_ppl, best_val_bleu, cur_attempt, lr], 
      		             model, optim, hparams, args.output_dir)
      else:
        lr = lr * args.lr_dec
        set_lr(optim, lr)
      # reset counter after eval
      log_start_time = time.time()
      target_words = total_corrects = total_loss = 0
      target_rules = target_total = target_eos = 0
      total_word_loss = total_rule_loss = total_eos_loss = 0
    if args.patience >= 0:
      if cur_attempt > args.patience: break
    elif args.n_train_epochs > 0:
      if epoch >= args.n_train_epochs: break
    else:
      if step > args.n_train_steps: break 
Beispiel #5
0
src1_w = []
with open(args.lex_file.split(",")[0], "r") as myfile:
    for line in myfile:
        src1_w.append(line.strip())

model.hparams.train_src_file_list = args.lex_file.split(",")
model.hparams.train_trg_file_list = args.lex_file.split(",")
model.hparams.dev_src_file = args.lex_file.split(",")[0]
model.hparams.dev_trg_file = args.lex_file.split(",")[0]
model.hparams.shuffle_train = False
model.hparams.batcher = "sent"
model.hparams.batch_size = len(src1_w)
model.hparams.cuda = args.cuda

data = DataUtil(model.hparams, shuffle=False)

out = args.out_file.split(",")
out1 = open(out[0], "wb")
out2 = open(out[1], "wb")

step = 0
while True:
    gc.collect()
    x_train, x_mask, x_count, x_len, x_pos_emb_idxs, \
   y_train, y_mask, y_count, y_len, y_pos_emb_idxs, \
   batch_size, x_train_char_sparse, y_train_char_sparse, eop, file_idx = data.next_train()
    if args.options == 'char-no-spe':
        model.hparams.sep_char_proj = False
        print(x_train[44])
        print(x_train_char_sparse[44])
Beispiel #6
0
def train():
  if args.load_model and (not args.reset_hparams):
    print("load hparams..")
    hparams_file_name = os.path.join(args.output_dir, "hparams.pt")
    hparams = torch.load(hparams_file_name)
    hparams.load_model = args.load_model
    hparams.n_train_steps = args.n_train_steps

    optim_file_name = os.path.join(args.output_dir, "optimizer.pt")
    print("Loading optimizer from {}".format(optim_file_name))
    trainable_params = [
      p for p in model.parameters() if p.requires_grad]
    #optim = torch.optim.Adam(trainable_params, lr=hparams.lr, betas=(0.9, 0.98), eps=1e-9, weight_decay=hparams.l2_reg)
    optim = torch.optim.Adam(trainable_params, lr=hparams.lr, weight_decay=hparams.l2_reg)
    optimizer_state = torch.load(optim_file_name)
    optim.load_state_dict(optimizer_state)

    extra_file_name = os.path.join(args.output_dir, "extra.pt")
    step, best_val_ppl, best_val_bleu, cur_attempt, lr = torch.load(extra_file_name)
  else:
    hparams = HParams(**vars(args))

  print("building model...")
  if args.load_model:
    data = DataUtil(hparams=hparams)
    model_file_name = os.path.join(args.output_dir, "model.pt")
    print("Loading model from '{0}'".format(model_file_name))
    model = torch.load(model_file_name)
    trainable_params = [
      p for p in model.parameters() if p.requires_grad]
    num_params = count_params(trainable_params)
    print("Model has {0} params".format(num_params))

    optim_file_name = os.path.join(args.output_dir, "optimizer.pt")
    print("Loading optimizer from {}".format(optim_file_name))
    #optim = torch.optim.Adam(trainable_params, lr=hparams.lr, betas=(0.9, 0.98), eps=1e-9)
    optim = torch.optim.Adam(trainable_params, lr=hparams.lr)
    optimizer_state = torch.load(optim_file_name)
    optim.load_state_dict(optimizer_state)

    extra_file_name = os.path.join(args.output_dir, "extra.pt")
    step, best_loss, best_acc, cur_attempt, lr = torch.load(extra_file_name)
  else:
    data = DataUtil(hparams=hparams)
    if hparams.classifer == "cnn":
        model = CNNClassify(hparams)
    else:
        model = BiLSTMClassify(hparams)
    if args.cuda:
      model = model.cuda()
    #if args.init_type == "uniform":
    #  print("initialize uniform with range {}".format(args.init_range))
    #  for p in model.parameters():
    #    p.data.uniform_(-args.init_range, args.init_range)
    trainable_params = [
      p for p in model.parameters() if p.requires_grad]
    num_params = count_params(trainable_params)
    print("Model has {0} params".format(num_params))

    optim = torch.optim.Adam(trainable_params, lr=hparams.lr)
    step = 0
    best_loss = None
    best_acc = None
    cur_attempt = 0
    lr = hparams.lr

  #crit = nn.CrossEntropyLoss(reduction='none')
  crit = nn.CrossEntropyLoss(reduce=False)

  print("-" * 80)
  print("start training...")
  start_time = log_start_time = time.time()
  total_loss, total_batch, acc = 0, 0, 0
  model.train()
  epoch = 0
  while True:
    x_train, x_mask, x_count, x_len, x_pos_emb_idxs, y_train, y_mask, y_count, y_len, y_pos_emb_idxs, y_sampled, y_sampled_mask, y_sampled_count, y_sampled_len, y_pos_emb_idxs, batch_size,  eop = data.next_train()
    step += 1
    #print(x_train)
    #print(x_mask)
    logits = model.forward(x_train, x_mask, x_len, step=step)
    logits = logits.view(-1, hparams.trg_vocab_size)
    labels = y_train.view(-1)

    tr_loss = crit(logits, labels)
    _, preds = torch.max(logits, dim=1)
    val_acc = torch.eq(preds, labels).int().sum()

    acc += val_acc.item()
    tr_loss = tr_loss.sum()
    total_loss += tr_loss.item()
    total_batch += batch_size

    tr_loss.div_(batch_size)
    tr_loss.backward()
    grad_norm = grad_clip(trainable_params, grad_bound=args.clip_grad)
    optim.step()
    optim.zero_grad()
    if eop: epoch += 1
    if step % args.log_every == 0:
      curr_time = time.time()
      since_start = (curr_time - start_time) / 60.0
      elapsed = (curr_time - log_start_time) / 60.0
      log_string = "ep={0:<3d}".format(epoch)
      log_string += " steps={0:<6.2f}".format((step) / 1000)
      log_string += " lr={0:<9.7f}".format(lr)
      log_string += " loss={0:<7.2f}".format(total_loss)
      log_string += " acc={0:<5.4f}".format(acc / total_batch)
      log_string += " |g|={0:<5.2f}".format(grad_norm)


      log_string += " wpm(k)={0:<5.2f}".format(total_batch / (1000 * elapsed))
      log_string += " time(min)={0:<5.2f}".format(since_start)
      print(log_string)
      acc, total_loss, total_batch = 0, 0, 0
      log_start_time = time.time()

    if step % args.eval_every == 0:
      model.eval()
      cur_acc, cur_loss = eval(model, data, crit, step, hparams)
      if not best_acc or best_acc < cur_acc:
        best_loss, best_acc = cur_loss, cur_acc
        cur_attempt = 0
        save_checkpoint([step, best_loss, best_acc, cur_attempt, lr], model, optim, hparams, args.output_dir)
      else:
        if args.lr_dec:
          lr = lr * args.lr_dec
          set_lr(optim, lr)

        cur_attempt += 1
        if args.patience and cur_attempt > args.patience: break
      model.train()
Beispiel #7
0
      os.makedirs(args.output_dir)
    elif args.reset_output_dir:
      print("-" * 80)
      print("Path {} exists. Remove and remake.".format(args.output_dir))
      shutil.rmtree(args.output_dir)
      os.makedirs(args.output_dir)

    print("-" * 80)
    log_file = os.path.join(args.output_dir, "stdout")
    print("Logging to {}".format(log_file))
    sys.stdout = Logger(log_file)

    train()
  else:
    hparams_file_name = os.path.join(args.output_dir, "hparams.pt")
    hparams = torch.load(hparams_file_name)
    hparams.decode = True
    hparams.test_src_file = args.test_src_file
    hparams.test_trg_file = args.test_trg_file

    data = DataUtil(hparams=hparams)
    model_file_name = os.path.join(args.output_dir, "model.pt")
    print("Loading model from '{0}'".format(model_file_name))
    model = torch.load(model_file_name)
    model.eval()
    hparams.valid_batch_size = args.valid_batch_size
    with torch.no_grad():
        cur_acc, cur_loss = test(model, data, hparams, args.test_src_file, args.test_trg_file, negate=args.negate)
    print("test_acc={}, test_loss={}".format(cur_acc, cur_loss))

Beispiel #8
0
    hparams.src_char_only = train_hparams.src_char_only
else:
    hparams.src_char_only = False
if not hasattr(train_hparams, 'n') and (train_hparams.char_ngram_n
                                        or train_hparams.char_input):
    hparams.n = 4
if not hasattr(train_hparams, 'single_n') and (train_hparams.char_ngram_n
                                               or train_hparams.char_input):
    hparams.single_n = False
if not hasattr(train_hparams, 'bpe_ngram'):
    hparams.bpe_ngram = False
if not hasattr(train_hparams, 'uni'):
    hparams.uni = False

model.hparams.cuda = hparams.cuda
data = DataUtil(hparams=hparams, decode=True)
filts = [model.hparams.pad_id, model.hparams.eos_id, model.hparams.bos_id]

if not hasattr(model, 'data'):
    model.data = data
if not hasattr(hparams, 'model_type'):
    if type(model) == Seq2Seq:
        hparams.model_type = "seq2seq"
    elif type(model) == Transformer:
        hparams.model_type = 'transformer'

if args.debug:
    hparams.add_param("target_word_vocab_size", data.target_word_vocab_size)
    hparams.add_param("target_rule_vocab_size", data.target_rule_vocab_size)
    crit = get_criterion(hparams)