Example #1
0
def _analyze(args, cell):
  # For now, always run python2. If needed we can run python3 when the current kernel
  # is py3. Since now our transform cannot work on py3 anyway, I would rather run
  # everything with python2.
  cmd_args = ['python', 'analyze.py', '--output', _abs_path(args['output'])]
  if args['cloud']:
    cmd_args.append('--cloud')

  training_data = args['training_data']
  if args['cloud']:
    tmpdir = os.path.join(args['output'], 'tmp')
  else:
    tmpdir = tempfile.mkdtemp()

  try:
    if isinstance(training_data, dict):
      if 'csv' in training_data and 'schema' in training_data:
        schema = training_data['schema']
        schema_file = _create_json_file(tmpdir, schema, 'schema.json')
        cmd_args.append('--csv=' + _abs_path(training_data['csv']))
        cmd_args.extend(['--schema', schema_file])
      elif 'bigquery_table' in training_data:
        cmd_args.extend(['--bigquery', training_data['bigquery_table']])
      elif 'bigquery_sql' in training_data:
        # see https://cloud.google.com/bigquery/querying-data#temporary_and_permanent_tables
        print('Creating temporary table that will be deleted in 24 hours')
        r = bq.Query(training_data['bigquery_sql']).execute().result()
        cmd_args.extend(['--bigquery', r.full_name])
      else:
        raise ValueError('Invalid training_data dict. '
                         'Requires either "csv_file_pattern" and "csv_schema", or "bigquery".')
    elif isinstance(training_data, google.datalab.ml.CsvDataSet):
      schema_file = _create_json_file(tmpdir, training_data.schema, 'schema.json')
      for file_name in training_data.input_files:
        cmd_args.append('--csv=' + _abs_path(file_name))

      cmd_args.extend(['--schema', schema_file])
    elif isinstance(training_data, google.datalab.ml.BigQueryDataSet):
      # TODO: Support query too once command line supports query.
      cmd_args.extend(['--bigquery', training_data.table])
    else:
      raise ValueError('Invalid training data. Requires either a dict, '
                       'a google.datalab.ml.CsvDataSet, or a google.datalab.ml.BigQueryDataSet.')

    features = args['features']
    features_file = _create_json_file(tmpdir, features, 'features.json')
    cmd_args.extend(['--features', features_file])

    if args['package']:
      code_path = os.path.join(tmpdir, 'package')
      _archive.extract_archive(args['package'], code_path)
    else:
      code_path = MLTOOLBOX_CODE_PATH

    _shell_process.run_and_monitor(cmd_args, os.getpid(), cwd=code_path)
  finally:
    file_io.delete_recursively(tmpdir)
Example #2
0
def extract_archive(archive_path, dest):
    """Extract a local or GCS archive file to a folder.

  Args:
    archive_path: local or gcs path to a *.tar.gz or *.tar file
    dest: local folder the archive will be extracted to
  """
    # Make the dest folder if it does not exist
    if not os.path.isdir(dest):
        os.makedirs(dest)

    try:
        tmpfolder = None

        if (not tf.gfile.Exists(archive_path)
            ) or tf.gfile.IsDirectory(archive_path):
            raise ValueError('archive path %s is not a file' % archive_path)

        if archive_path.startswith('gs://'):
            # Copy the file to a local temp folder
            tmpfolder = tempfile.mkdtemp()
            cmd_args = ['gsutil', 'cp', archive_path, tmpfolder]
            _shell_process.run_and_monitor(cmd_args, os.getpid())
            archive_path = os.path.join(tmpfolder, os.path.name(archive_path))

        if archive_path.lower().endswith('.tar.gz'):
            flags = '-xzf'
        elif archive_path.lower().endswith('.tar'):
            flags = '-xf'
        else:
            raise ValueError('Only tar.gz or tar.Z files are supported.')

        cmd_args = ['tar', flags, archive_path, '-C', dest]
        _shell_process.run_and_monitor(cmd_args, os.getpid())
    finally:
        if tmpfolder:
            shutil.rmtree(tmpfolder)
Example #3
0
def extract_archive(archive_path, dest):
  """Extract a local or GCS archive file to a folder.

  Args:
    archive_path: local or gcs path to a *.tar.gz or *.tar file
    dest: local folder the archive will be extracted to
  """
  # Make the dest folder if it does not exist
  if not os.path.isdir(dest):
    os.makedirs(dest)

  try:
    tmpfolder = None

    if (not tf.gfile.Exists(archive_path)) or tf.gfile.IsDirectory(archive_path):
      raise ValueError('archive path %s is not a file' % archive_path)

    if archive_path.startswith('gs://'):
      # Copy the file to a local temp folder
      tmpfolder = tempfile.mkdtemp()
      cmd_args = ['gsutil', 'cp', archive_path, tmpfolder]
      _shell_process.run_and_monitor(cmd_args, os.getpid())
      archive_path = os.path.join(tmpfolder, os.path.name(archive_path))

    if archive_path.lower().endswith('.tar.gz'):
      flags = '-xzf'
    elif archive_path.lower().endswith('.tar'):
      flags = '-xf'
    else:
      raise ValueError('Only tar.gz or tar.Z files are supported.')

    cmd_args = ['tar', flags, archive_path, '-C', dest]
    _shell_process.run_and_monitor(cmd_args, os.getpid())
  finally:
    if tmpfolder:
      shutil.rmtree(tmpfolder)
Example #4
0
def _train(args, cell):
  if args['cloud_config'] and not args['cloud']:
    raise ValueError('"cloud_config" is provided but no "--cloud". '
                     'Do you want local run or cloud run?')

  job_args = ['--job-dir', _abs_path(args['output']),
              '--analysis', _abs_path(args['analysis'])]

  def _process_train_eval_data(data, arg_name, job_args):
    if isinstance(data, dict):
      if 'csv' in data:
        job_args.append(arg_name + '=' + _abs_path(data['csv']))
        if '--transform' not in job_args:
          job_args.append('--transform')
      elif 'transformed' in data:
        job_args.append(arg_name + '=' + _abs_path(data['transformed']))
      else:
        raise ValueError('Invalid training_data dict. '
                         'Requires either "csv" or "transformed".')
    elif isinstance(data, google.datalab.ml.CsvDataSet):
      for file_name in data.input_files:
        job_args.append(arg_name + '=' + _abs_path(file_name))
    else:
      raise ValueError('Invalid training data. Requires either a dict, or '
                       'a google.datalab.ml.CsvDataSet')

  _process_train_eval_data(args['training_data'], '--train', job_args)
  _process_train_eval_data(args['evaluation_data'], '--eval', job_args)

  # TODO(brandondutra) document that any model_args that are file paths must
  # be given as an absolute path
  if args['model_args']:
    for k, v in six.iteritems(args['model_args']):
      job_args.extend(['--' + k, str(v)])

  try:
    tmpdir = None
    if args['package']:
      tmpdir = tempfile.mkdtemp()
      code_path = os.path.join(tmpdir, 'package')
      _archive.extract_archive(args['package'], code_path)
    else:
      code_path = MLTOOLBOX_CODE_PATH

    if args['cloud']:
      cloud_config = args['cloud_config']
      if not args['output'].startswith('gs://'):
        raise ValueError('Cloud training requires a GCS (starting with "gs://") output.')

      staging_tarball = os.path.join(args['output'], 'staging', 'trainer.tar.gz')
      datalab_ml.package_and_copy(code_path,
                                  os.path.join(code_path, 'setup.py'),
                                  staging_tarball)
      job_request = {
          'package_uris': [staging_tarball],
          'python_module': 'trainer.task',
          'job_dir': args['output'],
          'args': job_args,
      }
      job_request.update(cloud_config)
      job_id = cloud_config.get('job_id', None)
      job = datalab_ml.Job.submit_training(job_request, job_id)
      _show_job_link(job)
    else:
      cmd_args = ['python', '-m', 'trainer.task'] + job_args
      _shell_process.run_and_monitor(cmd_args, os.getpid(), cwd=code_path)
  finally:
    if tmpdir:
      shutil.rmtree(tmpdir)
Example #5
0
def _transform(args, cell):
  if args['cloud_config'] and not args['cloud']:
    raise ValueError('"cloud_config" is provided but no "--cloud". '
                     'Do you want local run or cloud run?')

  cmd_args = ['python', 'transform.py',
              '--output', _abs_path(args['output']),
              '--analysis', _abs_path(args['analysis']),
              '--prefix', args['prefix']]
  if args['cloud']:
    cmd_args.append('--cloud')
    cmd_args.append('--async')
  if args['shuffle']:
    cmd_args.append('--shuffle')
  if args['batch_size']:
    cmd_args.extend(['--batch-size', str(args['batch_size'])])

  training_data = args['training_data']
  if isinstance(training_data, dict):
    if 'csv' in training_data:
      cmd_args.append('--csv=' + _abs_path(training_data['csv']))
    elif 'bigquery_table' in training_data:
      cmd_args.extend(['--bigquery', training_data['bigquery_table']])
    elif 'bigquery_sql' in training_data:
        # see https://cloud.google.com/bigquery/querying-data#temporary_and_permanent_tables
        print('Creating temporary table that will be deleted in 24 hours')
        r = bq.Query(training_data['bigquery_sql']).execute().result()
        cmd_args.extend(['--bigquery', r.full_name])
    else:
      raise ValueError('Invalid training_data dict. '
                       'Requires either "csv", or "bigquery_talbe", or '
                       '"bigquery_sql".')
  elif isinstance(training_data, google.datalab.ml.CsvDataSet):
    for file_name in training_data.input_files:
      cmd_args.append('--csv=' + _abs_path(file_name))
  elif isinstance(training_data, google.datalab.ml.BigQueryDataSet):
    cmd_args.extend(['--bigquery', training_data.table])
  else:
    raise ValueError('Invalid training data. Requires either a dict, '
                     'a google.datalab.ml.CsvDataSet, or a google.datalab.ml.BigQueryDataSet.')

  cloud_config = args['cloud_config']
  if cloud_config:
    google.datalab.utils.commands.validate_config(
        cloud_config,
        required_keys=[],
        optional_keys=['num_workers', 'worker_machine_type', 'project_id', 'job_name'])
    if 'num_workers' in cloud_config:
      cmd_args.extend(['--num-workers', str(cloud_config['num_workers'])])
    if 'worker_machine_type' in cloud_config:
      cmd_args.extend(['--worker-machine-type', cloud_config['worker_machine_type']])
    if 'project_id' in cloud_config:
      cmd_args.extend(['--project-id', cloud_config['project_id']])
    else:
      cmd_args.extend(['--project-id', google.datalab.Context.default().project_id])
    if 'job_name' in cloud_config:
      cmd_args.extend(['--job-name', cloud_config['job_name']])

  try:
    tmpdir = None
    if args['package']:
      tmpdir = tempfile.mkdtemp()
      code_path = os.path.join(tmpdir, 'package')
      _archive.extract_archive(args['package'], code_path)
    else:
      code_path = MLTOOLBOX_CODE_PATH
    _shell_process.run_and_monitor(cmd_args, os.getpid(), cwd=code_path)
  finally:
    if tmpdir:
      shutil.rmtree(tmpdir)