def test_non_user_provided_inputs_never_shape_tensors(self): # If the user didn't provide metadata, then the value can never be a shape tensor. input_meta = TensorMetadata().add("X", dtype=np.int32, shape=(3, )) data_loader = DataLoader() data_loader.input_metadata = input_meta feed_dict = data_loader[0] assert feed_dict["X"].shape == (3, ) # Treat as a normal tensor
def test_shape_tensor_detected(self): INPUT_DATA = (1, 2, 3) input_meta = TensorMetadata().add("X", dtype=np.int32, shape=(3, )) # This contains the shape values overriden_meta = TensorMetadata().add("X", dtype=np.int32, shape=INPUT_DATA) data_loader = DataLoader(input_metadata=overriden_meta) data_loader.input_metadata = input_meta feed_dict = data_loader[0] assert np.all(feed_dict["X"] == INPUT_DATA) # values become INPUT_DATA
def test_can_override_shape(self): model = ONNX_MODELS["dynamic_identity"] shape = (1, 1, 4, 5) custom_input_metadata = TensorMetadata().add("X", dtype=None, shape=shape) data_loader = DataLoader(input_metadata=custom_input_metadata) # Simulate what the comparator does data_loader.input_metadata = model.input_metadata feed_dict = data_loader[0] assert tuple(feed_dict["X"].shape) == shape
def test_no_shape_tensor_false_positive_float(self): INPUT_DATA = (-100, -50, 0) # Float cannot be a shape tensor input_meta = TensorMetadata().add("X", dtype=np.float32, shape=(3, )) overriden_meta = TensorMetadata().add("X", dtype=np.float32, shape=INPUT_DATA) data_loader = DataLoader(input_metadata=overriden_meta) data_loader.input_metadata = input_meta feed_dict = data_loader[0] assert feed_dict["X"].shape == (3, ) # Values are NOT (3, ) assert np.any(feed_dict["X"] != INPUT_DATA) # Values are NOT (3, )
def test_no_shape_tensor_false_positive_negative_dims(self): INPUT_DATA = (-100, 2, 4) # This should NOT be detected as a shape tensor input_meta = TensorMetadata().add("X", dtype=np.int32, shape=(3, )) overriden_meta = TensorMetadata().add("X", dtype=np.int32, shape=INPUT_DATA) data_loader = DataLoader(input_metadata=overriden_meta) data_loader.input_metadata = input_meta feed_dict = data_loader[0] assert feed_dict["X"].shape == (3, ) # Shape IS (3, ), because this is NOT a shape tensor assert np.any(feed_dict["X"] != INPUT_DATA) # Contents are not INPUT_DATA, since it's not treated as a shape value