def test_create_lr_scheduler_with_warmup_with_real_model(dummy_model_factory): model = dummy_model_factory(with_grads=False, with_frozen_layer=False) init_lr = 0.01 optimizer = torch.optim.SGD(model.parameters(), lr=init_lr) scaled_lr = 0.02 warmup_duration = 5 step_size = 2 gamma = 0.97 output_simulated_values = [None] * 50 create_lr_scheduler_with_warmup( torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma), warmup_start_value=0.0, warmup_end_value=scaled_lr, warmup_duration=warmup_duration, output_simulated_values=output_simulated_values, ) assert output_simulated_values[0] == [0, 0.0] assert output_simulated_values[warmup_duration - 1] == [ warmup_duration - 1, scaled_lr, ] assert output_simulated_values[warmup_duration] == [ warmup_duration, init_lr ] v = [warmup_duration + step_size, init_lr * gamma] assert output_simulated_values[warmup_duration + step_size] == v
def test_create_lr_scheduler_with_warmup(): with pytest.raises(TypeError): create_lr_scheduler_with_warmup(12, warmup_start_value=0.0, warmup_end_value=0.1, warmup_duration=10) def _test(lr_scheduler, optimizer, warmup_end_value_to_check=None): num_iterations = 10 max_epochs = 20 warmup_duration = 10 warmup_end_value = 0.1 simulated_values = [None] * (num_iterations * max_epochs) scheduler = create_lr_scheduler_with_warmup( lr_scheduler, warmup_start_value=0.0, warmup_end_value=warmup_end_value, warmup_duration=warmup_duration, output_simulated_values=simulated_values) lrs = [] trainer = Engine(lambda engine, batch: None) @trainer.on(Events.ITERATION_COMPLETED) def save_lr(engine): lrs.append(optimizer.param_groups[0]['lr']) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) data = [0] * num_iterations trainer.run(data, max_epochs=max_epochs) assert lrs == pytest.approx([v for i, v in simulated_values]) if warmup_end_value_to_check is None: warmup_end_value_to_check = warmup_end_value assert lrs[warmup_duration] == warmup_end_value_to_check t1 = torch.zeros([1], requires_grad=True) optimizer = torch.optim.SGD([t1], lr=0.2) torch_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=0.98) _test(torch_lr_scheduler, optimizer) t1 = torch.zeros([1], requires_grad=True) optimizer = torch.optim.SGD([t1], lr=0.2) lr_scheduler = LinearCyclicalScheduler(optimizer=optimizer, param_name='lr', start_value=1.0, end_value=0.0, cycle_size=10) _test(lr_scheduler, optimizer, 1.0)
def _test(lr_scheduler, optimizer): num_iterations = 10 max_epochs = 20 simulated_values = [None] * (num_iterations * max_epochs) scheduler = create_lr_scheduler_with_warmup( lr_scheduler, warmup_start_value=0.0, warmup_end_value=0.1, warmup_duration=10, output_simulated_values=simulated_values) lrs = [] trainer = Engine(lambda engine, batch: None) @trainer.on(Events.ITERATION_COMPLETED) def save_lr(engine): lrs.append(optimizer.param_groups[0]['lr']) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) data = [0] * num_iterations trainer.run(data, max_epochs=max_epochs) assert lrs == pytest.approx([v for i, v in simulated_values])
def _test(save_history): tensor = torch.ones([1], requires_grad=True) optimizer = torch.optim.SGD([tensor], lr=0.001) max_epochs = 25 lr_max_value = 0.4 num_iterations_per_epoch = 128 num_iterations = max_epochs * num_iterations_per_epoch warmup_duration = 5 * num_iterations_per_epoch cooldown_duration = 5 * num_iterations_per_epoch scheduler_1 = LinearCyclicalScheduler( optimizer, "lr", start_value=lr_max_value, end_value=lr_max_value * 0.9, cycle_size=(num_iterations - warmup_duration - cooldown_duration) * 2, ) scheduler_2 = LinearCyclicalScheduler( optimizer, "lr", start_value=lr_max_value, end_value=0.0, cycle_size=cooldown_duration * 2 ) lr_scheduler = ConcatScheduler( schedulers=[scheduler_1, scheduler_2], durations=[num_iterations - warmup_duration - cooldown_duration], save_history=False, ) lr_values = [None] * num_iterations scheduler = create_lr_scheduler_with_warmup( lr_scheduler, warmup_start_value=0.0, warmup_end_value=lr_max_value, warmup_duration=warmup_duration, save_history=save_history, output_simulated_values=lr_values, ) state_dict = scheduler.state_dict() trainer = Engine(lambda engine, batch: None) @trainer.on(Events.ITERATION_COMPLETED) def save_lr(engine): lrs.append(optimizer.param_groups[0]["lr"]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) data = [0] * num_iterations_per_epoch for _ in range(2): lrs = [] trainer.run(data, max_epochs=max_epochs) assert lrs == pytest.approx([v for i, v in lr_values]) if save_history: param_history = trainer.state.param_history["lr"] assert lrs == pytest.approx([v[0] for v in param_history]) scheduler.load_state_dict(state_dict)
def _test( lr_scheduler, optimizer, warmup_start_value, warmup_end_value, warmup_duration, warmup_end_next_value, ): num_iterations = 10 max_epochs = 20 simulated_values = [None] * (num_iterations * max_epochs) scheduler = create_lr_scheduler_with_warmup( lr_scheduler, warmup_start_value=warmup_start_value, warmup_end_value=warmup_end_value, warmup_duration=warmup_duration, output_simulated_values=simulated_values, ) state_dict = scheduler.state_dict() trainer = Engine(lambda engine, batch: None) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) @trainer.on(Events.ITERATION_STARTED) def save_lr(engine): lrs.append(optimizer.param_groups[0]["lr"]) data = [0] * num_iterations for _ in range(2): lrs = [] trainer.run(data, max_epochs=max_epochs) assert lrs == pytest.approx([v for i, v in simulated_values]) assert lrs[0] == pytest.approx( warmup_start_value), "lrs={}".format(lrs[:warmup_duration + num_iterations]) assert lrs[warmup_duration - 1] == pytest.approx(warmup_end_value), "lrs={}".format( lrs[:warmup_duration + num_iterations]) assert lrs[warmup_duration] == pytest.approx( warmup_end_next_value), "lrs={}".format(lrs[:warmup_duration + num_iterations]) scheduler.load_state_dict(state_dict)
def attach_lr_warmup(trainer, config, lr_scheduler): warmup_duration = ( config['warmup_duration'] if config['warmup_duration'] > 0 else config['steps_per_epoch'] * -config['warmup_duration'] ) warmup_end_value = ( config['warmup_end_value'] if config['warmup_end_value'] != -1 else config['learning_rate'] ) scheduler_with_warmup = create_lr_scheduler_with_warmup( lr_scheduler, config['warmup_start_value'], warmup_end_value, warmup_duration, ) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler_with_warmup)
model.to(device) # multi-gpus if torch.cuda.device_count(): print('==================== Use {} GPUs ===================='.format(torch.cuda.device_count())) model = nn.DataParallel(model) # loss function loss_fn = nn.CrossEntropyLoss() # optimizer optimizer = optim.SGD(model.parameters(), lr=init_lr, momentum=0.9, weight_decay=5e-4) # scheduler scheduler = CosineAnnealingScheduler(optimizer, 'lr', init_lr, end_lr, 4*len(trainloader), cycle_mult=1.5, start_value_mult=0.1) scheduler = create_lr_scheduler_with_warmup(scheduler, warmup_start_value=0., warmup_end_value=init_lr, warmup_duration=len(trainloader)) # create trainer trainer = create_trainer(model, optimizer, loss_fn, device=device) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # add timer for each iteration timer = Timer(average=False) # logging training loss def log_loss(engine): i = engine.state.iteration e = engine.state.epoch if i % 100 == 0: print('[Iters {:0>7d}/{:0>2d}, {:.2f}s/100 iters, lr={:.4E}] loss={:.4f}'.format(i, e, timer.value(), optimizer.param_groups[0]['lr'], engine.state.output))
def test_create_lr_scheduler_with_warmup(): with pytest.raises(TypeError, match=r"Argument lr_scheduler should be a subclass of"): create_lr_scheduler_with_warmup(12, warmup_start_value=0.0, warmup_end_value=0.1, warmup_duration=10) t1 = torch.zeros([1], requires_grad=True) # A) opt lr != warmup_end_value optimizer = torch.optim.SGD([t1], lr=0.2) torch_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=0.98) with pytest.raises( ValueError, match=r"Argument warmup_duration should be at least 2 events"): create_lr_scheduler_with_warmup( torch_lr_scheduler, warmup_start_value=0.0, warmup_end_value=0.1, warmup_duration=1, ) with pytest.raises( ValueError, match=r"Argument warmup_duration should be at least 2 events"): create_lr_scheduler_with_warmup( torch_lr_scheduler, warmup_start_value=0.0, warmup_end_value=0.1, warmup_duration="abc", ) with pytest.raises( TypeError, match=r"Argument output_simulated_values should be a list of None" ): simulated_values = () create_lr_scheduler_with_warmup( torch_lr_scheduler, warmup_start_value=0.0, warmup_end_value=0.1, warmup_duration=10, output_simulated_values=simulated_values, ) def _test( lr_scheduler, optimizer, warmup_start_value, warmup_end_value, warmup_duration, warmup_end_next_value, ): num_iterations = 10 max_epochs = 20 simulated_values = [None] * (num_iterations * max_epochs) scheduler = create_lr_scheduler_with_warmup( lr_scheduler, warmup_start_value=warmup_start_value, warmup_end_value=warmup_end_value, warmup_duration=warmup_duration, output_simulated_values=simulated_values, ) state_dict = scheduler.state_dict() trainer = Engine(lambda engine, batch: None) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) @trainer.on(Events.ITERATION_STARTED) def save_lr(engine): lrs.append(optimizer.param_groups[0]["lr"]) data = [0] * num_iterations for _ in range(2): lrs = [] trainer.run(data, max_epochs=max_epochs) assert lrs == pytest.approx([v for i, v in simulated_values]) assert lrs[0] == pytest.approx( warmup_start_value), "lrs={}".format(lrs[:warmup_duration + num_iterations]) assert lrs[warmup_duration - 1] == pytest.approx(warmup_end_value), "lrs={}".format( lrs[:warmup_duration + num_iterations]) assert lrs[warmup_duration] == pytest.approx( warmup_end_next_value), "lrs={}".format(lrs[:warmup_duration + num_iterations]) scheduler.load_state_dict(state_dict) t1 = torch.zeros([1], requires_grad=True) # A) opt lr != warmup_end_value optimizer = torch.optim.SGD([t1], lr=0.2) torch_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=0.98) _test(torch_lr_scheduler, optimizer, 0.01, 0.05, 10, 0.2) optimizer = torch.optim.SGD([t1], lr=0.2) torch_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=0.98) _test(torch_lr_scheduler, optimizer, 0.01, 0.05, 2, 0.2) # B) opt lr == warmup_end_value optimizer = torch.optim.SGD([t1], lr=0.2) torch_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=0.98) _test(torch_lr_scheduler, optimizer, 0.01, 0.2, 10, 0.2 * 0.98) optimizer = torch.optim.SGD([t1], lr=0.2) torch_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=0.98) _test(torch_lr_scheduler, optimizer, 0.01, 0.2, 2, 0.2 * 0.98) # C) lr_scheduler start_value != warmup_end_value t1 = torch.zeros([1], requires_grad=True) optimizer = torch.optim.SGD([t1], lr=0.0) lr_scheduler = LinearCyclicalScheduler( optimizer=optimizer, param_name="lr", start_value=0.8, end_value=0.0, cycle_size=10, ) _test(lr_scheduler, optimizer, 0.01, 0.05, 10, 0.8) optimizer = torch.optim.SGD([t1], lr=0.0) lr_scheduler = LinearCyclicalScheduler( optimizer=optimizer, param_name="lr", start_value=0.8, end_value=0.0, cycle_size=10, ) _test(lr_scheduler, optimizer, 0.01, 0.05, 2, 0.8) # D) lr_scheduler start_value == warmup_end_value t1 = torch.zeros([1], requires_grad=True) optimizer = torch.optim.SGD([t1], lr=0.0) lr_scheduler = LinearCyclicalScheduler( optimizer=optimizer, param_name="lr", start_value=0.8, end_value=0.0, cycle_size=10, ) _test(lr_scheduler, optimizer, 0.01, 0.8, 10, 0.8 - (0.8 / 5.0)) optimizer = torch.optim.SGD([t1], lr=0.0) lr_scheduler = LinearCyclicalScheduler( optimizer=optimizer, param_name="lr", start_value=0.8, end_value=0.0, cycle_size=10, ) _test(lr_scheduler, optimizer, 0.01, 0.8, 2, 0.8 - (0.8 / 5.0))
def main(parser_args): """Main function to create trainer engine, add handlers to train and validation engines. Then runs train engine to perform training and validation. Args: parser_args (dict): parsed arguments """ dataloader_train, dataloader_validation = get_dataloaders(parser_args) criterion = nn.CrossEntropyLoss() unet = SphericalUNet(parser_args.pooling_class, parser_args.n_pixels, parser_args.depth, parser_args.laplacian_type, parser_args.kernel_size) unet, device = init_device(parser_args.device, unet) lr = parser_args.learning_rate optimizer = optim.Adam(unet.parameters(), lr=lr) def trainer(engine, batch): """Train Function to define train engine. Called for every batch of the train engine, for each epoch. Args: engine (ignite.engine): train engine batch (:obj:`torch.utils.data.dataloader`): batch from train dataloader Returns: :obj:`torch.tensor` : train loss for that batch and epoch """ unet.train() data, labels = batch labels = labels.to(device) data = data.to(device) output = unet(data) B, V, C = output.shape B_labels, V_labels, C_labels = labels.shape output = output.view(B * V, C) labels = labels.view(B_labels * V_labels, C_labels).max(1)[1] loss = criterion(output, labels) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item() writer = SummaryWriter(parser_args.tensorboard_path) engine_train = Engine(trainer) engine_validate = create_supervised_evaluator( model=unet, metrics={"AP": EpochMetric(average_precision_compute_fn)}, device=device, output_transform=validate_output_transform) engine_train.add_event_handler( Events.EPOCH_STARTED, lambda x: print("Starting Epoch: {}".format(x.state.epoch))) engine_train.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) @engine_train.on(Events.EPOCH_COMPLETED) def epoch_validation(engine): """Handler to run the validation engine at the end of the train engine's epoch. Args: engine (ignite.engine): train engine """ print("beginning validation epoch") engine_validate.run(dataloader_validation) reduce_lr_plateau = ReduceLROnPlateau( optimizer, mode=parser_args.reducelronplateau_mode, factor=parser_args.reducelronplateau_factor, patience=parser_args.reducelronplateau_patience, ) @engine_validate.on(Events.EPOCH_COMPLETED) def update_reduce_on_plateau(engine): """Handler to reduce the learning rate on plateau at the end of the validation engine's epoch Args: engine (ignite.engine): validation engine """ ap = engine.state.metrics["AP"] mean_average_precision = np.mean(ap[1:]) reduce_lr_plateau.step(mean_average_precision) @engine_validate.on(Events.EPOCH_COMPLETED) def save_epoch_results(engine): """Handler to save the metrics at the end of the validation engine's epoch Args: engine (ignite.engine): validation engine """ ap = engine.state.metrics["AP"] mean_average_precision = np.mean(ap[1:]) print("Average precisions:", ap) print("mAP:", mean_average_precision) writer.add_scalars( "metrics", { "mean average precision (AR+TC)": mean_average_precision, "AR average precision": ap[2], "TC average precision": ap[1] }, engine_train.state.epoch, ) writer.close() step_scheduler = StepLR(optimizer, step_size=parser_args.steplr_step_size, gamma=parser_args.steplr_gamma) scheduler = create_lr_scheduler_with_warmup( step_scheduler, warmup_start_value=parser_args.warmuplr_warmup_start_value, warmup_end_value=parser_args.warmuplr_warmup_end_value, warmup_duration=parser_args.warmuplr_warmup_duration, ) engine_validate.add_event_handler(Events.EPOCH_COMPLETED, scheduler) earlystopper = EarlyStopping( patience=parser_args.earlystopping_patience, score_function=lambda x: -x.state.metrics["AP"][1], trainer=engine_train) engine_validate.add_event_handler(Events.EPOCH_COMPLETED, earlystopper) add_tensorboard(engine_train, optimizer, unet, log_dir=parser_args.tensorboard_path) engine_train.run(dataloader_train, max_epochs=parser_args.n_epochs) torch.save(unet.state_dict(), parser_args.model_save_path + "unet_state.pt")
def train(cfg): print(cfg.pretty()) ################################################################### # Dataset ################################################################### wt = Dataset(batch_size=cfg.train.batch_size, bptt_len=cfg.train.bptt_len, dataset_cls=hydra.utils.get_class(cfg.dataset.name)) ################################################################### # Models ################################################################### base_embedding = hydra.utils.instantiate(cfg.embedding, ntokens=len(wt.text_field.vocab) + 3) embedding = TransformerEmbedding( embedding=base_embedding, max_length=cfg.train.bptt_len, embedding_size=base_embedding.embedding_size, use_positional_embedding=False) encoder = TransformerEncoder(query_dim=cfg.encoder.query_dim, att_num_units=cfg.encoder.att_num_units, ffn_num_unit=cfg.encoder.ffn_num_unit, max_ext=cfg.encoder.max_ext) model = TransformerLanguageModel(embedding, encoder) model.init_weight() # wandb.watch(model) ################################################################### # Loss ################################################################### criterion = lm_criterion(in_features=cfg.encoder.att_num_units[-1], vocab_size=len(wt.text_field.vocab)) ################################################################### # Parameters + Train ops ################################################################### parameters = (list(model.parameters()) + list(criterion.parameters())) tot_params = 0 for p in parameters: tot_params += reduce(lambda x, y: x * y, p.size()) print("Total Parameters: ", tot_params) opt = optim.Adam(parameters, lr=cfg.train.lr) model.to(DEVICE) criterion.to(DEVICE) ################################################################### # Train + Evaluation ################################################################### def train_step(engine, batch): model.train() opt.zero_grad() text = batch.text.to(DEVICE).t().contiguous() target = batch.target.to(DEVICE).t().contiguous() out, out_past = model(text, engine.state.train_past) engine.state.train_past = out_past raw_loss = criterion(out.view(-1, out.size(2)), target.view(-1)) loss = raw_loss[1] loss.backward() nn.utils.clip_grad_norm_(parameters, cfg.train.clip_grad) opt.step() return {"train_loss": loss.item(), "train_ppl": loss.exp().item()} def eval_step(engine, batch): model.eval() if not hasattr(engine.state, "eval_past"): engine.state.eval_past = None target_sample = [] result_sample = [] with torch.no_grad(): text = batch.text.to(DEVICE).t().contiguous() target = batch.target.to(DEVICE).t().contiguous() out, out_past = model(text, engine.state.eval_past) vocab = wt.text_field.vocab idx = list(range(32)) sample = random.choices(idx, k=5) for id_sample in sample: s = [] for target_id in target[id_sample]: s.append(vocab.itos[target_id]) target_sample.append(" ".join(s)) s = [] for result_id in out.max(-1)[1][id_sample]: s.append(vocab.itos[result_id]) result_sample.append(" ".join(s)) # engine.state.eval_past = out_past raw_loss = criterion(out.view(-1, out.size(2)), target.view(-1)) loss = raw_loss[1] return { "val_loss": loss.item(), "sample": (target_sample, result_sample) } train_engine = Engine(train_step) eval_engine = Engine(eval_step) def reset_state(engine): engine.state.train_past = None def run_eval(_): print("start running eval") eval_engine.run(wt.valid_iter) metrics = eval_engine.state.metrics print("Validation loss: ", metrics["val_loss"], ", ppl: ", np.exp(metrics["val_loss"])) train_engine.add_event_handler(Events.EPOCH_STARTED, reset_state) train_engine.add_event_handler(Events.EPOCH_COMPLETED, run_eval) ################################################################### # LR Scheduler ################################################################### cosine_scheduler = CosineAnnealingScheduler(opt.param_groups[0], "lr", 0.0, 2.5e-4, cycle_size=len(wt.train_iter)) warmup_scheduler = create_lr_scheduler_with_warmup(cosine_scheduler, 0.0, 2.5e-4, 200) train_engine.add_event_handler(Events.ITERATION_STARTED, warmup_scheduler) ################################################################### # Metrics ################################################################### RunningAverage(output_transform=lambda x: x["train_ppl"]).attach( train_engine, "train_ppl") RunningAverage(output_transform=lambda x: x["train_loss"]).attach( train_engine, "train_loss") RunningAverage(output_transform=lambda x: x["val_loss"]).attach( eval_engine, "val_loss") progress_bar = ProgressBar(persist=True) progress_bar.attach(train_engine, ["train_ppl", "train_loss"]) progress_bar_val = ProgressBar(persist=True) progress_bar_val.attach(eval_engine, ["val_loss"]) ################################################################### # Tensorboard ################################################################### # tb_logger = TensorboardLogger(log_dir=log_dir) tb_logger = WandbLogger(project="language_model", entity="akurniawan") tb_logger.watch(model) def stepn_logger(num_steps, handler): def logger_runner(engine, log_handler, event_name): if engine.state.iteration % num_steps == 0: handler(engine, log_handler, event_name) return logger_runner tb_logger.attach(train_engine, log_handler=stepn_logger( cfg.train.log_steps, OutputHandler(tag="training", output_transform=lambda loss: loss)), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(eval_engine, log_handler=OutputHandler( tag="validation", output_transform=lambda loss: loss, another_engine=train_engine), event_name=Events.EPOCH_COMPLETED) # tb_logger.attach(train_engine, # log_handler=stepn_logger(log_steps, # OptimizerParamsHandler(opt)), # event_name=Events.ITERATION_STARTED) # tb_logger.attach(train_engine, # log_handler=stepn_logger(log_steps, # WeightsScalarHandler(model)), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(train_engine, # log_handler=stepn_logger(log_steps, # GradsScalarHandler(model)), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(train_engine, # log_handler=stepn_logger(500, WeightsHistHandler(model)), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(train_engine, # log_handler=stepn_logger(500, GradsHistHandler(model)), # event_name=Events.ITERATION_COMPLETED) try: train_engine.run(wt.train_iter, max_epochs=cfg.train.epochs) except Exception: pass finally: tb_logger.close()
def train(epochs=500, batch_size=32, bptt_len=70, lr=0.00025, log_steps=200, clip_grad=0.25, log_dir="experiments"): ################################################################### # Dataset ################################################################### wt = wikitext103(batch_size=batch_size, bptt_len=bptt_len) # wt = wikitext2(batch_size=batch_size, bptt_len=bptt_len) ################################################################### # Configs ################################################################### embedding_config = DropEmbedding.Hyperparams(len(wt.text_field.vocab) + 3, ninp=512) encoder_config = TransformerEncoder.Hyperparams( att_num_units=[512, 512, 512, 512, 512, 512], max_ext=384) ################################################################### # Models ################################################################### base_embedding = DropEmbedding(embedding_config) embedding = TransformerEmbedding(embedding=base_embedding, max_length=bptt_len, embedding_size=embedding_config.ninp, use_positional_embedding=False) encoder = TransformerEncoder(encoder_config) model = TransformerLanguageModel(embedding, encoder) model.init_weight() ################################################################### # Loss ################################################################### criterion = lm_criterion(in_features=encoder_config.att_num_units[-1], vocab_size=len(wt.text_field.vocab)) ################################################################### # Parameters + Train ops ################################################################### parameters = (list(model.parameters()) + list(criterion.parameters())) tot_params = 0 for p in parameters: tot_params += reduce(lambda x, y: x * y, p.size()) print("Total Parameters: ", tot_params) opt = optim.Adam(parameters, lr=lr) model.to(DEVICE) criterion.to(DEVICE) ################################################################### # Train + Evaluation ################################################################### def train_step(engine, batch): model.train() opt.zero_grad() text = batch.text.to(DEVICE).t().contiguous() target = batch.target.to(DEVICE).t().contiguous() out, out_past = model(text, engine.state.train_past) engine.state.train_past = out_past raw_loss = criterion(out.view(-1, out.size(2)), target.view(-1)) loss = raw_loss[1] loss.backward() nn.utils.clip_grad_norm_(parameters, clip_grad) opt.step() return {"train_loss": loss.item(), "train_ppl": loss.exp().item()} def eval_step(engine, batch): model.eval() if not hasattr(engine.state, "eval_past"): engine.state.eval_past = None with torch.no_grad(): text = batch.text.to(DEVICE).t().contiguous() target = batch.target.to(DEVICE).t().contiguous() out, out_past = model(text, engine.state.eval_past) engine.state.eval_past = out_past raw_loss = criterion(out.view(-1, out.size(2)), target.view(-1)) loss = raw_loss[1] return {"val_loss": loss.item()} train_engine = Engine(train_step) eval_engine = Engine(eval_step) def reset_state(engine): engine.state.train_past = None def run_eval(_): print("start running eval") eval_engine.run(wt.valid_iter) metrics = eval_engine.state.metrics print("Validation loss: ", metrics["val_loss"], ", ppl: ", np.exp(metrics["val_loss"])) train_engine.add_event_handler(Events.EPOCH_STARTED, reset_state) train_engine.add_event_handler(Events.EPOCH_COMPLETED, run_eval) ################################################################### # LR Scheduler ################################################################### cosine_scheduler = CosineAnnealingScheduler(opt.param_groups[0], "lr", 0.0, 2.5e-4, cycle_size=len(wt.train_iter)) warmup_scheduler = create_lr_scheduler_with_warmup(cosine_scheduler, 0.0, 2.5e-4, 200) train_engine.add_event_handler(Events.ITERATION_STARTED, warmup_scheduler) ################################################################### # Metrics ################################################################### RunningAverage(output_transform=lambda x: x["train_ppl"]).attach( train_engine, "train_ppl") RunningAverage(output_transform=lambda x: x["train_loss"]).attach( train_engine, "train_loss") RunningAverage(output_transform=lambda x: x["val_loss"]).attach( eval_engine, "val_loss") progress_bar = ProgressBar(persist=True) progress_bar.attach(train_engine, ["train_ppl", "train_loss"]) progress_bar_val = ProgressBar(persist=True) progress_bar_val.attach(eval_engine, ["val_loss"]) ################################################################### # Tensorboard ################################################################### tb_logger = TensorboardLogger(log_dir=log_dir) def stepn_logger(num_steps, handler): def logger_runner(engine, log_handler, event_name): if engine.state.iteration % num_steps == 0: handler(engine, log_handler, event_name) return logger_runner tb_logger.attach(train_engine, log_handler=stepn_logger( log_steps, OutputHandler(tag="training", output_transform=lambda loss: loss)), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(eval_engine, log_handler=OutputHandler( tag="validation", output_transform=lambda loss: loss, another_engine=train_engine), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(train_engine, log_handler=stepn_logger(log_steps, OptimizerParamsHandler(opt)), event_name=Events.ITERATION_STARTED) tb_logger.attach(train_engine, log_handler=stepn_logger(log_steps, WeightsScalarHandler(model)), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(train_engine, log_handler=stepn_logger(log_steps, GradsScalarHandler(model)), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(train_engine, log_handler=stepn_logger(500, WeightsHistHandler(model)), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(train_engine, log_handler=stepn_logger(500, GradsHistHandler(model)), event_name=Events.ITERATION_COMPLETED) try: train_engine.run(wt.train_iter, max_epochs=epochs) except Exception: pass finally: tb_logger.close()