Exemplo n.º 1
0
class PredictionOutputsProcessor(BasePredictionOutputsProcessor):
    def __init__(self):
        if all(k in os.environ for k in (
                ODPSConfig.PROJECT_NAME,
                ODPSConfig.ACCESS_ID,
                ODPSConfig.ACCESS_KEY,
                ODPSConfig.ENDPOINT,
        )):
            self.odps_writer = ODPSWriter(
                os.environ[ODPSConfig.PROJECT_NAME],
                os.environ[ODPSConfig.ACCESS_ID],
                os.environ[ODPSConfig.ACCESS_KEY],
                os.environ[ODPSConfig.ENDPOINT],
                "cifar10_prediction_outputs",
                # TODO: Print out helpful error message if the columns and
                # column_types do not match with the prediction outputs
                columns=["f" + str(i) for i in range(10)],
                column_types=["double" for _ in range(10)],
            )
        else:
            self.odps_writer = None

    def process(self, predictions, worker_id):
        if self.odps_writer:
            self.odps_writer.from_iterator(iter(predictions.numpy().tolist()),
                                           worker_id)
        else:
            logger.info(predictions.numpy())
Exemplo n.º 2
0
    def test_write_from_iterator(self):
        columns = ["num", "num2"]
        column_types = ["bigint", "double"]

        # If the table doesn't exist yet
        writer = ODPSWriter(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_write_table,
            columns,
            column_types,
        )
        writer.from_iterator(iter([[1, 0.5], [2, 0.6]]), 2)
        table = self._odps_client.get_table(self._test_write_table,
                                            self._project)
        self.assertEqual(table.schema.names, columns)
        self.assertEqual(table.schema.types, column_types)
        self.assertEqual(table.to_df().count(), 1)

        # If the table already exists
        writer = ODPSWriter(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_write_table,
        )
        writer.from_iterator(iter([[1, 0.5], [2, 0.6]]), 2)
        table = self._odps_client.get_table(self._test_write_table,
                                            self._project)
        self.assertEqual(table.schema.names, columns)
        self.assertEqual(table.schema.types, column_types)
        self.assertEqual(table.to_df().count(), 2)
Exemplo n.º 3
0
 def __init__(self):
     if is_odps_configured():
         self.odps_writer = ODPSWriter(
             os.environ[ODPSConfig.PROJECT_NAME],
             os.environ[ODPSConfig.ACCESS_ID],
             os.environ[ODPSConfig.ACCESS_KEY],
             os.environ[ODPSConfig.ENDPOINT],
             "cifar10_prediction_outputs",
             # TODO: Print out helpful error message if the columns and
             # column_types do not match with the prediction outputs
             columns=["f" + str(i) for i in range(10)],
             column_types=["double" for _ in range(10)],
         )
     else:
         self.odps_writer = None