def __init__(self, model, beam_size=1): self.model = model self.model.eval() self.transition = SRTransition() self.beam_size = beam_size
def train(cdtb): model = build_model(cdtb.train) transition = SRTransition() num_epoch = config.get(config_section, "num_epoch", rtype=int) batch_size = config.get(config_section, "batch_size", rtype=int) eval_every = config.get(config_section, "eval_every", rtype=int) lr = config.get(config_section, "lr", rtype=float) l2 = config.get(config_section, "l2_penalty", rtype=float) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, weight_decay=l2) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6, 12], gamma=0.5) model.train() optimizer.zero_grad() step = 0 batch = 0 batch_loss = 0. best_model_score = 0. model_dir = config.get(config_section, "model_dir") for epoch in range(num_epoch): epoch += 1 scheduler.step() print("learning rate: %f" % scheduler.get_lr()[0]) for discourse in np.random.permutation(cdtb.train): step += 1 conf = SRConfiguration(discourse.strip()) state = model.new_state(conf) session = Session(conf, transition, state=state) scores = [] grounds = [] for label in sr_oracle(discourse): action, nuclear = label scores.append(model(session.state)) grounds.append(model.label2idx[label]) if action == SRTransition.SHIFT: session(action) session.state = model.shift(session.state) else: session(action, nuclear=nuclear) session.state = model.reduce(session.state) loss = criterion(torch.stack(scores), torch.Tensor(grounds).long()) loss.backward() batch_loss += loss.item() if step % batch_size == 0: batch += 1 optimizer.step() optimizer.zero_grad() print("step %d, epoch: %d, batch: %d, batch loss: %.3f" % (step, epoch, batch, batch_loss / batch_size)) batch_loss = 0. if batch % eval_every == 0: model_score = evaluate(model, cdtb.validate) evaluate(model, cdtb.test) if model_score > best_model_score: best_model_score = model_score with open("%s.%.3f" % (model_dir, model_score), "wb+") as best_model_fd: print("save new best model to %s.%.3f" % (model_dir, model_score)) torch.save(model, best_model_fd) model.train() # copy best model to model dir shutil.copy2("%s.%.3f" % (model_dir, best_model_score), model_dir)