bs_inner) if use_prog_bar: progressbar.update(1) progressbar.set_description(info) if use_prog_bar: progressbar.close() return model.save_fast_weights() # training start.. best = Best(max, 'corpus_bleu', 'i', model=model, opt=meta_opt, path=args.model_name, gpu=args.gpu) train_metrics = Metrics('train', 'loss', 'real', 'fake') dev_metrics = Metrics('dev', 'loss', 'gleu', 'real_loss', 'fake_loss', 'distance', 'alter_loss', 'distance2', 'fertility_loss', 'corpus_gleu') # overlall progress-ba progressbar = tqdm(total=args.eval_every, desc='start training') while True: # ----- saving the checkpoint ----- # if iters % args.save_every == 0:
def train_model(args, model, train, dev, teacher_model=None, save_path=None, maxsteps=None): if args.tensorboard and (not args.debug): from tensorboardX import SummaryWriter writer = SummaryWriter('./runs/{}'.format(args.prefix + args.hp_str)) # optimizer if args.optimizer == 'Adam': opt = torch.optim.Adam( [p for p in model.parameters() if p.requires_grad], betas=(0.9, 0.98), eps=1e-9) else: raise NotImplementedError # if resume training if (args.load_from is not None) and (args.resume): with torch.cuda.device(args.gpu): # very important. offset, opt_states = torch.load( './models/' + args.load_from + '.pt.states', map_location=lambda storage, loc: storage.cuda()) opt.load_state_dict(opt_states) else: offset = 0 # metrics if save_path is None: save_path = args.model_name best = Best(max, 'corpus_bleu', 'corpus_gleu', 'gleu', 'loss', 'i', model=model, opt=opt, path=save_path, gpu=args.gpu) train_metrics = Metrics('train', 'loss', 'real', 'fake') dev_metrics = Metrics('dev', 'loss', 'gleu', 'real_loss', 'fake_loss', 'distance', 'alter_loss', 'distance2', 'fertility_loss', 'corpus_gleu') progressbar = tqdm(total=args.eval_every, desc='start training.') for iters, batch in enumerate(train): iters += offset if iters % args.save_every == 0: args.logger.info( 'save (back-up) checkpoints at iter={}'.format(iters)) with torch.cuda.device(args.gpu): torch.save(best.model.state_dict(), '{}_iter={}.pt'.format(args.model_name, iters)) torch.save([iters, best.opt.state_dict()], '{}_iter={}.pt.states'.format( args.model_name, iters)) if iters % args.eval_every == 0: progressbar.close() dev_metrics.reset() if args.distillation: outputs_course = valid_model(args, model, dev, dev_metrics, distillation=True, teacher_model=None) outputs_data = valid_model( args, model, dev, None if args.distillation else dev_metrics, teacher_model=None, print_out=True) if args.tensorboard and (not args.debug): writer.add_scalar('dev/GLEU_sentence_', dev_metrics.gleu, iters) writer.add_scalar('dev/Loss', dev_metrics.loss, iters) writer.add_scalar('dev/GLEU_corpus_', outputs_data['corpus_gleu'], iters) writer.add_scalar('dev/BLEU_corpus_', outputs_data['corpus_bleu'], iters) if args.distillation: writer.add_scalar('dev/GLEU_corpus_dis', outputs_course['corpus_gleu'], iters) writer.add_scalar('dev/BLEU_corpus_dis', outputs_course['corpus_bleu'], iters) if not args.debug: best.accumulate(outputs_data['corpus_bleu'], outputs_data['corpus_gleu'], dev_metrics.gleu, dev_metrics.loss, iters) args.logger.info( 'the best model is achieved at {}, average greedy GLEU={}, corpus GLEU={}, corpus BLEU={}' .format(best.i, best.gleu, best.corpus_gleu, best.corpus_bleu)) args.logger.info('model:' + args.prefix + args.hp_str) # ---set-up a new progressor--- progressbar = tqdm(total=args.eval_every, desc='start training.') if maxsteps is None: maxsteps = args.maximum_steps if iters > maxsteps: args.logger.info('reach the maximum updating steps.') break # --- training --- # model.train() def get_learning_rate(i, lr0=0.1, disable=False): if not disable: return lr0 * 10 / math.sqrt(args.d_model) * min( 1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup))) return 0.00002 opt.param_groups[0]['lr'] = get_learning_rate( iters + 1, disable=args.disable_lr_schedule) opt.zero_grad() # prepare the data inputs, input_masks, \ targets, target_masks, \ sources, source_masks,\ encoding, batch_size = model.quick_prepare(batch, args.distillation) input_reorder, fertility_cost, decoder_inputs = None, None, inputs batch_fer = batch.fer_dec if args.distillation else batch.fer #print(input_masks.size(), target_masks.size(), input_masks.sum()) if type(model) is FastTransformer: inputs, input_reorder, input_masks, fertility_cost = model.prepare_initial( encoding, sources, source_masks, input_masks, batch_fer) # Maximum Likelihood Training if not args.finetuning: loss = model.cost(targets, target_masks, out=model(encoding, source_masks, inputs, input_masks)) if args.fertility: loss += fertility_cost else: # finetuning: # loss_student (MLE) if not args.fertility: decoding, out, probs = model(encoding, source_masks, inputs, input_masks, return_probs=True, decoding=True) loss_student = model.batched_cost(targets, target_masks, probs) # student-loss (MLE) decoder_masks = input_masks else: # Note that MLE and decoding has different translations. We need to run the same code twice # truth decoding, out, probs = model(encoding, source_masks, inputs, input_masks, decoding=True, return_probs=True) loss_student = model.cost(targets, target_masks, out=out) decoder_masks = input_masks # baseline decoder_inputs_b, _, decoder_masks_b, _, _ = model.prepare_initial( encoding, sources, source_masks, input_masks, None, mode='mean') decoding_b, out_b, probs_b = model( encoding, source_masks, decoder_inputs_b, decoder_masks_b, decoding=True, return_probs=True) # decode again # reinforce decoder_inputs_r, _, decoder_masks_r, _, _ = model.prepare_initial( encoding, sources, source_masks, input_masks, None, mode='reinforce') decoding_r, out_r, probs_r = model( encoding, source_masks, decoder_inputs_r, decoder_masks_r, decoding=True, return_probs=True) # decode again if args.fertility: loss_student += fertility_cost # loss_teacher (RKL+REINFORCE) teacher_model.eval() if not args.fertility: inputs_student_index, _, targets_student_soft, _, _, _, encoding_teacher, _ = model.quick_prepare( batch, False, decoding, probs, decoder_masks, decoder_masks, source_masks) out_teacher, probs_teacher = teacher_model( encoding_teacher, source_masks, inputs_student_index.detach(), decoder_masks, return_probs=True) loss_teacher = teacher_model.batched_cost( targets_student_soft, decoder_masks, probs_teacher.detach()) loss = ( 1 - args.beta1 ) * loss_teacher + args.beta1 * loss_student # final results else: inputs_student_index, _, targets_student_soft, _, _, _, encoding_teacher, _ = model.quick_prepare( batch, False, decoding, probs, decoder_masks, decoder_masks, source_masks) out_teacher, probs_teacher = teacher_model( encoding_teacher, source_masks, inputs_student_index.detach(), decoder_masks, return_probs=True) loss_teacher = teacher_model.batched_cost( targets_student_soft, decoder_masks, probs_teacher.detach()) inputs_student_index, _ = model.prepare_inputs( batch, decoding_b, False, decoder_masks_b) targets_student_soft, _ = model.prepare_targets( batch, probs_b, False, decoder_masks_b) out_teacher, probs_teacher = teacher_model( encoding_teacher, source_masks, inputs_student_index.detach(), decoder_masks_b, return_probs=True) _, loss_1 = teacher_model.batched_cost(targets_student_soft, decoder_masks_b, probs_teacher.detach(), True) inputs_student_index, _ = model.prepare_inputs( batch, decoding_r, False, decoder_masks_r) targets_student_soft, _ = model.prepare_targets( batch, probs_r, False, decoder_masks_r) out_teacher, probs_teacher = teacher_model( encoding_teacher, source_masks, inputs_student_index.detach(), decoder_masks_r, return_probs=True) _, loss_2 = teacher_model.batched_cost(targets_student_soft, decoder_masks_r, probs_teacher.detach(), True) rewards = -(loss_2 - loss_1).data rewards = rewards - rewards.mean() rewards = rewards.expand_as(source_masks) rewards = rewards * source_masks model.predictor.saved_fertilities.reinforce( 0.1 * rewards.contiguous().view(-1, 1)) loss = ( 1 - args.beta1 ) * loss_teacher + args.beta1 * loss_student # detect reinforce # accmulate the training metrics train_metrics.accumulate(batch_size, loss, print_iter=None) train_metrics.reset() # train the student if args.finetuning and args.fertility: torch.autograd.backward( (loss, model.predictor.saved_fertilities), (torch.ones(1).cuda(loss.get_device()), None)) else: loss.backward() opt.step() info = 'training step={}, loss={:.3f}, lr={:.5f}'.format( iters, export(loss), opt.param_groups[0]['lr']) if args.finetuning: info += '| NA:{:.3f}, AR:{:.3f}'.format(export(loss_student), export(loss_teacher)) if args.fertility: info += '| RL: {:.3f}'.format(export(rewards.mean())) if args.fertility: info += '| RE:{:.3f}'.format(export(fertility_cost)) if args.tensorboard and (not args.debug): writer.add_scalar('train/Loss', export(loss), iters) progressbar.update(1) progressbar.set_description(info)
def train_model(args, model, train, dev, src=None, trg=None, trg_len_dic=None, teacher_model=None, save_path=None, maxsteps=None): if args.tensorboard and (not args.debug): from tensorboardX import SummaryWriter writer = SummaryWriter(str(args.event_path / args.id_str)) if type(model) is FastTransformer and args.denoising_prob > 0.0: denoising_weights = [ args.denoising_weight for idx in range(args.train_repeat_dec) ] denoising_out_weights = [ args.denoising_out_weight for idx in range(args.train_repeat_dec) ] if type(model) is FastTransformer and args.layerwise_denoising_weight: start, end = 0.9, 0.1 diff = (start - end) / (args.train_repeat_dec - 1) denoising_weights = np.arange(start=end, stop=start, step=diff).tolist()[::-1] + [0.1] # optimizer for k, p in zip(model.state_dict().keys(), model.parameters()): # only finetune layers that are responsible to predicting target len if args.finetune_trg_len: if "pred_len" not in k: p.requires_grad = False else: if "pred_len" in k: p.requires_grad = False params = [p for p in model.parameters() if p.requires_grad] if args.optimizer == 'Adam': opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9) else: raise NotImplementedError # if resume training if (args.load_from is not None) and (args.resume): with torch.cuda.device(args.gpu): # very important. offset, opt_states = torch.load( str(args.model_path / args.load_from) + '.pt.states', map_location=lambda storage, loc: storage.cuda()) opt.load_state_dict(opt_states) else: offset = 0 if not args.finetune_trg_len: best = Best(max, *[ 'BLEU_dec{}'.format(ii + 1) for ii in range(args.valid_repeat_dec) ], 'i', model=model, opt=opt, path=str(args.model_path / args.id_str), gpu=args.gpu, which=range(args.valid_repeat_dec)) else: best = Best(max, *['pred_target_len_correct'], 'i', model=model, opt=opt, path=str(args.model_path / args.id_str), gpu=args.gpu, which=[0]) train_metrics = Metrics( 'train loss', *['loss_{}'.format(idx + 1) for idx in range(args.train_repeat_dec)], data_type="avg") dev_metrics = Metrics( 'dev loss', *['loss_{}'.format(idx + 1) for idx in range(args.valid_repeat_dec)], data_type="avg") if "predict" in args.trg_len_option: train_metrics_trg = Metrics('train loss target', *[ "pred_target_len_loss", "pred_target_len_correct", "pred_target_len_approx" ], data_type="avg") train_metrics_average = Metrics( 'train loss average', *["average_target_len_correct", "average_target_len_approx"], data_type="avg") dev_metrics_trg = Metrics('dev loss target', *[ "pred_target_len_loss", "pred_target_len_correct", "pred_target_len_approx" ], data_type="avg") dev_metrics_average = Metrics( 'dev loss average', *["average_target_len_correct", "average_target_len_approx"], data_type="avg") else: train_metrics_trg = None train_metrics_average = None dev_metrics_trg = None dev_metrics_average = None if not args.no_tqdm: progressbar = tqdm(total=args.eval_every, desc='start training.') if maxsteps is None: maxsteps = args.maximum_steps #targetlength = TargetLength() for iters, train_batch in enumerate(train): #targetlength.accumulate( train_batch ) #continue iters += offset if args.save_every > 0 and iters % args.save_every == 0: args.logger.info( 'save (back-up) checkpoints at iter={}'.format(iters)) with torch.cuda.device(args.gpu): torch.save( best.model.state_dict(), '{}_iter={}.pt'.format(str(args.model_path / args.id_str), iters)) torch.save([iters, best.opt.state_dict()], '{}_iter={}.pt.states'.format( str(args.model_path / args.id_str), iters)) if iters % args.eval_every == 0: torch.cuda.empty_cache() gc.collect() dev_metrics.reset() if dev_metrics_trg is not None: dev_metrics_trg.reset() if dev_metrics_average is not None: dev_metrics_average.reset() outputs_data = valid_model(args, model, dev, dev_metrics, dev_metrics_trg=dev_metrics_trg, dev_metrics_average=dev_metrics_average, teacher_model=None, print_out=True, trg_len_dic=trg_len_dic) #outputs_data = [0, [0,0,0,0], 0, 0] if args.tensorboard and (not args.debug): for ii in range(args.valid_repeat_dec): writer.add_scalar('dev/single/Loss_{}'.format(ii + 1), getattr(dev_metrics, "loss_{}".format(ii + 1)), iters) # NLL averaged over dev corpus writer.add_scalar('dev/single/BLEU_{}'.format(ii + 1), outputs_data['real'][ii][0], iters) # NOTE corpus bleu if "predict" in args.trg_len_option: writer.add_scalar("dev/single/pred_target_len_loss", outputs_data["pred_target_len_loss"], iters) writer.add_scalar("dev/single/pred_target_len_correct", outputs_data["pred_target_len_correct"], iters) writer.add_scalar("dev/single/pred_target_len_approx", outputs_data["pred_target_len_approx"], iters) writer.add_scalar( "dev/single/average_target_len_correct", outputs_data["average_target_len_correct"], iters) writer.add_scalar( "dev/single/average_target_len_approx", outputs_data["average_target_len_approx"], iters) """ writer.add_scalars('dev/total/BLEUs', {"iter_{}".format(idx+1):bleu for idx, bleu in enumerate(outputs_data['bleu']) }, iters) writer.add_scalars('dev/total/Losses', { "iter_{}".format(idx+1):getattr(dev_metrics, "loss_{}".format(idx+1)) for idx in range(args.valid_repeat_dec) }, iters ) """ if not args.debug: if not args.finetune_trg_len: best.accumulate(*[xx[0] for xx in outputs_data['real']], iters) values = list(best.metrics.values()) args.logger.info("best model : {}, {}".format( "BLEU=[{}]".format(", ".join( [ str(x) for x in values[:args.valid_repeat_dec] ] ) ), \ "i={}".format( values[args.valid_repeat_dec] ), ) ) else: best.accumulate(*[outputs_data['pred_target_len_correct']], iters) values = list(best.metrics.values()) args.logger.info("best model : {}".format( "pred_target_len_correct = {}".format(values[0]))) args.logger.info('model:' + args.prefix + args.hp_str) # ---set-up a new progressor--- if not args.no_tqdm: progressbar.close() progressbar = tqdm(total=args.eval_every, desc='start training.') if type(model) is FastTransformer and args.anneal_denoising_weight: for ii, bb in enumerate([xx[0] for xx in outputs_data['real']][:-1]): denoising_weights[ii] = 0.9 - 0.1 * int( math.floor(bb / 3.0)) if iters > maxsteps: args.logger.info('reached the maximum updating steps.') break model.train() def get_lr_transformer(i, lr0=0.1): return lr0 * 10 / math.sqrt(args.d_model) * min( 1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup))) def get_lr_anneal(iters, lr0=0.1): lr_end = 1e-5 return max(0, (args.lr - lr_end) * (args.anneal_steps - iters) / args.anneal_steps) + lr_end if args.lr_schedule == "fixed": opt.param_groups[0]['lr'] = args.lr elif args.lr_schedule == "anneal": opt.param_groups[0]['lr'] = get_lr_anneal(iters + 1) elif args.lr_schedule == "transformer": opt.param_groups[0]['lr'] = get_lr_transformer(iters + 1) opt.zero_grad() if args.dataset == "mscoco": decoder_inputs, decoder_masks,\ targets, target_masks,\ _, source_masks,\ encoding, batch_size, rest = model.quick_prepare_mscoco(train_batch, all_captions=train_batch[1], fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_offset, trg_len_dic=trg_len_dic, bp=args.bp) else: decoder_inputs, decoder_masks,\ targets, target_masks,\ sources, source_masks,\ encoding, batch_size, rest = model.quick_prepare(train_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp) losses = [] if type(model) is Transformer: loss = model.cost(targets, target_masks, out=model(encoding, source_masks, decoder_inputs, decoder_masks)) losses.append(loss) elif type(model) is FastTransformer: all_logits = [] all_denoising_masks = [] for iter_ in range(args.train_repeat_dec): curr_iter = min(iter_, args.num_decs - 1) next_iter = min(curr_iter + 1, args.num_decs - 1) out = model(encoding, source_masks, decoder_inputs, decoder_masks, iter_=curr_iter, return_probs=False) if args.self_distil > 0.0: loss, logits_masked = model.cost(targets, target_masks, out=out, iter_=curr_iter, return_logits=True) else: loss = model.cost(targets, target_masks, out=out, iter_=curr_iter) logits = model.decoder[curr_iter].out(out) if args.use_argmax: _, argmax = torch.max(logits, dim=-1) else: probs = softmax(logits) probs_sz = probs.size() logits_ = Variable(probs.data, requires_grad=False) argmax = torch.multinomial( logits_.contiguous().view(-1, probs_sz[-1]), 1).view(*probs_sz[:-1]) if args.self_distil > 0.0: all_logits.append(logits_masked) losses.append(loss) decoder_inputs_ = 0 denoising_mask = 1 if args.next_dec_input in ["both", "emb"]: if args.denoising_prob > 0.0 and np.random.rand( ) < args.denoising_prob: cor = corrupt_target(targets, decoder_masks, len(trg.vocab), denoising_weights[iter_], args.corruption_probs) emb = F.embedding( cor, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) denoising_mask = 0 else: emb = F.embedding( argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) if args.denoising_out_weight > 0: if denoising_out_weights[iter_] > 0.0: corrupted_argmax = corrupt_target( argmax, decoder_masks, denoising_out_weights[iter_]) else: corrupted_argmax = argmax emb = F.embedding( corrupted_argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) decoder_inputs_ += emb all_denoising_masks.append(denoising_mask) if args.next_dec_input in ["both", "out"]: decoder_inputs_ += out decoder_inputs = decoder_inputs_ # self distillation loss if requested if args.self_distil > 0.0: self_distil_losses = [] for logits_i in range(1, len(all_logits) - 1): self_distill_loss_i = 0.0 for logits_j in range(logits_i + 1, len(all_logits)): self_distill_loss_i += \ all_denoising_masks[logits_j] * \ all_denoising_masks[logits_i] * \ (1/(logits_j-logits_i)) * args.self_distil * F.mse_loss(all_logits[logits_i], all_logits[logits_j].detach()) self_distil_losses.append(self_distill_loss_i) self_distil_loss = sum(self_distil_losses) loss = sum(losses) # accmulate the training metrics train_metrics.accumulate(batch_size, *losses, print_iter=None) if train_metrics_trg is not None: train_metrics_trg.accumulate(batch_size, *[rest[0], rest[1], rest[2]]) if train_metrics_average is not None: train_metrics_average.accumulate(batch_size, *[rest[3], rest[4]]) if type(model) is FastTransformer and args.self_distil > 0.0: (loss + self_distil_loss).backward() else: if "predict" in args.trg_len_option: if args.finetune_trg_len: rest[0].backward() else: loss.backward() else: loss.backward() if args.grad_clip > 0: total_norm = nn.utils.clip_grad_norm(params, args.grad_clip) opt.step() mid_str = '' if type(model) is FastTransformer and args.self_distil > 0.0: mid_str += 'distil={:.5f}, '.format( self_distil_loss.cpu().data.numpy()[0]) if type(model) is FastTransformer and "predict" in args.trg_len_option: mid_str += 'pred_target_len_loss={:.5f}, '.format( rest[0].cpu().data.numpy()[0]) if type(model) is FastTransformer and args.denoising_prob > 0.0: mid_str += "/".join( ["{:.1f}".format(ff) for ff in denoising_weights[:-1]]) + ", " info = 'update={}, loss={}, {}lr={:.1e}'.format( iters, "/".join(["{:.3f}".format(export(ll)) for ll in losses]), mid_str, opt.param_groups[0]['lr']) if args.no_tqdm: if iters % args.eval_every == 0: args.logger.info("update {} : {}".format( iters, str(train_metrics))) else: progressbar.update(1) progressbar.set_description(info) if iters % args.eval_every == 0 and args.tensorboard and ( not args.debug): for idx in range(args.train_repeat_dec): writer.add_scalar( 'train/single/Loss_{}'.format(idx + 1), getattr(train_metrics, "loss_{}".format(idx + 1)), iters) if "predict" in args.trg_len_option: writer.add_scalar( "train/single/pred_target_len_loss", getattr(train_metrics_trg, "pred_target_len_loss"), iters) writer.add_scalar( "train/single/pred_target_len_correct", getattr(train_metrics_trg, "pred_target_len_correct"), iters) writer.add_scalar( "train/single/pred_target_len_approx", getattr(train_metrics_trg, "pred_target_len_approx"), iters) writer.add_scalar( "train/single/average_target_len_correct", getattr(train_metrics_average, "average_target_len_correct"), iters) writer.add_scalar( "train/single/average_target_len_approx", getattr(train_metrics_average, "average_target_len_approx"), iters) train_metrics.reset() if train_metrics_trg is not None: train_metrics_trg.reset() if train_metrics_average is not None: train_metrics_average.reset()
# ----- meta-validation ----- # dev_iters = iters weights = model.save_fast_weights() self_opt = torch.optim.Adam([ p for p in model.get_parameters(type=args.finetune_params) if p.requires_grad ], betas=(0.9, 0.98), eps=1e-9) corpus_bleu = -1 # training start.. best = Best(max, 'corpus_bleu', 'i', model=model, opt=self_opt, path=args.model_name, gpu=args.gpu) dev_metrics = Metrics('dev', 'loss', 'gleu') outputs_data = valid_model(args, model, dev_real, dev_metrics, print_out=False) corpus_bleu0 = outputs_data['corpus_bleu'] fast_weights = [(weights, corpus_bleu0)] if args.tensorboard and (not args.debug): writer.add_scalar('dev/BLEU_corpus_', outputs_data['corpus_bleu'],
def train_model(args, model, train, dev, save_path=None, maxsteps=None): if args.tensorboard and (not args.debug): from tensorboardX import SummaryWriter writer = SummaryWriter('./runs/{}'.format(args.prefix + args.hp_str)) # optimizer if args.optimizer == 'Adam': opt = torch.optim.Adam( [p for p in model.parameters() if p.requires_grad], betas=(0.9, 0.98), eps=1e-9) else: raise NotImplementedError # if resume training if (args.load_from is not None) and (args.resume): with torch.cuda.device(args.gpu): # very important. offset, opt_states = torch.load( './models/' + args.load_from + '.pt.states', map_location=lambda storage, loc: storage.cuda()) opt.load_state_dict(opt_states) else: offset = 0 # metrics if save_path is None: save_path = args.model_name best = Best(max, 'corpus_bleu', 'corpus_gleu', 'gleu', 'loss', 'i', model=model, opt=opt, path=save_path, gpu=args.gpu) train_metrics = Metrics('train', 'loss', 'real', 'fake') dev_metrics = Metrics('dev', 'loss', 'gleu', 'real_loss', 'fake_loss', 'distance', 'alter_loss', 'distance2', 'fertility_loss', 'corpus_gleu') progressbar = tqdm(total=args.eval_every, desc='start training.') for iters, batch in enumerate(train): iters += offset # --- saving --- # if iters % args.save_every == 0: args.logger.info( 'save (back-up) checkpoints at iter={}'.format(iters)) with torch.cuda.device(args.gpu): torch.save(best.model.state_dict(), '{}_iter={}.pt'.format(args.model_name, iters)) torch.save([iters, best.opt.state_dict()], '{}_iter={}.pt.states'.format( args.model_name, iters)) # --- validation --- # if iters % args.eval_every == 0: progressbar.close() dev_metrics.reset() if args.distillation: outputs_course = valid_model(args, model, dev, dev_metrics, distillation=True) outputs_data = valid_model( args, model, dev, None if args.distillation else dev_metrics, print_out=True) if args.tensorboard and (not args.debug): writer.add_scalar('dev/GLEU_sentence_', dev_metrics.gleu, iters) writer.add_scalar('dev/Loss', dev_metrics.loss, iters) writer.add_scalar('dev/GLEU_corpus_', outputs_data['corpus_gleu'], iters) writer.add_scalar('dev/BLEU_corpus_', outputs_data['corpus_bleu'], iters) if args.distillation: writer.add_scalar('dev/GLEU_corpus_dis', outputs_course['corpus_gleu'], iters) writer.add_scalar('dev/BLEU_corpus_dis', outputs_course['corpus_bleu'], iters) if not args.debug: best.accumulate(outputs_data['corpus_bleu'], outputs_data['corpus_gleu'], dev_metrics.gleu, dev_metrics.loss, iters) args.logger.info( 'the best model is achieved at {}, average greedy GLEU={}, corpus GLEU={}, corpus BLEU={}' .format(best.i, best.gleu, best.corpus_gleu, best.corpus_bleu)) args.logger.info('model:' + args.prefix + args.hp_str) # ---set-up a new progressor--- progressbar = tqdm(total=args.eval_every, desc='start training.') if maxsteps is None: maxsteps = args.maximum_steps if iters > maxsteps: args.logger.info('reach the maximum updating steps.') break # --- training --- # model.train() def get_learning_rate(i, lr0=0.1, disable=False): if not disable: return lr0 * 10 / math.sqrt(args.d_model) * min( 1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup))) return 0.00002 opt.param_groups[0]['lr'] = get_learning_rate( iters + 1, disable=args.disable_lr_schedule) opt.zero_grad() # prepare the data inputs, input_masks, \ targets, target_masks, \ sources, source_masks,\ encoding, batch_size = model.quick_prepare(batch, args.distillation) input_reorder, fertility_cost, decoder_inputs = None, None, inputs #print(input_masks.size(), target_masks.size(), input_masks.sum()) if type(model) is FastTransformer: batch_fer = batch.fer_dec if args.distillation else batch.fer inputs, input_reorder, input_masks, fertility_cost = model.prepare_initial( encoding, sources, source_masks, input_masks, batch_fer) # Maximum Likelihood Training loss = model.cost(targets, target_masks, out=model(encoding, source_masks, inputs, input_masks)) if hasattr(args, 'fertility') and args.fertility: loss += fertility_cost # accmulate the training metrics train_metrics.accumulate(batch_size, loss, print_iter=None) train_metrics.reset() loss.backward() opt.step() info = 'training step={}, loss={:.3f}, lr={:.5f}'.format( iters, export(loss), opt.param_groups[0]['lr']) if hasattr(args, 'fertility') and args.fertility: info += '| RE:{:.3f}'.format(export(fertility_cost)) if args.tensorboard and (not args.debug): writer.add_scalar('train/Loss', export(loss), iters) progressbar.update(1) progressbar.set_description(info)
def train_model(args, model, train, dev, save_path=None, maxsteps=None, writer=None): # optimizer if args.optimizer == 'Adam': opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], betas=(0.9, 0.98), eps=1e-9) else: raise NotImplementedError # if resume training if (args.load_from is not None) and (args.resume): with torch.cuda.device(args.gpu): # very important. offset, opt_states = torch.load(args.models_dir + '/' + args.load_from + '.pt.states', map_location=lambda storage, loc: storage.cuda()) if not args.finetune: # if finetune, do not have history opt.load_state_dict(opt_states) else: offset = 0 # metrics if save_path is None: save_path = args.model_name args.eval_every *= args.inter_size best = Best(max, 'corpus_bleu', 'corpus_gleu', 'gleu', 'loss', 'i', model=model, opt=opt, path=save_path, gpu=args.gpu) train_metrics = Metrics('train', 'loss', 'real', 'fake') dev_metrics = Metrics('dev', 'loss', 'gleu', 'real_loss', 'fake_loss', 'distance', 'alter_loss', 'distance2', 'fertility_loss', 'corpus_gleu') progressbar = tqdm(total=args.eval_every, desc='start training.') examples = 0 first_step = True loss_outer = 0 for iters, batch in enumerate(train): iters += offset # --- saving --- # if iters % args.save_every == 0: args.logger.info('save (back-up) checkpoints at iter={}'.format(iters)) with torch.cuda.device(args.gpu): torch.save(best.model.state_dict(), '{}_iter={}.pt'.format(args.model_name, iters)) torch.save([iters, best.opt.state_dict()], '{}_iter={}.pt.states'.format(args.model_name, iters)) # --- validation --- # if ((args.eval_every_examples == -1) and (iters % args.eval_every == 0)) \ or ((args.eval_every_examples > 0) and (examples > args.eval_every_examples)) \ or first_step: first_step = False if args.eval_every_examples > 0: examples = examples % args.eval_every_examples for dev_iters, dev_batch in enumerate(dev): progressbar.close() dev_metrics.reset() if args.distillation: outputs_course = valid_model(args, model, dev, dev_metrics, distillation=True) outputs_data = valid_model(args, model, dev, None if args.distillation else dev_metrics, print_out=True) if args.tensorboard and (not args.debug): writer.add_scalar('dev/GLEU_sentence_', dev_metrics.gleu, iters / args.inter_size) writer.add_scalar('dev/Loss', dev_metrics.loss, iters / args.inter_size) writer.add_scalar('dev/GLEU_corpus_', outputs_data['corpus_gleu'], iters / args.inter_size) writer.add_scalar('dev/BLEU_corpus_', outputs_data['corpus_bleu'], iters / args.inter_size) if not args.debug: best.accumulate(outputs_data['corpus_bleu'], outputs_data['corpus_gleu'], dev_metrics.gleu, dev_metrics.loss, iters / args.inter_size) args.logger.info('the best model is achieved at {}, average greedy GLEU={}, corpus GLEU={}, corpus BLEU={}'.format( best.i, best.gleu, best.corpus_gleu, best.corpus_bleu)) args.logger.info('model:' + args.prefix + args.hp_str) # ---set-up a new progressor--- progressbar = tqdm(total=args.eval_every, desc='start training.') if maxsteps is None: maxsteps = args.maximum_steps if iters > maxsteps: args.logger.info('reach the maximum updating steps.') break # --- training --- # model.train() def get_learning_rate(i, lr0=0.1, disable=False): if not disable: return lr0 * 10 / math.sqrt(args.d_model) * min(1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup))) return 0.00002 if iters % args.inter_size == 0: opt.param_groups[0]['lr'] = get_learning_rate(iters / args.inter_size + 1, disable=args.disable_lr_schedule) opt.zero_grad() loss_outer = 0 # prepare the data inputs, input_masks, \ targets, target_masks, \ sources, source_masks,\ encoding, batch_size = model.quick_prepare(batch, args.distillation) input_reorder, fertility_cost, decoder_inputs = None, None, inputs examples += batch_size # Maximum Likelihood Training loss = model.cost(targets, target_masks, out=model(encoding, source_masks, inputs, input_masks)) / args.inter_size loss_outer = loss_outer + loss # accmulate the training metrics train_metrics.accumulate(batch_size, loss, print_iter=None) train_metrics.reset() loss.backward() if iters % args.inter_size == (args.inter_size - 1): if args.universal_options == 'no_update_encdec': for p in model.parameters(): if p is not model.encoder.uni_out.weight: if p.grad is not None: p.grad.detach_() p.grad.zero_() opt.step() info = 'training step={}, loss={:.3f}, lr={:.8f}'.format(iters / args.inter_size, export(loss_outer), opt.param_groups[0]['lr']) if args.tensorboard and (not args.debug): writer.add_scalar('train/Loss', export(loss_outer), iters / args.inter_size) progressbar.update(1) progressbar.set_description(info)
def train_model(args, model, train, dev, src, trg, teacher_model=None, save_path=None, maxsteps=None): if args.tensorboard and (not args.debug): from tensorboardX import SummaryWriter writer = SummaryWriter('{}{}'.format(args.event_path, args.prefix+args.hp_str)) # optimizer params = [p for p in model.parameters() if p.requires_grad] if args.optimizer == 'Adam': opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9) else: raise NotImplementedError # if resume training if (args.load_from is not None) and (args.resume): with torch.cuda.device(args.gpu): # very important. offset, opt_states = torch.load(os.path.join(args.model_path, args.load_from + '.pt.states'), map_location=lambda storage, loc: storage.cuda()) opt.load_state_dict(opt_states) else: offset = 0 # metrics if save_path is None: save_path = args.model_name best = Best(max, *['BLEU_dec{}'.format(ii+1) for ii in range(args.valid_repeat_dec)], \ 'i', model=model, opt=opt, path=save_path, gpu=args.gpu, \ which=range(args.valid_repeat_dec)) train_metrics = Metrics('train loss', *['loss_{}'.format(idx+1) for idx in range(args.train_repeat_dec)], data_type = "avg") dev_metrics = Metrics('dev loss', *['loss_{}'.format(idx+1) for idx in range(args.valid_repeat_dec)], data_type = "avg") if not args.no_tqdm: progressbar = tqdm(total=args.eval_every, desc='start training.') for iters, batch in enumerate(train): iters += offset if iters % args.save_every == 0: args.logger.info('save (back-up) checkpoints at iter={}'.format(iters)) with torch.cuda.device(args.gpu): torch.save(best.model.state_dict(), '{}.pt'.format(args.model_name)) torch.save([iters, best.opt.state_dict()], '{}.pt.states'.format(args.model_name)) if iters % args.eval_every == 0: dev_metrics.reset() outputs_data = valid_model(args, model, dev, dev_metrics, teacher_model=None, print_out=True) if args.tensorboard and (not args.debug): for ii in range(args.valid_repeat_dec): writer.add_scalar('dev/single/Loss_{}'.format(ii + 1), getattr(dev_metrics, "loss_{}".format(ii+1)), iters) writer.add_scalar('dev/single/BLEU_{}'.format(ii + 1), outputs_data['bleu'][ii], iters) writer.add_scalars('dev/multi/BLEUs', {"iter_{}".format(idx+1):bleu for idx, bleu in enumerate(outputs_data['bleu']) }, iters) writer.add_scalars('dev/multi/Losses', \ { "iter_{}".format(idx+1):getattr(dev_metrics, "loss_{}".format(idx+1)) \ for idx in range(args.valid_repeat_dec) }, \ iters) if not args.debug: best.accumulate(*outputs_data['bleu'], iters) values = list( best.metrics.values() ) args.logger.info("best model : {}, {}".format( "BLEU=[{}]".format(", ".join( [ str(x) for x in values[:args.valid_repeat_dec] ] ) ), \ "i={}".format( values[args.valid_repeat_dec] ), ) ) args.logger.info('model:' + args.prefix + args.hp_str) # ---set-up a new progressor--- if not args.no_tqdm: progressbar.close() progressbar = tqdm(total=args.eval_every, desc='start training.') if maxsteps is None: maxsteps = args.maximum_steps if iters > maxsteps: args.logger.info('reach the maximum updating steps.') break # --- training --- # model.train() def get_learning_rate(i, lr0=0.1, disable=False): if not disable: return max(0.00003, args.lr / math.pow(5, math.floor(i/50000))) ''' return lr0 * 10 / math.sqrt(args.d_model) * min( 1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup))) ''' return args.lr opt.param_groups[0]['lr'] = get_learning_rate(iters + 1, disable=args.disable_lr_schedule) opt.zero_grad() # prepare the data inputs, input_masks, \ targets, target_masks, \ sources, source_masks,\ encoding, batch_size = model.quick_prepare(batch) #print(input_masks.size(), target_masks.size(), input_masks.sum()) if type(model) is Transformer: decoder_inputs, decoder_masks = inputs, input_masks elif type(model) is FastTransformer: decoder_inputs, _, decoder_masks = \ model.prepare_initial(encoding, sources, source_masks, input_masks) initial_inputs = decoder_inputs if type(model) is Transformer: out = model(encoding, source_masks, decoder_inputs, decoder_masks) loss = model.cost(targets, target_masks, out) elif type(model) is FastTransformer: losses = [] for iter_ in range(args.train_repeat_dec): curr_iter = min(iter_, args.num_shared_dec-1) next_iter = min(curr_iter + 1, args.num_shared_dec-1) out = model(encoding, source_masks, decoder_inputs, decoder_masks, iter_=curr_iter) losses.append( model.cost(targets, target_masks, out=out, iter_=curr_iter) ) logits = model.decoder[curr_iter].out(out) if args.use_argmax: _, argmax = torch.max(logits, dim=-1) else: logits = softmax(logits) logits_sz = logits.size() logits_ = Variable(logits.data, requires_grad=False) argmax = torch.multinomial(logits_.contiguous().view(-1, logits_sz[-1]), 1)\ .view(*logits_sz[:-1]) decoder_inputs = F.embedding(argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) if args.sum_out_and_emb: decoder_inputs += out if args.diff_loss_w > 0 and ((args.diff_loss_dec1 == False) or (args.diff_loss_dec1 == True and iter_ == 0)): num_words = out.size(1) # first L2 normalize out_norm = out.div(out.norm(p=2, dim=-1, keepdim=True)) # calculate loss diff_loss = torch.mean((out_norm[:,1:,:] * out_norm[:,:-1,:]).sum(-1).clamp(min=0)) * args.diff_loss_w # add this losses to all losses losses.append(diff_loss) loss = sum(losses) # accmulate the training metrics train_metrics.accumulate(batch_size, *losses, print_iter=None) # train the student loss.backward() if args.grad_clip > 0: total_norm = nn.utils.clip_grad_norm(params, args.grad_clip) opt.step() info = 'training step={}, loss={}, lr={:.5f}'.format( iters, "/".join(["{:.3f}".format(export(ll)) for ll in losses]), opt.param_groups[0]['lr']) if iters % args.eval_every == 0 and args.tensorboard and (not args.debug): for idx in range(args.train_repeat_dec): writer.add_scalar('train/single/Loss_{}'.format(idx+1), export(losses[idx]), iters) if args.no_tqdm: if iters % args.eval_every == 0: args.logger.info(train_metrics) else: progressbar.update(1) progressbar.set_description(info) train_metrics.reset()