コード例 #1
0
ファイル: model_executor.py プロジェクト: bruinxiong/mesh-1
def train_and_eval():
  """Trains and evaluates MeshTensorflow model without TPUEstimator.

  TODO(lehou): Pack everything nicely as a set of APIs.
  """

  mesh_context = None
  tf.logging.info('FLAGS.master: {}'.format(FLAGS.master))
  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(FLAGS.master)
  tf.config.experimental_connect_to_cluster(resolver)
  config = tf.ConfigProto()
  config.allow_soft_placement = True
  cluster_spec = resolver.cluster_spec()
  if cluster_spec:
    config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
  with tf.Session(target=resolver.master(), config=config) as sess:
    tf.tpu.experimental.initialize_tpu_system(resolver)
    mesh_context = MeshContext(
        sess, FLAGS.use_tpu, FLAGS.mesh_shape, unet.get_layout())

  for _ in range(FLAGS.num_training_loops):
    _train_phase(mesh_context)
    _eval_phase(mesh_context)

  if FLAGS.use_tpu:
    _shutdown()

  tf.logging.info('finished.')
コード例 #2
0
def train_and_eval():
    """Trains and evaluates MeshTensorflow model without TPUEstimator.

  TODO(lehou): Pack everything nicely as a set of APIs.
  """

    mesh_context = None
    tf.logging.info('FLAGS.master: {}'.format(FLAGS.master))
    with tf.Session(target=FLAGS.master,
                    config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        mesh_context = MeshContext(sess, FLAGS.use_tpu, FLAGS.mesh_shape,
                                   unet.get_layout())

    for _ in range(FLAGS.num_training_loops):
        _train_phase(mesh_context)
        _eval_phase(mesh_context)

    if FLAGS.use_tpu:
        _shutdown()

    tf.logging.info('finished.')
コード例 #3
0
def train_and_eval():
    """Trains and evaluates MeshTensorflow model without TPUEstimator.

  TODO(lehou): Pack everything nicely as a set of APIs.
  """
    tf.logging.info('FLAGS.master: {}'.format(FLAGS.master))

    # Open a session to get the list of CPU devices to hold master variables.
    with tf.Session(target=FLAGS.master,
                    config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        topology = sess.run(tpu.initialize_system())
        cpu_devices = _list_cpu_devices(sess)

    topo_object = tf.contrib.tpu.Topology(serialized=topology)
    num_cores = int(np.prod(topo_object.mesh_shape))
    num_hosts = int(topo_object.num_tasks)
    num_cores_per_host = int(num_cores // num_hosts)
    assert num_cores_per_host == int(topo_object.num_tpus_per_task)

    # Get a device_assignment object for mtf.
    d_assignment = device_assignment.device_assignment(
        topology, computation_shape=[1, 1, 1], num_replicas=num_cores)

    # Get mesh_impl.
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = unet.get_layout()
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, None,
                                                d_assignment)

    for _ in range(FLAGS.num_training_loops):
        _train_phase(mesh_impl, cpu_devices, d_assignment, num_hosts,
                     num_cores)
        _eval_phase(mesh_impl, cpu_devices, d_assignment, num_hosts, num_cores)

    _shutdown()

    tf.logging.info('finished.')
コード例 #4
0
    def _model_fn(input_fea, input_lab):
        """Creates a model, add summary, modes (train or eval), and hooks."""
        def _add_summary(lowering, train_or_eval, tf_loss, scalars,
                         global_step):
            """Add all summaries."""
            for k in scalars.keys():
                if not isinstance(scalars[k], tf.Tensor):
                    scalars[k] = tf.cast(
                        lowering.export_to_tf_tensor(scalars[k]), tf.float32)

            def _host_loss_summary(global_step, tf_loss, **scalars):
                """Add summary.scalar in host side."""
                gs = tf.cast(global_step, tf.int64)
                sum_loss = tf.contrib.summary.scalar(
                    '{}_loss'.format(train_or_eval), tf_loss, step=gs)
                sum_ops = [sum_loss.op]
                for description, tf_metric in scalars.iteritems():
                    sum_metric = tf.contrib.summary.scalar('{}_{}'.format(
                        train_or_eval, description),
                                                           tf_metric,
                                                           step=gs)
                    sum_ops.append(sum_metric)
                with tf.control_dependencies(sum_ops):
                    return tf.identity(tf_loss)

            # Cast the global step to tf.int32, since
            # outside_compilation does not support tf.int64.
            tf_loss = tpu.outside_compilation(_host_loss_summary,
                                              tf.cast(global_step, tf.int32),
                                              tf_loss, **scalars)

            return tf_loss

        global_step = tf.train.get_or_create_global_step()
        graph = mtf.Graph()

        # Worker 0 caches all the TPU binaries.
        replica_cache_size = 300 * 1024 * 1024  # 300M per replica.
        worker0_mem = replica_cache_size * 8 * num_hosts
        devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1)

        tf.logging.info('cpu_devices: {}, devices_mem: {}'.format(
            cpu_devices, devices_memory_usage))
        var_placer = mtf.utils.BalancedVariablePlacer(cpu_devices,
                                                      devices_memory_usage)

        mesh = mtf.Mesh(graph, 'my_mesh', var_placer)

        mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
        layout_rules = unet.get_layout()
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                    None, d_assignment)

        with mtf.utils.outside_all_rewrites():  # Do not tpu_rewrite this part.
            preds, loss, scalars, bn_update_ops = (
                unet.unet_with_spatial_partition(mesh, train_or_eval,
                                                 input_fea, input_lab))

        if train_or_eval == 'train':
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])

            lr = FLAGS.lr * tf.pow(
                FLAGS.lr_drop_rate,
                tf.floor(
                    tf.cast(global_step, tf.float32) / FLAGS.lr_drop_steps))
            scalars['learning_rate'] = lr

            optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=lr)
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

            # This is where the actual tf graph got built.
            lowering = mtf.Lowering(graph, {mesh: mesh_impl})

            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf_update_ops.extend(
                [lowering.lowered_operation(op) for op in bn_update_ops])
            tf_update_ops_group = tf.group(tf_update_ops)

        else:  # train_or_eval == 'eval':
            preds = [mtf.anonymize(pred) for pred in preds]

            # This is where the actual tf graph got built.
            lowering = mtf.Lowering(graph, {mesh: mesh_impl})

            tf_preds = [
                tf.cast(lowering.export_to_tf_tensor(pred), tf.float32)
                for pred in preds
            ]

        tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32)
        if FLAGS.write_summary:
            tf_loss = _add_summary(lowering, train_or_eval, tf_loss, scalars,
                                   global_step)
        master_to_slice_hook = mtf.MtfRestoreHook(lowering)

        if train_or_eval == 'train':
            with mtf.utils.outside_all_rewrites():
                saver = tf.train.Saver(tf.global_variables(),
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                slice_to_master_hook = tf.train.CheckpointSaverHook(
                    FLAGS.checkpoint_dir,
                    save_steps=FLAGS.save_checkpoints_steps,
                    saver=saver,
                    listeners=[saver_listener])
                captured_hooks.capture(
                    [master_to_slice_hook, slice_to_master_hook])
                return tf_update_ops_group

        else:  # train_or_eval == 'eval':
            tf_preds.extend([tf_loss, global_step])
            tf_preds_dtypes = [tf_pred.dtype for tf_pred in tf_preds]
            tf_preds_shapes = [tf_pred.shape for tf_pred in tf_preds]
            captured_hooks.capture([master_to_slice_hook, None])
            captured_output_dtypes_shapes.capture(
                [tf_preds_dtypes, tf_preds_shapes])
            return tpu_ops.outfeed_enqueue_tuple(tf_preds)