def test_serialize_and_deserialize(self): spec = mobile_search_space_v3.mobilenet_v3_large() serialized = schema_io.serialize(spec) self.assertIsInstance(serialized, str) deserialized = schema_io.deserialize(serialized) self.assertIsInstance(deserialized, basic_specs.ConvTowerSpec) self.assertEqual(deserialized, spec)
def _run_serialization_test(self, structure, expected_type=None): """Convert the structure to serialized JSON, then back to a string.""" expected_value = copy.deepcopy(structure) serialized = schema_io.serialize(structure) self.assertIsInstance(serialized, six.string_types) restored = schema_io.deserialize(serialized) self.assertEqual(restored, expected_value) if expected_type is not None: self.assertIsInstance(restored, expected_type)
def test_deserialization_defaults(self): # NamedTuple1 accepts one argument: foo. It has not default value. # NamedTuple3 accepts two arguments: foo and bar. Both have default values. # Use default arguments for both foo and bar. value = schema_io.deserialize( """["namedtuple:schema_io_test.NamedTuple3"]""") self.assertEqual(value, NamedTuple3(foo=3, bar='hi')) # Use default argument for bar only. value = schema_io.deserialize( """["namedtuple:schema_io_test.NamedTuple3", ["foo", 42]]""") self.assertEqual(value, NamedTuple3(foo=42, bar='hi')) # Use default argument for foo only. value = schema_io.deserialize( """["namedtuple:schema_io_test.NamedTuple3", ["bar", "bye"]]""") self.assertEqual(value, NamedTuple3(foo=3, bar='bye')) # Don't use any default arguments. value = schema_io.deserialize( """["namedtuple:schema_io_test.NamedTuple3", ["foo", 9], ["bar", "x"]]""") self.assertEqual(value, NamedTuple3(foo=9, bar='x')) # Default values should also work when we refer to a namedtuple by a # deprecated name. value = schema_io.deserialize( """["namedtuple:schema_io_test.DeprecatedNamedTuple3"]""") self.assertEqual(value, NamedTuple3(foo=3, bar='hi')) # Serialized value references a field that doesn't exist in the namedtuple. with self.assertRaisesRegex(ValueError, 'Invalid field: baz'): schema_io.deserialize( """["namedtuple:schema_io_test.NamedTuple3", ["baz", 10]]""") # Serialized value is missing a field that should exist in the namedtuple. with self.assertRaisesRegex(ValueError, 'Missing field: foo'): schema_io.deserialize( """["namedtuple:schema_io_test.NamedTuple1"]""")
def _scan_directory(directory, output_format, ssd): """Scan a directory for log files and write the final model to stdout.""" if output_format == _OUTPUT_FORMAT_LINES: print('directory =', directory) model_spec_filename = os.path.join(directory, 'model_spec.json') if not tf.io.gfile.exists(model_spec_filename): print('file {} not found; skipping'.format(model_spec_filename)) if output_format == _OUTPUT_FORMAT_LINES: print() return with tf.io.gfile.GFile(model_spec_filename, 'r') as handle: model_spec = schema_io.deserialize(handle.read()) paths = [] oneofs = dict() def populate_oneofs(path, oneof): paths.append(path) oneofs[path] = oneof schema.map_oneofs_with_paths(populate_oneofs, model_spec) all_path_logits = analyze_mobile_search_lib.read_path_logits(directory) if not all_path_logits: print( 'event data missing from directory {}; skipping'.format(directory)) if output_format == _OUTPUT_FORMAT_LINES: print() return global_step = max(all_path_logits) if output_format == _OUTPUT_FORMAT_LINES: print('global_step = {:d}'.format(global_step)) all_path_logit_keys = six.viewkeys(all_path_logits[global_step]) oneof_keys = six.viewkeys(oneofs) if all_path_logit_keys != oneof_keys: raise ValueError( 'OneOf key mismatch. Present in event files but not in model_spec: {}. ' 'Present in model_spec but not in event files: {}'.format( all_path_logit_keys - oneof_keys, oneof_keys - all_path_logit_keys)) indices = [] for path in paths: index = np.argmax(all_path_logits[global_step][path]) indices.append(index) indices_str = ':'.join(map(str, indices)) if output_format == _OUTPUT_FORMAT_LINES: print('indices = {:s}'.format(indices_str)) cost_model_time = mobile_cost_model.estimate_cost(indices, ssd) if output_format == _OUTPUT_FORMAT_LINES: print('cost_model = {:f}'.format(cost_model_time)) if output_format == _OUTPUT_FORMAT_LINES: print() elif output_format == _OUTPUT_FORMAT_CSV: fields = [indices_str, global_step, directory, cost_model_time] print(','.join(map(str, fields)))
def test_namedtuple_deserialization_with_deprecated_names(self): restored = schema_io.deserialize( '["namedtuple:schema_io_test.DeprecatedNamedTuple1",["foo",51]]') self.assertEqual(restored, NamedTuple1(51)) self.assertIsInstance(restored, NamedTuple1)