Exemple #1
0
  def grad_fn(inputs, variables, outputs, output_grads):
    del outputs  # recomputing below
    with fn_device_dependency("diet_grad", output_grads[0].device) as out_dep:
      with tf.variable_scope(vs_ctr[0], reuse=True):
        outputs = fn(*inputs)

      variables = [variable_ref(v) for v in variables]
      dequantized_variables = [params.dequantized[v.name][-1] for v in variables]

      grads = tf.gradients(outputs, inputs + dequantized_variables, output_grads)
      grad_inputs = grads[:len(inputs)]
      grad_variables = grads[len(inputs):]

      opt = _create_diet_optimizer(params)

      # Apply grad_variables here
      var_updates = []
      for v, dv in zip(variables, grad_variables):
        with tf.variable_scope(vs_ctr[0].name):
          opt.create_slots(v)
        update_op = opt.update_variable(v, dv)
        var_updates.append(update_op)

      with tf.control_dependencies(var_updates):
        grad_inputs = [tf.identity(dx) for dx in grad_inputs]

      out_dep.append(grad_inputs)

      return grad_inputs, [None] * len(variables)
Exemple #2
0
  def diet_var_initializer(shape, dtype, partition_info=None):
    del dtype
    del partition_info

    with fn_device_dependency("diet_init") as out_deps:
      float_range = math.sqrt(3)
      ret = tf.random_uniform(shape, -float_range, float_range)
      if params.quantize:
        ret = _quantize(ret, params, randomize=False)
      out_deps.append(ret)
      return ret
Exemple #3
0
def _fn_with_diet_vars(fn, args, params):
    """Call function with args; use diet variables according to params."""

    vs_ctr = []

    def grad_fn(inputs, variables, outputs, output_grads):
        del outputs  # recomputing below
        with fn_device_dependency("diet_grad",
                                  output_grads[0].device) as out_dep:
            with tf.variable_scope(vs_ctr[0], reuse=True):
                outputs = fn(*inputs)

            variables = [variable_ref(v) for v in variables]
            dequantized_variables = [
                params.dequantized[v.name][-1] for v in variables
            ]

            grads = tf.gradients(outputs, inputs + dequantized_variables,
                                 output_grads)
            grad_inputs = grads[:len(inputs)]
            grad_variables = grads[len(inputs):]

            opt = _create_diet_optimizer(params)

            # Apply grad_variables here
            var_updates = []
            for v, dv in zip(variables, grad_variables):
                with tf.variable_scope(vs_ctr[0].name):
                    opt.create_slots(v)
                update_op = opt.update_variable(v, dv)
                var_updates.append(update_op)

            with tf.control_dependencies(var_updates):
                grad_inputs = [tf.identity(dx) for dx in grad_inputs]

            out_dep.append(grad_inputs)

            return grad_inputs, [None] * len(variables)

    @fn_with_custom_grad(grad_fn, use_global_vars=True)
    def forward(*inputs):
        with tf.variable_scope(
                None,
                default_name="diet",
                custom_getter=make_diet_var_getter(params)) as vs:
            vs_ctr.append(vs)
            outputs = fn(*inputs)
            return outputs

    with fn_device_dependency("diet_forward", args[0].device) as out_dep:
        outputs = forward(*args)
        out_dep.append(outputs)
    return outputs