def _batch_std(inputs, training, decay=MOVING_AVERAGE_DECAY, epsilon=EPSILON, data_format='channels_first', name='moving_variance'): """Batch standard deviation.""" if data_format == 'channels_last': var_shape, axes = (1, 1, 1, inputs.shape[3]), [0, 1, 2] else: var_shape, axes = (1, inputs.shape[1], 1, 1), [0, 2, 3] moving_variance = tf.get_variable( name=name, shape=var_shape, initializer=tf.initializers.ones(), dtype=tf.float32, collections=[ tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES ], trainable=False) if training: _, variance = tf.nn.moments(inputs, axes, keep_dims=True) variance = tf.cast(variance, tf.float32) update_op = tf.assign_sub(moving_variance, (moving_variance - variance) * (1 - decay)) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op) else: variance = moving_variance std = tf.sqrt(variance + epsilon) return tf.cast(std, inputs.dtype)
def _undo_update(self): ops = [] for w, dw in zip(self._weights, self._dws): ops.append(tf.assign_sub(w, dw)) return tf.group(ops)
def apply_gradients(self, grads_and_vars): with tf.name_scope(self.name): state_vars = [] update_ops = [] # Adjust learning rate to deal with startup bias. with tf.control_dependencies(None): b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False) b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False) state_vars += [b1pow_var, b2pow_var] b1pow_new = b1pow_var * self.beta1 b2pow_new = b2pow_var * self.beta2 update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)] lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new) # Construct ops to update each variable. for grad, var in grads_and_vars: with tf.control_dependencies(None): m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False) v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False) state_vars += [m_var, v_var] m_new = self.beta1 * m_var + (1 - self.beta1) * grad v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad) var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon) update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)] # Group everything together. self.all_state_vars += state_vars return tf.group(*update_ops)
def _finish(self, caches): """ """ if self.clip > 0: S_t = [cache['s_t'] for cache in caches] S_t, _ = tf.clip_by_global_norm(S_t, self.clip) for cache, s_t in zip(caches, S_t): cache['s_t'] = s_t for cache in caches: x_tm1 = cache['x_tm1'] s_t = cache['s_t'] updates = cache['updates'] with tf.name_scope('update_' + x_tm1.op.name), tf.device( x_tm1.device): if 'idxs' in cache: idxs = cache['idxs'] x_t = tf.scatter_sub(x_tm1, idxs, s_t) if self.chi > 0: x_t_ = tf.gather(x_t, idxs) x_bar_t, t_x_bar = self._sparse_moving_average( x_tm1, idxs, x_t_, 'x', beta=self.chi) else: x_t = tf.assign_sub(x_tm1, s_t) if self.chi > 0: x_bar_t, t_x_bar = self._dense_moving_average( x_tm1, x_t, 'x', beta=self.chi) updates.append(x_t) if self.chi > 0: updates.extend([x_bar_t, t_x_bar]) update_ops = [tf.group(*cache['updates']) for cache in caches] return tf.group(*update_ops, name='update')
def build_trainer(self, child_model): """Build the train ops by connecting Controller with a Child.""" # actor self.valid_loss = tf.to_float(child_model.rl_loss) self.valid_loss = tf.stop_gradient(self.valid_loss) self.valid_ppl = tf.exp(self.valid_loss) self.reward = REWARD_CONSTANT / self.valid_ppl if self.params.controller_entropy_weight: self.reward += self.params.controller_entropy_weight * self.sample_entropy # or baseline self.sample_log_probs = tf.reduce_sum(self.sample_log_probs) self.baseline = tf.Variable(0.0, dtype=tf.float32, trainable=False) baseline_update = tf.assign_sub(self.baseline, ((1 - self.params.controller_baseline_dec) * (self.baseline - self.reward))) with tf.control_dependencies([baseline_update]): self.reward = tf.identity(self.reward) self.loss = self.sample_log_probs * (self.reward - self.baseline) self.train_step = tf.Variable( 0, dtype=tf.int32, trainable=False, name='train_step') tf_vars = [var for var in tf.trainable_variables() if var.name.startswith(self.name)] self.train_op, self.optimizer, self.grad_norm = _build_train_op( loss=self.loss, tf_vars=tf_vars, learning_rate=self.params.controller_learning_rate, train_step=self.train_step, num_aggregate=self.params.controller_num_aggregate)
def _apply_sparse_shared(self, grad, var, indices, scatter_add): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = tf.cast(beta1_power, var.dtype.base_dtype) beta2_power = tf.cast(beta2_power, var.dtype.base_dtype) lr_t = tf.cast(self._lr_t, var.dtype.base_dtype) beta1_t = tf.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = tf.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = tf.cast(self._epsilon_t, var.dtype.base_dtype) lr = (lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power)) # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_scaled_g_values = grad * (1 - beta1_t) m_t = tf.assign(m, m * beta1_t, use_locking=self._use_locking) with tf.control_dependencies([m_t]): m_t = scatter_add(m, indices, m_scaled_g_values) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_scaled_g_values = (grad * grad) * (1 - beta2_t) v_t = tf.assign(v, v * beta2_t, use_locking=self._use_locking) with tf.control_dependencies([v_t]): v_t = scatter_add(v, indices, v_scaled_g_values) v_sqrt = tf.sqrt(v_t) var_update = tf.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) return tf.group(*[var_update, m_t, v_t])
def _sub_mixed_grad(self): ops = [] # Subtract the current gradient evaluated with the reference weights. for g_agg, g in zip(self._grads_aggregated, self._grads): ops.append(tf.assign_sub(g_agg, g)) return tf.group(ops)
def build_trainer(self, child_model): child_model.build_valid_rl() self.valid_acc = (tf.to_float(child_model.valid_shuffle_acc) / tf.to_float(child_model.batch_size)) self.current_normal_arc = child_model.current_normal_arc self.current_reduce_arc = child_model.current_reduce_arc self.reward = self.valid_acc if self.entropy_weight is not None: self.reward += self.entropy_weight * self.sample_entropy self.sample_log_prob = tf.reduce_sum(self.sample_log_prob) self.baseline = tf.Variable(0.0, dtype=tf.float32, trainable=False) baseline_update = tf.assign_sub(self.baseline, (1 - self.bl_dec) * (self.baseline - self.reward)) with tf.control_dependencies([baseline_update]): self.reward = tf.identity(self.reward) self.loss = self.sample_log_prob * (self.reward - self.baseline) self.train_step = tf.Variable(0, dtype=tf.int32, trainable=False, name="train_step") tf_variables = [ var for var in tf.trainable_variables() if var.name.startswith(self.name) ] print("-" * 80) for var in tf_variables: print(var) self.train_op, self.lr, self.grad_norm, self.optimizer = get_train_ops( self.loss, tf_variables, self.train_step, clip_mode=self.clip_mode, grad_bound=self.grad_bound, l2_reg=self.l2_reg, lr_init=self.lr_init, lr_dec_start=self.lr_dec_start, lr_dec_every=self.lr_dec_every, lr_dec_rate=self.lr_dec_rate, optim_algo=self.optim_algo, sync_replicas=self.sync_replicas, num_aggregate=self.num_aggregate, num_replicas=self.num_replicas) self.skip_rate = tf.constant(0.0, dtype=tf.float32)
def _apply_sparse_shared(self, grad, var, indices, scatter_add): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = tf.cast(beta1_power, var.dtype.base_dtype) beta2_power = tf.cast(beta2_power, var.dtype.base_dtype) lr_t = tf.cast(self._lr_t, var.dtype.base_dtype) beta1_t = tf.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = tf.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = tf.cast(self._epsilon_t, var.dtype.base_dtype) weight_decay_rate_t = tf.cast(self._weight_decay_rate_t, var.dtype.base_dtype) # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_scaled_g_values = grad * (1 - beta1_t) m_t = tf.assign(m, m * beta1_t, use_locking=self._use_locking) with tf.control_dependencies([m_t]): m_t = scatter_add(m, indices, m_scaled_g_values) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_scaled_g_values = (grad * grad) * (1 - beta2_t) v_t = tf.assign(v, v * beta2_t, use_locking=self._use_locking) with tf.control_dependencies([v_t]): v_t = scatter_add(v, indices, v_scaled_g_values) # ==== The following is with m_t_hat and v_t_hat m_t_hat = m_t / (1. - beta1_power) v_t_hat = v_t / (1. - beta2_power) v_sqrt = tf.sqrt(v_t_hat) update = m_t_hat / (v_sqrt + epsilon_t) # ==== The following is the original LAMBOptimizer implementation # v_sqrt = tf.sqrt(v_t_hat) # update = m_t / (v_sqrt + epsilon_t) var_name = self._get_variable_name(var.name) if self._do_use_weight_decay(var_name): update += weight_decay_rate_t * var ratio = 1.0 if self._do_layer_adaptation(var_name): w_norm = tf.norm(var, ord=2) g_norm = tf.norm(update, ord=2) ratio = tf.where( tf.greater(w_norm, 0), tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0) var_update = tf.assign_sub(var, ratio * lr_t * update, use_locking=self._use_locking) return tf.group(*[var_update, m_t, v_t])
def _update(self, rs, ps): ops = [] # Compute the coefficient alpha. pTAp = tf.zeros(shape=[], dtype=ps[0].dtype) for p, Az in zip(ps, self._Azs): # Recall that p has already been assigned to z, and hence Az = Ap. pTAp += tf.reduce_sum(p * Az) indefinite = pTAp <= 0.0 ops.append(tf.assign(self._indefinite, indefinite)) alpha = tf.cond(indefinite, lambda: 0.0, lambda: self._rTr / pTAp) # Update the solution and residual. for x, r, p, Az in zip(self._xs, rs, ps, self._Azs): ops.append(tf.assign_add(x, alpha * p)) ops.append(tf.assign_sub(r, alpha * Az)) return tf.group(ops)
def _apply_dense(self, grad, var): m = self.get_slot(var, "m") v = self.get_slot(var, "v") lr = tf.cast(self._lr_t, grad.dtype.base_dtype) beta1 = tf.cast(self._beta1_t, grad.dtype.base_dtype) beta2 = tf.cast(self._beta2_t, grad.dtype.base_dtype) epsilon = tf.cast(self._epsilon_t, grad.dtype.base_dtype) grad = grad - var * self._l2_weight_decay # m_t = beta_1 * m_{t-1} + (1-beta_1) * g_t m_t = m.assign(beta1 * m + (1.0 - beta1) * grad) # v_t = beta_2 * v_{t-1} + (1-beta_2) * g_t ** 2 v_t = v.assign(beta2 * v + (1.0 - beta2) * grad * grad) if self._use_bias_correction: beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = tf.cast(beta1_power, grad.dtype.base_dtype) beta2_power = tf.cast(beta2_power, grad.dtype.base_dtype) lr_t = lr * tf.sqrt(1.0 - beta2_power) / (1.0 - beta1_power) else: lr_t = lr if self._use_nesterov: # delta theta = lr_t * ( # (beta_1 * m_t + (1-beta1) * g_t) / (sqrt(v_t) + epsilon)) step = lr_t * ((beta1 * m_t + (1.0 - beta1) * grad) / (tf.sqrt(v_t) + epsilon)) else: # delta theta = lr_t * m_t / (sqrt(v_t) + epsilon) step = lr_t * m_t / (tf.sqrt(v_t) + epsilon) # AdamW style weight decay term. step = step + lr_t * self._adamw_weight_decay * var theta_t = tf.assign_sub(var, step) return tf.group(*[theta_t, m_t, v_t])
def setup_ema(params, name_scope=None): """Create exponential moving average for all variables under `name_scope`.""" logging.info(f'ema_decay with rate {params.ema_decay}') all_vars = tf.global_variables() ema_ops = [] step = tf.cast(tf.train.get_or_create_global_step() - params.ema_start, tf.float32) decay = 1. - tf.minimum(params.ema_decay, (step+1.) / (step+10.)) decay = tf.cond(tf.train.get_or_create_global_step() < params.ema_start, lambda: tf.constant(1, tf.float32), lambda: decay) def should_skip(v): key_words = ['momentum', 'rms', 'global_step', 'debug', 'adam', 'lars'] conditions = [k in v.name.lower() for k in key_words] if name_scope is not None: conditions += [not v.name.lower().startswith(name_scope)] return any(conditions) def get_init(v_name): key_words = ['variance', 'beta'] if any([k in v_name for k in key_words]): return tf.initializers.ones() return tf.initializers.zeros() with tf.variable_scope('ema'): for v in all_vars: if not should_skip(v): v_name = strip_var_name(v.name) with tf.device(v.device): ema_var = tf.get_variable( name=v_name, shape=v.shape.as_list(), initializer=get_init(v_name), trainable=False) v = shard_weight(v, params.num_cores_per_replica) ema = shard_weight(ema_var, params.num_cores_per_replica) ema_op = tf.assign_sub(ema_var, decay * (ema-v), use_locking=True) ema_ops.append(ema_op) ema_op = tf.group(*ema_ops) return ema_op
def _update(self, rs, ps): ops = [] # Compute the coefficient alpha. pTHp = tf.zeros(shape=[], dtype=ps[0].dtype) for p, Hz in zip(ps, self._hessians): # Recall that p has already been assigned to z, and hence Hz = Hp. pTHp += tf.reduce_sum(p * Hz) # Compute the coefficient for the update. alpha = self._rTr / pTHp # Create a tensor that computes the norm of the iterate after the update # without actually modifying it. norm_dw_new = tf.zeros(shape=[], dtype=self._norm_dw.dtype) for dw, p in zip(self._dws, ps): dw_new = dw + alpha * p norm_dw_new += tf.reduce_sum(dw_new * dw_new) norm_dw_new = tf.sqrt(norm_dw_new) # Determine if we should follow the direction p until it intersects with the # boundary of the trust region. # This is the case if either p is a direction of indefiniteness or if dw + p # would be outside the trust region. follow_to_boundary = tf.logical_or(pTHp <= 0.0, norm_dw_new > self._radius_placeh) self._follow_to_boundary = tf.Variable(False) ops.append(tf.assign(self._follow_to_boundary, follow_to_boundary)) # If we follow p up to the boundary, we do not update dw here. # Instead, we determine the final update dw in the 'solve' method. alpha_or_zero = tf.cond(follow_to_boundary, lambda: 0.0, lambda: alpha) # Update the solution and residual. for dw, r, p, Hz in zip(self._dws, rs, ps, self._hessians): ops.append(tf.assign_add(dw, alpha_or_zero * p)) ops.append(tf.assign_sub(r, alpha_or_zero * Hz)) return tf.group(ops)
def _apply_dense(self, grad, var): # SM3 upper bounds the gradient square sums: # # To illustrate: # # For a Tensor `T` of shape [M, N, K]. # # `G` be its gradient of shape [M, N, K] # # SM3 keeps around three accumulators A1, A2, A3 of size M, N, K # respectively. # # `A` be the accumulator of shape [M, N, K]. `A` is not materialized until # its needed for every step, and is approximated by A1, A2, A3. # # At every gradient update step the accumulators satisify: # A1_t[i] >= Sum_{s <= t} G_t[i, j, k]^2 for all j, k. # A2_t[j] >= Sum_{s <= t} G_t[i, j, k]^2 for all i, k. # A3_t[k] >= Sum_{s <= t} G_t[i, j, k]^2 for all i, j. # # The RHS is the gradient sum squares. # # For every step we materialize the tensor `A` based on accumulated tensors # A1, A2 and A3. # # A = min(A1[i], A2[j], A3[j]) + G[i, j, k]^2 # # SM3 preconditioned gradient is # # preconditioned G = A^{-0.5} * G # # We then update the individual accumulator factors as: # # A1[i] = max_{j, k} A[i, j, k] # A2[j] = max_{i, k} A[i, j, k] # A3[k] = max_{i, j} A[i, j, k] # shape = np.array(var.get_shape()) var_rank = len(shape) if var_rank > 1: accumulator_list = [ self.get_slot(var, "accumulator_" + str(i)) for i in range(var_rank) ] accumulator = self._compute_past_accumulator( accumulator_list, shape) accumulator += grad * grad else: accumulator_var = self.get_slot(var, "accumulator") accumulator = tf.assign_add(accumulator_var, grad * grad) accumulator_inv_sqrt = tf.rsqrt(accumulator + 1e-30) scaled_g = (1.0 - self._momentum_tensor) * (grad * accumulator_inv_sqrt) accumulator_update_ops = [] with tf.control_dependencies([scaled_g]): if var_rank > 1: # Updates individual accumulator factors as: # A1[i] = max_{j, k} A[i, j, k] # A2[j] = max_{i, k} A[i, j, k] # A3[k] = max_{i, j} A[i, j, k] for i, accumulator_i in enumerate(accumulator_list): axes = list(range(i)) + list(range(i + 1, var_rank)) new_accumulator_i = tf.reduce_max(accumulator, axis=axes) accumulator_update_ops.append( tf.assign(accumulator_i, new_accumulator_i)) with tf.control_dependencies(accumulator_update_ops): if self._momentum > 0: gbar = self.get_slot(var, "momentum") update = tf.assign_add( gbar, gbar * (self._momentum_tensor - 1.0) + scaled_g) else: update = scaled_g return tf.assign_sub(var, self._learning_rate_tensor * update)
def u(moving, normal, name): num_replicas_fp = tf.cast(num_replicas, tf.float32) normal = tf.tpu.cross_replica_sum(normal) / num_replicas_fp diff = decay * (moving - normal) return tf.assign_sub(moving, diff, use_locking=True, name=name)
def apply_updates(self): assert not self._updates_applied self._updates_applied = True devices = list(self._dev_grads.keys()) total_grads = sum(len(grads) for grads in self._dev_grads.values()) assert len(devices) >= 1 and total_grads >= 1 ops = [] with absolute_name_scope(self.scope): # Cast gradients to FP32 and calculate partial sum within each device. dev_grads = OrderedDict() # device => [(grad, var), ...] for dev_idx, dev in enumerate(devices): with tf.name_scope('ProcessGrads%d' % dev_idx), tf.device(dev): sums = [] for gv in zip(*self._dev_grads[dev]): assert all(v is gv[0][1] for g, v in gv) g = [tf.cast(g, tf.float32) for g, v in gv] g = g[0] if len(g) == 1 else tf.add_n(g) sums.append((g, gv[0][1])) dev_grads[dev] = sums # Sum gradients across devices. if len(devices) > 1: with tf.name_scope('SumAcrossGPUs'), tf.device(None): for var_idx, grad_shape in enumerate(self._grad_shapes): g = [dev_grads[dev][var_idx][0] for dev in devices] if np.prod( grad_shape ): # nccl does not support zero-sized tensors g = tf.contrib.nccl.all_sum(g) for dev, gg in zip(devices, g): dev_grads[dev][var_idx] = ( gg, dev_grads[dev][var_idx][1]) # Apply updates separately on each device. for dev_idx, (dev, grads) in enumerate(dev_grads.items()): with tf.name_scope('ApplyGrads%d' % dev_idx), tf.device(dev): # Scale gradients as needed. if self.use_loss_scaling or total_grads > 1: with tf.name_scope('Scale'): coef = tf.constant(np.float32(1.0 / total_grads), name='coef') coef = self.undo_loss_scaling(coef) grads = [(g * coef, v) for g, v in grads] # Check for overflows. with tf.name_scope('CheckOverflow'): grad_ok = tf.reduce_all( tf.stack([ tf.reduce_all(tf.is_finite(g)) for g, v in grads ])) # Update weights and adjust loss scaling. with tf.name_scope('UpdateWeights'): opt = self._dev_opt[dev] ls_var = self.get_loss_scaling_var(dev) if not self.use_loss_scaling: ops.append( tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op)) else: ops.append( tf.cond( grad_ok, lambda: tf.group( tf.assign_add(ls_var, self. loss_scaling_inc), opt.apply_gradients(grads)), lambda: tf.group( tf.assign_sub(ls_var, self. loss_scaling_dec)))) # Report statistics on the last device. if dev == devices[-1]: with tf.name_scope('Statistics'): ops.append( autosummary(self.id + '/learning_rate', self.learning_rate)) ops.append( autosummary(self.id + '/overflow_frequency', tf.where(grad_ok, 0, 1))) if self.use_loss_scaling: ops.append( autosummary(self.id + '/loss_scaling_log2', ls_var)) # Initialize variables and group everything into a single op. self.reset_optimizer_state() init_uninited_vars(list(self._dev_ls_var.values())) return tf.group(*ops, name='TrainingOp')
def __init__(self, n_sample, minibatch_sz, m1_inp_shape, m2_inp_shape, m1_layers, m2_layers, msi_layers, m1_cause_init, m2_cause_init, msi_cause_init, reg_m1_causes, reg_m2_causes, reg_msi_causes, lr_m1_causes, lr_m2_causes, lr_msi_causes, reg_m1_filters, reg_m2_filters, reg_msi_filters, lr_m1_filters, lr_m2_filters, lr_msi_filters): self.m1_inp_shape = m1_inp_shape self.m2_inp_shape = m2_inp_shape self.m1_layers = m1_layers self.m2_layers = m2_layers self.msi_layers = msi_layers # create placeholders self.x_m1 = tf.placeholder(tf.float32, shape=[minibatch_sz, m1_inp_shape]) self.x_m2 = tf.placeholder(tf.float32, shape=[minibatch_sz, m2_inp_shape]) self.batch = tf.placeholder(tf.int32, shape=[]) # create filters and cause for m1 self.m1_filters = [] self.m1_causes = [] for i in range(len(self.m1_layers)): filter_name = 'm1_filter_%d' % i cause_name = 'm1_cause_%d' % i if i == 0: self.m1_filters += [ tf.get_variable( filter_name, shape=[self.m1_layers[i], self.m1_inp_shape]) ] else: self.m1_filters += [ tf.get_variable( filter_name, shape=[self.m1_layers[i], self.m1_layers[i - 1]]) ] init = tf.constant_initializer(m1_cause_init[i]) self.m1_causes += [ tf.get_variable(cause_name, shape=[n_sample, self.m1_layers[i]], initializer=init) ] # create filters and cause for m2 self.m2_filters = [] self.m2_causes = [] for i in range(len(self.m2_layers)): filter_name = 'm2_filter_%d' % i cause_name = 'm2_cause_%d' % i if i == 0: self.m2_filters += [ tf.get_variable( filter_name, shape=[self.m2_layers[i], self.m2_inp_shape]) ] else: self.m2_filters += [ tf.get_variable( filter_name, shape=[self.m2_layers[i], self.m2_layers[i - 1]]) ] init = tf.constant_initializer(m2_cause_init[i]) self.m2_causes += [ tf.get_variable(cause_name, shape=[n_sample, self.m2_layers[i]], initializer=init) ] # create filters and cause for msi self.msi_filters = [] self.msi_causes = [] for i in range(len(self.msi_layers)): if i == 0: # add filters for m1 filter_name = 'msi_m1_filter' self.msi_filters += [ tf.get_variable( filter_name, shape=[self.msi_layers[i], self.m1_layers[-1]]) ] # add filters for m2 filter_name = 'msi_m2_filter' self.msi_filters += [ tf.get_variable( filter_name, shape=[self.msi_layers[i], self.m2_layers[-1]]) ] else: filter_name = 'msi_filter_%d' % i self.msi_filters += [ tf.get_variable( filter_name, shape=[self.msi_layers[i], self.msi_layers[i - 1]]) ] cause_name = 'msi_cause_%d' % i init = tf.constant_initializer(msi_cause_init[i]) self.msi_causes += [ tf.get_variable(cause_name, shape=[n_sample, self.msi_layers[i]], initializer=init) ] # compute predictions current_batch = tf.range(self.batch * minibatch_sz, (self.batch + 1) * minibatch_sz) # m1 predictions self.m1_minibatch = [] self.m1_predictions = [] for i in range(len(self.m1_layers)): self.m1_minibatch += [ tf.gather(self.m1_causes[i], indices=current_batch, axis=0) ] self.m1_predictions += [ tf.nn.leaky_relu( tf.matmul(self.m1_minibatch[i], self.m1_filters[i])) ] # m2 predictions self.m2_minibatch = [] self.m2_predictions = [] for i in range(len(self.m2_layers)): self.m2_minibatch += [ tf.gather(self.m2_causes[i], indices=current_batch, axis=0) ] self.m2_predictions += [ tf.nn.leaky_relu( tf.matmul(self.m2_minibatch[i], self.m2_filters[i])) ] # msi predictions self.msi_minibatch = [] self.msi_predictions = [] for i in range(len(self.msi_layers)): self.msi_minibatch += [ tf.gather(self.msi_causes[i], indices=current_batch, axis=0) ] if i == 0: self.msi_predictions += [ tf.nn.leaky_relu( tf.matmul(self.msi_minibatch[i], self.msi_filters[i])) ] # m1 prediction self.msi_predictions += [ tf.nn.leaky_relu( tf.matmul(self.msi_minibatch[i], self.msi_filters[i + 1])) ] # m2 prediction else: self.msi_predictions += [ tf.nn.leaky_relu( tf.matmul(self.msi_minibatch[i], self.msi_filters[i + 1])) ] # add ops for computing gradients for m1 causes and for updating weights self.m1_bu_error = [] self.m1_update_filter = [] self.m1_cause_grad = [] for i in range(len(self.m1_layers)): if i == 0: self.m1_bu_error += [ tf.losses.mean_squared_error( self.x_m1, self.m1_predictions[i], reduction=tf.losses.Reduction.NONE) ] else: self.m1_bu_error += [ tf.losses.mean_squared_error( tf.stop_gradient(self.m1_minibatch[i - 1]), self.m1_predictions[i], reduction=tf.losses.Reduction.NONE) ] # compute top-down prediction error if len(self.m1_layers) > (i + 1): # there are more layers in this modality td_error = tf.losses.mean_squared_error( tf.stop_gradient(self.m1_predictions[i + 1]), self.m1_minibatch[i], reduction=tf.losses.Reduction.NONE) else: # this is the only layer in this modality td_error = tf.losses.mean_squared_error( tf.stop_gradient(self.msi_predictions[0]), self.m1_minibatch[i], reduction=tf.losses.Reduction.NONE) reg_error = reg_m1_causes[i] * (self.m1_minibatch[i]**2) # reg_error = tf.keras.regularizers.l2(reg_m1_causes[i])(self.m1_minibatch[i]) self.m1_cause_grad += [ tf.gradients([self.m1_bu_error[i], td_error, reg_error], self.m1_minibatch[i])[0] ] # ops for updating weights reg_error = reg_m1_filters[i] * (self.m1_filters[i]**2) m1_filter_grad = tf.gradients([self.m1_bu_error[i], reg_error], self.m1_filters[i])[0] self.m1_update_filter += [ tf.assign_sub(self.m1_filters[i], lr_m1_filters[i] * m1_filter_grad) ] # add ops for computing gradients for m2 causes and for updating weights self.m2_bu_error = [] self.m2_update_filter = [] self.m2_cause_grad = [] for i in range(len(self.m2_layers)): if i == 0: self.m2_bu_error += [ tf.losses.mean_squared_error( self.x_m2, self.m2_predictions[i], reduction=tf.losses.Reduction.NONE) ] else: self.m2_bu_error += [ tf.losses.mean_squared_error( tf.stop_gradient(self.m2_minibatch[i - 1]), self.m2_predictions[i], reduction=tf.losses.Reduction.NONE) ] # compute top-down prediction error if len(self.m2_layers) > (i + 1): # there are more layers in this modality td_error = tf.losses.mean_squared_error( tf.stop_gradient(self.m2_predictions[i + 1]), self.m2_minibatch[i], reduction=tf.losses.Reduction.NONE) else: # this is the only layer in this modality td_error = tf.losses.mean_squared_error( tf.stop_gradient(self.msi_predictions[1]), self.m2_minibatch[i], reduction=tf.losses.Reduction.NONE) reg_error = reg_m2_causes[i] * (self.m2_minibatch[i]**2) # reg_error = tf.keras.regularizers.l2(reg_m2_causes[i])(self.m2_minibatch[i]) self.m2_cause_grad += [ tf.gradients([self.m2_bu_error[i], td_error, reg_error], self.m2_minibatch[i])[0] ] # add ops for updating weights reg_error = reg_m2_filters[i] * (self.m2_filters[i]**2) m2_filter_grad = tf.gradients([self.m2_bu_error[i], reg_error], self.m2_filters[i])[0] self.m1_update_filter += [ tf.assign_sub(self.m2_filters[i], lr_m2_filters[i] * m2_filter_grad) ] #else: #raise NotImplementedError # add ops for computing gradients for msi causes self.msi_bu_error = [] self.msi_reg_error = [] self.msi_update_filter = [] self.msi_cause_grad = [] for i in range(len(self.msi_layers)): if i == 0: self.msi_bu_error += [ tf.losses.mean_squared_error( tf.stop_gradient(self.m1_minibatch[-1]), self.msi_predictions[i], reduction=tf.losses.Reduction.NONE) ] self.msi_bu_error += [ tf.losses.mean_squared_error( tf.stop_gradient(self.m2_minibatch[-1]), self.msi_predictions[i + 1], reduction=tf.losses.Reduction.NONE) ] self.msi_reg_error += [ reg_msi_causes[i] * (self.msi_minibatch[i]**2) ] # self.msi_reg_error += [tf.keras.regularizers.l2(reg_msi_causes[i])(self.msi_minibatch[i])] if len(self.msi_layers) > 1: raise NotImplementedError else: self.msi_cause_grad += [ tf.gradients([ self.msi_bu_error[i], self.msi_bu_error[i + 1], self.msi_reg_error[i] ], self.msi_minibatch[i])[0] ] # add ops for updating weights reg_error = reg_msi_filters[i] * (self.msi_filters[i]**2) msi_filter_grad = tf.gradients( [self.msi_bu_error[i], reg_error], self.msi_filters[i])[0] self.msi_update_filter += [ tf.assign_sub(self.msi_filters[i], lr_msi_filters[i] * msi_filter_grad) ] reg_error = reg_msi_filters[i + 1] * (self.msi_filters[i + 1]** 2) msi_filter_grad = tf.gradients( [self.msi_bu_error[i + 1], reg_error], self.msi_filters[i + 1])[0] self.msi_update_filter += [ tf.assign_sub(self.msi_filters[i + 1], lr_msi_filters[i + 1] * msi_filter_grad) ] else: raise NotImplementedError # add ops for updating causes self.m1_update_cause = [] self.m2_update_cause = [] self.msi_update_cause = [] with tf.control_dependencies(self.m1_cause_grad + self.m2_cause_grad + self.msi_cause_grad): # m1 modality for i in range(len(self.m1_layers)): self.m1_update_cause += [ tf.scatter_sub(self.m1_causes[i], indices=current_batch, updates=(lr_m1_causes[i] * self.m1_cause_grad[i])) ] # m2 modality for i in range(len(self.m2_layers)): self.m2_update_cause += [ tf.scatter_sub(self.m2_causes[i], indices=current_batch, updates=(lr_m2_causes[i] * self.m2_cause_grad[i])) ] # msi modality for i in range(len(self.msi_layers)): self.msi_update_cause += [ tf.scatter_sub(self.msi_causes[i], indices=current_batch, updates=(lr_msi_causes[i] * self.msi_cause_grad[i])) ]
def step_fn(self, params, model): """Separate implementation.""" train_batch_size = params.train_batch_size num_replicas = params.num_replicas uda_data = params.uda_data batch_size = train_batch_size // num_replicas dtypes = [ tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.float32, tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.bfloat16 if params.use_bfloat16 else tf.float32] shapes = [ [batch_size, params.image_size, params.image_size, 3], [batch_size, params.num_classes], [batch_size*params.uda_data, params.image_size, params.image_size, 3], [batch_size*params.uda_data, params.image_size, params.image_size, 3]] if params.use_xla_sharding and params.num_cores_per_replica > 1: q = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=4, host_id=0, input_partition_dims=[[1, 1, params.num_cores_per_replica, 1], [1, 1], [1, 1, params.num_cores_per_replica, 1], [1, 1, params.num_cores_per_replica, 1],], device_assignment=params.device_assignment) q.set_tuple_types(dtypes) q.set_tuple_shapes(shapes) l_images, l_labels, u_images_ori, u_images_aug = q.generate_dequeue_op() l_images = xla_sharding.split(l_images, 2, params.num_cores_per_replica) u_images_ori = xla_sharding.split(u_images_ori, 2, params.num_cores_per_replica) u_images_aug = xla_sharding.split(u_images_aug, 2, params.num_cores_per_replica) else: with tf.device(tf.tpu.core(0)): (l_images, l_labels, u_images_ori, u_images_aug) = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes, shapes=shapes) global_step = tf.train.get_or_create_global_step() num_replicas = tf.cast(params.num_replicas, tf.float32) all_images = tf.concat([l_images, u_images_ori, u_images_aug], axis=0) # all calls to teacher with tf.variable_scope('teacher', reuse=tf.AUTO_REUSE): logits, labels, masks, cross_entropy = UDA.build_uda_cross_entropy( params, model, all_images, l_labels) # 1st call to student with tf.variable_scope(MODEL_SCOPE): u_aug_and_l_images = tf.concat([u_images_aug, l_images], axis=0) logits['s_on_u_aug_and_l'] = model(u_aug_and_l_images, training=True) logits['s_on_u'], logits['s_on_l_old'] = tf.split( logits['s_on_u_aug_and_l'], [u_images_aug.shape[0].value, l_images.shape[0].value], axis=0) # for backprop cross_entropy['s_on_u'] = tf.losses.softmax_cross_entropy( onehot_labels=tf.stop_gradient(tf.nn.softmax(logits['u_aug'], -1)), logits=logits['s_on_u'], label_smoothing=params.label_smoothing, reduction=tf.losses.Reduction.NONE) cross_entropy['s_on_u'] = tf.reduce_sum(cross_entropy['s_on_u']) / float( train_batch_size*uda_data) # for Taylor cross_entropy['s_on_l_old'] = tf.losses.softmax_cross_entropy( onehot_labels=labels['l'], logits=logits['s_on_l_old'], reduction=tf.losses.Reduction.SUM) cross_entropy['s_on_l_old'] = tf.tpu.cross_replica_sum( cross_entropy['s_on_l_old']) / float(train_batch_size) shadow = tf.get_variable( name='cross_entropy_old', shape=[], trainable=False, dtype=tf.float32) shadow_update = tf.assign(shadow, cross_entropy['s_on_l_old']) w_s = {} g_s = {} g_n = {} lr = {} optim = {} w_s['s'] = [w for w in tf.trainable_variables() if w.name.lower().startswith(MODEL_SCOPE)] g_s['s_on_u'] = tf.gradients(cross_entropy['s_on_u'], w_s['s']) # g_s['s_on_u'] = [tf.tpu.cross_replica_sum(g) for g in g_s['s_on_u']] lr['s'] = common_utils.get_learning_rate( params, initial_lr=params.mpl_student_lr, num_warmup_steps=params.mpl_student_lr_warmup_steps, num_wait_steps=params.mpl_student_lr_wait_steps) lr['s'], optim['s'] = common_utils.get_optimizer( params, learning_rate=lr['s']) optim['s']._create_slots(w_s['s']) # pylint: disable=protected-access update_ops = [op for op in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if op.name.startswith(f'train/{MODEL_SCOPE}/')] with tf.control_dependencies(update_ops + [shadow_update]): g_s['s_on_u'] = common_utils.add_weight_decay( params, w_s['s'], g_s['s_on_u']) g_s['s_on_u'], g_n['s_on_u'] = tf.clip_by_global_norm( g_s['s_on_u'], params.grad_bound) train_op = optim['s'].apply_gradients(zip(g_s['s_on_u'], w_s['s'])) with tf.control_dependencies([train_op]): ema_train_op = common_utils.setup_ema( params, name_scope=f'{MODEL_SCOPE}/{model.name}') # 2nd call to student with tf.control_dependencies([ema_train_op]): with tf.variable_scope(MODEL_SCOPE, reuse=tf.AUTO_REUSE): logits['s_on_l_new'] = model(l_images, training=True) cross_entropy['s_on_l_new'] = tf.losses.softmax_cross_entropy( onehot_labels=labels['l'], logits=logits['s_on_l_new'], reduction=tf.losses.Reduction.SUM) cross_entropy['s_on_l_new'] = tf.tpu.cross_replica_sum( cross_entropy['s_on_l_new']) / float(train_batch_size) dot_product = cross_entropy['s_on_l_new'] - shadow # dot_product = tf.clip_by_value( # dot_product, # clip_value_min=-params.mpl_dot_product_bound, # clip_value_max=params.mpl_dot_product_bound) moving_dot_product = tf.get_variable( 'moving_dot_product', shape=[], trainable=False, dtype=tf.float32) moving_dot_product_update = tf.assign_sub( moving_dot_product, 0.01 * (moving_dot_product - dot_product)) with tf.control_dependencies([moving_dot_product_update]): dot_product = dot_product - moving_dot_product dot_product = tf.stop_gradient(dot_product) cross_entropy['mpl'] = tf.losses.softmax_cross_entropy( onehot_labels=tf.stop_gradient(tf.nn.softmax(logits['u_aug'], axis=-1)), logits=logits['u_aug'], reduction=tf.losses.Reduction.NONE) cross_entropy['mpl'] = tf.reduce_sum(cross_entropy['mpl']) / float( train_batch_size*uda_data) # teacher train op uda_weight = params.uda_weight * tf.minimum( 1., tf.cast(global_step, tf.float32) / float(params.uda_steps)) teacher_loss = (cross_entropy['u'] * uda_weight + cross_entropy['l'] + cross_entropy['mpl'] * dot_product) w_s['t'] = [w for w in tf.trainable_variables() if 'teacher' in w.name] g_s['t'] = tf.gradients(teacher_loss, w_s['t']) g_s['t'] = common_utils.add_weight_decay(params, w_s['t'], g_s['t']) g_s['t'], g_n['t'] = tf.clip_by_global_norm(g_s['t'], params.grad_bound) lr['t'] = common_utils.get_learning_rate( params, initial_lr=params.mpl_teacher_lr, num_warmup_steps=params.mpl_teacher_lr_warmup_steps) lr['t'], optim['t'] = common_utils.get_optimizer(params, learning_rate=lr['t']) teacher_train_op = optim['t'].apply_gradients(zip(g_s['t'], w_s['t']), global_step=global_step) with tf.control_dependencies([teacher_train_op]): logs = collections.OrderedDict() logs['global_step'] = tf.cast(global_step, tf.float32) logs['cross_entropy/student_on_u'] = cross_entropy['s_on_u'] logs['cross_entropy/student_on_l'] = (cross_entropy['s_on_l_new'] / num_replicas) logs['cross_entropy/teacher_on_u'] = cross_entropy['u'] logs['cross_entropy/teacher_on_l'] = cross_entropy['l'] logs['lr/student'] = tf.identity(lr['s']) / num_replicas logs['lr/teacher'] = tf.identity(lr['t']) / num_replicas logs['mpl/dot_product'] = dot_product / num_replicas logs['mpl/moving_dot_product'] = moving_dot_product / num_replicas logs['uda/u_ratio'] = tf.reduce_mean(masks['u']) / num_replicas logs['uda/l_ratio'] = tf.reduce_mean(masks['l']) / num_replicas logs['uda/weight'] = uda_weight / num_replicas tensors = [tf.expand_dims(t, axis=0) for t in logs.values()] self.step_info = {k: [tf.float32, [1]] for k in logs.keys()} def outfeed(tensors): with tf.device(tf.tpu.core(params.num_cores_per_replica-1)): return tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors) outfeed_enqueue_op = tf.cond( common_utils.should_log(params), lambda: outfeed(tensors), tf.no_op) return outfeed_enqueue_op
def apply_updates(self, allow_no_op: bool = False) -> tf.Operation: """Construct training op to update the registered variables based on their gradients.""" tfutil.assert_tf_initialized() assert not self._updates_applied self._updates_applied = True all_ops = [] # Check for no-op. if allow_no_op and len(self._devices) == 0: with tfutil.absolute_name_scope(self.scope): return tf.no_op(name='TrainingOp') # Clean up gradients. for device_idx, device in enumerate(self._devices.values()): with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name): for var, grad in device.grad_raw.items(): # Filter out disconnected gradients and convert to float32. grad = [g for g in grad if g is not None] grad = [tf.cast(g, tf.float32) for g in grad] # Sum within the device. if len(grad) == 0: grad = tf.zeros(var.shape) # No gradients => zero. elif len(grad) == 1: grad = grad[0] # Single gradient => use as is. else: grad = tf.add_n(grad) # Multiple gradients => sum. # Scale as needed. scale = 1.0 / len(device.grad_raw[var]) / len(self._devices) scale = tf.constant(scale, dtype=tf.float32, name="scale") if self.minibatch_multiplier is not None: scale /= tf.cast(self.minibatch_multiplier, tf.float32) scale = self.undo_loss_scaling(scale) device.grad_clean[var] = grad * scale # Sum gradients across devices. if len(self._devices) > 1: with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None): if platform.system() == "Windows": # Windows => NCCL ops are not available. self._broadcast_fallback() elif tf.VERSION.startswith("1.15."): # TF 1.15 => NCCL ops are broken: https://github.com/tensorflow/tensorflow/issues/41539 self._broadcast_fallback() else: # Otherwise => NCCL ops are safe to use. self._broadcast_nccl() # Apply updates separately on each device. for device_idx, device in enumerate(self._devices.values()): with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name): # pylint: disable=cell-var-from-loop # Accumulate gradients over time. if self.minibatch_multiplier is None: acc_ok = tf.constant(True, name='acc_ok') device.grad_acc = OrderedDict(device.grad_clean) else: # Create variables. with tf.control_dependencies(None): for var in device.grad_clean.keys(): device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var") device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count") # Track counter. count_cur = device.grad_acc_count + 1.0 count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur) count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([])) acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32)) all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op)) # Track gradients. for var, grad in device.grad_clean.items(): acc_var = device.grad_acc_vars[var] acc_cur = acc_var + grad device.grad_acc[var] = acc_cur with tf.control_dependencies([acc_cur]): acc_inc_op = lambda: tf.assign(acc_var, acc_cur) acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape)) all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op)) # No overflow => apply gradients. all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()])) apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()]) all_ops.append(tf.cond(all_ok, apply_op, tf.no_op)) # Adjust loss scaling. if self.use_loss_scaling: ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc) ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec) ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op)) all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op)) # Last device => report statistics. if device_idx == len(self._devices) - 1: all_ops.append(autosummary.autosummary(self.id + "/learning_rate", tf.convert_to_tensor(self.learning_rate))) all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok)) if self.use_loss_scaling: all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var)) # Initialize variables. self.reset_optimizer_state() if self.use_loss_scaling: tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()]) if self.minibatch_multiplier is not None: tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]]) # Group everything into a single op. with tfutil.absolute_name_scope(self.scope): return tf.group(*all_ops, name="TrainingOp")
def build_trainer(self, child_model): # actor child_model.build_valid_rl() self.valid_acc = (tf.to_float(child_model.valid_shuffle_acc) / tf.to_float(child_model.batch_size)) self.reward = self.valid_acc if self.use_critic: # critic all_h = tf.concat(self.all_h, axis=0) value_function = tf.matmul(all_h, self.w_critic) advantage = value_function - self.reward critic_loss = tf.reduce_sum(advantage**2) self.baseline = tf.reduce_mean(value_function) self.loss = -tf.reduce_mean(self.sample_log_probs * advantage) critic_train_step = tf.Variable(0, dtype=tf.int32, trainable=False, name="critic_train_step") critic_train_op, _, _, _ = get_train_ops(critic_loss, [self.w_critic], critic_train_step, clip_mode=None, lr_init=1e-3, lr_dec_start=0, lr_dec_every=int(1e9), optim_algo="adam", sync_replicas=False) else: # or baseline self.sample_log_probs = tf.reduce_sum(self.sample_log_probs) self.baseline = tf.Variable(0.0, dtype=tf.float32, trainable=False) baseline_update = tf.assign_sub(self.baseline, (1 - self.bl_dec) * (self.baseline - self.reward)) with tf.control_dependencies([baseline_update]): self.reward = tf.identity(self.reward) self.loss = self.sample_log_probs * (self.reward - self.baseline) self.train_step = tf.Variable(0, dtype=tf.int32, trainable=False, name="train_step") tf_variables = [ var for var in tf.trainable_variables() if var.name.startswith(self.name) and "w_critic" not in var.name ] print "-" * 80 for var in tf_variables: print var self.train_op, self.lr, self.grad_norm, self.optimizer = get_train_ops( self.loss, tf_variables, self.train_step, clip_mode=self.clip_mode, grad_bound=self.grad_bound, l2_reg=self.l2_reg, lr_init=self.lr_init, lr_dec_start=self.lr_dec_start, lr_dec_every=self.lr_dec_every, lr_dec_rate=self.lr_dec_rate, optim_algo=self.optim_algo, sync_replicas=self.sync_replicas, num_aggregate=self.num_aggregate, num_replicas=self.num_replicas) if self.use_critic: self.train_op = tf.group(self.train_op, critic_train_op)
def step_fn(self, params, model): """A single step for supervised learning.""" (train_images, train_labels, valid_images, valid_labels) = tf.raw_ops.InfeedDequeueTuple( dtypes=params.train_dtypes, shapes=params.train_shapes) if train_labels.dtype == tf.int32: train_labels = tf.one_hot(train_labels, depth=params.num_classes, dtype=tf.float32) if valid_labels.dtype == tf.int32: valid_labels = tf.one_hot(valid_labels, depth=params.num_classes, dtype=tf.float32) global_step = tf.train.get_or_create_global_step() num_replicas = tf.cast(params.num_replicas, tf.float32) with tf.variable_scope(MODEL_SCOPE): train_logits = model(train_images, training=True) with tf.variable_scope(SCORE_SCOPE): score_logits = model(train_images, training=False, return_scores=True) score_m = tf.tpu.cross_replica_sum(tf.reduce_sum(score_logits)) score_m = tf.stop_gradient(score_m) / float(params.num_replicas) score_e = tf.exp(score_logits - score_m) score_z = tf.tpu.cross_replica_sum(tf.reduce_sum(score_e)) score_probs = score_e / score_z # train the main model cross_entropy = tf.losses.softmax_cross_entropy( onehot_labels=train_labels, logits=train_logits, label_smoothing=params.label_smoothing, reduction=tf.losses.Reduction.NONE) cross_entropy = tf.reduce_sum(cross_entropy * tf.stop_gradient(score_probs)) l2_reg_rate = tf.cast(params.weight_decay / params.num_replicas, tf.float32) weight_dec = common_utils.get_l2_loss(excluded_keywords=[SCORE_SCOPE]) total_loss = cross_entropy + weight_dec * l2_reg_rate model_variables = [ v for v in tf.trainable_variables() if MODEL_SCOPE in v.name ] train_gradients = tf.gradients(total_loss, model_variables) train_gradients = [ tf.tpu.cross_replica_sum(g) for g in train_gradients ] train_gradients, grad_norm = tf.clip_by_global_norm( train_gradients, params.grad_bound) learning_rate, optimizer = common_utils.get_optimizer(params) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) train_op = tf.cond( tf.math.is_finite(grad_norm), lambda: optimizer. apply_gradients(zip(train_gradients, model_variables), global_step=global_step), tf.no_op) with tf.control_dependencies(update_ops + [train_op]): ema_train_op = common_utils.setup_ema( params, f'{MODEL_SCOPE}/{model.name}') with tf.control_dependencies([ema_train_op]): with tf.variable_scope(MODEL_SCOPE, reuse=True): valid_logits = model(valid_images, training=False) valid_cross_entropy = tf.losses.softmax_cross_entropy( onehot_labels=valid_labels, logits=valid_logits, reduction=tf.losses.Reduction.MEAN) / float( params.num_replicas) valid_gradients = tf.gradients(valid_cross_entropy, model_variables) valid_gradients = [ tf.tpu.cross_replica_sum(g) for g in valid_gradients ] dot_product = tf.add_n([ tf.reduce_sum(g_t * g_v) for g_t, g_v in zip(train_gradients, valid_gradients) ]) dot_product = tf.stop_gradient(dot_product) dot_product_avg = tf.get_variable(name='dot_product_avg', shape=[], trainable=False) dot_product_update = tf.assign_sub( dot_product_avg, 0.01 * (dot_product_avg - dot_product)) with tf.control_dependencies([dot_product_update]): dot_product = tf.identity(dot_product - dot_product_avg) # trains the scorer. score_entropy = tf.reduce_sum(-score_probs * tf.math.log(score_probs)) score_entropy = tf.tpu.cross_replica_sum(score_entropy) / float( valid_images.shape[0].value) score_variables = [ v for v in tf.trainable_variables() if SCORE_SCOPE in v.name ] score_gradients = tf.gradients(dot_product * score_entropy, score_variables) score_gradients = [ tf.tpu.cross_replica_sum(g) for g in score_gradients ] score_optimizer = tf.train.GradientDescentOptimizer( learning_rate=params.scorer_lr, use_locking=True) score_train_op = tf.cond( global_step < params.scorer_wait_steps, tf.no_op, lambda: score_optimizer.apply_gradients( zip(score_gradients, score_variables))) with tf.control_dependencies([score_train_op]): logs = collections.OrderedDict() logs['global_step'] = tf.cast(global_step, tf.float32) logs['model/total'] = total_loss logs['model/weight_decay'] = weight_dec / num_replicas logs['model/cross_entropy'] = cross_entropy logs['model/lr'] = tf.identity(learning_rate) / num_replicas logs['model/grad_norm'] = grad_norm / num_replicas logs['score/dot_product'] = dot_product / num_replicas logs['score/dot_product_avg'] = dot_product_avg / num_replicas logs['score/entropy'] = score_entropy logs['score/p_min'] = tf.reduce_min(score_probs) / num_replicas logs['score/p_max'] = tf.reduce_max(score_probs) / num_replicas tensors = [tf.expand_dims(t, axis=0) for t in logs.values()] self.step_info = {k: [tf.float32, [1]] for k in logs.keys()} outfeed_enqueue_op = tf.cond( common_utils.should_log(params), lambda: tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors), tf.no_op) return outfeed_enqueue_op