def __init__(self, dims=10, seed=None, initial_dist=None, A_dist=None, A_noise_dist=None, B_dist=None, B_noise_dist=None, C_dist=None, C_noise_dist=None, grad_noise_dist=None, output_fn=None, weight_rescale=1.0): """Initializer for a sampled quadratic task. The loss for this task is described by: X = param * weight_rescale output_fn((AX-B)^2 + C) where param is initialized by: param = initial_dist.sample() / weight_rescale A, B, C are sampled once at task creation using either a random seed or the seed specified by `seed`. Each iteration {A, B, C}_noise_dist is sampled and added to the fixed values. Gradients are computed using backprop but have additional noise sampled from grad_noise_dist added to them. Args: dims: int Number of dims of base problem. seed: optional int Seed passed into sample function of the different distributions. initial_dist: tf.distributions.Distribution Distribution returning a tensor of the same size of dims. This is used as the initial value for the task. A_dist: tf.distributions.Distribution Distribution over the quadratic term. This should return a dims x dims matrix when sampled. This is sampled once at task construction. A_noise_dist: tf.distributions.Distribution Distribution over noise added to the quadratic term. B_dist: tf.distribution.Distribution Distribution over the linear term. This should be of size dims. This is sampled once at task construction. B_noise_dist: tf.distribution.Distribution Distribution over noise added to the linear term. C_dist: tf.distribution.Distribution Distribution over the scalar term. This should be a scalar. This is sampled once at task construction. C_noise_dist: tf.distribution.Distribution Distribution over noise added to the scalar term. grad_noise_dist: tf.distribution.Distribution Distribution over noise added to the gradient. output_fn: Callable[tf.Tensor, tf.Tensor] Callable applied just before returning the loss. weight_rescale: float Weight rescaling to change step size dynamics. """ super(QuadraticBasedTask, self).__init__() if not A_noise_dist: A_noise_dist = ConstantDistribution(0.) if not B_noise_dist: B_noise_dist = ConstantDistribution(0.) if not C_noise_dist: C_noise_dist = ConstantDistribution(0.) if not grad_noise_dist: grad_noise_dist = ConstantDistribution(0.) self.A_noise_dist = A_noise_dist self.B_noise_dist = B_noise_dist self.C_noise_dist = C_noise_dist self.grad_noise_dist = grad_noise_dist self.output_fn = output_fn self.seed = seed self.dims = dims self.A_dist = A_dist self.weight_rescale = weight_rescale with self._enter_variable_scope(): self.initial_dist = initial_dist init = initial_dist.sample(seed=seed + 1 if seed else None) self.weight = tf.get_variable( name="weight", initializer=init, trainable=True) A = A_dist.sample(seed=seed + 2 if seed else None) self.A = tf.get_variable("A", initializer=A, trainable=False) B = B_dist.sample(seed=seed + 3 if seed else None) self.B = tf.get_variable("B", initializer=B, trainable=False) C = C_dist.sample(seed=seed + 4 if seed else None) self.C = tf.get_variable("C", initializer=C, trainable=False)
def apply_gradients(self, grads_and_vars, global_step=None, name=None): """See base class.""" assignments = [] background_lr = distill_util.get_background_lr( global_step=global_step, steps_per_phase=self.steps_per_phase) for (grad, param) in grads_and_vars: if grad is None or param is None: continue param_name = self._get_variable_name(param.name) m = tf.get_variable(name=param_name + "/adam_m", shape=param.shape.as_list(), dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) v = tf.get_variable(name=param_name + "/adam_v", shape=param.shape.as_list(), dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) if self.use_layer_wise_warmup: # Use model-specific name spaces to get layer id. if param_name.startswith("bert/encoder/layer_"): layer_id = int( param_name[len("bert/encoder/layer_"):].split("/", 1)[0]) layer_wise_lr = distill_util.layer_wise_learning_rate( layer_id=layer_id, steps_per_phase=self.steps_per_phase, background_lr=background_lr) layer_wise_gate = tf.where( tf.math.greater(layer_wise_lr, 0.0), 1.0, 0.0) else: layer_wise_lr = 0.0 layer_wise_gate = 0.0 else: layer_wise_lr = 1.0 layer_wise_gate = 1.0 # Standard Adam update. next_m = layer_wise_gate * (tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) next_v = layer_wise_gate * ( tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, tf.square(grad))) update = next_m / (tf.sqrt(next_v) + self.epsilon) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want ot decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. if self._do_use_weight_decay(param_name): update += layer_wise_gate * self.weight_decay_rate * param ratio = 1.0 if self._do_layer_adaptation(param_name): w_norm = tf.linalg.norm(param, ord=2) g_norm = tf.linalg.norm(update, ord=2) ratio = tf.where( tf.math.greater(w_norm, 0), tf.where(tf.math.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0) update_with_lr = layer_wise_lr * ratio * self.learning_rate * update next_param = param - update_with_lr assignments.extend( [param.assign(next_param), m.assign(next_m), v.assign(next_v)]) return tf.group(*assignments, name=name)
def main(_): assert FLAGS.model_dirs_or_checkpoints if not tf.gfile.Exists(FLAGS.output_dir): tf.gfile.MakeDirs(FLAGS.output_dir) if (FLAGS.operation == "average_last_n" and len(FLAGS.model_dirs_or_checkpoints) > 1): raise ValueError("Need only 1 directory for %s operation" % FLAGS.operation) checkpoints = [] for path in FLAGS.model_dirs_or_checkpoints: if tf.gfile.IsDirectory(path): # Grab the latest checkpoint for all the provided model dirs checkpoint_state = tf.train.get_checkpoint_state(path) if FLAGS.operation == "average_last_n": ckpt_paths = tf.io.gfile.glob(os.path.join(path, "model.ckpt*index")) def sort_fn(ckpt): return int(re.sub(".*ckpt-", "", ckpt)) ckpts = sorted([c.replace(".index", "") for c in ckpt_paths], key=sort_fn) checkpoints.extend(ckpts[-FLAGS.number_of_checkpoints:]) else: checkpoints.append(checkpoint_state.all_model_checkpoint_paths[-1]) else: if FLAGS.operation == "average_last_n": raise ValueError("need a directory while running %s operation" % FLAGS.operation) checkpoints.append(path) logging.info("Using checkpoints %s", checkpoints) if FLAGS.operation in ["ensemble", "average", "average_last_n"]: if len(checkpoints) == 1: raise ValueError("no point in ensebling/averaging one checkpoint") else: if len(checkpoints) != 1: raise ValueError( "operation %s requires exactly one checkpoint" % FLAGS.operation) var_values = {} var_dtypes = {} for i in range(0, len(checkpoints)): checkpoint = checkpoints[i] logging.info("loading checkpoint %s", checkpoint) reader = tf.train.load_checkpoint(checkpoint) var_list = tf.train.list_variables(checkpoint) for (name, _) in var_list: if i: assert name in var_values tensor = reader.get_tensor(name) assert tensor.dtype == var_dtypes[name] var_values[name].append(tensor) else: tensor = reader.get_tensor(name) var_dtypes[name] = tensor.dtype var_values[name] = [tensor] if not FLAGS.global_step: if name == "global_step": FLAGS.global_step = tensor logging.info("Read from checkpoint %s", checkpoint) # stack the list of tensors along the 0th dimension. for name, tensors in var_values.items(): tensor = tensors[0] if name == "global_step": new_val = np.int32(FLAGS.global_step) elif FLAGS.operation == "ensemble": new_val = np.stack(tensors) elif FLAGS.operation == "autoensemble": new_val = np.stack([tensor] * FLAGS.autoensemble_size) elif FLAGS.operation == "average" or FLAGS.operation == "average_last_n": new_val = average_tensors(tensors) elif FLAGS.operation == "extract_first": new_val = tensor[0] else: raise ValueError("unknown FLAGS.operation=%s" % FLAGS.operation) var_values[name] = new_val with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): tf_vars = [ tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v]) for v in var_values ] placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] saver = tf.train.Saver(tf.all_variables()) output_file = "model.ckpt-" + str(FLAGS.global_step) output_path = os.path.join(FLAGS.output_dir, output_file) # Build a model consisting only of variables, set them to the average values. with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for p, assign_op, (name, value) in zip(placeholders, assign_ops, six.iteritems(var_values)): sess.run(assign_op, {p: value}) # Use the built saver to save the averaged checkpoint. saver.save(sess, output_path) logging.info("Transformed checkpoints saved in %s", output_path)