def test_serialize_and_deserialize(self):
        spec = mobile_search_space_v3.mobilenet_v3_large()

        serialized = schema_io.serialize(spec)
        self.assertIsInstance(serialized, str)

        deserialized = schema_io.deserialize(serialized)
        self.assertIsInstance(deserialized, basic_specs.ConvTowerSpec)
        self.assertEqual(deserialized, spec)
Example #2
0
    def _run_serialization_test(self, structure, expected_type=None):
        """Convert the structure to serialized JSON, then back to a string."""
        expected_value = copy.deepcopy(structure)

        serialized = schema_io.serialize(structure)
        self.assertIsInstance(serialized, six.string_types)

        restored = schema_io.deserialize(serialized)
        self.assertEqual(restored, expected_value)

        if expected_type is not None:
            self.assertIsInstance(restored, expected_type)
Example #3
0
    def test_deserialization_defaults(self):
        # NamedTuple1 accepts one argument: foo. It has not default value.
        # NamedTuple3 accepts two arguments: foo and bar. Both have default values.

        # Use default arguments for both foo and bar.
        value = schema_io.deserialize(
            """["namedtuple:schema_io_test.NamedTuple3"]""")
        self.assertEqual(value, NamedTuple3(foo=3, bar='hi'))

        # Use default argument for bar only.
        value = schema_io.deserialize(
            """["namedtuple:schema_io_test.NamedTuple3", ["foo", 42]]""")
        self.assertEqual(value, NamedTuple3(foo=42, bar='hi'))

        # Use default argument for foo only.
        value = schema_io.deserialize(
            """["namedtuple:schema_io_test.NamedTuple3", ["bar", "bye"]]""")
        self.assertEqual(value, NamedTuple3(foo=3, bar='bye'))

        # Don't use any default arguments.
        value = schema_io.deserialize(
            """["namedtuple:schema_io_test.NamedTuple3",
            ["foo", 9], ["bar", "x"]]""")
        self.assertEqual(value, NamedTuple3(foo=9, bar='x'))

        # Default values should also work when we refer to a namedtuple by a
        # deprecated name.
        value = schema_io.deserialize(
            """["namedtuple:schema_io_test.DeprecatedNamedTuple3"]""")
        self.assertEqual(value, NamedTuple3(foo=3, bar='hi'))

        # Serialized value references a field that doesn't exist in the namedtuple.
        with self.assertRaisesRegex(ValueError, 'Invalid field: baz'):
            schema_io.deserialize(
                """["namedtuple:schema_io_test.NamedTuple3", ["baz", 10]]""")

        # Serialized value is missing a field that should exist in the namedtuple.
        with self.assertRaisesRegex(ValueError, 'Missing field: foo'):
            schema_io.deserialize(
                """["namedtuple:schema_io_test.NamedTuple1"]""")
def _scan_directory(directory, output_format, ssd):
    """Scan a directory for log files and write the final model to stdout."""
    if output_format == _OUTPUT_FORMAT_LINES:
        print('directory =', directory)

    model_spec_filename = os.path.join(directory, 'model_spec.json')
    if not tf.io.gfile.exists(model_spec_filename):
        print('file {} not found; skipping'.format(model_spec_filename))
        if output_format == _OUTPUT_FORMAT_LINES:
            print()
        return

    with tf.io.gfile.GFile(model_spec_filename, 'r') as handle:
        model_spec = schema_io.deserialize(handle.read())

    paths = []
    oneofs = dict()

    def populate_oneofs(path, oneof):
        paths.append(path)
        oneofs[path] = oneof

    schema.map_oneofs_with_paths(populate_oneofs, model_spec)

    all_path_logits = analyze_mobile_search_lib.read_path_logits(directory)
    if not all_path_logits:
        print(
            'event data missing from directory {}; skipping'.format(directory))
        if output_format == _OUTPUT_FORMAT_LINES:
            print()
        return

    global_step = max(all_path_logits)
    if output_format == _OUTPUT_FORMAT_LINES:
        print('global_step = {:d}'.format(global_step))

    all_path_logit_keys = six.viewkeys(all_path_logits[global_step])
    oneof_keys = six.viewkeys(oneofs)
    if all_path_logit_keys != oneof_keys:
        raise ValueError(
            'OneOf key mismatch. Present in event files but not in model_spec: {}. '
            'Present in model_spec but not in event files: {}'.format(
                all_path_logit_keys - oneof_keys,
                oneof_keys - all_path_logit_keys))

    indices = []
    for path in paths:
        index = np.argmax(all_path_logits[global_step][path])
        indices.append(index)

    indices_str = ':'.join(map(str, indices))
    if output_format == _OUTPUT_FORMAT_LINES:
        print('indices = {:s}'.format(indices_str))

    cost_model_time = mobile_cost_model.estimate_cost(indices, ssd)
    if output_format == _OUTPUT_FORMAT_LINES:
        print('cost_model = {:f}'.format(cost_model_time))

    if output_format == _OUTPUT_FORMAT_LINES:
        print()
    elif output_format == _OUTPUT_FORMAT_CSV:
        fields = [indices_str, global_step, directory, cost_model_time]
        print(','.join(map(str, fields)))
Example #5
0
 def test_namedtuple_deserialization_with_deprecated_names(self):
     restored = schema_io.deserialize(
         '["namedtuple:schema_io_test.DeprecatedNamedTuple1",["foo",51]]')
     self.assertEqual(restored, NamedTuple1(51))
     self.assertIsInstance(restored, NamedTuple1)