def test_terminate_at_start_of_epoch_stops_training_after_completing_iteration( ): max_epochs = 5 epoch_to_terminate_on = 3 batches_per_epoch = [1, 2, 3] def start_of_epoch_handler(trainer): if trainer.current_epoch == epoch_to_terminate_on: trainer.terminate() trainer = Trainer(batches_per_epoch, MagicMock(return_value=1), MagicMock(), MagicMock()) trainer.add_event_handler(TrainingEvents.EPOCH_STARTED, start_of_epoch_handler) assert not trainer.should_terminate trainer.run(max_epochs=max_epochs, validate_every_epoch=False) # epoch is not completed so counter is not incremented assert trainer.current_epoch == epoch_to_terminate_on assert trainer.should_terminate # completes first iteration assert trainer.current_iteration == (epoch_to_terminate_on * len(batches_per_epoch)) + 1
def test_current_validation_iteration_counter_increases_every_iteration(): validation_batches = [1, 2, 3] trainer = Trainer([1], MagicMock(return_value=1), validation_batches, MagicMock(return_value=1)) max_epochs = 5 class IterationCounter(object): def __init__(self): self.current_iteration_count = 0 self.total_count = 0 def __call__(self, trainer): assert trainer.current_validation_iteration == self.current_iteration_count self.current_iteration_count += 1 self.total_count += 1 def clear(self): self.current_iteration_count = 0 iteration_counter = IterationCounter() def clear_counter(trainer, counter): counter.clear() trainer.add_event_handler(TrainingEvents.VALIDATION_STARTING, clear_counter, iteration_counter) trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED, iteration_counter) trainer.run(max_epochs=max_epochs, validate_every_epoch=True) assert iteration_counter.total_count == max_epochs * len( validation_batches)
def test_training_iteration_events_are_fired(): max_epochs = 5 num_batches = 3 data = _create_mock_data_loader(max_epochs, num_batches) trainer = Trainer(MagicMock(return_value=1)) mock_manager = Mock() iteration_started = Mock() trainer.add_event_handler(Events.ITERATION_STARTED, iteration_started) iteration_complete = Mock() trainer.add_event_handler(Events.ITERATION_COMPLETED, iteration_complete) mock_manager.attach_mock(iteration_started, 'iteration_started') mock_manager.attach_mock(iteration_complete, 'iteration_complete') state = trainer.run(data, max_epochs=max_epochs) assert iteration_started.call_count == num_batches * max_epochs assert iteration_complete.call_count == num_batches * max_epochs expected_calls = [] for i in range(max_epochs * num_batches): expected_calls.append(call.iteration_started(trainer, state)) expected_calls.append(call.iteration_complete(trainer, state)) assert mock_manager.mock_calls == expected_calls
def test_validation_iteration_events_are_fired_when_validate_is_called_explicitly( ): max_epochs = 5 num_batches = 3 validation_data = _create_mock_data_loader(max_epochs, num_batches) trainer = Trainer(training_data=[None], validation_data=validation_data, training_update_function=MagicMock(), validation_inference_function=MagicMock(return_value=1)) mock_manager = Mock() iteration_started = Mock() trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED, iteration_started) iteration_complete = Mock() trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_COMPLETED, iteration_complete) mock_manager.attach_mock(iteration_started, 'iteration_started') mock_manager.attach_mock(iteration_complete, 'iteration_complete') assert iteration_started.call_count == 0 assert iteration_complete.call_count == 0 trainer.validate() assert iteration_started.call_count == num_batches assert iteration_complete.call_count == num_batches
def test_validation_iteration_events_are_fired(): max_epochs = 5 num_batches = 3 validation_data = _create_mock_data_loader(max_epochs, num_batches) trainer = Trainer(training_data=[None], validation_data=validation_data, training_update_function=MagicMock(return_value=1), validation_inference_function=MagicMock(return_value=1)) mock_manager = Mock() iteration_started = Mock() trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED, iteration_started) iteration_complete = Mock() trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_COMPLETED, iteration_complete) mock_manager.attach_mock(iteration_started, 'iteration_started') mock_manager.attach_mock(iteration_complete, 'iteration_complete') trainer.run(max_epochs=max_epochs) assert iteration_started.call_count == num_batches * max_epochs assert iteration_complete.call_count == num_batches * max_epochs expected_calls = [] for i in range(max_epochs * num_batches): expected_calls.append(call.iteration_started(trainer)) expected_calls.append(call.iteration_complete(trainer)) assert mock_manager.mock_calls == expected_calls
def test_adding_handler_for_non_existent_event_throws_error(): trainer = Trainer(MagicMock(), MagicMock(), MagicMock(), MagicMock()) event_name = uuid.uuid4() while event_name in TrainingEvents.__members__.values(): event_name = uuid.uuid4() with raises(ValueError): trainer.add_event_handler(event_name, lambda x: x)
def test_adding_multiple_event_handlers(): trainer = Trainer([1], MagicMock(return_value=1), MagicMock(), MagicMock()) handlers = [MagicMock(), MagicMock()] for handler in handlers: trainer.add_event_handler(TrainingEvents.TRAINING_STARTED, handler) trainer.run(validate_every_epoch=False) for handler in handlers: handler.assert_called_once_with(trainer)
def test_custom_exception_handler(): value_error = ValueError() training_update_function = MagicMock(side_effect=value_error) trainer = Trainer(training_update_function) exception_handler = MagicMock() trainer.add_event_handler(Events.EXCEPTION_RAISED, exception_handler) state = trainer.run([1]) # only one call from _run_once_over_data, since the exception is swallowed exception_handler.assert_has_calls([call(trainer, state, value_error)])
def test_exception_handler_called_on_error(): training_update_function = MagicMock(side_effect=ValueError()) trainer = Trainer([1], training_update_function, MagicMock(), MagicMock()) exception_handler = MagicMock() trainer.add_event_handler(TrainingEvents.EXCEPTION_RAISED, exception_handler) with raises(ValueError): trainer.run() exception_handler.assert_called_once_with(trainer)
def test_terminate_stops_training_mid_epoch(): num_iterations_per_epoch = 10 iteration_to_stop = num_iterations_per_epoch + 3 # i.e. part way through the 2nd epoch trainer = Trainer(MagicMock(return_value=1), MagicMock()) def end_of_iteration_handler(trainer): if trainer.current_iteration == iteration_to_stop: trainer.terminate() trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED, end_of_iteration_handler) trainer.run(training_data=[None] * num_iterations_per_epoch, max_epochs=3) assert (trainer.current_iteration == iteration_to_stop + 1 ) # completes the iteration when terminate called assert trainer.current_epoch == np.ceil( iteration_to_stop / num_iterations_per_epoch) - 1 # it starts from 0
def test_terminate_stops_training_mid_epoch(): num_iterations_per_epoch = 10 iteration_to_stop = num_iterations_per_epoch + 3 # i.e. part way through the 3rd epoch trainer = Trainer(MagicMock(return_value=1)) def start_of_iteration_handler(trainer, state): if state.iteration == iteration_to_stop: trainer.terminate() trainer.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler) state = trainer.run(data=[None] * num_iterations_per_epoch, max_epochs=3) # completes the iteration but doesn't increment counter (this happens just before a new iteration starts) assert (state.iteration == iteration_to_stop) assert state.epoch == np.ceil(iteration_to_stop / num_iterations_per_epoch) # it starts from 0
def test_current_epoch_counter_increases_every_epoch(): trainer = Trainer([1], MagicMock(return_value=1), MagicMock(), MagicMock()) max_epochs = 5 class EpochCounter(object): def __init__(self): self.current_epoch_count = 0 def __call__(self, trainer): assert trainer.current_epoch == self.current_epoch_count self.current_epoch_count += 1 trainer.add_event_handler(TrainingEvents.EPOCH_STARTED, EpochCounter()) trainer.run(max_epochs=max_epochs, validate_every_epoch=False) assert trainer.current_epoch == max_epochs
def test_terminate_after_training_iteration_skips_validation_run(): num_iterations_per_epoch = 10 iteration_to_stop = num_iterations_per_epoch - 1 trainer = Trainer(training_data=[None] * num_iterations_per_epoch, training_update_function=MagicMock(return_value=1), validation_data=MagicMock(), validation_inference_function=MagicMock()) def end_of_iteration_handler(trainer): if trainer.current_iteration == iteration_to_stop: trainer.terminate() trainer.validate = MagicMock() trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED, end_of_iteration_handler) trainer.run(max_epochs=3, validate_every_epoch=True) assert trainer.validate.call_count == 0
def test_terminate_at_end_of_epoch_stops_training(): max_epochs = 5 last_epoch_to_run = 3 def end_of_epoch_handler(trainer): if trainer.current_epoch == last_epoch_to_run: trainer.terminate() trainer = Trainer(MagicMock(return_value=1)) trainer.add_event_handler(Events.EPOCH_COMPLETED, end_of_epoch_handler) assert not trainer.should_terminate trainer.run([1], max_epochs=max_epochs) assert trainer.current_epoch == last_epoch_to_run assert trainer.should_terminate
def test_terminate_after_training_iteration_skips_validation_run(): num_iterations_per_epoch = 10 iteration_to_stop = num_iterations_per_epoch - 1 trainer = Trainer(MagicMock(return_value=1), MagicMock()) def end_of_iteration_handler(trainer): if trainer.current_iteration == iteration_to_stop: trainer.terminate() trainer.validate = MagicMock() trainer.add_event_handler(TrainingEvents.TRAINING_EPOCH_COMPLETED, _validate, MagicMock()) trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED, end_of_iteration_handler) trainer.run([None] * num_iterations_per_epoch, max_epochs=3) assert trainer.validate.call_count == 0
def test_terminate_at_end_of_epoch_stops_training(): max_epochs = 5 last_epoch_to_run = 3 def end_of_epoch_handler(trainer): if trainer.current_epoch == last_epoch_to_run: trainer.terminate() trainer = Trainer([1], MagicMock(return_value=1), MagicMock(), MagicMock()) trainer.add_event_handler(TrainingEvents.EPOCH_COMPLETED, end_of_epoch_handler) assert not trainer.should_terminate trainer.run(max_epochs=max_epochs, validate_every_epoch=False) assert trainer.current_epoch == last_epoch_to_run + 1 # counter is incremented at end of loop assert trainer.should_terminate
def test_current_epoch_counter_increases_every_epoch(): trainer = Trainer(MagicMock(return_value=1)) max_epochs = 5 class EpochCounter(object): def __init__(self): self.current_epoch_count = 1 def __call__(self, trainer, state): assert state.epoch == self.current_epoch_count self.current_epoch_count += 1 trainer.add_event_handler(Events.EPOCH_STARTED, EpochCounter()) state = trainer.run([1], max_epochs=max_epochs) assert state.epoch == max_epochs
def test_with_trainer(dirname): def update_fn(batch): pass name = 'model' trainer = Trainer(update_fn) handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, save_interval=1) trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {name: 42}) trainer.run([0], max_epochs=4) expected = ['{}_{}_{}.pth'.format(_PREFIX, name, i) for i in [3, 4]] assert sorted(os.listdir(dirname)) == expected
def test_current_iteration_counter_increases_every_iteration(): training_batches = [1, 2, 3] trainer = Trainer(MagicMock(return_value=1)) max_epochs = 5 class IterationCounter(object): def __init__(self): self.current_iteration_count = 1 def __call__(self, trainer, state): assert state.iteration == self.current_iteration_count self.current_iteration_count += 1 trainer.add_event_handler(Events.ITERATION_STARTED, IterationCounter()) state = trainer.run(training_batches, max_epochs=max_epochs) assert state.iteration == max_epochs * len(training_batches)
def test_current_iteration_counter_increases_every_iteration(): training_batches = [1, 2, 3] trainer = Trainer(training_batches, MagicMock(return_value=1), MagicMock(), MagicMock()) max_epochs = 5 class IterationCounter(object): def __init__(self): self.current_iteration_count = 0 def __call__(self, trainer): assert trainer.current_iteration == self.current_iteration_count self.current_iteration_count += 1 trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED, IterationCounter()) trainer.run(max_epochs=max_epochs, validate_every_epoch=False) assert trainer.current_iteration == max_epochs * len(training_batches)
def test_args_and_kwargs_are_passed_to_event(): trainer = Trainer([1], MagicMock(return_value=1), MagicMock(), MagicMock()) kwargs = {'a': 'a', 'b': 'b'} args = (1, 2, 3) handlers = [] for event in TrainingEvents: handler = MagicMock() trainer.add_event_handler(event, handler, *args, **kwargs) handlers.append(handler) trainer.run(max_epochs=1, validate_every_epoch=False) called_handlers = [handle for handle in handlers if handle.called] assert len(called_handlers) > 0 for handler in called_handlers: handler_args, handler_kwargs = handler.call_args assert handler_args[0] == trainer assert handler_args[1::] == args assert handler_kwargs == kwargs
def test_terminate_stops_trainer_when_called_during_validation(): num_iterations_per_epoch = 10 iteration_to_stop = 3 # i.e. part way through the 2nd validation run epoch_to_stop = 2 trainer = Trainer(training_data=[None] * num_iterations_per_epoch, training_update_function=MagicMock(return_value=1), validation_data=[None] * num_iterations_per_epoch, validation_inference_function=MagicMock(return_value=1)) def end_of_iteration_handler(trainer): if (trainer.current_epoch == epoch_to_stop and trainer.current_validation_iteration == iteration_to_stop): trainer.terminate() trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED, end_of_iteration_handler) trainer.run(max_epochs=4, validate_every_epoch=True) assert trainer.current_epoch == epoch_to_stop # should complete the iteration when terminate called assert trainer.current_validation_iteration == iteration_to_stop + 1 assert trainer.current_iteration == (epoch_to_stop + 1) * num_iterations_per_epoch
def test_terminate_stops_trainer_when_called_during_validation(): num_iterations_per_epoch = 10 iteration_to_stop = 3 # i.e. part way through the 2nd validation run epoch_to_stop = 2 trainer = Trainer(MagicMock(return_value=1), MagicMock(return_value=1)) def end_of_iteration_handler(trainer): if (trainer.current_epoch == epoch_to_stop and trainer.current_validation_iteration == iteration_to_stop): trainer.terminate() trainer.add_event_handler(TrainingEvents.TRAINING_EPOCH_COMPLETED, _validate, [None] * num_iterations_per_epoch) trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED, end_of_iteration_handler) trainer.run([None] * num_iterations_per_epoch, max_epochs=4) assert trainer.current_epoch == epoch_to_stop # should complete the iteration when terminate called assert trainer.current_validation_iteration == iteration_to_stop + 1 assert trainer.current_iteration == (epoch_to_stop + 1) * num_iterations_per_epoch
def test_terminate_at_start_of_epoch_stops_training_after_completing_iteration( ): max_epochs = 5 epoch_to_terminate_on = 3 batches_per_epoch = [1, 2, 3] trainer = Trainer(MagicMock(return_value=1)) def start_of_epoch_handler(trainer, state): if state.epoch == epoch_to_terminate_on: trainer.terminate() trainer.add_event_handler(Events.EPOCH_STARTED, start_of_epoch_handler) assert not trainer.should_terminate state = trainer.run(batches_per_epoch, max_epochs=max_epochs) # epoch is not completed so counter is not incremented assert state.epoch == epoch_to_terminate_on assert trainer.should_terminate # completes first iteration assert state.iteration == ( (epoch_to_terminate_on - 1) * len(batches_per_epoch)) + 1
def train(epochs=10, batch_size=64, latent_dim=2, hidden_dim=400, use_gpu=False): from torch.autograd import Variable from ignite.trainer import Trainer, TrainingEvents import logging logger = logging.getLogger('ignite') logger.addHandler(logging.StreamHandler()) logger.setLevel(logging.INFO) from tensorboardX import SummaryWriter from autoencoders.models.sampling import cvae_reconstructions from autoencoders.models.utils import flatten, to_one_hot from autoencoders.data.mnist import mnist_dataloader from autoencoders.utils.tensorboard import run_path from autoencoders.models.loss import VAELoss from autoencoders.utils.notifications import send_training_complete_push experiment_name = run_path('cvae/l{}_h{}_b{}_adam_3e-4'.format( latent_dim, hidden_dim, batch_size)) checkpoint_path = '_'.join(experiment_name.split('/')[1:]) writer = SummaryWriter(experiment_name) model = ConditionalVariationalAutoencoder(latent_dim=latent_dim, hidden_dim=hidden_dim, n_classes=10) optimizer = torch.optim.Adam(model.parameters(), 3e-4) criterion = VAELoss() train_loader = mnist_dataloader(path='data', batch_size=batch_size, download=True) val_loader = mnist_dataloader(path='data', batch_size=batch_size, train=False, download=True) if use_gpu: model.cuda() criterion.cuda() def training_update_function(batch): model.train() optimizer.zero_grad() inputs, targets = batch inputs = flatten(Variable(inputs)) targets = Variable( to_one_hot(targets, batch_size=batch_size, n_classes=10)) if use_gpu: inputs = inputs.cuda() targets = targets.cuda() output, mu, logvar = model(inputs, targets) loss = criterion(output, inputs, mu, logvar) loss.backward() optimizer.step() return loss.data[0] def validation_inference_function(batch): model.eval() inputs, targets = batch inputs = flatten(Variable(inputs)) targets = Variable( to_one_hot(targets, batch_size=batch_size, n_classes=10)) if use_gpu: inputs = inputs.cuda() targets = targets.cuda() output, mu, logvar = model(inputs, targets) loss = criterion(output, inputs, mu, logvar) return loss.data[0] # def on_end_epoch(state): # meter_loss.reset() def on_epoch_competed(trainer, writer): print(trainer.current_epoch, np.mean(trainer.training_history)) writer.add_scalar('loss/training loss', np.mean(trainer.training_history), trainer.current_epoch) def on_validation(trainer, writer): print(trainer.current_epoch, np.mean(trainer.validation_history)) writer.add_scalar('loss/validation loss', np.mean(trainer.validation_history), trainer.current_epoch) writer.add_image('image', cvae_reconstructions(model, val_loader), trainer.current_epoch) def on_complete(trainer): desc = 'Val Loss: {}'.format(np.mean(trainer.validation_history)) send_training_complete_push(checkpoint_path, desc) torch.save(model.state_dict(), 'models/{}.cpt'.format(checkpoint_path)) trainer = Trainer(train_loader, training_update_function, val_loader, validation_inference_function) trainer.add_event_handler(TrainingEvents.TRAINING_EPOCH_COMPLETED, on_epoch_competed, writer) trainer.add_event_handler(TrainingEvents.VALIDATION_COMPLETED, on_validation, writer) trainer.add_event_handler(TrainingEvents.TRAINING_COMPLETED, on_complete) trainer.run(max_epochs=epochs)
def run(): train_dataloader = data.DataLoader(CMAPSSData(train=True), shuffle=True, batch_size=300, pin_memory=USE_CUDA, collate_fn=collate_fn) validation_dataloader = data.DataLoader(CMAPSSData(train=False), shuffle=True, batch_size=1000, pin_memory=USE_CUDA, collate_fn=collate_fn) model = TTE(24, 128, n_layers=1, dropout=0.2) print(model) if USE_CUDA: model.cuda() optimizer = optim.RMSprop(model.parameters(), lr=0.001) loss_fn = weibull_loglikelihood def training_update(batch): model.train() optimizer.zero_grad() inputs, lengths, targets = batch if USE_CUDA: inputs = inputs.cuda() targets = targets.cuda() inputs = Variable(inputs) targets = Variable(targets) outputs = model(inputs, lengths) loss = loss_fn(outputs, targets, lengths) loss.backward() optimizer.step() return loss.data[0] def validation_inference(batch): model.eval() inputs, lengths, targets = batch if USE_CUDA: inputs = inputs.cuda() targets = targets.cuda() inputs = Variable(inputs, volatile=True) targets = Variable(targets, volatile=True) outputs = model(inputs, lengths) loss = loss_fn(outputs, targets, lengths) return loss.data[0], outputs.data[:, :, 0], outputs.data[:, :, 1], targets.data trainer = Trainer(training_update) evaluator = Evaluator(validation_inference) progress = Progress() plot_interval = 1 trainer.add_event_handler( Events.EPOCH_COMPLETED, Evaluate(evaluator, validation_dataloader, epoch_interval=1)) @trainer.on(Events.EPOCH_STARTED) def epoch_started(trainer): print('Epoch {:4}/{}'.format(trainer.current_epoch, trainer.max_epochs), end='') @trainer.on(Events.ITERATION_COMPLETED) def iteration_completed(trainer): if trainer.current_iteration % plot_interval == 0: avg_loss = trainer.history.simple_moving_average(window_size=100) values = [('iter', trainer.current_iteration), ('loss', avg_loss)] progress.update(values) @evaluator.on(Events.COMPLETED) def epoch_completed(evaluator): history = evaluator.history[0] loss = history[0] alpha = history[1] beta = history[2] target = history[3] mae = torch.mean(torch.abs(alpha - target[:, :, 0])) alpha = alpha.mean() beta = beta.mean() values = [('val_loss', loss), ('mae', mae), ('alpha', alpha), ('beta', beta)] progress.update(values, end=True) trainer.run(train_dataloader, max_epochs=600) return model
def train(batch_size=512, epochs=100): from torch.autograd import Variable from ignite.trainer import Trainer, TrainingEvents import logging logger = logging.getLogger('ignite') logger.addHandler(logging.StreamHandler()) logger.setLevel(logging.INFO) from tensorboardX import SummaryWriter from autoencoders.models.sampling import vae_reconstructions from autoencoders.data.mnist import mnist_dataloader from autoencoders.utils.tensorboard import run_path from autoencoders.models.loss import VAELoss import numpy as np use_gpu = torch.cuda.is_available() writer = SummaryWriter(run_path('vae')) model = VariationalAutoencoder() optimizer = torch.optim.Adam(model.parameters(), 3e-4) criterion = VAELoss() train_loader = mnist_dataloader(path='data', batch_size=batch_size, download=True) val_loader = mnist_dataloader(path='data', batch_size=batch_size, train=False, download=True) if use_gpu: model.cuda() criterion.cuda() def training_update_function(batch): model.train() optimizer.zero_grad() inputs, _ = batch inputs = Variable(inputs) if use_gpu: inputs = inputs.cuda() output, mu, logvar = model(inputs) loss = criterion(output, inputs.view(-1, 784), mu, logvar) loss.backward() optimizer.step() return loss.data[0] def validation_inference_function(batch): model.eval() inputs, _ = batch inputs = Variable(inputs) if use_gpu: inputs = inputs.cuda() output, mu, logvar = model(inputs) loss = criterion(output, inputs, mu, logvar) return loss.data[0] def on_epoch_competed(trainer, writer): writer.add_scalar('loss/training loss', np.mean(trainer.training_history), trainer.current_epoch) def on_validation(trainer, writer): writer.add_scalar('loss/validation loss', np.mean(trainer.validation_history), trainer.current_epoch) writer.add_image('image', vae_reconstructions(model, val_loader), trainer.current_epoch) trainer = Trainer(train_loader, training_update_function, val_loader, validation_inference_function) trainer.add_event_handler(TrainingEvents.TRAINING_EPOCH_COMPLETED, on_epoch_competed, writer) trainer.add_event_handler(TrainingEvents.VALIDATION_COMPLETED, on_validation, writer) trainer.run(max_epochs=epochs)
def run(batch_size, val_batch_size, epochs, lr, momentum, log_interval, logger): vis = visdom.Visdom() if not vis.check_connection(): raise RuntimeError( "Visdom server not running. Please run python -m visdom.server") data_transform = Compose([ToTensor(), Normalize((0.1307, ), (0.3081, ))]) train_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True), batch_size=batch_size, shuffle=True) val_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False), batch_size=val_batch_size, shuffle=False) model = Net() optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) def training_update_function(batch): model.train() optimizer.zero_grad() data, target = Variable(batch[0]), Variable(batch[1]) output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() return loss.data[0] def validation_inference_function(batch): model.eval() data, target = Variable(batch[0], volatile=True), Variable(batch[1]) output = model(data) loss = F.nll_loss(output, target, size_average=False).data[0] pred = output.data.max(1, keepdim=True)[1] correct = pred.eq(target.data.view_as(pred)).sum() return loss, correct trainer = Trainer(training_update_function) evaluator = Evaluator(validation_inference_function) # trainer event handlers trainer.add_event_handler(Events.ITERATION_COMPLETED, log_simple_moving_average, window_size=100, metric_name="NLL", should_log=lambda trainer: trainer. current_iteration % log_interval == 0, logger=logger) trainer.add_event_handler( Events.ITERATION_COMPLETED, get_plot_training_loss_handler(vis, plot_every=log_interval)) trainer.add_event_handler( Events.EPOCH_COMPLETED, Evaluate(evaluator, val_loader, epoch_interval=1)) # evaluator event handlers evaluator.add_event_handler( Events.COMPLETED, get_log_validation_loss_and_accuracy_handler(logger)) evaluator.add_event_handler(Events.COMPLETED, get_plot_validation_accuracy_handler(vis), trainer) # kick everything off trainer.run(train_loader, max_epochs=epochs)
def main(args, logger): if args.reverse: preprocessing = reverse_sentence else: preprocessing = None SRC = data.Field(init_token=BOS, eos_token=EOS, include_lengths=True, preprocessing=preprocessing) TRG = data.Field(init_token=BOS, eos_token=EOS) if args.dataset == 'enja': train, val, test = SmallEnJa.splits(exts=('.en', '.ja'), fields=(SRC, TRG)) elif args.dataset == 'wmt14': train, val, test = WMT14.splits( exts=('.en', '.de'), fields=(SRC, TRG), filter_pred=lambda ex: len(ex.src) <= 50 and len(ex.trg) <= 50) SRC.build_vocab(train.src, max_size=args.src_vocab) TRG.build_vocab(train.trg, max_size=args.trg_vocab) stoi = TRG.vocab.stoi train_iter, val_iter = data.BucketIterator.splits( (train, val), batch_sizes=(args.batch, args.batch * 2), repeat=False, sort_within_batch=True, sort_key=SmallEnJa.sort_key, device=args.gpu[0] if len(args.gpu) == 1 else -1) test_iter = data.Iterator(test, batch_size=1, repeat=False, sort=False, train=False, device=args.gpu[0] if args.gpu else -1) model = Seq2Seq(len(SRC.vocab), args.embed, args.encoder_hidden, len(TRG.vocab), args.embed, args.decoder_hidden, args.encoder_layers, not args.encoder_unidirectional, args.decoder_layers, SRC.vocab.stoi[SRC.pad_token], stoi[TRG.pad_token], stoi[TRG.init_token], stoi[TRG.eos_token], args.dropout_ratio, args.attention_type) model.prepare_translation(TRG.vocab.itos, args.max_length) translate = model.translate if len(args.gpu) >= 2: model = DataParallel(model, device_ids=args.gpu, dim=1).cuda() elif len(args.gpu) == 1: model.cuda(args.gpu[0]) if args.optim == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) scheduler = None elif args.optim == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.learning_rate) scheduler = MultiStepLR(optimizer, milestones=list(range(8, 12)), gamma=0.5) trainer = Trainer( TeacherForceUpdater(model, optimizer, model, args.gradient_clipping)) evaluator = Evaluator(TeacherForceInference(model, model)) trainer.add_event_handler(Events.EPOCH_COMPLETED, log_training_average_nll, logger=logger) trainer.add_event_handler(Events.EPOCH_COMPLETED, Evaluate(evaluator, val_iter, epoch_interval=1)) evaluator.add_event_handler(Events.COMPLETED, get_log_validation_ppl(val.trg), logger=logger) if args.best_file is not None: evaluator.add_event_handler(Events.COMPLETED, BestModelSnapshot(model, 'ppl', 1e10, le), args.best_file, logger) trainer.add_event_handler(Events.COMPLETED, ComputeBleu(model, test.trg, translate), test_iter, args.best_file, logger) if scheduler is not None: trainer.add_event_handler(Events.EPOCH_STARTED, lambda trainer: scheduler.step()) trainer.run(train_iter, max_epochs=args.epoch)
def run(model_dir, max_len, source_train_path, target_train_path, source_val_path, target_val_path, enc_max_vocab, dec_max_vocab, encoder_emb_size, decoder_emb_size, encoder_units, decoder_units, batch_size, epochs, learning_rate, decay_step, decay_percent, log_interval, save_interval, compare_interval): train_iter, val_iter, source_vocab, target_vocab = create_dataset( batch_size, enc_max_vocab, dec_max_vocab, source_train_path, target_train_path, source_val_path, target_val_path) transformer = Transformer(max_length=max_len, enc_vocab=source_vocab, dec_vocab=target_vocab, enc_emb_size=encoder_emb_size, dec_emb_size=decoder_emb_size, enc_units=encoder_units, dec_units=decoder_units) loss_fn = nn.CrossEntropyLoss() opt = optim.Adam(transformer.parameters(), lr=learning_rate) lr_decay = StepLR(opt, step_size=decay_step, gamma=decay_percent) if torch.cuda.is_available(): transformer.cuda() loss_fn.cuda() def training_update_function(batch): transformer.train() lr_decay.step() opt.zero_grad() softmaxed_predictions, predictions = transformer(batch.src, batch.trg) flattened_predictions = predictions.view(-1, len(target_vocab.itos)) flattened_target = batch.trg.view(-1) loss = loss_fn(flattened_predictions, flattened_target) loss.backward() opt.step() return softmaxed_predictions.data, loss.data[0], batch.trg.data def validation_inference_function(batch): transformer.eval() softmaxed_predictions, predictions = transformer(batch.src, batch.trg) flattened_predictions = predictions.view(-1, len(target_vocab.itos)) flattened_target = batch.trg.view(-1) loss = loss_fn(flattened_predictions, flattened_target) return loss.data[0] trainer = Trainer(train_iter, training_update_function, val_iter, validation_inference_function) trainer.add_event_handler(TrainingEvents.TRAINING_STARTED, restore_checkpoint_hook(transformer, model_dir)) trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED, log_training_simple_moving_average, window_size=10, metric_name="CrossEntropy", should_log=lambda trainer: trainer. current_iteration % log_interval == 0, history_transform=lambda history: history[1]) trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED, save_checkpoint_hook(transformer, model_dir), should_save=lambda trainer: trainer. current_iteration % save_interval == 0) trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED, print_current_prediction_hook(target_vocab), should_print=lambda trainer: trainer. current_iteration % compare_interval == 0) trainer.add_event_handler(TrainingEvents.VALIDATION_COMPLETED, log_validation_simple_moving_average, window_size=10, metric_name="CrossEntropy") trainer.add_event_handler(TrainingEvents.TRAINING_COMPLETED, save_checkpoint_hook(transformer, model_dir), should_save=lambda trainer: True) trainer.run(max_epochs=epochs, validate_every_epoch=True)