Ejemplo n.º 1
0
def test_data_logger_val_data_dicts(test_settings, live_mock_server):
    with wandb.init(settings=test_settings) as run:
        vd = ValidationDataLogger(
            inputs={
                "ia": np.array([[i, i, i] for i in range(10)]),
                "ib": np.array([[i, i, i] for i in range(10)]),
            },
            targets={
                "ta": np.array([[i] for i in range(10)]),
                "tb": np.array([[i] for i in range(10)]),
            },
            indexes=None,
            validation_row_processor=None,
            prediction_row_processor=None,
            class_labels=None,
            infer_missing_processors=False,
        )

        cols = ["ia", "ib", "ta", "tb"]
        tcols = vd.validation_indexes[0]._table.columns
        assert set(tcols) == set(cols)
        assert np.all([
            vd.validation_indexes[0]._table.data[i][tcols.index(
                "ia")].tolist() == [i, i, i]
            and vd.validation_indexes[0]._table.data[i][tcols.index(
                "ib")].tolist() == [i, i, i]
            and vd.validation_indexes[0]._table.data[i][tcols.index(
                "ta")].tolist() == [i]
            and vd.validation_indexes[0]._table.data[i][tcols.index(
                "tb")].tolist() == [i] for i in range(10)
        ])

        assert (
            vd.validation_indexes[0]._table._get_artifact_reference_entry()
            is not None)
Ejemplo n.º 2
0
def test_data_logger_pred_user_proc(test_settings, live_mock_server):
    with wandb.init(settings=test_settings) as run:
        vd = ValidationDataLogger(
            inputs=np.array([[i, i, i] for i in range(10)]),
            targets=np.array([[i] for i in range(10)]),
            indexes=None,
            validation_row_processor=None,
            prediction_row_processor=lambda ndx, row: {"oa": row["output"] + 1},
            class_labels=None,
            infer_missing_processors=False,
        )
        t = vd.log_predictions(vd.make_predictions(lambda inputs: inputs[:, 0]))
        cols = ["val_row", "output", "oa"]
        tcols = t.columns

        assert set(tcols) == set(cols)
        assert np.all([t.data[i] == [i, i, i + 1] for i in range(10)])
        assert t._get_artifact_reference_entry() is not None
Ejemplo n.º 3
0
def test_data_logger_val_invalid(test_settings, live_mock_server):
    with wandb.init(settings=test_settings) as run:
        with pytest.raises(AssertionError):
            vd = ValidationDataLogger(
                inputs={
                    "ia": np.array([[i, i, i] for i in range(10)]),
                    "ib": np.array([[i, i, i] for i in range(10)]),
                },
                targets=None,
                indexes=None,
                validation_row_processor=None,
                prediction_row_processor=None,
                class_labels=None,
                infer_missing_processors=False,
            )
Ejemplo n.º 4
0
def test_data_logger_val_indexes(test_settings, live_mock_server):
    with wandb.init(settings=test_settings) as run:
        table = wandb.Table(columns=["label"], data=[["cat"]])
        vd = ValidationDataLogger(
            inputs={
                "ia": np.array([[i, i, i] for i in range(10)]),
                "ib": np.array([[i, i, i] for i in range(10)]),
            },
            targets=None,
            indexes=[table.index_ref(0) for i in range(10)],
            validation_row_processor=None,
            prediction_row_processor=None,
            class_labels=None,
            infer_missing_processors=False,
        )
Ejemplo n.º 5
0
def test_data_logger_val_user_proc(test_settings, live_mock_server):
    with wandb.init(settings=test_settings) as run:
        vd = ValidationDataLogger(
            inputs=np.array([[i, i, i] for i in range(10)]),
            targets=np.array([[i] for i in range(10)]),
            indexes=None,
            validation_row_processor=lambda ndx, row: {
                "ip_1": row["input"] + 1,
                "tp_1": row["target"] + 1,
            },
            prediction_row_processor=None,
            class_labels=None,
            infer_missing_processors=False,
        )

        cols = [
            "input",
            "target",
            "ip_1",
            "tp_1",
        ]
        tcols = vd.validation_indexes[0]._table.columns
        assert set(tcols) == set(cols)
        assert np.all(
            [
                vd.validation_indexes[0]._table.data[i][tcols.index("input")].tolist()
                == [i, i, i]
                and vd.validation_indexes[0]
                ._table.data[i][tcols.index("target")]
                .tolist()
                == [i]
                and vd.validation_indexes[0]
                ._table.data[i][tcols.index("ip_1")]
                .tolist()
                == [i + 1, i + 1, i + 1]
                and vd.validation_indexes[0]
                ._table.data[i][tcols.index("tp_1")]
                .tolist()
                == [i + 1]
                for i in range(10)
            ]
        )
        assert (
            vd.validation_indexes[0]._table._get_artifact_reference_entry() is not None
        )
Ejemplo n.º 6
0
 def on_train_begin(self, logs=None):
     if self.log_evaluation:
         try:
             validation_data = None
             if self.validation_data:
                 validation_data = self.validation_data
             elif self.generator:
                 if not self.validation_steps:
                     wandb.termwarn(
                         "WandbCallback is unable to log validation data. When using a generator for validation_data, you must pass validation_steps"
                     )
                 else:
                     x = None
                     y_true = None
                     for i in range(self.validation_steps):
                         bx, by_true = next(self.generator)
                         if x is None:
                             x, y_true = bx, by_true
                         else:
                             x, y_true = (
                                 np.append(x, bx, axis=0),
                                 np.append(y_true, by_true, axis=0),
                             )
                     validation_data = (x, y_true)
             else:
                 wandb.termwarn(
                     "WandbCallback is unable to read validation_data from trainer and therefore cannot log validation data. Ensure Keras is properly patched by calling `from wandb.keras import WandbCallback` at the top of your script."
                 )
             if validation_data:
                 self._validation_data_logger = ValidationDataLogger(
                     inputs=validation_data[0],
                     targets=validation_data[1],
                     indexes=self._validation_indexes,
                     validation_row_processor=self.
                     _validation_row_processor,
                     prediction_row_processor=self.
                     _prediction_row_processor,
                     class_labels=self.labels,
                     infer_missing_processors=self.
                     _infer_missing_processors,
                 )
         except Exception as e:
             wandb.termwarn(
                 "Error initializing ValidationDataLogger in WandbCallback. Skipping logging validation data. Error: "
                 + str(e))
Ejemplo n.º 7
0
 def _init_validation_gen(self):
     """
     Helper method for initializing Validation data table
     """
     if self.log_evaluation:
         try:
             validation_data = None
             if self.validation_data:
                 validation_data = self.validation_data
                 self.validation_data_logger = ValidationDataLogger(
                     inputs=validation_data[0],
                     targets=validation_data[1],
                     indexes=None,
                     validation_row_processor=None,
                     prediction_row_processor=lambda ndx, row: {"output": np.argmax(row["output"])},
                     class_labels=self.labels,
                     infer_missing_processors=self.infer_missing_processors)
         except Exception as e:
             wandb.termwarn(
                 "Error initializing ValidationDataLogger in WandbCallback. Skipping logging validation data. Error: " + str(
                     e))
Ejemplo n.º 8
0
 def _init_testing_gen(self):
     """
     Helper method for initializing Testing data table
     """
     if self.log_evaluation:
         try:
             testing_data = None
             if self.testing_data:
                 testing_data = self.testing_data
                 self.testing_data_logger = ValidationDataLogger(
                     inputs=testing_data[0],
                     targets=testing_data[1],
                     indexes=None,
                     validation_row_processor=None,
                     prediction_row_processor=None,
                     class_labels=self.labels,
                     infer_missing_processors=self.infer_missing_processors)
         except Exception as e:
             wandb.termwarn(
                 "Error initializing ValidationDataLogger in WandbCallback. Skipping logging validation data. Error: " + str(
                     e))
Ejemplo n.º 9
0
def test_data_logger_val_data_lists(live_mock_server, test_settings):
    with wandb.init(settings=test_settings) as run:
        vd = ValidationDataLogger(
            inputs=np.array([[i, i, i] for i in range(10)]),
            targets=np.array([[i] for i in range(10)]),
            indexes=None,
            validation_row_processor=None,
            prediction_row_processor=None,
            class_labels=None,
            infer_missing_processors=False,
        )
        cols = ["input", "target"]
        tcols = vd.validation_indexes[0]._table.columns
        assert set(tcols) == set(cols)
        assert np.all([
            vd.validation_indexes[0]._table.data[i][tcols.index(
                "input")].tolist() == [i, i, i]
            and vd.validation_indexes[0]._table.data[i][tcols.index(
                "target")].tolist() == [i] for i in range(10)
        ])
        assert vd.validation_indexes[0]._table._get_artifact_entry_ref_url(
        ) is not None
        run.finish()
Ejemplo n.º 10
0
def test_data_logger_pred_inferred_proc_no_classes(test_settings,
                                                   live_mock_server):
    with wandb.init(settings=test_settings) as run:
        vd = ValidationDataLogger(
            inputs=np.array([[i, i, i] for i in range(10)]),
            targets=np.array([[i] for i in range(10)]),
            indexes=None,
            validation_row_processor=None,
            prediction_row_processor=None,
            class_labels=None,
            infer_missing_processors=True,
        )

        t = vd.log_predictions(
            vd.make_predictions(
                lambda inputs: {
                    "simple": np.random.randint(5, size=(10)),
                    "wrapped": np.random.randint(5, size=(10, 1)),
                    "logits": np.random.randint(5, size=(10, 5)),
                    "nodes": np.random.randint(5, size=(10, 10)),
                    "2dimages": np.random.randint(255, size=(10, 5, 5)),
                    "3dimages": np.random.randint(255, size=(10, 5, 5, 3)),
                    "video": np.random.randint(255, size=(10, 5, 5, 3, 10)),
                }))

        cols = [
            "val_row",
            "simple",
            "wrapped",
            "logits",
            "nodes",
            "2dimages",
            "3dimages",
            "video",
            "wrapped:val",
            "logits:node",
            "logits:argmax",
            "logits:argmin",
            "nodes:node",
            "nodes:argmax",
            "nodes:argmin",
        ]
        if CAN_INFER_IMAGE_AND_VIDEO:
            cols.append("2dimages:image")
            cols.append("3dimages:image")
            cols.append("video:video")

        tcols = t.columns

        row = t.data[0]

        assert set(tcols) == set(cols)
        assert isinstance(row[tcols.index("val_row")],
                          wandb.data_types._TableIndex)
        assert isinstance(row[tcols.index("simple")].tolist(), int)
        assert len(row[tcols.index("wrapped")]) == 1
        assert len(row[tcols.index("logits")]) == 5
        assert len(row[tcols.index("nodes")]) == 10
        assert row[tcols.index("2dimages")].shape == (5, 5)
        assert row[tcols.index("3dimages")].shape == (5, 5, 3)
        assert row[tcols.index("video")].shape == (5, 5, 3, 10)
        assert isinstance(row[tcols.index("wrapped:val")].tolist(), int)
        # assert isinstance(row[tcols.index("logits:node")], dict)
        assert isinstance(row[tcols.index("logits:node")], list)
        assert isinstance(row[tcols.index("logits:argmax")].tolist(), int)
        assert isinstance(row[tcols.index("logits:argmin")].tolist(), int)
        # assert isinstance(row[tcols.index("nodes:node")], dict)
        assert isinstance(row[tcols.index("nodes:node")], list)
        assert isinstance(row[tcols.index("nodes:argmax")].tolist(), int)
        assert isinstance(row[tcols.index("nodes:argmin")].tolist(), int)

        if CAN_INFER_IMAGE_AND_VIDEO:
            assert isinstance(row[tcols.index("2dimages:image")],
                              wandb.data_types.Image)
            assert isinstance(row[tcols.index("3dimages:image")],
                              wandb.data_types.Image)
            assert isinstance(row[tcols.index("video:video")],
                              wandb.data_types.Video)
Ejemplo n.º 11
0
def test_data_logger_val_inferred_proc(test_settings, live_mock_server):
    with wandb.init(settings=test_settings) as run:
        np.random.seed(42)
        vd = ValidationDataLogger(
            inputs=np.array([[i, i, i] for i in range(10)]),
            targets={
                "simple": np.random.randint(5, size=(10)),
                "wrapped": np.random.randint(5, size=(10, 1)),
                "logits": np.random.randint(5, size=(10, 5)),
                "nodes": np.random.randint(5, size=(10, 10)),
                "2dimages": np.random.randint(255, size=(10, 5, 5)),
                "3dimages": np.random.randint(255, size=(10, 5, 5, 3)),
                "video": np.random.randint(255, size=(10, 5, 5, 3, 10)),
            },
            indexes=None,
            validation_row_processor=None,
            prediction_row_processor=None,
            class_labels=["a", "b", "c", "d", "e"],
            infer_missing_processors=True,
        )

        cols = [
            "input",
            "simple",
            "wrapped",
            "logits",
            "nodes",
            "2dimages",
            "3dimages",
            "video",
            "input:node",
            "input:argmax",
            "input:argmin",
            "wrapped:class",
            "logits:max_class",
            "logits:score",
            "nodes:node",
            "nodes:argmax",
            "nodes:argmin",
        ]

        if CAN_INFER_IMAGE_AND_VIDEO:
            cols.append("2dimages:image")
            cols.append("3dimages:image")
            cols.append("video:video")

        tcols = vd.validation_indexes[0]._table.columns
        row = vd.validation_indexes[0]._table.data[0]

        assert set(tcols) == set(cols)
        assert np.all(row[tcols.index("input")] == [0, 0, 0])
        assert isinstance(row[tcols.index("simple")].tolist(), int)
        assert len(row[tcols.index("wrapped")]) == 1
        assert len(row[tcols.index("logits")]) == 5
        assert len(row[tcols.index("nodes")]) == 10
        assert row[tcols.index("2dimages")].shape == (5, 5)
        assert row[tcols.index("3dimages")].shape == (5, 5, 3)
        assert row[tcols.index("video")].shape == (5, 5, 3, 10)
        # assert isinstance(row[tcols.index("input:node")], dict)
        assert isinstance(row[tcols.index("input:node")], list)
        assert isinstance(row[tcols.index("input:argmax")].tolist(), int)
        assert isinstance(row[tcols.index("input:argmin")].tolist(), int)
        assert isinstance(row[tcols.index("wrapped:class")],
                          wandb.data_types._TableIndex)
        assert isinstance(row[tcols.index("logits:max_class")],
                          wandb.data_types._TableIndex)
        assert isinstance(row[tcols.index("logits:score")], dict)
        # assert isinstance(row[tcols.index("nodes:node")], dict)
        assert isinstance(row[tcols.index("nodes:node")], list)
        assert isinstance(row[tcols.index("nodes:argmax")].tolist(), int)
        assert isinstance(row[tcols.index("nodes:argmin")].tolist(), int)

        if CAN_INFER_IMAGE_AND_VIDEO:
            assert isinstance(row[tcols.index("2dimages:image")],
                              wandb.data_types.Image)
            assert isinstance(row[tcols.index("3dimages:image")],
                              wandb.data_types.Image)
            assert isinstance(row[tcols.index("video:video")],
                              wandb.data_types.Video)