Beispiel #1
0
    def test_non_user_provided_inputs_never_shape_tensors(self):
        # If the user didn't provide metadata, then the value can never be a shape tensor.
        input_meta = TensorMetadata().add("X", dtype=np.int32, shape=(3, ))
        data_loader = DataLoader()
        data_loader.input_metadata = input_meta

        feed_dict = data_loader[0]
        assert feed_dict["X"].shape == (3, )  # Treat as a normal tensor
Beispiel #2
0
    def test_shape_tensor_detected(self):
        INPUT_DATA = (1, 2, 3)
        input_meta = TensorMetadata().add("X", dtype=np.int32, shape=(3, ))
        # This contains the shape values
        overriden_meta = TensorMetadata().add("X", dtype=np.int32, shape=INPUT_DATA)
        data_loader = DataLoader(input_metadata=overriden_meta)
        data_loader.input_metadata = input_meta

        feed_dict = data_loader[0]
        assert np.all(feed_dict["X"] == INPUT_DATA) # values become INPUT_DATA
Beispiel #3
0
    def test_can_override_shape(self):
        model = ONNX_MODELS["dynamic_identity"]

        shape = (1, 1, 4, 5)
        custom_input_metadata = TensorMetadata().add("X", dtype=None, shape=shape)
        data_loader = DataLoader(input_metadata=custom_input_metadata)
        # Simulate what the comparator does
        data_loader.input_metadata = model.input_metadata

        feed_dict = data_loader[0]
        assert tuple(feed_dict["X"].shape) == shape
Beispiel #4
0
    def test_no_shape_tensor_false_positive_float(self):
        INPUT_DATA = (-100, -50, 0)
        # Float cannot be a shape tensor
        input_meta = TensorMetadata().add("X", dtype=np.float32, shape=(3, ))
        overriden_meta = TensorMetadata().add("X", dtype=np.float32, shape=INPUT_DATA)
        data_loader = DataLoader(input_metadata=overriden_meta)
        data_loader.input_metadata = input_meta

        feed_dict = data_loader[0]
        assert feed_dict["X"].shape == (3, ) # Values are NOT (3, )
        assert np.any(feed_dict["X"] != INPUT_DATA) # Values are NOT (3, )
Beispiel #5
0
    def test_no_shape_tensor_false_positive_negative_dims(self):
        INPUT_DATA = (-100, 2, 4)
        # This should NOT be detected as a shape tensor
        input_meta = TensorMetadata().add("X", dtype=np.int32, shape=(3, ))
        overriden_meta = TensorMetadata().add("X", dtype=np.int32, shape=INPUT_DATA)
        data_loader = DataLoader(input_metadata=overriden_meta)
        data_loader.input_metadata = input_meta

        feed_dict = data_loader[0]
        assert feed_dict["X"].shape == (3, ) # Shape IS (3, ), because this is NOT a shape tensor
        assert np.any(feed_dict["X"] != INPUT_DATA) # Contents are not INPUT_DATA, since it's not treated as a shape value