def set_project_id(self, project_id): """ Set the project_id for the context. """ self._project_id = project_id if self == Context._global_context: try: from google.datalab import Context as new_context new_context.default().set_project_id(project_id) except ImportError: # If the new library is not loaded, then we have nothing to do. pass
def test_default_context(self, mock_save_project_id, mock_get_default_project_id, mock_get_credentials): mock_get_default_project_id.return_value = 'default_project' mock_get_credentials.return_value = '' c = Context.default() default_project = c.project_id self.assertEqual(default_project, 'default_project') # deliberately change the default project and make sure it's reset c.set_project_id('test_project4') self.assertEqual(Context.default().project_id, 'default_project')
def _batch_predict(args, cell): if args['cloud_config'] and not args['cloud']: raise ValueError('"cloud_config" is provided but no "--cloud". ' 'Do you want local run or cloud run?') if args['cloud']: parts = args['model'].split('.') if len(parts) != 2: raise ValueError('Invalid model name for cloud prediction. Use "model.version".') version_name = ('projects/%s/models/%s/versions/%s' % (Context.default().project_id, parts[0], parts[1])) cloud_config = args['cloud_config'] or {} job_id = cloud_config.pop('job_id', None) job_request = { 'version_name': version_name, 'data_format': 'TEXT', 'input_paths': file_io.get_matching_files(args['prediction_data']['csv']), 'output_path': args['output'], } job_request.update(cloud_config) job = datalab_ml.Job.submit_batch_prediction(job_request, job_id) _show_job_link(job) else: print('local prediction...') _local_predict.local_batch_predict(args['model'], args['prediction_data']['csv'], args['output'], args['format'], args['batch_size']) print('done.')
def _show_job_link(job): log_url_query_strings = { 'project': Context.default().project_id, 'resource': 'ml.googleapis.com/job_id/' + job.info['jobId'] } log_url = 'https://console.developers.google.com/logs/viewer?' + \ urllib.urlencode(log_url_query_strings) html = 'Job "%s" submitted.' % job.info['jobId'] html += '<p>Click <a href="%s" target="_blank">here</a> to view cloud log. <br/>' % log_url IPython.display.display_html(html, raw=True)
def _batch_predict(args, cell): env = google.datalab.utils.commands.notebook_environment() cell_data = google.datalab.utils.commands.parse_config(cell, env) required_keys = ['prediction_data'] if args['cloud']: required_keys.append('cloud') google.datalab.utils.commands.validate_config(cell_data, required_keys=required_keys) data = cell_data['prediction_data'] google.datalab.utils.commands.validate_config( data, required_keys=['csv_file_pattern']) if args['cloud']: parts = args['model'].split('.') if len(parts) != 2: raise ValueError( 'Invalid model name for cloud prediction. Use "model.version".' ) version_name = ('projects/%s/models/%s/versions/%s' % (Context.default().project_id, parts[0], parts[1])) cloud_config = cell_data['cloud'] job_id = cloud_config.pop('job_id', None) job_request = { 'version_name': version_name, 'data_format': 'TEXT', 'input_paths': file_io.get_matching_files(data['csv_file_pattern']), 'output_path': args['output_dir'], } job_request.update(cloud_config) job = datalab_ml.Job.submit_batch_prediction(job_request, job_id) _show_job_link(job) else: print('local prediction...') _local_predict.local_batch_predict(args['model'], data['csv_file_pattern'], args['output_dir'], args['output_format'], args['batch_size']) print('done.')
def test_default_project(self, mock_save_project_id, mock_get_default_project_id, mock_get_credentials): # verify setting the project on a default Context object sets # the global default project global default_project default_project = '' def save_project(project=None): global default_project default_project = project def get_project(): global default_project return default_project mock_save_project_id.side_effect = save_project mock_get_default_project_id.side_effect = get_project mock_get_credentials.return_value = '' c = Context.default() dummy_project = 'test_project3' c.set_project_id(dummy_project) self.assertEqual(du.get_default_project_id(), dummy_project)
def _default_project(): from google.datalab import Context return Context.default().project_id