def test_save_numpy_array(tmpdir):
    tmpdir = str(tmpdir)
    h, w = 30, 30
    c = np.random.randint(low=1, high=4)
    z = np.random.randint(low=1, high=6)

    # 2D images without channel axis
    img = _get_image(h, w, 1)
    img = np.squeeze(img)
    files = utils.save_numpy_array(img, 'name', '/a/b/', tmpdir)
    assert len(files) == 1
    for f in files:
        assert os.path.isfile(f)
        assert f.startswith(os.path.join(tmpdir, 'a', 'b'))

    # 2D images
    img = _get_image(h, w, c)
    files = utils.save_numpy_array(img, 'name', '/a/b/', tmpdir)
    assert len(files) == c
    for f in files:
        assert os.path.isfile(f)
        assert f.startswith(os.path.join(tmpdir, 'a', 'b'))

    # 3D images
    imgs = np.vstack([_get_image(h, w, c)[None, ...] for i in range(z)])
    files = utils.save_numpy_array(imgs, 'name', '/a/b/', tmpdir)
    assert len(files) == c
    for f in files:
        assert os.path.isfile(f)
        assert f.startswith(os.path.join(tmpdir, 'a', 'b'))

    # Bad path will not fail, but will log error
    img = _get_image(h, w, c)
    files = utils.save_numpy_array(img, 'name', '/a/b/', '/does/not/exist/')
    assert not files
def test_grpc_response_to_dict():
    # pylint: disable=E1101
    # test valid response
    data = _get_image(300, 300, 1)
    tensor_proto = grpc_clients.make_tensor_proto(data, 'DT_FLOAT')
    response = PredictResponse()
    response.outputs['prediction'].CopyFrom(tensor_proto)
    response_dict = grpc_clients.grpc_response_to_dict(response)
    assert isinstance(response_dict, (dict, ))
    np.testing.assert_allclose(response_dict['prediction'], data)
    # test scalar input
    data = 3
    tensor_proto = grpc_clients.make_tensor_proto(data, 'DT_FLOAT')
    response = PredictResponse()
    response.outputs['prediction'].CopyFrom(tensor_proto)
    response_dict = grpc_clients.grpc_response_to_dict(response)
    assert isinstance(response_dict, (dict, ))
    np.testing.assert_allclose(response_dict['prediction'], data)
    # test bad dtype
    # logs an error, but should throw a KeyError as well.
    data = _get_image(300, 300, 1)
    tensor_proto = grpc_clients.make_tensor_proto(data, 'DT_FLOAT')
    response = PredictResponse()
    response.outputs['prediction'].CopyFrom(tensor_proto)
    response.outputs['prediction'].dtype = 32

    with pytest.raises(KeyError):
        response_dict = grpc_clients.grpc_response_to_dict(response)
    def test_get_image_label(self, mocker, redis_client):
        queue = 'q'
        stg = DummyStorage()
        consumer = consumers.SegmentationConsumer(redis_client, stg, queue)
        image = _get_image(256, 256, 1)

        # test no label provided
        expected = 1
        mocker.patch.object(consumer, 'detect_label', lambda *x: expected)
        label = consumer.get_image_label(None, image, 'some hash')
        assert label == expected

        # test label provided
        expected = 2
        label = consumer.get_image_label(expected, image, 'some hash')
        assert label == expected

        # test label provided is invalid
        with pytest.raises(ValueError):
            label = -1
            consumer.get_image_label(label, image, 'some hash')

        # test label provided is bad type
        with pytest.raises(ValueError):
            label = 'badval'
            consumer.get_image_label(label, image, 'some hash')
def _write_image(filepath, img_w=300, img_h=300):
    imarray = _get_image(img_h, img_w, 1)
    _, ext = os.path.splitext(filepath.lower())
    if ext in {'.tif', '.tiff'}:
        tifffile.imsave(filepath, imarray[..., 0])
    else:
        img = array_to_img(imarray, scale=False, data_format='channels_last')
        img.save(filepath)
def test_make_tensor_proto():
    # test with numpy array
    data = _get_image(300, 300, 1)
    proto = grpc_clients.make_tensor_proto(data, 'DT_FLOAT')
    assert isinstance(proto, (TensorProto, ))
    # test with value
    data = 10.0
    proto = grpc_clients.make_tensor_proto(data, types_pb2.DT_FLOAT)
    assert isinstance(proto, (TensorProto, ))
    def test__consume(self, mocker, redis_client):
        queue = 'track'
        storage = DummyStorage()
        test_hash = 0

        dummy_results = {
            'y_tracked': np.zeros((32, 32, 1)),
            'tracks': []
        }

        mock_model = Bunch(
            get_batch_size=lambda *x: 1,
            input_shape=(1, 32, 32, 1)
        )
        mock_app = Bunch(
            predict=lambda *x, **y: dummy_results,
            track=lambda *x, **y: dummy_results,
            model_mpp=1,
            model=mock_model,
        )

        consumer = consumers.CalibanConsumer(redis_client, storage, queue)

        mocker.patch.object(consumer, 'get_grpc_app',
                            lambda *x, **y: mock_app)
        # mock get_model_wrapper for neighborhood encoder
        mocker.patch.object(consumer, 'get_model_wrapper',
                            lambda *x, **y: mock_model)

        frames = 3
        dummy_data = {
            'X': np.array([_get_image(21, 21) for _ in range(frames)]),
            'y': np.random.randint(0, 9, size=(frames, 21, 21)),
        }

        mocker.patch.object(consumer, '_load_data', lambda *x: dummy_data)

        # test finished statuses are returned
        for status in (consumer.failed_status, consumer.final_status):
            test_hash += 1
            data = {'input_file_name': 'file.tiff', 'status': status}
            redis_client.hmset(test_hash, data)
            result = consumer._consume(test_hash)
            assert result == status

        # test new key is processed
        test_hash += 1
        data = {'input_file_name': 'file.tiff', 'status': 'new'}
        redis_client.hmset(test_hash, data)
        result = consumer._consume(test_hash)
        assert result == consumer.final_status
        assert redis_client.hget(test_hash, 'status') == consumer.final_status
    def test_detect_label(self, mocker, redis_client):
        # pylint: disable=W0613
        shape = (1, 256, 256, 1)
        queue = 'q'
        consumer = consumers.SegmentationConsumer(redis_client, None, queue)

        expected_label = random.randint(1, 9)

        mock_app = Bunch(predict=lambda *x, **y: expected_label,
                         model=Bunch(get_batch_size=lambda *x: 1))

        mocker.patch.object(consumer, 'get_grpc_app', lambda *x: mock_app)

        image = _get_image(shape[1] * 2, shape[2] * 2, shape[3])

        mocker.patch.object(settings, 'LABEL_DETECT_ENABLED', False)
        label = consumer.detect_label(image)
        assert label == 0

        mocker.patch.object(settings, 'LABEL_DETECT_ENABLED', True)
        label = consumer.detect_label(image)
        assert label == expected_label
    def test_detect_scale(self, mocker, redis_client):
        # pylint: disable=W0613
        shape = (1, 256, 256, 1)
        consumer = consumers.MesmerConsumer(redis_client, None, 'q')

        image = _get_image(shape[1] * 2, shape[2] * 2, shape[3])

        expected_scale = 1  # random.uniform(0.5, 1.5)
        # model_mpp = random.uniform(0.5, 1.5)

        mock_app = Bunch(
            predict=lambda *x, **y: expected_scale,
            # model_mpp=model_mpp,
            model=Bunch(get_batch_size=lambda *x: 1))

        mocker.patch.object(consumer, 'get_grpc_app', lambda *x: mock_app)

        mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', False)
        scale = consumer.detect_scale(image)
        assert scale == 1  # model_mpp

        mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', True)
        scale = consumer.detect_scale(image)
        assert scale == expected_scale  # * model_mpp
    def test__load_data(self, tmpdir, mocker, redis_client):
        queue = 'track'
        storage = DummyStorage()
        consumer = consumers.CalibanConsumer(redis_client, storage, queue)
        tmpdir = str(tmpdir)
        exp = random.randint(0, 99)

        # test load trk files
        key = 'trk file test'
        mocker.patch('redis_consumer.utils.load_track_file', lambda x: exp)
        result = consumer._load_data(key, tmpdir, 'data.trk')
        assert result == exp
        result = consumer._load_data(key, tmpdir, 'data.trks')
        assert result == exp

        # test bad filetype
        key = 'invalid filetype test'
        with pytest.raises(ValueError):
            consumer._load_data(key, tmpdir, 'data.npz')

        # test bad ndim for tiffstack
        fname = 'test.tiff'
        filepath = os.path.join(tmpdir, fname)
        tifffile.imsave(filepath, _get_image())
        with pytest.raises(ValueError):
            consumer._load_data(key, tmpdir, fname)

        # test successful workflow
        def hget_successful_status(*_):
            return consumer.final_status

        def hget_failed_status(*_):
            return consumer.failed_status

        def write_child_tiff(*_, **__):
            letters = string.ascii_lowercase
            name = ''.join(random.choice(letters) for i in range(12))
            path = os.path.join(tmpdir, '{}.tiff'.format(name))
            tifffile.imsave(path, _get_image(21, 21))
            return [path]

        mocker.patch.object(settings, 'INTERVAL', 0)
        mocker.patch.object(redis_client, 'hget', hget_successful_status)
        mocker.patch('redis_consumer.utils.iter_image_archive',
                     write_child_tiff)

        for label_detect in (True, False):
            mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', label_detect)
            mocker.patch.object(settings, 'LABEL_DETECT_ENABLED', label_detect)

            tifffile.imsave(filepath, np.random.random((3, 21, 21)))
            results = consumer._load_data(key, tmpdir, fname)
            X, y = results.get('X'), results.get('y')
            assert isinstance(X, np.ndarray)
            assert isinstance(y, np.ndarray)
            assert X.shape == y.shape

        # test failed child
        with pytest.raises(RuntimeError):
            mocker.patch.object(redis_client, 'hget', hget_failed_status)
            consumer._load_data(key, tmpdir, fname)

        # test wrong number of images in the test file
        with pytest.raises(RuntimeError):
            mocker.patch.object(redis_client, 'hget', hget_successful_status)
            mocker.patch('redis_consumer.utils.iter_image_archive',
                         lambda *x: range(1, 3))
            consumer._load_data(key, tmpdir, fname)
 def write_child_tiff(*_, **__):
     letters = string.ascii_lowercase
     name = ''.join(random.choice(letters) for i in range(12))
     path = os.path.join(tmpdir, '{}.tiff'.format(name))
     tifffile.imsave(path, _get_image(21, 21))
     return [path]