Ejemplo n.º 1
0
def train_model(args,
                model,
                train,
                dev,
                teacher_model=None,
                save_path=None,
                maxsteps=None):

    if args.tensorboard and (not args.debug):
        from tensorboardX import SummaryWriter
        writer = SummaryWriter('./runs/{}'.format(args.prefix + args.hp_str))

    # optimizer
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(
            [p for p in model.parameters() if p.requires_grad],
            betas=(0.9, 0.98),
            eps=1e-9)
    else:
        raise NotImplementedError

    # if resume training
    if (args.load_from is not None) and (args.resume):
        with torch.cuda.device(args.gpu):  # very important.
            offset, opt_states = torch.load(
                './models/' + args.load_from + '.pt.states',
                map_location=lambda storage, loc: storage.cuda())
            opt.load_state_dict(opt_states)
    else:
        offset = 0

    # metrics
    if save_path is None:
        save_path = args.model_name

    best = Best(max,
                'corpus_bleu',
                'corpus_gleu',
                'gleu',
                'loss',
                'i',
                model=model,
                opt=opt,
                path=save_path,
                gpu=args.gpu)
    train_metrics = Metrics('train', 'loss', 'real', 'fake')
    dev_metrics = Metrics('dev', 'loss', 'gleu', 'real_loss', 'fake_loss',
                          'distance', 'alter_loss', 'distance2',
                          'fertility_loss', 'corpus_gleu')
    progressbar = tqdm(total=args.eval_every, desc='start training.')

    for iters, batch in enumerate(train):

        iters += offset

        if iters % args.save_every == 0:
            args.logger.info(
                'save (back-up) checkpoints at iter={}'.format(iters))
            with torch.cuda.device(args.gpu):
                torch.save(best.model.state_dict(),
                           '{}_iter={}.pt'.format(args.model_name, iters))
                torch.save([iters, best.opt.state_dict()],
                           '{}_iter={}.pt.states'.format(
                               args.model_name, iters))

        if iters % args.eval_every == 0:
            progressbar.close()
            dev_metrics.reset()

            if args.distillation:
                outputs_course = valid_model(args,
                                             model,
                                             dev,
                                             dev_metrics,
                                             distillation=True,
                                             teacher_model=None)

            outputs_data = valid_model(
                args,
                model,
                dev,
                None if args.distillation else dev_metrics,
                teacher_model=None,
                print_out=True)
            if args.tensorboard and (not args.debug):
                writer.add_scalar('dev/GLEU_sentence_', dev_metrics.gleu,
                                  iters)
                writer.add_scalar('dev/Loss', dev_metrics.loss, iters)
                writer.add_scalar('dev/GLEU_corpus_',
                                  outputs_data['corpus_gleu'], iters)
                writer.add_scalar('dev/BLEU_corpus_',
                                  outputs_data['corpus_bleu'], iters)

                if args.distillation:
                    writer.add_scalar('dev/GLEU_corpus_dis',
                                      outputs_course['corpus_gleu'], iters)
                    writer.add_scalar('dev/BLEU_corpus_dis',
                                      outputs_course['corpus_bleu'], iters)

            if not args.debug:
                best.accumulate(outputs_data['corpus_bleu'],
                                outputs_data['corpus_gleu'], dev_metrics.gleu,
                                dev_metrics.loss, iters)
                args.logger.info(
                    'the best model is achieved at {}, average greedy GLEU={}, corpus GLEU={}, corpus BLEU={}'
                    .format(best.i, best.gleu, best.corpus_gleu,
                            best.corpus_bleu))
            args.logger.info('model:' + args.prefix + args.hp_str)

            # ---set-up a new progressor---
            progressbar = tqdm(total=args.eval_every, desc='start training.')

        if maxsteps is None:
            maxsteps = args.maximum_steps

        if iters > maxsteps:
            args.logger.info('reach the maximum updating steps.')
            break

        # --- training --- #
        model.train()

        def get_learning_rate(i, lr0=0.1, disable=False):
            if not disable:
                return lr0 * 10 / math.sqrt(args.d_model) * min(
                    1 / math.sqrt(i), i /
                    (args.warmup * math.sqrt(args.warmup)))
            return 0.00002

        opt.param_groups[0]['lr'] = get_learning_rate(
            iters + 1, disable=args.disable_lr_schedule)
        opt.zero_grad()

        # prepare the data
        inputs, input_masks, \
        targets, target_masks, \
        sources, source_masks,\
        encoding, batch_size = model.quick_prepare(batch, args.distillation)
        input_reorder, fertility_cost, decoder_inputs = None, None, inputs
        batch_fer = batch.fer_dec if args.distillation else batch.fer

        #print(input_masks.size(), target_masks.size(), input_masks.sum())

        if type(model) is FastTransformer:
            inputs, input_reorder, input_masks, fertility_cost = model.prepare_initial(
                encoding, sources, source_masks, input_masks, batch_fer)

        # Maximum Likelihood Training
        if not args.finetuning:
            loss = model.cost(targets,
                              target_masks,
                              out=model(encoding, source_masks, inputs,
                                        input_masks))
            if args.fertility:
                loss += fertility_cost

        else:
            # finetuning:

            # loss_student (MLE)
            if not args.fertility:
                decoding, out, probs = model(encoding,
                                             source_masks,
                                             inputs,
                                             input_masks,
                                             return_probs=True,
                                             decoding=True)
                loss_student = model.batched_cost(targets, target_masks,
                                                  probs)  # student-loss (MLE)
                decoder_masks = input_masks

            else:  # Note that MLE and decoding has different translations. We need to run the same code twice
                # truth
                decoding, out, probs = model(encoding,
                                             source_masks,
                                             inputs,
                                             input_masks,
                                             decoding=True,
                                             return_probs=True)
                loss_student = model.cost(targets, target_masks, out=out)
                decoder_masks = input_masks

                # baseline
                decoder_inputs_b, _, decoder_masks_b, _, _ = model.prepare_initial(
                    encoding,
                    sources,
                    source_masks,
                    input_masks,
                    None,
                    mode='mean')
                decoding_b, out_b, probs_b = model(
                    encoding,
                    source_masks,
                    decoder_inputs_b,
                    decoder_masks_b,
                    decoding=True,
                    return_probs=True)  # decode again

                # reinforce
                decoder_inputs_r, _, decoder_masks_r, _, _ = model.prepare_initial(
                    encoding,
                    sources,
                    source_masks,
                    input_masks,
                    None,
                    mode='reinforce')
                decoding_r, out_r, probs_r = model(
                    encoding,
                    source_masks,
                    decoder_inputs_r,
                    decoder_masks_r,
                    decoding=True,
                    return_probs=True)  # decode again

            if args.fertility:
                loss_student += fertility_cost

            # loss_teacher (RKL+REINFORCE)
            teacher_model.eval()
            if not args.fertility:
                inputs_student_index, _, targets_student_soft, _, _, _, encoding_teacher, _ = model.quick_prepare(
                    batch, False, decoding, probs, decoder_masks,
                    decoder_masks, source_masks)
                out_teacher, probs_teacher = teacher_model(
                    encoding_teacher,
                    source_masks,
                    inputs_student_index.detach(),
                    decoder_masks,
                    return_probs=True)
                loss_teacher = teacher_model.batched_cost(
                    targets_student_soft, decoder_masks,
                    probs_teacher.detach())
                loss = (
                    1 - args.beta1
                ) * loss_teacher + args.beta1 * loss_student  # final results

            else:
                inputs_student_index, _, targets_student_soft, _, _, _, encoding_teacher, _ = model.quick_prepare(
                    batch, False, decoding, probs, decoder_masks,
                    decoder_masks, source_masks)
                out_teacher, probs_teacher = teacher_model(
                    encoding_teacher,
                    source_masks,
                    inputs_student_index.detach(),
                    decoder_masks,
                    return_probs=True)
                loss_teacher = teacher_model.batched_cost(
                    targets_student_soft, decoder_masks,
                    probs_teacher.detach())

                inputs_student_index, _ = model.prepare_inputs(
                    batch, decoding_b, False, decoder_masks_b)
                targets_student_soft, _ = model.prepare_targets(
                    batch, probs_b, False, decoder_masks_b)

                out_teacher, probs_teacher = teacher_model(
                    encoding_teacher,
                    source_masks,
                    inputs_student_index.detach(),
                    decoder_masks_b,
                    return_probs=True)

                _, loss_1 = teacher_model.batched_cost(targets_student_soft,
                                                       decoder_masks_b,
                                                       probs_teacher.detach(),
                                                       True)

                inputs_student_index, _ = model.prepare_inputs(
                    batch, decoding_r, False, decoder_masks_r)
                targets_student_soft, _ = model.prepare_targets(
                    batch, probs_r, False, decoder_masks_r)

                out_teacher, probs_teacher = teacher_model(
                    encoding_teacher,
                    source_masks,
                    inputs_student_index.detach(),
                    decoder_masks_r,
                    return_probs=True)
                _, loss_2 = teacher_model.batched_cost(targets_student_soft,
                                                       decoder_masks_r,
                                                       probs_teacher.detach(),
                                                       True)

                rewards = -(loss_2 - loss_1).data
                rewards = rewards - rewards.mean()
                rewards = rewards.expand_as(source_masks)
                rewards = rewards * source_masks

                model.predictor.saved_fertilities.reinforce(
                    0.1 * rewards.contiguous().view(-1, 1))
                loss = (
                    1 - args.beta1
                ) * loss_teacher + args.beta1 * loss_student  # detect reinforce

        # accmulate the training metrics
        train_metrics.accumulate(batch_size, loss, print_iter=None)
        train_metrics.reset()

        # train the student
        if args.finetuning and args.fertility:
            torch.autograd.backward(
                (loss, model.predictor.saved_fertilities),
                (torch.ones(1).cuda(loss.get_device()), None))
        else:
            loss.backward()
        opt.step()

        info = 'training step={}, loss={:.3f}, lr={:.5f}'.format(
            iters, export(loss), opt.param_groups[0]['lr'])
        if args.finetuning:
            info += '| NA:{:.3f}, AR:{:.3f}'.format(export(loss_student),
                                                    export(loss_teacher))
            if args.fertility:
                info += '| RL: {:.3f}'.format(export(rewards.mean()))

        if args.fertility:
            info += '| RE:{:.3f}'.format(export(fertility_cost))

        if args.tensorboard and (not args.debug):
            writer.add_scalar('train/Loss', export(loss), iters)

        progressbar.update(1)
        progressbar.set_description(info)
Ejemplo n.º 2
0
                    corpus_bleu = outputs_data['corpus_bleu']

                args.logger.info('model:' + args.prefix + args.hp_str + "\n")

        if args.tensorboard and (not args.debug):
            writer.add_scalar('dev/zero_shot_BLEU', corpus_bleu0, iters)
            writer.add_scalar('dev/fine_tune_BLEU', corpus_bleu, iters)

        args.logger.info('validation done.\n')
        model.load_fast_weights(weights)  # --- comming back to normal

        # -- restart the progressbar --
        progressbar = tqdm(total=args.eval_every, desc='start training')

        if not args.debug:
            best.accumulate(corpus_bleu, iters)
            args.logger.info(
                'the best model is achieved at {},  corpus BLEU={}'.format(
                    best.i, best.corpus_bleu))

    # ----- meta-training ------- #
    model.train()
    if iters > args.maximum_steps:
        args.logger.info('reach the maximum updating steps.')
        break

    # ----- inner-loop ------
    selected = random.randint(0, args.n_lang -
                              1)  # randomly pick one language pair
    if args.cross_meta_learning:
Ejemplo n.º 3
0
def train_model(args,
                model,
                train,
                dev,
                src=None,
                trg=None,
                trg_len_dic=None,
                teacher_model=None,
                save_path=None,
                maxsteps=None):

    if args.tensorboard and (not args.debug):
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(str(args.event_path / args.id_str))

    if type(model) is FastTransformer and args.denoising_prob > 0.0:
        denoising_weights = [
            args.denoising_weight for idx in range(args.train_repeat_dec)
        ]
        denoising_out_weights = [
            args.denoising_out_weight for idx in range(args.train_repeat_dec)
        ]

    if type(model) is FastTransformer and args.layerwise_denoising_weight:
        start, end = 0.9, 0.1
        diff = (start - end) / (args.train_repeat_dec - 1)
        denoising_weights = np.arange(start=end, stop=start,
                                      step=diff).tolist()[::-1] + [0.1]

    # optimizer
    for k, p in zip(model.state_dict().keys(), model.parameters()):
        # only finetune layers that are responsible to predicting target len
        if args.finetune_trg_len:
            if "pred_len" not in k:
                p.requires_grad = False
        else:
            if "pred_len" in k:
                p.requires_grad = False

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9)
    else:
        raise NotImplementedError

    # if resume training
    if (args.load_from is not None) and (args.resume):
        with torch.cuda.device(args.gpu):  # very important.
            offset, opt_states = torch.load(
                str(args.model_path / args.load_from) + '.pt.states',
                map_location=lambda storage, loc: storage.cuda())
            opt.load_state_dict(opt_states)
    else:
        offset = 0

    if not args.finetune_trg_len:
        best = Best(max,
                    *[
                        'BLEU_dec{}'.format(ii + 1)
                        for ii in range(args.valid_repeat_dec)
                    ],
                    'i',
                    model=model,
                    opt=opt,
                    path=str(args.model_path / args.id_str),
                    gpu=args.gpu,
                    which=range(args.valid_repeat_dec))
    else:
        best = Best(max,
                    *['pred_target_len_correct'],
                    'i',
                    model=model,
                    opt=opt,
                    path=str(args.model_path / args.id_str),
                    gpu=args.gpu,
                    which=[0])
    train_metrics = Metrics(
        'train loss',
        *['loss_{}'.format(idx + 1) for idx in range(args.train_repeat_dec)],
        data_type="avg")
    dev_metrics = Metrics(
        'dev loss',
        *['loss_{}'.format(idx + 1) for idx in range(args.valid_repeat_dec)],
        data_type="avg")

    if "predict" in args.trg_len_option:
        train_metrics_trg = Metrics('train loss target',
                                    *[
                                        "pred_target_len_loss",
                                        "pred_target_len_correct",
                                        "pred_target_len_approx"
                                    ],
                                    data_type="avg")
        train_metrics_average = Metrics(
            'train loss average',
            *["average_target_len_correct", "average_target_len_approx"],
            data_type="avg")
        dev_metrics_trg = Metrics('dev loss target',
                                  *[
                                      "pred_target_len_loss",
                                      "pred_target_len_correct",
                                      "pred_target_len_approx"
                                  ],
                                  data_type="avg")
        dev_metrics_average = Metrics(
            'dev loss average',
            *["average_target_len_correct", "average_target_len_approx"],
            data_type="avg")
    else:
        train_metrics_trg = None
        train_metrics_average = None
        dev_metrics_trg = None
        dev_metrics_average = None

    if not args.no_tqdm:
        progressbar = tqdm(total=args.eval_every, desc='start training.')

    if maxsteps is None:
        maxsteps = args.maximum_steps

    #targetlength = TargetLength()
    for iters, train_batch in enumerate(train):
        #targetlength.accumulate( train_batch )
        #continue

        iters += offset

        if args.save_every > 0 and iters % args.save_every == 0:
            args.logger.info(
                'save (back-up) checkpoints at iter={}'.format(iters))
            with torch.cuda.device(args.gpu):
                torch.save(
                    best.model.state_dict(),
                    '{}_iter={}.pt'.format(str(args.model_path / args.id_str),
                                           iters))
                torch.save([iters, best.opt.state_dict()],
                           '{}_iter={}.pt.states'.format(
                               str(args.model_path / args.id_str), iters))

        if iters % args.eval_every == 0:
            torch.cuda.empty_cache()
            gc.collect()
            dev_metrics.reset()
            if dev_metrics_trg is not None:
                dev_metrics_trg.reset()
            if dev_metrics_average is not None:
                dev_metrics_average.reset()
            outputs_data = valid_model(args,
                                       model,
                                       dev,
                                       dev_metrics,
                                       dev_metrics_trg=dev_metrics_trg,
                                       dev_metrics_average=dev_metrics_average,
                                       teacher_model=None,
                                       print_out=True,
                                       trg_len_dic=trg_len_dic)
            #outputs_data = [0, [0,0,0,0], 0, 0]
            if args.tensorboard and (not args.debug):
                for ii in range(args.valid_repeat_dec):
                    writer.add_scalar('dev/single/Loss_{}'.format(ii + 1),
                                      getattr(dev_metrics,
                                              "loss_{}".format(ii + 1)),
                                      iters)  # NLL averaged over dev corpus
                    writer.add_scalar('dev/single/BLEU_{}'.format(ii + 1),
                                      outputs_data['real'][ii][0],
                                      iters)  # NOTE corpus bleu

                if "predict" in args.trg_len_option:
                    writer.add_scalar("dev/single/pred_target_len_loss",
                                      outputs_data["pred_target_len_loss"],
                                      iters)
                    writer.add_scalar("dev/single/pred_target_len_correct",
                                      outputs_data["pred_target_len_correct"],
                                      iters)
                    writer.add_scalar("dev/single/pred_target_len_approx",
                                      outputs_data["pred_target_len_approx"],
                                      iters)
                    writer.add_scalar(
                        "dev/single/average_target_len_correct",
                        outputs_data["average_target_len_correct"], iters)
                    writer.add_scalar(
                        "dev/single/average_target_len_approx",
                        outputs_data["average_target_len_approx"], iters)
                """
                writer.add_scalars('dev/total/BLEUs', {"iter_{}".format(idx+1):bleu for idx, bleu in enumerate(outputs_data['bleu']) }, iters)
                writer.add_scalars('dev/total/Losses',
                    { "iter_{}".format(idx+1):getattr(dev_metrics, "loss_{}".format(idx+1))
                     for idx in range(args.valid_repeat_dec) },
                     iters )
                """

            if not args.debug:
                if not args.finetune_trg_len:
                    best.accumulate(*[xx[0] for xx in outputs_data['real']],
                                    iters)

                    values = list(best.metrics.values())
                    args.logger.info("best model : {}, {}".format( "BLEU=[{}]".format(", ".join( [ str(x) for x in values[:args.valid_repeat_dec] ] ) ), \
                                                                  "i={}".format( values[args.valid_repeat_dec] ), ) )
                else:
                    best.accumulate(*[outputs_data['pred_target_len_correct']],
                                    iters)
                    values = list(best.metrics.values())
                    args.logger.info("best model : {}".format(
                        "pred_target_len_correct = {}".format(values[0])))

            args.logger.info('model:' + args.prefix + args.hp_str)

            # ---set-up a new progressor---
            if not args.no_tqdm:
                progressbar.close()
                progressbar = tqdm(total=args.eval_every,
                                   desc='start training.')

            if type(model) is FastTransformer and args.anneal_denoising_weight:
                for ii, bb in enumerate([xx[0]
                                         for xx in outputs_data['real']][:-1]):
                    denoising_weights[ii] = 0.9 - 0.1 * int(
                        math.floor(bb / 3.0))

        if iters > maxsteps:
            args.logger.info('reached the maximum updating steps.')
            break

        model.train()

        def get_lr_transformer(i, lr0=0.1):
            return lr0 * 10 / math.sqrt(args.d_model) * min(
                1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup)))

        def get_lr_anneal(iters, lr0=0.1):
            lr_end = 1e-5
            return max(0, (args.lr - lr_end) * (args.anneal_steps - iters) /
                       args.anneal_steps) + lr_end

        if args.lr_schedule == "fixed":
            opt.param_groups[0]['lr'] = args.lr
        elif args.lr_schedule == "anneal":
            opt.param_groups[0]['lr'] = get_lr_anneal(iters + 1)
        elif args.lr_schedule == "transformer":
            opt.param_groups[0]['lr'] = get_lr_transformer(iters + 1)

        opt.zero_grad()

        if args.dataset == "mscoco":
            decoder_inputs, decoder_masks,\
            targets, target_masks,\
            _, source_masks,\
            encoding, batch_size, rest = model.quick_prepare_mscoco(train_batch, all_captions=train_batch[1], fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_offset, trg_len_dic=trg_len_dic, bp=args.bp)
        else:
            decoder_inputs, decoder_masks,\
            targets, target_masks,\
            sources, source_masks,\
            encoding, batch_size, rest = model.quick_prepare(train_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp)

        losses = []
        if type(model) is Transformer:
            loss = model.cost(targets,
                              target_masks,
                              out=model(encoding, source_masks, decoder_inputs,
                                        decoder_masks))
            losses.append(loss)

        elif type(model) is FastTransformer:
            all_logits = []
            all_denoising_masks = []
            for iter_ in range(args.train_repeat_dec):
                curr_iter = min(iter_, args.num_decs - 1)
                next_iter = min(curr_iter + 1, args.num_decs - 1)

                out = model(encoding,
                            source_masks,
                            decoder_inputs,
                            decoder_masks,
                            iter_=curr_iter,
                            return_probs=False)

                if args.self_distil > 0.0:
                    loss, logits_masked = model.cost(targets,
                                                     target_masks,
                                                     out=out,
                                                     iter_=curr_iter,
                                                     return_logits=True)
                else:
                    loss = model.cost(targets,
                                      target_masks,
                                      out=out,
                                      iter_=curr_iter)

                logits = model.decoder[curr_iter].out(out)

                if args.use_argmax:
                    _, argmax = torch.max(logits, dim=-1)
                else:
                    probs = softmax(logits)
                    probs_sz = probs.size()
                    logits_ = Variable(probs.data, requires_grad=False)
                    argmax = torch.multinomial(
                        logits_.contiguous().view(-1, probs_sz[-1]),
                        1).view(*probs_sz[:-1])

                if args.self_distil > 0.0:
                    all_logits.append(logits_masked)

                losses.append(loss)

                decoder_inputs_ = 0
                denoising_mask = 1
                if args.next_dec_input in ["both", "emb"]:
                    if args.denoising_prob > 0.0 and np.random.rand(
                    ) < args.denoising_prob:
                        cor = corrupt_target(targets, decoder_masks,
                                             len(trg.vocab),
                                             denoising_weights[iter_],
                                             args.corruption_probs)

                        emb = F.embedding(
                            cor, model.decoder[next_iter].out.weight *
                            math.sqrt(args.d_model))
                        denoising_mask = 0
                    else:
                        emb = F.embedding(
                            argmax, model.decoder[next_iter].out.weight *
                            math.sqrt(args.d_model))

                    if args.denoising_out_weight > 0:
                        if denoising_out_weights[iter_] > 0.0:
                            corrupted_argmax = corrupt_target(
                                argmax, decoder_masks,
                                denoising_out_weights[iter_])
                        else:
                            corrupted_argmax = argmax
                        emb = F.embedding(
                            corrupted_argmax,
                            model.decoder[next_iter].out.weight *
                            math.sqrt(args.d_model))
                    decoder_inputs_ += emb
                all_denoising_masks.append(denoising_mask)

                if args.next_dec_input in ["both", "out"]:
                    decoder_inputs_ += out
                decoder_inputs = decoder_inputs_

            # self distillation loss if requested
            if args.self_distil > 0.0:
                self_distil_losses = []

                for logits_i in range(1, len(all_logits) - 1):
                    self_distill_loss_i = 0.0
                    for logits_j in range(logits_i + 1, len(all_logits)):
                        self_distill_loss_i += \
                                all_denoising_masks[logits_j] * \
                                all_denoising_masks[logits_i] * \
                                (1/(logits_j-logits_i)) * args.self_distil * F.mse_loss(all_logits[logits_i], all_logits[logits_j].detach())

                    self_distil_losses.append(self_distill_loss_i)

                self_distil_loss = sum(self_distil_losses)

        loss = sum(losses)

        # accmulate the training metrics
        train_metrics.accumulate(batch_size, *losses, print_iter=None)
        if train_metrics_trg is not None:
            train_metrics_trg.accumulate(batch_size,
                                         *[rest[0], rest[1], rest[2]])
        if train_metrics_average is not None:
            train_metrics_average.accumulate(batch_size, *[rest[3], rest[4]])
        if type(model) is FastTransformer and args.self_distil > 0.0:
            (loss + self_distil_loss).backward()
        else:
            if "predict" in args.trg_len_option:
                if args.finetune_trg_len:
                    rest[0].backward()
                else:
                    loss.backward()
            else:
                loss.backward()

        if args.grad_clip > 0:
            total_norm = nn.utils.clip_grad_norm(params, args.grad_clip)
        opt.step()

        mid_str = ''
        if type(model) is FastTransformer and args.self_distil > 0.0:
            mid_str += 'distil={:.5f}, '.format(
                self_distil_loss.cpu().data.numpy()[0])
        if type(model) is FastTransformer and "predict" in args.trg_len_option:
            mid_str += 'pred_target_len_loss={:.5f}, '.format(
                rest[0].cpu().data.numpy()[0])
        if type(model) is FastTransformer and args.denoising_prob > 0.0:
            mid_str += "/".join(
                ["{:.1f}".format(ff) for ff in denoising_weights[:-1]]) + ", "

        info = 'update={}, loss={}, {}lr={:.1e}'.format(
            iters, "/".join(["{:.3f}".format(export(ll)) for ll in losses]),
            mid_str, opt.param_groups[0]['lr'])

        if args.no_tqdm:
            if iters % args.eval_every == 0:
                args.logger.info("update {} : {}".format(
                    iters, str(train_metrics)))
        else:
            progressbar.update(1)
            progressbar.set_description(info)

        if iters % args.eval_every == 0 and args.tensorboard and (
                not args.debug):
            for idx in range(args.train_repeat_dec):
                writer.add_scalar(
                    'train/single/Loss_{}'.format(idx + 1),
                    getattr(train_metrics, "loss_{}".format(idx + 1)), iters)
            if "predict" in args.trg_len_option:
                writer.add_scalar(
                    "train/single/pred_target_len_loss",
                    getattr(train_metrics_trg, "pred_target_len_loss"), iters)
                writer.add_scalar(
                    "train/single/pred_target_len_correct",
                    getattr(train_metrics_trg, "pred_target_len_correct"),
                    iters)
                writer.add_scalar(
                    "train/single/pred_target_len_approx",
                    getattr(train_metrics_trg, "pred_target_len_approx"),
                    iters)
                writer.add_scalar(
                    "train/single/average_target_len_correct",
                    getattr(train_metrics_average,
                            "average_target_len_correct"), iters)
                writer.add_scalar(
                    "train/single/average_target_len_approx",
                    getattr(train_metrics_average,
                            "average_target_len_approx"), iters)

            train_metrics.reset()
            if train_metrics_trg is not None:
                train_metrics_trg.reset()
            if train_metrics_average is not None:
                train_metrics_average.reset()
Ejemplo n.º 4
0
def train_model(args, model, train, dev, save_path=None, maxsteps=None):

    if args.tensorboard and (not args.debug):
        from tensorboardX import SummaryWriter
        writer = SummaryWriter('./runs/{}'.format(args.prefix + args.hp_str))

    # optimizer
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(
            [p for p in model.parameters() if p.requires_grad],
            betas=(0.9, 0.98),
            eps=1e-9)
    else:
        raise NotImplementedError

    # if resume training
    if (args.load_from is not None) and (args.resume):
        with torch.cuda.device(args.gpu):  # very important.
            offset, opt_states = torch.load(
                './models/' + args.load_from + '.pt.states',
                map_location=lambda storage, loc: storage.cuda())
            opt.load_state_dict(opt_states)
    else:
        offset = 0

    # metrics
    if save_path is None:
        save_path = args.model_name

    best = Best(max,
                'corpus_bleu',
                'corpus_gleu',
                'gleu',
                'loss',
                'i',
                model=model,
                opt=opt,
                path=save_path,
                gpu=args.gpu)
    train_metrics = Metrics('train', 'loss', 'real', 'fake')
    dev_metrics = Metrics('dev', 'loss', 'gleu', 'real_loss', 'fake_loss',
                          'distance', 'alter_loss', 'distance2',
                          'fertility_loss', 'corpus_gleu')
    progressbar = tqdm(total=args.eval_every, desc='start training.')

    for iters, batch in enumerate(train):

        iters += offset

        # --- saving --- #
        if iters % args.save_every == 0:
            args.logger.info(
                'save (back-up) checkpoints at iter={}'.format(iters))
            with torch.cuda.device(args.gpu):
                torch.save(best.model.state_dict(),
                           '{}_iter={}.pt'.format(args.model_name, iters))
                torch.save([iters, best.opt.state_dict()],
                           '{}_iter={}.pt.states'.format(
                               args.model_name, iters))

        # --- validation --- #
        if iters % args.eval_every == 0:
            progressbar.close()
            dev_metrics.reset()

            if args.distillation:
                outputs_course = valid_model(args,
                                             model,
                                             dev,
                                             dev_metrics,
                                             distillation=True)

            outputs_data = valid_model(
                args,
                model,
                dev,
                None if args.distillation else dev_metrics,
                print_out=True)
            if args.tensorboard and (not args.debug):
                writer.add_scalar('dev/GLEU_sentence_', dev_metrics.gleu,
                                  iters)
                writer.add_scalar('dev/Loss', dev_metrics.loss, iters)
                writer.add_scalar('dev/GLEU_corpus_',
                                  outputs_data['corpus_gleu'], iters)
                writer.add_scalar('dev/BLEU_corpus_',
                                  outputs_data['corpus_bleu'], iters)

                if args.distillation:
                    writer.add_scalar('dev/GLEU_corpus_dis',
                                      outputs_course['corpus_gleu'], iters)
                    writer.add_scalar('dev/BLEU_corpus_dis',
                                      outputs_course['corpus_bleu'], iters)

            if not args.debug:
                best.accumulate(outputs_data['corpus_bleu'],
                                outputs_data['corpus_gleu'], dev_metrics.gleu,
                                dev_metrics.loss, iters)
                args.logger.info(
                    'the best model is achieved at {}, average greedy GLEU={}, corpus GLEU={}, corpus BLEU={}'
                    .format(best.i, best.gleu, best.corpus_gleu,
                            best.corpus_bleu))
            args.logger.info('model:' + args.prefix + args.hp_str)

            # ---set-up a new progressor---
            progressbar = tqdm(total=args.eval_every, desc='start training.')

        if maxsteps is None:
            maxsteps = args.maximum_steps

        if iters > maxsteps:
            args.logger.info('reach the maximum updating steps.')
            break

        # --- training --- #
        model.train()

        def get_learning_rate(i, lr0=0.1, disable=False):
            if not disable:
                return lr0 * 10 / math.sqrt(args.d_model) * min(
                    1 / math.sqrt(i), i /
                    (args.warmup * math.sqrt(args.warmup)))
            return 0.00002

        opt.param_groups[0]['lr'] = get_learning_rate(
            iters + 1, disable=args.disable_lr_schedule)
        opt.zero_grad()

        # prepare the data
        inputs, input_masks, \
        targets, target_masks, \
        sources, source_masks,\
        encoding, batch_size = model.quick_prepare(batch, args.distillation)
        input_reorder, fertility_cost, decoder_inputs = None, None, inputs

        #print(input_masks.size(), target_masks.size(), input_masks.sum())
        if type(model) is FastTransformer:
            batch_fer = batch.fer_dec if args.distillation else batch.fer
            inputs, input_reorder, input_masks, fertility_cost = model.prepare_initial(
                encoding, sources, source_masks, input_masks, batch_fer)

        # Maximum Likelihood Training
        loss = model.cost(targets,
                          target_masks,
                          out=model(encoding, source_masks, inputs,
                                    input_masks))
        if hasattr(args, 'fertility') and args.fertility:
            loss += fertility_cost

        # accmulate the training metrics
        train_metrics.accumulate(batch_size, loss, print_iter=None)
        train_metrics.reset()

        loss.backward()
        opt.step()

        info = 'training step={}, loss={:.3f}, lr={:.5f}'.format(
            iters, export(loss), opt.param_groups[0]['lr'])
        if hasattr(args, 'fertility') and args.fertility:
            info += '| RE:{:.3f}'.format(export(fertility_cost))

        if args.tensorboard and (not args.debug):
            writer.add_scalar('train/Loss', export(loss), iters)

        progressbar.update(1)
        progressbar.set_description(info)
Ejemplo n.º 5
0
def train_model(args, model, train, dev, save_path=None, maxsteps=None, writer=None):

    # optimizer
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], betas=(0.9, 0.98), eps=1e-9)
    else:
        raise NotImplementedError

    # if resume training
    if (args.load_from is not None) and (args.resume):
        with torch.cuda.device(args.gpu):   # very important.
            offset, opt_states = torch.load(args.models_dir + '/' + args.load_from + '.pt.states',
                                            map_location=lambda storage, loc: storage.cuda())
            if not args.finetune:  # if finetune, do not have history
                opt.load_state_dict(opt_states)
    else:
        offset = 0

    # metrics
    if save_path is None:
        save_path = args.model_name

    args.eval_every *= args.inter_size


    best = Best(max, 'corpus_bleu', 'corpus_gleu', 'gleu', 'loss', 'i', model=model, opt=opt, path=save_path, gpu=args.gpu)
    train_metrics = Metrics('train', 'loss', 'real', 'fake')
    dev_metrics = Metrics('dev', 'loss', 'gleu', 'real_loss', 'fake_loss', 'distance', 'alter_loss', 'distance2', 'fertility_loss', 'corpus_gleu')
    progressbar = tqdm(total=args.eval_every, desc='start training.')
    examples = 0
    first_step = True
    loss_outer = 0

    for iters, batch in enumerate(train):

        iters += offset
        
        # --- saving --- #
        if iters % args.save_every == 0:
            args.logger.info('save (back-up) checkpoints at iter={}'.format(iters))
            with torch.cuda.device(args.gpu):
                torch.save(best.model.state_dict(), '{}_iter={}.pt'.format(args.model_name, iters))
                torch.save([iters, best.opt.state_dict()], '{}_iter={}.pt.states'.format(args.model_name, iters))


        # --- validation --- #
        if ((args.eval_every_examples == -1) and (iters % args.eval_every == 0)) \
            or ((args.eval_every_examples > 0) and (examples > args.eval_every_examples)) \
            or first_step:

            first_step = False

            if args.eval_every_examples > 0:
                examples = examples % args.eval_every_examples

            for dev_iters, dev_batch in enumerate(dev):

                progressbar.close()
                dev_metrics.reset()

                if args.distillation:
                    outputs_course = valid_model(args, model, dev, dev_metrics, distillation=True)

                outputs_data = valid_model(args, model, dev, None if args.distillation else dev_metrics, print_out=True)
                if args.tensorboard and (not args.debug):
                    writer.add_scalar('dev/GLEU_sentence_', dev_metrics.gleu, iters / args.inter_size)
                    writer.add_scalar('dev/Loss', dev_metrics.loss, iters / args.inter_size)
                    writer.add_scalar('dev/GLEU_corpus_', outputs_data['corpus_gleu'], iters / args.inter_size)
                    writer.add_scalar('dev/BLEU_corpus_', outputs_data['corpus_bleu'], iters / args.inter_size)


                if not args.debug:
                    best.accumulate(outputs_data['corpus_bleu'], outputs_data['corpus_gleu'], dev_metrics.gleu, dev_metrics.loss, iters / args.inter_size)
                    args.logger.info('the best model is achieved at {}, average greedy GLEU={}, corpus GLEU={}, corpus BLEU={}'.format(
                        best.i, best.gleu, best.corpus_gleu, best.corpus_bleu))
                args.logger.info('model:' + args.prefix + args.hp_str)

            # ---set-up a new progressor---
            progressbar = tqdm(total=args.eval_every, desc='start training.')


        if maxsteps is None:
            maxsteps = args.maximum_steps

        if iters > maxsteps:
            args.logger.info('reach the maximum updating steps.')
            break


        # --- training --- #
        model.train()
        def get_learning_rate(i, lr0=0.1, disable=False):
            if not disable:
                return lr0 * 10 / math.sqrt(args.d_model) * min(1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup)))
            return 0.00002
        
        if iters % args.inter_size == 0:
            opt.param_groups[0]['lr'] = get_learning_rate(iters / args.inter_size + 1, disable=args.disable_lr_schedule)
            opt.zero_grad()
            loss_outer = 0

        # prepare the data
        inputs, input_masks, \
        targets, target_masks, \
        sources, source_masks,\
        encoding, batch_size = model.quick_prepare(batch, args.distillation)
        input_reorder, fertility_cost, decoder_inputs = None, None, inputs

        examples += batch_size

        # Maximum Likelihood Training
        loss = model.cost(targets, target_masks, out=model(encoding, source_masks, inputs, input_masks)) / args.inter_size
        loss_outer = loss_outer + loss

        # accmulate the training metrics
        train_metrics.accumulate(batch_size, loss, print_iter=None)
        train_metrics.reset()

        loss.backward()
        
        if iters % args.inter_size == (args.inter_size - 1):

            if args.universal_options == 'no_update_encdec':
                for p in model.parameters():
                    if p is not model.encoder.uni_out.weight:
                        if p.grad is not None:
                            p.grad.detach_()
                            p.grad.zero_()

            opt.step()

            info = 'training step={}, loss={:.3f}, lr={:.8f}'.format(iters / args.inter_size, export(loss_outer), opt.param_groups[0]['lr'])
            if args.tensorboard and (not args.debug):
                writer.add_scalar('train/Loss', export(loss_outer), iters / args.inter_size)

            progressbar.update(1)
            progressbar.set_description(info)
Ejemplo n.º 6
0
def train_model(args, model, train, dev, src, trg, teacher_model=None, save_path=None, maxsteps=None):

    if args.tensorboard and (not args.debug):
        from tensorboardX import SummaryWriter
        writer = SummaryWriter('{}{}'.format(args.event_path, args.prefix+args.hp_str))

    # optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9)
    else:
        raise NotImplementedError

    # if resume training
    if (args.load_from is not None) and (args.resume):
        with torch.cuda.device(args.gpu):   # very important.
            offset, opt_states = torch.load(os.path.join(args.model_path, args.load_from + '.pt.states'),
                                            map_location=lambda storage, loc: storage.cuda())
            opt.load_state_dict(opt_states)
    else:
        offset = 0

    # metrics
    if save_path is None:
        save_path = args.model_name

    best = Best(max, *['BLEU_dec{}'.format(ii+1) for ii in range(args.valid_repeat_dec)], \
                     'i', model=model, opt=opt, path=save_path, gpu=args.gpu, \
                     which=range(args.valid_repeat_dec))
    train_metrics = Metrics('train loss', *['loss_{}'.format(idx+1) for idx in range(args.train_repeat_dec)], data_type = "avg")
    dev_metrics = Metrics('dev loss', *['loss_{}'.format(idx+1) for idx in range(args.valid_repeat_dec)], data_type = "avg")
    if not args.no_tqdm:
        progressbar = tqdm(total=args.eval_every, desc='start training.')

    for iters, batch in enumerate(train):
        iters += offset

        if iters % args.save_every == 0:
            args.logger.info('save (back-up) checkpoints at iter={}'.format(iters))
            with torch.cuda.device(args.gpu):
                torch.save(best.model.state_dict(), '{}.pt'.format(args.model_name))
                torch.save([iters, best.opt.state_dict()], '{}.pt.states'.format(args.model_name))

        if iters % args.eval_every == 0:
            dev_metrics.reset()
            outputs_data = valid_model(args, model, dev, dev_metrics, teacher_model=None, print_out=True)

            if args.tensorboard and (not args.debug):
                for ii in range(args.valid_repeat_dec):
                    writer.add_scalar('dev/single/Loss_{}'.format(ii + 1), getattr(dev_metrics, "loss_{}".format(ii+1)), iters)
                    writer.add_scalar('dev/single/BLEU_{}'.format(ii + 1), outputs_data['bleu'][ii], iters)

                writer.add_scalars('dev/multi/BLEUs', {"iter_{}".format(idx+1):bleu for idx, bleu in enumerate(outputs_data['bleu']) }, iters)
                writer.add_scalars('dev/multi/Losses', \
                    { "iter_{}".format(idx+1):getattr(dev_metrics, "loss_{}".format(idx+1)) \
                     for idx in range(args.valid_repeat_dec) }, \
                     iters)

            if not args.debug:
                best.accumulate(*outputs_data['bleu'], iters)
                values = list( best.metrics.values() )
                args.logger.info("best model : {}, {}".format( "BLEU=[{}]".format(", ".join( [ str(x) for x in values[:args.valid_repeat_dec] ] ) ), \
                                                              "i={}".format( values[args.valid_repeat_dec] ), ) )
            args.logger.info('model:' + args.prefix + args.hp_str)

            # ---set-up a new progressor---
            if not args.no_tqdm:
                progressbar.close()
                progressbar = tqdm(total=args.eval_every, desc='start training.')

        if maxsteps is None:
            maxsteps = args.maximum_steps

        if iters > maxsteps:
            args.logger.info('reach the maximum updating steps.')
            break

        # --- training --- #
        model.train()
        def get_learning_rate(i, lr0=0.1, disable=False):
            if not disable:
                return max(0.00003, args.lr / math.pow(5, math.floor(i/50000)))
                '''
                return lr0 * 10 / math.sqrt(args.d_model) * min(
                        1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup)))
                '''
            return args.lr
        opt.param_groups[0]['lr'] = get_learning_rate(iters + 1, disable=args.disable_lr_schedule)
        opt.zero_grad()

        # prepare the data
        inputs, input_masks, \
        targets, target_masks, \
        sources, source_masks,\
        encoding, batch_size = model.quick_prepare(batch)

        #print(input_masks.size(), target_masks.size(), input_masks.sum())

        if type(model) is Transformer:
            decoder_inputs, decoder_masks = inputs, input_masks
        elif type(model) is FastTransformer:
            decoder_inputs, _, decoder_masks = \
                    model.prepare_initial(encoding, sources, source_masks, input_masks)
            initial_inputs = decoder_inputs

        if type(model) is Transformer:
            out = model(encoding, source_masks, decoder_inputs, decoder_masks)
            loss = model.cost(targets, target_masks, out)
        elif type(model) is FastTransformer:
            losses = []
            for iter_ in range(args.train_repeat_dec):

                curr_iter = min(iter_, args.num_shared_dec-1)
                next_iter = min(curr_iter + 1, args.num_shared_dec-1)

                out = model(encoding, source_masks, decoder_inputs, decoder_masks, iter_=curr_iter)
                losses.append( model.cost(targets, target_masks, out=out, iter_=curr_iter) )

                logits = model.decoder[curr_iter].out(out)
                if args.use_argmax:
                    _, argmax = torch.max(logits, dim=-1)
                else:
                    logits = softmax(logits)
                    logits_sz = logits.size()
                    logits_ = Variable(logits.data, requires_grad=False)
                    argmax = torch.multinomial(logits_.contiguous().view(-1, logits_sz[-1]), 1)\
                            .view(*logits_sz[:-1])

                decoder_inputs = F.embedding(argmax, model.decoder[next_iter].out.weight *
                                             math.sqrt(args.d_model))
                if args.sum_out_and_emb:
                    decoder_inputs += out

                if args.diff_loss_w > 0 and ((args.diff_loss_dec1 == False) or (args.diff_loss_dec1 == True and iter_ == 0)):
                    num_words = out.size(1)

                    # first L2 normalize
                    out_norm = out.div(out.norm(p=2, dim=-1, keepdim=True))

                    # calculate loss
                    diff_loss = torch.mean((out_norm[:,1:,:] * out_norm[:,:-1,:]).sum(-1).clamp(min=0)) * args.diff_loss_w

                    # add this losses to all losses
                    losses.append(diff_loss)

            loss = sum(losses)

        # accmulate the training metrics
        train_metrics.accumulate(batch_size, *losses, print_iter=None)

        # train the student
        loss.backward()
        if args.grad_clip > 0:
            total_norm = nn.utils.clip_grad_norm(params, args.grad_clip)
        opt.step()

        info = 'training step={}, loss={}, lr={:.5f}'.format(
                    iters,
                    "/".join(["{:.3f}".format(export(ll)) for ll in losses]),
                    opt.param_groups[0]['lr'])

        if iters % args.eval_every == 0 and args.tensorboard and (not args.debug):
            for idx in range(args.train_repeat_dec):
                writer.add_scalar('train/single/Loss_{}'.format(idx+1), export(losses[idx]), iters)

        if args.no_tqdm:
            if iters % args.eval_every == 0:
                args.logger.info(train_metrics)
        else:
            progressbar.update(1)
            progressbar.set_description(info)
        train_metrics.reset()