def test_serialize(): """Tests that it correctly [de]serializes. """ in_shape = (32, 64, np.random.randint(1, 128)) out_channels = np.random.randint(1, 128) window_shape = (4, 2) strides = (3, 1) pad = (1, 3) window_data = StridedWindowData(in_shape, window_shape, strides, pad, out_channels) serialized = window_data.serialize() assert serialized.in_height == 32 assert serialized.in_width == 64 assert serialized.in_channels == in_shape[2] assert serialized.window_height == 4 assert serialized.window_width == 2 assert serialized.out_channels == out_channels assert serialized.stride_height == 3 assert serialized.stride_width == 1 assert serialized.pad_height == 1 assert serialized.pad_width == 3 assert StridedWindowData.deserialize(serialized).serialize() == serialized
def deserialize(cls, serialized): """Deserializes the layer from the Protobuf format. """ if serialized.WhichOneof("layer_data") == "maxpool_data": window_data = StridedWindowData.deserialize( serialized.maxpool_data.window_data) return cls(window_data) return None
def deserialize(cls, serialized): """Deserializes from the Protobuf format. """ if serialized.WhichOneof("layer_data") == "conv2d_data": window_data = StridedWindowData.deserialize( serialized.conv2d_data.window_data) filters = np.array(serialized.conv2d_data.filters) filters = filters.reshape(window_data.window_shape + (window_data.input_shape[2], window_data.out_channels,)) biases = np.array(serialized.conv2d_data.biases) return cls(window_data, filters, biases) return None