Esempio n. 1
0
def _batch_std(inputs,
               training,
               decay=MOVING_AVERAGE_DECAY,
               epsilon=EPSILON,
               data_format='channels_first',
               name='moving_variance'):
    """Batch standard deviation."""
    if data_format == 'channels_last':
        var_shape, axes = (1, 1, 1, inputs.shape[3]), [0, 1, 2]
    else:
        var_shape, axes = (1, inputs.shape[1], 1, 1), [0, 2, 3]
    moving_variance = tf.get_variable(
        name=name,
        shape=var_shape,
        initializer=tf.initializers.ones(),
        dtype=tf.float32,
        collections=[
            tf.GraphKeys.MOVING_AVERAGE_VARIABLES,
            tf.GraphKeys.GLOBAL_VARIABLES
        ],
        trainable=False)
    if training:
        _, variance = tf.nn.moments(inputs, axes, keep_dims=True)
        variance = tf.cast(variance, tf.float32)
        update_op = tf.assign_sub(moving_variance,
                                  (moving_variance - variance) * (1 - decay))
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op)
    else:
        variance = moving_variance
    std = tf.sqrt(variance + epsilon)
    return tf.cast(std, inputs.dtype)
Esempio n. 2
0
File: newton.py Progetto: cthl/sqgn
    def _undo_update(self):
        ops = []

        for w, dw in zip(self._weights, self._dws):
            ops.append(tf.assign_sub(w, dw))

        return tf.group(ops)
Esempio n. 3
0
    def apply_gradients(self, grads_and_vars):
        with tf.name_scope(self.name):
            state_vars = []
            update_ops = []

            # Adjust learning rate to deal with startup bias.
            with tf.control_dependencies(None):
                b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
                b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
                state_vars += [b1pow_var, b2pow_var]
            b1pow_new = b1pow_var * self.beta1
            b2pow_new = b2pow_var * self.beta2
            update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)]
            lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)

            # Construct ops to update each variable.
            for grad, var in grads_and_vars:
                with tf.control_dependencies(None):
                    m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
                    v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
                    state_vars += [m_var, v_var]
                m_new = self.beta1 * m_var + (1 - self.beta1) * grad
                v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
                var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
                update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)]

            # Group everything together.
            self.all_state_vars += state_vars
            return tf.group(*update_ops)
Esempio n. 4
0
    def _finish(self, caches):
        """ """

        if self.clip > 0:
            S_t = [cache['s_t'] for cache in caches]
            S_t, _ = tf.clip_by_global_norm(S_t, self.clip)
            for cache, s_t in zip(caches, S_t):
                cache['s_t'] = s_t

        for cache in caches:
            x_tm1 = cache['x_tm1']
            s_t = cache['s_t']
            updates = cache['updates']
            with tf.name_scope('update_' + x_tm1.op.name), tf.device(
                    x_tm1.device):
                if 'idxs' in cache:
                    idxs = cache['idxs']
                    x_t = tf.scatter_sub(x_tm1, idxs, s_t)
                    if self.chi > 0:
                        x_t_ = tf.gather(x_t, idxs)
                        x_bar_t, t_x_bar = self._sparse_moving_average(
                            x_tm1, idxs, x_t_, 'x', beta=self.chi)
                else:
                    x_t = tf.assign_sub(x_tm1, s_t)
                    if self.chi > 0:
                        x_bar_t, t_x_bar = self._dense_moving_average(
                            x_tm1, x_t, 'x', beta=self.chi)
            updates.append(x_t)
            if self.chi > 0:
                updates.extend([x_bar_t, t_x_bar])

        update_ops = [tf.group(*cache['updates']) for cache in caches]
        return tf.group(*update_ops, name='update')
Esempio n. 5
0
  def build_trainer(self, child_model):
    """Build the train ops by connecting Controller with a Child."""
    # actor
    self.valid_loss = tf.to_float(child_model.rl_loss)
    self.valid_loss = tf.stop_gradient(self.valid_loss)
    self.valid_ppl = tf.exp(self.valid_loss)
    self.reward = REWARD_CONSTANT / self.valid_ppl

    if self.params.controller_entropy_weight:
      self.reward += self.params.controller_entropy_weight * self.sample_entropy

    # or baseline
    self.sample_log_probs = tf.reduce_sum(self.sample_log_probs)
    self.baseline = tf.Variable(0.0, dtype=tf.float32, trainable=False)
    baseline_update = tf.assign_sub(self.baseline,
                                    ((1 - self.params.controller_baseline_dec) *
                                     (self.baseline - self.reward)))

    with tf.control_dependencies([baseline_update]):
      self.reward = tf.identity(self.reward)
    self.loss = self.sample_log_probs * (self.reward - self.baseline)

    self.train_step = tf.Variable(
        0, dtype=tf.int32, trainable=False, name='train_step')
    tf_vars = [var for var in tf.trainable_variables()
               if var.name.startswith(self.name)]

    self.train_op, self.optimizer, self.grad_norm = _build_train_op(
        loss=self.loss,
        tf_vars=tf_vars,
        learning_rate=self.params.controller_learning_rate,
        train_step=self.train_step,
        num_aggregate=self.params.controller_num_aggregate)
 def _apply_sparse_shared(self, grad, var, indices, scatter_add):
     beta1_power, beta2_power = self._get_beta_accumulators()
     beta1_power = tf.cast(beta1_power, var.dtype.base_dtype)
     beta2_power = tf.cast(beta2_power, var.dtype.base_dtype)
     lr_t = tf.cast(self._lr_t, var.dtype.base_dtype)
     beta1_t = tf.cast(self._beta1_t, var.dtype.base_dtype)
     beta2_t = tf.cast(self._beta2_t, var.dtype.base_dtype)
     epsilon_t = tf.cast(self._epsilon_t, var.dtype.base_dtype)
     lr = (lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power))
     # m_t = beta1 * m + (1 - beta1) * g_t
     m = self.get_slot(var, "m")
     m_scaled_g_values = grad * (1 - beta1_t)
     m_t = tf.assign(m, m * beta1_t, use_locking=self._use_locking)
     with tf.control_dependencies([m_t]):
         m_t = scatter_add(m, indices, m_scaled_g_values)
     # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
     v = self.get_slot(var, "v")
     v_scaled_g_values = (grad * grad) * (1 - beta2_t)
     v_t = tf.assign(v, v * beta2_t, use_locking=self._use_locking)
     with tf.control_dependencies([v_t]):
         v_t = scatter_add(v, indices, v_scaled_g_values)
     v_sqrt = tf.sqrt(v_t)
     var_update = tf.assign_sub(var,
                                lr * m_t / (v_sqrt + epsilon_t),
                                use_locking=self._use_locking)
     return tf.group(*[var_update, m_t, v_t])
Esempio n. 7
0
    def _sub_mixed_grad(self):
        ops = []

        # Subtract the current gradient evaluated with the reference weights.
        for g_agg, g in zip(self._grads_aggregated, self._grads):
            ops.append(tf.assign_sub(g_agg, g))

        return tf.group(ops)
Esempio n. 8
0
    def build_trainer(self, child_model):
        child_model.build_valid_rl()
        self.valid_acc = (tf.to_float(child_model.valid_shuffle_acc) /
                          tf.to_float(child_model.batch_size))
        self.current_normal_arc = child_model.current_normal_arc
        self.current_reduce_arc = child_model.current_reduce_arc

        self.reward = self.valid_acc

        if self.entropy_weight is not None:
            self.reward += self.entropy_weight * self.sample_entropy

        self.sample_log_prob = tf.reduce_sum(self.sample_log_prob)
        self.baseline = tf.Variable(0.0, dtype=tf.float32, trainable=False)
        baseline_update = tf.assign_sub(self.baseline, (1 - self.bl_dec) *
                                        (self.baseline - self.reward))

        with tf.control_dependencies([baseline_update]):
            self.reward = tf.identity(self.reward)

        self.loss = self.sample_log_prob * (self.reward - self.baseline)
        self.train_step = tf.Variable(0,
                                      dtype=tf.int32,
                                      trainable=False,
                                      name="train_step")

        tf_variables = [
            var for var in tf.trainable_variables()
            if var.name.startswith(self.name)
        ]
        print("-" * 80)
        for var in tf_variables:
            print(var)

        self.train_op, self.lr, self.grad_norm, self.optimizer = get_train_ops(
            self.loss,
            tf_variables,
            self.train_step,
            clip_mode=self.clip_mode,
            grad_bound=self.grad_bound,
            l2_reg=self.l2_reg,
            lr_init=self.lr_init,
            lr_dec_start=self.lr_dec_start,
            lr_dec_every=self.lr_dec_every,
            lr_dec_rate=self.lr_dec_rate,
            optim_algo=self.optim_algo,
            sync_replicas=self.sync_replicas,
            num_aggregate=self.num_aggregate,
            num_replicas=self.num_replicas)

        self.skip_rate = tf.constant(0.0, dtype=tf.float32)
    def _apply_sparse_shared(self, grad, var, indices, scatter_add):
        beta1_power, beta2_power = self._get_beta_accumulators()
        beta1_power = tf.cast(beta1_power, var.dtype.base_dtype)
        beta2_power = tf.cast(beta2_power, var.dtype.base_dtype)
        lr_t = tf.cast(self._lr_t, var.dtype.base_dtype)
        beta1_t = tf.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = tf.cast(self._beta2_t, var.dtype.base_dtype)
        epsilon_t = tf.cast(self._epsilon_t, var.dtype.base_dtype)
        weight_decay_rate_t = tf.cast(self._weight_decay_rate_t,
                                      var.dtype.base_dtype)
        # m_t = beta1 * m + (1 - beta1) * g_t
        m = self.get_slot(var, "m")
        m_scaled_g_values = grad * (1 - beta1_t)
        m_t = tf.assign(m, m * beta1_t, use_locking=self._use_locking)
        with tf.control_dependencies([m_t]):
            m_t = scatter_add(m, indices, m_scaled_g_values)
        # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
        v = self.get_slot(var, "v")
        v_scaled_g_values = (grad * grad) * (1 - beta2_t)
        v_t = tf.assign(v, v * beta2_t, use_locking=self._use_locking)
        with tf.control_dependencies([v_t]):
            v_t = scatter_add(v, indices, v_scaled_g_values)

        # ==== The following is with m_t_hat and v_t_hat
        m_t_hat = m_t / (1. - beta1_power)
        v_t_hat = v_t / (1. - beta2_power)

        v_sqrt = tf.sqrt(v_t_hat)
        update = m_t_hat / (v_sqrt + epsilon_t)

        # ==== The following is the original LAMBOptimizer implementation
        # v_sqrt = tf.sqrt(v_t_hat)
        # update = m_t / (v_sqrt + epsilon_t)

        var_name = self._get_variable_name(var.name)
        if self._do_use_weight_decay(var_name):
            update += weight_decay_rate_t * var

        ratio = 1.0
        if self._do_layer_adaptation(var_name):
            w_norm = tf.norm(var, ord=2)
            g_norm = tf.norm(update, ord=2)
            ratio = tf.where(
                tf.greater(w_norm, 0),
                tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)
        var_update = tf.assign_sub(var,
                                   ratio * lr_t * update,
                                   use_locking=self._use_locking)
        return tf.group(*[var_update, m_t, v_t])
Esempio n. 10
0
File: krylov.py Progetto: cthl/sqgn
    def _update(self, rs, ps):
        ops = []

        # Compute the coefficient alpha.
        pTAp = tf.zeros(shape=[], dtype=ps[0].dtype)
        for p, Az in zip(ps, self._Azs):
            # Recall that p has already been assigned to z, and hence Az = Ap.
            pTAp += tf.reduce_sum(p * Az)

        indefinite = pTAp <= 0.0
        ops.append(tf.assign(self._indefinite, indefinite))

        alpha = tf.cond(indefinite, lambda: 0.0, lambda: self._rTr / pTAp)

        # Update the solution and residual.
        for x, r, p, Az in zip(self._xs, rs, ps, self._Azs):
            ops.append(tf.assign_add(x, alpha * p))
            ops.append(tf.assign_sub(r, alpha * Az))

        return tf.group(ops)
Esempio n. 11
0
    def _apply_dense(self, grad, var):
        m = self.get_slot(var, "m")
        v = self.get_slot(var, "v")

        lr = tf.cast(self._lr_t, grad.dtype.base_dtype)
        beta1 = tf.cast(self._beta1_t, grad.dtype.base_dtype)
        beta2 = tf.cast(self._beta2_t, grad.dtype.base_dtype)
        epsilon = tf.cast(self._epsilon_t, grad.dtype.base_dtype)

        grad = grad - var * self._l2_weight_decay

        # m_t = beta_1 * m_{t-1} + (1-beta_1) * g_t
        m_t = m.assign(beta1 * m + (1.0 - beta1) * grad)

        # v_t = beta_2 * v_{t-1} + (1-beta_2) * g_t ** 2
        v_t = v.assign(beta2 * v + (1.0 - beta2) * grad * grad)

        if self._use_bias_correction:
            beta1_power, beta2_power = self._get_beta_accumulators()
            beta1_power = tf.cast(beta1_power, grad.dtype.base_dtype)
            beta2_power = tf.cast(beta2_power, grad.dtype.base_dtype)
            lr_t = lr * tf.sqrt(1.0 - beta2_power) / (1.0 - beta1_power)
        else:
            lr_t = lr

        if self._use_nesterov:
            # delta theta = lr_t * (
            #    (beta_1 * m_t + (1-beta1) * g_t) / (sqrt(v_t) + epsilon))
            step = lr_t * ((beta1 * m_t + (1.0 - beta1) * grad) /
                           (tf.sqrt(v_t) + epsilon))
        else:
            # delta theta = lr_t * m_t / (sqrt(v_t) + epsilon)
            step = lr_t * m_t / (tf.sqrt(v_t) + epsilon)

        # AdamW style weight decay term.
        step = step + lr_t * self._adamw_weight_decay * var

        theta_t = tf.assign_sub(var, step)

        return tf.group(*[theta_t, m_t, v_t])
Esempio n. 12
0
def setup_ema(params, name_scope=None):
  """Create exponential moving average for all variables under `name_scope`."""
  logging.info(f'ema_decay with rate {params.ema_decay}')
  all_vars = tf.global_variables()
  ema_ops = []
  step = tf.cast(tf.train.get_or_create_global_step() - params.ema_start,
                 tf.float32)
  decay = 1. - tf.minimum(params.ema_decay, (step+1.) / (step+10.))
  decay = tf.cond(tf.train.get_or_create_global_step() < params.ema_start,
                  lambda: tf.constant(1, tf.float32), lambda: decay)

  def should_skip(v):
    key_words = ['momentum', 'rms', 'global_step', 'debug', 'adam', 'lars']
    conditions = [k in v.name.lower() for k in key_words]
    if name_scope is not None:
      conditions += [not v.name.lower().startswith(name_scope)]
    return any(conditions)

  def get_init(v_name):
    key_words = ['variance', 'beta']
    if any([k in v_name for k in key_words]):
      return tf.initializers.ones()
    return tf.initializers.zeros()

  with tf.variable_scope('ema'):
    for v in all_vars:
      if not should_skip(v):
        v_name = strip_var_name(v.name)
        with tf.device(v.device):
          ema_var = tf.get_variable(
              name=v_name,
              shape=v.shape.as_list(),
              initializer=get_init(v_name),
              trainable=False)
          v = shard_weight(v, params.num_cores_per_replica)
          ema = shard_weight(ema_var, params.num_cores_per_replica)
          ema_op = tf.assign_sub(ema_var, decay * (ema-v), use_locking=True)
        ema_ops.append(ema_op)
  ema_op = tf.group(*ema_ops)
  return ema_op
Esempio n. 13
0
File: krylov.py Progetto: cthl/sqgn
    def _update(self, rs, ps):
        ops = []

        # Compute the coefficient alpha.
        pTHp = tf.zeros(shape=[], dtype=ps[0].dtype)
        for p, Hz in zip(ps, self._hessians):
            # Recall that p has already been assigned to z, and hence Hz = Hp.
            pTHp += tf.reduce_sum(p * Hz)

        # Compute the coefficient for the update.
        alpha = self._rTr / pTHp

        # Create a tensor that computes the norm of the iterate after the update
        # without actually modifying it.
        norm_dw_new = tf.zeros(shape=[], dtype=self._norm_dw.dtype)
        for dw, p in zip(self._dws, ps):
            dw_new = dw + alpha * p
            norm_dw_new += tf.reduce_sum(dw_new * dw_new)
        norm_dw_new = tf.sqrt(norm_dw_new)

        # Determine if we should follow the direction p until it intersects with the
        # boundary of the trust region.
        # This is the case if either p is a direction of indefiniteness or if dw + p
        # would be outside the trust region.
        follow_to_boundary = tf.logical_or(pTHp <= 0.0,
                                           norm_dw_new > self._radius_placeh)
        self._follow_to_boundary = tf.Variable(False)
        ops.append(tf.assign(self._follow_to_boundary, follow_to_boundary))

        # If we follow p up to the boundary, we do not update dw here.
        # Instead, we determine the final update dw in the 'solve' method.
        alpha_or_zero = tf.cond(follow_to_boundary, lambda: 0.0, lambda: alpha)

        # Update the solution and residual.
        for dw, r, p, Hz in zip(self._dws, rs, ps, self._hessians):
            ops.append(tf.assign_add(dw, alpha_or_zero * p))
            ops.append(tf.assign_sub(r, alpha_or_zero * Hz))

        return tf.group(ops)
Esempio n. 14
0
    def _apply_dense(self, grad, var):
        # SM3 upper bounds the gradient square sums:
        #
        # To illustrate:
        #
        # For a Tensor `T` of shape [M, N, K].
        #
        # `G` be its gradient of shape [M, N, K]
        #
        # SM3 keeps around three accumulators A1, A2, A3 of size M, N, K
        # respectively.
        #
        # `A` be the accumulator of shape [M, N, K]. `A` is not materialized until
        #   its needed for every step, and is approximated by A1, A2, A3.
        #
        # At every gradient update step the accumulators satisify:
        #   A1_t[i] >= Sum_{s <= t} G_t[i, j, k]^2 for all j, k.
        #   A2_t[j] >= Sum_{s <= t} G_t[i, j, k]^2 for all i, k.
        #   A3_t[k] >= Sum_{s <= t} G_t[i, j, k]^2 for all i, j.
        #
        # The RHS is the gradient sum squares.
        #
        # For every step we materialize the tensor `A` based on accumulated tensors
        # A1, A2 and A3.
        #
        #  A = min(A1[i], A2[j], A3[j]) + G[i, j, k]^2
        #
        # SM3 preconditioned gradient is
        #
        #  preconditioned G = A^{-0.5} * G
        #
        # We then update the individual accumulator factors as:
        #
        #  A1[i] = max_{j, k} A[i, j, k]
        #  A2[j] = max_{i, k} A[i, j, k]
        #  A3[k] = max_{i, j} A[i, j, k]
        #
        shape = np.array(var.get_shape())
        var_rank = len(shape)
        if var_rank > 1:
            accumulator_list = [
                self.get_slot(var, "accumulator_" + str(i))
                for i in range(var_rank)
            ]
            accumulator = self._compute_past_accumulator(
                accumulator_list, shape)
            accumulator += grad * grad
        else:
            accumulator_var = self.get_slot(var, "accumulator")
            accumulator = tf.assign_add(accumulator_var, grad * grad)

        accumulator_inv_sqrt = tf.rsqrt(accumulator + 1e-30)
        scaled_g = (1.0 - self._momentum_tensor) * (grad *
                                                    accumulator_inv_sqrt)
        accumulator_update_ops = []

        with tf.control_dependencies([scaled_g]):
            if var_rank > 1:
                # Updates individual accumulator factors as:
                #  A1[i] = max_{j, k} A[i, j, k]
                #  A2[j] = max_{i, k} A[i, j, k]
                #  A3[k] = max_{i, j} A[i, j, k]
                for i, accumulator_i in enumerate(accumulator_list):
                    axes = list(range(i)) + list(range(i + 1, var_rank))
                    new_accumulator_i = tf.reduce_max(accumulator, axis=axes)
                    accumulator_update_ops.append(
                        tf.assign(accumulator_i, new_accumulator_i))

        with tf.control_dependencies(accumulator_update_ops):
            if self._momentum > 0:
                gbar = self.get_slot(var, "momentum")
                update = tf.assign_add(
                    gbar,
                    gbar * (self._momentum_tensor - 1.0) + scaled_g)
            else:
                update = scaled_g
            return tf.assign_sub(var, self._learning_rate_tensor * update)
Esempio n. 15
0
 def u(moving, normal, name):
     num_replicas_fp = tf.cast(num_replicas, tf.float32)
     normal = tf.tpu.cross_replica_sum(normal) / num_replicas_fp
     diff = decay * (moving - normal)
     return tf.assign_sub(moving, diff, use_locking=True, name=name)
    def apply_updates(self):
        assert not self._updates_applied
        self._updates_applied = True
        devices = list(self._dev_grads.keys())
        total_grads = sum(len(grads) for grads in self._dev_grads.values())
        assert len(devices) >= 1 and total_grads >= 1
        ops = []
        with absolute_name_scope(self.scope):

            # Cast gradients to FP32 and calculate partial sum within each device.
            dev_grads = OrderedDict()  # device => [(grad, var), ...]
            for dev_idx, dev in enumerate(devices):
                with tf.name_scope('ProcessGrads%d' % dev_idx), tf.device(dev):
                    sums = []
                    for gv in zip(*self._dev_grads[dev]):
                        assert all(v is gv[0][1] for g, v in gv)
                        g = [tf.cast(g, tf.float32) for g, v in gv]
                        g = g[0] if len(g) == 1 else tf.add_n(g)
                        sums.append((g, gv[0][1]))
                    dev_grads[dev] = sums

            # Sum gradients across devices.
            if len(devices) > 1:
                with tf.name_scope('SumAcrossGPUs'), tf.device(None):
                    for var_idx, grad_shape in enumerate(self._grad_shapes):
                        g = [dev_grads[dev][var_idx][0] for dev in devices]
                        if np.prod(
                                grad_shape
                        ):  # nccl does not support zero-sized tensors
                            g = tf.contrib.nccl.all_sum(g)
                        for dev, gg in zip(devices, g):
                            dev_grads[dev][var_idx] = (
                                gg, dev_grads[dev][var_idx][1])

            # Apply updates separately on each device.
            for dev_idx, (dev, grads) in enumerate(dev_grads.items()):
                with tf.name_scope('ApplyGrads%d' % dev_idx), tf.device(dev):

                    # Scale gradients as needed.
                    if self.use_loss_scaling or total_grads > 1:
                        with tf.name_scope('Scale'):
                            coef = tf.constant(np.float32(1.0 / total_grads),
                                               name='coef')
                            coef = self.undo_loss_scaling(coef)
                            grads = [(g * coef, v) for g, v in grads]

                    # Check for overflows.
                    with tf.name_scope('CheckOverflow'):
                        grad_ok = tf.reduce_all(
                            tf.stack([
                                tf.reduce_all(tf.is_finite(g))
                                for g, v in grads
                            ]))

                    # Update weights and adjust loss scaling.
                    with tf.name_scope('UpdateWeights'):
                        opt = self._dev_opt[dev]
                        ls_var = self.get_loss_scaling_var(dev)
                        if not self.use_loss_scaling:
                            ops.append(
                                tf.cond(grad_ok,
                                        lambda: opt.apply_gradients(grads),
                                        tf.no_op))
                        else:
                            ops.append(
                                tf.cond(
                                    grad_ok, lambda: tf.group(
                                        tf.assign_add(ls_var, self.
                                                      loss_scaling_inc),
                                        opt.apply_gradients(grads)),
                                    lambda: tf.group(
                                        tf.assign_sub(ls_var, self.
                                                      loss_scaling_dec))))

                    # Report statistics on the last device.
                    if dev == devices[-1]:
                        with tf.name_scope('Statistics'):
                            ops.append(
                                autosummary(self.id + '/learning_rate',
                                            self.learning_rate))
                            ops.append(
                                autosummary(self.id + '/overflow_frequency',
                                            tf.where(grad_ok, 0, 1)))
                            if self.use_loss_scaling:
                                ops.append(
                                    autosummary(self.id + '/loss_scaling_log2',
                                                ls_var))

            # Initialize variables and group everything into a single op.
            self.reset_optimizer_state()
            init_uninited_vars(list(self._dev_ls_var.values()))
            return tf.group(*ops, name='TrainingOp')
Esempio n. 17
0
    def __init__(self, n_sample, minibatch_sz, m1_inp_shape, m2_inp_shape,
                 m1_layers, m2_layers, msi_layers, m1_cause_init,
                 m2_cause_init, msi_cause_init, reg_m1_causes, reg_m2_causes,
                 reg_msi_causes, lr_m1_causes, lr_m2_causes, lr_msi_causes,
                 reg_m1_filters, reg_m2_filters, reg_msi_filters,
                 lr_m1_filters, lr_m2_filters, lr_msi_filters):

        self.m1_inp_shape = m1_inp_shape
        self.m2_inp_shape = m2_inp_shape
        self.m1_layers = m1_layers
        self.m2_layers = m2_layers
        self.msi_layers = msi_layers

        # create placeholders
        self.x_m1 = tf.placeholder(tf.float32,
                                   shape=[minibatch_sz, m1_inp_shape])
        self.x_m2 = tf.placeholder(tf.float32,
                                   shape=[minibatch_sz, m2_inp_shape])
        self.batch = tf.placeholder(tf.int32, shape=[])

        # create filters and cause for m1
        self.m1_filters = []
        self.m1_causes = []
        for i in range(len(self.m1_layers)):
            filter_name = 'm1_filter_%d' % i
            cause_name = 'm1_cause_%d' % i

            if i == 0:
                self.m1_filters += [
                    tf.get_variable(
                        filter_name,
                        shape=[self.m1_layers[i], self.m1_inp_shape])
                ]
            else:
                self.m1_filters += [
                    tf.get_variable(
                        filter_name,
                        shape=[self.m1_layers[i], self.m1_layers[i - 1]])
                ]

            init = tf.constant_initializer(m1_cause_init[i])
            self.m1_causes += [
                tf.get_variable(cause_name,
                                shape=[n_sample, self.m1_layers[i]],
                                initializer=init)
            ]

        # create filters and cause for m2
        self.m2_filters = []
        self.m2_causes = []
        for i in range(len(self.m2_layers)):
            filter_name = 'm2_filter_%d' % i
            cause_name = 'm2_cause_%d' % i

            if i == 0:
                self.m2_filters += [
                    tf.get_variable(
                        filter_name,
                        shape=[self.m2_layers[i], self.m2_inp_shape])
                ]
            else:
                self.m2_filters += [
                    tf.get_variable(
                        filter_name,
                        shape=[self.m2_layers[i], self.m2_layers[i - 1]])
                ]

            init = tf.constant_initializer(m2_cause_init[i])
            self.m2_causes += [
                tf.get_variable(cause_name,
                                shape=[n_sample, self.m2_layers[i]],
                                initializer=init)
            ]

        # create filters and cause for msi
        self.msi_filters = []
        self.msi_causes = []
        for i in range(len(self.msi_layers)):
            if i == 0:
                # add filters for m1
                filter_name = 'msi_m1_filter'
                self.msi_filters += [
                    tf.get_variable(
                        filter_name,
                        shape=[self.msi_layers[i], self.m1_layers[-1]])
                ]
                # add filters for m2
                filter_name = 'msi_m2_filter'
                self.msi_filters += [
                    tf.get_variable(
                        filter_name,
                        shape=[self.msi_layers[i], self.m2_layers[-1]])
                ]
            else:
                filter_name = 'msi_filter_%d' % i
                self.msi_filters += [
                    tf.get_variable(
                        filter_name,
                        shape=[self.msi_layers[i], self.msi_layers[i - 1]])
                ]

            cause_name = 'msi_cause_%d' % i
            init = tf.constant_initializer(msi_cause_init[i])
            self.msi_causes += [
                tf.get_variable(cause_name,
                                shape=[n_sample, self.msi_layers[i]],
                                initializer=init)
            ]

        # compute predictions
        current_batch = tf.range(self.batch * minibatch_sz,
                                 (self.batch + 1) * minibatch_sz)
        # m1 predictions
        self.m1_minibatch = []
        self.m1_predictions = []
        for i in range(len(self.m1_layers)):
            self.m1_minibatch += [
                tf.gather(self.m1_causes[i], indices=current_batch, axis=0)
            ]
            self.m1_predictions += [
                tf.nn.leaky_relu(
                    tf.matmul(self.m1_minibatch[i], self.m1_filters[i]))
            ]

        # m2 predictions
        self.m2_minibatch = []
        self.m2_predictions = []
        for i in range(len(self.m2_layers)):
            self.m2_minibatch += [
                tf.gather(self.m2_causes[i], indices=current_batch, axis=0)
            ]
            self.m2_predictions += [
                tf.nn.leaky_relu(
                    tf.matmul(self.m2_minibatch[i], self.m2_filters[i]))
            ]

        # msi predictions
        self.msi_minibatch = []
        self.msi_predictions = []
        for i in range(len(self.msi_layers)):
            self.msi_minibatch += [
                tf.gather(self.msi_causes[i], indices=current_batch, axis=0)
            ]
            if i == 0:
                self.msi_predictions += [
                    tf.nn.leaky_relu(
                        tf.matmul(self.msi_minibatch[i], self.msi_filters[i]))
                ]  # m1 prediction
                self.msi_predictions += [
                    tf.nn.leaky_relu(
                        tf.matmul(self.msi_minibatch[i],
                                  self.msi_filters[i + 1]))
                ]  # m2 prediction
            else:
                self.msi_predictions += [
                    tf.nn.leaky_relu(
                        tf.matmul(self.msi_minibatch[i],
                                  self.msi_filters[i + 1]))
                ]

        # add ops for computing gradients for m1 causes and for updating weights
        self.m1_bu_error = []
        self.m1_update_filter = []
        self.m1_cause_grad = []
        for i in range(len(self.m1_layers)):
            if i == 0:
                self.m1_bu_error += [
                    tf.losses.mean_squared_error(
                        self.x_m1,
                        self.m1_predictions[i],
                        reduction=tf.losses.Reduction.NONE)
                ]
            else:
                self.m1_bu_error += [
                    tf.losses.mean_squared_error(
                        tf.stop_gradient(self.m1_minibatch[i - 1]),
                        self.m1_predictions[i],
                        reduction=tf.losses.Reduction.NONE)
                ]

            # compute top-down prediction error
            if len(self.m1_layers) > (i + 1):
                # there are more layers in this modality
                td_error = tf.losses.mean_squared_error(
                    tf.stop_gradient(self.m1_predictions[i + 1]),
                    self.m1_minibatch[i],
                    reduction=tf.losses.Reduction.NONE)
            else:
                # this is the only layer in this modality
                td_error = tf.losses.mean_squared_error(
                    tf.stop_gradient(self.msi_predictions[0]),
                    self.m1_minibatch[i],
                    reduction=tf.losses.Reduction.NONE)

            reg_error = reg_m1_causes[i] * (self.m1_minibatch[i]**2)
            # reg_error = tf.keras.regularizers.l2(reg_m1_causes[i])(self.m1_minibatch[i])
            self.m1_cause_grad += [
                tf.gradients([self.m1_bu_error[i], td_error, reg_error],
                             self.m1_minibatch[i])[0]
            ]

            # ops for updating weights
            reg_error = reg_m1_filters[i] * (self.m1_filters[i]**2)
            m1_filter_grad = tf.gradients([self.m1_bu_error[i], reg_error],
                                          self.m1_filters[i])[0]
            self.m1_update_filter += [
                tf.assign_sub(self.m1_filters[i],
                              lr_m1_filters[i] * m1_filter_grad)
            ]

        # add ops for computing gradients for m2 causes and for updating weights
        self.m2_bu_error = []
        self.m2_update_filter = []
        self.m2_cause_grad = []
        for i in range(len(self.m2_layers)):
            if i == 0:
                self.m2_bu_error += [
                    tf.losses.mean_squared_error(
                        self.x_m2,
                        self.m2_predictions[i],
                        reduction=tf.losses.Reduction.NONE)
                ]
            else:
                self.m2_bu_error += [
                    tf.losses.mean_squared_error(
                        tf.stop_gradient(self.m2_minibatch[i - 1]),
                        self.m2_predictions[i],
                        reduction=tf.losses.Reduction.NONE)
                ]

            # compute top-down prediction error
            if len(self.m2_layers) > (i + 1):
                # there are more layers in this modality
                td_error = tf.losses.mean_squared_error(
                    tf.stop_gradient(self.m2_predictions[i + 1]),
                    self.m2_minibatch[i],
                    reduction=tf.losses.Reduction.NONE)
            else:
                # this is the only layer in this modality
                td_error = tf.losses.mean_squared_error(
                    tf.stop_gradient(self.msi_predictions[1]),
                    self.m2_minibatch[i],
                    reduction=tf.losses.Reduction.NONE)

            reg_error = reg_m2_causes[i] * (self.m2_minibatch[i]**2)
            # reg_error = tf.keras.regularizers.l2(reg_m2_causes[i])(self.m2_minibatch[i])
            self.m2_cause_grad += [
                tf.gradients([self.m2_bu_error[i], td_error, reg_error],
                             self.m2_minibatch[i])[0]
            ]

            # add ops for updating weights
            reg_error = reg_m2_filters[i] * (self.m2_filters[i]**2)
            m2_filter_grad = tf.gradients([self.m2_bu_error[i], reg_error],
                                          self.m2_filters[i])[0]
            self.m1_update_filter += [
                tf.assign_sub(self.m2_filters[i],
                              lr_m2_filters[i] * m2_filter_grad)
            ]
            #else:
            #raise NotImplementedError

        # add ops for computing gradients for msi causes
        self.msi_bu_error = []
        self.msi_reg_error = []
        self.msi_update_filter = []
        self.msi_cause_grad = []
        for i in range(len(self.msi_layers)):
            if i == 0:
                self.msi_bu_error += [
                    tf.losses.mean_squared_error(
                        tf.stop_gradient(self.m1_minibatch[-1]),
                        self.msi_predictions[i],
                        reduction=tf.losses.Reduction.NONE)
                ]
                self.msi_bu_error += [
                    tf.losses.mean_squared_error(
                        tf.stop_gradient(self.m2_minibatch[-1]),
                        self.msi_predictions[i + 1],
                        reduction=tf.losses.Reduction.NONE)
                ]

                self.msi_reg_error += [
                    reg_msi_causes[i] * (self.msi_minibatch[i]**2)
                ]
                # self.msi_reg_error += [tf.keras.regularizers.l2(reg_msi_causes[i])(self.msi_minibatch[i])]
                if len(self.msi_layers) > 1:
                    raise NotImplementedError
                else:
                    self.msi_cause_grad += [
                        tf.gradients([
                            self.msi_bu_error[i], self.msi_bu_error[i + 1],
                            self.msi_reg_error[i]
                        ], self.msi_minibatch[i])[0]
                    ]

                # add ops for updating weights
                reg_error = reg_msi_filters[i] * (self.msi_filters[i]**2)
                msi_filter_grad = tf.gradients(
                    [self.msi_bu_error[i], reg_error], self.msi_filters[i])[0]
                self.msi_update_filter += [
                    tf.assign_sub(self.msi_filters[i],
                                  lr_msi_filters[i] * msi_filter_grad)
                ]
                reg_error = reg_msi_filters[i + 1] * (self.msi_filters[i + 1]**
                                                      2)
                msi_filter_grad = tf.gradients(
                    [self.msi_bu_error[i + 1], reg_error],
                    self.msi_filters[i + 1])[0]
                self.msi_update_filter += [
                    tf.assign_sub(self.msi_filters[i + 1],
                                  lr_msi_filters[i + 1] * msi_filter_grad)
                ]
            else:
                raise NotImplementedError

        # add ops for updating causes
        self.m1_update_cause = []
        self.m2_update_cause = []
        self.msi_update_cause = []
        with tf.control_dependencies(self.m1_cause_grad + self.m2_cause_grad +
                                     self.msi_cause_grad):
            # m1 modality
            for i in range(len(self.m1_layers)):
                self.m1_update_cause += [
                    tf.scatter_sub(self.m1_causes[i],
                                   indices=current_batch,
                                   updates=(lr_m1_causes[i] *
                                            self.m1_cause_grad[i]))
                ]

            # m2 modality
            for i in range(len(self.m2_layers)):
                self.m2_update_cause += [
                    tf.scatter_sub(self.m2_causes[i],
                                   indices=current_batch,
                                   updates=(lr_m2_causes[i] *
                                            self.m2_cause_grad[i]))
                ]

            # msi modality
            for i in range(len(self.msi_layers)):
                self.msi_update_cause += [
                    tf.scatter_sub(self.msi_causes[i],
                                   indices=current_batch,
                                   updates=(lr_msi_causes[i] *
                                            self.msi_cause_grad[i]))
                ]
  def step_fn(self, params, model):
    """Separate implementation."""
    train_batch_size = params.train_batch_size
    num_replicas = params.num_replicas
    uda_data = params.uda_data
    batch_size = train_batch_size // num_replicas

    dtypes = [
        tf.bfloat16 if params.use_bfloat16 else tf.float32,
        tf.float32,
        tf.bfloat16 if params.use_bfloat16 else tf.float32,
        tf.bfloat16 if params.use_bfloat16 else tf.float32]
    shapes = [
        [batch_size, params.image_size, params.image_size, 3],
        [batch_size, params.num_classes],
        [batch_size*params.uda_data, params.image_size, params.image_size, 3],
        [batch_size*params.uda_data, params.image_size, params.image_size, 3]]

    if params.use_xla_sharding and params.num_cores_per_replica > 1:
      q = tpu_feed._PartitionedInfeedQueue(
          number_of_tuple_elements=4,
          host_id=0,
          input_partition_dims=[[1, 1, params.num_cores_per_replica, 1],
                                [1, 1],
                                [1, 1, params.num_cores_per_replica, 1],
                                [1, 1, params.num_cores_per_replica, 1],],
          device_assignment=params.device_assignment)
      q.set_tuple_types(dtypes)
      q.set_tuple_shapes(shapes)
      l_images, l_labels, u_images_ori, u_images_aug = q.generate_dequeue_op()
      l_images = xla_sharding.split(l_images, 2,
                                    params.num_cores_per_replica)
      u_images_ori = xla_sharding.split(u_images_ori, 2,
                                        params.num_cores_per_replica)
      u_images_aug = xla_sharding.split(u_images_aug, 2,
                                        params.num_cores_per_replica)
    else:
      with tf.device(tf.tpu.core(0)):
        (l_images, l_labels, u_images_ori,
         u_images_aug) = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes,
                                                       shapes=shapes)
    global_step = tf.train.get_or_create_global_step()
    num_replicas = tf.cast(params.num_replicas, tf.float32)

    all_images = tf.concat([l_images, u_images_ori, u_images_aug], axis=0)

    # all calls to teacher
    with tf.variable_scope('teacher', reuse=tf.AUTO_REUSE):
      logits, labels, masks, cross_entropy = UDA.build_uda_cross_entropy(
          params, model, all_images, l_labels)

    # 1st call to student
    with tf.variable_scope(MODEL_SCOPE):
      u_aug_and_l_images = tf.concat([u_images_aug, l_images], axis=0)
      logits['s_on_u_aug_and_l'] = model(u_aug_and_l_images, training=True)
      logits['s_on_u'], logits['s_on_l_old'] = tf.split(
          logits['s_on_u_aug_and_l'],
          [u_images_aug.shape[0].value, l_images.shape[0].value], axis=0)

    # for backprop
    cross_entropy['s_on_u'] = tf.losses.softmax_cross_entropy(
        onehot_labels=tf.stop_gradient(tf.nn.softmax(logits['u_aug'], -1)),
        logits=logits['s_on_u'],
        label_smoothing=params.label_smoothing,
        reduction=tf.losses.Reduction.NONE)
    cross_entropy['s_on_u'] = tf.reduce_sum(cross_entropy['s_on_u']) / float(
        train_batch_size*uda_data)

    # for Taylor
    cross_entropy['s_on_l_old'] = tf.losses.softmax_cross_entropy(
        onehot_labels=labels['l'],
        logits=logits['s_on_l_old'],
        reduction=tf.losses.Reduction.SUM)
    cross_entropy['s_on_l_old'] = tf.tpu.cross_replica_sum(
        cross_entropy['s_on_l_old']) / float(train_batch_size)
    shadow = tf.get_variable(
        name='cross_entropy_old', shape=[], trainable=False, dtype=tf.float32)
    shadow_update = tf.assign(shadow, cross_entropy['s_on_l_old'])

    w_s = {}
    g_s = {}
    g_n = {}
    lr = {}
    optim = {}
    w_s['s'] = [w for w in tf.trainable_variables()
                if w.name.lower().startswith(MODEL_SCOPE)]
    g_s['s_on_u'] = tf.gradients(cross_entropy['s_on_u'], w_s['s'])
    # g_s['s_on_u'] = [tf.tpu.cross_replica_sum(g) for g in g_s['s_on_u']]

    lr['s'] = common_utils.get_learning_rate(
        params,
        initial_lr=params.mpl_student_lr,
        num_warmup_steps=params.mpl_student_lr_warmup_steps,
        num_wait_steps=params.mpl_student_lr_wait_steps)
    lr['s'], optim['s'] = common_utils.get_optimizer(
        params, learning_rate=lr['s'])
    optim['s']._create_slots(w_s['s'])  # pylint: disable=protected-access
    update_ops = [op for op in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                  if op.name.startswith(f'train/{MODEL_SCOPE}/')]

    with tf.control_dependencies(update_ops + [shadow_update]):
      g_s['s_on_u'] = common_utils.add_weight_decay(
          params, w_s['s'], g_s['s_on_u'])
      g_s['s_on_u'], g_n['s_on_u'] = tf.clip_by_global_norm(
          g_s['s_on_u'], params.grad_bound)
      train_op = optim['s'].apply_gradients(zip(g_s['s_on_u'], w_s['s']))

      with tf.control_dependencies([train_op]):
        ema_train_op = common_utils.setup_ema(
            params, name_scope=f'{MODEL_SCOPE}/{model.name}')

    # 2nd call to student
    with tf.control_dependencies([ema_train_op]):
      with tf.variable_scope(MODEL_SCOPE, reuse=tf.AUTO_REUSE):
        logits['s_on_l_new'] = model(l_images, training=True)

    cross_entropy['s_on_l_new'] = tf.losses.softmax_cross_entropy(
        onehot_labels=labels['l'],
        logits=logits['s_on_l_new'],
        reduction=tf.losses.Reduction.SUM)
    cross_entropy['s_on_l_new'] = tf.tpu.cross_replica_sum(
        cross_entropy['s_on_l_new']) / float(train_batch_size)

    dot_product = cross_entropy['s_on_l_new'] - shadow
    # dot_product = tf.clip_by_value(
    #     dot_product,
    #     clip_value_min=-params.mpl_dot_product_bound,
    #     clip_value_max=params.mpl_dot_product_bound)
    moving_dot_product = tf.get_variable(
        'moving_dot_product', shape=[], trainable=False, dtype=tf.float32)
    moving_dot_product_update = tf.assign_sub(
        moving_dot_product, 0.01 * (moving_dot_product - dot_product))
    with tf.control_dependencies([moving_dot_product_update]):
      dot_product = dot_product - moving_dot_product
      dot_product = tf.stop_gradient(dot_product)
    cross_entropy['mpl'] = tf.losses.softmax_cross_entropy(
        onehot_labels=tf.stop_gradient(tf.nn.softmax(logits['u_aug'], axis=-1)),
        logits=logits['u_aug'],
        reduction=tf.losses.Reduction.NONE)
    cross_entropy['mpl'] = tf.reduce_sum(cross_entropy['mpl']) / float(
        train_batch_size*uda_data)

    # teacher train op
    uda_weight = params.uda_weight * tf.minimum(
        1., tf.cast(global_step, tf.float32) / float(params.uda_steps))
    teacher_loss = (cross_entropy['u'] * uda_weight +
                    cross_entropy['l'] +
                    cross_entropy['mpl'] * dot_product)
    w_s['t'] = [w for w in tf.trainable_variables() if 'teacher' in w.name]
    g_s['t'] = tf.gradients(teacher_loss, w_s['t'])
    g_s['t'] = common_utils.add_weight_decay(params, w_s['t'], g_s['t'])
    g_s['t'], g_n['t'] = tf.clip_by_global_norm(g_s['t'], params.grad_bound)
    lr['t'] = common_utils.get_learning_rate(
        params,
        initial_lr=params.mpl_teacher_lr,
        num_warmup_steps=params.mpl_teacher_lr_warmup_steps)
    lr['t'], optim['t'] = common_utils.get_optimizer(params,
                                                     learning_rate=lr['t'])

    teacher_train_op = optim['t'].apply_gradients(zip(g_s['t'], w_s['t']),
                                                  global_step=global_step)

    with tf.control_dependencies([teacher_train_op]):
      logs = collections.OrderedDict()
      logs['global_step'] = tf.cast(global_step, tf.float32)

      logs['cross_entropy/student_on_u'] = cross_entropy['s_on_u']
      logs['cross_entropy/student_on_l'] = (cross_entropy['s_on_l_new'] /
                                            num_replicas)
      logs['cross_entropy/teacher_on_u'] = cross_entropy['u']
      logs['cross_entropy/teacher_on_l'] = cross_entropy['l']
      logs['lr/student'] = tf.identity(lr['s']) / num_replicas
      logs['lr/teacher'] = tf.identity(lr['t']) / num_replicas
      logs['mpl/dot_product'] = dot_product / num_replicas
      logs['mpl/moving_dot_product'] = moving_dot_product / num_replicas
      logs['uda/u_ratio'] = tf.reduce_mean(masks['u']) / num_replicas
      logs['uda/l_ratio'] = tf.reduce_mean(masks['l']) / num_replicas
      logs['uda/weight'] = uda_weight / num_replicas

      tensors = [tf.expand_dims(t, axis=0) for t in logs.values()]
      self.step_info = {k: [tf.float32, [1]] for k in logs.keys()}
      def outfeed(tensors):
        with tf.device(tf.tpu.core(params.num_cores_per_replica-1)):
          return tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors)

      outfeed_enqueue_op = tf.cond(
          common_utils.should_log(params), lambda: outfeed(tensors), tf.no_op)

      return outfeed_enqueue_op
Esempio n. 19
0
    def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
        """Construct training op to update the registered variables based on their gradients."""
        tfutil.assert_tf_initialized()
        assert not self._updates_applied
        self._updates_applied = True
        all_ops = []

        # Check for no-op.
        if allow_no_op and len(self._devices) == 0:
            with tfutil.absolute_name_scope(self.scope):
                return tf.no_op(name='TrainingOp')

        # Clean up gradients.
        for device_idx, device in enumerate(self._devices.values()):
            with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
                for var, grad in device.grad_raw.items():

                    # Filter out disconnected gradients and convert to float32.
                    grad = [g for g in grad if g is not None]
                    grad = [tf.cast(g, tf.float32) for g in grad]

                    # Sum within the device.
                    if len(grad) == 0:
                        grad = tf.zeros(var.shape)  # No gradients => zero.
                    elif len(grad) == 1:
                        grad = grad[0]              # Single gradient => use as is.
                    else:
                        grad = tf.add_n(grad)       # Multiple gradients => sum.

                    # Scale as needed.
                    scale = 1.0 / len(device.grad_raw[var]) / len(self._devices)
                    scale = tf.constant(scale, dtype=tf.float32, name="scale")
                    if self.minibatch_multiplier is not None:
                        scale /= tf.cast(self.minibatch_multiplier, tf.float32)
                    scale = self.undo_loss_scaling(scale)
                    device.grad_clean[var] = grad * scale

        # Sum gradients across devices.
        if len(self._devices) > 1:
            with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
                if platform.system() == "Windows":    # Windows => NCCL ops are not available.
                    self._broadcast_fallback()
                elif tf.VERSION.startswith("1.15."):  # TF 1.15 => NCCL ops are broken: https://github.com/tensorflow/tensorflow/issues/41539
                    self._broadcast_fallback()
                else:                                 # Otherwise => NCCL ops are safe to use.
                    self._broadcast_nccl()

        # Apply updates separately on each device.
        for device_idx, device in enumerate(self._devices.values()):
            with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
                # pylint: disable=cell-var-from-loop

                # Accumulate gradients over time.
                if self.minibatch_multiplier is None:
                    acc_ok = tf.constant(True, name='acc_ok')
                    device.grad_acc = OrderedDict(device.grad_clean)
                else:
                    # Create variables.
                    with tf.control_dependencies(None):
                        for var in device.grad_clean.keys():
                            device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var")
                        device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count")

                    # Track counter.
                    count_cur = device.grad_acc_count + 1.0
                    count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur)
                    count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([]))
                    acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32))
                    all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op))

                    # Track gradients.
                    for var, grad in device.grad_clean.items():
                        acc_var = device.grad_acc_vars[var]
                        acc_cur = acc_var + grad
                        device.grad_acc[var] = acc_cur
                        with tf.control_dependencies([acc_cur]):
                            acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
                            acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape))
                            all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op))

                # No overflow => apply gradients.
                all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
                apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
                all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))

                # Adjust loss scaling.
                if self.use_loss_scaling:
                    ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc)
                    ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec)
                    ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
                    all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))

                # Last device => report statistics.
                if device_idx == len(self._devices) - 1:
                    all_ops.append(autosummary.autosummary(self.id + "/learning_rate", tf.convert_to_tensor(self.learning_rate)))
                    all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
                    if self.use_loss_scaling:
                        all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var))

        # Initialize variables.
        self.reset_optimizer_state()
        if self.use_loss_scaling:
            tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()])
        if self.minibatch_multiplier is not None:
            tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]])

        # Group everything into a single op.
        with tfutil.absolute_name_scope(self.scope):
            return tf.group(*all_ops, name="TrainingOp")
Esempio n. 20
0
    def build_trainer(self, child_model):
        # actor
        child_model.build_valid_rl()
        self.valid_acc = (tf.to_float(child_model.valid_shuffle_acc) /
                          tf.to_float(child_model.batch_size))
        self.reward = self.valid_acc

        if self.use_critic:
            # critic
            all_h = tf.concat(self.all_h, axis=0)
            value_function = tf.matmul(all_h, self.w_critic)
            advantage = value_function - self.reward
            critic_loss = tf.reduce_sum(advantage**2)
            self.baseline = tf.reduce_mean(value_function)
            self.loss = -tf.reduce_mean(self.sample_log_probs * advantage)

            critic_train_step = tf.Variable(0,
                                            dtype=tf.int32,
                                            trainable=False,
                                            name="critic_train_step")
            critic_train_op, _, _, _ = get_train_ops(critic_loss,
                                                     [self.w_critic],
                                                     critic_train_step,
                                                     clip_mode=None,
                                                     lr_init=1e-3,
                                                     lr_dec_start=0,
                                                     lr_dec_every=int(1e9),
                                                     optim_algo="adam",
                                                     sync_replicas=False)
        else:
            # or baseline
            self.sample_log_probs = tf.reduce_sum(self.sample_log_probs)
            self.baseline = tf.Variable(0.0, dtype=tf.float32, trainable=False)
            baseline_update = tf.assign_sub(self.baseline, (1 - self.bl_dec) *
                                            (self.baseline - self.reward))
            with tf.control_dependencies([baseline_update]):
                self.reward = tf.identity(self.reward)
            self.loss = self.sample_log_probs * (self.reward - self.baseline)

        self.train_step = tf.Variable(0,
                                      dtype=tf.int32,
                                      trainable=False,
                                      name="train_step")
        tf_variables = [
            var for var in tf.trainable_variables()
            if var.name.startswith(self.name) and "w_critic" not in var.name
        ]
        print "-" * 80
        for var in tf_variables:
            print var
        self.train_op, self.lr, self.grad_norm, self.optimizer = get_train_ops(
            self.loss,
            tf_variables,
            self.train_step,
            clip_mode=self.clip_mode,
            grad_bound=self.grad_bound,
            l2_reg=self.l2_reg,
            lr_init=self.lr_init,
            lr_dec_start=self.lr_dec_start,
            lr_dec_every=self.lr_dec_every,
            lr_dec_rate=self.lr_dec_rate,
            optim_algo=self.optim_algo,
            sync_replicas=self.sync_replicas,
            num_aggregate=self.num_aggregate,
            num_replicas=self.num_replicas)

        if self.use_critic:
            self.train_op = tf.group(self.train_op, critic_train_op)
    def step_fn(self, params, model):
        """A single step for supervised learning."""
        (train_images, train_labels, valid_images,
         valid_labels) = tf.raw_ops.InfeedDequeueTuple(
             dtypes=params.train_dtypes, shapes=params.train_shapes)

        if train_labels.dtype == tf.int32:
            train_labels = tf.one_hot(train_labels,
                                      depth=params.num_classes,
                                      dtype=tf.float32)
        if valid_labels.dtype == tf.int32:
            valid_labels = tf.one_hot(valid_labels,
                                      depth=params.num_classes,
                                      dtype=tf.float32)
        global_step = tf.train.get_or_create_global_step()

        num_replicas = tf.cast(params.num_replicas, tf.float32)

        with tf.variable_scope(MODEL_SCOPE):
            train_logits = model(train_images, training=True)

        with tf.variable_scope(SCORE_SCOPE):
            score_logits = model(train_images,
                                 training=False,
                                 return_scores=True)
            score_m = tf.tpu.cross_replica_sum(tf.reduce_sum(score_logits))
            score_m = tf.stop_gradient(score_m) / float(params.num_replicas)
            score_e = tf.exp(score_logits - score_m)
            score_z = tf.tpu.cross_replica_sum(tf.reduce_sum(score_e))
            score_probs = score_e / score_z

        # train the main model
        cross_entropy = tf.losses.softmax_cross_entropy(
            onehot_labels=train_labels,
            logits=train_logits,
            label_smoothing=params.label_smoothing,
            reduction=tf.losses.Reduction.NONE)
        cross_entropy = tf.reduce_sum(cross_entropy *
                                      tf.stop_gradient(score_probs))

        l2_reg_rate = tf.cast(params.weight_decay / params.num_replicas,
                              tf.float32)
        weight_dec = common_utils.get_l2_loss(excluded_keywords=[SCORE_SCOPE])
        total_loss = cross_entropy + weight_dec * l2_reg_rate

        model_variables = [
            v for v in tf.trainable_variables() if MODEL_SCOPE in v.name
        ]
        train_gradients = tf.gradients(total_loss, model_variables)
        train_gradients = [
            tf.tpu.cross_replica_sum(g) for g in train_gradients
        ]
        train_gradients, grad_norm = tf.clip_by_global_norm(
            train_gradients, params.grad_bound)

        learning_rate, optimizer = common_utils.get_optimizer(params)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        train_op = tf.cond(
            tf.math.is_finite(grad_norm), lambda: optimizer.
            apply_gradients(zip(train_gradients, model_variables),
                            global_step=global_step), tf.no_op)
        with tf.control_dependencies(update_ops + [train_op]):
            ema_train_op = common_utils.setup_ema(
                params, f'{MODEL_SCOPE}/{model.name}')

        with tf.control_dependencies([ema_train_op]):
            with tf.variable_scope(MODEL_SCOPE, reuse=True):
                valid_logits = model(valid_images, training=False)
                valid_cross_entropy = tf.losses.softmax_cross_entropy(
                    onehot_labels=valid_labels,
                    logits=valid_logits,
                    reduction=tf.losses.Reduction.MEAN) / float(
                        params.num_replicas)
                valid_gradients = tf.gradients(valid_cross_entropy,
                                               model_variables)
                valid_gradients = [
                    tf.tpu.cross_replica_sum(g) for g in valid_gradients
                ]

            dot_product = tf.add_n([
                tf.reduce_sum(g_t * g_v)
                for g_t, g_v in zip(train_gradients, valid_gradients)
            ])
            dot_product = tf.stop_gradient(dot_product)
            dot_product_avg = tf.get_variable(name='dot_product_avg',
                                              shape=[],
                                              trainable=False)
            dot_product_update = tf.assign_sub(
                dot_product_avg, 0.01 * (dot_product_avg - dot_product))
            with tf.control_dependencies([dot_product_update]):
                dot_product = tf.identity(dot_product - dot_product_avg)

        # trains the scorer.
        score_entropy = tf.reduce_sum(-score_probs * tf.math.log(score_probs))
        score_entropy = tf.tpu.cross_replica_sum(score_entropy) / float(
            valid_images.shape[0].value)
        score_variables = [
            v for v in tf.trainable_variables() if SCORE_SCOPE in v.name
        ]
        score_gradients = tf.gradients(dot_product * score_entropy,
                                       score_variables)
        score_gradients = [
            tf.tpu.cross_replica_sum(g) for g in score_gradients
        ]
        score_optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=params.scorer_lr, use_locking=True)
        score_train_op = tf.cond(
            global_step < params.scorer_wait_steps, tf.no_op,
            lambda: score_optimizer.apply_gradients(
                zip(score_gradients, score_variables)))

        with tf.control_dependencies([score_train_op]):
            logs = collections.OrderedDict()
            logs['global_step'] = tf.cast(global_step, tf.float32)

            logs['model/total'] = total_loss
            logs['model/weight_decay'] = weight_dec / num_replicas
            logs['model/cross_entropy'] = cross_entropy
            logs['model/lr'] = tf.identity(learning_rate) / num_replicas
            logs['model/grad_norm'] = grad_norm / num_replicas

            logs['score/dot_product'] = dot_product / num_replicas
            logs['score/dot_product_avg'] = dot_product_avg / num_replicas
            logs['score/entropy'] = score_entropy
            logs['score/p_min'] = tf.reduce_min(score_probs) / num_replicas
            logs['score/p_max'] = tf.reduce_max(score_probs) / num_replicas

            tensors = [tf.expand_dims(t, axis=0) for t in logs.values()]
            self.step_info = {k: [tf.float32, [1]] for k in logs.keys()}
            outfeed_enqueue_op = tf.cond(
                common_utils.should_log(params),
                lambda: tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors),
                tf.no_op)
        return outfeed_enqueue_op