class PredictionOutputsProcessor(BasePredictionOutputsProcessor): def __init__(self): if all(k in os.environ for k in ( ODPSConfig.PROJECT_NAME, ODPSConfig.ACCESS_ID, ODPSConfig.ACCESS_KEY, ODPSConfig.ENDPOINT, )): self.odps_writer = ODPSWriter( os.environ[ODPSConfig.PROJECT_NAME], os.environ[ODPSConfig.ACCESS_ID], os.environ[ODPSConfig.ACCESS_KEY], os.environ[ODPSConfig.ENDPOINT], "cifar10_prediction_outputs", # TODO: Print out helpful error message if the columns and # column_types do not match with the prediction outputs columns=["f" + str(i) for i in range(10)], column_types=["double" for _ in range(10)], ) else: self.odps_writer = None def process(self, predictions, worker_id): if self.odps_writer: self.odps_writer.from_iterator(iter(predictions.numpy().tolist()), worker_id) else: logger.info(predictions.numpy())
def test_write_from_iterator(self): columns = ["num", "num2"] column_types = ["bigint", "double"] # If the table doesn't exist yet writer = ODPSWriter( self._project, self._access_id, self._access_key, self._endpoint, self._test_write_table, columns, column_types, ) writer.from_iterator(iter([[1, 0.5], [2, 0.6]]), 2) table = self._odps_client.get_table(self._test_write_table, self._project) self.assertEqual(table.schema.names, columns) self.assertEqual(table.schema.types, column_types) self.assertEqual(table.to_df().count(), 1) # If the table already exists writer = ODPSWriter( self._project, self._access_id, self._access_key, self._endpoint, self._test_write_table, ) writer.from_iterator(iter([[1, 0.5], [2, 0.6]]), 2) table = self._odps_client.get_table(self._test_write_table, self._project) self.assertEqual(table.schema.names, columns) self.assertEqual(table.schema.types, column_types) self.assertEqual(table.to_df().count(), 2)
def __init__(self): if is_odps_configured(): self.odps_writer = ODPSWriter( os.environ[ODPSConfig.PROJECT_NAME], os.environ[ODPSConfig.ACCESS_ID], os.environ[ODPSConfig.ACCESS_KEY], os.environ[ODPSConfig.ENDPOINT], "cifar10_prediction_outputs", # TODO: Print out helpful error message if the columns and # column_types do not match with the prediction outputs columns=["f" + str(i) for i in range(10)], column_types=["double" for _ in range(10)], ) else: self.odps_writer = None