def get_model_spec( model_zoo, model_def, model_params, dataset_fn, loss, optimizer, eval_metrics_fn, prediction_outputs_processor, custom_data_reader, callbacks, ): """Get the model spec items in a tuple. The model spec tuple contains the following items in order: * The model object instantiated with parameters specified in `model_params`, * The `dataset_fn`, * The `loss`, * The `optimizer`, * The `eval_metrics_fn`, * The `prediction_outputs_processor`. Note that it will print warning if it's not inherited from `BasePredictionOutputsProcessor`. * The `custom_data_reader` * The `callbacks` """ model_def_module_file = get_module_file_path(model_zoo, model_def) default_module = load_module(model_def_module_file).__dict__ model = load_model_from_module(model_def, default_module, model_params) prediction_outputs_processor = _get_spec_value( prediction_outputs_processor, model_zoo, default_module) if prediction_outputs_processor and not isinstance( prediction_outputs_processor, BasePredictionOutputsProcessor): logger.warning("prediction_outputs_processor is not " "inherited from BasePredictionOutputsProcessor. " "Prediction outputs may not be processed correctly.") # If ODPS data source is used, dataset_fn is optional dataset_fn_required = not is_odps_configured() callbacks_list = load_callbacks_from_module(callbacks, default_module) return ( model, _get_spec_value(dataset_fn, model_zoo, default_module, required=dataset_fn_required), _get_spec_value(loss, model_zoo, default_module, required=True), _get_spec_value(optimizer, model_zoo, default_module, required=True), _get_spec_value(eval_metrics_fn, model_zoo, default_module, required=True), prediction_outputs_processor, _get_spec_value(custom_data_reader, model_zoo, default_module), callbacks_list, )
def create_data_reader(data_origin, records_per_task=None, **kwargs): """Create a data reader to read records Args: data_origin: The origin of the data, e.g. location to files, table name in the database, etc. records_per_task: The number of records to create a task kwargs: data reader params, the supported keys are "columns", "partition", "reader_type" """ reader_type = kwargs.get("reader_type", None) if reader_type is None: if is_odps_configured(): return ODPSDataReader( project=os.environ[MaxComputeConfig.PROJECT_NAME], access_id=os.environ[MaxComputeConfig.ACCESS_ID], access_key=os.environ[MaxComputeConfig.ACCESS_KEY], table=data_origin, endpoint=os.environ.get(MaxComputeConfig.ENDPOINT), tunnel_endpoint=os.environ.get( MaxComputeConfig.TUNNEL_ENDPOINT, None), records_per_task=records_per_task, **kwargs, ) elif data_origin and data_origin.endswith(".csv"): return TextDataReader( filename=data_origin, records_per_task=records_per_task, **kwargs, ) else: return RecordIODataReader(data_dir=data_origin) elif reader_type == ReaderType.CSV_READER: return TextDataReader(filename=data_origin, records_per_task=records_per_task, **kwargs) elif reader_type == ReaderType.ODPS_READER: if not is_odps_configured: raise ValueError( "MAXCOMPUTE_AK, MAXCOMPUTE_SK and MAXCOMPUTE_PROJECT ", "must be configured in envs", ) return ODPSDataReader( project=os.environ[MaxComputeConfig.PROJECT_NAME], access_id=os.environ[MaxComputeConfig.ACCESS_ID], access_key=os.environ[MaxComputeConfig.ACCESS_KEY], table=data_origin, endpoint=os.environ.get(MaxComputeConfig.ENDPOINT), records_per_task=records_per_task, **kwargs, ) elif reader_type == ReaderType.RECORDIO_READER: return RecordIODataReader(data_dir=data_origin) else: raise ValueError( "The reader type {} is not supported".format(reader_type))
def create_data_reader(data_origin, records_per_task=None, **kwargs): if is_odps_configured(): return ODPSDataReader( project=os.environ[ODPSConfig.PROJECT_NAME], access_id=os.environ[ODPSConfig.ACCESS_ID], access_key=os.environ[ODPSConfig.ACCESS_KEY], table=data_origin, endpoint=os.environ.get(ODPSConfig.ENDPOINT), records_per_task=records_per_task, **kwargs, ) else: return RecordIODataReader(data_dir=data_origin)
def __init__(self): if is_odps_configured(): self.odps_writer = ODPSWriter( os.environ[ODPSConfig.PROJECT_NAME], os.environ[ODPSConfig.ACCESS_ID], os.environ[ODPSConfig.ACCESS_KEY], os.environ[ODPSConfig.ENDPOINT], "cifar10_prediction_outputs", # TODO: Print out helpful error message if the columns and # column_types do not match with the prediction outputs columns=["f" + str(i) for i in range(10)], column_types=["double" for _ in range(10)], ) else: self.odps_writer = None
dataset = dataset.map(_parse_data) dataset = dataset.batch(10) return dataset dataset = tf.data.Dataset.from_generator( _gen, csv_data_reader.records_output_types) dataset = _dataset_fn(dataset, None, None) for features, labels in dataset: self.assertEqual(features.shape.as_list(), [10, 4]) self.assertEqual(labels.shape.as_list(), [10]) break @unittest.skipIf( not is_odps_configured(), "ODPS environment is not configured", ) class ODPSDataReaderTest(unittest.TestCase): def setUp(self): self.project = os.environ[MaxComputeConfig.PROJECT_NAME] access_id = os.environ[MaxComputeConfig.ACCESS_ID] access_key = os.environ[MaxComputeConfig.ACCESS_KEY] endpoint = os.environ.get(MaxComputeConfig.ENDPOINT) tunnel_endpoint = os.environ.get(MaxComputeConfig.TUNNEL_ENDPOINT, None) self.test_table = "test_odps_data_reader_%d_%d" % ( int(time.time()), random.randint(1, 101), ) self.odps_client = ODPS(access_id, access_key, self.project, endpoint)
dataset = dataset.map(_parse_data) dataset = dataset.batch(10) return dataset dataset = tf.data.Dataset.from_generator( _gen, csv_data_reader.records_output_types ) dataset = _dataset_fn(dataset, None, None) for features, labels in dataset: self.assertEqual(features.shape.as_list(), [10, 4]) self.assertEqual(labels.shape.as_list(), [10]) break @unittest.skipIf( not is_odps_configured(), "ODPS environment is not configured", ) class ODPSDataReaderTest(unittest.TestCase): def setUp(self): self.project = os.environ[MaxComputeConfig.PROJECT_NAME] access_id = os.environ[MaxComputeConfig.ACCESS_ID] access_key = os.environ[MaxComputeConfig.ACCESS_KEY] endpoint = os.environ.get(MaxComputeConfig.ENDPOINT) tunnel_endpoint = os.environ.get( MaxComputeConfig.TUNNEL_ENDPOINT, None ) self.test_table = "test_odps_data_reader_%d_%d" % ( int(time.time()), random.randint(1, 101), ) self.odps_client = ODPS(access_id, access_key, self.project, endpoint)