def test_override_required_output_keys(): # https://discuss.pytorch.org/t/how-access-inputs-in-custom-ignite-metric/91221/5 import torch.nn as nn from ignite.engine import create_supervised_evaluator counter = [0] class CustomMetric(Metric): required_output_keys = ("y_pred", "y", "x") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def update(self, output): y_pred, y, x = output assert y_pred.shape == (4, 3) assert y.shape == (4, ) assert x.shape == (4, 10) assert x.equal(data[counter[0]][0]) assert y.equal(data[counter[0]][1]) counter[0] += 1 def reset(self): pass def compute(self): pass model = nn.Linear(10, 3) metrics = {"Precision": Precision(), "CustomMetric": CustomMetric()} evaluator = create_supervised_evaluator( model, metrics=metrics, output_transform=lambda x, y, y_pred: { "x": x, "y": y, "y_pred": y_pred }) data = [ (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))), (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))), (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))), ] evaluator.run(data)
def test_override_required_output_keys(): # https://github.com/pytorch/ignite/issues/1415 from ignite.engine import create_supervised_evaluator counter = [0] class DummyLoss2(Loss): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def update(self, output): y_pred, y, criterion_kwargs = output assert y_pred.shape == (4, 3) assert y.shape == (4, ) assert criterion_kwargs == c_kwargs assert y.equal(data[counter[0]][1]) counter[0] += 1 def reset(self): pass def compute(self): pass model = nn.Linear(10, 3) metrics = {"Precision": Precision(), "DummyLoss2": DummyLoss2(nll_loss)} # global criterion kwargs c_kwargs = {"reduction": "sum"} evaluator = create_supervised_evaluator( model, metrics=metrics, output_transform=lambda x, y, y_pred: { "x": x, "y": y, "y_pred": y_pred, "criterion_kwargs": c_kwargs }, ) data = [ (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))), (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))), (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))), ] evaluator.run(data)
def _test(): y_true = np.arange(0, n_iters * batch_size * dist.get_world_size()) % n_classes y_pred = 0.2 * np.random.rand( n_iters * batch_size * dist.get_world_size(), n_classes ) for i in range(n_iters * batch_size * dist.get_world_size()): if np.random.rand() > 0.4: y_pred[i, y_true[i]] = 1.0 else: j = np.random.randint(0, n_classes) y_pred[i, j] = 0.7 y_true = y_true.reshape(n_iters * dist.get_world_size(), batch_size) y_pred = y_pred.reshape(n_iters * dist.get_world_size(), batch_size, n_classes) def update_fn(engine, i): y_true_batch = y_true[i + rank * n_iters, ...] y_pred_batch = y_pred[i + rank * n_iters, ...] return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) evaluator = Engine(update_fn) precision = Precision(average=False, device=device) recall = Recall(average=False, device=device) def Fbeta(r, p, beta): return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r)).item() F1 = MetricsLambda(Fbeta, recall, precision, 1) F1.attach(evaluator, "f1") another_f1 = ( (1.0 + precision * recall * 2 / (precision + recall + 1e-20)).mean().item() ) another_f1.attach(evaluator, "ff1") data = list(range(n_iters)) state = evaluator.run(data, max_epochs=1) assert "f1" in state.metrics assert "ff1" in state.metrics f1_true = f1_score( y_true.ravel(), np.argmax(y_pred.reshape(-1, n_classes), axis=-1), average="macro", ) assert f1_true == approx(state.metrics["f1"]) assert 1.0 + f1_true == approx(state.metrics["ff1"])
def _test_distrib_integration(device): rank = idist.get_rank() torch.manual_seed(12) def _test(p, r, average, n_epochs): n_iters = 60 s = 16 n_classes = 7 offset = n_iters * s y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device) y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device) def update(engine, i): return ( y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :], y_true[i * s + rank * offset : (i + 1) * s + rank * offset], ) engine = Engine(update) fbeta = Fbeta(beta=2.5, average=average, device=device) fbeta.attach(engine, "f2.5") data = list(range(n_iters)) engine.run(data=data, max_epochs=n_epochs) assert "f2.5" in engine.state.metrics res = engine.state.metrics["f2.5"] if isinstance(res, torch.Tensor): res = res.cpu().numpy() true_res = fbeta_score( y_true.cpu().numpy(), torch.argmax(y_preds, dim=1).cpu().numpy(), beta=2.5, average="macro" if average else None, ) assert pytest.approx(res) == true_res _test(None, None, average=True, n_epochs=1) _test(None, None, average=True, n_epochs=2) precision = Precision(average=False) recall = Recall(average=False) _test(precision, recall, average=False, n_epochs=1) _test(precision, recall, average=False, n_epochs=2)
def test_integration(): def _test(p, r, average): np.random.seed(1) n_iters = 10 batch_size = 10 n_classes = 10 y_true = np.arange(0, n_iters * batch_size) % n_classes y_pred = 0.2 * np.random.rand(n_iters * batch_size, n_classes) for i in range(n_iters * batch_size): if np.random.rand() > 0.4: y_pred[i, y_true[i]] = 1.0 else: j = np.random.randint(0, n_classes) y_pred[i, j] = 0.7 y_true_batch_values = iter(y_true.reshape(n_iters, batch_size)) y_pred_batch_values = iter(y_pred.reshape(n_iters, batch_size, n_classes)) def update_fn(engine, batch): y_true_batch = next(y_true_batch_values) y_pred_batch = next(y_pred_batch_values) return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) evaluator = Engine(update_fn) f2 = Fbeta(beta=2.0, average=average, precision=p, recall=r) f2.attach(evaluator, "f2") data = list(range(n_iters)) state = evaluator.run(data, max_epochs=1) f2_true = fbeta_score(y_true, np.argmax(y_pred, axis=-1), average='macro' if average else None, beta=2.0) if isinstance(state.metrics['f2'], torch.Tensor): assert (f2_true == state.metrics['f2'].numpy()).all(), "{} vs {}".format(f2_true, state.metrics['f2']) else: assert f2_true == pytest.approx(state.metrics['f2']), "{} vs {}".format(f2_true, state.metrics['f2']) _test(None, None, False) _test(None, None, True) precision = Precision(average=False) recall = Recall(average=False) _test(precision, recall, False) _test(precision, recall, True)
def build_metrics(): metrics = { 'precision': Precision(lambda x: (x[0].argmax(dim=1), x[1])), 'recall': Recall(lambda x: (x[0].argmax(dim=1), x[1])), 'accuracy': Accuracy(lambda x: (x[0].argmax(dim=1), x[1])), 'confusion_matrix': ConfusionMatrix(2, output_transform=prepare_confusion_matrix), 'metric_output_results': MetricOutputResults(lambda x: (x[0].argmax(dim=1), x[1], x[0], x[2])), 'metric_last_layer': MetricLastLayer(lambda x: (x[3])), #'auroc':AUROC(lambda x: (x[0], x[1])), } return metrics
def _test(average): pr = Precision(average=average) y_pred = torch.softmax(torch.rand(4, 4), dim=1) y = torch.ones(4).type(torch.LongTensor) pr.update((y_pred, y)) y_pred = torch.rand(4, 1) y = torch.ones(4).type(torch.LongTensor) with pytest.raises(RuntimeError): pr.update((y_pred, y))
def _test(average): pr = Precision(average=average) y_pred = torch.softmax(torch.rand(4, 4), dim=1) y = torch.ones(4).long() pr.update((y_pred, y)) y_pred = torch.randint(0, 2, size=(4,)) y = torch.ones(4).long() with pytest.raises(RuntimeError): pr.update((y_pred, y))
def test_multiclass_wrong_inputs(): pr = Precision() with pytest.raises(ValueError): # incompatible shapes pr.update( (torch.rand(10, 5, 4), torch.randint(0, 2, size=(10, )).type(torch.LongTensor))) with pytest.raises(ValueError): # incompatible shapes pr.update((torch.rand(10, 5, 6), torch.randint(0, 5, size=(10, 5)).type(torch.LongTensor))) with pytest.raises(ValueError): # incompatible shapes pr.update((torch.rand(10), torch.randint(0, 5, size=(10, 5, 6)).type(torch.LongTensor)))
def test_incorrect_shape(): precision = Precision() y_pred = torch.zeros(2, 3, 2, 2) y = torch.zeros(2, 3) with pytest.raises(ValueError): precision.update((y_pred, y)) y_pred = torch.zeros(2, 3, 2, 2) y = torch.zeros(2, 3, 4, 4) with pytest.raises(ValueError): precision.update((y_pred, y))
def test_integration_ingredients_not_attached(): np.random.seed(1) n_iters = 10 batch_size = 10 n_classes = 10 y_true = np.arange(0, n_iters * batch_size) % n_classes y_pred = 0.2 * np.random.rand(n_iters * batch_size, n_classes) for i in range(n_iters * batch_size): if np.random.rand() > 0.4: y_pred[i, y_true[i]] = 1.0 else: j = np.random.randint(0, n_classes) y_pred[i, j] = 0.7 y_true_batch_values = iter(y_true.reshape(n_iters, batch_size)) y_pred_batch_values = iter(y_pred.reshape(n_iters, batch_size, n_classes)) def update_fn(engine, batch): y_true_batch = next(y_true_batch_values) y_pred_batch = next(y_pred_batch_values) return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) evaluator = Engine(update_fn) precision = Precision(average=False) recall = Recall(average=False) def Fbeta(r, p, beta): return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r)).item() F1 = MetricsLambda(Fbeta, recall, precision, 1) F1.attach(evaluator, "f1") data = list(range(n_iters)) state = evaluator.run(data, max_epochs=1) f1_true = f1_score(y_true, np.argmax(y_pred, axis=-1), average="macro") assert f1_true == approx(state.metrics["f1"]), "{} vs {}".format( f1_true, state.metrics["f1"] )
def Fbeta(beta, output_transform=lambda x: x, average=True, precision=None, recall=None, device=None): """Calculates F-beta score Args: beta (float): weight of precision in harmonic mean output_transform (callable, optional): a callable that is used to transform the :class:`~ignite.engine.Engine`'s `process_function`'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. average (bool, optional): if True, F-beta score is computed as the unweighted average (across all classes in multiclass case), otherwise, returns a tensor with F-beta score for each class in multiclass case. precision (Precision, optional): precision object metric with `average=False` to compute F-beta score recall (Precision, optional): recall object metric with `average=False` to compute F-beta score device (str of torch.device, optional): device specification in case of distributed computation usage. In most of the cases, it can be defined as "cuda:local_rank" or "cuda" if already set `torch.cuda.set_device(local_rank)`. By default, if a distributed process group is initialized and available, device is set to `cuda`. Returns: MetricsLambda, F-beta metric """ if not (beta > 0): raise ValueError("Beta should be a positive integer, but given {}".format(beta)) if precision is None: precision = Precision(output_transform=output_transform, average=False, device=device) elif precision._average: raise ValueError("Input precision metric should have average=False") if recall is None: recall = Recall(output_transform=output_transform, average=False, device=device) elif recall._average: raise ValueError("Input recall metric should have average=False") fbeta = (1.0 + beta ** 2) * precision * recall / (beta ** 2 * precision + recall) if average: fbeta = fbeta.mean().item() return fbeta
def test_metrics_lambda_update_and_attach_together(): y_pred = torch.randint(0, 2, size=(15, 10, 4)).float() y = torch.randint(0, 2, size=(15, 10, 4)).long() def update_fn(engine, batch): y_pred, y = batch return y_pred, y engine = Engine(update_fn) precision = Precision(average=False) recall = Recall(average=False) def Fbeta(r, p, beta): return torch.mean((1 + beta**2) * p * r / (beta**2 * p + r)).item() F1 = MetricsLambda(Fbeta, recall, precision, 1) F1.attach(engine, "f1") with pytest.raises( ValueError, match=r"MetricsLambda is already attached to an engine"): F1.update((y_pred, y)) y_pred = torch.randint(0, 2, size=(15, 10, 4)).float() y = torch.randint(0, 2, size=(15, 10, 4)).long() F1 = MetricsLambda(Fbeta, recall, precision, 1) F1.update((y_pred, y)) engine = Engine(update_fn) with pytest.raises(ValueError, match=r"The underlying metrics are already updated"): F1.attach(engine, "f1") F1.reset() F1.attach(engine, "f1")
def test_predict(model,dataloader_test,use_cuda): if use_cuda: model = model.cuda() precision = Precision() recall = Recall() f1 = Fbeta(beta=1.0, average=True, precision=precision, recall=recall) for i,(img, label) in enumerate(dataloader_test): img, labels = Variable(img),Variable(label) if use_cuda: img = img.cuda() label = label.cuda() pred = model(img) _,my_label = torch.max(label, dim=1) precision.update((pred, my_label)) recall.update((pred, my_label)) f1.update((pred, my_label)) precision.compute() recall.compute() print("\tF1 Score: {:0.2f}".format(f1.compute()*100))
def test_multilabel_wrong_inputs(): pr = Precision(average=True, is_multilabel=True) with pytest.raises(ValueError): # incompatible shapes pr.update((torch.randint(0, 2, size=(10, )), torch.randint(0, 2, size=(10, )).long())) with pytest.raises(ValueError): # incompatible y_pred pr.update((torch.rand(10, 5), torch.randint(0, 2, size=(10, 5)).long())) with pytest.raises(ValueError): # incompatible y pr.update((torch.randint(0, 5, size=(10, 5, 6)), torch.rand(10))) with pytest.raises(ValueError): # incompatible shapes between two updates pr.update((torch.randint(0, 2, size=(20, 5)), torch.randint(0, 2, size=(20, 5)).long())) pr.update((torch.randint(0, 2, size=(20, 6)), torch.randint(0, 2, size=(20, 6)).long()))
def _test(average): pr = Precision(average=average) y_pred = torch.randint(0, 2, size=(10,)) y = torch.randint(0, 2, size=(10,)).long() pr.update((y_pred, y)) np_y = y.numpy().ravel() np_y_pred = y_pred.numpy().ravel() assert pr._type == "binary" assert isinstance(pr.compute(), float if average else torch.Tensor) pr_compute = pr.compute() if average else pr.compute().numpy() assert precision_score(np_y, np_y_pred, average="binary") == pytest.approx(pr_compute) pr.reset() y_pred = torch.randint(0, 2, size=(10,)) y = torch.randint(0, 2, size=(10,)).long() pr.update((y_pred, y)) np_y = y.numpy().ravel() np_y_pred = y_pred.numpy().ravel() assert pr._type == "binary" assert isinstance(pr.compute(), float if average else torch.Tensor) pr_compute = pr.compute() if average else pr.compute().numpy() assert precision_score(np_y, np_y_pred, average="binary") == pytest.approx(pr_compute) pr.reset() y_pred = torch.Tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.51]) y_pred = torch.round(y_pred) y = torch.randint(0, 2, size=(10,)).long() pr.update((y_pred, y)) np_y = y.numpy().ravel() np_y_pred = y_pred.numpy().ravel() assert pr._type == "binary" assert isinstance(pr.compute(), float if average else torch.Tensor) pr_compute = pr.compute() if average else pr.compute().numpy() assert precision_score(np_y, np_y_pred, average="binary") == pytest.approx(pr_compute) # Batched Updates pr.reset() y_pred = torch.randint(0, 2, size=(100,)) y = torch.randint(0, 2, size=(100,)).long() batch_size = 16 n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size pr.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) np_y = y.numpy().ravel() np_y_pred = y_pred.numpy().ravel() assert pr._type == "binary" assert isinstance(pr.compute(), float if average else torch.Tensor) pr_compute = pr.compute() if average else pr.compute().numpy() assert precision_score(np_y, np_y_pred, average="binary") == pytest.approx(pr_compute)
def _test(average): pr = Precision(average=average, is_multilabel=True) y_pred = torch.randint(0, 2, size=(10, 5, 18, 16)) y = torch.randint(0, 2, size=(10, 5, 18, 16)).long() pr.update((y_pred, y)) np_y_pred = to_numpy_multilabel(y_pred) np_y = to_numpy_multilabel(y) assert pr._type == "multilabel" pr_compute = pr.compute() if average else pr.compute().mean().item() with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) assert precision_score(np_y, np_y_pred, average="samples") == pytest.approx(pr_compute) pr.reset() y_pred = torch.randint(0, 2, size=(10, 4, 20, 23)) y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() pr.update((y_pred, y)) np_y_pred = to_numpy_multilabel(y_pred) np_y = to_numpy_multilabel(y) assert pr._type == "multilabel" pr_compute = pr.compute() if average else pr.compute().mean().item() with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) assert precision_score(np_y, np_y_pred, average="samples") == pytest.approx(pr_compute) # Batched Updates pr.reset() y_pred = torch.randint(0, 2, size=(100, 5, 12, 14)) y = torch.randint(0, 2, size=(100, 5, 12, 14)).long() batch_size = 16 n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size pr.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) np_y = to_numpy_multilabel(y) np_y_pred = to_numpy_multilabel(y_pred) assert pr._type == "multilabel" pr_compute = pr.compute() if average else pr.compute().mean().item() with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) assert precision_score(np_y, np_y_pred, average="samples") == pytest.approx(pr_compute)
def test_binary_wrong_inputs(): pr = Precision() with pytest.raises(ValueError): # y has not only 0 or 1 values pr.update((torch.randint(0, 2, size=(10,)).long(), torch.arange(0, 10).long())) with pytest.raises(ValueError): # y_pred values are not thresholded to 0, 1 values pr.update((torch.rand(10,), torch.randint(0, 2, size=(10,)).long())) with pytest.raises(ValueError): # incompatible shapes pr.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5)).long())) with pytest.raises(ValueError): # incompatible shapes pr.update((torch.randint(0, 2, size=(10, 5, 6)).long(), torch.randint(0, 2, size=(10,)).long())) with pytest.raises(ValueError): # incompatible shapes pr.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5, 6)).long()))
def _test(average): pr = Precision(average=average) y_pred = torch.rand(10, 5, 18, 16) y = torch.randint(0, 5, size=(10, 18, 16)).long() pr.update((y_pred, y)) num_classes = y_pred.shape[1] np_y_pred = y_pred.argmax(dim=1).numpy().ravel() np_y = y.numpy().ravel() assert pr._type == "multiclass" assert isinstance(pr.compute(), float if average else torch.Tensor) pr_compute = pr.compute() if average else pr.compute().numpy() sk_average_parameter = "macro" if average else None with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) sk_compute = precision_score(np_y, np_y_pred, labels=range(0, num_classes), average=sk_average_parameter) assert sk_compute == pytest.approx(pr_compute) pr.reset() y_pred = torch.rand(10, 7, 20, 12) y = torch.randint(0, 7, size=(10, 20, 12)).long() pr.update((y_pred, y)) num_classes = y_pred.shape[1] np_y_pred = y_pred.argmax(dim=1).numpy().ravel() np_y = y.numpy().ravel() assert pr._type == "multiclass" assert isinstance(pr.compute(), float if average else torch.Tensor) pr_compute = pr.compute() if average else pr.compute().numpy() sk_average_parameter = "macro" if average else None with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) sk_compute = precision_score(np_y, np_y_pred, labels=range(0, num_classes), average=sk_average_parameter) assert sk_compute == pytest.approx(pr_compute) # Batched Updates pr.reset() y_pred = torch.rand(100, 8, 12, 14) y = torch.randint(0, 8, size=(100, 12, 14)).long() batch_size = 16 n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size pr.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) num_classes = y_pred.shape[1] np_y = y.numpy().ravel() np_y_pred = y_pred.argmax(dim=1).numpy().ravel() assert pr._type == "multiclass" assert isinstance(pr.compute(), float if average else torch.Tensor) pr_compute = pr.compute() if average else pr.compute().numpy() sk_average_parameter = "macro" if average else None with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) sk_compute = precision_score(np_y, np_y_pred, labels=range(0, num_classes), average=sk_average_parameter) assert sk_compute == pytest.approx(pr_compute)
def test_indexing_metric(): def _test(ignite_metric, sklearn_metic, sklearn_args, index, num_classes=5): y_pred = torch.rand(15, 10, num_classes).float() y = torch.randint(0, num_classes, size=(15, 10)).long() def update_fn(engine, batch): y_pred, y = batch return y_pred, y metrics = { "metric": ignite_metric[index], "metric_wo_index": ignite_metric } validator = Engine(update_fn) for name, metric in metrics.items(): metric.attach(validator, name) def data(y_pred, y): for i in range(y_pred.shape[0]): yield (y_pred[i], y[i]) d = data(y_pred, y) state = validator.run(d, max_epochs=1, epoch_length=y_pred.shape[0]) sklearn_output = sklearn_metic( y.view(-1).numpy(), y_pred.view(-1, num_classes).argmax(dim=1).numpy(), **sklearn_args) assert (state.metrics["metric_wo_index"][index] == state.metrics["metric"]).all() assert np.allclose(state.metrics["metric"].numpy(), sklearn_output) num_classes = 5 labels = list(range(0, num_classes, 2)) _test(Precision(), precision_score, { "labels": labels, "average": None }, index=labels) labels = list(range(num_classes - 1, 0, -2)) _test(Precision(), precision_score, { "labels": labels, "average": None }, index=labels) labels = [1] _test(Precision(), precision_score, { "labels": labels, "average": None }, index=labels) labels = list(range(0, num_classes, 2)) _test(Recall(), recall_score, { "labels": labels, "average": None }, index=labels) labels = list(range(num_classes - 1, 0, -2)) _test(Recall(), recall_score, { "labels": labels, "average": None }, index=labels) labels = [1] _test(Recall(), recall_score, { "labels": labels, "average": None }, index=labels) # np.ix_ is used to allow for a 2D slice of a matrix. This is required to get accurate result from # ConfusionMatrix. ConfusionMatrix must be sliced the same row-wise and column-wise. labels = list(range(0, num_classes, 2)) _test(ConfusionMatrix(num_classes), confusion_matrix, {"labels": labels}, index=np.ix_(labels, labels)) labels = list(range(num_classes - 1, 0, -2)) _test(ConfusionMatrix(num_classes), confusion_matrix, {"labels": labels}, index=np.ix_(labels, labels)) labels = [1] _test(ConfusionMatrix(num_classes), confusion_matrix, {"labels": labels}, index=np.ix_(labels, labels))
def test_multiclass_wrong_inputs(): pr = Precision() with pytest.raises(ValueError): # incompatible shapes pr.update((torch.rand(10, 5, 4), torch.randint(0, 2, size=(10,)).long())) with pytest.raises(ValueError): # incompatible shapes pr.update((torch.rand(10, 5, 6), torch.randint(0, 5, size=(10, 5)).long())) with pytest.raises(ValueError): # incompatible shapes pr.update((torch.rand(10), torch.randint(0, 5, size=(10, 5, 6)).long())) pr = Precision(average=True) with pytest.raises(ValueError): # incompatible shapes between two updates pr.update((torch.rand(10, 5), torch.randint(0, 5, size=(10,)).long())) pr.update((torch.rand(10, 6), torch.randint(0, 5, size=(10,)).long())) with pytest.raises(ValueError): # incompatible shapes between two updates pr.update((torch.rand(10, 5, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long())) pr.update((torch.rand(10, 6, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long())) pr = Precision(average=False) with pytest.raises(ValueError): # incompatible shapes between two updates pr.update((torch.rand(10, 5), torch.randint(0, 5, size=(10,)).long())) pr.update((torch.rand(10, 6), torch.randint(0, 5, size=(10,)).long())) with pytest.raises(ValueError): # incompatible shapes between two updates pr.update((torch.rand(10, 5, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long())) pr.update((torch.rand(10, 6, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long()))
def test_pytorch_operators(): def _test(composed_metric, metric_name, compute_true_value_fn): metrics = { metric_name: composed_metric, } y_pred = torch.rand(15, 10, 5).float() y = torch.randint(0, 5, size=(15, 10)).long() def update_fn(engine, batch): y_pred, y = batch return y_pred, y validator = Engine(update_fn) for name, metric in metrics.items(): metric.attach(validator, name) def data(y_pred, y): for i in range(y_pred.shape[0]): yield (y_pred[i], y[i]) d = data(y_pred, y) state = validator.run(d, max_epochs=1, epoch_length=y_pred.shape[0]) assert set(state.metrics.keys()) == set([metric_name]) np_y_pred = np.argmax(y_pred.numpy(), axis=-1).ravel() np_y = y.numpy().ravel() assert state.metrics[metric_name] == approx( compute_true_value_fn(np_y_pred, np_y)) precision_1 = Precision(average=False) precision_2 = Precision(average=False) norm_summed_precision = (precision_1 + precision_2).norm(p=10) def compute_true_norm_summed_precision(y_pred, y): p1 = precision_score(y, y_pred, average=None) p2 = precision_score(y, y_pred, average=None) return np.linalg.norm(p1 + p2, ord=10) _test(norm_summed_precision, "mean summed precision", compute_true_value_fn=compute_true_norm_summed_precision) precision = Precision(average=False) recall = Recall(average=False) sum_precision_recall = (precision + recall).sum() def compute_sum_precision_recall(y_pred, y): p = precision_score(y, y_pred, average=None) r = recall_score(y, y_pred, average=None) return np.sum(p + r) _test(sum_precision_recall, "sum precision recall", compute_true_value_fn=compute_sum_precision_recall) precision = Precision(average=False) recall = Recall(average=False) f1 = (precision * recall * 2 / (precision + recall + 1e-20)).mean() def compute_f1(y_pred, y): f1 = f1_score(y, y_pred, average="macro") return f1 _test(f1, "f1", compute_true_value_fn=compute_f1)
def test_binary_input(average): pr = Precision(average=average) def _test(y_pred, y, batch_size): pr.reset() pr.update((y_pred, y)) np_y = y.numpy().ravel() np_y_pred = y_pred.numpy().ravel() if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size pr.update( (y_pred[idx:idx + batch_size], y[idx:idx + batch_size])) assert pr._type == "binary" assert isinstance(pr.compute(), float if average else torch.Tensor) pr_compute = pr.compute() if average else pr.compute().numpy() assert precision_score(np_y, np_y_pred, average="binary") == pytest.approx(pr_compute) def get_test_cases(): test_cases = [ # Binary accuracy on input of shape (N, 1) or (N, ) (torch.randint(0, 2, size=(10, )), torch.randint(0, 2, size=(10, )), 1), (torch.randint(0, 2, size=(10, 1)), torch.randint(0, 2, size=(10, 1)), 1), # updated batches (torch.randint(0, 2, size=(50, )), torch.randint(0, 2, size=(50, )), 16), (torch.randint(0, 2, size=(50, 1)), torch.randint(0, 2, size=(50, 1)), 16), # Binary accuracy on input of shape (N, L) (torch.randint(0, 2, size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1), (torch.randint(0, 2, size=(10, 1, 5)), torch.randint(0, 2, size=(10, 1, 5)), 1), # updated batches (torch.randint(0, 2, size=(50, 5)), torch.randint(0, 2, size=(50, 5)), 16), (torch.randint(0, 2, size=(50, 1, 5)), torch.randint(0, 2, size=(50, 1, 5)), 16), # Binary accuracy on input of shape (N, H, W) (torch.randint(0, 2, size=(10, 12, 10)), torch.randint(0, 2, size=(10, 12, 10)), 1), (torch.randint(0, 2, size=(10, 1, 12, 10)), torch.randint(0, 2, size=(10, 1, 12, 10)), 1), # updated batches (torch.randint(0, 2, size=(50, 12, 10)), torch.randint(0, 2, size=(50, 12, 10)), 16), (torch.randint(0, 2, size=(50, 1, 12, 10)), torch.randint(0, 2, size=(50, 1, 12, 10)), 16), ] return test_cases for _ in range(5): # check multiple random inputs as random exact occurencies are rare test_cases = get_test_cases() for y_pred, y, batch_size in test_cases: _test(y, y_pred, batch_size)
def test_recursive_attachment(): def _test(composed_metric, metric_name, compute_true_value_fn): metrics = { metric_name: composed_metric, } y_pred = torch.randint(0, 2, size=(15, 10, 4)).float() y = torch.randint(0, 2, size=(15, 10, 4)).long() def update_fn(engine, batch): y_pred, y = batch return y_pred, y validator = Engine(update_fn) for name, metric in metrics.items(): metric.attach(validator, name) def data(y_pred, y): for i in range(y_pred.shape[0]): yield (y_pred[i], y[i]) d = data(y_pred, y) state = validator.run(d, max_epochs=1) assert set(state.metrics.keys()) == set([ metric_name, ]) np_y_pred = y_pred.numpy().ravel() np_y = y.numpy().ravel() assert state.metrics[metric_name] == approx( compute_true_value_fn(np_y_pred, np_y)) precision_1 = Precision() precision_2 = Precision() summed_precision = precision_1 + precision_2 def compute_true_summed_precision(y_pred, y): p1 = precision_score(y, y_pred) p2 = precision_score(y, y_pred) return p1 + p2 _test(summed_precision, "summed precision", compute_true_value_fn=compute_true_summed_precision) precision_1 = Precision() precision_2 = Precision() mean_precision = (precision_1 + precision_2) / 2 def compute_true_mean_precision(y_pred, y): p1 = precision_score(y, y_pred) p2 = precision_score(y, y_pred) return (p1 + p2) * 0.5 _test(mean_precision, "mean precision", compute_true_value_fn=compute_true_mean_precision) precision_1 = Precision() precision_2 = Precision() some_metric = 2.0 + 0.2 * (precision_1 * precision_2 + precision_1 - precision_2)**0.5 def compute_true_somemetric(y_pred, y): p1 = precision_score(y, y_pred) p2 = precision_score(y, y_pred) return 2.0 + 0.2 * (p1 * p2 + p1 - p2)**0.5 _test(some_metric, "some metric", compute_true_somemetric)
def _test_distrib_integration_multilabel(device): from ignite.engine import Engine rank = idist.get_rank() torch.manual_seed(12) def _test(average, n_epochs): n_iters = 60 s = 16 n_classes = 7 offset = n_iters * s y_true = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device) y_preds = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device) def update(engine, i): return ( y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, ...], y_true[i * s + rank * offset : (i + 1) * s + rank * offset, ...], ) engine = Engine(update) pr = Precision(average=average, is_multilabel=True) pr.attach(engine, "pr") data = list(range(n_iters)) engine.run(data=data, max_epochs=n_epochs) assert "pr" in engine.state.metrics res = engine.state.metrics["pr"] res2 = pr.compute() if isinstance(res, torch.Tensor): res = res.cpu().numpy() res2 = res2.cpu().numpy() assert (res == res2).all() else: assert res == res2 with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) true_res = precision_score( to_numpy_multilabel(y_true), to_numpy_multilabel(y_preds), average="samples" if average else None ) assert pytest.approx(res) == true_res for _ in range(2): _test(average=True, n_epochs=1) _test(average=True, n_epochs=2) if idist.get_world_size() > 1: with pytest.warns( RuntimeWarning, match="Precision/Recall metrics do not work in distributed setting when " "average=False and is_multilabel=True", ): pr = Precision(average=False, is_multilabel=True) y_pred = torch.randint(0, 2, size=(4, 3, 6, 8)) y = torch.randint(0, 2, size=(4, 3, 6, 8)).long() pr.update((y_pred, y)) pr_compute1 = pr.compute() pr_compute2 = pr.compute() assert len(pr_compute1) == 4 * 6 * 8 assert (pr_compute1 == pr_compute2).all()
def create_eval_engine(model, is_multilabel, n_classes, cpu): def process_function(engine, batch): X, y = batch if cpu: pred = model(X.cpu()) gold = y.cpu() else: pred = model(X.cuda()) gold = y.cuda() return pred, gold eval_engine = Engine(process_function) if is_multilabel: accuracy = MulticlassOverallAccuracy(n_classes=n_classes) accuracy.attach(eval_engine, "accuracy") per_class_accuracy = MulticlassPerClassAccuracy(n_classes=n_classes) per_class_accuracy.attach(eval_engine, "per class accuracy") recall = MulticlassRecall(n_classes=n_classes) recall.attach(eval_engine, "recall") precision = MulticlassPrecision(n_classes=n_classes) precision.attach(eval_engine, "precision") f1 = MulticlassF(n_classes=n_classes, f_n=1) f1.attach(eval_engine, "f1") f2= MulticlassF(n_classes=n_classes, f_n=2) f2.attach(eval_engine, "f2") avg_recall = MulticlassRecall(n_classes=n_classes, average=True) avg_recall.attach(eval_engine, "average recall") avg_precision = MulticlassPrecision(n_classes=n_classes, average=True) avg_precision.attach(eval_engine, "average precision") avg_f1 = MulticlassF(n_classes=n_classes, average=True, f_n=1) avg_f1.attach(eval_engine, "average f1") avg_f2= MulticlassF(n_classes=n_classes, average=True, f_n=2) avg_f2.attach(eval_engine, "average f2") else: accuracy = Accuracy() accuracy.attach(eval_engine, "accuracy") recall = Recall(average=False) recall.attach(eval_engine, "recall") precision = Precision(average=False) precision.attach(eval_engine, "precision") confusion_matrix = ConfusionMatrix(num_classes=n_classes) confusion_matrix.attach(eval_engine, "confusion_matrix") f1 = (precision * recall * 2 / (precision + recall)) f1.attach(eval_engine, "f1") f2 = (precision * recall * 5 / ((4*precision) + recall)) f2.attach(eval_engine, "f2") def Fbeta(r, p, beta): return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r + 1e-20)).item() avg_f1 = MetricsLambda(Fbeta, recall, precision, 1) avg_f1.attach(eval_engine, "average f1") avg_f2 = MetricsLambda(Fbeta, recall, precision, 2) avg_f2.attach(eval_engine, "average f2") avg_recall = Recall(average=True) avg_recall.attach(eval_engine, "average recall") avg_precision = Precision(average=True) avg_precision.attach(eval_engine, "average precision") if n_classes == 2: top_k = TopK(k=10, label_idx_of_interest=0) top_k.attach(eval_engine, "top_k") return eval_engine
def run(tb, vb, lr, epochs, writer): device = os.environ['main-device'] logging.info('Training program start!') logging.info('Configuration:') logging.info('\n'+json.dumps(INFO, indent=2)) # ------------------------------------ # 1. Define dataloader train_loader, train4val_loader, val_loader = get_dataloaders(tb, vb) # ------------------------------------ # 2. Define model model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes']) model = carrier(model) # ------------------------------------ # 3. Define optimizer optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9) # ------------------------------------ # 4. Define metrics metrics = { 'accuracy': Accuracy(), 'loss': Loss(nn.functional.cross_entropy), 'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(), Recall(), train_loader.dataset.classes), 'cmatrix': MetricsLambda(CMatrixTable, ConfusionMatrix(7), train_loader.dataset.classes) } # ------------------------------------ # 5. Create trainer trainer = create_supervised_trainer(model, optimizer, nn.functional.cross_entropy, device=device) # ------------------------------------ # 6. Create evaluator evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) desc = 'ITERATION - loss: {:.4f}' pbar = tqdm( initial=0, leave=False, total=len(train_loader), desc=desc.format(0) ) # ------------------------------------ # 7. Create event hooks @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): log_interval = 5 iter = (engine.state.iteration - 1) % len(train_loader) + 1 if iter % log_interval == 0: pbar.desc = desc.format(engine.state.output) pbar.update(log_interval) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): pbar.refresh() print ('Checking on training set.') evaluator.run(train4val_loader) metrics = evaluator.state.metrics avg_accuracy = metrics['accuracy'] avg_loss = metrics['loss'] precision_recall = metrics['precision_recall'] cmatrix = metrics['cmatrix'] prompt = GetTemplate('default-log').format('Training',engine.state.epoch,avg_accuracy,avg_loss,precision_recall['pretty'],cmatrix['pretty']) tqdm.write(prompt) logging.info('\n'+prompt) writer.add_text(os.environ['run-id'], prompt, engine.state.epoch) writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy}, engine.state.epoch) writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss}, engine.state.epoch) # writer.add_scalars('Aggregate/Score', {'Train avg precision': precision_recall['data'][0, -1], 'Train avg recall': precision_recall['data'][1, -1]}, engine.state.epoch) # pbar.n = pbar.last_print_n = 0 @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): print ('Checking on validation set.') evaluator.run(val_loader) metrics = evaluator.state.metrics avg_accuracy = metrics['accuracy'] avg_loss = metrics['loss'] precision_recall = metrics['precision_recall'] cmatrix = metrics['cmatrix'] prompt = GetTemplate('default-log').format('Validating',engine.state.epoch,avg_accuracy,avg_loss,precision_recall['pretty'],cmatrix['pretty']) tqdm.write(prompt) logging.info('\n'+prompt) writer.add_text(os.environ['run-id'], prompt, engine.state.epoch) writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy}, engine.state.epoch) writer.add_scalars('Aggregate/Loss', {'Val Loss': avg_loss}, engine.state.epoch) writer.add_scalars('Aggregate/Score', {'Val avg precision': precision_recall['data'][0, -1], 'Val avg recall': precision_recall['data'][1, -1]}, engine.state.epoch) pbar.n = pbar.last_print_n = 0 # ------------------------------------ # Run trainer.run(train_loader, max_epochs=epochs) pbar.close()
def run(tb, vb, lr, epochs, writer): device = os.environ['main-device'] logging.info('Training program start!') logging.info('Configuration:') logging.info('\n' + json.dumps(INFO, indent=2)) # ------------------------------------ # 1. Define dataloader train_loader, train4val_loader, val_loader, num_of_images, mapping = get_dataloaders( tb, vb) # train_loader, train4val_loader, val_loader, num_of_images = get_dataloaders(tb, vb) weights = (1 / num_of_images) / ((1 / num_of_images).sum().item()) # weights = (1/num_of_images)/(1/num_of_images + 1/(num_of_images.sum().item()-num_of_images)) weights = weights.to(device=device) # ------------------------------------ # 2. Define model model = EfficientNet.from_pretrained( 'efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes']) model = carrier(model) # ------------------------------------ # 3. Define optimizer optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) ignite_scheduler = LRScheduler(scheduler) # ------------------------------------ # 4. Define metrics class DOCLoss(nn.Module): def __init__(self, weight): super(DOCLoss, self).__init__() self.class_weights = weight def forward(self, input, target): sigmoid = 1 - 1 / (1 + torch.exp(-input)) sigmoid[range(0, sigmoid.shape[0]), target] = 1 - sigmoid[range(0, sigmoid.shape[0]), target] sigmoid = torch.log(sigmoid) if self.class_weights is not None: loss = -torch.sum(sigmoid * self.class_weights) else: loss = -torch.sum(sigmoid) return loss train_metrics = { 'accuracy': Accuracy(), 'loss': Loss(DOCLoss(weight=weights)), 'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(), Recall(), train_loader.dataset.classes), 'cmatrix': MetricsLambda(CMatrixTable, ConfusionMatrix(INFO['dataset-info']['num-of-classes']), train_loader.dataset.classes) } def val_pred_transform(output): y_pred, y = output new_y_pred = torch.zeros( (y_pred.shape[0], INFO['dataset-info']['num-of-classes'] + 1)).to(device=device) for ind, c in enumerate(train_loader.dataset.classes): new_col = val_loader.dataset.class_to_idx[c] new_y_pred[:, new_col] += y_pred[:, ind] ukn_ind = val_loader.dataset.class_to_idx['UNKNOWN'] import math new_y_pred[:, ukn_ind] = -math.inf return new_y_pred, y val_metrics = { 'accuracy': Accuracy(), 'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(val_pred_transform), Recall(val_pred_transform), val_loader.dataset.classes), 'cmatrix': MetricsLambda( CMatrixTable, ConfusionMatrix(INFO['dataset-info']['num-of-classes'] + 1, output_transform=val_pred_transform), val_loader.dataset.classes) } # ------------------------------------ # 5. Create trainer trainer = create_supervised_trainer(model, optimizer, DOCLoss(weight=weights), device=device) # ------------------------------------ # 6. Create evaluator train_evaluator = create_supervised_evaluator(model, metrics=train_metrics, device=device) val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device) desc = 'ITERATION - loss: {:.4f}' pbar = tqdm(initial=0, leave=False, total=len(train_loader), desc=desc.format(0)) # ------------------------------------ # 7. Create event hooks # Update process bar on each iteration completed. @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): log_interval = 1 iter = (engine.state.iteration - 1) % len(train_loader) + 1 if iter % log_interval == 0: pbar.desc = desc.format(engine.state.output) pbar.update(log_interval) # Compute metrics on train data on each epoch completed. @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): pbar.refresh() print('Checking on training set.') train_evaluator.run(train4val_loader) metrics = train_evaluator.state.metrics avg_accuracy = metrics['accuracy'] avg_loss = metrics['loss'] precision_recall = metrics['precision_recall'] cmatrix = metrics['cmatrix'] prompt = """ Training Results - Epoch: {} Avg accuracy: {:.4f} Avg loss: {:.4f} precision_recall: \n{} confusion matrix: \n{} """.format(engine.state.epoch, avg_accuracy, avg_loss, precision_recall['pretty'], cmatrix['pretty']) tqdm.write(prompt) logging.info('\n' + prompt) writer.add_text(os.environ['run-id'], prompt, engine.state.epoch) writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy}, engine.state.epoch) writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss}, engine.state.epoch) # Compute metrics on val data on each epoch completed. @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): print('Checking on validation set.') val_evaluator.run(val_loader) metrics = val_evaluator.state.metrics avg_accuracy = metrics['accuracy'] precision_recall = metrics['precision_recall'] cmatrix = metrics['cmatrix'] prompt = """ Validating Results - Epoch: {} Avg accuracy: {:.4f} precision_recall: \n{} confusion matrix: \n{} """.format(engine.state.epoch, avg_accuracy, precision_recall['pretty'], cmatrix['pretty']) tqdm.write(prompt) logging.info('\n' + prompt) writer.add_text(os.environ['run-id'], prompt, engine.state.epoch) writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy}, engine.state.epoch) writer.add_scalars( 'Aggregate/Score', { 'Val avg precision': precision_recall['data'][0, -1], 'Val avg recall': precision_recall['data'][1, -1] }, engine.state.epoch) pbar.n = pbar.last_print_n = 0 # Save model ever N epoch. save_model_handler = ModelCheckpoint(os.environ['savedir'], '', save_interval=50, n_saved=2) trainer.add_event_handler(Events.EPOCH_COMPLETED, save_model_handler, {'model': model}) # Update learning-rate due to scheduler. trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler) # ------------------------------------ # Run trainer.run(train_loader, max_epochs=epochs) pbar.close()
def create_trainer( train_step, output_names, model, ema_model, optimizer, lr_scheduler, supervised_train_loader, test_loader, cfg, logger, cta=None, unsup_train_loader=None, cta_probe_loader=None, ): trainer = Engine(train_step) trainer.logger = logger output_path = os.getcwd() to_save = { "model": model, "ema_model": ema_model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler, } if cta is not None: to_save["cta"] = cta common.setup_common_training_handlers( trainer, train_sampler=supervised_train_loader.sampler, to_save=to_save, save_every_iters=cfg.solver.checkpoint_every, output_path=output_path, output_names=output_names, lr_scheduler=lr_scheduler, with_pbars=False, clear_cuda_cache=False, ) ProgressBar(persist=False).attach( trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED ) unsupervised_train_loader_iter = None if unsup_train_loader is not None: unsupervised_train_loader_iter = cycle(unsup_train_loader) cta_probe_loader_iter = None if cta_probe_loader is not None: cta_probe_loader_iter = cycle(cta_probe_loader) # Setup handler to prepare data batches @trainer.on(Events.ITERATION_STARTED) def prepare_batch(e): sup_batch = e.state.batch e.state.batch = { "sup_batch": sup_batch, } if unsupervised_train_loader_iter is not None: unsup_batch = next(unsupervised_train_loader_iter) e.state.batch["unsup_batch"] = unsup_batch if cta_probe_loader_iter is not None: cta_probe_batch = next(cta_probe_loader_iter) cta_probe_batch["policy"] = [ deserialize(p) for p in cta_probe_batch["policy"] ] e.state.batch["cta_probe_batch"] = cta_probe_batch # Setup handler to update EMA model @trainer.on(Events.ITERATION_COMPLETED, cfg.ema_decay) def update_ema_model(ema_decay): # EMA on parametes for ema_param, param in zip(ema_model.parameters(), model.parameters()): ema_param.data.mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay) # Setup handlers for debugging if cfg.debug: @trainer.on(Events.STARTED | Events.ITERATION_COMPLETED(every=100)) @idist.one_rank_only() def log_weights_norms(): wn = [] ema_wn = [] for ema_param, param in zip(ema_model.parameters(), model.parameters()): wn.append(torch.mean(param.data)) ema_wn.append(torch.mean(ema_param.data)) msg = "\n\nWeights norms" msg += "\n- Raw model: {}".format( to_list_str(torch.tensor(wn[:10] + wn[-10:])) ) msg += "\n- EMA model: {}\n".format( to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:])) ) logger.info(msg) rmn = [] rvar = [] ema_rmn = [] ema_rvar = [] for m1, m2 in zip(model.modules(), ema_model.modules()): if isinstance(m1, nn.BatchNorm2d) and isinstance(m2, nn.BatchNorm2d): rmn.append(torch.mean(m1.running_mean)) rvar.append(torch.mean(m1.running_var)) ema_rmn.append(torch.mean(m2.running_mean)) ema_rvar.append(torch.mean(m2.running_var)) msg = "\n\nBN buffers" msg += "\n- Raw mean: {}".format(to_list_str(torch.tensor(rmn[:10]))) msg += "\n- Raw var: {}".format(to_list_str(torch.tensor(rvar[:10]))) msg += "\n- EMA mean: {}".format(to_list_str(torch.tensor(ema_rmn[:10]))) msg += "\n- EMA var: {}\n".format(to_list_str(torch.tensor(ema_rvar[:10]))) logger.info(msg) # TODO: Need to inspect a bug # if idist.get_rank() == 0: # from ignite.contrib.handlers import ProgressBar # # profiler = BasicTimeProfiler() # profiler.attach(trainer) # # @trainer.on(Events.ITERATION_COMPLETED(every=200)) # def log_profiling(_): # results = profiler.get_results() # profiler.print_results(results) # Setup validation engine metrics = { "accuracy": Accuracy(), } if not (idist.has_xla_support and idist.backend() == idist.xla.XLA_TPU): metrics.update({ "precision": Precision(average=False), "recall": Recall(average=False), }) eval_kwargs = dict( metrics=metrics, prepare_batch=sup_prepare_batch, device=idist.device(), non_blocking=True, ) evaluator = create_supervised_evaluator(model, **eval_kwargs) ema_evaluator = create_supervised_evaluator(ema_model, **eval_kwargs) def log_results(epoch, max_epochs, metrics, ema_metrics): msg1 = "\n".join( ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in metrics.items()] ) msg2 = "\n".join( ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in ema_metrics.items()] ) logger.info( "\nEpoch {}/{}\nRaw:\n{}\nEMA:\n{}\n".format(epoch, max_epochs, msg1, msg2) ) if cta is not None: logger.info("\n" + stats(cta)) @trainer.on( Events.EPOCH_COMPLETED(every=cfg.solver.validate_every) | Events.STARTED | Events.COMPLETED ) def run_evaluation(): evaluator.run(test_loader) ema_evaluator.run(test_loader) log_results( trainer.state.epoch, trainer.state.max_epochs, evaluator.state.metrics, ema_evaluator.state.metrics, ) # setup TB logging if idist.get_rank() == 0: tb_logger = common.setup_tb_logging( output_path, trainer, optimizers=optimizer, evaluators={"validation": evaluator, "ema validation": ema_evaluator}, log_every_iters=15, ) if cfg.online_exp_tracking.wandb: from ignite.contrib.handlers import WandBLogger wb_dir = Path("/tmp/output-fixmatch-wandb") if not wb_dir.exists(): wb_dir.mkdir() _ = WandBLogger( project="fixmatch-pytorch", name=cfg.name, config=cfg, sync_tensorboard=True, dir=wb_dir.as_posix(), reinit=True, ) resume_from = cfg.solver.resume_from if resume_from is not None: resume_from = list(Path(resume_from).rglob("training_checkpoint*.pt*")) if len(resume_from) > 0: # get latest checkpoint_fp = max(resume_from, key=lambda p: p.stat().st_mtime) assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format( checkpoint_fp.as_posix() ) logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) checkpoint = torch.load(checkpoint_fp.as_posix()) Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) @trainer.on(Events.COMPLETED) def release_all_resources(): nonlocal unsupervised_train_loader_iter, cta_probe_loader_iter if idist.get_rank() == 0: tb_logger.close() if unsupervised_train_loader_iter is not None: unsupervised_train_loader_iter = None if cta_probe_loader_iter is not None: cta_probe_loader_iter = None return trainer
def run(tb, vb, lr, epochs, writer): device = os.environ['main-device'] logging.info('Training program start!') logging.info('Configuration:') logging.info('\n' + json.dumps(INFO, indent=2)) # ------------------------------------ # 1. Define dataloader train_loader, train4val_loader, val_loader, num_of_images, mapping = get_dataloaders( tb, vb) weights = (1 / num_of_images) / ((1 / num_of_images).sum().item()) weights = weights.to(device=device) # ------------------------------------ # 2. Define model model = EfficientNet.from_pretrained( 'efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes']) model = carrier(model) # ------------------------------------ # 3. Define optimizer optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) ignite_scheduler = LRScheduler(scheduler) # ------------------------------------ # 4. Define metrics train_metrics = { 'accuracy': Accuracy(), 'loss': Loss(nn.CrossEntropyLoss(weight=weights)), 'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(), Recall(), train_loader.dataset.classes), 'cmatrix': MetricsLambda(CMatrixTable, ConfusionMatrix(INFO['dataset-info']['num-of-classes']), train_loader.dataset.classes) } def val_pred_transform(output): y_pred, y = output new_y_pred = torch.zeros( (y_pred.shape[0], len(INFO['dataset-info']['known-classes']) + 1)).to(device=device) for c in range(y_pred.shape[1]): if c == 0: new_y_pred[:, mapping[c]] += y_pred[:, c] elif mapping[c] == val_loader.dataset.class_to_idx['UNKNOWN']: new_y_pred[:, mapping[c]] = torch.where( new_y_pred[:, mapping[c]] > y_pred[:, c], new_y_pred[:, mapping[c]], y_pred[:, c]) else: new_y_pred[:, mapping[c]] += y_pred[:, c] return new_y_pred, y val_metrics = { 'accuracy': Accuracy(val_pred_transform), 'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(val_pred_transform), Recall(val_pred_transform), val_loader.dataset.classes), 'cmatrix': MetricsLambda( CMatrixTable, ConfusionMatrix(len(INFO['dataset-info']['known-classes']) + 1, output_transform=val_pred_transform), val_loader.dataset.classes) } # ------------------------------------ # 5. Create trainer trainer = create_supervised_trainer(model, optimizer, nn.CrossEntropyLoss(weight=weights), device=device) # ------------------------------------ # 6. Create evaluator train_evaluator = create_supervised_evaluator(model, metrics=train_metrics, device=device) val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device) desc = 'ITERATION - loss: {:.4f}' pbar = tqdm(initial=0, leave=False, total=len(train_loader), desc=desc.format(0)) # ------------------------------------ # 7. Create event hooks @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): log_interval = 1 iter = (engine.state.iteration - 1) % len(train_loader) + 1 if iter % log_interval == 0: pbar.desc = desc.format(engine.state.output) pbar.update(log_interval) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): pbar.refresh() print('Checking on training set.') train_evaluator.run(train4val_loader) metrics = train_evaluator.state.metrics avg_accuracy = metrics['accuracy'] avg_loss = metrics['loss'] precision_recall = metrics['precision_recall'] cmatrix = metrics['cmatrix'] prompt = """ <Training> Results - Epoch: {} Avg accuracy: {:.4f} Avg loss: {:.4f} precision_recall: \n{} confusion matrix: \n{} """.format(engine.state.epoch, avg_accuracy, avg_loss, precision_recall['pretty'], cmatrix['pretty']) tqdm.write(prompt) logging.info('\n' + prompt) writer.add_text(os.environ['run-id'], prompt, engine.state.epoch) writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy}, engine.state.epoch) writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss}, engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): print('Checking on validation set.') val_evaluator.run(val_loader) metrics = val_evaluator.state.metrics avg_accuracy = metrics['accuracy'] precision_recall = metrics['precision_recall'] cmatrix = metrics['cmatrix'] prompt = """ <Validating> Results - Epoch: {} Avg accuracy: {:.4f} precision_recall: \n{} confusion matrix: \n{} """.format(engine.state.epoch, avg_accuracy, precision_recall['pretty'], cmatrix['pretty']) tqdm.write(prompt) logging.info('\n' + prompt) writer.add_text(os.environ['run-id'], prompt, engine.state.epoch) writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy}, engine.state.epoch) writer.add_scalars( 'Aggregate/Score', { 'Val avg precision': precision_recall['data'][0, -1], 'Val avg recall': precision_recall['data'][1, -1] }, engine.state.epoch) pbar.n = pbar.last_print_n = 0 trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler) # ------------------------------------ # Run trainer.run(train_loader, max_epochs=epochs) pbar.close()