Beispiel #1
0
  def test_extract_archive_tar(self):
    """Tests extracting tar.gz files."""

    # Make a tar.gz file
    archive_path = os.path.join(self._root_folder, 'test.tar')
    cmd = ['tar', '-cf', archive_path, '-C', self._src_folder, self._filename1, self._filename2]
    subprocess.check_call(cmd)

    # Undo it
    dest = os.path.join(self._root_folder, 'output')
    _archive.extract_archive(archive_path, dest)

    expected_file1 = os.path.join(dest, self._filename1)
    expected_file2 = os.path.join(dest, self._filename2)
    self.assertTrue(os.path.isfile(expected_file1))
    self.assertTrue(os.path.isfile(expected_file2))

    with open(expected_file2, 'r') as f:
      file_contents = f.read()
    self.assertTrue(file_contents, 'and this is file2')
Beispiel #2
0
    def test_extract_archive_tar(self):
        """Tests extracting tar.gz files."""

        # Make a tar.gz file
        archive_path = os.path.join(self._root_folder, 'test.tar')
        cmd = [
            'tar', '-cf', archive_path, '-C', self._src_folder,
            self._filename1, self._filename2
        ]
        subprocess.check_call(cmd)

        # Undo it
        dest = os.path.join(self._root_folder, 'output')
        _archive.extract_archive(archive_path, dest)

        expected_file1 = os.path.join(dest, self._filename1)
        expected_file2 = os.path.join(dest, self._filename2)
        self.assertTrue(os.path.isfile(expected_file1))
        self.assertTrue(os.path.isfile(expected_file2))

        with open(expected_file2, 'r') as f:
            file_contents = f.read()
        self.assertTrue(file_contents, 'and this is file2')
Beispiel #3
0
def _train(args, cell):
    env = google.datalab.utils.commands.notebook_environment()
    cell_data = google.datalab.utils.commands.parse_config(cell, env)
    required_keys = ['training_data', 'evaluation_data']
    if args['cloud']:
        required_keys.append('cloud')

    google.datalab.utils.commands.validate_config(cell_data,
                                                  required_keys=required_keys,
                                                  optional_keys=['model_args'])
    job_args = [
        '--job-dir',
        _abs_path(args['output_dir']), '--output-dir-from-analysis-step',
        _abs_path(args['output_dir_from_analysis_step'])
    ]

    def _process_train_eval_data(data, arg_name, job_args):
        if isinstance(data, dict):
            if 'csv_file_pattern' in data:
                job_args.extend(
                    [arg_name, _abs_path(data['csv_file_pattern'])])
                if '--run-transforms' not in job_args:
                    job_args.append('--run-transforms')
            elif 'transformed_file_pattern' in data:
                job_args.extend(
                    [arg_name,
                     _abs_path(data['transformed_file_pattern'])])
            else:
                raise ValueError(
                    'Invalid training_data dict. ' +
                    'Requires either "csv_file_pattern" or "transformed_file_pattern".'
                )
        elif isinstance(data, google.datalab.ml.CsvDataSet):
            for file_name in data.input_files:
                job_args.append(arg_name + '=' + _abs_path(file_name))
        else:
            raise ValueError(
                'Invalid training data. Requires either a dict, or ' +
                'a google.datalab.ml.CsvDataSet')

    _process_train_eval_data(cell_data['training_data'], '--train-data-paths',
                             job_args)
    _process_train_eval_data(cell_data['evaluation_data'], '--eval-data-paths',
                             job_args)

    # TODO(brandondutra) document that any model_args that are file paths must
    # be given as an absolute path
    if 'model_args' in cell_data:
        for k, v in six.iteritems(cell_data['model_args']):
            job_args.extend(['--' + k, str(v)])

    try:
        tmpdir = None
        if args['package']:
            tmpdir = tempfile.mkdtemp()
            code_path = os.path.join(tmpdir, 'package')
            _archive.extract_archive(args['package'], code_path)
        else:
            code_path = MLTOOLBOX_CODE_PATH

        if args['cloud']:
            cloud_config = cell_data['cloud']
            if not args['output_dir'].startswith('gs://'):
                raise ValueError(
                    'Cloud training requires a GCS (starting with "gs://") output_dir.'
                )

            staging_tarball = os.path.join(args['output_dir'], 'staging',
                                           'trainer.tar.gz')
            datalab_ml.package_and_copy(code_path,
                                        os.path.join(code_path, 'setup.py'),
                                        staging_tarball)
            job_request = {
                'package_uris': [staging_tarball],
                'python_module': 'trainer.task',
                'job_dir': args['output_dir'],
                'args': job_args,
            }
            job_request.update(cloud_config)
            job_id = cloud_config.get('job_id', None)
            job = datalab_ml.Job.submit_training(job_request, job_id)
            _show_job_link(job)
        else:
            cmd_args = ['python', '-m', 'trainer.task'] + job_args
            _shell_process.run_and_monitor(cmd_args,
                                           os.getpid(),
                                           cwd=code_path)
    finally:
        if tmpdir:
            shutil.rmtree(tmpdir)
Beispiel #4
0
def _transform(args, cell):
    env = google.datalab.utils.commands.notebook_environment()
    cell_data = google.datalab.utils.commands.parse_config(cell, env)
    google.datalab.utils.commands.validate_config(
        cell_data, required_keys=['training_data'], optional_keys=['cloud'])
    training_data = cell_data['training_data']
    cmd_args = [
        'python', 'transform.py', '--output-dir',
        _abs_path(args['output_dir']), '--output-dir-from-analysis-step',
        _abs_path(args['output_dir_from_analysis_step']),
        '--output-filename-prefix', args['output_filename_prefix']
    ]
    if args['cloud']:
        cmd_args.append('--cloud')
        cmd_args.append('--async')
    if args['shuffle']:
        cmd_args.append('--shuffle')
    if args['batch_size']:
        cmd_args.extend(['--batch-size', str(args['batch_size'])])

    if isinstance(training_data, dict):
        if 'csv_file_pattern' in training_data:
            cmd_args.extend([
                '--csv-file-pattern',
                _abs_path(training_data['csv_file_pattern'])
            ])
        elif 'bigquery_table' in training_data:
            cmd_args.extend(
                ['--bigquery-table', training_data['bigquery_table']])
        elif 'bigquery_sql' in training_data:
            # see https://cloud.google.com/bigquery/querying-data#temporary_and_permanent_tables
            print('Creating temporary table that will be deleted in 24 hours')
            r = bq.Query(training_data['bigquery_sql']).execute().result()
            cmd_args.extend(['--bigquery-table', r.full_name])
        else:
            raise ValueError(
                'Invalid training_data dict. ' +
                'Requires either "csv_file_pattern" and "csv_schema", or "bigquery".'
            )
    elif isinstance(training_data, google.datalab.ml.CsvDataSet):
        for file_name in training_data.input_files:
            cmd_args.append('--csv-file-pattern=' + _abs_path(file_name))
    elif isinstance(training_data, google.datalab.ml.BigQueryDataSet):
        cmd_args.extend(['--bigquery-table', training_data.table])
    else:
        raise ValueError(
            'Invalid training data. Requires either a dict, ' +
            'a google.datalab.ml.CsvDataSet, or a google.datalab.ml.BigQueryDataSet.'
        )

    if 'cloud' in cell_data:
        cloud_config = cell_data['cloud']
        google.datalab.utils.commands.validate_config(
            cloud_config,
            required_keys=[],
            optional_keys=['num_workers', 'worker_machine_type', 'project_id'])
        if 'num_workers' in cloud_config:
            cmd_args.extend(
                ['--num-workers',
                 str(cloud_config['num_workers'])])
        if 'worker_machine_type' in cloud_config:
            cmd_args.extend(
                ['--worker-machine-type', cloud_config['worker_machine_type']])
        if 'project_id' in cloud_config:
            cmd_args.extend(['--project-id', cloud_config['project_id']])

    if '--project-id' not in cmd_args:
        cmd_args.extend(
            ['--project-id',
             google.datalab.Context.default().project_id])

    try:
        tmpdir = None
        if args['package']:
            tmpdir = tempfile.mkdtemp()
            code_path = os.path.join(tmpdir, 'package')
            _archive.extract_archive(args['package'], code_path)
        else:
            code_path = MLTOOLBOX_CODE_PATH

        _shell_process.run_and_monitor(cmd_args, os.getpid(), cwd=code_path)
    finally:
        if tmpdir:
            shutil.rmtree(tmpdir)
Beispiel #5
0
def _analyze(args, cell):
    env = google.datalab.utils.commands.notebook_environment()
    cell_data = google.datalab.utils.commands.parse_config(cell, env)
    google.datalab.utils.commands.validate_config(
        cell_data, required_keys=['training_data', 'features'])
    # For now, always run python2. If needed we can run python3 when the current kernel
    # is py3. Since now our transform cannot work on py3 anyway, I would rather run
    # everything with python2.
    cmd_args = [
        'python', 'analyze.py', '--output-dir',
        _abs_path(args['output_dir'])
    ]
    if args['cloud']:
        cmd_args.append('--cloud')

    training_data = cell_data['training_data']
    if args['cloud']:
        tmpdir = os.path.join(args['output_dir'], 'tmp')
    else:
        tmpdir = tempfile.mkdtemp()

    try:
        if isinstance(training_data, dict):
            if 'csv_file_pattern' in training_data and 'csv_schema' in training_data:
                schema = training_data['csv_schema']
                schema_file = _create_json_file(tmpdir, schema, 'schema.json')
                cmd_args.extend([
                    '--csv-file-pattern',
                    _abs_path(training_data['csv_file_pattern'])
                ])
                cmd_args.extend(['--csv-schema-file', schema_file])
            elif 'bigquery_table' in training_data:
                cmd_args.extend(
                    ['--bigquery-table', training_data['bigquery_table']])
            elif 'bigquery_sql' in training_data:
                # see https://cloud.google.com/bigquery/querying-data#temporary_and_permanent_tables
                print(
                    'Creating temporary table that will be deleted in 24 hours'
                )
                r = bq.Query(training_data['bigquery_sql']).execute().result()
                cmd_args.extend(['--bigquery-table', r.full_name])
            else:
                raise ValueError(
                    'Invalid training_data dict. ' +
                    'Requires either "csv_file_pattern" and "csv_schema", or "bigquery".'
                )
        elif isinstance(training_data, google.datalab.ml.CsvDataSet):
            schema_file = _create_json_file(tmpdir, training_data.schema,
                                            'schema.json')
            for file_name in training_data.input_files:
                cmd_args.append('--csv-file-pattern=' + _abs_path(file_name))

            cmd_args.extend(['--csv-schema-file', schema_file])
        elif isinstance(training_data, google.datalab.ml.BigQueryDataSet):
            # TODO: Support query too once command line supports query.
            cmd_args.extend(['--bigquery-table', training_data.table])
        else:
            raise ValueError(
                'Invalid training data. Requires either a dict, ' +
                'a google.datalab.ml.CsvDataSet, or a google.datalab.ml.BigQueryDataSet.'
            )

        features = cell_data['features']
        features_file = _create_json_file(tmpdir, features, 'features.json')
        cmd_args.extend(['--features-file', features_file])

        if args['package']:
            code_path = os.path.join(tmpdir, 'package')
            _archive.extract_archive(args['package'], code_path)
        else:
            code_path = MLTOOLBOX_CODE_PATH

        _shell_process.run_and_monitor(cmd_args, os.getpid(), cwd=code_path)
    finally:
        file_io.delete_recursively(tmpdir)