Esempio n. 1
0
  def meta_optimize(self):
    """Meta optimization step."""

    probe_images, probe_labels = self.probe_images, self.probe_labels
    labels = self.labels
    net = self.net
    logits = self.logits
    gate_gradients = 1

    batch_size = int(self.batch_size / self.strategy.num_replicas_in_sync)
    init_eps_val = float(1) / batch_size

    meta_net = networks.MetaImage(self.net, name='meta_model')

    if FLAGS.meta_momentum and not self.optimizer.variables():
      # Initializing momentum state of optimizer for meta momentum update.
      # It is a hacky implementation
      logging.info('Pre-initialize optimizer momentum states.')
      idle_net_cost = tf.losses.sparse_softmax_cross_entropy(
          self.labels, logits)
      tmp_var_grads = self.optimizer.compute_gradients(
          tf.reduce_mean(idle_net_cost), net.trainable_variables)
      self.optimizer.apply_gradients(tmp_var_grads)

    with tf.name_scope('coefficient'):
      # Data weight coefficient
      target = tf.constant(
          [init_eps_val] * batch_size,
          shape=(batch_size,),
          dtype=np.float32,
          name='weight')
      # Data re-labeling coefficient
      eps = tf.constant(
          [FLAGS.grad_eps_init] * batch_size,
          shape=(batch_size,),
          dtype=tf.float32,
          name='eps')

    onehot_labels = tf.one_hot(labels, self.dataset.num_classes)
    onehot_labels = tf.cast(onehot_labels, tf.float32)
    eps_k = tf.reshape(eps, [batch_size, 1])

    mixed_labels = eps_k * onehot_labels + (1 - eps_k) * self.guessed_label
    # raw softmax loss
    log_softmax = tf.nn.log_softmax(logits)
    net_cost = -tf.reduce_sum(mixed_labels * log_softmax, 1)

    lookahead_loss = tf.reduce_sum(tf.multiply(target, net_cost))
    lookahead_loss = lookahead_loss + net.regularization_loss

    with tf.control_dependencies([lookahead_loss]):
      train_vars = net.trainable_variables
      var_grads = tf.gradients(
          lookahead_loss, train_vars, gate_gradients=gate_gradients)

      static_vars = []
      for i in range(len(train_vars)):
        if FLAGS.meta_momentum > 0:
          actual_grad = self.meta_momentum_update(var_grads[i],
                                                  train_vars[i].name,
                                                  self.optimizer)
          static_vars.append(
              tf.math.subtract(train_vars[i],
                               FLAGS.meta_stepsize * actual_grad))
        else:
          static_vars.append(
              tf.math.subtract(train_vars[i],
                               FLAGS.meta_stepsize * var_grads[i]))
        # new style
        meta_net.add_variable_alias(
            static_vars[-1], var_name=train_vars[i].name)

      for uv in net.updates_variables:
        meta_net.add_variable_alias(
            uv, var_name=uv.name, var_type='updates_variables')
      meta_net.verbose()

    with tf.control_dependencies(static_vars):
      g_logits = meta_net(
          probe_images, name='meta_model', reuse=True, training=True)

      desired_y = tf.one_hot(probe_labels, self.dataset.num_classes)
      meta_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
          desired_y, g_logits)
      meta_loss = tf.reduce_mean(meta_loss, name='meta_loss')
      meta_loss = meta_loss + meta_net.get_regularization_loss(net.wd)
      meta_acc, meta_acc_op = tf.metrics.accuracy(probe_labels,
                                                  tf.argmax(g_logits, axis=1))

    with tf.control_dependencies([meta_loss] + [meta_acc_op]):
      meta_train_vars = meta_net.trainable_variables
      grad_meta_vars = tf.gradients(
          meta_loss, meta_train_vars, gate_gradients=gate_gradients)
      grad_target, grad_eps = tf.gradients(
          static_vars, [target, eps],
          grad_ys=grad_meta_vars,
          gate_gradients=gate_gradients)
    # updates weight
    raw_weight = target - grad_target
    raw_weight = raw_weight - init_eps_val
    unorm_weight = tf.clip_by_value(
        raw_weight, clip_value_min=0, clip_value_max=float('inf'))
    norm_c = tf.reduce_sum(unorm_weight)
    weight = tf.divide(unorm_weight, norm_c + 0.00001)

    # gets new lambda by the sign of gradient
    new_eps = tf.where(grad_eps < 0, x=tf.ones_like(eps), y=tf.zeros_like(eps))

    return tf.stop_gradient(weight), tf.stop_gradient(
        new_eps), meta_loss, meta_acc
Esempio n. 2
0
  def meta_optimize(self, net_cost):
    """Meta optimization step."""
    probe_images, probe_labels = self.probe_images, self.probe_labels
    net = self.net
    gate_gradients = 1

    batch_size = int(self.batch_size / self.strategy.num_replicas_in_sync)
    # initial data weight is zero
    init_eps_val = 0.0

    meta_net = networks.MetaImage(self.net, name='meta_model')

    target = tf.constant(
        [init_eps_val] * batch_size, dtype=np.float32, name='weight')

    lookahead_loss = tf.reduce_sum(tf.multiply(target, net_cost))
    lookahead_loss = lookahead_loss + net.regularization_loss

    with tf.control_dependencies([lookahead_loss]):
      train_vars = net.trainable_variables
      var_grads = tf.gradients(
          lookahead_loss, train_vars, gate_gradients=gate_gradients)

      static_vars = []
      for i in range(len(train_vars)):
        static_vars.append(
            tf.math.subtract(train_vars[i], FLAGS.meta_stepsize * var_grads[i]))
        meta_net.add_variable_alias(
            static_vars[-1], var_name=train_vars[i].name)

      for uv in net.updates_variables:
        meta_net.add_variable_alias(
            uv, var_name=uv.name, var_type='updates_variables')
      meta_net.verbose()

    with tf.control_dependencies(static_vars):
      g_logits = meta_net(
          probe_images, name='meta_model', reuse=True, training=True)

      desired_y = tf.one_hot(probe_labels, self.dataset.num_classes)
      meta_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
          desired_y, g_logits)
      meta_loss = tf.reduce_mean(meta_loss, name='meta_loss')
      meta_loss = meta_loss + meta_net.get_regularization_loss(net.wd)
      meta_acc, meta_acc_op = tf.metrics.accuracy(probe_labels,
                                                  tf.argmax(g_logits, axis=1))

    with tf.control_dependencies([meta_loss] + [meta_acc_op]):
      meta_train_vars = meta_net.trainable_variables
      # sanity: save memory for partial graph backpropagate
      grad_meta_vars = tf.gradients(
          meta_loss, meta_train_vars, gate_gradients=gate_gradients)
      grad_target = tf.gradients(
          static_vars,
          target,
          grad_ys=grad_meta_vars,
          gate_gradients=gate_gradients)[0]

    unorm_weight = tf.clip_by_value(
        -grad_target, clip_value_min=0, clip_value_max=float('inf'))
    norm_c = tf.reduce_sum(unorm_weight)
    weight = tf.divide(unorm_weight, norm_c + 0.00001)

    return tf.stop_gradient(weight), meta_loss, meta_acc