示例#1
0
def create_eval_engine(model, device):

    process_function = get_process_function(model, device)

    eval_engine = Engine(process_function)

    accuracy = Accuracy()
    accuracy.attach(eval_engine, "accuracy")
    recall = Recall(average=False)
    recall.attach(eval_engine, "recall")
    precision = Precision(average=False)
    precision.attach(eval_engine, "precision")
    f1 = (precision * recall * 2 / (precision + recall))
    f1.attach(eval_engine, "f1")
    f2 = (precision * recall * 5 / ((4 * precision) + recall))
    f2.attach(eval_engine, "f2")

    def Fbeta(r, p, beta):
        return torch.mean(
            (1 + beta**2) * p * r / (beta**2 * p + r + 1e-20)).item()

    avg_f1 = MetricsLambda(Fbeta, recall, precision, 1)
    avg_f1.attach(eval_engine, "average f1")
    avg_f2 = MetricsLambda(Fbeta, recall, precision, 2)
    avg_f2.attach(eval_engine, "average f2")
    avg_recall = Recall(average=True)
    avg_recall.attach(eval_engine, "average recall")
    avg_precision = Precision(average=True)
    avg_precision.attach(eval_engine, "average precision")

    return eval_engine
示例#2
0
    def _test(average, n_epochs):
        n_iters = 60
        s = 16
        n_classes = 7

        offset = n_iters * s
        y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device)
        y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device)

        def update(engine, i):
            return (
                y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :],
                y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
            )

        engine = Engine(update)

        pr = Precision(average=average)
        pr.attach(engine, "pr")

        data = list(range(n_iters))
        engine.run(data=data, max_epochs=n_epochs)

        assert "pr" in engine.state.metrics
        res = engine.state.metrics["pr"]
        if isinstance(res, torch.Tensor):
            assert res.device.type == "cpu"
            res = res.cpu().numpy()

        true_res = precision_score(
            y_true.cpu().numpy(), torch.argmax(y_preds, dim=1).cpu().numpy(), average="macro" if average else None
        )

        assert pytest.approx(res) == true_res
示例#3
0
def test_state_metrics():

    y_pred = torch.randint(0, 2, size=(15, 10, 4)).float()
    y = torch.randint(0, 2, size=(15, 10, 4)).long()

    def update_fn(engine, batch):
        y_pred, y = batch
        return y_pred, y

    evaluator = Engine(update_fn)

    precision = Precision(average=False)
    recall = Recall(average=False)
    F1 = precision * recall * 2 / (precision + recall + 1e-20)
    F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1)

    precision.attach(evaluator, "precision")
    recall.attach(evaluator, "recall")
    F1.attach(evaluator, "f1")

    def data(y_pred, y):
        for i in range(y_pred.shape[0]):
            yield (y_pred[i], y[i])

    d = data(y_pred, y)
    state = evaluator.run(d, max_epochs=1)

    assert set(state.metrics.keys()) == set(["precision", "recall", "f1"])
示例#4
0
def test_integration():
    np.random.seed(1)

    n_iters = 10
    batch_size = 10
    n_classes = 10

    y_true = np.arange(0, n_iters * batch_size) % n_classes
    y_pred = 0.2 * np.random.rand(n_iters * batch_size, n_classes)
    for i in range(n_iters * batch_size):
        if np.random.rand() > 0.4:
            y_pred[i, y_true[i]] = 1.0
        else:
            j = np.random.randint(0, n_classes)
            y_pred[i, j] = 0.7

    y_true_batch_values = iter(y_true.reshape(n_iters, batch_size))
    y_pred_batch_values = iter(y_pred.reshape(n_iters, batch_size, n_classes))

    def update_fn(engine, batch):
        y_true_batch = next(y_true_batch_values)
        y_pred_batch = next(y_pred_batch_values)
        return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

    evaluator = Engine(update_fn)

    precision = Precision(average=False)
    recall = Recall(average=False)

    def Fbeta(r, p, beta):
        return torch.mean((1 + beta**2) * p * r / (beta**2 * p + r)).item()

    F1 = MetricsLambda(Fbeta, recall, precision, 1)

    precision.attach(evaluator, "precision")
    recall.attach(evaluator, "recall")
    F1.attach(evaluator, "f1")

    data = list(range(n_iters))
    state = evaluator.run(data, max_epochs=1)

    precision_true = precision_score(y_true,
                                     np.argmax(y_pred, axis=-1),
                                     average=None)
    recall_true = recall_score(y_true,
                               np.argmax(y_pred, axis=-1),
                               average=None)
    f1_true = f1_score(y_true, np.argmax(y_pred, axis=-1), average='macro')

    precision = state.metrics['precision'].numpy()
    recall = state.metrics['recall'].numpy()

    assert precision_true == approx(precision), "{} vs {}".format(
        precision_true, precision)
    assert recall_true == approx(recall), "{} vs {}".format(
        recall_true, recall)
    assert f1_true == approx(state.metrics['f1']), "{} vs {}".format(
        f1_true, state.metrics['f1'])
示例#5
0
def test_integration():
    np.random.seed(1)

    n_iters = 10
    batch_size = 10
    n_classes = 10

    y_true = np.arange(0, n_iters * batch_size, dtype="int64") % n_classes
    y_pred = 0.2 * np.random.rand(n_iters * batch_size, n_classes)
    for i in range(n_iters * batch_size):
        if np.random.rand() > 0.4:
            y_pred[i, y_true[i]] = 1.0
        else:
            j = np.random.randint(0, n_classes)
            y_pred[i, j] = 0.7

    y_true_batch_values = iter(y_true.reshape(n_iters, batch_size))
    y_pred_batch_values = iter(y_pred.reshape(n_iters, batch_size, n_classes))

    def update_fn(engine, batch):
        y_true_batch = next(y_true_batch_values)
        y_pred_batch = next(y_pred_batch_values)
        return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

    evaluator = Engine(update_fn)

    precision = Precision(average=False)
    recall = Recall(average=False)
    F1 = precision * recall * 2 / (precision + recall)

    precision.attach(evaluator, "precision")
    recall.attach(evaluator, "recall")
    F1.attach(evaluator, "f1")

    data = list(range(n_iters))
    state = evaluator.run(data, max_epochs=1)

    precision_true = precision_score(y_true,
                                     np.argmax(y_pred, axis=-1),
                                     average=None)
    recall_true = recall_score(y_true,
                               np.argmax(y_pred, axis=-1),
                               average=None)
    f1_true = f1_score(y_true, np.argmax(y_pred, axis=-1), average=None)

    precision = state.metrics["precision"].numpy()
    recall = state.metrics["recall"].numpy()
    f1 = state.metrics["f1"].numpy()

    assert precision_true == approx(precision), "{} vs {}".format(
        precision_true, precision)
    assert recall_true == approx(recall), "{} vs {}".format(
        recall_true, recall)
    assert f1_true == approx(f1), "{} vs {}".format(f1_true, f1)
示例#6
0
    def _test(average, n_epochs, metric_device):
        n_iters = 60
        s = 16
        n_classes = 7

        offset = n_iters * s
        y_true = torch.randint(0,
                               2,
                               size=(offset * idist.get_world_size(),
                                     n_classes, 6, 8)).to(device)
        y_preds = torch.randint(0,
                                2,
                                size=(offset * idist.get_world_size(),
                                      n_classes, 6, 8)).to(device)

        def update(engine, i):
            return (
                y_preds[i * s + rank * offset:(i + 1) * s + rank * offset,
                        ...],
                y_true[i * s + rank * offset:(i + 1) * s + rank * offset, ...],
            )

        engine = Engine(update)

        pr = Precision(average=average,
                       is_multilabel=True,
                       device=metric_device)
        pr.attach(engine, "pr")

        data = list(range(n_iters))
        engine.run(data=data, max_epochs=n_epochs)

        assert "pr" in engine.state.metrics
        res = engine.state.metrics["pr"]
        res2 = pr.compute()
        if isinstance(res, torch.Tensor):
            res = res.cpu().numpy()
            res2 = res2.cpu().numpy()
            assert (res == res2).all()
        else:
            assert res == res2

        np_y_preds = to_numpy_multilabel(y_preds)
        np_y_true = to_numpy_multilabel(y_true)
        assert pr._type == "multilabel"
        res = res if average else res.mean().item()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UndefinedMetricWarning)
            assert precision_score(np_y_true, np_y_preds,
                                   average="samples") == pytest.approx(res)
示例#7
0
def create_evaluator(model, criterion, cfg):
    def _validation_step(_, batch):
        model.eval()
        with torch.no_grad():
            x, y = batch_to_tensor(batch, cfg)
            x, y = x.to(cfg.device), y.to(cfg.device)

            y_pred, hidden = model(x)
            loss = criterion(y_pred, y)

            if cfg.multi_label:
                y_pred = (y_pred > 0).float()

            return y_pred, y, loss, hidden

    evaluator = Engine(_validation_step)

    accuracy = Accuracy(lambda x: x[0:2], is_multilabel=cfg.multi_label)
    accuracy.attach(evaluator, "acc")

    precision = Precision(lambda x: x[0:2],
                          average=False,
                          is_multilabel=cfg.multi_label)
    precision.attach(evaluator, 'precision')
    MetricsLambda(lambda t: torch.mean(t).item(),
                  precision).attach(evaluator, "MP")

    recall = Recall(lambda x: x[0:2],
                    average=False,
                    is_multilabel=cfg.multi_label)
    recall.attach(evaluator, 'recall')
    MetricsLambda(lambda t: torch.mean(t).item(),
                  recall).attach(evaluator, "MR")

    F1 = 2. * precision * recall / (precision + recall + 1e-20)
    f1 = MetricsLambda(lambda t: torch.mean(t).item(), F1)
    f1.attach(evaluator, "F1")

    Average(lambda x: x[2]).attach(evaluator, 'loss')

    return evaluator
示例#8
0
def create_evaluator(model, cfg):
    def _validation_step(_, batch):
        model.eval()
        with torch.no_grad():
            x_char, x_type, y_word, y_syllable = batch_to_tensor(batch, cfg)
            x_char, x_type, y_word, y_syllable = (t.to(
                cfg.device) for t in [x_char, x_type, y_word, y_syllable])

            logits_word, logits_syllable = model(x_char, x_type)
            loss, word_loss, syllable_loss, align_loss = model.joint_loss(
                logits_word, y_word, logits_syllable, y_syllable)

            return ((logits_word > 0.5).long(), y_word,
                    (logits_syllable > 0.5).long(), y_syllable, loss,
                    word_loss, syllable_loss, align_loss)

    evaluator = Engine(_validation_step)

    w_loss = Accuracy(lambda x: x[0:2])
    w_loss.attach(evaluator, 'w_acc')

    s_acc = Accuracy(lambda x: x[2:4])
    s_acc.attach(evaluator, 's_acc')

    Average(lambda x: x[4]).attach(evaluator, 'loss')
    Average(lambda x: x[5]).attach(evaluator, 'w_loss')
    Average(lambda x: x[6]).attach(evaluator, 's_loss')
    Average(lambda x: x[7]).attach(evaluator, 'a_loss')

    accuracy = Accuracy(lambda x: x[0:2])
    accuracy.attach(evaluator, "acc")

    w_precision = Precision(lambda x: x[0:2])
    w_precision.attach(evaluator, 'WP')
    MetricsLambda(lambda t: torch.mean(t).item(),
                  w_precision).attach(evaluator, "WMP")

    s_precision = Precision(lambda x: x[2:4])
    s_precision.attach(evaluator, 'SP')
    MetricsLambda(lambda t: torch.mean(t).item(),
                  s_precision).attach(evaluator, "SMP")

    w_recall = Recall(lambda x: x[0:2])
    w_recall.attach(evaluator, 'WR')
    MetricsLambda(lambda t: torch.mean(t).item(),
                  w_recall).attach(evaluator, "WMR")

    s_recall = Recall(lambda x: x[2:4])
    s_recall.attach(evaluator, 'SR')
    MetricsLambda(lambda t: torch.mean(t).item(),
                  s_recall).attach(evaluator, "SMR")

    w_f1 = 2. * w_precision * w_recall / (w_precision + w_recall + 1e-20)
    w_f1 = MetricsLambda(lambda t: torch.mean(t).item(), w_f1)
    w_f1.attach(evaluator, "WF1")

    s_f1 = 2. * s_precision * s_recall / (s_precision + s_recall + 1e-20)
    s_f1 = MetricsLambda(lambda t: torch.mean(t).item(), s_f1)
    s_f1.attach(evaluator, "SF1")

    return evaluator
示例#9
0
文件: main.py 项目: mittagessen/blla
def train(name, load, lrate, weight_decay, workers, smooth, device, validation,
          ground_truth):

    if not name:
        name = '{}_{}'.format(lrate, weight_decay)
    click.echo('model output name: {}'.format(name))

    torch.set_num_threads(1)

    train_set = BaselineSet(glob.glob('{}/**/*.seeds.png'.format(ground_truth),
                                      recursive=True),
                            smooth=smooth)
    train_data_loader = DataLoader(dataset=train_set,
                                   num_workers=workers,
                                   batch_size=1,
                                   shuffle=True,
                                   pin_memory=True)
    val_set = BaselineSet(glob.glob('{}/**/*.seeds.png'.format(validation),
                                    recursive=True),
                          smooth=smooth)
    val_data_loader = DataLoader(dataset=val_set,
                                 num_workers=workers,
                                 batch_size=1,
                                 pin_memory=True)

    click.echo('loading network')
    model = ResUNet(refine_encoder=False).to(device)

    if load:
        click.echo('loading weights')
        model = torch.load(load, map_location=device)

    criterion = nn.BCEWithLogitsLoss()
    opti = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                      lr=lrate,
                      weight_decay=weight_decay)

    def score_function(engine):
        val_loss = engine.state.metrics['loss']
        return -val_loss

    def output_preprocess(output):
        o, target = output
        o = torch.sigmoid(o)
        o = denoising_hysteresis_thresh(o.detach().squeeze().cpu().numpy(),
                                        0.8, 0.9, 2.5)
        return torch.from_numpy(o.astype('f')).unsqueeze(0).unsqueeze(0).to(
            device), target.double().to(device)

    trainer = create_supervised_trainer(model,
                                        opti,
                                        criterion,
                                        device=device,
                                        non_blocking=True)
    accuracy = Accuracy(output_transform=output_preprocess)
    precision = Precision(output_transform=output_preprocess)
    recall = Recall(output_transform=output_preprocess)
    loss = Loss(criterion)
    precision = Precision(average=False)
    recall = Recall(average=False)
    f1 = (precision * recall * 2 / (precision + recall)).mean()

    evaluator = create_supervised_evaluator(model,
                                            device=device,
                                            non_blocking=True)

    accuracy.attach(evaluator, 'accuracy')
    precision.attach(evaluator, 'precision')
    recall.attach(evaluator, 'recall')
    loss.attach(evaluator, 'loss')
    f1.attach(evaluator, 'f1')

    ckpt_handler = ModelCheckpoint('.',
                                   name,
                                   save_interval=1,
                                   n_saved=10,
                                   require_empty=False)
    RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

    progress_bar = ProgressBar(persist=True)
    progress_bar.attach(trainer, ['loss'])

    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=ckpt_handler,
                              to_save={'net': model})
    trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED,
                              handler=TerminateOnNan())

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_data_loader)
        metrics = evaluator.state.metrics
        progress_bar.log_message(
            'eval results - epoch {} loss: {:.4f} f1: {:.4f}, accuracy: {:.4f} recall: {:.4f} precision {:.4f}'
            .format(engine.state.epoch, metrics['loss'], metrics['f1'],
                    metrics['accuracy'], metrics['recall'],
                    metrics['precision']))

    trainer.run(train_data_loader, max_epochs=1000)
    def create_zero_shot_eval_engine(self, model, zero_shot_label,
                                     model_mapping, label_mapping,
                                     is_test_multilabel, cpu):

        # Iterate through all labels in both the train and test sets to see which labels correspond to the zero shot label (the unifying label)
        model_target_int = [
            int for label, int in model_mapping.items()
            if zero_shot_label in label.lower()
        ]
        label_target_int = [
            int for label, int in label_mapping.items()
            if zero_shot_label in label.lower()
        ]

        # There should only be one unifying label in each dataset (Possible TODO: Allow multiple labels to map to one unifying label)
        assert len(
            model_target_int
        ) == 1, f"Ambiguous or empty model label list when trying to map {zero_shot_label} to {model_target_int}"
        assert len(
            label_target_int
        ) == 1, f"Ambiguous or empty gold label list when trying to map {zero_shot_label} to {label_target_int}"

        model_target_int = model_target_int[0]
        label_target_int = label_target_int[0]

        def process_function(engine, batch):
            X, y = batch

            if cpu:
                pred = model(X.cpu())
                gold = y.cpu()
            else:
                pred = model(X.cuda())
                gold = y.cuda()

            # Get the softmax of the raw model output (logits)
            pred = torch.softmax(pred, dim=1)

            # Get the probability that the prediction is the target class
            pred_in_class_prob = pred[:, [model_target_int]]

            # Get all the probabilities of all the other classes outside the target class by finding the complement of the in class probability
            pred_out_class_prob = 1 - pred_in_class_prob

            # Create a combined tensor which acts as a set of probabilities for in vs out of the zero-shot target class.
            # In this, 0 is out of class, whilst 1 is in class, so the combined tensor has the out of class probabilities in the 0th column and the in-class probs in the 1st column.
            pred = torch.cat((pred_out_class_prob, pred_in_class_prob), dim=1)

            if is_test_multilabel:
                # If test task is multilabel, get the values from the appropriate column of the truth labels
                gold = gold[:, label_target_int]
            else:
                # To correspond to the above contructed tensor, we set the golds as 1 (I.e. True) if the gold label is the zero-shot label, and 0 (False) if not.
                gold = (gold == label_target_int).long()

            return pred, gold

        eval_engine = Engine(process_function)

        really_small_number = 1e-10

        accuracy = Accuracy()
        accuracy.attach(eval_engine, "accuracy")
        recall = Recall()
        recall.attach(eval_engine, "recall")
        precision = Precision()
        precision.attach(eval_engine, "precision")
        f1 = (precision * recall * 2 /
              (precision + recall + really_small_number))
        f1.attach(eval_engine, "f1")
        f2 = (precision * recall * 5 /
              ((4 * precision) + recall + really_small_number))
        f2.attach(eval_engine, "f2")

        avg_recall = Recall(average=True)
        avg_recall.attach(eval_engine, "average recall")
        avg_precision = Precision(average=True)
        avg_precision.attach(eval_engine, "average precision")
        avg_f1 = (avg_precision * avg_recall * 2 /
                  (avg_precision + avg_recall + really_small_number))
        avg_f1.attach(eval_engine, "average f1")
        avg_f2 = (avg_precision * avg_recall * 5 /
                  ((4 * avg_precision) + avg_recall + really_small_number))
        avg_f2.attach(eval_engine, "average f2")

        return eval_engine
示例#11
0
def run(model, criterion, optimizer, epochs=100, log_interval=10):
    vis = visdom.Visdom(env='ft_lift_ignite')

    train_loader = dataloaders['train']
    val_loader = dataloaders['test']

    # if not vis.check_connection():
    #     raise RuntimeError("Visdom server not running. Please run python -m visdom.server")

    # trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
    # evaluator = create_supervised_evaluator(model,
    #                                         metrics={'accuracy': Accuracy(criterion['label']),
    #                                                  'nll': Loss(criterion['label']),
    #                                                  'precision': Precision(average=True )},
    #                                         device=device)

    def update_model(trainer, batch):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        #inputs, labels = _prepare_batch(batch, device=device)
        optimizer.zero_grad()
        class_output, structured_output = model(inputs)
        loss = criterion['label'](class_output, labels)+criterion['structured'](structured_output, labels)
        loss.backward()
        optimizer.step()
        return {'loss': loss.item(), 
                 'class_output': class_output, 
                 'structured_output': structured_output, 
                 #'inputs': inputs, 
                 'labels': labels}
    trainer = Engine(update_model)

    # def _prepare_batch(batch, device=None, non_blocking=False):
    #     """Prepare batch for training: pass to a device with options

    #     """
    #     x, y = batch
    #     return (convert_tensor(x, device=device, non_blocking=non_blocking),
    #             convert_tensor(y, device=device, non_blocking=non_blocking))

    def _inference(evaluator, batch):
        model.eval()
        with torch.no_grad():
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            class_output, structured_output = model(inputs)
            loss = criterion['label'](class_output, labels)+criterion['structured'](structured_output, labels)
            return {'loss': loss.item(), 
                 'class_output': class_output, 
                 'structured_output': structured_output, 
                 #'inputs': inputs, 
                 'labels': labels}
    evaluator = Engine(_inference)

    output_transform1 = lambda data: (data['class_output'], data['labels'])
    output_transform2 = lambda data: (data['structured_output'], data['labels'])

    metric_accuracy = Accuracy(output_transform=output_transform1)
    metric_accuracy.attach(evaluator, 'accuracy')

    metric_nll = Loss(criterion['label'], output_transform=output_transform1)
    metric_nll.attach(evaluator, 'nll')

    metric_precision = Precision(average=True, output_transform=output_transform1)
    metric_precision.attach(evaluator, 'precision')

    # evaluator = create_supervised_evaluator(model,
    #                                     metrics={'accuracy': Accuracy(output_transform=output_transform1),
    #                                                 'nll': Loss(criterion['label'], output_transform=output_transform1),
    #                                                 'precision': Precision(average=True, output_transform=output_transform1)},
    #                                     device=device)

    handler = ModelCheckpoint('/1116/tmp/lift_models', 'myprefix', save_interval=1, n_saved=150, require_empty=False, create_dir=True)

    train_loss_window = create_plot_window(vis, '#Iterations', 'Loss', 'Training Loss')
    train_avg_loss_window = create_plot_window(vis, '#Iterations', 'Loss', 'Training Average Loss')
    train_avg_accuracy_window = create_plot_window(vis, '#Iterations', 'Accuracy', 'Training Average Accuracy')
    train_avg_precision_window = create_plot_window(vis, '#Iterations', 'Precision', 'Training Average Precision')
    val_avg_loss_window = create_plot_window(vis, '#Epochs', 'Loss', 'Validation Average Loss')
    val_avg_accuracy_window = create_plot_window(vis, '#Epochs', 'Accuracy', 'Validation Average Accuracy')
    val_avg_precision_window = create_plot_window(vis, '#Epochs', 'Precision', 'Validation Average Precison')
    
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % log_interval == 0:
            print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
                  "".format(engine.state.epoch, iter, len(train_loader), engine.state.output['loss']))
            vis.line(X=np.array([engine.state.iteration]),
                     Y=np.array([engine.state.output['loss']]),
                     update='append', win=train_loss_window)
    
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        avg_precision = metrics['precision']
        print("Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f} Avg Precision: {:.2f}"
              .format(engine.state.epoch, avg_accuracy, avg_nll, avg_precision))
        vis.line(X=np.array([engine.state.epoch]), Y=np.array([avg_accuracy]),
                 win=train_avg_accuracy_window, update='append')
        vis.line(X=np.array([engine.state.epoch]), Y=np.array([avg_nll]),
                 win=train_avg_loss_window, update='append')
        vis.line(X=np.array([engine.state.epoch]), Y=np.array([avg_precision]),
                 win=train_avg_precision_window, update='append')

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        avg_precision = metrics['precision']
        print("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f} Avg Precision: {:.2f}"
              .format(engine.state.epoch, avg_accuracy, avg_nll, avg_precision))
        vis.line(X=np.array([engine.state.epoch]), Y=np.array([avg_accuracy]),
                 win=val_avg_accuracy_window, update='append')
        vis.line(X=np.array([engine.state.epoch]), Y=np.array([avg_nll]),
                 win=val_avg_loss_window, update='append')
        vis.line(X=np.array([engine.state.epoch]), Y=np.array([avg_precision]),
                 win=val_avg_precision_window, update='append')        
                 
    
    # kick everything off
    trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {'mymodel': model})
    trainer.run(train_loader, max_epochs=epochs)