Beispiel #1
0
    model.to(device=args.device)

    early_stopping = EarlyStopping(args.patience)
    trainer = Trainer(model, {
        'train': train,
        'valid': valid
    },
                      optimizer,
                      losses=('ppl', ),
                      early_stopping=early_stopping,
                      max_norm=args.max_norm,
                      checkpoint=Checkpoint('EncoderDecoder',
                                            mode='nlast',
                                            keep=3).setup(args))
    trainer.add_loggers(StdLogger())
    # trainer.add_loggers(VisdomLogger(env='encdec'))
    trainer.add_loggers(TensorboardLogger(comment='encdec'))

    hook = make_encdec_hook(args.target, beam=args.beam)
    trainer.add_hook(hook, hooks_per_epoch=args.hooks_per_epoch)
    hook = u.make_schedule_hook(
        inflection_sigmoid(len(train) * 2, 1.75, inverse=True))
    trainer.add_hook(hook, hooks_per_epoch=1000)

    (model,
     valid_loss), test_loss = trainer.train(args.epochs,
                                            args.checkpoint,
                                            shuffle=True,
                                            use_schedule=args.use_schedule)
Beispiel #2
0
def kl_sigmoid_annealing_schedule(inflection, steepness=3):
    return inflection_sigmoid(inflection, steepness)
Beispiel #3
0
        early_stopping = EarlyStopping(args.patience)

    checkpoint = None
    if args.save:
        checkpoint = Checkpoint(m.__class__.__name__, keep=3).setup(args)

    model_hook = u.make_lm_hook(
        d, temperature=args.temperature, max_seq_len=args.max_seq_len,
        device=args.device, level=args.level, early_stopping=early_stopping,
        checkpoint=checkpoint)
    trainer.add_hook(model_hook, hooks_per_epoch=args.hooks_per_epoch)

    # - scheduled sampling hook
    if args.use_schedule:
        schedule = inflection_sigmoid(
            len(train) * args.schedule_inflection, args.schedule_steepness,
            a=args.schedule_init, inverse=True)
        trainer.add_hook(
            u.make_schedule_hook(schedule, verbose=True), hooks_per_epoch=10e4)

    # - lr schedule hook
    if args.lr_schedule_factor < 1.0:
        hook = make_lr_hook(
            optimizer, args.lr_schedule_factor, args.lr_schedule_checkpoints)
        # run a hook args.checkpoint * 4 batches
        trainer.add_hook(hook, hooks_per_epoch=args.lr_checkpoints_per_epoch)

    # loggers
    trainer.add_loggers(StdLogger())
    if args.visdom:
        visdom_logger = VisdomLogger(