def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") logits, loss = model_backbone(features, labels, mesh) variables = graph._all_variables for v in variables: logger.debug("[parameter] (name,shape,dtype): ({},{},{})".format(v.name,v.shape,v.dtype.master_dtype)) mesh_shape = mtf.convert_to_shape(args_opt.mesh_shape) # layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss]) mesh_shape = mtf.convert_to_shape(mesh_shape) estimator = memory_estimator.MemoryEstimator(graph, mesh_shape, [logits, loss]) optimizer = layout_optimizer.LayoutOptimizer(estimator,scheduler_alg="NAIVE") layout_rules = mtf.convert_to_layout_rules(optimizer.solve()) # layout_rules=[('batch', 'b1')] logger.info("[auto mtf search] strategy: {}".format(layout_rules)) mesh_devices = ["gpu:{}".format(i) for i in range(int(args_opt.num_gpus))] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(mesh_shape, layout_rules, mesh_devices) if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.SgdOptimizer(0.01) # optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.export_to_tf_tensor(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) accuracy = tf.metrics.accuracy( labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") logging_hook = tf.train.LoggingTensorHook(every_n_iter=100,tensors={'loss': 'cross_entropy','acc':'train_accuracy'}) # profiling_hook = tf.estimator.ProfilerHook(save_steps=20, output_dir='./profiling/') # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook,logging_hook])
def layout(mtf_graph, mesh_shape, mtf_outputs=()): """Compute layout rules based on a computational graph and mesh shape. Args: mtf_graph: a mtf.Graph. mesh_shape: an mtf.Shape, str, or listlike of mtf.Dimension. mtf_outputs: an optional iterable of mtf.Tensor, representing the outputs of the computation. Returns: a mtf.LayoutRules """ mesh_shape = mtf.convert_to_shape(mesh_shape) estimator = memory_estimator.MemoryEstimator(mtf_graph, mesh_shape, mtf_outputs) optimizer = layout_optimizer.LayoutOptimizer(estimator) return mtf.convert_to_layout_rules(optimizer.solve())
def layout_and_mesh_shape(mtf_graph, num_machines, mtf_outputs=(), max_mesh_shape_dimensions=2): """Compute layout rules and mesh shape based on computational graph. Brute-forces over all possible mesh shapes to find a (layout, mesh_shape) pair. Note that the layout optimizer is more efficient when the mesh_shape has fewer dimensions, so a smaller max_mesh_shape_dimensions makes this call faster. Args: mtf_graph: a mtf.Graph. num_machines: integer, a power of two, the number of machines available. mtf_outputs: an optional iterable of mtf.Tensor, representing the outputs of the computation. max_mesh_shape_dimensions: optional integer, the maximum number of dimensions to consider in any layout. For example, num_machines=1024 and max_mesh_shape_dimensions=2 results in testing the mesh shapes "mesh_0:1024", "mesh_0:512;mesh_1:2", "mesh_0:256;mesh_1:4", "mesh_0:128;mesh_1:8", "mesh_0:64;mesh_1:16", and "mesh_0:32;mesh_1:32". If set to None, there is no maximum. Returns: a (mtf.LayoutRules, mtf.Shape) tuple. """ best_layout_and_mesh_shape = (None, None) best_value = None for mesh_shape_list in _mesh_shape_iterator(num_machines, max_mesh_shape_dimensions): mesh_shape = mtf.Shape([ mtf.Dimension("mesh_{}".format(i), size) for i, size in enumerate(mesh_shape_list) ]) tf.logging.info( "Computing layout for mesh shape: {}".format(mesh_shape)) estimator = memory_estimator.MemoryEstimator(mtf_graph, mesh_shape, mtf_outputs) optimizer = layout_optimizer.LayoutOptimizer(estimator) layout_string = optimizer.solve() value = optimizer.evaluate_layout(layout_string) if best_value is None or value < best_value: best_value = value best_layout_and_mesh_shape = ( mtf.convert_to_layout_rules(layout_string), mesh_shape) return best_layout_and_mesh_shape
def get_layout_optimizer(self): return layout_optimizer.LayoutOptimizer( memory_estimator.MemoryEstimator(self.mtf_graph, self.mesh_shape))