def test_unsupported_schema(self):
    """Test supported schema values.

    Note that not all valid BQ schema values are valid/used in the structured
    data package
    """

    unsupported_col_types = ['bytes', 'boolean', 'timestamp', 'date', 'time',
                             'datetime', 'record']
    for col_type in unsupported_col_types:
      schema = 'col_name:%s' % col_type

      job = core_sd.analyze_async(
        'some_dir',
        dlml.CsvDataSet(
            file_pattern=['file1.txt'],
            schema=schema),
        cloud=False).wait()
      self.assertIn('Schema contains an unsupported type %s.' % col_type,
                    job.fatal_error.message)

      job = core_sd.analyze_async(
        'gs://some_dir',
        dlml.CsvDataSet(
            file_pattern=['gs://file1.txt'],
            schema=schema),
        cloud=True,
        project_id='junk_project_id').wait()
      self.assertIn('Schema contains an unsupported type %s.' % col_type,
                    job.fatal_error.message)
 def test_cloud_but_local_files(self):
   job = core_sd.analyze_async(
       'gs://some_dir',
       dlml.CsvDataSet(
           file_pattern=['file1.txt'],
           schema='col1:STRING,col2:INTEGER,col3:FLOAT'),
       project_id='project_id',
       cloud=True).wait()
   self.assertIn('File file1.txt is not a gcs path', job.fatal_error.message)
  def _run_train(self):
    reglinear.train(
        train_dataset=dlml.CsvDataSet(
            file_pattern=self._csv_train_filename,
            schema_file=self._schema_filename),
        eval_dataset=dlml.CsvDataSet(
            file_pattern=self._csv_eval_filename,
            schema_file=self._schema_filename),
        analysis_dir=self._preprocess_output,
        output_dir=self._train_output,
        features=self._input_features_filename,
        max_steps=100,
        train_batch_size=100)

    self.assertTrue(os.path.isfile(
        os.path.join(self._train_output, 'model', 'saved_model.pb')))
    self.assertTrue(os.path.isfile(
        os.path.join(self._train_output, 'evaluation_model', 'saved_model.pb')))
 def test_projectid(self):
   """Test passing project id but cloud is false"""
   job = core_sd.analyze_async(
       'some_dir',
       dlml.CsvDataSet(
           file_pattern=['file1.txt'],
           schema='col1:STRING,col2:INTEGER,col3:FLOAT'),
       project_id='project_id').wait()
   self.assertIn('project_id only needed if cloud is True',
                 job.fatal_error.message)
 def test_csvdataset_one_file(self):
   """Test CsvDataSet has only one file/pattern"""
   # TODO(brandondutra) remove this restriction
   job = core_sd.analyze_async(
       'some_dir',
       dlml.CsvDataSet(
           file_pattern=['file1.txt', 'file2.txt'],
           schema='col1:STRING,col2:INTEGER,col3:FLOAT')).wait()
   self.assertIn('should be built with a file pattern',
                 job.fatal_error.message)
  def _run_analyze(self):
    reglinear.analyze(
        output_dir=self._preprocess_output,
        dataset=dlml.CsvDataSet(
            file_pattern=self._csv_train_filename,
            schema_file=self._schema_filename))

    self.assertTrue(os.path.isfile(
        os.path.join(self._preprocess_output, 'stats.json')))
    self.assertTrue(os.path.isfile(
        os.path.join(self._preprocess_output, 'vocab_str1.csv')))