Exemplo n.º 1
0
def get_model_spec(
    model_zoo,
    model_def,
    model_params,
    dataset_fn,
    loss,
    optimizer,
    eval_metrics_fn,
    prediction_outputs_processor,
    custom_data_reader,
    callbacks,
):
    """Get the model spec items in a tuple.

    The model spec tuple contains the following items in order:

    * The model object instantiated with parameters specified
      in `model_params`,
    * The `dataset_fn`,
    * The `loss`,
    * The `optimizer`,
    * The `eval_metrics_fn`,
    * The `prediction_outputs_processor`. Note that it will print
      warning if it's not inherited from `BasePredictionOutputsProcessor`.
    * The `custom_data_reader`
    * The `callbacks`
    """
    model_def_module_file = get_module_file_path(model_zoo, model_def)
    default_module = load_module(model_def_module_file).__dict__
    model = load_model_from_module(model_def, default_module, model_params)
    prediction_outputs_processor = _get_spec_value(
        prediction_outputs_processor, model_zoo, default_module)
    if prediction_outputs_processor and not isinstance(
            prediction_outputs_processor, BasePredictionOutputsProcessor):
        logger.warning("prediction_outputs_processor is not "
                       "inherited from BasePredictionOutputsProcessor. "
                       "Prediction outputs may not be processed correctly.")

    # If ODPS data source is used, dataset_fn is optional
    dataset_fn_required = not is_odps_configured()
    callbacks_list = load_callbacks_from_module(callbacks, default_module)

    return (
        model,
        _get_spec_value(dataset_fn,
                        model_zoo,
                        default_module,
                        required=dataset_fn_required),
        _get_spec_value(loss, model_zoo, default_module, required=True),
        _get_spec_value(optimizer, model_zoo, default_module, required=True),
        _get_spec_value(eval_metrics_fn,
                        model_zoo,
                        default_module,
                        required=True),
        prediction_outputs_processor,
        _get_spec_value(custom_data_reader, model_zoo, default_module),
        callbacks_list,
    )
Exemplo n.º 2
0
def create_data_reader(data_origin, records_per_task=None, **kwargs):
    """Create a data reader to read records
    Args:
        data_origin: The origin of the data, e.g. location to files,
            table name in the database, etc.
        records_per_task: The number of records to create a task
        kwargs: data reader params, the supported keys are
            "columns", "partition", "reader_type"
    """
    reader_type = kwargs.get("reader_type", None)
    if reader_type is None:
        if is_odps_configured():
            return ODPSDataReader(
                project=os.environ[MaxComputeConfig.PROJECT_NAME],
                access_id=os.environ[MaxComputeConfig.ACCESS_ID],
                access_key=os.environ[MaxComputeConfig.ACCESS_KEY],
                table=data_origin,
                endpoint=os.environ.get(MaxComputeConfig.ENDPOINT),
                tunnel_endpoint=os.environ.get(
                    MaxComputeConfig.TUNNEL_ENDPOINT, None),
                records_per_task=records_per_task,
                **kwargs,
            )
        elif data_origin and data_origin.endswith(".csv"):
            return TextDataReader(
                filename=data_origin,
                records_per_task=records_per_task,
                **kwargs,
            )
        else:
            return RecordIODataReader(data_dir=data_origin)
    elif reader_type == ReaderType.CSV_READER:
        return TextDataReader(filename=data_origin,
                              records_per_task=records_per_task,
                              **kwargs)
    elif reader_type == ReaderType.ODPS_READER:
        if not is_odps_configured:
            raise ValueError(
                "MAXCOMPUTE_AK, MAXCOMPUTE_SK and MAXCOMPUTE_PROJECT ",
                "must be configured in envs",
            )
        return ODPSDataReader(
            project=os.environ[MaxComputeConfig.PROJECT_NAME],
            access_id=os.environ[MaxComputeConfig.ACCESS_ID],
            access_key=os.environ[MaxComputeConfig.ACCESS_KEY],
            table=data_origin,
            endpoint=os.environ.get(MaxComputeConfig.ENDPOINT),
            records_per_task=records_per_task,
            **kwargs,
        )
    elif reader_type == ReaderType.RECORDIO_READER:
        return RecordIODataReader(data_dir=data_origin)
    else:
        raise ValueError(
            "The reader type {} is not supported".format(reader_type))
Exemplo n.º 3
0
def create_data_reader(data_origin, records_per_task=None, **kwargs):
    if is_odps_configured():
        return ODPSDataReader(
            project=os.environ[ODPSConfig.PROJECT_NAME],
            access_id=os.environ[ODPSConfig.ACCESS_ID],
            access_key=os.environ[ODPSConfig.ACCESS_KEY],
            table=data_origin,
            endpoint=os.environ.get(ODPSConfig.ENDPOINT),
            records_per_task=records_per_task,
            **kwargs,
        )
    else:
        return RecordIODataReader(data_dir=data_origin)
Exemplo n.º 4
0
 def __init__(self):
     if is_odps_configured():
         self.odps_writer = ODPSWriter(
             os.environ[ODPSConfig.PROJECT_NAME],
             os.environ[ODPSConfig.ACCESS_ID],
             os.environ[ODPSConfig.ACCESS_KEY],
             os.environ[ODPSConfig.ENDPOINT],
             "cifar10_prediction_outputs",
             # TODO: Print out helpful error message if the columns and
             # column_types do not match with the prediction outputs
             columns=["f" + str(i) for i in range(10)],
             column_types=["double" for _ in range(10)],
         )
     else:
         self.odps_writer = None
Exemplo n.º 5
0
                dataset = dataset.map(_parse_data)
                dataset = dataset.batch(10)
                return dataset

            dataset = tf.data.Dataset.from_generator(
                _gen, csv_data_reader.records_output_types)
            dataset = _dataset_fn(dataset, None, None)
            for features, labels in dataset:
                self.assertEqual(features.shape.as_list(), [10, 4])
                self.assertEqual(labels.shape.as_list(), [10])
                break


@unittest.skipIf(
    not is_odps_configured(),
    "ODPS environment is not configured",
)
class ODPSDataReaderTest(unittest.TestCase):
    def setUp(self):
        self.project = os.environ[MaxComputeConfig.PROJECT_NAME]
        access_id = os.environ[MaxComputeConfig.ACCESS_ID]
        access_key = os.environ[MaxComputeConfig.ACCESS_KEY]
        endpoint = os.environ.get(MaxComputeConfig.ENDPOINT)
        tunnel_endpoint = os.environ.get(MaxComputeConfig.TUNNEL_ENDPOINT,
                                         None)
        self.test_table = "test_odps_data_reader_%d_%d" % (
            int(time.time()),
            random.randint(1, 101),
        )
        self.odps_client = ODPS(access_id, access_key, self.project, endpoint)
Exemplo n.º 6
0
                dataset = dataset.map(_parse_data)
                dataset = dataset.batch(10)
                return dataset

            dataset = tf.data.Dataset.from_generator(
                _gen, csv_data_reader.records_output_types
            )
            dataset = _dataset_fn(dataset, None, None)
            for features, labels in dataset:
                self.assertEqual(features.shape.as_list(), [10, 4])
                self.assertEqual(labels.shape.as_list(), [10])
                break


@unittest.skipIf(
    not is_odps_configured(), "ODPS environment is not configured",
)
class ODPSDataReaderTest(unittest.TestCase):
    def setUp(self):
        self.project = os.environ[MaxComputeConfig.PROJECT_NAME]
        access_id = os.environ[MaxComputeConfig.ACCESS_ID]
        access_key = os.environ[MaxComputeConfig.ACCESS_KEY]
        endpoint = os.environ.get(MaxComputeConfig.ENDPOINT)
        tunnel_endpoint = os.environ.get(
            MaxComputeConfig.TUNNEL_ENDPOINT, None
        )
        self.test_table = "test_odps_data_reader_%d_%d" % (
            int(time.time()),
            random.randint(1, 101),
        )
        self.odps_client = ODPS(access_id, access_key, self.project, endpoint)