Exemplo n.º 1
0
def train_model(args, model, train, dev, save_path=None, maxsteps=None, writer=None):

    # optimizer
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], betas=(0.9, 0.98), eps=1e-9)
    else:
        raise NotImplementedError

    # if resume training
    if (args.load_from is not None) and (args.resume):
        with torch.cuda.device(args.gpu):   # very important.
            offset, opt_states = torch.load(args.models_dir + '/' + args.load_from + '.pt.states',
                                            map_location=lambda storage, loc: storage.cuda())
            if not args.finetune:  # if finetune, do not have history
                opt.load_state_dict(opt_states)
    else:
        offset = 0

    # metrics
    if save_path is None:
        save_path = args.model_name

    args.eval_every *= args.inter_size


    best = Best(max, 'corpus_bleu', 'corpus_gleu', 'gleu', 'loss', 'i', model=model, opt=opt, path=save_path, gpu=args.gpu)
    train_metrics = Metrics('train', 'loss', 'real', 'fake')
    dev_metrics = Metrics('dev', 'loss', 'gleu', 'real_loss', 'fake_loss', 'distance', 'alter_loss', 'distance2', 'fertility_loss', 'corpus_gleu')
    progressbar = tqdm(total=args.eval_every, desc='start training.')
    examples = 0
    first_step = True
    loss_outer = 0

    for iters, batch in enumerate(train):

        iters += offset
        
        # --- saving --- #
        if iters % args.save_every == 0:
            args.logger.info('save (back-up) checkpoints at iter={}'.format(iters))
            with torch.cuda.device(args.gpu):
                torch.save(best.model.state_dict(), '{}_iter={}.pt'.format(args.model_name, iters))
                torch.save([iters, best.opt.state_dict()], '{}_iter={}.pt.states'.format(args.model_name, iters))


        # --- validation --- #
        if ((args.eval_every_examples == -1) and (iters % args.eval_every == 0)) \
            or ((args.eval_every_examples > 0) and (examples > args.eval_every_examples)) \
            or first_step:

            first_step = False

            if args.eval_every_examples > 0:
                examples = examples % args.eval_every_examples

            for dev_iters, dev_batch in enumerate(dev):

                progressbar.close()
                dev_metrics.reset()

                if args.distillation:
                    outputs_course = valid_model(args, model, dev, dev_metrics, distillation=True)

                outputs_data = valid_model(args, model, dev, None if args.distillation else dev_metrics, print_out=True)
                if args.tensorboard and (not args.debug):
                    writer.add_scalar('dev/GLEU_sentence_', dev_metrics.gleu, iters / args.inter_size)
                    writer.add_scalar('dev/Loss', dev_metrics.loss, iters / args.inter_size)
                    writer.add_scalar('dev/GLEU_corpus_', outputs_data['corpus_gleu'], iters / args.inter_size)
                    writer.add_scalar('dev/BLEU_corpus_', outputs_data['corpus_bleu'], iters / args.inter_size)


                if not args.debug:
                    best.accumulate(outputs_data['corpus_bleu'], outputs_data['corpus_gleu'], dev_metrics.gleu, dev_metrics.loss, iters / args.inter_size)
                    args.logger.info('the best model is achieved at {}, average greedy GLEU={}, corpus GLEU={}, corpus BLEU={}'.format(
                        best.i, best.gleu, best.corpus_gleu, best.corpus_bleu))
                args.logger.info('model:' + args.prefix + args.hp_str)

            # ---set-up a new progressor---
            progressbar = tqdm(total=args.eval_every, desc='start training.')


        if maxsteps is None:
            maxsteps = args.maximum_steps

        if iters > maxsteps:
            args.logger.info('reach the maximum updating steps.')
            break


        # --- training --- #
        model.train()
        def get_learning_rate(i, lr0=0.1, disable=False):
            if not disable:
                return lr0 * 10 / math.sqrt(args.d_model) * min(1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup)))
            return 0.00002
        
        if iters % args.inter_size == 0:
            opt.param_groups[0]['lr'] = get_learning_rate(iters / args.inter_size + 1, disable=args.disable_lr_schedule)
            opt.zero_grad()
            loss_outer = 0

        # prepare the data
        inputs, input_masks, \
        targets, target_masks, \
        sources, source_masks,\
        encoding, batch_size = model.quick_prepare(batch, args.distillation)
        input_reorder, fertility_cost, decoder_inputs = None, None, inputs

        examples += batch_size

        # Maximum Likelihood Training
        loss = model.cost(targets, target_masks, out=model(encoding, source_masks, inputs, input_masks)) / args.inter_size
        loss_outer = loss_outer + loss

        # accmulate the training metrics
        train_metrics.accumulate(batch_size, loss, print_iter=None)
        train_metrics.reset()

        loss.backward()
        
        if iters % args.inter_size == (args.inter_size - 1):

            if args.universal_options == 'no_update_encdec':
                for p in model.parameters():
                    if p is not model.encoder.uni_out.weight:
                        if p.grad is not None:
                            p.grad.detach_()
                            p.grad.zero_()

            opt.step()

            info = 'training step={}, loss={:.3f}, lr={:.8f}'.format(iters / args.inter_size, export(loss_outer), opt.param_groups[0]['lr'])
            if args.tensorboard and (not args.debug):
                writer.add_scalar('train/Loss', export(loss_outer), iters / args.inter_size)

            progressbar.update(1)
            progressbar.set_description(info)
Exemplo n.º 2
0
def train_model(args, model, train, dev, src, trg, teacher_model=None, save_path=None, maxsteps=None):

    if args.tensorboard and (not args.debug):
        from tensorboardX import SummaryWriter
        writer = SummaryWriter('{}{}'.format(args.event_path, args.prefix+args.hp_str))

    # optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9)
    else:
        raise NotImplementedError

    # if resume training
    if (args.load_from is not None) and (args.resume):
        with torch.cuda.device(args.gpu):   # very important.
            offset, opt_states = torch.load(os.path.join(args.model_path, args.load_from + '.pt.states'),
                                            map_location=lambda storage, loc: storage.cuda())
            opt.load_state_dict(opt_states)
    else:
        offset = 0

    # metrics
    if save_path is None:
        save_path = args.model_name

    best = Best(max, *['BLEU_dec{}'.format(ii+1) for ii in range(args.valid_repeat_dec)], \
                     'i', model=model, opt=opt, path=save_path, gpu=args.gpu, \
                     which=range(args.valid_repeat_dec))
    train_metrics = Metrics('train loss', *['loss_{}'.format(idx+1) for idx in range(args.train_repeat_dec)], data_type = "avg")
    dev_metrics = Metrics('dev loss', *['loss_{}'.format(idx+1) for idx in range(args.valid_repeat_dec)], data_type = "avg")
    if not args.no_tqdm:
        progressbar = tqdm(total=args.eval_every, desc='start training.')

    for iters, batch in enumerate(train):
        iters += offset

        if iters % args.save_every == 0:
            args.logger.info('save (back-up) checkpoints at iter={}'.format(iters))
            with torch.cuda.device(args.gpu):
                torch.save(best.model.state_dict(), '{}.pt'.format(args.model_name))
                torch.save([iters, best.opt.state_dict()], '{}.pt.states'.format(args.model_name))

        if iters % args.eval_every == 0:
            dev_metrics.reset()
            outputs_data = valid_model(args, model, dev, dev_metrics, teacher_model=None, print_out=True)

            if args.tensorboard and (not args.debug):
                for ii in range(args.valid_repeat_dec):
                    writer.add_scalar('dev/single/Loss_{}'.format(ii + 1), getattr(dev_metrics, "loss_{}".format(ii+1)), iters)
                    writer.add_scalar('dev/single/BLEU_{}'.format(ii + 1), outputs_data['bleu'][ii], iters)

                writer.add_scalars('dev/multi/BLEUs', {"iter_{}".format(idx+1):bleu for idx, bleu in enumerate(outputs_data['bleu']) }, iters)
                writer.add_scalars('dev/multi/Losses', \
                    { "iter_{}".format(idx+1):getattr(dev_metrics, "loss_{}".format(idx+1)) \
                     for idx in range(args.valid_repeat_dec) }, \
                     iters)

            if not args.debug:
                best.accumulate(*outputs_data['bleu'], iters)
                values = list( best.metrics.values() )
                args.logger.info("best model : {}, {}".format( "BLEU=[{}]".format(", ".join( [ str(x) for x in values[:args.valid_repeat_dec] ] ) ), \
                                                              "i={}".format( values[args.valid_repeat_dec] ), ) )
            args.logger.info('model:' + args.prefix + args.hp_str)

            # ---set-up a new progressor---
            if not args.no_tqdm:
                progressbar.close()
                progressbar = tqdm(total=args.eval_every, desc='start training.')

        if maxsteps is None:
            maxsteps = args.maximum_steps

        if iters > maxsteps:
            args.logger.info('reach the maximum updating steps.')
            break

        # --- training --- #
        model.train()
        def get_learning_rate(i, lr0=0.1, disable=False):
            if not disable:
                return max(0.00003, args.lr / math.pow(5, math.floor(i/50000)))
                '''
                return lr0 * 10 / math.sqrt(args.d_model) * min(
                        1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup)))
                '''
            return args.lr
        opt.param_groups[0]['lr'] = get_learning_rate(iters + 1, disable=args.disable_lr_schedule)
        opt.zero_grad()

        # prepare the data
        inputs, input_masks, \
        targets, target_masks, \
        sources, source_masks,\
        encoding, batch_size = model.quick_prepare(batch)

        #print(input_masks.size(), target_masks.size(), input_masks.sum())

        if type(model) is Transformer:
            decoder_inputs, decoder_masks = inputs, input_masks
        elif type(model) is FastTransformer:
            decoder_inputs, _, decoder_masks = \
                    model.prepare_initial(encoding, sources, source_masks, input_masks)
            initial_inputs = decoder_inputs

        if type(model) is Transformer:
            out = model(encoding, source_masks, decoder_inputs, decoder_masks)
            loss = model.cost(targets, target_masks, out)
        elif type(model) is FastTransformer:
            losses = []
            for iter_ in range(args.train_repeat_dec):

                curr_iter = min(iter_, args.num_shared_dec-1)
                next_iter = min(curr_iter + 1, args.num_shared_dec-1)

                out = model(encoding, source_masks, decoder_inputs, decoder_masks, iter_=curr_iter)
                losses.append( model.cost(targets, target_masks, out=out, iter_=curr_iter) )

                logits = model.decoder[curr_iter].out(out)
                if args.use_argmax:
                    _, argmax = torch.max(logits, dim=-1)
                else:
                    logits = softmax(logits)
                    logits_sz = logits.size()
                    logits_ = Variable(logits.data, requires_grad=False)
                    argmax = torch.multinomial(logits_.contiguous().view(-1, logits_sz[-1]), 1)\
                            .view(*logits_sz[:-1])

                decoder_inputs = F.embedding(argmax, model.decoder[next_iter].out.weight *
                                             math.sqrt(args.d_model))
                if args.sum_out_and_emb:
                    decoder_inputs += out

                if args.diff_loss_w > 0 and ((args.diff_loss_dec1 == False) or (args.diff_loss_dec1 == True and iter_ == 0)):
                    num_words = out.size(1)

                    # first L2 normalize
                    out_norm = out.div(out.norm(p=2, dim=-1, keepdim=True))

                    # calculate loss
                    diff_loss = torch.mean((out_norm[:,1:,:] * out_norm[:,:-1,:]).sum(-1).clamp(min=0)) * args.diff_loss_w

                    # add this losses to all losses
                    losses.append(diff_loss)

            loss = sum(losses)

        # accmulate the training metrics
        train_metrics.accumulate(batch_size, *losses, print_iter=None)

        # train the student
        loss.backward()
        if args.grad_clip > 0:
            total_norm = nn.utils.clip_grad_norm(params, args.grad_clip)
        opt.step()

        info = 'training step={}, loss={}, lr={:.5f}'.format(
                    iters,
                    "/".join(["{:.3f}".format(export(ll)) for ll in losses]),
                    opt.param_groups[0]['lr'])

        if iters % args.eval_every == 0 and args.tensorboard and (not args.debug):
            for idx in range(args.train_repeat_dec):
                writer.add_scalar('train/single/Loss_{}'.format(idx+1), export(losses[idx]), iters)

        if args.no_tqdm:
            if iters % args.eval_every == 0:
                args.logger.info(train_metrics)
        else:
            progressbar.update(1)
            progressbar.set_description(info)
        train_metrics.reset()