def test__consume(self, mocker, redis_client):
        # pylint: disable=W0613
        queue = 'multiplex'
        storage = DummyStorage()

        consumer = consumers.MesmerConsumer(redis_client, storage, queue)

        empty_data = {'input_file_name': 'file.tiff'}

        output_shape = (1, 256, 256, 2)

        mock_app = Bunch(
            predict=lambda *x, **y: np.random.randint(1, 5, size=output_shape),
            model_mpp=1,
            model=Bunch(get_batch_size=lambda *x: 1,
                        input_shape=(1, 32, 32, 1)))

        mocker.patch.object(consumer, 'get_grpc_app', lambda *x, **_: mock_app)
        mocker.patch.object(consumer, 'get_image_scale', lambda *x, **_: 1)
        mocker.patch.object(consumer, 'validate_model_input',
                            lambda *x, **_: x[0])
        mocker.patch.object(consumer, 'detect_dimension_order',
                            lambda *x, **_: 'YXC')

        test_hash = 'some hash'

        redis_client.hmset(test_hash, empty_data)
        result = consumer._consume(test_hash)
        assert result == consumer.final_status
        result = redis_client.hget(test_hash, 'status')
        assert result == consumer.final_status
    def test_get_image_label(self, mocker, redis_client):
        queue = 'q'
        stg = DummyStorage()
        consumer = consumers.SegmentationConsumer(redis_client, stg, queue)
        image = _get_image(256, 256, 1)

        # test no label provided
        expected = 1
        mocker.patch.object(consumer, 'detect_label', lambda *x: expected)
        label = consumer.get_image_label(None, image, 'some hash')
        assert label == expected

        # test label provided
        expected = 2
        label = consumer.get_image_label(expected, image, 'some hash')
        assert label == expected

        # test label provided is invalid
        with pytest.raises(ValueError):
            label = -1
            consumer.get_image_label(label, image, 'some hash')

        # test label provided is bad type
        with pytest.raises(ValueError):
            label = 'badval'
            consumer.get_image_label(label, image, 'some hash')
    def test__consume(self, mocker, redis_client):
        queue = 'track'
        storage = DummyStorage()
        test_hash = 0

        dummy_results = {
            'y_tracked': np.zeros((32, 32, 1)),
            'tracks': []
        }

        mock_model = Bunch(
            get_batch_size=lambda *x: 1,
            input_shape=(1, 32, 32, 1)
        )
        mock_app = Bunch(
            predict=lambda *x, **y: dummy_results,
            track=lambda *x, **y: dummy_results,
            model_mpp=1,
            model=mock_model,
        )

        consumer = consumers.CalibanConsumer(redis_client, storage, queue)

        mocker.patch.object(consumer, 'get_grpc_app',
                            lambda *x, **y: mock_app)
        # mock get_model_wrapper for neighborhood encoder
        mocker.patch.object(consumer, 'get_model_wrapper',
                            lambda *x, **y: mock_model)

        frames = 3
        dummy_data = {
            'X': np.array([_get_image(21, 21) for _ in range(frames)]),
            'y': np.random.randint(0, 9, size=(frames, 21, 21)),
        }

        mocker.patch.object(consumer, '_load_data', lambda *x: dummy_data)

        # test finished statuses are returned
        for status in (consumer.failed_status, consumer.final_status):
            test_hash += 1
            data = {'input_file_name': 'file.tiff', 'status': status}
            redis_client.hmset(test_hash, data)
            result = consumer._consume(test_hash)
            assert result == status

        # test new key is processed
        test_hash += 1
        data = {'input_file_name': 'file.tiff', 'status': 'new'}
        redis_client.hmset(test_hash, data)
        result = consumer._consume(test_hash)
        assert result == consumer.final_status
        assert redis_client.hget(test_hash, 'status') == consumer.final_status
    def test_is_valid_hash(self, mocker, redis_client):
        queue = 'track'
        storage = DummyStorage()
        consumer = consumers.CalibanConsumer(redis_client, storage, queue)

        mocker.patch.object(redis_client, 'hget', lambda x, y: x.split(':')[-1])

        assert consumer.is_valid_hash(None) is False
        assert consumer.is_valid_hash('predict:123456789:file.png') is False
        assert consumer.is_valid_hash('predict:1234567890:file.tiff') is True
        assert consumer.is_valid_hash('predict:1234567890:file.png') is False
        assert consumer.is_valid_hash('track:1234567890:file.ZIp') is False
        assert consumer.is_valid_hash('track:123456789:file.zip') is False
        assert consumer.is_valid_hash('track:1234567890:file.png') is False
        assert consumer.is_valid_hash('track:1234567890:file.tiff') is True
        assert consumer.is_valid_hash('track:1234567890:file.trk') is True
        assert consumer.is_valid_hash('track:1234567890:file.trks') is True
    def test__consume_finished_status(self, redis_client):
        queue = 'q'
        storage = DummyStorage()

        consumer = consumers.MesmerConsumer(redis_client, storage, queue)

        empty_data = {'input_file_name': 'file.tiff'}

        test_hash = 0
        # test finished statuses are returned
        for status in (consumer.failed_status, consumer.final_status):
            test_hash += 1
            data = empty_data.copy()
            data['status'] = status
            redis_client.hmset(test_hash, data)
            result = consumer._consume(test_hash)
            assert result == status
            result = redis_client.hget(test_hash, 'status')
            assert result == status
            test_hash += 1
    def test__load_data(self, tmpdir, mocker, redis_client):
        queue = 'track'
        storage = DummyStorage()
        consumer = consumers.CalibanConsumer(redis_client, storage, queue)
        tmpdir = str(tmpdir)
        exp = random.randint(0, 99)

        # test load trk files
        key = 'trk file test'
        mocker.patch('redis_consumer.utils.load_track_file', lambda x: exp)
        result = consumer._load_data(key, tmpdir, 'data.trk')
        assert result == exp
        result = consumer._load_data(key, tmpdir, 'data.trks')
        assert result == exp

        # test bad filetype
        key = 'invalid filetype test'
        with pytest.raises(ValueError):
            consumer._load_data(key, tmpdir, 'data.npz')

        # test bad ndim for tiffstack
        fname = 'test.tiff'
        filepath = os.path.join(tmpdir, fname)
        tifffile.imsave(filepath, _get_image())
        with pytest.raises(ValueError):
            consumer._load_data(key, tmpdir, fname)

        # test successful workflow
        def hget_successful_status(*_):
            return consumer.final_status

        def hget_failed_status(*_):
            return consumer.failed_status

        def write_child_tiff(*_, **__):
            letters = string.ascii_lowercase
            name = ''.join(random.choice(letters) for i in range(12))
            path = os.path.join(tmpdir, '{}.tiff'.format(name))
            tifffile.imsave(path, _get_image(21, 21))
            return [path]

        mocker.patch.object(settings, 'INTERVAL', 0)
        mocker.patch.object(redis_client, 'hget', hget_successful_status)
        mocker.patch('redis_consumer.utils.iter_image_archive',
                     write_child_tiff)

        for label_detect in (True, False):
            mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', label_detect)
            mocker.patch.object(settings, 'LABEL_DETECT_ENABLED', label_detect)

            tifffile.imsave(filepath, np.random.random((3, 21, 21)))
            results = consumer._load_data(key, tmpdir, fname)
            X, y = results.get('X'), results.get('y')
            assert isinstance(X, np.ndarray)
            assert isinstance(y, np.ndarray)
            assert X.shape == y.shape

        # test failed child
        with pytest.raises(RuntimeError):
            mocker.patch.object(redis_client, 'hget', hget_failed_status)
            consumer._load_data(key, tmpdir, fname)

        # test wrong number of images in the test file
        with pytest.raises(RuntimeError):
            mocker.patch.object(redis_client, 'hget', hget_successful_status)
            mocker.patch('redis_consumer.utils.iter_image_archive',
                         lambda *x: range(1, 3))
            consumer._load_data(key, tmpdir, fname)