Exemplo n.º 1
0
def _test_gpu_info(device='cpu'):
    gpu_info = GpuInfo()

    # increase code cov
    gpu_info.reset()
    gpu_info.update(None)

    t = torch.rand(4, 10, 100, 100).to(device)
    data = gpu_info.compute()
    assert len(data) > 0
    assert "fb_memory_usage" in data[0]
    mem_report = data[0]['fb_memory_usage']
    assert 'used' in mem_report and 'total' in mem_report
    assert mem_report['total'] > 0.0
    assert mem_report['used'] > t.shape[0] * t.shape[1] * t.shape[2] * t.shape[3] / 1024.0 / 1024.0

    assert "utilization" in data[0]
    util_report = data[0]['utilization']
    assert 'gpu_util' in util_report

    # with Engine
    engine = Engine(lambda engine, batch: 0.0)
    engine.state = State(metrics={})

    gpu_info.completed(engine, name='gpu')

    assert 'gpu:0 mem(%)' in engine.state.metrics
    assert 'gpu:0 util(%)' in engine.state.metrics

    assert isinstance(engine.state.metrics['gpu:0 mem(%)'], int)
    assert int(mem_report['used'] * 100.0 / mem_report['total']) == engine.state.metrics['gpu:0 mem(%)']

    assert isinstance(engine.state.metrics['gpu:0 util(%)'], int)
    assert int(util_report['gpu_util']) == engine.state.metrics['gpu:0 util(%)']
Exemplo n.º 2
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum,
        display_gpu_info):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        F.nll_loss,
                                        device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                "accuracy": Accuracy(),
                                                "nll": Loss(F.nll_loss)
                                            },
                                            device=device)

    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")

    if display_gpu_info:
        from ignite.contrib.metrics import GpuInfo

        GpuInfo().attach(trainer, name="gpu")

    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names="all")

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        pbar.log_message(
            "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        pbar.log_message(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))

        pbar.n = pbar.last_print_n = 0

    trainer.run(train_loader, max_epochs=epochs)
Exemplo n.º 3
0
def _setup_common_training_handlers(trainer,
                                    to_save=None, save_every_iters=1000, output_path=None,
                                    lr_scheduler=None, with_gpu_stats=True,
                                    output_names=None, with_pbars=True, with_pbar_on_iters=True,
                                    log_every_iters=100, device='cuda'):
    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    if lr_scheduler is not None:
        if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step())
        else:
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)

    if to_save is not None:
        if output_path is None:
            raise ValueError("If to_save argument is provided then output_path argument should be also defined")
        checkpoint_handler = ModelCheckpoint(dirname=output_path, filename_prefix="training")
        trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler, to_save)

    if with_gpu_stats:
        GpuInfo().attach(trainer, name='gpu', event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

    if output_names is not None:

        def output_transform(x, index, name):
            if isinstance(x, Mapping):
                return x[name]
            elif isinstance(x, Sequence):
                return x[index]
            elif isinstance(x, torch.Tensor):
                return x
            else:
                raise ValueError("Unhandled type of update_function's output. "
                                 "It should either mapping or sequence, but given {}".format(type(x)))

        for i, n in enumerate(output_names):
            RunningAverage(output_transform=partial(output_transform, index=i, name=n),
                           epoch_bound=False, device=device).attach(trainer, n)

    if with_pbars:
        if with_pbar_on_iters:
            ProgressBar(persist=False).attach(trainer, metric_names='all',
                                              event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

        ProgressBar(persist=True, bar_format="").attach(trainer,
                                                        event_name=Events.EPOCH_STARTED,
                                                        closing_event_name=Events.COMPLETED)
Exemplo n.º 4
0
def _test_gpu_info(device="cpu"):
    gpu_info = GpuInfo()

    # increase code cov
    gpu_info.reset()
    gpu_info.update(None)

    t = torch.rand(4, 10, 100, 100).to(device)
    data = gpu_info.compute()
    assert len(data) > 0
    assert "fb_memory_usage" in data[0]
    mem_report = data[0]["fb_memory_usage"]
    assert "used" in mem_report and "total" in mem_report
    assert mem_report["total"] > 0.0
    assert mem_report["used"] > t.shape[0] * t.shape[1] * t.shape[2] * t.shape[
        3] / 1024.0 / 1024.0

    assert "utilization" in data[0]
    util_report = data[0]["utilization"]
    assert "gpu_util" in util_report

    # with Engine
    engine = Engine(lambda engine, batch: 0.0)
    engine.state = State(metrics={})

    gpu_info.completed(engine, name="gpu")

    assert "gpu:0 mem(%)" in engine.state.metrics

    assert isinstance(engine.state.metrics["gpu:0 mem(%)"], int)
    assert int(mem_report["used"] * 100.0 /
               mem_report["total"]) == engine.state.metrics["gpu:0 mem(%)"]

    if util_report["gpu_util"] != "N/A":
        assert "gpu:0 util(%)" in engine.state.metrics
        assert isinstance(engine.state.metrics["gpu:0 util(%)"], int)
        assert int(
            util_report["gpu_util"]) == engine.state.metrics["gpu:0 util(%)"]
    else:
        assert "gpu:0 util(%)" not in engine.state.metrics
Exemplo n.º 5
0
    def _test_with_custom_query(resp, warn_msg, check_compute=False):
        from pynvml.smi import nvidia_smi

        def query(*args, **kwargs):
            return resp

        def getInstance():
            nvsmi = Mock()
            nvsmi.DeviceQuery = Mock(side_effect=query)
            return nvsmi

        nvidia_smi.getInstance = Mock(side_effect=getInstance)
        gpu_info = GpuInfo()
        if check_compute:
            with pytest.warns(UserWarning, match=warn_msg):
                gpu_info.compute()

        # with Engine
        engine = Engine(lambda engine, batch: 0.0)
        engine.state = State(metrics={})

        with pytest.warns(UserWarning, match=warn_msg):
            gpu_info.completed(engine, name="gpu info")
Exemplo n.º 6
0
def _setup_common_training_handlers(
    trainer: Engine,
    to_save: Optional[Mapping] = None,
    save_every_iters: int = 1000,
    output_path: Optional[str] = None,
    lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
    with_gpu_stats: bool = False,
    output_names: Optional[Iterable[str]] = None,
    with_pbars: bool = True,
    with_pbar_on_iters: bool = True,
    log_every_iters: int = 100,
    stop_on_nan: bool = True,
    clear_cuda_cache: bool = True,
    save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
    **kwargs: Any,
) -> None:
    if output_path is not None and save_handler is not None:
        raise ValueError(
            "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them"
        )

    if stop_on_nan:
        trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    if lr_scheduler is not None:
        if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
            trainer.add_event_handler(
                Events.ITERATION_COMPLETED,
                lambda engine: cast(_LRScheduler, lr_scheduler).step())
        elif isinstance(lr_scheduler, LRScheduler):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
        else:
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

    if torch.cuda.is_available() and clear_cuda_cache:
        trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)

    if to_save is not None:

        if output_path is None and save_handler is None:
            raise ValueError(
                "If to_save argument is provided then output_path or save_handler arguments should be also defined"
            )
        if output_path is not None:
            save_handler = DiskSaver(dirname=output_path, require_empty=False)

        checkpoint_handler = Checkpoint(to_save,
                                        cast(Union[Callable, BaseSaveHandler],
                                             save_handler),
                                        filename_prefix="training",
                                        **kwargs)
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(every=save_every_iters),
            checkpoint_handler)

    if with_gpu_stats:
        GpuInfo().attach(
            trainer,
            name="gpu",
            event_name=Events.ITERATION_COMPLETED(
                every=log_every_iters)  # type: ignore[arg-type]
        )

    if output_names is not None:

        def output_transform(x: Any, index: int, name: str) -> Any:
            if isinstance(x, Mapping):
                return x[name]
            elif isinstance(x, Sequence):
                return x[index]
            elif isinstance(x, (torch.Tensor, numbers.Number)):
                return x
            else:
                raise TypeError(
                    "Unhandled type of update_function's output. "
                    f"It should either mapping or sequence, but given {type(x)}"
                )

        for i, n in enumerate(output_names):
            RunningAverage(output_transform=partial(output_transform,
                                                    index=i,
                                                    name=n),
                           epoch_bound=False).attach(trainer, n)

    if with_pbars:
        if with_pbar_on_iters:
            ProgressBar(persist=False).attach(
                trainer,
                metric_names="all",
                event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

        ProgressBar(persist=True,
                    bar_format="").attach(trainer,
                                          event_name=Events.EPOCH_STARTED,
                                          closing_event_name=Events.COMPLETED)
Exemplo n.º 7
0
        archived=False,
        global_step_transform=global_step_from_engine(trainer))
    nan_handler = TerminateOnNan()
    coslr = CosineAnnealingScheduler(opt,
                                     "lr",
                                     start_value=LR,
                                     end_value=LR / 4,
                                     cycle_size=TOTAL_UPDATE_STEPS // 1)

    evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpointer,
                                {'_': mude})

    trainer.add_event_handler(Events.ITERATION_COMPLETED, nan_handler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, coslr)

    GpuInfo().attach(trainer, name='gpu')
    pbar.attach(trainer,
                output_transform=lambda output: {'loss': output['loss']},
                metric_names=[f"gpu:{args.gpu} mem(%)"])

    # FIRE
    tb_logger = TensorboardLogger(log_dir=TENSORBOARD_RUN_LOG_DIR_PATH)
    tb_logger.attach(
        trainer,
        log_handler=OutputHandler(
            tag='training',
            output_transform=lambda output: {'loss': output['loss']}),
        event_name=Events.ITERATION_COMPLETED(
            every=LOG_TRAINING_PROGRESS_EVERY_N))
    tb_logger.attach(
        evaluator,
Exemplo n.º 8
0
def test_no_pynvml_package():
    with patch.dict("sys.modules", {"pynvml.smi": None}):
        with pytest.raises(
                RuntimeError,
                match="This contrib module requires pynvml to be installed."):
            GpuInfo()
Exemplo n.º 9
0
def test_no_gpu():
    with pytest.raises(RuntimeError,
                       match="This contrib module requires available GPU"):
        GpuInfo()
def train(): 
    parser = ArgumentParser()
    parser.add_argument("--train_path", type=str, default='data/spolin-train-acl.json', help="Set data path")    
    parser.add_argument("--valid_path", type=str, default='data/spolin-valid.json', help="Set data path")     

    parser.add_argument("--correct_bias", type=bool, default=False, help="Set to true to correct bias for Adam optimizer")
    parser.add_argument("--lr", type=float, default=2e-5, help="Set learning rate")
    parser.add_argument("--n_epochs", type=int, default=4, help="Set number of epochs")
    parser.add_argument("--num_warmup_steps", type=float, default=1000, help="Set number of warm-up steps")
    parser.add_argument("--num_total_steps", type=float, default=10000, help="Set number of total steps")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Set maximum gradient normalization.")
    parser.add_argument("--pretrained_path", type=str, default='bert-base-uncased', help="Choose which pretrained model to use (bert-base-uncased, roberta-base, roberta-large, roberta-large-mnli)")    
    parser.add_argument("--batch_size", type=int, default=32, help="Provide the batch size")    
    parser.add_argument("--random_seed", type=int, default=42, help="Set the random seed")
    parser.add_argument("--test", action='store_true', help="If true, run with small dataset for testing code")
    parser.add_argument("--base", action='store_true', help="If true, run with base experiment configuration (training with spont only) for comparison")

    args = parser.parse_args() 

    logging.basicConfig(level=logging.INFO)
    logger.info("Arguments: {}".format(pformat(args)))

    if 'roberta' in args.pretrained_path: 
        # initialize tokenizer and model 
        logger.info("Initialize model and tokenizer.")
        tokenizer = RobertaTokenizer.from_pretrained(args.pretrained_path, cache_dir = '../pretrained_models')
        model = RobertaForSequenceClassification.from_pretrained(args.pretrained_path, cache_dir='../pretrained_models')

        ### START MODEL MODIFICATION
        # Pretrained model was not trained with token type ids. 
        # fix token type embeddings for finetuning. Without this, the model can only take 0s as valid input for token_type_ids 
        model.config.type_vocab_size = 2 
        model.roberta.embeddings.token_type_embeddings = torch.nn.Embedding(2, model.config.hidden_size)
        model.roberta.embeddings.token_type_embeddings.weight.data.normal_(mean=0.0, std=model.config.initializer_range)

        ### END MOD
    elif 'bert' in args.pretrained_path: 
        model = BertForSequenceClassification.from_pretrained(args.pretrained_path, cache_dir='../pretrained_models')
        tokenizer = BertTokenizer.from_pretrained(args.pretrained_path, cache_dir='../pretrained_models')

    model.to(args.device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']

    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, 
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                        lr=args.lr,
                        correct_bias = args.correct_bias)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.num_warmup_steps, t_total=args.num_total_steps) 

    logger.info("Prepare datasets")
    logger.info("Loading train set...")

    train_data = get_data(args.train_path)
    valid_data = get_data(args.valid_path)

    cornell_valid_data = {k: {'cornell': valid_data[k]['cornell']} for k in valid_data.keys()}
    spont_valid_data = {k: {'spont': valid_data[k]['spont']} for k in valid_data.keys()}

    train_loader, train_sampler = get_data_loaders(args, train_data, args.train_path, tokenizer)
    logger.info("Loading validation set...")
    valid_p = Path(args.valid_path)
    cornell_valid_loader, cornell_valid_sampler = get_data_loaders(args, cornell_valid_data, f"{str(valid_p.parent)}/cornell_{valid_p.name}",  tokenizer)
    spont_valid_loader, spont_valid_sampler = get_data_loaders(args, spont_valid_data, f"{str(valid_p.parent)}/spont_{valid_p.name}", tokenizer)


    # Training function and trainer 
    def update(engine, batch): 
        model.train() 

        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        b_input_ids, b_input_mask, b_input_segment, b_labels = batch

        optimizer.zero_grad()
        #roberta has issues with token_type_ids 
        loss, logits = model(b_input_ids, token_type_ids=b_input_segment, attention_mask=b_input_mask, labels=b_labels)
        # loss, logits = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)


        loss.backward() 
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        
        optimizer.step() 
        scheduler.step() 

        return loss.item(), logits, b_labels

    trainer = Engine(update)     

    # Evaluation function and evaluator 
    def inference(engine, batch): 
        model.eval() 

        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        b_input_ids, b_input_mask, b_input_segment, b_labels = batch
        
        with torch.no_grad(): 
            #roberta has issues with token_type_ids 
            # loss, logits = model(b_input_ids, token_type_ids = None, attention_mask=b_input_mask, labels=b_labels)
            loss, logits = model(b_input_ids, token_type_ids = b_input_segment, attention_mask=b_input_mask, labels=b_labels)
            label_ids = b_labels

        return logits, label_ids, loss.item()
    cornell_evaluator = Engine(inference)
    spont_evaluator = Engine(inference)


    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: cornell_evaluator.run(cornell_valid_loader))
    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: spont_evaluator.run(spont_valid_loader))


    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss") 
    RunningAverage(Accuracy(output_transform=lambda x: (x[1], x[2]))).attach(trainer, "accuracy")
    if torch.cuda.is_available(): 
        GpuInfo().attach(trainer, name='gpu')

    recall = Recall(output_transform=lambda x: (x[0], x[1]))
    precision = Precision(output_transform=lambda x: (x[0], x[1]))
    F1 = (precision * recall * 2 / (precision + recall)).mean()
    accuracy = Accuracy(output_transform=lambda x: (x[0], x[1]))
    metrics = {"recall": recall, "precision": precision, "f1": F1, "accuracy": accuracy, "loss": Average(output_transform=lambda x: x[2])}

    for name, metric in metrics.items(): 
        metric.attach(cornell_evaluator, name) 
        metric.attach(spont_evaluator, name) 


    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=['loss', 'accuracy'])
    pbar.attach(trainer, metric_names=['gpu:0 mem(%)', 'gpu:0 util(%)'])
    
    cornell_evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Cornell validation metrics:\n %s" % pformat(cornell_evaluator.state.metrics)))
    spont_evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Spont validation metrics:\n %s" % pformat(spont_evaluator.state.metrics)))


    tb_logger = TensorboardLogger(log_dir=None)
    tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
    tb_logger.attach(cornell_evaluator, log_handler=OutputHandler(tag="valid", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(spont_evaluator, log_handler=OutputHandler(tag="valid", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)


    # tb_logger.writer.log_dir -> tb_logger.writer.logdir (this is the correct attribute name as seen in: https://tensorboardx.readthedocs.io/en/latest/_modules/tensorboardX/writer.html#SummaryWriter)
    checkpoint_handler = ModelCheckpoint(tb_logger.writer.logdir, 'checkpoint', save_interval=1, n_saved=5)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)})  # "getattr" take care of distributed encapsulation

    torch.save(args, tb_logger.writer.logdir + '/model_training_args.bin')
    getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.logdir, CONFIG_NAME))
    tokenizer.save_vocabulary(tb_logger.writer.logdir)

    trainer.run(train_loader, max_epochs = args.n_epochs)

    if args.n_epochs > 0: 
        os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.logdir, WEIGHTS_NAME))  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Exemplo n.º 11
0
def test_no_pynvml_package(no_site_packages):

    with pytest.raises(
            RuntimeError,
            match="This contrib module requires pynvml to be installed."):
        GpuInfo()
Exemplo n.º 12
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum,
        display_gpu_info, eval):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        F.nll_loss,
                                        device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                "accuracy": Accuracy(),
                                                "nll": Loss(F.nll_loss)
                                            },
                                            device=device)

    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")

    if display_gpu_info:
        from ignite.contrib.metrics import GpuInfo

        GpuInfo().attach(trainer, name="gpu")

    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names="all")

    def score_function(engine):
        val_loss = engine.state.metrics['nll']
        return -val_loss

    stopping_handler = EarlyStopping(patience=3,
                                     score_function=score_function,
                                     trainer=trainer)
    evaluator.add_event_handler(Events.COMPLETED, stopping_handler)

    saving_handler = ModelCheckpoint('models',
                                     'MNIST',
                                     n_saved=2,
                                     create_dir=True,
                                     require_empty=False)
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), saving_handler,
                              {'mymodel': model})

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        pbar.log_message(
            "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        pbar.log_message(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))

        pbar.n = pbar.last_print_n = 0

    if eval:
        model.load_state_dict(torch.load("models/MNIST_model.pth"))
        draw_image()
        image = Image.open("number.png")
        image = ImageOps.invert(image)

        class OneHotNormalization(object):
            def __call__(self, tensor):
                for i in range(tensor.shape[1]):
                    for j in range(tensor.shape[2]):
                        if tensor[0, i, j] > 0:
                            tensor[0, i, j] += .45
                return tensor

        image = Compose([
            Grayscale(),
            Resize((28, 28), interpolation=5),
            ToTensor(),
            OneHotNormalization(),
            Normalize((0.1307, ), (0.3081, ))
        ])(image)
        image = image.reshape(28, 28)
        plt.imshow(image)
        plt.show()
        model.eval()
        image = image.reshape(1, 1, 28, 28)
        print("Hai disegnato un: ", torch.argmax(model(image)).item())
    else:
        trainer.run(train_loader, max_epochs=epochs)