Exemple #1
0
def test_global_step_from_engine():

    engine = Engine(lambda engine, batch: None)
    engine.state = State()
    engine.state.epoch = 1

    another_engine = Engine(lambda engine, batch: None)
    another_engine.state = State()
    another_engine.state.epoch = 10

    global_step_transform = global_step_from_engine(another_engine)
    res = global_step_transform(engine, Events.EPOCH_COMPLETED)

    assert res == another_engine.state.epoch
Exemple #2
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = 'cpu'

    if torch.cuda.is_available():
        device = 'cuda'

    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        F.nll_loss,
                                        device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                'accuracy': Accuracy(),
                                                'nll': Loss(F.nll_loss)
                                            },
                                            device=device)

    desc = "ITERATION - loss: {:.2f}"
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter % log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(log_interval)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        tqdm.write("Training Results - "
                   "Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}".format(
                       engine.state.epoch, avg_accuracy, avg_nll))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        tqdm.write("Validation Results "
                   "- Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}".format(
                       engine.state.epoch, avg_accuracy, avg_nll))

        pbar.n = pbar.last_print_n = 0

    # [ChainerUI] import logger and handler
    # [ChainerUI] setup logger, this logger manages ChainerUI web client.
    logger = ChainerUILogger()
    # [ChainerUI] to ease requests of metrics posting, set interval option
    train_handler = OutputHandler('train',
                                  output_transform=lambda o: {'nll': o},
                                  interval_step=20)
    logger.attach(trainer,
                  log_handler=train_handler,
                  event_name=Events.ITERATION_COMPLETED)
    # [ChainerUI] to set same value of x axis, use global_step_transform
    val_handler = OutputHandler(
        'val',
        metric_names='all',
        global_step_transform=global_step_from_engine(trainer))
    logger.attach(evaluator,
                  log_handler=val_handler,
                  event_name=Events.EPOCH_COMPLETED)
    # [ChainerUI] to post remainder of metrics caused by interval, use "with"
    with logger:
        trainer.run(train_loader, max_epochs=epochs)
        pbar.close()