Beispiel #1
0
  def _check_same_graph(self):
    """Checks that the module is not being connect to multiple Graphs.

    An instance of a Sonnet module 'owns' the variables it contains, and permits
    seamless variable sharing. As such, connecting a single module instance to
    multiple Graphs is not possible - this function will raise an error should
    that occur.

    Raises:
      DifferentGraphError: if the module is connected to a different Graph than
        it was previously used in.
    """
    with tf.init_scope():
      # We need `init_scope` incase we're running inside a defun. In that case
      # what we want is information about where the function will be called not
      # where the function is being built.
      current_graph = tf.get_default_graph()
      will_call_in_eager_context = tf.executing_eagerly()

    if self._graph is None:
      self._graph = current_graph
      self._set_module_info()

    if not will_call_in_eager_context:
      # Same graph checks only make sense when calling from graph mode (in eager
      # mode there is a single process level context where all modules are
      # created).
      if self._graph != current_graph:
        raise DifferentGraphError("Cannot connect module to multiple Graphs.")
Beispiel #2
0
def compute_sgd_updates(self, grads_and_vars):
    """Constructs and returns tensors of SGD-updated parameters."""
    grads_and_vars = tuple(grads_and_vars)
    var_list = [v for _, v in grads_and_vars]

    # Calculate SGD updates.
    with tf.name_scope(self._name):
        # Initialize.
        with tf.init_scope():
            self._create_slots(var_list)
        self._prepare()

        # Compute updates for each variable.
        var_updates = []
        for grad, var in grads_and_vars:
            if grad is None:
                var_updates.append(None)
                continue

            lr = tf.cast(self._learning_rate_tensor, var.dtype.base_dtype)
            var_updates.append(lr * grad)

        # Compute updated variables.
        updated_vars = []
        for var, var_update in zip(var_list, var_updates):
            assert var_updates is not None
            updated_vars.append((
                var - var_update) if var_update is not None else var)

    return updated_vars
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    with tf.init_scope():
      self._create_slots([v for (_, v) in grads_and_vars])

    accums = []
    variables = []

    for g, v in grads_and_vars:
      accum = self.get_slot(v, 'grad_accum')
      variables.append(v)
      if isinstance(g, tf.IndexedSlices):
        scaled_grad = tf.IndexedSlices(
            g.values / self._grad_steps, g.indices, dense_shape=g.dense_shape)
        accums.append(accum.assign_add(scaled_grad))  # pytype: disable=attribute-error
      else:
        accums.append(accum.assign_add(g / self._grad_steps))  # pytype: disable=attribute-error

    def _apply_and_zero():
      apply_op = self._opt.apply_gradients(list(zip(accums, variables)))
      with tf.control_dependencies([apply_op]):
        zero_op = [tf.assign(accum, tf.zeros_like(accum)) for accum in accums]
      return tf.group(zero_op, tf.assign_add(self._counter, 1))

    def _accum():
      return tf.group(accums)

    accum_step = tf.cond(
        tf.equal(tf.mod(global_step, self._grad_steps), self._grad_steps - 1),
        _apply_and_zero, _accum)

    with tf.control_dependencies([accum_step]):
      global_step = tf.assign_add(global_step, 1)
      return tf.group(global_step)
    def _generate(self, feature_map_shape_list):
        """Generates a collection of bounding boxes to be used as anchors.

    Args:
      feature_map_shape_list: list of pairs of convnet layer resolutions in the
        format [(height_0, width_0)].  For example, setting
        feature_map_shape_list=[(8, 8)] asks for anchors that correspond
        to an 8x8 layer.  For this anchor generator, only lists of length 1 are
        allowed.

    Returns:
      boxes_list: a list of BoxLists each holding anchor boxes corresponding to
        the input feature map shapes.

    Raises:
      ValueError: if feature_map_shape_list, box_specs_list do not have the same
        length.
      ValueError: if feature_map_shape_list does not consist of pairs of
        integers
    """
        if not (isinstance(feature_map_shape_list, list)
                and len(feature_map_shape_list) == 1):
            raise ValueError(
                'feature_map_shape_list must be a list of length 1.')
        if not all([
                isinstance(list_item, tuple) and len(list_item) == 2
                for list_item in feature_map_shape_list
        ]):
            raise ValueError('feature_map_shape_list must be a list of pairs.')

        # Create constants in init_scope so they can be created in tf.functions
        # and accessed from outside of the function.
        with tf.init_scope():
            self._base_anchor_size = tf.cast(tf.convert_to_tensor(
                self._base_anchor_size),
                                             dtype=tf.float32)
            self._anchor_stride = tf.cast(tf.convert_to_tensor(
                self._anchor_stride),
                                          dtype=tf.float32)
            self._anchor_offset = tf.cast(tf.convert_to_tensor(
                self._anchor_offset),
                                          dtype=tf.float32)

        grid_height, grid_width = feature_map_shape_list[0]
        scales_grid, aspect_ratios_grid = ops.meshgrid(self._scales,
                                                       self._aspect_ratios)
        scales_grid = tf.reshape(scales_grid, [-1])
        aspect_ratios_grid = tf.reshape(aspect_ratios_grid, [-1])
        anchors = tile_anchors(grid_height, grid_width, scales_grid,
                               aspect_ratios_grid, self._base_anchor_size,
                               self._anchor_stride, self._anchor_offset)

        num_anchors = anchors.num_boxes_static()
        if num_anchors is None:
            num_anchors = anchors.num_boxes()
        anchor_indices = tf.zeros([num_anchors])
        anchors.add_field('feature_map_index', anchor_indices)
        #print(anchors)
        return [anchors]
 def _get_beta_accumulators(self):
     with tf.init_scope():
         if tf.executing_eagerly():
             graph = None
         else:
             graph = tf.get_default_graph()
         return (self._get_non_slot_variable("beta1_power", graph=graph),
                 self._get_non_slot_variable("beta2_power", graph=graph))
Beispiel #6
0
    def _create_optimizer(self):
        """Initializes the hyperparameters and sets the self._optimizer property."""
        if self._optimizer:
            return
        if not self._layer_collection:
            self.register_layers(self._model, self._loss)

        if self._config['adapt_damping']:
            if 'train_batch' not in self._kfac_kwargs:
                raise ValueError(
                    'Must provide a train_batch tuple to use adaptive '
                    'damping. Use register_train_batch or pass it in '
                    'during optimizer construction.')
            if 'loss_fn' not in self._kfac_kwargs:
                self._kfac_kwargs['loss_fn'] = utils.get_loss_fn(
                    self._model,
                    self._loss,
                    loss_weights=self._config['loss_weights'])

        with tf.name_scope(self._name):
            with tf.init_scope():
                # "iterations" property will create iterations if necessary.
                _ = self.iterations
                self._create_hypers()

        self._kfac_kwargs.update(self._hyper)
        try:
            # We use the TF 1 variable_scope instead of the TF 2 recommended
            # name_scope because we need to recover the variables created in this
            # scope, which is not possible with name_scope.
            with tf.variable_scope(self._tf_var_scope):
                self._optimizer = _KFAC_OPT_CLASS(
                    layer_collection=self._layer_collection,
                    **self._kfac_kwargs)
        except ValueError as e:
            msg = str(e)
            if re.search('Variable .* already exists', msg):
                raise ValueError(
                    'You may have instantiated a KFAC Optimizer with the same name as '
                    'an existing one. Try resetting the default graph, instantiating '
                    'the optimizer with a different name, or changing the optimizer\'s '
                    'name.\nHere is the original ValueError:\n ' + msg)
            elif re.search(
                    'Found the following errors with variable registration'
                    '.*gamma.*registered with wrong number of uses.*', msg):
                # We don't regex the name batch_normalization because the user could
                # have renamed the layer. We don't regex beta because they could have
                # used BatchNorm without the shift.
                raise ValueError(
                    'There may have been an issue registering BatchNormalization. Try '
                    'using tf.keras.backend.set_learning_phase before model '
                    'construction. An alternative solution is to use the unfused '
                    'batchnorm implementation (pass the argument fused=False to '
                    'BatchNormalization).\nHere is the original ValueError:\n '
                    + msg)
            else:
                raise e
Beispiel #7
0
    def getOrCreateSparseLinear(self,
                                x_shape,
                                x_dtype,
                                sparsity,
                                dense_length,
                                block_size,
                                use_bias,
                                override_partials_type=None):
        x_dtype = tf.as_dtype(x_dtype)

        # Each layer should have a unique scope name
        scope_name = tf.get_default_graph().get_name_scope()
        logger.info(f"Sparse layer with scope name: {scope_name}")

        # Construct the layer if it does not exist
        if scope_name not in self.sparse_layers:
            layer_matmul_options = self.sparse_matmul_options
            if override_partials_type:
                layer_matmul_options['partialsType'] = override_partials_type
            limit = np.sqrt(6 / ((x_shape[-1] + dense_length) *
                                 (1 - sparsity)))
            uniform_gen = partial(self.random.uniform, -limit, limit)
            indices_random_gen = np.random.default_rng(seed=self.random_seed)
            sparse_layer = layers.SparseFcLayer.from_random_generator(
                dense_length,
                x_shape,
                1 - sparsity,
                block_size=block_size,
                values_initialiser_gen=uniform_gen,
                indices_initialiser_gen=indices_random_gen,
                name="sparse_layer",
                dtype=x_dtype,
                matmul_options=layer_matmul_options,
                use_bias=use_bias,
                relu=False,
                disable_updating=self.disable_updating,
                pooling_type=self.pooling_type)

            # Create placeholders on the host, outside XLA
            with tf.init_scope():  # escapes XLA
                with tf.device("cpu"):
                    sparse_layer.create_placeholders()
                self.sparse_layers[scope_name] = sparse_layer
        else:
            # Re-use a previously defined layer
            sparse_layer = self.sparse_layers[scope_name]

        return sparse_layer
Beispiel #8
0
def get_global_variables_safely():
  """If not executing eagerly, returns tf.global_variables().

  Raises a ValueError if eager execution is enabled,
  because the variables are not tracked when executing eagerly.

  If executing eagerly, use a Keras model's .variables property instead.

  Returns:
    The result of tf.global_variables()
  """
  with tf.init_scope():
    if tf.executing_eagerly():
      raise ValueError("Global variables collection is not tracked when "
                       "executing eagerly. Use a Keras model's `.variables` "
                       "attribute instead.")
  return tf.global_variables()
Beispiel #9
0
    def sparseLinear(self, x, sparsity, dense_length, compute_dense_grad,
                     use_bias):
        # The underlying API requires a 2D tensor, so collapse the batch dimensions
        *batch_dimensions, input_length = x.shape.as_list()
        x = tf.reshape(x, [-1, input_length])
        x_shape = x.shape.with_rank(2).as_list()

        # Each layer should have a unique scope name
        scope_name = tf.get_default_graph().get_name_scope()
        logger.info(f"Sparse layer with scope name: {scope_name}")

        # Construct the layer if it does not exist
        if scope_name not in self.sparse_layers:
            limit = np.sqrt(6 / ((x_shape[-1] + dense_length) *
                                 (1 - sparsity)))
            uniform_gen = partial(self.random.uniform, -limit, limit)
            indices_random_gen = np.random.default_rng(seed=self.random_seed)
            sparse_layer = layers.SparseFcLayer.from_random_generator(
                dense_length,
                x_shape,
                1 - sparsity,
                uniform_gen,
                indices_random_gen,
                name="sparse_layer",
                dtype=x.dtype,
                matmul_options=self.sparse_matmul_options,
                bias=use_bias,
                relu=False,
                disable_updating=self.disable_updating)

            # Create placeholders on the host, outside XLA
            with tf.init_scope():  # escapes XLA
                with tf.device("cpu"):
                    sparse_layer.create_placeholders()
                self.sparse_layers[scope_name] = sparse_layer
        else:
            # Re-use a previously defined layer
            sparse_layer = self.sparse_layers[scope_name]

        # Call the layer with the provided input
        x = sparse_layer(x, compute_dense_grad)

        # Recover the original batch dimensions
        x = tf.reshape(x, batch_dimensions + [dense_length])
        return x
Beispiel #10
0
    def record_slot_var(self, slot_name: str, optimizer: tf.train.Optimizer):
        """
        Used by the optimiser to record a slot with this layer.
        Returns the Tensorflow slot_variable that was recorded.
        """
        if self.values_var is None:
            raise AttributeError(
                f"This sparse layer '{self.name}' is being asked to record a "
                "slot variable but it has not yet been called! "
                "Make sure you call this layer upstream of the loss op or "
                "remove it from the sparse_layers list.")

        slot_var = optimizer.get_slot(self.values_var, slot_name)
        internal_name = slot_var.name
        logger.debug(
            f"Recording slot variable {slot_var.name} as {internal_name}")

        if slot_var is None:
            raise ValueError(
                f"This sparse layer '{self.name}' is being asked to record "
                f"a slot variable for '{self.values_var.name}' but no such "
                "slot exists! Make sure the loss op is actually dependent "
                "on this layer or remove it from the sparse_layers list.")

        if slot_var.shape != self.weights.get_values().shape:
            raise ValueError(
                f"Shape mismatch between variable {slot_var.shape} "
                f"and slot {self.weights.get_values().shape}")

        with tf.init_scope():  # escapes XLA, so placeholders can be created
            with tf.device("cpu"):
                placeholder = tf.placeholder(dtype=slot_var.dtype,
                                             shape=slot_var.shape)

        self.sparse_slots[internal_name] = SparseSlot(
            placeholder=placeholder,
            tf_variable=slot_var,
            np_variable=np.zeros_like(self.weights.get_values()))
        return slot_var
Beispiel #11
0
    def sparseLinear(x, dense_length, opts):
        x_shape = x.shape.with_rank(2).as_list()
        limit = np.sqrt(6 / ((x_shape[-1] + dense_length) * opts.density))
        uniform_gen = partial(np.random.uniform, -limit, limit)
        indices_random_gen = np.random.default_rng(seed=0)

        sparse_layer = layers.SparseFcLayer.from_random_generator(
            hidden_size=dense_length,
            input_shape=x_shape,
            density=opts.density,
            block_size=1,
            values_initialiser_gen=uniform_gen,
            indices_initialiser_gen=indices_random_gen,
            name="sparse_layer",
            dtype=x.dtype,
            matmul_options=opts.sparse_matmul_options,
            use_bias=opts.use_bias,
            relu=True,
            disable_updating=opts.disable_updating,
            pooling_type="NONE")

        # Create placeholders on the host, outside XLA
        with tf.init_scope():  # escapes XLA
            with tf.device("cpu"):
                sparse_layer.create_placeholders()
        sparse_layers.append(sparse_layer)

        if use_ipu_function:
            @ipu.outlined_function
            def f(x):
                # Call the layer with the provided input
                x = sparse_layer(x, opts.compute_dense_grad)
                return x
            return f(x)
        else:
            return sparse_layer(x, opts.compute_dense_grad)
Beispiel #12
0
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):

    # Emulate global step to the optimizer, we will increment it every "apply_gradients" call.
    # and also use it to perform Warmup ramping and Weight decay.
    internal_global_step_name = self._get_variable_name("global_step_tf")
    with tf.init_scope():
      internal_global_step = self._get_or_make_slot(global_step, tf.constant(self.initial_step, dtype=tf.float32), internal_global_step_name, internal_global_step_name)

    print('self.initial_step = ' + str(self.initial_step))

    learning_rate = tf.train.polynomial_decay(
        self.learning_rate,
        internal_global_step,
        self.num_train_steps,
        end_learning_rate=0.0,
        power=self.lr_decay_power,
        cycle=False)

    warmup_steps = max(self.num_train_steps * self.warmup_proportion, self.warmup_steps)

    '''
    if layerwise_lr_decay_power > 0:
      learning_rate = _get_layer_lrs(learning_rate, layerwise_lr_decay_power,
                                    n_transformer_layers)
    '''

    internal_global_step = self.get_slot(global_step, internal_global_step_name)
    # global_step_print = tf.Print(internal_global_step, ['internal_global_step', tf.shape(internal_global_step), internal_global_step], summarize=32)
    global_step_update_op = internal_global_step.assign(internal_global_step + 1)

    print('warmup_steps = ' + str(warmup_steps))
    print('num_train_steps = ' + str(self.num_train_steps))

    learning_rate *= tf.minimum(
        1.0, tf.cast(internal_global_step, tf.float32) / tf.cast(warmup_steps, tf.float32))
    #lr_print = tf.Print(learning_rate, ['learning_rate', tf.shape(learning_rate), learning_rate], summarize=32)

    lr_update_op = self.learning_rate_tensor.assign(learning_rate)

    # Clip the gradient to be at most 1.0 (from original BERT implementation)
    grads, tvars = zip(*grads_and_vars)
    (clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
    grads_and_vars = list(zip(clipped_grads, tvars))

    if isinstance(learning_rate, dict):
      key_to_grads_and_vars = {}
      for grad, var in grads_and_vars:
        update_for_var = False
        for key in self.learning_rate:
          if key in var.name:
            update_for_var = True
            if key not in key_to_grads_and_vars:
              key_to_grads_and_vars[key] = []
            key_to_grads_and_vars[key].append((grad, var))
        if not update_for_var:
          raise ValueError("No learning rate specified for variable", var)
      assignments = []
      for key, key_grads_and_vars in key_to_grads_and_vars.items():
        assignments += self._apply_gradients(key_grads_and_vars,
                                             learning_rate[key])
    else:
      assignments = self._apply_gradients(grads_and_vars, learning_rate)
    return tf.group([*assignments, global_step_update_op, lr_update_op], name=name)
Beispiel #13
0
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        grad_list = []
        var_list = []
        for g, v in grads_and_vars:
            grad_list.append(g)
            var_list.append(v)
        with tf.init_scope():
            self._create_slots(var_list)

        # accumulate gradients
        accums = []
        for g, v in zip(grad_list, var_list):
            accum = self.get_slot(v, 'grad_accum')
            # pytype: disable=attribute-error
            if isinstance(g, tf.IndexedSlices):
                scaled_grad = tf.IndexedSlices(g.values / self._grad_steps,
                                               g.indices,
                                               dense_shape=g.dense_shape)
                accums.append(
                    accum.assign(
                        self._sharding(accum.read_value()) + scaled_grad))
            else:
                accums.append(
                    accum.assign(
                        self._sharding(accum.read_value()) +
                        g / self._grad_steps))
            # pytype: enable=attribute-error

        if self._use_tpu:

            def _apply_and_zero_tpu2():
                normalized_accums = accums
                if self._apply_crs_to_grad:
                    normalized_accums = [
                        tf.tpu.cross_replica_sum(accum.read_value())
                        for accum in accums
                    ]
                apply_op = self._opt.apply_gradients(
                    list(zip(normalized_accums, var_list)))
                with tf.control_dependencies([apply_op]):
                    zero_op = [
                        tf.assign(accum, tf.zeros_like(accum))
                        for accum in accums
                    ]
                return tf.group(zero_op, tf.assign_add(global_step, 1))

            def _accum_tpu2():
                return tf.group(tf.no_op(), tf.assign_add(global_step, 1))

            accum_step = tf.cond(
                tf.equal(tf.mod(self._counter, self._grad_steps),
                         self._grad_steps - 1), _apply_and_zero_tpu2,
                _accum_tpu2)

            with tf.control_dependencies([tf.group(accums)]):
                return tf.group(accum_step, tf.assign_add(self._counter, 1))

        # for GPUs, use merge_call outside tf.cond to avoid issues
        with tf.control_dependencies([tf.group(accums)]):
            merge_return = tf.distribute.get_replica_context().merge_call(
                self._maybe_apply_grads_and_zero,
                args=(global_step, accums, var_list))

        return merge_return
Beispiel #14
0
    def call_method(method, obj, args, kwargs):
        """Calls `method` with a variable scope whose reuse flag is set correctly.

    The first time the wrapper is called it creates a
    `(tf.Graph, tf.VariableScope)` key and checks it for membership in
    `initialized_variable_scopes`. The check is `False` if and only if this is
    the first time the wrapper has been called with the key, otherwise the
    check is `True`. The result of this check is used as the `reuse` flag for
    entering the provided variable scope before calling `method`.

    Here are two examples of how to use the reuse_variables decorator.

    1. Decorate an arbitrary instance method with a `variable_scope` attribute:

      ```python
      class Reusable(object):

        def __init__(self, name):
          with tf.variable_scope(None, default_name=name) as vs:
            self.variable_scope = vs

        @snt.reuse_variables
        def add_a(self, input_tensor):
          a = tf.get_variable("a", shape=input_tensor.get_shape())
          return a + input_tensor

      obj = Reusable("reusable")
      x = tf.constant(5.0)
      out1 = obj.add_a(x)
      out2 = obj.add_a(x)
      # out1 == out2
      ```

    2. Decorating a snt.AbstractModule instance method:

      ```python
      class ReusableModule(snt.AbstractModule):

        @snt.reuse_variables
        def add_a(self, input_tensor):
          a = tf.get_variable("a", shape=input_tensor.get_shape())
          return a + input_tensor

        # We don't need @snt.reuse_variables here because build is
        wrapped by # `tf.make_template` inside `snt.AbstractModule`.
        def _build(self, input_tensor):
          b = tf.get_variable("b", shape=input_tensor.get_shape())
          return b + self.add_a(input_tensor)

      obj = Reusable("reusable")
      x = tf.constant(5.0)
      out1 = obj(x)
      out2 = obj(x)
      # out1 == out2
      ```

    Args:
      method: The method to wrap.
      obj: The object instance passed to the wrapped method.
      args: The positional arguments (Tensors) passed to the wrapped method.
      kwargs: The keyword arguments passed to the wrapped method.

    Returns:
      Output of the wrapped method.

    Raises:
      ValueError: If no variable scope is provided or if `method` is a method
                  and a variable_scope keyword argument is also provided.
    """

        # If @reuse_variables is combined with @property, obj is passed in args
        # and method is still unbound at this stage.
        if obj is None:
            obj = args[0]

        def default_context_manager(reuse=None):
            variable_scope = obj.variable_scope
            return tf.variable_scope(variable_scope, reuse=reuse)

        variable_scope_context_manager = getattr(obj, "_enter_variable_scope",
                                                 default_context_manager)

        with tf.init_scope():
            # We need `init_scope` incase we're running inside a defun. In that case
            # what we want is information about where the function will be called not
            # where the function is being built.
            graph = tf.get_default_graph()
            will_call_in_eager_context = tf.executing_eagerly()

        if will_call_in_eager_context:
            initialized_variable_scopes = initialized_variable_scopes_eager
        else:
            if graph not in initialized_variable_scopes_graph:
                initialized_variable_scopes_graph[graph] = set()
            initialized_variable_scopes = initialized_variable_scopes_graph[
                graph]

        # Temporarily enter the variable scope to capture it
        with variable_scope_context_manager() as tmp_variable_scope:
            variable_scope = tmp_variable_scope

        reuse = variable_scope.name in initialized_variable_scopes

        # Enter the pure variable scope with reuse correctly set
        with variable_scope_ops._pure_variable_scope(  # pylint:disable=protected-access
                variable_scope, reuse=reuse) as pure_variable_scope:
            current_name_scope = tf.get_default_graph().get_name_scope()
            # Force tf.name_scope to treat current_name_scope as an "absolute" scope
            # so we can re-enter it.
            if current_name_scope and current_name_scope[-1] != "/":
                current_name_scope += "/"
            with tf.name_scope(current_name_scope):
                module_name = pure_variable_scope.name
                method_name = to_snake_case(method.__name__)
                method_name_scope = "{}/{}".format(module_name, method_name)
                with tf.name_scope(method_name_scope) as scope:
                    if hasattr(obj, "_capture_variables"):
                        with obj._capture_variables():  # pylint: disable=protected-access
                            out_ops = method(*args, **kwargs)
                    else:
                        out_ops = method(*args, **kwargs)
            initialized_variable_scopes.add(pure_variable_scope.name)
            try:
                # If `obj` is a Sonnet module, let it know it's been connected
                # to the TF graph.
                obj._is_connected = True  # pylint: disable=protected-access
                if not tf.executing_eagerly():
                    obj._add_connected_subgraph(  # pylint: disable=protected-access
                        method, out_ops, scope, args, kwargs)
            except AttributeError:
                pass
        return out_ops
 def _get_beta_accumulators(self):
     with tf.init_scope():
         graph = tf.get_default_graph()
         return (self._get_non_slot_variable("beta1_power", graph=graph),
                 self._get_non_slot_variable("beta2_power", graph=graph))
Beispiel #16
0
def compute_adam_updates(self, grads_and_vars):
    """Constructs and returns tensors of Adam-updated parameters."""
    grads_and_vars = tuple(grads_and_vars)
    var_list = [v for _, v in grads_and_vars]

    # Calculate Adam updates.
    with tf.name_scope(self._name):
        # Initialize.
        with tf.init_scope():
            self._create_slots(var_list)
        self._prepare()

        # Compute updates for each variable.
        var_updates = []
        for grad, var in grads_and_vars:
            if grad is None:
                var_updates.append(None)
                continue

            # Setup.
            beta1_power, beta2_power = self._get_beta_accumulators()
            beta1_power = tf.cast(beta1_power, var.dtype.base_dtype)
            beta2_power = tf.cast(beta2_power, var.dtype.base_dtype)
            lr_t = tf.cast(self._lr_t, var.dtype.base_dtype)
            beta1_t = tf.cast(self._beta1_t, var.dtype.base_dtype)
            beta2_t = tf.cast(self._beta2_t, var.dtype.base_dtype)
            epsilon_t = tf.cast(self._epsilon_t, var.dtype.base_dtype)
            lr = lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power)

            # m_t = beta1 * m + (1 - beta1) * g_t.
            m = self.get_slot(var, "m")
            m_t = tf.assign(m,
                            beta1_t * m + grad * (1 - beta1_t),
                            use_locking=self._use_locking)

            # v_t = beta2 * v + (1 - beta2) * (g_t * g_t).
            v = self.get_slot(var, "v")
            v_t = tf.assign(
                v,
                beta2_t * v + (grad * grad) * (1 - beta2_t),
                use_locking=self._use_locking,
            )

            var_updates.append(lr * m_t / (tf.sqrt(v_t) + epsilon_t))

        # Update power accumulators.
        with tf.control_dependencies(var_updates):
            beta1_power, beta2_power = self._get_beta_accumulators()
            update_beta1 = tf.assign(beta1_power,
                                     beta1_power * self._beta1_t,
                                     use_locking=self._use_locking)
            update_beta2 = tf.assign(beta2_power,
                                     beta2_power * self._beta2_t,
                                     use_locking=self._use_locking)

        # Compute updated variables.
        updated_vars = []
        with tf.control_dependencies([update_beta1, update_beta2]):
            for var, var_update in zip(var_list, var_updates):
                assert var_updates is not None
                updated_vars.append((
                    var - var_update) if var_update is not None else var)

    return updated_vars