Example #1
0
def model_fn(features, labels, mode, params):
	"""The model_fn argument for creating an Estimator."""
	global_step = tf.train.get_global_step()
	graph = mtf.Graph()
	mesh = mtf.Mesh(graph, "my_mesh")
	logits, loss = model_backbone(features, labels, mesh)
	
	variables = graph._all_variables
	for v in variables:
		logger.debug("[parameter] (name,shape,dtype): ({},{},{})".format(v.name,v.shape,v.dtype.master_dtype))
	mesh_shape = mtf.convert_to_shape(args_opt.mesh_shape)
	# layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss])
	mesh_shape = mtf.convert_to_shape(mesh_shape)
	estimator = memory_estimator.MemoryEstimator(graph, mesh_shape, [logits, loss])
	optimizer = layout_optimizer.LayoutOptimizer(estimator,scheduler_alg="NAIVE")
	layout_rules =  mtf.convert_to_layout_rules(optimizer.solve())

	# layout_rules=[('batch', 'b1')]

	logger.info("[auto mtf search] strategy: {}".format(layout_rules))
	mesh_devices = ["gpu:{}".format(i) for i in range(int(args_opt.num_gpus))]
	mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(mesh_shape, layout_rules, mesh_devices)



	if mode == tf.estimator.ModeKeys.TRAIN:
		var_grads = mtf.gradients(
			[loss], [v.outputs[0] for v in graph.trainable_variables])
		optimizer = mtf.optimize.SgdOptimizer(0.01)
		# optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
		update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)

	lowering = mtf.Lowering(graph, {mesh: mesh_impl})
	restore_hook = mtf.MtfRestoreHook(lowering)

	tf_logits = lowering.export_to_tf_tensor(logits)
	if mode != tf.estimator.ModeKeys.PREDICT:
		tf_loss = lowering.export_to_tf_tensor(loss)

	if mode == tf.estimator.ModeKeys.TRAIN:
		tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
		tf_update_ops.append(tf.assign_add(global_step, 1))
		train_op = tf.group(tf_update_ops)

		accuracy = tf.metrics.accuracy(
			labels=labels, predictions=tf.argmax(tf_logits, axis=1))

		# Name tensors to be logged with LoggingTensorHook.
		tf.identity(tf_loss, "cross_entropy")
		tf.identity(accuracy[1], name="train_accuracy")

		logging_hook = tf.train.LoggingTensorHook(every_n_iter=100,tensors={'loss': 'cross_entropy','acc':'train_accuracy'})
		# profiling_hook = tf.estimator.ProfilerHook(save_steps=20, output_dir='./profiling/')
		# restore_hook must come before saver_hook
		return tf.estimator.EstimatorSpec(
			tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
			training_chief_hooks=[restore_hook,logging_hook])
Example #2
0
def layout(mtf_graph, mesh_shape, mtf_outputs=()):
  """Compute layout rules based on a computational graph and mesh shape.

  Args:
    mtf_graph: a mtf.Graph.
    mesh_shape: an mtf.Shape, str, or listlike of mtf.Dimension.
    mtf_outputs: an optional iterable of mtf.Tensor, representing the outputs
        of the computation.

  Returns:
    a mtf.LayoutRules
  """
  mesh_shape = mtf.convert_to_shape(mesh_shape)
  estimator = memory_estimator.MemoryEstimator(mtf_graph, mesh_shape,
                                               mtf_outputs)
  optimizer = layout_optimizer.LayoutOptimizer(estimator)
  return mtf.convert_to_layout_rules(optimizer.solve())
Example #3
0
def layout_and_mesh_shape(mtf_graph,
                          num_machines,
                          mtf_outputs=(),
                          max_mesh_shape_dimensions=2):
    """Compute layout rules and mesh shape based on computational graph.

  Brute-forces over all possible mesh shapes to find a (layout, mesh_shape)
  pair. Note that the layout optimizer is more efficient when the mesh_shape has
  fewer dimensions, so a smaller max_mesh_shape_dimensions makes this call
  faster.

  Args:
    mtf_graph: a mtf.Graph.
    num_machines: integer, a power of two, the number of machines available.
    mtf_outputs: an optional iterable of mtf.Tensor, representing the outputs
        of the computation.
    max_mesh_shape_dimensions: optional integer, the maximum number of
        dimensions to consider in any layout. For example, num_machines=1024 and
        max_mesh_shape_dimensions=2 results in testing the mesh shapes
        "mesh_0:1024", "mesh_0:512;mesh_1:2", "mesh_0:256;mesh_1:4",
        "mesh_0:128;mesh_1:8", "mesh_0:64;mesh_1:16", and "mesh_0:32;mesh_1:32".
        If set to None, there is no maximum.

  Returns:
    a (mtf.LayoutRules, mtf.Shape) tuple.
  """
    best_layout_and_mesh_shape = (None, None)
    best_value = None
    for mesh_shape_list in _mesh_shape_iterator(num_machines,
                                                max_mesh_shape_dimensions):
        mesh_shape = mtf.Shape([
            mtf.Dimension("mesh_{}".format(i), size)
            for i, size in enumerate(mesh_shape_list)
        ])
        tf.logging.info(
            "Computing layout for mesh shape: {}".format(mesh_shape))
        estimator = memory_estimator.MemoryEstimator(mtf_graph, mesh_shape,
                                                     mtf_outputs)
        optimizer = layout_optimizer.LayoutOptimizer(estimator)
        layout_string = optimizer.solve()
        value = optimizer.evaluate_layout(layout_string)
        if best_value is None or value < best_value:
            best_value = value
            best_layout_and_mesh_shape = (
                mtf.convert_to_layout_rules(layout_string), mesh_shape)
    return best_layout_and_mesh_shape
Example #4
0
 def get_layout_optimizer(self):
     return layout_optimizer.LayoutOptimizer(
         memory_estimator.MemoryEstimator(self.mtf_graph, self.mesh_shape))