Пример #1
0
def test_terminate_at_start_of_epoch_stops_training_after_completing_iteration(
):
    max_epochs = 5
    epoch_to_terminate_on = 3
    batches_per_epoch = [1, 2, 3]

    def start_of_epoch_handler(trainer):
        if trainer.current_epoch == epoch_to_terminate_on:
            trainer.terminate()

    trainer = Trainer(batches_per_epoch, MagicMock(return_value=1),
                      MagicMock(), MagicMock())
    trainer.add_event_handler(TrainingEvents.EPOCH_STARTED,
                              start_of_epoch_handler)

    assert not trainer.should_terminate

    trainer.run(max_epochs=max_epochs, validate_every_epoch=False)

    # epoch is not completed so counter is not incremented
    assert trainer.current_epoch == epoch_to_terminate_on
    assert trainer.should_terminate
    # completes first iteration
    assert trainer.current_iteration == (epoch_to_terminate_on *
                                         len(batches_per_epoch)) + 1
Пример #2
0
def test_current_validation_iteration_counter_increases_every_iteration():
    validation_batches = [1, 2, 3]
    trainer = Trainer([1], MagicMock(return_value=1), validation_batches,
                      MagicMock(return_value=1))
    max_epochs = 5

    class IterationCounter(object):
        def __init__(self):
            self.current_iteration_count = 0
            self.total_count = 0

        def __call__(self, trainer):
            assert trainer.current_validation_iteration == self.current_iteration_count
            self.current_iteration_count += 1
            self.total_count += 1

        def clear(self):
            self.current_iteration_count = 0

    iteration_counter = IterationCounter()

    def clear_counter(trainer, counter):
        counter.clear()

    trainer.add_event_handler(TrainingEvents.VALIDATION_STARTING,
                              clear_counter, iteration_counter)
    trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED,
                              iteration_counter)

    trainer.run(max_epochs=max_epochs, validate_every_epoch=True)

    assert iteration_counter.total_count == max_epochs * len(
        validation_batches)
Пример #3
0
def test_training_iteration_events_are_fired():
    max_epochs = 5
    num_batches = 3
    data = _create_mock_data_loader(max_epochs, num_batches)

    trainer = Trainer(MagicMock(return_value=1))

    mock_manager = Mock()
    iteration_started = Mock()
    trainer.add_event_handler(Events.ITERATION_STARTED, iteration_started)

    iteration_complete = Mock()
    trainer.add_event_handler(Events.ITERATION_COMPLETED, iteration_complete)

    mock_manager.attach_mock(iteration_started, 'iteration_started')
    mock_manager.attach_mock(iteration_complete, 'iteration_complete')

    state = trainer.run(data, max_epochs=max_epochs)

    assert iteration_started.call_count == num_batches * max_epochs
    assert iteration_complete.call_count == num_batches * max_epochs

    expected_calls = []
    for i in range(max_epochs * num_batches):
        expected_calls.append(call.iteration_started(trainer, state))
        expected_calls.append(call.iteration_complete(trainer, state))

    assert mock_manager.mock_calls == expected_calls
Пример #4
0
def test_validation_iteration_events_are_fired_when_validate_is_called_explicitly(
):
    max_epochs = 5
    num_batches = 3
    validation_data = _create_mock_data_loader(max_epochs, num_batches)

    trainer = Trainer(training_data=[None],
                      validation_data=validation_data,
                      training_update_function=MagicMock(),
                      validation_inference_function=MagicMock(return_value=1))

    mock_manager = Mock()
    iteration_started = Mock()
    trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED,
                              iteration_started)

    iteration_complete = Mock()
    trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_COMPLETED,
                              iteration_complete)

    mock_manager.attach_mock(iteration_started, 'iteration_started')
    mock_manager.attach_mock(iteration_complete, 'iteration_complete')

    assert iteration_started.call_count == 0
    assert iteration_complete.call_count == 0

    trainer.validate()

    assert iteration_started.call_count == num_batches
    assert iteration_complete.call_count == num_batches
Пример #5
0
def test_validation_iteration_events_are_fired():
    max_epochs = 5
    num_batches = 3
    validation_data = _create_mock_data_loader(max_epochs, num_batches)

    trainer = Trainer(training_data=[None],
                      validation_data=validation_data,
                      training_update_function=MagicMock(return_value=1),
                      validation_inference_function=MagicMock(return_value=1))

    mock_manager = Mock()
    iteration_started = Mock()
    trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED,
                              iteration_started)

    iteration_complete = Mock()
    trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_COMPLETED,
                              iteration_complete)

    mock_manager.attach_mock(iteration_started, 'iteration_started')
    mock_manager.attach_mock(iteration_complete, 'iteration_complete')

    trainer.run(max_epochs=max_epochs)

    assert iteration_started.call_count == num_batches * max_epochs
    assert iteration_complete.call_count == num_batches * max_epochs

    expected_calls = []
    for i in range(max_epochs * num_batches):
        expected_calls.append(call.iteration_started(trainer))
        expected_calls.append(call.iteration_complete(trainer))

    assert mock_manager.mock_calls == expected_calls
Пример #6
0
def test_adding_handler_for_non_existent_event_throws_error():
    trainer = Trainer(MagicMock(), MagicMock(), MagicMock(), MagicMock())

    event_name = uuid.uuid4()
    while event_name in TrainingEvents.__members__.values():
        event_name = uuid.uuid4()

    with raises(ValueError):
        trainer.add_event_handler(event_name, lambda x: x)
Пример #7
0
def test_adding_multiple_event_handlers():
    trainer = Trainer([1], MagicMock(return_value=1), MagicMock(), MagicMock())
    handlers = [MagicMock(), MagicMock()]
    for handler in handlers:
        trainer.add_event_handler(TrainingEvents.TRAINING_STARTED, handler)

    trainer.run(validate_every_epoch=False)
    for handler in handlers:
        handler.assert_called_once_with(trainer)
Пример #8
0
def test_custom_exception_handler():
    value_error = ValueError()
    training_update_function = MagicMock(side_effect=value_error)

    trainer = Trainer(training_update_function)
    exception_handler = MagicMock()
    trainer.add_event_handler(Events.EXCEPTION_RAISED, exception_handler)
    state = trainer.run([1])

    # only one call from _run_once_over_data, since the exception is swallowed
    exception_handler.assert_has_calls([call(trainer, state, value_error)])
Пример #9
0
def test_exception_handler_called_on_error():
  training_update_function = MagicMock(side_effect=ValueError())

  trainer = Trainer([1], training_update_function, MagicMock(), MagicMock())
  exception_handler = MagicMock()
  trainer.add_event_handler(TrainingEvents.EXCEPTION_RAISED, exception_handler)

  with raises(ValueError):
    trainer.run()

  exception_handler.assert_called_once_with(trainer)
Пример #10
0
def test_terminate_stops_training_mid_epoch():
    num_iterations_per_epoch = 10
    iteration_to_stop = num_iterations_per_epoch + 3  # i.e. part way through the 2nd epoch
    trainer = Trainer(MagicMock(return_value=1), MagicMock())

    def end_of_iteration_handler(trainer):
        if trainer.current_iteration == iteration_to_stop:
            trainer.terminate()

    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED,
                              end_of_iteration_handler)
    trainer.run(training_data=[None] * num_iterations_per_epoch, max_epochs=3)
    assert (trainer.current_iteration == iteration_to_stop + 1
            )  # completes the iteration when terminate called
    assert trainer.current_epoch == np.ceil(
        iteration_to_stop / num_iterations_per_epoch) - 1  # it starts from 0
Пример #11
0
def test_terminate_stops_training_mid_epoch():
    num_iterations_per_epoch = 10
    iteration_to_stop = num_iterations_per_epoch + 3  # i.e. part way through the 3rd epoch
    trainer = Trainer(MagicMock(return_value=1))

    def start_of_iteration_handler(trainer, state):
        if state.iteration == iteration_to_stop:
            trainer.terminate()

    trainer.add_event_handler(Events.ITERATION_STARTED,
                              start_of_iteration_handler)
    state = trainer.run(data=[None] * num_iterations_per_epoch, max_epochs=3)
    # completes the iteration but doesn't increment counter (this happens just before a new iteration starts)
    assert (state.iteration == iteration_to_stop)
    assert state.epoch == np.ceil(iteration_to_stop /
                                  num_iterations_per_epoch)  # it starts from 0
Пример #12
0
def test_current_epoch_counter_increases_every_epoch():
    trainer = Trainer([1], MagicMock(return_value=1), MagicMock(), MagicMock())
    max_epochs = 5

    class EpochCounter(object):
        def __init__(self):
            self.current_epoch_count = 0

        def __call__(self, trainer):
            assert trainer.current_epoch == self.current_epoch_count
            self.current_epoch_count += 1

    trainer.add_event_handler(TrainingEvents.EPOCH_STARTED, EpochCounter())

    trainer.run(max_epochs=max_epochs, validate_every_epoch=False)

    assert trainer.current_epoch == max_epochs
Пример #13
0
def test_terminate_after_training_iteration_skips_validation_run():
  num_iterations_per_epoch = 10
  iteration_to_stop = num_iterations_per_epoch - 1
  trainer = Trainer(training_data=[None] * num_iterations_per_epoch,
                    training_update_function=MagicMock(return_value=1),
                    validation_data=MagicMock(),
                    validation_inference_function=MagicMock())

  def end_of_iteration_handler(trainer):
    if trainer.current_iteration == iteration_to_stop:
      trainer.terminate()

  trainer.validate = MagicMock()

  trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED, end_of_iteration_handler)
  trainer.run(max_epochs=3, validate_every_epoch=True)
  assert trainer.validate.call_count == 0
Пример #14
0
def test_terminate_at_end_of_epoch_stops_training():
    max_epochs = 5
    last_epoch_to_run = 3

    def end_of_epoch_handler(trainer):
        if trainer.current_epoch == last_epoch_to_run:
            trainer.terminate()

    trainer = Trainer(MagicMock(return_value=1))
    trainer.add_event_handler(Events.EPOCH_COMPLETED, end_of_epoch_handler)

    assert not trainer.should_terminate

    trainer.run([1], max_epochs=max_epochs)

    assert trainer.current_epoch == last_epoch_to_run
    assert trainer.should_terminate
Пример #15
0
def test_terminate_after_training_iteration_skips_validation_run():
    num_iterations_per_epoch = 10
    iteration_to_stop = num_iterations_per_epoch - 1
    trainer = Trainer(MagicMock(return_value=1), MagicMock())

    def end_of_iteration_handler(trainer):
        if trainer.current_iteration == iteration_to_stop:
            trainer.terminate()

    trainer.validate = MagicMock()

    trainer.add_event_handler(TrainingEvents.TRAINING_EPOCH_COMPLETED,
                              _validate, MagicMock())
    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED,
                              end_of_iteration_handler)
    trainer.run([None] * num_iterations_per_epoch, max_epochs=3)
    assert trainer.validate.call_count == 0
Пример #16
0
def test_terminate_at_end_of_epoch_stops_training():
  max_epochs = 5
  last_epoch_to_run = 3

  def end_of_epoch_handler(trainer):
    if trainer.current_epoch == last_epoch_to_run:
      trainer.terminate()

  trainer = Trainer([1], MagicMock(return_value=1), MagicMock(), MagicMock())
  trainer.add_event_handler(TrainingEvents.EPOCH_COMPLETED, end_of_epoch_handler)

  assert not trainer.should_terminate

  trainer.run(max_epochs=max_epochs, validate_every_epoch=False)

  assert trainer.current_epoch == last_epoch_to_run + 1  # counter is incremented at end of loop
  assert trainer.should_terminate
Пример #17
0
def test_current_epoch_counter_increases_every_epoch():
    trainer = Trainer(MagicMock(return_value=1))
    max_epochs = 5

    class EpochCounter(object):
        def __init__(self):
            self.current_epoch_count = 1

        def __call__(self, trainer, state):
            assert state.epoch == self.current_epoch_count
            self.current_epoch_count += 1

    trainer.add_event_handler(Events.EPOCH_STARTED, EpochCounter())

    state = trainer.run([1], max_epochs=max_epochs)

    assert state.epoch == max_epochs
Пример #18
0
def test_with_trainer(dirname):
    def update_fn(batch):
        pass

    name = 'model'
    trainer = Trainer(update_fn)
    handler = ModelCheckpoint(dirname,
                              _PREFIX,
                              create_dir=False,
                              n_saved=2,
                              save_interval=1)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {name: 42})
    trainer.run([0], max_epochs=4)

    expected = ['{}_{}_{}.pth'.format(_PREFIX, name, i) for i in [3, 4]]

    assert sorted(os.listdir(dirname)) == expected
Пример #19
0
def test_current_iteration_counter_increases_every_iteration():
    training_batches = [1, 2, 3]
    trainer = Trainer(MagicMock(return_value=1))
    max_epochs = 5

    class IterationCounter(object):
        def __init__(self):
            self.current_iteration_count = 1

        def __call__(self, trainer, state):
            assert state.iteration == self.current_iteration_count
            self.current_iteration_count += 1

    trainer.add_event_handler(Events.ITERATION_STARTED, IterationCounter())

    state = trainer.run(training_batches, max_epochs=max_epochs)

    assert state.iteration == max_epochs * len(training_batches)
Пример #20
0
def test_current_iteration_counter_increases_every_iteration():
  training_batches = [1, 2, 3]
  trainer = Trainer(training_batches, MagicMock(return_value=1), MagicMock(), MagicMock())
  max_epochs = 5

  class IterationCounter(object):
    def __init__(self):
      self.current_iteration_count = 0

    def __call__(self, trainer):
      assert trainer.current_iteration == self.current_iteration_count
      self.current_iteration_count += 1

  trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED, IterationCounter())

  trainer.run(max_epochs=max_epochs, validate_every_epoch=False)

  assert trainer.current_iteration == max_epochs * len(training_batches)
Пример #21
0
def test_args_and_kwargs_are_passed_to_event():
    trainer = Trainer([1], MagicMock(return_value=1), MagicMock(), MagicMock())
    kwargs = {'a': 'a', 'b': 'b'}
    args = (1, 2, 3)
    handlers = []
    for event in TrainingEvents:
        handler = MagicMock()
        trainer.add_event_handler(event, handler, *args, **kwargs)
        handlers.append(handler)

    trainer.run(max_epochs=1, validate_every_epoch=False)
    called_handlers = [handle for handle in handlers if handle.called]
    assert len(called_handlers) > 0

    for handler in called_handlers:
        handler_args, handler_kwargs = handler.call_args
        assert handler_args[0] == trainer
        assert handler_args[1::] == args
        assert handler_kwargs == kwargs
Пример #22
0
def test_terminate_stops_trainer_when_called_during_validation():
  num_iterations_per_epoch = 10
  iteration_to_stop = 3  # i.e. part way through the 2nd validation run
  epoch_to_stop = 2
  trainer = Trainer(training_data=[None] * num_iterations_per_epoch,
                    training_update_function=MagicMock(return_value=1),
                    validation_data=[None] * num_iterations_per_epoch,
                    validation_inference_function=MagicMock(return_value=1))

  def end_of_iteration_handler(trainer):
    if (trainer.current_epoch == epoch_to_stop and
          trainer.current_validation_iteration == iteration_to_stop):
      trainer.terminate()

  trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED, end_of_iteration_handler)
  trainer.run(max_epochs=4, validate_every_epoch=True)

  assert trainer.current_epoch == epoch_to_stop
  # should complete the iteration when terminate called
  assert trainer.current_validation_iteration == iteration_to_stop + 1
  assert trainer.current_iteration == (epoch_to_stop + 1) * num_iterations_per_epoch
Пример #23
0
def test_terminate_stops_trainer_when_called_during_validation():
    num_iterations_per_epoch = 10
    iteration_to_stop = 3  # i.e. part way through the 2nd validation run
    epoch_to_stop = 2
    trainer = Trainer(MagicMock(return_value=1), MagicMock(return_value=1))

    def end_of_iteration_handler(trainer):
        if (trainer.current_epoch == epoch_to_stop
                and trainer.current_validation_iteration == iteration_to_stop):

            trainer.terminate()

    trainer.add_event_handler(TrainingEvents.TRAINING_EPOCH_COMPLETED,
                              _validate, [None] * num_iterations_per_epoch)
    trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED,
                              end_of_iteration_handler)
    trainer.run([None] * num_iterations_per_epoch, max_epochs=4)

    assert trainer.current_epoch == epoch_to_stop
    # should complete the iteration when terminate called
    assert trainer.current_validation_iteration == iteration_to_stop + 1
    assert trainer.current_iteration == (epoch_to_stop +
                                         1) * num_iterations_per_epoch
Пример #24
0
def test_terminate_at_start_of_epoch_stops_training_after_completing_iteration(
):
    max_epochs = 5
    epoch_to_terminate_on = 3
    batches_per_epoch = [1, 2, 3]

    trainer = Trainer(MagicMock(return_value=1))

    def start_of_epoch_handler(trainer, state):
        if state.epoch == epoch_to_terminate_on:
            trainer.terminate()

    trainer.add_event_handler(Events.EPOCH_STARTED, start_of_epoch_handler)

    assert not trainer.should_terminate

    state = trainer.run(batches_per_epoch, max_epochs=max_epochs)

    # epoch is not completed so counter is not incremented
    assert state.epoch == epoch_to_terminate_on
    assert trainer.should_terminate
    # completes first iteration
    assert state.iteration == (
        (epoch_to_terminate_on - 1) * len(batches_per_epoch)) + 1
Пример #25
0
def train(epochs=10,
          batch_size=64,
          latent_dim=2,
          hidden_dim=400,
          use_gpu=False):
    from torch.autograd import Variable
    from ignite.trainer import Trainer, TrainingEvents
    import logging
    logger = logging.getLogger('ignite')
    logger.addHandler(logging.StreamHandler())
    logger.setLevel(logging.INFO)
    from tensorboardX import SummaryWriter
    from autoencoders.models.sampling import cvae_reconstructions
    from autoencoders.models.utils import flatten, to_one_hot
    from autoencoders.data.mnist import mnist_dataloader
    from autoencoders.utils.tensorboard import run_path
    from autoencoders.models.loss import VAELoss
    from autoencoders.utils.notifications import send_training_complete_push

    experiment_name = run_path('cvae/l{}_h{}_b{}_adam_3e-4'.format(
        latent_dim, hidden_dim, batch_size))

    checkpoint_path = '_'.join(experiment_name.split('/')[1:])

    writer = SummaryWriter(experiment_name)

    model = ConditionalVariationalAutoencoder(latent_dim=latent_dim,
                                              hidden_dim=hidden_dim,
                                              n_classes=10)
    optimizer = torch.optim.Adam(model.parameters(), 3e-4)
    criterion = VAELoss()

    train_loader = mnist_dataloader(path='data',
                                    batch_size=batch_size,
                                    download=True)

    val_loader = mnist_dataloader(path='data',
                                  batch_size=batch_size,
                                  train=False,
                                  download=True)

    if use_gpu:
        model.cuda()
        criterion.cuda()

    def training_update_function(batch):
        model.train()
        optimizer.zero_grad()

        inputs, targets = batch
        inputs = flatten(Variable(inputs))
        targets = Variable(
            to_one_hot(targets, batch_size=batch_size, n_classes=10))

        if use_gpu:
            inputs = inputs.cuda()
            targets = targets.cuda()

        output, mu, logvar = model(inputs, targets)
        loss = criterion(output, inputs, mu, logvar)
        loss.backward()
        optimizer.step()

        return loss.data[0]

    def validation_inference_function(batch):
        model.eval()

        inputs, targets = batch
        inputs = flatten(Variable(inputs))
        targets = Variable(
            to_one_hot(targets, batch_size=batch_size, n_classes=10))

        if use_gpu:
            inputs = inputs.cuda()
            targets = targets.cuda()

        output, mu, logvar = model(inputs, targets)
        loss = criterion(output, inputs, mu, logvar)

        return loss.data[0]

    # def on_end_epoch(state):

    #     meter_loss.reset()

    def on_epoch_competed(trainer, writer):
        print(trainer.current_epoch, np.mean(trainer.training_history))
        writer.add_scalar('loss/training loss',
                          np.mean(trainer.training_history),
                          trainer.current_epoch)

    def on_validation(trainer, writer):
        print(trainer.current_epoch, np.mean(trainer.validation_history))
        writer.add_scalar('loss/validation loss',
                          np.mean(trainer.validation_history),
                          trainer.current_epoch)
        writer.add_image('image', cvae_reconstructions(model, val_loader),
                         trainer.current_epoch)

    def on_complete(trainer):
        desc = 'Val Loss: {}'.format(np.mean(trainer.validation_history))
        send_training_complete_push(checkpoint_path, desc)

        torch.save(model.state_dict(), 'models/{}.cpt'.format(checkpoint_path))

    trainer = Trainer(train_loader, training_update_function, val_loader,
                      validation_inference_function)

    trainer.add_event_handler(TrainingEvents.TRAINING_EPOCH_COMPLETED,
                              on_epoch_competed, writer)
    trainer.add_event_handler(TrainingEvents.VALIDATION_COMPLETED,
                              on_validation, writer)

    trainer.add_event_handler(TrainingEvents.TRAINING_COMPLETED, on_complete)

    trainer.run(max_epochs=epochs)
Пример #26
0
def run():
    train_dataloader = data.DataLoader(CMAPSSData(train=True),
                                       shuffle=True,
                                       batch_size=300,
                                       pin_memory=USE_CUDA,
                                       collate_fn=collate_fn)

    validation_dataloader = data.DataLoader(CMAPSSData(train=False),
                                            shuffle=True,
                                            batch_size=1000,
                                            pin_memory=USE_CUDA,
                                            collate_fn=collate_fn)

    model = TTE(24, 128, n_layers=1, dropout=0.2)
    print(model)
    if USE_CUDA:
        model.cuda()
    optimizer = optim.RMSprop(model.parameters(), lr=0.001)
    loss_fn = weibull_loglikelihood

    def training_update(batch):
        model.train()
        optimizer.zero_grad()
        inputs, lengths, targets = batch
        if USE_CUDA:
            inputs = inputs.cuda()
            targets = targets.cuda()
        inputs = Variable(inputs)
        targets = Variable(targets)
        outputs = model(inputs, lengths)
        loss = loss_fn(outputs, targets, lengths)
        loss.backward()
        optimizer.step()
        return loss.data[0]

    def validation_inference(batch):
        model.eval()
        inputs, lengths, targets = batch
        if USE_CUDA:
            inputs = inputs.cuda()
            targets = targets.cuda()
        inputs = Variable(inputs, volatile=True)
        targets = Variable(targets, volatile=True)
        outputs = model(inputs, lengths)
        loss = loss_fn(outputs, targets, lengths)
        return loss.data[0], outputs.data[:, :,
                                          0], outputs.data[:, :,
                                                           1], targets.data

    trainer = Trainer(training_update)
    evaluator = Evaluator(validation_inference)
    progress = Progress()
    plot_interval = 1

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        Evaluate(evaluator, validation_dataloader, epoch_interval=1))

    @trainer.on(Events.EPOCH_STARTED)
    def epoch_started(trainer):
        print('Epoch {:4}/{}'.format(trainer.current_epoch,
                                     trainer.max_epochs),
              end='')

    @trainer.on(Events.ITERATION_COMPLETED)
    def iteration_completed(trainer):
        if trainer.current_iteration % plot_interval == 0:
            avg_loss = trainer.history.simple_moving_average(window_size=100)
            values = [('iter', trainer.current_iteration), ('loss', avg_loss)]
            progress.update(values)

    @evaluator.on(Events.COMPLETED)
    def epoch_completed(evaluator):
        history = evaluator.history[0]
        loss = history[0]
        alpha = history[1]
        beta = history[2]
        target = history[3]

        mae = torch.mean(torch.abs(alpha - target[:, :, 0]))
        alpha = alpha.mean()
        beta = beta.mean()

        values = [('val_loss', loss), ('mae', mae), ('alpha', alpha),
                  ('beta', beta)]
        progress.update(values, end=True)

    trainer.run(train_dataloader, max_epochs=600)

    return model
Пример #27
0
def train(batch_size=512, epochs=100):
    from torch.autograd import Variable
    from ignite.trainer import Trainer, TrainingEvents
    import logging
    logger = logging.getLogger('ignite')
    logger.addHandler(logging.StreamHandler())
    logger.setLevel(logging.INFO)
    from tensorboardX import SummaryWriter
    from autoencoders.models.sampling import vae_reconstructions
    from autoencoders.data.mnist import mnist_dataloader
    from autoencoders.utils.tensorboard import run_path
    from autoencoders.models.loss import VAELoss
    import numpy as np
    use_gpu = torch.cuda.is_available()

    writer = SummaryWriter(run_path('vae'))

    model = VariationalAutoencoder()
    optimizer = torch.optim.Adam(model.parameters(), 3e-4)
    criterion = VAELoss()

    train_loader = mnist_dataloader(path='data',
                                    batch_size=batch_size,
                                    download=True)

    val_loader = mnist_dataloader(path='data',
                                  batch_size=batch_size,
                                  train=False,
                                  download=True)

    if use_gpu:
        model.cuda()
        criterion.cuda()

    def training_update_function(batch):
        model.train()
        optimizer.zero_grad()

        inputs, _ = batch
        inputs = Variable(inputs)
        if use_gpu:
            inputs = inputs.cuda()

        output, mu, logvar = model(inputs)
        loss = criterion(output, inputs.view(-1, 784), mu, logvar)
        loss.backward()
        optimizer.step()

        return loss.data[0]

    def validation_inference_function(batch):
        model.eval()

        inputs, _ = batch
        inputs = Variable(inputs)
        if use_gpu:
            inputs = inputs.cuda()

        output, mu, logvar = model(inputs)
        loss = criterion(output, inputs, mu, logvar)

        return loss.data[0]

    def on_epoch_competed(trainer, writer):
        writer.add_scalar('loss/training loss',
                          np.mean(trainer.training_history),
                          trainer.current_epoch)

    def on_validation(trainer, writer):
        writer.add_scalar('loss/validation loss',
                          np.mean(trainer.validation_history),
                          trainer.current_epoch)

        writer.add_image('image', vae_reconstructions(model, val_loader),
                         trainer.current_epoch)

    trainer = Trainer(train_loader, training_update_function, val_loader,
                      validation_inference_function)

    trainer.add_event_handler(TrainingEvents.TRAINING_EPOCH_COMPLETED,
                              on_epoch_competed, writer)
    trainer.add_event_handler(TrainingEvents.VALIDATION_COMPLETED,
                              on_validation, writer)

    trainer.run(max_epochs=epochs)
def run(batch_size, val_batch_size, epochs, lr, momentum, log_interval,
        logger):
    vis = visdom.Visdom()
    if not vis.check_connection():
        raise RuntimeError(
            "Visdom server not running. Please run python -m visdom.server")

    data_transform = Compose([ToTensor(), Normalize((0.1307, ), (0.3081, ))])

    train_loader = DataLoader(MNIST(download=True,
                                    root=".",
                                    transform=data_transform,
                                    train=True),
                              batch_size=batch_size,
                              shuffle=True)

    val_loader = DataLoader(MNIST(download=False,
                                  root=".",
                                  transform=data_transform,
                                  train=False),
                            batch_size=val_batch_size,
                            shuffle=False)

    model = Net()
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)

    def training_update_function(batch):
        model.train()
        optimizer.zero_grad()
        data, target = Variable(batch[0]), Variable(batch[1])
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        return loss.data[0]

    def validation_inference_function(batch):
        model.eval()
        data, target = Variable(batch[0], volatile=True), Variable(batch[1])
        output = model(data)
        loss = F.nll_loss(output, target, size_average=False).data[0]
        pred = output.data.max(1, keepdim=True)[1]
        correct = pred.eq(target.data.view_as(pred)).sum()
        return loss, correct

    trainer = Trainer(training_update_function)
    evaluator = Evaluator(validation_inference_function)

    # trainer event handlers
    trainer.add_event_handler(Events.ITERATION_COMPLETED,
                              log_simple_moving_average,
                              window_size=100,
                              metric_name="NLL",
                              should_log=lambda trainer: trainer.
                              current_iteration % log_interval == 0,
                              logger=logger)
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED,
        get_plot_training_loss_handler(vis, plot_every=log_interval))
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        Evaluate(evaluator, val_loader, epoch_interval=1))

    # evaluator event handlers
    evaluator.add_event_handler(
        Events.COMPLETED, get_log_validation_loss_and_accuracy_handler(logger))
    evaluator.add_event_handler(Events.COMPLETED,
                                get_plot_validation_accuracy_handler(vis),
                                trainer)

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)
Пример #29
0
def main(args, logger):
    if args.reverse:
        preprocessing = reverse_sentence
    else:
        preprocessing = None
    SRC = data.Field(init_token=BOS,
                     eos_token=EOS,
                     include_lengths=True,
                     preprocessing=preprocessing)
    TRG = data.Field(init_token=BOS, eos_token=EOS)

    if args.dataset == 'enja':
        train, val, test = SmallEnJa.splits(exts=('.en', '.ja'),
                                            fields=(SRC, TRG))
    elif args.dataset == 'wmt14':
        train, val, test = WMT14.splits(
            exts=('.en', '.de'),
            fields=(SRC, TRG),
            filter_pred=lambda ex: len(ex.src) <= 50 and len(ex.trg) <= 50)

    SRC.build_vocab(train.src, max_size=args.src_vocab)
    TRG.build_vocab(train.trg, max_size=args.trg_vocab)

    stoi = TRG.vocab.stoi

    train_iter, val_iter = data.BucketIterator.splits(
        (train, val),
        batch_sizes=(args.batch, args.batch * 2),
        repeat=False,
        sort_within_batch=True,
        sort_key=SmallEnJa.sort_key,
        device=args.gpu[0] if len(args.gpu) == 1 else -1)
    test_iter = data.Iterator(test,
                              batch_size=1,
                              repeat=False,
                              sort=False,
                              train=False,
                              device=args.gpu[0] if args.gpu else -1)

    model = Seq2Seq(len(SRC.vocab), args.embed, args.encoder_hidden,
                    len(TRG.vocab), args.embed, args.decoder_hidden,
                    args.encoder_layers, not args.encoder_unidirectional,
                    args.decoder_layers, SRC.vocab.stoi[SRC.pad_token],
                    stoi[TRG.pad_token], stoi[TRG.init_token],
                    stoi[TRG.eos_token], args.dropout_ratio,
                    args.attention_type)
    model.prepare_translation(TRG.vocab.itos, args.max_length)

    translate = model.translate
    if len(args.gpu) >= 2:
        model = DataParallel(model, device_ids=args.gpu, dim=1).cuda()
    elif len(args.gpu) == 1:
        model.cuda(args.gpu[0])

    if args.optim == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.learning_rate,
                               weight_decay=args.weight_decay)
        scheduler = None
    elif args.optim == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.learning_rate)
        scheduler = MultiStepLR(optimizer,
                                milestones=list(range(8, 12)),
                                gamma=0.5)

    trainer = Trainer(
        TeacherForceUpdater(model, optimizer, model, args.gradient_clipping))
    evaluator = Evaluator(TeacherForceInference(model, model))
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              log_training_average_nll,
                              logger=logger)
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              Evaluate(evaluator, val_iter, epoch_interval=1))
    evaluator.add_event_handler(Events.COMPLETED,
                                get_log_validation_ppl(val.trg),
                                logger=logger)
    if args.best_file is not None:
        evaluator.add_event_handler(Events.COMPLETED,
                                    BestModelSnapshot(model, 'ppl', 1e10, le),
                                    args.best_file, logger)
        trainer.add_event_handler(Events.COMPLETED,
                                  ComputeBleu(model, test.trg, translate),
                                  test_iter, args.best_file, logger)
    if scheduler is not None:
        trainer.add_event_handler(Events.EPOCH_STARTED,
                                  lambda trainer: scheduler.step())
    trainer.run(train_iter, max_epochs=args.epoch)
Пример #30
0
def run(model_dir, max_len, source_train_path, target_train_path,
        source_val_path, target_val_path, enc_max_vocab, dec_max_vocab,
        encoder_emb_size, decoder_emb_size, encoder_units, decoder_units,
        batch_size, epochs, learning_rate, decay_step, decay_percent,
        log_interval, save_interval, compare_interval):

    train_iter, val_iter, source_vocab, target_vocab = create_dataset(
        batch_size, enc_max_vocab, dec_max_vocab, source_train_path,
        target_train_path, source_val_path, target_val_path)
    transformer = Transformer(max_length=max_len,
                              enc_vocab=source_vocab,
                              dec_vocab=target_vocab,
                              enc_emb_size=encoder_emb_size,
                              dec_emb_size=decoder_emb_size,
                              enc_units=encoder_units,
                              dec_units=decoder_units)
    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(transformer.parameters(), lr=learning_rate)
    lr_decay = StepLR(opt, step_size=decay_step, gamma=decay_percent)

    if torch.cuda.is_available():
        transformer.cuda()
        loss_fn.cuda()

    def training_update_function(batch):
        transformer.train()
        lr_decay.step()
        opt.zero_grad()

        softmaxed_predictions, predictions = transformer(batch.src, batch.trg)

        flattened_predictions = predictions.view(-1, len(target_vocab.itos))
        flattened_target = batch.trg.view(-1)

        loss = loss_fn(flattened_predictions, flattened_target)

        loss.backward()
        opt.step()

        return softmaxed_predictions.data, loss.data[0], batch.trg.data

    def validation_inference_function(batch):
        transformer.eval()
        softmaxed_predictions, predictions = transformer(batch.src, batch.trg)

        flattened_predictions = predictions.view(-1, len(target_vocab.itos))
        flattened_target = batch.trg.view(-1)

        loss = loss_fn(flattened_predictions, flattened_target)

        return loss.data[0]

    trainer = Trainer(train_iter, training_update_function, val_iter,
                      validation_inference_function)
    trainer.add_event_handler(TrainingEvents.TRAINING_STARTED,
                              restore_checkpoint_hook(transformer, model_dir))
    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED,
                              log_training_simple_moving_average,
                              window_size=10,
                              metric_name="CrossEntropy",
                              should_log=lambda trainer: trainer.
                              current_iteration % log_interval == 0,
                              history_transform=lambda history: history[1])
    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED,
                              save_checkpoint_hook(transformer, model_dir),
                              should_save=lambda trainer: trainer.
                              current_iteration % save_interval == 0)
    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED,
                              print_current_prediction_hook(target_vocab),
                              should_print=lambda trainer: trainer.
                              current_iteration % compare_interval == 0)
    trainer.add_event_handler(TrainingEvents.VALIDATION_COMPLETED,
                              log_validation_simple_moving_average,
                              window_size=10,
                              metric_name="CrossEntropy")
    trainer.add_event_handler(TrainingEvents.TRAINING_COMPLETED,
                              save_checkpoint_hook(transformer, model_dir),
                              should_save=lambda trainer: True)
    trainer.run(max_epochs=epochs, validate_every_epoch=True)