Exemplo n.º 1
0
  def test_is_signed_in(self, mock_get_credentials):
    mock_get_credentials.side_effect = Exception('No creds!')
    self.assertFalse(Context._is_signed_in())

    mock_get_credentials.return_value = {}
    mock_get_credentials.side_effect = None
    self.assertTrue(Context._is_signed_in())
Exemplo n.º 2
0
  def test_credentials(self):
    dummy_creds = {}
    c = Context('test_project', credentials=dummy_creds)

    self.assertEqual(c.credentials, dummy_creds)

    dummy_creds = {'test': 'test'}
    c.set_credentials(dummy_creds)
    self.assertEqual(c.credentials, dummy_creds)
Exemplo n.º 3
0
 def set_project_id(self, project_id):
   """ Set the project_id for the context. """
   self._project_id = project_id
   if self == Context._global_context:
     try:
       from google.datalab import Context as new_context
       new_context.default().set_project_id(project_id)
     except ImportError:
       # If the new library is not loaded, then we have nothing to do.
       pass
Exemplo n.º 4
0
  def test_config(self):
    dummy_config = {}
    c = Context('test_project', credentials=None, config=dummy_config)

    self.assertEqual(c.config, dummy_config)

    dummy_config = {'test': 'test'}
    c.set_config(dummy_config)
    self.assertEqual(c.config, dummy_config)

    c = Context('test_project', None, None)
    self.assertEqual(c.config, Context._get_default_config())
Exemplo n.º 5
0
  def test_project(self):
    dummy_project = 'test_project'
    c = Context(dummy_project, credentials=None, config=None)

    self.assertEqual(c.project_id, dummy_project)

    dummy_project = 'test_project2'
    c.set_project_id(dummy_project)
    self.assertEqual(c.project_id, dummy_project)

    c = Context(None, None, None)
    with self.assertRaises(Exception):
      print(c.project_id)
Exemplo n.º 6
0
  def test_default_context(self, mock_save_project_id, mock_get_default_project_id,
                           mock_get_credentials):

    mock_get_default_project_id.return_value = 'default_project'
    mock_get_credentials.return_value = ''

    c = Context.default()
    default_project = c.project_id
    self.assertEqual(default_project, 'default_project')

    # deliberately change the default project and make sure it's reset
    c.set_project_id('test_project4')

    self.assertEqual(Context.default().project_id, 'default_project')
Exemplo n.º 7
0
def _batch_predict(args, cell):
  if args['cloud_config'] and not args['cloud']:
    raise ValueError('"cloud_config" is provided but no "--cloud". '
                     'Do you want local run or cloud run?')

  if args['cloud']:
    parts = args['model'].split('.')
    if len(parts) != 2:
      raise ValueError('Invalid model name for cloud prediction. Use "model.version".')

    version_name = ('projects/%s/models/%s/versions/%s' %
                    (Context.default().project_id, parts[0], parts[1]))

    cloud_config = args['cloud_config'] or {}
    job_id = cloud_config.pop('job_id', None)
    job_request = {
      'version_name': version_name,
      'data_format': 'TEXT',
      'input_paths': file_io.get_matching_files(args['prediction_data']['csv']),
      'output_path': args['output'],
    }
    job_request.update(cloud_config)
    job = datalab_ml.Job.submit_batch_prediction(job_request, job_id)
    _show_job_link(job)
  else:
    print('local prediction...')
    _local_predict.local_batch_predict(args['model'],
                                       args['prediction_data']['csv'],
                                       args['output'],
                                       args['format'],
                                       args['batch_size'])
    print('done.')
Exemplo n.º 8
0
def _show_job_link(job):
    log_url_query_strings = {
        'project': Context.default().project_id,
        'resource': 'ml.googleapis.com/job_id/' + job.info['jobId']
    }
    log_url = 'https://console.developers.google.com/logs/viewer?' + \
        urllib.urlencode(log_url_query_strings)
    html = 'Job "%s" submitted.' % job.info['jobId']
    html += '<p>Click <a href="%s" target="_blank">here</a> to view cloud log. <br/>' % log_url
    IPython.display.display_html(html, raw=True)
Exemplo n.º 9
0
def _show_job_link(job):
  log_url_query_strings = {
    'project': Context.default().project_id,
    'resource': 'ml.googleapis.com/job_id/' + job.info['jobId']
  }
  log_url = 'https://console.developers.google.com/logs/viewer?' + \
      urllib.urlencode(log_url_query_strings)
  html = 'Job "%s" submitted.' % job.info['jobId']
  html += '<p>Click <a href="%s" target="_blank">here</a> to view cloud log. <br/>' % log_url
  IPython.display.display_html(html, raw=True)
Exemplo n.º 10
0
def _batch_predict(args, cell):
    env = google.datalab.utils.commands.notebook_environment()
    cell_data = google.datalab.utils.commands.parse_config(cell, env)
    required_keys = ['prediction_data']
    if args['cloud']:
        required_keys.append('cloud')

    google.datalab.utils.commands.validate_config(cell_data,
                                                  required_keys=required_keys)

    data = cell_data['prediction_data']
    google.datalab.utils.commands.validate_config(
        data, required_keys=['csv_file_pattern'])

    if args['cloud']:
        parts = args['model'].split('.')
        if len(parts) != 2:
            raise ValueError(
                'Invalid model name for cloud prediction. Use "model.version".'
            )

        version_name = ('projects/%s/models/%s/versions/%s' %
                        (Context.default().project_id, parts[0], parts[1]))

        cloud_config = cell_data['cloud']
        job_id = cloud_config.pop('job_id', None)
        job_request = {
            'version_name': version_name,
            'data_format': 'TEXT',
            'input_paths':
            file_io.get_matching_files(data['csv_file_pattern']),
            'output_path': args['output_dir'],
        }
        job_request.update(cloud_config)
        job = datalab_ml.Job.submit_batch_prediction(job_request, job_id)
        _show_job_link(job)
    else:
        print('local prediction...')
        _local_predict.local_batch_predict(args['model'],
                                           data['csv_file_pattern'],
                                           args['output_dir'],
                                           args['output_format'],
                                           args['batch_size'])
        print('done.')
Exemplo n.º 11
0
  def test_default_project(self, mock_save_project_id, mock_get_default_project_id,
                           mock_get_credentials):
    # verify setting the project on a default Context object sets
    # the global default project

    global default_project
    default_project = ''

    def save_project(project=None):
      global default_project
      default_project = project

    def get_project():
      global default_project
      return default_project

    mock_save_project_id.side_effect = save_project
    mock_get_default_project_id.side_effect = get_project
    mock_get_credentials.return_value = ''

    c = Context.default()
    dummy_project = 'test_project3'
    c.set_project_id(dummy_project)
    self.assertEqual(du.get_default_project_id(), dummy_project)
Exemplo n.º 12
0
def _default_project():
  from google.datalab import Context
  return Context.default().project_id