def _build(cmd_flags: str, mock_download_and_prepare: bool = True) -> List[str]: """Executes `tfds build {cmd_flags}` and returns the list of generated ds.""" # Execute the command args = main._parse_flags(f'tfds build {cmd_flags}'.split()) original_dl_and_prepare = tfds.core.DatasetBuilder.download_and_prepare # Unfortunatelly, `mock.Mock` remove `self` from `call_args`, so we have # to patch manually the function to record the generated_ds. # See: # https://stackoverflow.com/questions/64792295/how-to-get-self-instance-in-mock-mock-call-args generated_ds_names = [] def _download_and_prepare(self, *args, **kwargs): # Remove version from generated name (as only last version can be generated) full_name = '/'.join(self.info.full_name.split('/')[:-1]) generated_ds_names.append(full_name) if mock_download_and_prepare: return else: return original_dl_and_prepare(self, *args, **kwargs) with mock.patch( 'tensorflow_datasets.core.DatasetBuilder.download_and_prepare', _download_and_prepare, ): main.main(args) return generated_ds_names
def test_main(): def _check_exit(status=0, message=None): del message assert status == 0 # Check argparse exit gracefully # Argparse call `sys.exit(0)` when `--version` is passed. with mock.patch('sys.exit', _check_exit): main.main(main._parse_flags(['', '--version']))
def _build(cmd_flags: str) -> mock.Mock: """Executes `tfds build` command with the given flags.""" # Execute the command args = main._parse_flags(f'tfds build {cmd_flags}'.split()) with mock.patch( 'tensorflow_datasets.core.DatasetBuilder.download_and_prepare' ) as mock_download_and_prepare: main.main(args) return mock_download_and_prepare
def main(args: argparse.Namespace) -> None: if _display_warning: logging.warning( '***`tfds build` should be used instead of `download_and_prepare`.***' ) if module_import.value: args.imports = module_import.value if builder_config_id.value is not None: args.config_idx = builder_config_id.value main_cli.main(args)
def _run_cli(cmd: str) -> None: main.main(main._parse_flags([''] + cmd.split()))
def test_main(): main.main(main._parse_flags(['', '--version']))