def run(model_dir, max_len, source_train_path, target_train_path, source_val_path, target_val_path, enc_max_vocab, dec_max_vocab, encoder_emb_size, decoder_emb_size, encoder_units, decoder_units, batch_size, epochs, learning_rate, decay_step, decay_percent, log_interval, save_interval, compare_interval): train_iter, val_iter, source_vocab, target_vocab = create_dataset( batch_size, enc_max_vocab, dec_max_vocab, source_train_path, target_train_path, source_val_path, target_val_path) transformer = Transformer(max_length=max_len, enc_vocab=source_vocab, dec_vocab=target_vocab, enc_emb_size=encoder_emb_size, dec_emb_size=decoder_emb_size, enc_units=encoder_units, dec_units=decoder_units) loss_fn = nn.CrossEntropyLoss() opt = optim.Adam(transformer.parameters(), lr=learning_rate) lr_decay = StepLR(opt, step_size=decay_step, gamma=decay_percent) if torch.cuda.is_available(): transformer.cuda() loss_fn.cuda() def training_update_function(batch): transformer.train() lr_decay.step() opt.zero_grad() softmaxed_predictions, predictions = transformer(batch.src, batch.trg) flattened_predictions = predictions.view(-1, len(target_vocab.itos)) flattened_target = batch.trg.view(-1) loss = loss_fn(flattened_predictions, flattened_target) loss.backward() opt.step() return softmaxed_predictions.data, loss.data[0], batch.trg.data def validation_inference_function(batch): transformer.eval() softmaxed_predictions, predictions = transformer(batch.src, batch.trg) flattened_predictions = predictions.view(-1, len(target_vocab.itos)) flattened_target = batch.trg.view(-1) loss = loss_fn(flattened_predictions, flattened_target) return loss.data[0] trainer = Trainer(train_iter, training_update_function, val_iter, validation_inference_function) trainer.add_event_handler(TrainingEvents.TRAINING_STARTED, restore_checkpoint_hook(transformer, model_dir)) trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED, log_training_simple_moving_average, window_size=10, metric_name="CrossEntropy", should_log=lambda trainer: trainer. current_iteration % log_interval == 0, history_transform=lambda history: history[1]) trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED, save_checkpoint_hook(transformer, model_dir), should_save=lambda trainer: trainer. current_iteration % save_interval == 0) trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED, print_current_prediction_hook(target_vocab), should_print=lambda trainer: trainer. current_iteration % compare_interval == 0) trainer.add_event_handler(TrainingEvents.VALIDATION_COMPLETED, log_validation_simple_moving_average, window_size=10, metric_name="CrossEntropy") trainer.add_event_handler(TrainingEvents.TRAINING_COMPLETED, save_checkpoint_hook(transformer, model_dir), should_save=lambda trainer: True) trainer.run(max_epochs=epochs, validate_every_epoch=True)
def run(model_dir, max_len, source_train_path, target_train_path, source_val_path, target_val_path, enc_max_vocab, dec_max_vocab, encoder_emb_size, decoder_emb_size, encoder_units, decoder_units, batch_size, epochs, learning_rate, decay_step, decay_percent, val_interval, save_interval, compare_interval): logging.basicConfig(filename="validation.log", filemode="w", level=logging.INFO) train_iter, val_iter, source_vocab, target_vocab = create_dataset( batch_size, enc_max_vocab, dec_max_vocab, source_train_path, target_train_path, source_val_path, target_val_path) transformer = Transformer(max_length=max_len, enc_vocab=source_vocab, dec_vocab=target_vocab, enc_emb_size=encoder_emb_size, dec_emb_size=decoder_emb_size, enc_units=encoder_units, dec_units=decoder_units) loss_fn = nn.CrossEntropyLoss() opt = optim.Adam(transformer.parameters(), lr=learning_rate) lr_decay = StepLR(opt, step_size=decay_step, gamma=decay_percent) if torch.cuda.is_available(): transformer.cuda() loss_fn.cuda() def training_step(engine, batch): transformer.train() lr_decay.step() opt.zero_grad() _, predictions = transformer(batch.src, batch.trg) flattened_predictions = predictions.view(-1, len(target_vocab.itos)) flattened_target = batch.trg.view(-1) loss = loss_fn(flattened_predictions, flattened_target) loss.backward() opt.step() return loss.cpu().item() def validation_step(engine, batch): transformer.eval() with torch.no_grad(): softmaxed_predictions, predictions = transformer( batch.src, batch.trg) flattened_predictions = predictions.view(-1, len(target_vocab.itos)) flattened_target = batch.trg.view(-1) loss = loss_fn(flattened_predictions, flattened_target) if not engine.state.output: predictions = softmaxed_predictions.argmax( -1).cpu().numpy().tolist() targets = batch.trg.cpu().numpy().tolist() else: predictions = engine.state.output[ "predictions"] + softmaxed_predictions.argmax( -1).cpu().numpy().tolist() targets = engine.state.output["targets"] + batch.trg.cpu( ).numpy().tolist() return { "loss": loss.cpu().item(), "predictions": predictions, "targets": targets } trainer = Engine(training_step) evaluator = Engine(validation_step) checkpoint_handler = ModelCheckpoint(model_dir, "Transformer", save_interval=save_interval, n_saved=10, require_empty=False) timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # Attach training metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "train_loss") # Attach validation metrics RunningAverage(output_transform=lambda x: x["loss"]).attach( evaluator, "val_loss") pbar = ProgressBar() pbar.attach(trainer, ["train_loss"]) # trainer.add_event_handler(Events.TRAINING_STARTED, # restore_checkpoint_hook(transformer, model_dir)) trainer.add_event_handler(Events.ITERATION_COMPLETED, handler=validation_result_hook( evaluator, val_iter, target_vocab, val_interval, logger=logging.info)) trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED, handler=checkpoint_handler, to_save={ "nmt": { "transformer": transformer, "opt": opt, "lr_decay": lr_decay } }) # Run the prediction trainer.run(train_iter, max_epochs=epochs)