Esempio n. 1
0
  def grad_fn(inputs, variables, outputs, output_grads):
    """Custom gradient function."""
    del outputs  # recomputing below
    with common_layers.fn_device_dependency("diet_grad",
                                            output_grads[0].device) as out_dep:
      with tf.variable_scope(vs_ctr[0], reuse=True):
        outputs = fn(*inputs)

      variables = [common_layers.underlying_variable_ref(v) for v in variables]
      dequantized_variables = [
          params.dequantized[v.name][-1] for v in variables
      ]

      grads = tf.gradients(outputs, inputs + dequantized_variables,
                           output_grads)
      grad_inputs = grads[:len(inputs)]
      grad_variables = grads[len(inputs):]

      opt = _create_diet_optimizer(params)

      # Apply grad_variables here
      var_updates = []
      for v, dv in zip(variables, grad_variables):
        with tf.variable_scope(vs_ctr[0].name):
          opt.create_slots(v)
        update_op = opt.update_variable(v, dv)
        var_updates.append(update_op)

      with tf.control_dependencies(var_updates):
        grad_inputs = [tf.identity(dx) for dx in grad_inputs]

      out_dep.append(grad_inputs)

      return grad_inputs, [None] * len(variables)
Esempio n. 2
0
    def grad_fn(inputs, variables, outputs, output_grads):
        del outputs  # recomputing below
        with common_layers.fn_device_dependency(
                "diet_grad", output_grads[0].device) as out_dep:
            with tf.variable_scope(vs_ctr[0], reuse=True):
                outputs = fn(*inputs)

            variables = [
                common_layers.underlying_variable_ref(v) for v in variables
            ]
            dequantized_variables = [
                params.dequantized[v.name][-1] for v in variables
            ]

            grads = tf.gradients(outputs, inputs + dequantized_variables,
                                 output_grads)
            grad_inputs = grads[:len(inputs)]
            grad_variables = grads[len(inputs):]

            opt = _create_diet_optimizer(params)

            # Apply grad_variables here
            var_updates = []
            for v, dv in zip(variables, grad_variables):
                with tf.variable_scope(vs_ctr[0].name):
                    opt.create_slots(v)
                update_op = opt.update_variable(v, dv)
                var_updates.append(update_op)

            with tf.control_dependencies(var_updates):
                grad_inputs = [tf.identity(dx) for dx in grad_inputs]

            out_dep.append(grad_inputs)

            return grad_inputs, [None] * len(variables)
Esempio n. 3
0
    def custom_grad_fn(inputs, variables, ys, grad_ys):
        """Custom gradient fn for a block of reversible residual layers."""
        side_inputs = inputs[2:]
        f_side_idxs = [None] * len(f_side_input)
        g_side_idxs = [None] * len(g_side_input)
        assert len(side_inputs) == len(f_side_input) + len(g_side_input)

        for i, t in enumerate(side_inputs):
            if t in f_side_input:
                f_side_idxs[f_side_input.index(t)] = i
            elif t in g_side_input:
                g_side_idxs[g_side_input.index(t)] = i
            else:
                assert False

        f_vars = [[] for _ in range(num_layers)]
        g_vars = [[] for _ in range(num_layers)]
        f_vars_idxs = [[] for _ in range(num_layers)]
        g_vars_idxs = [[] for _ in range(num_layers)]

        for i, t in enumerate(variables):
            ref = common_layers.underlying_variable_ref(t)

            # Use the name to identify the layer number and function (f or g)
            regex = LAYER_RE.match(ref.name)
            layer_no = int(regex.group(1))
            fn_name = regex.group(2)
            if fn_name == "f":
                f_vars[layer_no].append(ref)
                f_vars_idxs[layer_no].append(i)
            else:
                assert fn_name == "g"
                g_vars[layer_no].append(ref)
                g_vars_idxs[layer_no].append(i)

        f_var_grads = []
        g_var_grads = []
        f_side_grads = []
        g_side_grads = []

        # Reverse variable containers to go backward
        layer_scopes.reverse()
        f_vars.reverse()
        g_vars.reverse()
        f.reverse()
        g.reverse()

        for i in xrange(num_layers):
            with tf.variable_scope(layer_scopes[i], reuse=True):

                ys, grad_ys, f_ret, g_ret = _rev_layer_backward(
                    ys, grad_ys, f[i], g[i], f_vars[i], f_side_input,
                    g_vars[i], g_side_input)

                grad_f_vars, grad_f_side = f_ret
                grad_g_vars, grad_g_side = g_ret
                f_var_grads.append(grad_f_vars)
                g_var_grads.append(grad_g_vars)
                f_side_grads.append(grad_f_side)
                g_side_grads.append(grad_g_side)

        # Accumulate layer gradients for f_side_input and g_side_input
        acc_f_side_grads = _acc_grads(*f_side_grads)
        acc_g_side_grads = _acc_grads(*g_side_grads)

        # Use the stored idxs to put gradients in the passed-in order.
        side_input_grads = [None] * len(side_inputs)
        variable_grads = [None] * len(variables)

        # Variable gradients were collected in reverse layer order. Reverse to match
        # idxs.
        f_var_grads.reverse()
        g_var_grads.reverse()
        for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list(
                zip(g_vars_idxs, g_var_grads)):
            for i, grad in zip(idxs, grads):
                variable_grads[i] = grad

        for i, grad in zip(f_side_idxs, acc_f_side_grads):
            side_input_grads[i] = grad
        for i, grad in zip(g_side_idxs, acc_g_side_grads):
            side_input_grads[i] = grad

        grad_x1, grad_x2 = grad_ys
        return [grad_x1, grad_x2] + side_input_grads, variable_grads
Esempio n. 4
0
  def _efficient_grad_fn(self, inputs, variables, ys, grad_ys):
    """Custom gradient fn for a block of reversible residual layers."""
    side_inputs = inputs[2:]
    f_side_idxs = [None] * len(self.f_side_input)
    g_side_idxs = [None] * len(self.g_side_input)
    assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input)

    for i, t in enumerate(side_inputs):
      if t in self.f_side_input:
        f_side_idxs[self.f_side_input.index(t)] = i
      elif t in self.g_side_input:
        g_side_idxs[self.g_side_input.index(t)] = i
      else:
        assert False

    f_vars = [[] for _ in range(self.num_layers)]
    g_vars = [[] for _ in range(self.num_layers)]
    f_vars_idxs = [[] for _ in range(self.num_layers)]
    g_vars_idxs = [[] for _ in range(self.num_layers)]

    for i, t in enumerate(variables):
      ref = common_layers.underlying_variable_ref(t)

      # Use the name to identify the layer number and function (f or g)
      regex = LAYER_RE.match(ref.name)
      layer_no = int(regex.group(1))
      fn_name = regex.group(2)
      if fn_name == "f":
        f_vars[layer_no].append(ref)
        f_vars_idxs[layer_no].append(i)
      else:
        assert fn_name == "g"
        g_vars[layer_no].append(ref)
        g_vars_idxs[layer_no].append(i)

    f_var_grads = []
    g_var_grads = []
    f_side_grads = []
    g_side_grads = []

    # Reverse variable containers to go backward
    f_vars.reverse()
    g_vars.reverse()
    f = list(self.f)
    g = list(self.g)
    f.reverse()
    g.reverse()

    for i in xrange(self.num_layers):
      ys, grad_ys, f_ret, g_ret = _rev_layer_backward(
          ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i],
          self.g_side_input)

      grad_f_vars, grad_f_side = f_ret
      grad_g_vars, grad_g_side = g_ret
      f_var_grads.append(grad_f_vars)
      g_var_grads.append(grad_g_vars)
      f_side_grads.append(grad_f_side)
      g_side_grads.append(grad_g_side)

    # Accumulate layer gradients for f_side_input and g_side_input
    acc_f_side_grads = _acc_grads(*f_side_grads)
    acc_g_side_grads = _acc_grads(*g_side_grads)

    # Use the stored idxs to put gradients in the passed-in order.
    side_input_grads = [None] * len(side_inputs)
    variable_grads = [None] * len(variables)

    # Variable gradients were collected in reverse layer order. Reverse to match
    # idxs.
    f_var_grads.reverse()
    g_var_grads.reverse()
    for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list(
        zip(g_vars_idxs, g_var_grads)):
      for i, grad in zip(idxs, grads):
        variable_grads[i] = grad

    for i, grad in zip(f_side_idxs, acc_f_side_grads):
      side_input_grads[i] = grad
    for i, grad in zip(g_side_idxs, acc_g_side_grads):
      side_input_grads[i] = grad

    grad_x1, grad_x2 = grad_ys
    return [grad_x1, grad_x2] + side_input_grads, variable_grads