def _ReadAndCheckRowsUsingFeatures(self, num_rows):
        self.server.handler.num_rows = num_rows

        with self.test_session() as sess:
            feature_configs = {
                "int64_col":
                parsing_ops.FixedLenFeature([1], dtype=dtypes.int64),
                "string_col":
                parsing_ops.FixedLenFeature([1],
                                            dtype=dtypes.string,
                                            default_value="s_default"),
            }
            reader = cloud.BigQueryReader(
                project_id=_PROJECT,
                dataset_id=_DATASET,
                table_id=_TABLE,
                num_partitions=4,
                features=feature_configs,
                timestamp_millis=1,
                test_end_point=("%s:%s" %
                                (self.server.httpd.server_address[0],
                                 self.server.httpd.server_address[1])))

            key, value = _SetUpQueue(reader)

            seen_rows = []
            features = parsing_ops.parse_example(array_ops.reshape(value, [1]),
                                                 feature_configs)
            for _ in range(num_rows):
                int_value, str_value = sess.run(
                    [features["int64_col"], features["string_col"]])

                # Parse values returned from the session.
                self.assertEqual(int_value.shape, (1, 1))
                self.assertEqual(str_value.shape, (1, 1))
                int64_col = int_value[0][0]
                string_col = str_value[0][0]
                seen_rows.append(int64_col)

                # Compare.
                expected_row = _ROWS[int64_col]
                self.assertEqual(int64_col, expected_row[0])
                self.assertEqual(
                    compat.as_str(string_col),
                    ("s_%d" % int64_col) if expected_row[1] else "s_default")

            self.assertItemsEqual(seen_rows, range(num_rows))

            with self.assertRaisesOpError(
                    "is closed and has insufficient elements "
                    "\\(requested 1, current size 0\\)"):
                sess.run([key, value])
Ejemplo n.º 2
0
def input_fn_from_bigquery():
    features = dict(
        uid=tf.FixedLenFeature([1], tf.string),
        startStation=tf.FixedLenFeature([1], tf.string),
        endStation=tf.FixedLenFeature([1], dtype=tf.string, default_value="USA"),
        longitude=tf.FixedLenFeature([1], tf.string),
        latitude=tf.FixedLenFeature([1], tf.string),
        locationAccuracy=tf.FixedLenFeature([1], tf.string),
        locationAccuracyINT2=tf.FixedLenFeature([1], tf.int64),
        timeLocationDetected=tf.FixedLenFeature([1], tf.string),
        detectedActivity=tf.FixedLenFeature([1], tf.string),
        detectedActivityINT=tf.FixedLenFeature([1], tf.string),
        detectedActivityINT2=tf.FixedLenFeature([1], tf.int64, default_value=0),
        longitudeFloat=tf.FixedLenFeature([1], tf.float32,default_value=12.0),
        latitudeFloat=tf.FixedLenFeature([1], tf.float32, default_value=14.0),
        detectedActivityConfidenceINT=tf.FixedLenFeature([1], tf.int64, default_value=5),
        timeLocationDetectedINT=tf.FixedLenFeature([1], tf.int64, default_value=6),
    )
    # Create the parse_examples list of features.
    training_data = dict(
        locationAccuracyINT2=tf.FixedLenFeature([1], tf.int64, default_value=3),
        longitudeFloat=tf.FixedLenFeature([1], tf.float32, default_value=12.0),
        latitudeFloat=tf.FixedLenFeature([1], tf.float32, default_value=14.0),
        detectedActivityConfidenceINT=tf.FixedLenFeature([1], tf.int64, default_value=5),
        timeLocationDetectedINT=tf.FixedLenFeature([1], tf.int64, default_value=6),
    )

    # Create the parse_examples list of features.
    label = dict(
        #endStation=tf.FixedLenFeature([1], tf.string),
        #endStation=tf.FixedLenFeature([1], tf.string),
        detectedActivityINT2=tf.FixedLenFeature([1], tf.int64, default_value=2)
    )

    # Create a Reader.
    reader = bigquery_reader_ops.BigQueryReader(project_id=PROJECT,
                                                dataset_id=DATASET,
                                                table_id=TABLE,
                                                timestamp_millis=TIME,
                                                num_partitions=NUM_PARTITIONS,
                                                features=features)
    # Populate a queue with the BigQuery Table partitions.

    queue = tf.train.string_input_producer(reader.partitions())

    #tf.train.start_queue_runners(sess) ##Remove later
    row_id, examples_serialized = reader.read_up_to(queue, 50)  ##OK then we get a vector
    features = tf.parse_example(examples_serialized, features=training_data)
    labels = tf.parse_example(examples_serialized, features=label)
    return features, labels["detectedActivityINT2"]
    def testReadingMultipleRowsUsingColumns(self):
        num_rows = 10
        self.server.handler.num_rows = num_rows

        with self.test_session() as sess:
            reader = cloud.BigQueryReader(
                project_id=_PROJECT,
                dataset_id=_DATASET,
                table_id=_TABLE,
                num_partitions=4,
                columns=["int64_col", "float_col", "string_col"],
                timestamp_millis=1,
                test_end_point=("%s:%s" %
                                (self.server.httpd.server_address[0],
                                 self.server.httpd.server_address[1])))
            key, value = _SetUpQueue(reader)
            seen_rows = []
            for row_index in range(num_rows):
                returned_row_id, example_proto = sess.run([key, value])
                example = example_pb2.Example()
                example.ParseFromString(example_proto)
                self.assertIn("int64_col", example.features.feature)
                feature = example.features.feature["int64_col"]
                self.assertEqual(len(feature.int64_list.value), 1)
                int64_col = feature.int64_list.value[0]
                seen_rows.append(int64_col)

                # Create our expected Example.
                expected_example = example_pb2.Example()
                expected_example = _ConvertRowToExampleProto(_ROWS[int64_col])

                # Compare.
                self.assertProtoEquals(example, expected_example)
                self.assertEqual(row_index, int(returned_row_id))

            self.assertItemsEqual(seen_rows, range(num_rows))

            with self.assertRaisesOpError(
                    "is closed and has insufficient elements "
                    "\\(requested 1, current size 0\\)"):
                sess.run([key, value])