def _test(n_epochs, metric_device): n_iters = 100 s = 16 n_classes = 10 offset = n_iters * s y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device) y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device) def update(engine, i): return ( y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :], y_true[i * s + rank * offset : (i + 1) * s + rank * offset], ) engine = Engine(update) k = 5 acc = TopKCategoricalAccuracy(k=k, device=metric_device) acc.attach(engine, "acc") data = list(range(n_iters)) engine.run(data=data, max_epochs=n_epochs) assert "acc" in engine.state.metrics res = engine.state.metrics["acc"] if isinstance(res, torch.Tensor): res = res.cpu().numpy() true_res = top_k_accuracy(y_true.cpu().numpy(), y_preds.cpu().numpy(), k=k) assert pytest.approx(res) == true_res
def multiclass_train_lstm( model: LstmClassifier, dataloader_train: DataLoader, dataloader_val: DataLoader, filename_prefix: str, ): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-3) criterion = CrossEntropyLossOneHot() def process_function(_engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) y = y.to(device) y_pred = model(x) loss = criterion(y_pred, y) loss.backward() optimizer.step() return y_pred, y, loss.item(), def eval_function(_engine, batch): model.eval() with torch.no_grad(): x, y = batch y = y.to(device) x = x.to(device) y_pred = model(x) return y_pred, y def score_function(engine): return engine.state.metrics['top3-accuracy'] model.to(device) trainer = Engine(process_function) train_evaluator = Engine(eval_function) validation_evaluator = Engine(eval_function) accuracy_top1 = Accuracy(output_transform=lambda x: (x[0], x[1]), device=device, is_multilabel=True) accuracy_top3 = TopKCategoricalAccuracy(output_transform=lambda x: (x[0], x[1]), k=3, device=device) RunningAverage(accuracy_top1).attach(trainer, 'accuracy') RunningAverage(accuracy_top3).attach(trainer, 'top3-accuracy') RunningAverage(output_transform=lambda x: x[2]).attach(trainer, 'loss') accuracy_top1.attach(train_evaluator, 'accuracy') accuracy_top3.attach(train_evaluator, 'top3-accuracy') Loss(criterion).attach(train_evaluator, 'loss') accuracy_top1.attach(validation_evaluator, 'accuracy') accuracy_top3.attach(validation_evaluator, 'top3-accuracy') Loss(criterion).attach(validation_evaluator, 'loss') pbar = ProgressBar(persist=True, bar_format="") pbar.attach(engine=trainer, metric_names='all') @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): train_evaluator.run(dataloader_train) message = f'Training results - Epoch: {engine.state.epoch}.' for metric_name, score in train_evaluator.state.metrics.items(): message += f' {metric_name}: {score:.2f}.' pbar.log_message(message) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): validation_evaluator.run(dataloader_val) message = f'Validation results - Epoch: {engine.state.epoch}.' for metric_name, score in train_evaluator.state.metrics.items(): message += f' {metric_name}: {score:.2f}.' pbar.log_message(message) pbar.n = pbar.last_print_n = 0 validation_evaluator.add_event_handler( Events.COMPLETED, EarlyStopping(patience=5, score_function=score_function, trainer=trainer)) checkpointer = ModelCheckpoint(dirname=DIR_MODELS, filename_prefix=filename_prefix, score_function=score_function, score_name='top3-accuracy', n_saved=2, create_dir=True, save_as_state_dict=True, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'v2': model}) trainer.run(dataloader_train, max_epochs=20)