def test_create_tfrecords_dataflow_runner(self, mock_beam):
        """Tests `create_tfrecords` Dataflow case."""
        mock_beam.build_pipeline().run().job_id.return_value = 'foo_id'

        df2 = self.test_df.copy()
        df2[constants.IMAGE_URI_KEY] = 'gs://' + df2[constants.IMAGE_URI_KEY]

        outdir = '/tmp/dataflow_runner'

        expected = {
            'job_id':
            'foo_id',
            'dataflow_url':
            'https://console.cloud.google.com/dataflow/jobs/' +
            'us-central1/foo_id?project=foo'
        }

        os.makedirs(outdir, exist_ok=True)
        r = client.create_tfrecords(df2,
                                    runner='DataflowRunner',
                                    output_dir=outdir,
                                    region=self.test_region,
                                    project=self.test_project,
                                    tfrecorder_wheel=self.test_wheel)
        self.assertEqual(r, expected)
Esempio n. 2
0
    def to_tfr(self,
               output_dir: str,
               schema_map: Dict[str,
                                schema.SchemaMap] = schema.image_csv_schema,
               runner: str = 'DirectRunner',
               project: Optional[str] = None,
               region: Optional[str] = None,
               tfrecorder_wheel: Optional[str] = None,
               dataflow_options: Union[Dict[str, Any], None] = None,
               job_label: str = 'to-tfr',
               compression: Optional[str] = 'gzip',
               num_shards: int = 0) -> Dict[str, Any]:
        """TFRecorder Pandas Accessor.

    TFRecorder provides an easy interface to create image-based tensorflow
    records from a dataframe containing GCS locations of the images and labels.

    Usage:
      import tfrecorder

      df.tfrecorder.to_tfr(
          output_dir='gcs://foo/bar/train',
          runner='DirectRunner',
          compression='gzip',
          num_shards=10)

    Args:
      schema_map: A dict mapping column names to supported types.
      output_dir: Local directory or GCS Location to save TFRecords to.
        Note: GCS required for DataflowRunner
      runner: Beam runner. Can be DirectRunner or  DataflowRunner.
      project: GCP project name (Required if DataflowRunner).
      region: GCP region name (Required if DataflowRunner).
      tfrecorder_wheel: Path to the tfrecorder wheel Dataflow will run.
        (create with 'python setup.py sdist' or
        'pip download tfrecorder --no-deps')
      dataflow_options: Optional dictionary containing Dataflow options.
      job_label: User supplied description for the beam job name.
      compression: Can be 'gzip' or None for no compression.
      num_shards: Number of shards to divide the TFRecords into. Default is
          0 = no sharding.
    Returns:
      job_results: A dictionary of job results.
    """
        display.display(
            display.HTML('<b>Logging output to /tmp/{} </b>'.format(
                constants.LOGFILE)))

        r = client.create_tfrecords(self._df,
                                    output_dir=output_dir,
                                    schema_map=schema_map,
                                    runner=runner,
                                    project=project,
                                    region=region,
                                    tfrecorder_wheel=tfrecorder_wheel,
                                    dataflow_options=dataflow_options,
                                    job_label=job_label,
                                    compression=compression,
                                    num_shards=num_shards)
        return r
 def test_create_tfrecords_direct_runner(self, mock_beam):
     """Tests `create_tfrecords` Direct case."""
     mock_beam.build_pipeline().run().wait_until_finished.return_value = {
         'rows': 6
     }
     r = client.create_tfrecords(self.test_df,
                                 runner='DirectRunner',
                                 output_dir='/tmp/direct_runner')
     self.assertTrue('metrics' in r)
Esempio n. 4
0
    def to_tfr(self,
               output_dir: str,
               runner: str = 'DirectRunner',
               project: Optional[str] = None,
               region: Optional[str] = None,
               dataflow_options: Union[Dict[str, Any], None] = None,
               job_label: str = 'to-tfr',
               compression: Optional[str] = 'gzip',
               num_shards: int = 0) -> Dict[str, Any]:
        """TFRecorder Pandas Accessor.

    TFRecorder provides an easy interface to create image-based tensorflow
    records from a dataframe containing GCS locations of the images and labels.

    Usage:
      import tfrecorder

      df.tfrecorder.to_tfr(
          output_dir='gcs://foo/bar/train',
          runner='DirectRunner',
          compression='gzip',
          num_shards=10)

    Args:
      output_dir: Local directory or GCS Location to save TFRecords to.
      runner: Beam runner. Can be DirectRunner or  DataFlowRunner.
      project: GCP project name (Required if DataFlowRunner).
      region: GCP region name (Required if DataFlowRunner).
      dataflow_options: Optional dictionary containing DataFlow options.
      job_label: User supplied description for the beam job name.
      compression: Can be 'gzip' or None for no compression.
      num_shards: Number of shards to divide the TFRecords into. Default is
          0 = no sharding.
    Returns:
      job_results: A dictionary of job results.
    """
        display.display(
            display.HTML('<b>Logging output to /tmp/{} </b>'.format(
                constants.LOGFILE)))

        r = client.create_tfrecords(self._df,
                                    output_dir=output_dir,
                                    runner=runner,
                                    project=project,
                                    region=region,
                                    dataflow_options=dataflow_options,
                                    job_label=job_label,
                                    compression=compression,
                                    num_shards=num_shards)
        return r