Esempio n. 1
0
def test_data_logger_pred_user_proc(test_settings, live_mock_server):
    with wandb.init(settings=test_settings) as run:
        vd = ValidationDataLogger(
            inputs=np.array([[i, i, i] for i in range(10)]),
            targets=np.array([[i] for i in range(10)]),
            indexes=None,
            validation_row_processor=None,
            prediction_row_processor=lambda ndx, row: {"oa": row["output"] + 1},
            class_labels=None,
            infer_missing_processors=False,
        )
        t = vd.log_predictions(vd.make_predictions(lambda inputs: inputs[:, 0]))
        cols = ["val_row", "output", "oa"]
        tcols = t.columns

        assert set(tcols) == set(cols)
        assert np.all([t.data[i] == [i, i, i + 1] for i in range(10)])
        assert t._get_artifact_reference_entry() is not None
Esempio n. 2
0
def test_data_logger_pred_inferred_proc_no_classes(test_settings,
                                                   live_mock_server):
    with wandb.init(settings=test_settings) as run:
        vd = ValidationDataLogger(
            inputs=np.array([[i, i, i] for i in range(10)]),
            targets=np.array([[i] for i in range(10)]),
            indexes=None,
            validation_row_processor=None,
            prediction_row_processor=None,
            class_labels=None,
            infer_missing_processors=True,
        )

        t = vd.log_predictions(
            vd.make_predictions(
                lambda inputs: {
                    "simple": np.random.randint(5, size=(10)),
                    "wrapped": np.random.randint(5, size=(10, 1)),
                    "logits": np.random.randint(5, size=(10, 5)),
                    "nodes": np.random.randint(5, size=(10, 10)),
                    "2dimages": np.random.randint(255, size=(10, 5, 5)),
                    "3dimages": np.random.randint(255, size=(10, 5, 5, 3)),
                    "video": np.random.randint(255, size=(10, 5, 5, 3, 10)),
                }))

        cols = [
            "val_row",
            "simple",
            "wrapped",
            "logits",
            "nodes",
            "2dimages",
            "3dimages",
            "video",
            "wrapped:val",
            "logits:node",
            "logits:argmax",
            "logits:argmin",
            "nodes:node",
            "nodes:argmax",
            "nodes:argmin",
        ]
        if CAN_INFER_IMAGE_AND_VIDEO:
            cols.append("2dimages:image")
            cols.append("3dimages:image")
            cols.append("video:video")

        tcols = t.columns

        row = t.data[0]

        assert set(tcols) == set(cols)
        assert isinstance(row[tcols.index("val_row")],
                          wandb.data_types._TableIndex)
        assert isinstance(row[tcols.index("simple")].tolist(), int)
        assert len(row[tcols.index("wrapped")]) == 1
        assert len(row[tcols.index("logits")]) == 5
        assert len(row[tcols.index("nodes")]) == 10
        assert row[tcols.index("2dimages")].shape == (5, 5)
        assert row[tcols.index("3dimages")].shape == (5, 5, 3)
        assert row[tcols.index("video")].shape == (5, 5, 3, 10)
        assert isinstance(row[tcols.index("wrapped:val")].tolist(), int)
        # assert isinstance(row[tcols.index("logits:node")], dict)
        assert isinstance(row[tcols.index("logits:node")], list)
        assert isinstance(row[tcols.index("logits:argmax")].tolist(), int)
        assert isinstance(row[tcols.index("logits:argmin")].tolist(), int)
        # assert isinstance(row[tcols.index("nodes:node")], dict)
        assert isinstance(row[tcols.index("nodes:node")], list)
        assert isinstance(row[tcols.index("nodes:argmax")].tolist(), int)
        assert isinstance(row[tcols.index("nodes:argmin")].tolist(), int)

        if CAN_INFER_IMAGE_AND_VIDEO:
            assert isinstance(row[tcols.index("2dimages:image")],
                              wandb.data_types.Image)
            assert isinstance(row[tcols.index("3dimages:image")],
                              wandb.data_types.Image)
            assert isinstance(row[tcols.index("video:video")],
                              wandb.data_types.Video)