def get_placement_mesh(hparams): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") mesh_shape = mtf.parse_mesh_shape(hparams.mesh_shape) mesh_size = mtf.list_product(mesh_shape) mesh_devices = [""] * mesh_size mesh_impl = placement_mesh_impl.PlacementMeshImpl( mesh_shape, mtf.parse_layout(hparams.layout), mesh_devices) return mesh, mesh_impl
def dense(x, output_dim, reduced_dims=None, expert_dims=None, use_bias=True, activation=None, name=None): """Dense layer doing (kernel*x + bias) computation. Args: x: a mtf.Tensor of shape [..., reduced_dims]. output_dim: a mtf.Dimension reduced_dims: an optional list of mtf.Dimensions of x to be reduced. If omitted, we reduce the last dimension. expert_dims: an optional list of mtf.Dimension which represent different experts. Different experts get different weights. use_bias: a boolean, whether to add bias. activation: an optional function from mtf.Tensor to mtf.Tensor name: a string. variable scope. Returns: a mtf.Tensor of shape [..., output_dim]. """ if expert_dims is None: expert_dims = [] if reduced_dims is None: reduced_dims = x.shape.dims[-1:] w_shape = mtf.Shape(expert_dims + reduced_dims + [output_dim]) output_shape = mtf.Shape( [d for d in x.shape.dims if d not in reduced_dims] + [output_dim]) with tf.variable_scope(name, default_name="dense"): stddev = mtf.list_product(d.size for d in reduced_dims)**-0.5 w = mtf.get_variable( x.mesh, "kernel", w_shape, initializer=tf.random_normal_initializer(stddev=stddev), activation_dtype=x.dtype) y = mtf.einsum([x, w], output_shape) if use_bias: b = mtf.get_variable(x.mesh, "bias", mtf.Shape(expert_dims + [output_dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype) y += b if activation is not None: y = activation(y) return y
def allreduce_ring(xs, devices, reduction_fn_string="SUM"): """Compute the reduction of all Tensors and put the result everywhere. Performance-optimized for a ring of devices. Args: xs: a list of n tf.Tensors devices: a list of strings reduction_fn_string: "SUM" or "MAX" Returns: a list of n Tensors Raises: ValueError: if devices is not a list of n strings """ n = len(xs) if len(devices) != n: raise ValueError("devices must be a list of length len(xs)") if n == 1: return xs shape = xs[0].shape.as_list() # tf.logging.info("allreduce_ring shape = %s" % shape) size = None if None in shape else mtf.list_product(shape) if size is None or size < 1024 or size % n != 0: return allreduce_ring_single_shard(xs, devices, reduction_fn_string) def _circular_shift(l, n): n %= len(l) return l[-n:] + l[:-n] def _flatten_and_split(x): return tf.split(tf.reshape(x, [size]), n) def _concat_and_reshape(xs): return tf.reshape(tf.concat(xs, 0), shape) # [device, shard] x_split = mtf.parallel(devices, _flatten_and_split, xs) x_split_t = mtf.transpose_list_of_lists(x_split) y_split_t = [] for shard in xrange(n): shard_xs = _circular_shift(x_split_t[shard], shard) shard_devices = _circular_shift(devices, shard) shard_ys = allreduce_ring_single_shard(shard_xs, shard_devices, reduction_fn_string) y_split_t.append(_circular_shift(shard_ys, -shard)) y_split = mtf.transpose_list_of_lists(y_split_t) ys = mtf.parallel(devices, _concat_and_reshape, y_split) return ys
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" tf.logging.info("features = %s labels = %s mode = %s params=%s" % (features, labels, mode, params)) global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") logits, loss = mnist_model(features, labels, mesh) mesh_shape = mtf.parse_mesh_shape(FLAGS.mesh_shape) mesh_size = mtf.list_product(mesh_shape) mesh_devices = [""] * mesh_size mesh_impl = placement_mesh_impl.PlacementMeshImpl( mesh_shape, mtf.parse_layout(FLAGS.layout), mesh_devices) if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf_optimize.AdafactorOptimizer() update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.outfeed(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.outfeed(loss) tf.summary.scalar("loss", tf_loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook(FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) accuracy = tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") # Save accuracy scalar to Tensorboard output. tf.summary.scalar("train_accuracy", accuracy[1]) # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { "classes": tf.argmax(tf_logits, axis=1), "probabilities": tf.nn.softmax(tf_logits), } return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[restore_hook], export_outputs={ "classify": tf.estimator.export.PredictOutput(predictions) }) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=tf_loss, evaluation_hooks=[restore_hook], eval_metric_ops={ "accuracy": tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)), })
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False, xla_compile=False): hparams = copy.deepcopy(hparams) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning( "Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") mesh_shape = mtf.parse_mesh_shape(hparams.mesh_shape) mesh_size = mtf.list_product(mesh_shape) if use_tpu: mesh_devices = [""] * mesh_size mesh_impl = simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.parse_layout(hparams.layout), mesh_devices, params["context"].device_assignment) else: if len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_size else: assert len(data_parallelism.ps_devices) == mesh_size mesh_devices = data_parallelism.ps_devices mesh_impl = placement_mesh_impl.PlacementMeshImpl( mesh_shape, mtf.parse_layout(hparams.layout), mesh_devices) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) mtf_lr = mtf.infeed(mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.TensorShape([])) optimizer = mtf_optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.outfeed(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.outfeed(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf_utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.outfeed(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: _remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.parse_mesh_shape(FLAGS.mesh_shape) mesh_size = mtf.list_product(mesh_shape) mesh_devices = [''] * mesh_size mesh_impl = SimdMeshImpl(mesh_shape, mtf.parse_layout(FLAGS.layout), mesh_devices, params['context'].device_assignment) with mtf_utils.outside_all_rewrites(): logits, loss = toy_model(features, mesh) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf_optimize.AdafactorOptimizer() update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.outfeed(loss) tf_logits = lowering.outfeed(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info('tf_update_ops: {}'.format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf_utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logitss = tf.metrics.mean(tf_logits) return {'mean_logitss': mean_logitss} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)