def train_model(args, model, train, dev, save_path=None, maxsteps=None, writer=None): # optimizer if args.optimizer == 'Adam': opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], betas=(0.9, 0.98), eps=1e-9) else: raise NotImplementedError # if resume training if (args.load_from is not None) and (args.resume): with torch.cuda.device(args.gpu): # very important. offset, opt_states = torch.load(args.models_dir + '/' + args.load_from + '.pt.states', map_location=lambda storage, loc: storage.cuda()) if not args.finetune: # if finetune, do not have history opt.load_state_dict(opt_states) else: offset = 0 # metrics if save_path is None: save_path = args.model_name args.eval_every *= args.inter_size best = Best(max, 'corpus_bleu', 'corpus_gleu', 'gleu', 'loss', 'i', model=model, opt=opt, path=save_path, gpu=args.gpu) train_metrics = Metrics('train', 'loss', 'real', 'fake') dev_metrics = Metrics('dev', 'loss', 'gleu', 'real_loss', 'fake_loss', 'distance', 'alter_loss', 'distance2', 'fertility_loss', 'corpus_gleu') progressbar = tqdm(total=args.eval_every, desc='start training.') examples = 0 first_step = True loss_outer = 0 for iters, batch in enumerate(train): iters += offset # --- saving --- # if iters % args.save_every == 0: args.logger.info('save (back-up) checkpoints at iter={}'.format(iters)) with torch.cuda.device(args.gpu): torch.save(best.model.state_dict(), '{}_iter={}.pt'.format(args.model_name, iters)) torch.save([iters, best.opt.state_dict()], '{}_iter={}.pt.states'.format(args.model_name, iters)) # --- validation --- # if ((args.eval_every_examples == -1) and (iters % args.eval_every == 0)) \ or ((args.eval_every_examples > 0) and (examples > args.eval_every_examples)) \ or first_step: first_step = False if args.eval_every_examples > 0: examples = examples % args.eval_every_examples for dev_iters, dev_batch in enumerate(dev): progressbar.close() dev_metrics.reset() if args.distillation: outputs_course = valid_model(args, model, dev, dev_metrics, distillation=True) outputs_data = valid_model(args, model, dev, None if args.distillation else dev_metrics, print_out=True) if args.tensorboard and (not args.debug): writer.add_scalar('dev/GLEU_sentence_', dev_metrics.gleu, iters / args.inter_size) writer.add_scalar('dev/Loss', dev_metrics.loss, iters / args.inter_size) writer.add_scalar('dev/GLEU_corpus_', outputs_data['corpus_gleu'], iters / args.inter_size) writer.add_scalar('dev/BLEU_corpus_', outputs_data['corpus_bleu'], iters / args.inter_size) if not args.debug: best.accumulate(outputs_data['corpus_bleu'], outputs_data['corpus_gleu'], dev_metrics.gleu, dev_metrics.loss, iters / args.inter_size) args.logger.info('the best model is achieved at {}, average greedy GLEU={}, corpus GLEU={}, corpus BLEU={}'.format( best.i, best.gleu, best.corpus_gleu, best.corpus_bleu)) args.logger.info('model:' + args.prefix + args.hp_str) # ---set-up a new progressor--- progressbar = tqdm(total=args.eval_every, desc='start training.') if maxsteps is None: maxsteps = args.maximum_steps if iters > maxsteps: args.logger.info('reach the maximum updating steps.') break # --- training --- # model.train() def get_learning_rate(i, lr0=0.1, disable=False): if not disable: return lr0 * 10 / math.sqrt(args.d_model) * min(1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup))) return 0.00002 if iters % args.inter_size == 0: opt.param_groups[0]['lr'] = get_learning_rate(iters / args.inter_size + 1, disable=args.disable_lr_schedule) opt.zero_grad() loss_outer = 0 # prepare the data inputs, input_masks, \ targets, target_masks, \ sources, source_masks,\ encoding, batch_size = model.quick_prepare(batch, args.distillation) input_reorder, fertility_cost, decoder_inputs = None, None, inputs examples += batch_size # Maximum Likelihood Training loss = model.cost(targets, target_masks, out=model(encoding, source_masks, inputs, input_masks)) / args.inter_size loss_outer = loss_outer + loss # accmulate the training metrics train_metrics.accumulate(batch_size, loss, print_iter=None) train_metrics.reset() loss.backward() if iters % args.inter_size == (args.inter_size - 1): if args.universal_options == 'no_update_encdec': for p in model.parameters(): if p is not model.encoder.uni_out.weight: if p.grad is not None: p.grad.detach_() p.grad.zero_() opt.step() info = 'training step={}, loss={:.3f}, lr={:.8f}'.format(iters / args.inter_size, export(loss_outer), opt.param_groups[0]['lr']) if args.tensorboard and (not args.debug): writer.add_scalar('train/Loss', export(loss_outer), iters / args.inter_size) progressbar.update(1) progressbar.set_description(info)
def train_model(args, model, train, dev, src, trg, teacher_model=None, save_path=None, maxsteps=None): if args.tensorboard and (not args.debug): from tensorboardX import SummaryWriter writer = SummaryWriter('{}{}'.format(args.event_path, args.prefix+args.hp_str)) # optimizer params = [p for p in model.parameters() if p.requires_grad] if args.optimizer == 'Adam': opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9) else: raise NotImplementedError # if resume training if (args.load_from is not None) and (args.resume): with torch.cuda.device(args.gpu): # very important. offset, opt_states = torch.load(os.path.join(args.model_path, args.load_from + '.pt.states'), map_location=lambda storage, loc: storage.cuda()) opt.load_state_dict(opt_states) else: offset = 0 # metrics if save_path is None: save_path = args.model_name best = Best(max, *['BLEU_dec{}'.format(ii+1) for ii in range(args.valid_repeat_dec)], \ 'i', model=model, opt=opt, path=save_path, gpu=args.gpu, \ which=range(args.valid_repeat_dec)) train_metrics = Metrics('train loss', *['loss_{}'.format(idx+1) for idx in range(args.train_repeat_dec)], data_type = "avg") dev_metrics = Metrics('dev loss', *['loss_{}'.format(idx+1) for idx in range(args.valid_repeat_dec)], data_type = "avg") if not args.no_tqdm: progressbar = tqdm(total=args.eval_every, desc='start training.') for iters, batch in enumerate(train): iters += offset if iters % args.save_every == 0: args.logger.info('save (back-up) checkpoints at iter={}'.format(iters)) with torch.cuda.device(args.gpu): torch.save(best.model.state_dict(), '{}.pt'.format(args.model_name)) torch.save([iters, best.opt.state_dict()], '{}.pt.states'.format(args.model_name)) if iters % args.eval_every == 0: dev_metrics.reset() outputs_data = valid_model(args, model, dev, dev_metrics, teacher_model=None, print_out=True) if args.tensorboard and (not args.debug): for ii in range(args.valid_repeat_dec): writer.add_scalar('dev/single/Loss_{}'.format(ii + 1), getattr(dev_metrics, "loss_{}".format(ii+1)), iters) writer.add_scalar('dev/single/BLEU_{}'.format(ii + 1), outputs_data['bleu'][ii], iters) writer.add_scalars('dev/multi/BLEUs', {"iter_{}".format(idx+1):bleu for idx, bleu in enumerate(outputs_data['bleu']) }, iters) writer.add_scalars('dev/multi/Losses', \ { "iter_{}".format(idx+1):getattr(dev_metrics, "loss_{}".format(idx+1)) \ for idx in range(args.valid_repeat_dec) }, \ iters) if not args.debug: best.accumulate(*outputs_data['bleu'], iters) values = list( best.metrics.values() ) args.logger.info("best model : {}, {}".format( "BLEU=[{}]".format(", ".join( [ str(x) for x in values[:args.valid_repeat_dec] ] ) ), \ "i={}".format( values[args.valid_repeat_dec] ), ) ) args.logger.info('model:' + args.prefix + args.hp_str) # ---set-up a new progressor--- if not args.no_tqdm: progressbar.close() progressbar = tqdm(total=args.eval_every, desc='start training.') if maxsteps is None: maxsteps = args.maximum_steps if iters > maxsteps: args.logger.info('reach the maximum updating steps.') break # --- training --- # model.train() def get_learning_rate(i, lr0=0.1, disable=False): if not disable: return max(0.00003, args.lr / math.pow(5, math.floor(i/50000))) ''' return lr0 * 10 / math.sqrt(args.d_model) * min( 1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup))) ''' return args.lr opt.param_groups[0]['lr'] = get_learning_rate(iters + 1, disable=args.disable_lr_schedule) opt.zero_grad() # prepare the data inputs, input_masks, \ targets, target_masks, \ sources, source_masks,\ encoding, batch_size = model.quick_prepare(batch) #print(input_masks.size(), target_masks.size(), input_masks.sum()) if type(model) is Transformer: decoder_inputs, decoder_masks = inputs, input_masks elif type(model) is FastTransformer: decoder_inputs, _, decoder_masks = \ model.prepare_initial(encoding, sources, source_masks, input_masks) initial_inputs = decoder_inputs if type(model) is Transformer: out = model(encoding, source_masks, decoder_inputs, decoder_masks) loss = model.cost(targets, target_masks, out) elif type(model) is FastTransformer: losses = [] for iter_ in range(args.train_repeat_dec): curr_iter = min(iter_, args.num_shared_dec-1) next_iter = min(curr_iter + 1, args.num_shared_dec-1) out = model(encoding, source_masks, decoder_inputs, decoder_masks, iter_=curr_iter) losses.append( model.cost(targets, target_masks, out=out, iter_=curr_iter) ) logits = model.decoder[curr_iter].out(out) if args.use_argmax: _, argmax = torch.max(logits, dim=-1) else: logits = softmax(logits) logits_sz = logits.size() logits_ = Variable(logits.data, requires_grad=False) argmax = torch.multinomial(logits_.contiguous().view(-1, logits_sz[-1]), 1)\ .view(*logits_sz[:-1]) decoder_inputs = F.embedding(argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) if args.sum_out_and_emb: decoder_inputs += out if args.diff_loss_w > 0 and ((args.diff_loss_dec1 == False) or (args.diff_loss_dec1 == True and iter_ == 0)): num_words = out.size(1) # first L2 normalize out_norm = out.div(out.norm(p=2, dim=-1, keepdim=True)) # calculate loss diff_loss = torch.mean((out_norm[:,1:,:] * out_norm[:,:-1,:]).sum(-1).clamp(min=0)) * args.diff_loss_w # add this losses to all losses losses.append(diff_loss) loss = sum(losses) # accmulate the training metrics train_metrics.accumulate(batch_size, *losses, print_iter=None) # train the student loss.backward() if args.grad_clip > 0: total_norm = nn.utils.clip_grad_norm(params, args.grad_clip) opt.step() info = 'training step={}, loss={}, lr={:.5f}'.format( iters, "/".join(["{:.3f}".format(export(ll)) for ll in losses]), opt.param_groups[0]['lr']) if iters % args.eval_every == 0 and args.tensorboard and (not args.debug): for idx in range(args.train_repeat_dec): writer.add_scalar('train/single/Loss_{}'.format(idx+1), export(losses[idx]), iters) if args.no_tqdm: if iters % args.eval_every == 0: args.logger.info(train_metrics) else: progressbar.update(1) progressbar.set_description(info) train_metrics.reset()