Example #1
0
def test_create_lr_scheduler_with_warmup_with_real_model(dummy_model_factory):

    model = dummy_model_factory(with_grads=False, with_frozen_layer=False)
    init_lr = 0.01
    optimizer = torch.optim.SGD(model.parameters(), lr=init_lr)
    scaled_lr = 0.02
    warmup_duration = 5
    step_size = 2
    gamma = 0.97

    output_simulated_values = [None] * 50

    create_lr_scheduler_with_warmup(
        torch.optim.lr_scheduler.StepLR(optimizer,
                                        step_size=step_size,
                                        gamma=gamma),
        warmup_start_value=0.0,
        warmup_end_value=scaled_lr,
        warmup_duration=warmup_duration,
        output_simulated_values=output_simulated_values,
    )

    assert output_simulated_values[0] == [0, 0.0]
    assert output_simulated_values[warmup_duration - 1] == [
        warmup_duration - 1,
        scaled_lr,
    ]
    assert output_simulated_values[warmup_duration] == [
        warmup_duration, init_lr
    ]
    v = [warmup_duration + step_size, init_lr * gamma]
    assert output_simulated_values[warmup_duration + step_size] == v
def test_create_lr_scheduler_with_warmup():
    with pytest.raises(TypeError):
        create_lr_scheduler_with_warmup(12,
                                        warmup_start_value=0.0,
                                        warmup_end_value=0.1,
                                        warmup_duration=10)

    def _test(lr_scheduler, optimizer, warmup_end_value_to_check=None):
        num_iterations = 10
        max_epochs = 20

        warmup_duration = 10
        warmup_end_value = 0.1

        simulated_values = [None] * (num_iterations * max_epochs)
        scheduler = create_lr_scheduler_with_warmup(
            lr_scheduler,
            warmup_start_value=0.0,
            warmup_end_value=warmup_end_value,
            warmup_duration=warmup_duration,
            output_simulated_values=simulated_values)

        lrs = []
        trainer = Engine(lambda engine, batch: None)

        @trainer.on(Events.ITERATION_COMPLETED)
        def save_lr(engine):
            lrs.append(optimizer.param_groups[0]['lr'])

        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

        data = [0] * num_iterations
        trainer.run(data, max_epochs=max_epochs)

        assert lrs == pytest.approx([v for i, v in simulated_values])

        if warmup_end_value_to_check is None:
            warmup_end_value_to_check = warmup_end_value
        assert lrs[warmup_duration] == warmup_end_value_to_check

    t1 = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([t1], lr=0.2)
    torch_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer=optimizer, gamma=0.98)
    _test(torch_lr_scheduler, optimizer)

    t1 = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([t1], lr=0.2)
    lr_scheduler = LinearCyclicalScheduler(optimizer=optimizer,
                                           param_name='lr',
                                           start_value=1.0,
                                           end_value=0.0,
                                           cycle_size=10)
    _test(lr_scheduler, optimizer, 1.0)
Example #3
0
    def _test(lr_scheduler, optimizer):
        num_iterations = 10
        max_epochs = 20

        simulated_values = [None] * (num_iterations * max_epochs)
        scheduler = create_lr_scheduler_with_warmup(
            lr_scheduler,
            warmup_start_value=0.0,
            warmup_end_value=0.1,
            warmup_duration=10,
            output_simulated_values=simulated_values)

        lrs = []
        trainer = Engine(lambda engine, batch: None)

        @trainer.on(Events.ITERATION_COMPLETED)
        def save_lr(engine):
            lrs.append(optimizer.param_groups[0]['lr'])

        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

        data = [0] * num_iterations
        trainer.run(data, max_epochs=max_epochs)

        assert lrs == pytest.approx([v for i, v in simulated_values])
Example #4
0
    def _test(save_history):
        tensor = torch.ones([1], requires_grad=True)
        optimizer = torch.optim.SGD([tensor], lr=0.001)

        max_epochs = 25
        lr_max_value = 0.4
        num_iterations_per_epoch = 128
        num_iterations = max_epochs * num_iterations_per_epoch
        warmup_duration = 5 * num_iterations_per_epoch
        cooldown_duration = 5 * num_iterations_per_epoch

        scheduler_1 = LinearCyclicalScheduler(
            optimizer,
            "lr",
            start_value=lr_max_value,
            end_value=lr_max_value * 0.9,
            cycle_size=(num_iterations - warmup_duration - cooldown_duration) * 2,
        )

        scheduler_2 = LinearCyclicalScheduler(
            optimizer, "lr", start_value=lr_max_value, end_value=0.0, cycle_size=cooldown_duration * 2
        )

        lr_scheduler = ConcatScheduler(
            schedulers=[scheduler_1, scheduler_2],
            durations=[num_iterations - warmup_duration - cooldown_duration],
            save_history=False,
        )
        lr_values = [None] * num_iterations
        scheduler = create_lr_scheduler_with_warmup(
            lr_scheduler,
            warmup_start_value=0.0,
            warmup_end_value=lr_max_value,
            warmup_duration=warmup_duration,
            save_history=save_history,
            output_simulated_values=lr_values,
        )
        state_dict = scheduler.state_dict()

        trainer = Engine(lambda engine, batch: None)

        @trainer.on(Events.ITERATION_COMPLETED)
        def save_lr(engine):
            lrs.append(optimizer.param_groups[0]["lr"])

        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

        data = [0] * num_iterations_per_epoch

        for _ in range(2):
            lrs = []
            trainer.run(data, max_epochs=max_epochs)

            assert lrs == pytest.approx([v for i, v in lr_values])

            if save_history:
                param_history = trainer.state.param_history["lr"]
                assert lrs == pytest.approx([v[0] for v in param_history])

            scheduler.load_state_dict(state_dict)
Example #5
0
    def _test(
        lr_scheduler,
        optimizer,
        warmup_start_value,
        warmup_end_value,
        warmup_duration,
        warmup_end_next_value,
    ):
        num_iterations = 10
        max_epochs = 20

        simulated_values = [None] * (num_iterations * max_epochs)
        scheduler = create_lr_scheduler_with_warmup(
            lr_scheduler,
            warmup_start_value=warmup_start_value,
            warmup_end_value=warmup_end_value,
            warmup_duration=warmup_duration,
            output_simulated_values=simulated_values,
        )
        state_dict = scheduler.state_dict()
        trainer = Engine(lambda engine, batch: None)

        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

        @trainer.on(Events.ITERATION_STARTED)
        def save_lr(engine):
            lrs.append(optimizer.param_groups[0]["lr"])

        data = [0] * num_iterations

        for _ in range(2):
            lrs = []
            trainer.run(data, max_epochs=max_epochs)

            assert lrs == pytest.approx([v for i, v in simulated_values])

            assert lrs[0] == pytest.approx(
                warmup_start_value), "lrs={}".format(lrs[:warmup_duration +
                                                         num_iterations])
            assert lrs[warmup_duration -
                       1] == pytest.approx(warmup_end_value), "lrs={}".format(
                           lrs[:warmup_duration + num_iterations])
            assert lrs[warmup_duration] == pytest.approx(
                warmup_end_next_value), "lrs={}".format(lrs[:warmup_duration +
                                                            num_iterations])
            scheduler.load_state_dict(state_dict)
def attach_lr_warmup(trainer, config, lr_scheduler):

    warmup_duration = (
        config['warmup_duration'] if config['warmup_duration'] > 0
        else config['steps_per_epoch'] * -config['warmup_duration']
    )

    warmup_end_value = (
        config['warmup_end_value'] if config['warmup_end_value'] != -1
        else config['learning_rate']
    )

    scheduler_with_warmup = create_lr_scheduler_with_warmup(
        lr_scheduler,
        config['warmup_start_value'],
        warmup_end_value,
        warmup_duration,
    )

    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler_with_warmup)
Example #7
0
model.to(device)

# multi-gpus
if torch.cuda.device_count():
    print('==================== Use {} GPUs ===================='.format(torch.cuda.device_count()))
    model = nn.DataParallel(model)

# loss function
loss_fn = nn.CrossEntropyLoss()

# optimizer
optimizer = optim.SGD(model.parameters(), lr=init_lr, momentum=0.9, weight_decay=5e-4)

# scheduler
scheduler = CosineAnnealingScheduler(optimizer, 'lr', init_lr, end_lr, 4*len(trainloader), cycle_mult=1.5, start_value_mult=0.1)
scheduler = create_lr_scheduler_with_warmup(scheduler, warmup_start_value=0., warmup_end_value=init_lr, warmup_duration=len(trainloader))

# create trainer
trainer = create_trainer(model, optimizer, loss_fn, device=device)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

# add timer for each iteration
timer = Timer(average=False)

# logging training loss
def log_loss(engine):
    i = engine.state.iteration
    e = engine.state.epoch

    if i % 100 == 0:
        print('[Iters {:0>7d}/{:0>2d}, {:.2f}s/100 iters, lr={:.4E}] loss={:.4f}'.format(i, e, timer.value(), optimizer.param_groups[0]['lr'], engine.state.output))
Example #8
0
def test_create_lr_scheduler_with_warmup():

    with pytest.raises(TypeError,
                       match=r"Argument lr_scheduler should be a subclass of"):
        create_lr_scheduler_with_warmup(12,
                                        warmup_start_value=0.0,
                                        warmup_end_value=0.1,
                                        warmup_duration=10)

    t1 = torch.zeros([1], requires_grad=True)
    # A) opt lr != warmup_end_value
    optimizer = torch.optim.SGD([t1], lr=0.2)
    torch_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer=optimizer, gamma=0.98)

    with pytest.raises(
            ValueError,
            match=r"Argument warmup_duration should be at least 2 events"):
        create_lr_scheduler_with_warmup(
            torch_lr_scheduler,
            warmup_start_value=0.0,
            warmup_end_value=0.1,
            warmup_duration=1,
        )

    with pytest.raises(
            ValueError,
            match=r"Argument warmup_duration should be at least 2 events"):
        create_lr_scheduler_with_warmup(
            torch_lr_scheduler,
            warmup_start_value=0.0,
            warmup_end_value=0.1,
            warmup_duration="abc",
        )

    with pytest.raises(
            TypeError,
            match=r"Argument output_simulated_values should be a list of None"
    ):
        simulated_values = ()
        create_lr_scheduler_with_warmup(
            torch_lr_scheduler,
            warmup_start_value=0.0,
            warmup_end_value=0.1,
            warmup_duration=10,
            output_simulated_values=simulated_values,
        )

    def _test(
        lr_scheduler,
        optimizer,
        warmup_start_value,
        warmup_end_value,
        warmup_duration,
        warmup_end_next_value,
    ):
        num_iterations = 10
        max_epochs = 20

        simulated_values = [None] * (num_iterations * max_epochs)
        scheduler = create_lr_scheduler_with_warmup(
            lr_scheduler,
            warmup_start_value=warmup_start_value,
            warmup_end_value=warmup_end_value,
            warmup_duration=warmup_duration,
            output_simulated_values=simulated_values,
        )
        state_dict = scheduler.state_dict()
        trainer = Engine(lambda engine, batch: None)

        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

        @trainer.on(Events.ITERATION_STARTED)
        def save_lr(engine):
            lrs.append(optimizer.param_groups[0]["lr"])

        data = [0] * num_iterations

        for _ in range(2):
            lrs = []
            trainer.run(data, max_epochs=max_epochs)

            assert lrs == pytest.approx([v for i, v in simulated_values])

            assert lrs[0] == pytest.approx(
                warmup_start_value), "lrs={}".format(lrs[:warmup_duration +
                                                         num_iterations])
            assert lrs[warmup_duration -
                       1] == pytest.approx(warmup_end_value), "lrs={}".format(
                           lrs[:warmup_duration + num_iterations])
            assert lrs[warmup_duration] == pytest.approx(
                warmup_end_next_value), "lrs={}".format(lrs[:warmup_duration +
                                                            num_iterations])
            scheduler.load_state_dict(state_dict)

    t1 = torch.zeros([1], requires_grad=True)
    # A) opt lr != warmup_end_value
    optimizer = torch.optim.SGD([t1], lr=0.2)
    torch_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer=optimizer, gamma=0.98)
    _test(torch_lr_scheduler, optimizer, 0.01, 0.05, 10, 0.2)
    optimizer = torch.optim.SGD([t1], lr=0.2)
    torch_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer=optimizer, gamma=0.98)
    _test(torch_lr_scheduler, optimizer, 0.01, 0.05, 2, 0.2)

    # B) opt lr == warmup_end_value
    optimizer = torch.optim.SGD([t1], lr=0.2)
    torch_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer=optimizer, gamma=0.98)
    _test(torch_lr_scheduler, optimizer, 0.01, 0.2, 10, 0.2 * 0.98)
    optimizer = torch.optim.SGD([t1], lr=0.2)
    torch_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer=optimizer, gamma=0.98)
    _test(torch_lr_scheduler, optimizer, 0.01, 0.2, 2, 0.2 * 0.98)

    # C) lr_scheduler start_value != warmup_end_value
    t1 = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([t1], lr=0.0)
    lr_scheduler = LinearCyclicalScheduler(
        optimizer=optimizer,
        param_name="lr",
        start_value=0.8,
        end_value=0.0,
        cycle_size=10,
    )
    _test(lr_scheduler, optimizer, 0.01, 0.05, 10, 0.8)
    optimizer = torch.optim.SGD([t1], lr=0.0)
    lr_scheduler = LinearCyclicalScheduler(
        optimizer=optimizer,
        param_name="lr",
        start_value=0.8,
        end_value=0.0,
        cycle_size=10,
    )
    _test(lr_scheduler, optimizer, 0.01, 0.05, 2, 0.8)

    # D) lr_scheduler start_value == warmup_end_value
    t1 = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([t1], lr=0.0)
    lr_scheduler = LinearCyclicalScheduler(
        optimizer=optimizer,
        param_name="lr",
        start_value=0.8,
        end_value=0.0,
        cycle_size=10,
    )
    _test(lr_scheduler, optimizer, 0.01, 0.8, 10, 0.8 - (0.8 / 5.0))
    optimizer = torch.optim.SGD([t1], lr=0.0)
    lr_scheduler = LinearCyclicalScheduler(
        optimizer=optimizer,
        param_name="lr",
        start_value=0.8,
        end_value=0.0,
        cycle_size=10,
    )
    _test(lr_scheduler, optimizer, 0.01, 0.8, 2, 0.8 - (0.8 / 5.0))
def main(parser_args):
    """Main function to create trainer engine, add handlers to train and validation engines.
    Then runs train engine to perform training and validation.

    Args:
        parser_args (dict): parsed arguments
    """
    dataloader_train, dataloader_validation = get_dataloaders(parser_args)
    criterion = nn.CrossEntropyLoss()

    unet = SphericalUNet(parser_args.pooling_class, parser_args.n_pixels,
                         parser_args.depth, parser_args.laplacian_type,
                         parser_args.kernel_size)
    unet, device = init_device(parser_args.device, unet)
    lr = parser_args.learning_rate
    optimizer = optim.Adam(unet.parameters(), lr=lr)

    def trainer(engine, batch):
        """Train Function to define train engine.
        Called for every batch of the train engine, for each epoch.

        Args:
            engine (ignite.engine): train engine
            batch (:obj:`torch.utils.data.dataloader`): batch from train dataloader

        Returns:
            :obj:`torch.tensor` : train loss for that batch and epoch
        """
        unet.train()
        data, labels = batch
        labels = labels.to(device)
        data = data.to(device)
        output = unet(data)

        B, V, C = output.shape
        B_labels, V_labels, C_labels = labels.shape
        output = output.view(B * V, C)
        labels = labels.view(B_labels * V_labels, C_labels).max(1)[1]

        loss = criterion(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return loss.item()

    writer = SummaryWriter(parser_args.tensorboard_path)

    engine_train = Engine(trainer)

    engine_validate = create_supervised_evaluator(
        model=unet,
        metrics={"AP": EpochMetric(average_precision_compute_fn)},
        device=device,
        output_transform=validate_output_transform)

    engine_train.add_event_handler(
        Events.EPOCH_STARTED,
        lambda x: print("Starting Epoch: {}".format(x.state.epoch)))
    engine_train.add_event_handler(Events.ITERATION_COMPLETED,
                                   TerminateOnNan())

    @engine_train.on(Events.EPOCH_COMPLETED)
    def epoch_validation(engine):
        """Handler to run the validation engine at the end of the train engine's epoch.

        Args:
            engine (ignite.engine): train engine
        """
        print("beginning validation epoch")
        engine_validate.run(dataloader_validation)

    reduce_lr_plateau = ReduceLROnPlateau(
        optimizer,
        mode=parser_args.reducelronplateau_mode,
        factor=parser_args.reducelronplateau_factor,
        patience=parser_args.reducelronplateau_patience,
    )

    @engine_validate.on(Events.EPOCH_COMPLETED)
    def update_reduce_on_plateau(engine):
        """Handler to reduce the learning rate on plateau at the end of the validation engine's epoch

        Args:
            engine (ignite.engine): validation engine
        """
        ap = engine.state.metrics["AP"]
        mean_average_precision = np.mean(ap[1:])
        reduce_lr_plateau.step(mean_average_precision)

    @engine_validate.on(Events.EPOCH_COMPLETED)
    def save_epoch_results(engine):
        """Handler to save the metrics at the end of the validation engine's epoch

        Args:
            engine (ignite.engine): validation engine
        """
        ap = engine.state.metrics["AP"]
        mean_average_precision = np.mean(ap[1:])
        print("Average precisions:", ap)
        print("mAP:", mean_average_precision)
        writer.add_scalars(
            "metrics",
            {
                "mean average precision (AR+TC)": mean_average_precision,
                "AR average precision": ap[2],
                "TC average precision": ap[1]
            },
            engine_train.state.epoch,
        )
        writer.close()

    step_scheduler = StepLR(optimizer,
                            step_size=parser_args.steplr_step_size,
                            gamma=parser_args.steplr_gamma)
    scheduler = create_lr_scheduler_with_warmup(
        step_scheduler,
        warmup_start_value=parser_args.warmuplr_warmup_start_value,
        warmup_end_value=parser_args.warmuplr_warmup_end_value,
        warmup_duration=parser_args.warmuplr_warmup_duration,
    )
    engine_validate.add_event_handler(Events.EPOCH_COMPLETED, scheduler)

    earlystopper = EarlyStopping(
        patience=parser_args.earlystopping_patience,
        score_function=lambda x: -x.state.metrics["AP"][1],
        trainer=engine_train)
    engine_validate.add_event_handler(Events.EPOCH_COMPLETED, earlystopper)

    add_tensorboard(engine_train,
                    optimizer,
                    unet,
                    log_dir=parser_args.tensorboard_path)

    engine_train.run(dataloader_train, max_epochs=parser_args.n_epochs)

    torch.save(unet.state_dict(),
               parser_args.model_save_path + "unet_state.pt")
Example #10
0
def train(cfg):
    print(cfg.pretty())

    ###################################################################
    # Dataset
    ###################################################################
    wt = Dataset(batch_size=cfg.train.batch_size,
                 bptt_len=cfg.train.bptt_len,
                 dataset_cls=hydra.utils.get_class(cfg.dataset.name))

    ###################################################################
    # Models
    ###################################################################
    base_embedding = hydra.utils.instantiate(cfg.embedding,
                                             ntokens=len(wt.text_field.vocab) +
                                             3)
    embedding = TransformerEmbedding(
        embedding=base_embedding,
        max_length=cfg.train.bptt_len,
        embedding_size=base_embedding.embedding_size,
        use_positional_embedding=False)
    encoder = TransformerEncoder(query_dim=cfg.encoder.query_dim,
                                 att_num_units=cfg.encoder.att_num_units,
                                 ffn_num_unit=cfg.encoder.ffn_num_unit,
                                 max_ext=cfg.encoder.max_ext)
    model = TransformerLanguageModel(embedding, encoder)
    model.init_weight()

    # wandb.watch(model)

    ###################################################################
    # Loss
    ###################################################################
    criterion = lm_criterion(in_features=cfg.encoder.att_num_units[-1],
                             vocab_size=len(wt.text_field.vocab))

    ###################################################################
    # Parameters + Train ops
    ###################################################################
    parameters = (list(model.parameters()) + list(criterion.parameters()))
    tot_params = 0
    for p in parameters:
        tot_params += reduce(lambda x, y: x * y, p.size())
    print("Total Parameters: ", tot_params)
    opt = optim.Adam(parameters, lr=cfg.train.lr)
    model.to(DEVICE)
    criterion.to(DEVICE)

    ###################################################################
    # Train + Evaluation
    ###################################################################
    def train_step(engine, batch):
        model.train()
        opt.zero_grad()

        text = batch.text.to(DEVICE).t().contiguous()
        target = batch.target.to(DEVICE).t().contiguous()

        out, out_past = model(text, engine.state.train_past)
        engine.state.train_past = out_past
        raw_loss = criterion(out.view(-1, out.size(2)), target.view(-1))
        loss = raw_loss[1]

        loss.backward()
        nn.utils.clip_grad_norm_(parameters, cfg.train.clip_grad)
        opt.step()

        return {"train_loss": loss.item(), "train_ppl": loss.exp().item()}

    def eval_step(engine, batch):
        model.eval()

        if not hasattr(engine.state, "eval_past"):
            engine.state.eval_past = None

        target_sample = []
        result_sample = []
        with torch.no_grad():
            text = batch.text.to(DEVICE).t().contiguous()
            target = batch.target.to(DEVICE).t().contiguous()

            out, out_past = model(text, engine.state.eval_past)

            vocab = wt.text_field.vocab
            idx = list(range(32))
            sample = random.choices(idx, k=5)
            for id_sample in sample:
                s = []
                for target_id in target[id_sample]:
                    s.append(vocab.itos[target_id])
                target_sample.append(" ".join(s))

                s = []
                for result_id in out.max(-1)[1][id_sample]:
                    s.append(vocab.itos[result_id])
                result_sample.append(" ".join(s))
            # engine.state.eval_past = out_past
            raw_loss = criterion(out.view(-1, out.size(2)), target.view(-1))
            loss = raw_loss[1]

            return {
                "val_loss": loss.item(),
                "sample": (target_sample, result_sample)
            }

    train_engine = Engine(train_step)
    eval_engine = Engine(eval_step)

    def reset_state(engine):
        engine.state.train_past = None

    def run_eval(_):
        print("start running eval")
        eval_engine.run(wt.valid_iter)
        metrics = eval_engine.state.metrics
        print("Validation loss: ", metrics["val_loss"], ", ppl: ",
              np.exp(metrics["val_loss"]))

    train_engine.add_event_handler(Events.EPOCH_STARTED, reset_state)
    train_engine.add_event_handler(Events.EPOCH_COMPLETED, run_eval)

    ###################################################################
    # LR Scheduler
    ###################################################################
    cosine_scheduler = CosineAnnealingScheduler(opt.param_groups[0],
                                                "lr",
                                                0.0,
                                                2.5e-4,
                                                cycle_size=len(wt.train_iter))
    warmup_scheduler = create_lr_scheduler_with_warmup(cosine_scheduler, 0.0,
                                                       2.5e-4, 200)
    train_engine.add_event_handler(Events.ITERATION_STARTED, warmup_scheduler)

    ###################################################################
    # Metrics
    ###################################################################
    RunningAverage(output_transform=lambda x: x["train_ppl"]).attach(
        train_engine, "train_ppl")
    RunningAverage(output_transform=lambda x: x["train_loss"]).attach(
        train_engine, "train_loss")
    RunningAverage(output_transform=lambda x: x["val_loss"]).attach(
        eval_engine, "val_loss")
    progress_bar = ProgressBar(persist=True)
    progress_bar.attach(train_engine, ["train_ppl", "train_loss"])
    progress_bar_val = ProgressBar(persist=True)
    progress_bar_val.attach(eval_engine, ["val_loss"])

    ###################################################################
    # Tensorboard
    ###################################################################
    # tb_logger = TensorboardLogger(log_dir=log_dir)
    tb_logger = WandbLogger(project="language_model", entity="akurniawan")
    tb_logger.watch(model)

    def stepn_logger(num_steps, handler):
        def logger_runner(engine, log_handler, event_name):
            if engine.state.iteration % num_steps == 0:
                handler(engine, log_handler, event_name)

        return logger_runner

    tb_logger.attach(train_engine,
                     log_handler=stepn_logger(
                         cfg.train.log_steps,
                         OutputHandler(tag="training",
                                       output_transform=lambda loss: loss)),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(eval_engine,
                     log_handler=OutputHandler(
                         tag="validation",
                         output_transform=lambda loss: loss,
                         another_engine=train_engine),
                     event_name=Events.EPOCH_COMPLETED)
    # tb_logger.attach(train_engine,
    #                  log_handler=stepn_logger(log_steps,
    #                                           OptimizerParamsHandler(opt)),
    #                  event_name=Events.ITERATION_STARTED)
    # tb_logger.attach(train_engine,
    #                  log_handler=stepn_logger(log_steps,
    #                                           WeightsScalarHandler(model)),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(train_engine,
    #                  log_handler=stepn_logger(log_steps,
    #                                           GradsScalarHandler(model)),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(train_engine,
    #                  log_handler=stepn_logger(500, WeightsHistHandler(model)),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(train_engine,
    #                  log_handler=stepn_logger(500, GradsHistHandler(model)),
    #                  event_name=Events.ITERATION_COMPLETED)

    try:
        train_engine.run(wt.train_iter, max_epochs=cfg.train.epochs)
    except Exception:
        pass
    finally:
        tb_logger.close()
Example #11
0
def train(epochs=500,
          batch_size=32,
          bptt_len=70,
          lr=0.00025,
          log_steps=200,
          clip_grad=0.25,
          log_dir="experiments"):
    ###################################################################
    # Dataset
    ###################################################################
    wt = wikitext103(batch_size=batch_size, bptt_len=bptt_len)
    # wt = wikitext2(batch_size=batch_size, bptt_len=bptt_len)

    ###################################################################
    # Configs
    ###################################################################
    embedding_config = DropEmbedding.Hyperparams(len(wt.text_field.vocab) + 3,
                                                 ninp=512)
    encoder_config = TransformerEncoder.Hyperparams(
        att_num_units=[512, 512, 512, 512, 512, 512], max_ext=384)

    ###################################################################
    # Models
    ###################################################################
    base_embedding = DropEmbedding(embedding_config)
    embedding = TransformerEmbedding(embedding=base_embedding,
                                     max_length=bptt_len,
                                     embedding_size=embedding_config.ninp,
                                     use_positional_embedding=False)
    encoder = TransformerEncoder(encoder_config)
    model = TransformerLanguageModel(embedding, encoder)
    model.init_weight()

    ###################################################################
    # Loss
    ###################################################################
    criterion = lm_criterion(in_features=encoder_config.att_num_units[-1],
                             vocab_size=len(wt.text_field.vocab))

    ###################################################################
    # Parameters + Train ops
    ###################################################################
    parameters = (list(model.parameters()) + list(criterion.parameters()))
    tot_params = 0
    for p in parameters:
        tot_params += reduce(lambda x, y: x * y, p.size())
    print("Total Parameters: ", tot_params)
    opt = optim.Adam(parameters, lr=lr)
    model.to(DEVICE)
    criterion.to(DEVICE)

    ###################################################################
    # Train + Evaluation
    ###################################################################
    def train_step(engine, batch):
        model.train()
        opt.zero_grad()

        text = batch.text.to(DEVICE).t().contiguous()
        target = batch.target.to(DEVICE).t().contiguous()

        out, out_past = model(text, engine.state.train_past)
        engine.state.train_past = out_past
        raw_loss = criterion(out.view(-1, out.size(2)), target.view(-1))
        loss = raw_loss[1]

        loss.backward()
        nn.utils.clip_grad_norm_(parameters, clip_grad)
        opt.step()

        return {"train_loss": loss.item(), "train_ppl": loss.exp().item()}

    def eval_step(engine, batch):
        model.eval()

        if not hasattr(engine.state, "eval_past"):
            engine.state.eval_past = None

        with torch.no_grad():
            text = batch.text.to(DEVICE).t().contiguous()
            target = batch.target.to(DEVICE).t().contiguous()

            out, out_past = model(text, engine.state.eval_past)
            engine.state.eval_past = out_past
            raw_loss = criterion(out.view(-1, out.size(2)), target.view(-1))
            loss = raw_loss[1]

            return {"val_loss": loss.item()}

    train_engine = Engine(train_step)
    eval_engine = Engine(eval_step)

    def reset_state(engine):
        engine.state.train_past = None

    def run_eval(_):
        print("start running eval")
        eval_engine.run(wt.valid_iter)
        metrics = eval_engine.state.metrics
        print("Validation loss: ", metrics["val_loss"], ", ppl: ",
              np.exp(metrics["val_loss"]))

    train_engine.add_event_handler(Events.EPOCH_STARTED, reset_state)
    train_engine.add_event_handler(Events.EPOCH_COMPLETED, run_eval)

    ###################################################################
    # LR Scheduler
    ###################################################################
    cosine_scheduler = CosineAnnealingScheduler(opt.param_groups[0],
                                                "lr",
                                                0.0,
                                                2.5e-4,
                                                cycle_size=len(wt.train_iter))
    warmup_scheduler = create_lr_scheduler_with_warmup(cosine_scheduler, 0.0,
                                                       2.5e-4, 200)
    train_engine.add_event_handler(Events.ITERATION_STARTED, warmup_scheduler)

    ###################################################################
    # Metrics
    ###################################################################
    RunningAverage(output_transform=lambda x: x["train_ppl"]).attach(
        train_engine, "train_ppl")
    RunningAverage(output_transform=lambda x: x["train_loss"]).attach(
        train_engine, "train_loss")
    RunningAverage(output_transform=lambda x: x["val_loss"]).attach(
        eval_engine, "val_loss")
    progress_bar = ProgressBar(persist=True)
    progress_bar.attach(train_engine, ["train_ppl", "train_loss"])
    progress_bar_val = ProgressBar(persist=True)
    progress_bar_val.attach(eval_engine, ["val_loss"])

    ###################################################################
    # Tensorboard
    ###################################################################
    tb_logger = TensorboardLogger(log_dir=log_dir)

    def stepn_logger(num_steps, handler):
        def logger_runner(engine, log_handler, event_name):
            if engine.state.iteration % num_steps == 0:
                handler(engine, log_handler, event_name)

        return logger_runner

    tb_logger.attach(train_engine,
                     log_handler=stepn_logger(
                         log_steps,
                         OutputHandler(tag="training",
                                       output_transform=lambda loss: loss)),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(eval_engine,
                     log_handler=OutputHandler(
                         tag="validation",
                         output_transform=lambda loss: loss,
                         another_engine=train_engine),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(train_engine,
                     log_handler=stepn_logger(log_steps,
                                              OptimizerParamsHandler(opt)),
                     event_name=Events.ITERATION_STARTED)
    tb_logger.attach(train_engine,
                     log_handler=stepn_logger(log_steps,
                                              WeightsScalarHandler(model)),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(train_engine,
                     log_handler=stepn_logger(log_steps,
                                              GradsScalarHandler(model)),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(train_engine,
                     log_handler=stepn_logger(500, WeightsHistHandler(model)),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(train_engine,
                     log_handler=stepn_logger(500, GradsHistHandler(model)),
                     event_name=Events.ITERATION_COMPLETED)

    try:
        train_engine.run(wt.train_iter, max_epochs=epochs)
    except Exception:
        pass
    finally:
        tb_logger.close()