def test_assert_when_redis_timeout_is_too_short(self): with self.assertRaises(AssertionError): make_redis_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, redis_timeout=2, )
def test_lock_is_removed_after_func_is_finished(self): redis_params = make_redis_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) server = fakeredis.FakeServer() with patch('gokart.redis_lock.redis.Redis') as redis_mock: redis_mock.return_value = fakeredis.FakeRedis( server=server, host=redis_params.redis_host, port=redis_params.redis_port) mock_func = MagicMock() resulted = with_lock(func=mock_func, redis_params=redis_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args self.assertTupleEqual(called_args, (123, )) self.assertDictEqual(called_kwargs, dict(b='abc')) self.assertEqual(resulted, mock_func()) fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): fake_redis[redis_params.redis_key]
def make_model_target(self, relative_file_path: str, save_function: Callable[[Any, str], None], load_function: Callable[[str], Any], use_unique_id: bool = True): """ Make target for models which generate multiple files in saving, e.g. gensim.Word2Vec, Tensorflow, and so on. :param relative_file_path: A file path to save. :param save_function: A function to save a model. This takes a model object and a file path. :param load_function: A function to load a model. This takes a file path and returns a model object. :param use_unique_id: If this is true, add an unique id to a file base name. """ file_path = os.path.join(self.workspace_directory, relative_file_path) assert relative_file_path[-3:] == 'zip', f'extension must be zip, but {relative_file_path} is passed.' unique_id = self.make_unique_id() if use_unique_id else None redis_params = make_redis_params(file_path=file_path, unique_id=unique_id, redis_host=self.redis_host, redis_port=self.redis_port, redis_timeout=self.redis_timeout, redis_fail_on_collision=self.redis_fail_on_collision) return gokart.target.make_model_target(file_path=file_path, temporary_directory=self.local_temporary_directory, unique_id=unique_id, save_function=save_function, load_function=load_function, redis_params=redis_params)
def make_large_data_frame_target( self, relative_file_path: str = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart: formatted_relative_file_path = relative_file_path if relative_file_path is not None else os.path.join( self.__module__.replace(".", "/"), f"{type(self).__name__}.zip") file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None redis_params = make_redis_params( file_path=file_path, unique_id=unique_id, redis_host=self.redis_host, redis_port=self.redis_port, redis_timeout=self.redis_timeout, redis_fail_on_collision=self.redis_fail_on_collision) return gokart.target.make_model_target( file_path=file_path, temporary_directory=self.local_temporary_directory, unique_id=unique_id, save_function=gokart.target.LargeDataFrameProcessor( max_byte=max_byte).save, load_function=gokart.target.LargeDataFrameProcessor.load, redis_params=redis_params)
def make_target(file_path: str, unique_id: Optional[str] = None, processor: Optional[FileProcessor] = None, redis_params: RedisParams = None, store_index_in_feather: bool = True) -> TargetOnKart: _redis_params = redis_params if redis_params is not None else make_redis_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather) file_system_target = _make_file_system_target(file_path, processor=processor, store_index_in_feather=store_index_in_feather) return SingleFileTarget(target=file_system_target, processor=processor, redis_params=_redis_params)
def make_target(self, relative_file_path: str = None, use_unique_id: bool = True, processor: Optional[FileProcessor] = None) -> TargetOnKart: formatted_relative_file_path = relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace(".", "/"), f"{type(self).__name__}.pkl") file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None redis_params = make_redis_params(file_path=file_path, unique_id=unique_id, redis_host=self.redis_host, redis_port=self.redis_port, redis_timeout=self.redis_timeout, redis_fail_on_collision=self.redis_fail_on_collision) return gokart.target.make_target(file_path=file_path, unique_id=unique_id, processor=processor, redis_params=redis_params)
def test_make_redis_params_with_no_host(self): result = make_redis_params(file_path='gs://aaa.pkl', unique_id='123', redis_host=None, redis_port='12345', redis_timeout=180, redis_fail_on_collision=False) expected = RedisParams(redis_host=None, redis_port='12345', redis_key='aaa_123', should_redis_lock=False, redis_timeout=180, redis_fail_on_collision=False) self.assertEqual(result, expected)
def make_model_target(file_path: str, temporary_directory: str, save_function, load_function, unique_id: Optional[str] = None, redis_params: RedisParams = None) -> TargetOnKart: _redis_params = redis_params if redis_params is not None else make_redis_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) temporary_directory = os.path.join(temporary_directory, hashlib.md5(file_path.encode()).hexdigest()) return ModelTarget(file_path=file_path, temporary_directory=temporary_directory, save_function=save_function, load_function=load_function, redis_params=_redis_params)
def test_make_redis_params_with_valid_host(self): result = make_redis_params(file_path='gs://aaa.pkl', unique_id='123', redis_host='0.0.0.0', redis_port='12345', redis_timeout=180, redis_fail_on_collision=False) expected = RedisParams(redis_host='0.0.0.0', redis_port='12345', redis_key='aaa_123', should_redis_lock=True, redis_timeout=180, redis_fail_on_collision=False, lock_extend_seconds=10) self.assertEqual(result, expected)
def test_check_lock_extended(self): redis_params = make_redis_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, redis_timeout=2, lock_extend_seconds=1, ) with patch('gokart.redis_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis resulted = with_lock(func=self._sample_long_func, redis_params=redis_params)(123, b='abc') expected = dict(a=123, b='abc') self.assertEqual(resulted, expected)
def test_no_redis(self): redis_params = make_redis_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host=None, redis_port=None, ) mock_func = MagicMock() resulted = with_lock(func=mock_func, redis_params=redis_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args self.assertTupleEqual(called_args, (123, )) self.assertDictEqual(called_kwargs, dict(b='abc')) self.assertEqual(resulted, mock_func())
def test_use_redis(self): redis_params = make_redis_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) with patch('gokart.redis_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis mock_func = MagicMock() resulted = with_lock(func=mock_func, redis_params=redis_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args self.assertTupleEqual(called_args, (123, )) self.assertDictEqual(called_kwargs, dict(b='abc')) self.assertEqual(resulted, mock_func())
def test_lock_is_removed_after_func_is_finished_with_error(self): redis_params = make_redis_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) server = fakeredis.FakeServer() with patch('gokart.redis_lock.redis.Redis') as redis_mock: redis_mock.return_value = fakeredis.FakeRedis( server=server, host=redis_params.redis_host, port=redis_params.redis_port) try: with_lock(func=self._sample_func_with_error, redis_params=redis_params)(123, b='abc') except Exception: fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): fake_redis[redis_params.redis_key]