def visualize(self, model_root, log_dir): save_embeddings_metadata( log_dir, misc.get_variable_name(self.embedding, model_root), self.vocabulary_file, num_oov_buckets=self.num_oov_buckets, )
def testGetVariableName(self, distributed_variables, dtype_policy): tf.keras.mixed_precision.experimental.set_policy(dtype_policy) if distributed_variables: devices = tf.config.list_logical_devices(device_type="CPU") strategy = tf.distribute.MirroredStrategy(devices=devices) else: strategy = None class Layer(tf.keras.layers.Layer): def __init__(self): super().__init__() self.variable = self.add_weight("variable", [42]) class Model(tf.keras.layers.Layer): def __init__(self): super().__init__() self.layers = [Layer()] if strategy is not None: with strategy.scope(): model = Model() else: model = Model() variable = model.layers[0].variable expected_name = "model/layers/0/variable/.ATTRIBUTES/VARIABLE_VALUE" variable_name = misc.get_variable_name(variable, model) self.assertEqual(variable_name, expected_name) variables = misc.get_variables_name_mapping(model, root_key="model") self.assertIs(variables[expected_name], variable) tf.keras.mixed_precision.experimental.set_policy("float32")
def testGetVariableName(self): class Layer(tf.Module): def __init__(self): super(Layer, self).__init__() self.variable = tf.Variable(0) class Model(tf.Module): def __init__(self): super(Model, self).__init__() self.layers = [Layer()] model = Model() variable_name = misc.get_variable_name(model.layers[0].variable, model) self.assertEqual(variable_name, "model/layers/0/variable/.ATTRIBUTES/VARIABLE_VALUE")
def testGetVariableName(self): class Layer(tf.Module): def __init__(self): super(Layer, self).__init__() self.variable = tf.Variable(0) class Model(tf.Module): def __init__(self): super(Model, self).__init__() self.layers = [Layer()] model = Model() variable = model.layers[0].variable expected_name = "model/layers/0/variable/.ATTRIBUTES/VARIABLE_VALUE" variable_name = misc.get_variable_name(variable, model) self.assertEqual(variable_name, expected_name) variables_to_names, names_to_variables = misc.get_variables_name_mapping(model, root_key="model") self.assertDictEqual(variables_to_names, {variable.experimental_ref(): expected_name}) self.assertDictEqual(names_to_variables, {expected_name: variable})