def as_recarray(info_blob): """ Save network output as recarray to h5. Intended for when network outputs are 2D, i.e. (batchsize, X). Output from network: Dict with arrays, shapes (batchsize, x_i). E.g. {"foo": ndarray, "bar": ndarray} dtypes that will get saved to h5: (foo_1, foo_2, ..., bar_1, bar_2, ... ) """ datasets = dict() datasets["pred"] = misc.dict_to_recarray(info_blob.get("y_pred")) ys = info_blob.get("ys") if ys is not None: datasets["true"] = misc.dict_to_recarray(ys) y_values = info_blob.get("y_values") if y_values is not None: datasets['y_values'] = y_values # is already a structured array return datasets
def test_dict_to_recarray_weird_shape(self): inp = {"aa": np.ones((5, 3, 2)), "bb": np.ones((5, 1))} output = misc.dict_to_recarray(inp) self.assertTrue(output.shape == (5, )) self.assertTupleEqual( output.dtype.names, ('aa_1', 'aa_2', 'aa_3', 'aa_4', 'aa_5', 'aa_6', 'bb_1'))
def test_dict_to_recarray_dtype(self): inp = {"aa": np.ones((5, 3), dtype="int16"), "bb": np.ones((5, 1), dtype="float32")*4} output = misc.dict_to_recarray(inp) dtypes = { "aa_1": np.dtype("int16"), "aa_2": np.dtype("int16"), "aa_3": np.dtype("int16"), "bb_1": np.dtype("float32"), } self.assertTupleEqual(output.dtype.names, ('aa_1', 'aa_2', 'aa_3', 'bb_1')) for dtype_name, dtype in dtypes.items(): np.testing.assert_equal(output[dtype_name].dtype, dtype) np.testing.assert_equal(output["aa_1"], np.ones(5, dtype="int16")) np.testing.assert_equal(output["bb_1"], np.ones(5, dtype="float32")*4)
def test_dict_to_recarray_shape(self): inp = {"aa": np.ones((5, 3)), "bb": np.ones((5, 1))} output = misc.dict_to_recarray(inp) self.assertTrue(output.shape == (5, ))
def test_dict_to_recarray_wrong_dim(self): inp = {"aa": np.ones((4, 3)), "bb": np.ones((5, 1))} with self.assertRaises(ValueError): misc.dict_to_recarray(inp)