Пример #1
0
def notify_about_new_variables(callback):
    """Calls `callback(var)` for all newly created variables.

  Callback should not modify the variable passed in. Use cases that require
  variables to be modified should use `variable_creator_scope` directly and sit
  within the variable creator stack.

      >>> variables = []
      >>> with notify_about_variables(variables.append):
      ...   v = tf.Variable(1.0, name='v')
      ...   w = tf.get_variable('w', [])
      >>> assert variables == [v, w]

  Args:
    callback: a callable taking a single argument which is a tf.Variable.

  Yields:
    `None` - used for contextmanager API.
  """
    def _tracking_creator(getter, **kwargs):
        v = getter(**kwargs)
        callback(v)
        return v

    with tf.variable_creator_scope(_tracking_creator):
        yield
Пример #2
0
  def build_network(self, inputs, phase_train=True, nclass=1001):
    try:
      from official.recommendation import neumf_model  # pylint: disable=g-import-not-at-top
    except ImportError as e:
      if 'neumf_model' not in e.message:
        raise
      raise ImportError('To use the experimental NCF model, you must clone the '
                        'repo https://github.com/tensorflow/models and add '
                        'tensorflow/models to the PYTHONPATH.')
    del nclass

    users, items, _ = inputs
    params = {
        'num_users': _NUM_USERS_20M,
        'num_items': _NUM_ITEMS_20M,
        'model_layers': (256, 256, 128, 64),
        'mf_dim': 64,
        'mf_regularization': 0,
        'mlp_reg_layers': (0, 0, 0, 0),
        'use_tpu': False
    }
    if self.data_type == tf.float32:
      keras_model = neumf_model.construct_model(users, items, params)
      logits = keras_model.output
    else:
      assert self.data_type == tf.float16
      tf.keras.backend.set_floatx('float16')
      # We cannot rely on the variable_scope's fp16 custom getter here, because
      # the NCF model uses keras layers, which ignore variable scopes. So we use
      # a variable_creator_scope instead.
      with tf.variable_creator_scope(_fp16_variable_creator):
        keras_model = neumf_model.construct_model(users, items, params)
      logits = tf.cast(keras_model.output, tf.float32)

    return model.BuildNetworkResult(logits=logits, extra_info=None)
Пример #3
0
    def __call__(self, image, probe=True, aug_image_key='image'):
        # creating local variable which will store copy of CTA log probabilities
        with tf.variable_creator_scope(_skip_mirrored_creator):
            local_log_prob = tf.Variable(
                lambda: tf.ones(self.state_shape, dtype=tf.float32),
                trainable=False,
                name='cta_log_probs')
        self.log_probs.append(local_log_prob)

        output_dict = {}
        if probe:
            probe_op_indices, probe_op_args = self._sample_ops_uniformly()
            probe_image = self._apply_ops(image, probe_op_indices,
                                          probe_op_args)
            output_dict['probe_op_indices'] = probe_op_indices
            output_dict['probe_op_args'] = probe_op_args
            output_dict['probe_image'] = probe_image

        if aug_image_key is not None:
            op_indices, op_args = self._sample_ops(local_log_prob)
            aug_image = self._apply_ops(image,
                                        op_indices,
                                        op_args,
                                        prob_to_apply=self.prob_to_apply)
            output_dict[aug_image_key] = aug_image

        if aug_image_key != 'image':
            output_dict['image'] = image

        return output_dict
Пример #4
0
    def testOptimizerInsideModelFn(self, distribution, optimizer_fn):
        if (not tf.executing_eagerly()
                and tf.compat.v1.control_flow_v2_enabled()):
            self.skipTest("b/138751864")
        created_variables = []
        trainable_variables = []

        def appending_creator(next_creator, **kwargs):
            v = next_creator(**kwargs)
            created_variables.append(v.name)
            if "trainable" in kwargs and kwargs["trainable"]:
                trainable_variables.append(v.name)
            return v

        # Creator scope needs to be set before it's used inside
        # `distribution.scope`.
        with tf.variable_creator_scope(
                appending_creator), distribution.scope():
            optimizer = optimizer_fn()
            model_fn, dataset_fn, _ = minimize_loss_example(
                optimizer, use_bias=True, use_callable_loss=True)

            def step_fn(ctx, inputs):
                del ctx  # Unused
                return distribution.group(
                    distribution.extended.call_for_each_replica(
                        model_fn, args=(inputs, )))

            iterator = self._get_iterator(distribution, dataset_fn)

            def run_step():
                return distribution.extended.experimental_run_steps_on_iterator(
                    step_fn, iterator, iterations=1).run_op

            if not tf.executing_eagerly():
                with self.cached_session() as sess:
                    run_step = sess.make_callable(run_step())
            self.evaluate(tf.compat.v1.global_variables_initializer())
            run_step()

            def get_expected_variables(num_parameter_devices):
                name = optimizer._name

                if isinstance(optimizer, optimizer_v2.OptimizerV2):
                    variables = VAR_MAP_V2[name]
                else:
                    variables = VAR_MAP_V1[name]

                extended_variables = [
                    v + "/replica_{}".format(replica) for v in variables
                    for replica in range(1, num_parameter_devices)
                ]
                variables = list(variables) + extended_variables
                return set(v + ":0" for v in variables)

            self.assertEqual(
                get_expected_variables(
                    len(distribution.extended.parameter_devices)),
                set(created_variables))
Пример #5
0
def test_demo():
    with tf.variable_creator_scope('chart1'):
        a = tf.Variable(40)
        b = tf.Variable(50)
    with tf.variable_creator_scope('chart2'):
        c = tf.add(a, b)
    #  init = tf.global_variables_initializer()
    print(a)
    print(b)
    print(c)
    with tf.Session() as sess:
        #     sess.run(init)
        a_value, b_value, c_value = sess.run([a, b, c])
        print("a_value", a_value)
        print("b_value", b_value)
        print("c_value", c_value)
    return None
Пример #6
0
def anchor_target_layer(cls_pre, bbox, img_info, scope_name):
    """     Get the results of anchor
    Args:
        cls_pre (float): classifier predict
        bbox : bounding boxes
        img_info: image information
    """

    with tf.variable_creator_scope(scope_name) as scope:
        
Пример #7
0
    def update(state, y_true, y_pred, sample_weight=None):
        del sample_weight  # Unused.

        def update(metric):
            metric.update_state(y_true, y_pred, sample_weight=None)
            return metric.variables

        with tf.variable_creator_scope(
                build_replace_variable_with_parameter_creator(state)):
            return tf.nest.map_structure(update, metrics_constructor())
Пример #8
0
    def predict_on_batch(model_weights: ModelWeights,
                         x: Any,
                         training: bool = True) -> Any:
        with tf.init_scope():
            if tf.executing_eagerly():
                raise KerasFunctionalModelError(
                    'tf.keras.Model used as a FunctionalModel is only usable inside a '
                    'tff.tf_computation decorated callable or a graph context.'
                )
        # Make a copy of the weights container; can't mutate Python containers
        # inside a tf.function.
        trainable, non_trainable = (list(w) for w in model_weights)

        # Here were intercept variable creation requests during the
        # `tf.keras.models.clone_model()` call.
        #
        # Instead of forwarding the variable request to TF core and getting a
        # `tf.Variable` back, we skip that and return only the `tf.Tensor` that
        # corresponds to the `tf.Variable` recreation request (avoiding any variable
        # creation). This works because TF operations that accept `tf.Variable`
        # inputs automatically call `variable.read_value()` and then operate on that
        # resulting tensor. We're relying on shortcutting that and providing the
        # tensor straight away.
        #
        # For example, `tf.matmul` doesn't notice its input is `tf.Variable` or
        # `tf.Tensor`:
        #
        #   v = tf.Variable([[1], [2], [3]])
        #   tf.matmul(v, [[4, 5, 6]])
        #
        #   and
        #
        #   v = tf.constant([[1], [2], [3]])
        #   tf.matmul(v, [[4, 5, 6]])
        #
        #   both result in:
        #
        #   <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
        #   array([[ 4,  5,  6],
        #          [ 8, 10, 12],
        #          [12, 15, 18]], dtype=int32)>
        def swap_tensor_parameter_for_variable(_, **kwargs):
            if kwargs.get('trainable', True):
                return trainable.pop(0)
            else:
                return non_trainable.pop(0)

        with tf.variable_creator_scope(swap_tensor_parameter_for_variable):
            if isinstance(keras_model, tf.keras.Model):
                variableless_model = tf.keras.models.clone_model(keras_model)
            else:
                variableless_model = keras_model()
        return variableless_model(x, training)
Пример #9
0
def write_v2_saved_model(tf_function: function.Function, name: str,
                         saved_model_dir: str) -> function.ConcreteFunction:
    """Writes `tf_function` under attr `name` to `saved_model_dir`."""
    module = tf.Module()

    resource_tracker = tracking.ResourceTracker()
    object_tracker = annotators.ObjectTracker()
    created_variables = []

    def _variable_creator(next_creator, **kwargs):
        var = next_creator(**kwargs)
        created_variables.append(var)
        return var

    # TODO(b/164921571): Handle generic Trackable objects.
    # Trace `tf_function` to gather any resources in it using the
    # resource_tracker. These are then assigned to `module.resources` and tracked
    # before exporting to SavedModel.
    with tracking.resource_tracker_scope(resource_tracker), \
         annotators.object_tracker_scope(object_tracker), \
         tf.variable_creator_scope(_variable_creator):
        concrete_fn = tf_function.get_concrete_function()

    # Prior to 2020/10/08, saving a tf.function with a concrete function signature
    # would ensure that the function was not re-traced in a round-trip to a
    # SavedModel. Since this is no longer the case, we save the concrete function
    # directly.
    if tf.compat.forward_compatible(2020, 10, 8):
        pruned_function = optimize_concrete_function(concrete_fn)
        module.pruned_variables = pruned_function.variables
        setattr(module, name, pruned_function)
    else:
        setattr(module, name, tf_function)

    # Any variables created need to be explicitly tracked.
    module.created_variables = created_variables
    # Resources need to be explicitly tracked.
    module.resources = resource_tracker.resources
    module.trackable_objects = object_tracker.trackable_objects
    # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the
    # table should be sufficient.
    initializers = []
    for resource in module.resources:
        if isinstance(resource, lookup_ops.InitializableLookupTableBase):
            initializers.append(resource._initializer)  # pylint: disable=protected-access
    module.initializers = initializers
    module.assets = [
        common_types.Asset(asset_filepath)
        for asset_filepath in concrete_fn.graph.get_collection(
            tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
    ]
    tf.saved_model.save(module, saved_model_dir)
    return concrete_fn
Пример #10
0
def record_variable_creation_scope():
    """Creates a single use contextmanager for capture variable creation calls."""
    variable_list = []

    def logging_variable_creator(next_creator, **kwargs):
        variable = next_creator(**kwargs)
        variable_list.append(variable)
        return variable

    with contextlib.ExitStack() as stack:
        stack.enter_context(
            tf.variable_creator_scope(logging_variable_creator))
        yield variable_list
Пример #11
0
    def __call__(self,
                 data: tf.Tensor,
                 probe: bool = True,
                 aug_key: str = 'data') -> dict:
        """
      When training labeled data, use `probe=True` to update weights
      
      When training unlabeled data, use `probe=False aug_key=aug_data` to augment data
    
    Args:
        data (tf.Tensor): 
        probe (bool, optional): Defaults to True.
        aug_key (str, optional): Defaults to 'data'.
    
    Returns:
        dict: data_dict
    """
        # creating local variable which will store copy of CTA log probabilities
        with tf.variable_creator_scope(_skip_mirrored_creator):
            local_log_prob = tf.Variable(
                lambda: tf.ones(self.state_shape, dtype=tf.float32),
                trainable=False,
                name='cta_log_probs')
        self.log_probs.append(local_log_prob)

        output_dict = {}
        if probe:
            # 采样 [num_layers] 个 op_indices 和 op_args
            probe_op_indices, probe_op_args = self._sample_ops_uniformly()
            probe_data = self._apply_ops(data, probe_op_indices, probe_op_args)
            output_dict['probe_op_indices'] = probe_op_indices
            output_dict['probe_op_args'] = probe_op_args
            output_dict['probe_data'] = probe_data

        if aug_key is not None:
            op_indices, op_args = self._sample_ops(local_log_prob)
            aug_data = self._apply_ops(data,
                                       op_indices,
                                       op_args,
                                       prob_to_apply=self.prob_to_apply)
            output_dict[aug_key] = aug_data

        if aug_key != 'data':
            output_dict['data'] = data

        return output_dict
Пример #12
0
    def test_keras_layer_add_weight(self):
        class Layer(base_layer.Layer):
            def __init__(self):
                super().__init__()
                self.w = self.add_weight(
                    shape=(2, ),
                    initializer=lambda shape, dtype: [0, 1],
                    trainable=True)
                self.b = self.add_weight(
                    shape=(2, ),
                    initializer=lambda shape, dtype: [2, 3],
                    trainable=False)

        def sharded_variable_creator(next_creator, **kwargs):
            v1_value = kwargs['initial_value']()[0:1]
            v2_value = kwargs['initial_value']()[1:]

            kwargs['initial_value'] = v1_value
            kwargs['shape'] = (1, )
            v1 = next_creator(**kwargs)

            kwargs['initial_value'] = v2_value
            kwargs['shape'] = (1, )
            v2 = next_creator(**kwargs)

            return sharded_variable.ShardedVariable([v1, v2])

        with tf.variable_creator_scope(sharded_variable_creator):
            layer = Layer()

        self.assertLen(layer.trainable_weights, 2)
        self.assertEqual(layer.trainable_weights[0], [0])
        self.assertEqual(layer.trainable_weights[1], [1])
        self.assertLen(layer.non_trainable_weights, 2)
        self.assertEqual(layer.non_trainable_weights[0], [2])
        self.assertEqual(layer.non_trainable_weights[1], [3])
        self.assertAllEqual(
            layer.weights,
            layer.trainable_weights + layer.non_trainable_weights)
        self.assertAllEqual(layer.trainable_weights, layer.trainable_variables)
        self.assertAllEqual(layer.weights, layer.variables)

        checkpoint_deps = set(dep.ref
                              for dep in layer._checkpoint_dependencies)
        self.assertEqual(checkpoint_deps, set([layer.w, layer.b]))
Пример #13
0
def network_initializer(self):
    """

    """
    with tf.variable_creator_scope('convolution_network') as scope:
        output = self.
Пример #14
0
 def initialize():
     with tf.variable_creator_scope(variable_utils.create_tensor_variable):
         return tf.nest.map_structure(lambda m: m.variables,
                                      metrics_constructor())
Пример #15
0
 def finalize(state):
     with tf.variable_creator_scope(
             build_replace_variable_with_parameter_creator(state)):
         return tf.nest.map_structure(lambda metric: metric.result(),
                                      metrics_constructor())
Пример #16
0
def trace_and_update_module(
    module: tf.Module, tf_function: function.Function, name: str,
    strip_control_dependencies: bool) -> function.ConcreteFunction:
  """Traces `tf_function` and saves under attr `name` of `module`.

  Args:
    module: A saveable module which will contain the traced `tf_function` under
      attr `name`.
    tf_function: A tf.function to trace.
    name: A name to same the traced `tf_function` to.
    strip_control_dependencies: Boolean. If True, automatic control dependencies
      will be stripped from the outputs of `tf_function`. This should almost
      always be False. It is useful only if you want to use the structure of the
      TF graph to perform any graph manipulations.

  Returns:
    The concrete function obtained from tracing `tf_function`.
  """
  resource_tracker = tracking.ResourceTracker()
  object_tracker = annotators.ObjectTracker()
  created_variables = []

  def _variable_creator(next_creator, **kwargs):
    var = next_creator(**kwargs)
    created_variables.append(var)
    return var

  # Trace `tf_function` to gather any resources in it using the
  # resource_tracker. These are then assigned to `module.resources` and tracked
  # before exporting to SavedModel.
  with tracking.resource_tracker_scope(resource_tracker), \
       annotators.object_tracker_scope(object_tracker), \
       tf.variable_creator_scope(_variable_creator):
    concrete_fn = tf_function.get_concrete_function()

  # Prior to 2020/10/08, saving a tf.function with a concrete function signature
  # would ensure that the function was not re-traced in a round-trip to a
  # SavedModel. Since this is no longer the case, we save the concrete function
  # directly.
  if tf.compat.forward_compatible(2020, 10, 8):
    pruned_function = optimize_concrete_function(concrete_fn,
                                                 strip_control_dependencies)
    module.pruned_variables = pruned_function.variables
    setattr(module, name, pruned_function)
  else:
    setattr(module, name, tf_function)

  # Any variables created need to be explicitly tracked.
  module.created_variables = created_variables
  # Resources need to be explicitly tracked.
  module.resources = resource_tracker.resources
  module.trackable_objects = object_tracker.trackable_objects
  # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the
  # table should be sufficient.
  initializers = []
  for resource in module.resources:
    if isinstance(resource, lookup_ops.InitializableLookupTableBase):
      initializers.append(resource._initializer)  # pylint: disable=protected-access
  module.initializers = initializers
  module.assets = [
      common_types.Asset(asset_filepath) for asset_filepath in
      concrete_fn.graph.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
  ]
  return concrete_fn
Пример #17
0
def trace_and_write_v2_saved_model(saved_model_dir, preprocessing_fn,
                                   input_signature, base_temp_dir,
                                   tensor_replacement_map,
                                   output_keys_to_name_map):
    """Writes out a SavedModelV2 with preprocessing_fn traced using tf.function.

  The SavedModel written contains a method called `transform_fn` that
  represents the traced `preprocessing_fn`. Additionally, if this is the final
  SavedModel being written out, it will contain a method called `metadata_fn`
  that provides deferred schema annotations.

  Args:
    saved_model_dir: Path to write SavedModel to.
    preprocessing_fn: A user defined python function to be traced.
    input_signature: TypeSpecs describing the inputs to the `preprocessing_fn`.
    base_temp_dir: Base path to write temporary artifacts to.
    tensor_replacement_map: A map from placeholder tensor names to their
      evaluated replacement tensors.
    output_keys_to_name_map: A map from output dictionary keys to the names of
      the tensors that they represent.

  Returns:
    A tuple containing a pair of `tf.ConcreteFunction`s:
      1. The traced preprocessing_fn.
      2. A metadata_fn that returns a dictionary containing the deferred
      annotations added to the graph when invoked with any valid input.
  """

    module = tf.Module()
    transform_fn = get_traced_transform_fn(
        preprocessing_fn,
        input_signature,
        base_temp_dir,
        tensor_replacement_map=tensor_replacement_map,
        output_keys_to_name_map=output_keys_to_name_map)
    metadata_fn = None

    resource_tracker = tracking.ResourceTracker()
    created_variables = []

    def _variable_creator(next_creator, **kwargs):
        var = next_creator(**kwargs)
        created_variables.append(var)
        return var

    # TODO(b/164921571): Handle generic Trackable objects.
    # Trace the `transform_fn` and `metadata_fn` to gather any resources in it
    # using the resource_tracker. These are then assigned to `module.resources`
    # and tracked before exporting to SavedModel.
    with tracking.resource_tracker_scope(
            resource_tracker), tf.variable_creator_scope(_variable_creator):
        concrete_transform_fn = transform_fn.get_concrete_function()
        concrete_metadata_fn = None
        # If the `TENSOR_REPLACEMENTS` graph collection is empty, all TFT analyzers
        # in the `preprocessing_fn` have already been evaluated.
        if not concrete_transform_fn.graph.get_collection(
                analyzer_nodes.TENSOR_REPLACEMENTS):
            metadata_fn = schema_inference.get_traced_metadata_fn(
                tensor_replacement_map,
                preprocessing_fn,
                input_signature,
                base_temp_dir,
                evaluate_schema_overrides=True)
            concrete_metadata_fn = metadata_fn.get_concrete_function()

    # Save ConcreteFunction when possible since the above workaround won't work if
    # the tf.function is retraced.
    if tf.compat.forward_compatible(2020, 10, 8):
        module.transform_fn = concrete_transform_fn
        module.metadata_fn = concrete_metadata_fn
    else:
        module.transform_fn = transform_fn
        module.metadata_fn = metadata_fn

    # Any variables created need to be explicitly tracked.
    module.created_variables = created_variables
    # Resources need to be explicitly tracked.
    module.resources = resource_tracker.resources
    # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the
    # table should be sufficient.
    initializers = []
    for resource in module.resources:
        if isinstance(resource, lookup_ops.InitializableLookupTableBase):
            initializers.append(resource._initializer)  # pylint: disable=protected-access
    module.initializers = initializers
    module.assets = [
        common_types.Asset(asset_filepath)
        for asset_filepath in concrete_transform_fn.graph.get_collection(
            tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
    ]
    tf.saved_model.save(module, saved_model_dir)
    return concrete_transform_fn, concrete_metadata_fn
Пример #18
0
 def define_vggish(waveform):
   with tf.variable_creator_scope(var_tracker):
     features = waveform_to_features(waveform)
     return vggish_slim.define_vggish_slim(features, training=False)
    def __init__(self,
                 batch_size=64,
                 input_space=28 * 28,
                 latent_space=10,
                 p=3.,
                 middle_layers=None,
                 activation_fn=tf.nn.tanh,
                 learning_rate=0.001,
                 l2_lambda=0.001,
                 initializer_fn=he_initializer):
        self.batch_size = batch_size
        self.input_space = input_space
        self.latent_space = latent_space
        self.p = p
        self.middle_layers = [1024, 1024]
        self.activation_fn = activation_fn
        self.learning_rate = learning_rate
        self.initializer_fn = initializer_fn

        tf.reset_default_graph()

        self.input_x = tf.placeholder(tf.float32, [None, input_space])
        self.z_tensor = tf.placeholder(tf.float32, [None, latent_space])

        with tf.variable_scope('encoder'):
            self._encoder()
        self.encoded = self.encoder_layers[-1]

        with tf.variable_scope('decoder'):
            self.decoder_layers = self._decoder(self.encoded)
            self.decoded = self.decoder_layers[-1]
            tf.get_variable_scope().reuse_variables()
            self.generator_layers = self._decoder(self.z_tensor)
            self.generated = tf.nn.sigmoid(self.generator_layers[-1],
                                           name='generated')
        sizes = [64, 64, 1]
        with tf.variable_scope('discriminator'):
            self.disc_layers_neg = self._discriminator(self.encoded, sizes)
            self.disc_neg = self.disc_layers_neg[-1]
            tf.get_variable_scope().reuse_variables()
            self.disc_layers_pos = self._discriminator(self.z_tensor, sizes)
            self.disc_pos = self.disc_layers_pos[-1]
        self.pos_loss = tf.nn.relu(self.disc_pos) - self.disc_pos + tf.log(
            1.0 + tf.exp(-tf.abs(self.disc_pos)))
        self.neg_loss = tf.nn.relu(self.disc_neg) - self.disc_neg + tf.log(
            1.0 + tf.exp(-tf.abs(self.disc_neg)))
        self.disc_loss = tf.reduce_mean(tf.add(self.pos_loss, self.neg_loss))
        self.enc_loss = tf.reduce_mean(
            tf.subtract(self.neg_loss, self.disc_neg))
        batch_logloss = tf.reduce_sum(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=self.decoded,
                                                    labels=self.input_x), 1)
        self.ae_loss = tf.reduce_mean(batch_logloss)
        disc_ws = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                    scope='discriminator')
        ae_ws = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                  scope='encoder') + tf.get_collection(
                                      tf.GraphKeys.GLOBAL_VARIABLES,
                                      scope='decoder')
        self.l2_loss = tf.multiply(
            tf.reduce_sum([tf.nn.l2_loss(ws) for ws in ae_ws]), l2_lambda)
        self.gen_loss = tf.add(tf.add(self.enc_loss, self.ae_loss),
                               self.l2_loss)
        with tf.variable_creator_scope('optimizer'):
            self.train_discriminator = tf.train.RMSPropOptimizer(
                self.learning_rate).minimize(self.disc_loss, var_list=disc_ws)
            self.train_generator = tf.train.RMSPropOptimizer(
                self.learning_rate).minimize(self.gen_loss, var_list=ae_ws)
        self.sess = tf.Session()
        init = tf.global_variables_initializer()
        self.sess.run(init)