def test_data_logger_pred_user_proc(test_settings, live_mock_server): with wandb.init(settings=test_settings) as run: vd = ValidationDataLogger( inputs=np.array([[i, i, i] for i in range(10)]), targets=np.array([[i] for i in range(10)]), indexes=None, validation_row_processor=None, prediction_row_processor=lambda ndx, row: {"oa": row["output"] + 1}, class_labels=None, infer_missing_processors=False, ) t = vd.log_predictions(vd.make_predictions(lambda inputs: inputs[:, 0])) cols = ["val_row", "output", "oa"] tcols = t.columns assert set(tcols) == set(cols) assert np.all([t.data[i] == [i, i, i + 1] for i in range(10)]) assert t._get_artifact_reference_entry() is not None
def test_data_logger_pred_inferred_proc_no_classes(test_settings, live_mock_server): with wandb.init(settings=test_settings) as run: vd = ValidationDataLogger( inputs=np.array([[i, i, i] for i in range(10)]), targets=np.array([[i] for i in range(10)]), indexes=None, validation_row_processor=None, prediction_row_processor=None, class_labels=None, infer_missing_processors=True, ) t = vd.log_predictions( vd.make_predictions( lambda inputs: { "simple": np.random.randint(5, size=(10)), "wrapped": np.random.randint(5, size=(10, 1)), "logits": np.random.randint(5, size=(10, 5)), "nodes": np.random.randint(5, size=(10, 10)), "2dimages": np.random.randint(255, size=(10, 5, 5)), "3dimages": np.random.randint(255, size=(10, 5, 5, 3)), "video": np.random.randint(255, size=(10, 5, 5, 3, 10)), })) cols = [ "val_row", "simple", "wrapped", "logits", "nodes", "2dimages", "3dimages", "video", "wrapped:val", "logits:node", "logits:argmax", "logits:argmin", "nodes:node", "nodes:argmax", "nodes:argmin", ] if CAN_INFER_IMAGE_AND_VIDEO: cols.append("2dimages:image") cols.append("3dimages:image") cols.append("video:video") tcols = t.columns row = t.data[0] assert set(tcols) == set(cols) assert isinstance(row[tcols.index("val_row")], wandb.data_types._TableIndex) assert isinstance(row[tcols.index("simple")].tolist(), int) assert len(row[tcols.index("wrapped")]) == 1 assert len(row[tcols.index("logits")]) == 5 assert len(row[tcols.index("nodes")]) == 10 assert row[tcols.index("2dimages")].shape == (5, 5) assert row[tcols.index("3dimages")].shape == (5, 5, 3) assert row[tcols.index("video")].shape == (5, 5, 3, 10) assert isinstance(row[tcols.index("wrapped:val")].tolist(), int) # assert isinstance(row[tcols.index("logits:node")], dict) assert isinstance(row[tcols.index("logits:node")], list) assert isinstance(row[tcols.index("logits:argmax")].tolist(), int) assert isinstance(row[tcols.index("logits:argmin")].tolist(), int) # assert isinstance(row[tcols.index("nodes:node")], dict) assert isinstance(row[tcols.index("nodes:node")], list) assert isinstance(row[tcols.index("nodes:argmax")].tolist(), int) assert isinstance(row[tcols.index("nodes:argmin")].tolist(), int) if CAN_INFER_IMAGE_AND_VIDEO: assert isinstance(row[tcols.index("2dimages:image")], wandb.data_types.Image) assert isinstance(row[tcols.index("3dimages:image")], wandb.data_types.Image) assert isinstance(row[tcols.index("video:video")], wandb.data_types.Video)