def test_linear_transformation_md5_01(): """ Test LinearTransformation op: valid params (transformation_matrix, mean_vector) Expected to pass """ logger.info("test_linear_transformation_md5_01") # Initialize parameters height = 50 weight = 50 dim = 3 * height * weight transformation_matrix = np.ones([dim, dim]) mean_vector = np.zeros(dim) # Generate dataset data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) transforms = [ py_vision.Decode(), py_vision.CenterCrop([height, weight]), py_vision.ToTensor(), py_vision.LinearTransformation(transformation_matrix, mean_vector) ] transform = py_vision.ComposeOp(transforms) data1 = data1.map(input_columns=["image"], operations=transform()) # Compare with expected md5 from images filename = "linear_transformation_01_result.npz" save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_linear_transformation_md5_05(): """ Test LinearTransformation op: mean_vector does not match dimension of transformation_matrix Expected to raise ValueError """ logger.info("test_linear_transformation_md5_05") # Initialize parameters height = 50 weight = 50 dim = 3 * height * weight transformation_matrix = np.ones([dim, dim]) mean_vector = np.zeros(dim - 1) # Generate dataset data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) try: transforms = [ py_vision.Decode(), py_vision.CenterCrop([height, weight]), py_vision.ToTensor(), py_vision.LinearTransformation(transformation_matrix, mean_vector) ] transform = py_vision.ComposeOp(transforms) data1 = data1.map(input_columns=["image"], operations=transform()) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) assert "should match" in str(e)
def test_linear_transformation_exception_02(): """ Test LinearTransformation op: mean_vector is not provided Expected to raise ValueError """ logger.info("test_linear_transformation_exception_02") # Initialize parameters height = 50 weight = 50 dim = 3 * height * weight transformation_matrix = np.ones([dim, dim]) # Generate dataset data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) try: transforms = [ py_vision.Decode(), py_vision.CenterCrop([height, weight]), py_vision.ToTensor(), py_vision.LinearTransformation(transformation_matrix, None) ] transform = py_vision.ComposeOp(transforms) data1 = data1.map(input_columns=["image"], operations=transform()) except TypeError as e: logger.info("Got an exception in DE: {}".format(str(e))) assert "Argument mean_vector with value None is not of type (<class 'numpy.ndarray'>,)" in str(e)
def test_linear_transformation_op(plot=False): """ Test LinearTransformation op: verify if images transform correctly """ logger.info("test_linear_transformation_01") # Initialize parameters height = 50 weight = 50 dim = 3 * height * weight transformation_matrix = np.eye(dim) mean_vector = np.zeros(dim) # Define operations transforms = [ py_vision.Decode(), py_vision.CenterCrop([height, weight]), py_vision.ToTensor() ] transform = py_vision.ComposeOp(transforms) # First dataset data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) data1 = data1.map(input_columns=["image"], operations=transform()) # Note: if transformation matrix is diagonal matrix with all 1 in diagonal, # the output matrix in expected to be the same as the input matrix. data1 = data1.map(input_columns=["image"], operations=py_vision.LinearTransformation( transformation_matrix, mean_vector)) # Second dataset data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) data2 = data2.map(input_columns=["image"], operations=transform()) image_transformed = [] image = [] for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8) image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) image_transformed.append(image1) image.append(image2) mse = diff_mse(image1, image2) assert mse == 0 if plot: visualize_list(image, image_transformed)