コード例 #1
0
  def _run_train_phase():
    """The real function that runs the training phase."""
    # Setup input pipeline.
    ds_creator = unet.get_dataset_creator('train')
    mtf_shapes = unet.get_input_mtf_shapes('train')

    model_train_fn, train_hooks, _ = _get_model_fn('train', mesh_context)

    if FLAGS.use_tpu:
      assert mesh_context.device_assignment
      assert mesh_context.num_cores
      simd_input_reader = input_reader.SimdMeshImplInputReader(
          mesh_context.mesh_impl, ds_creator, mtf_shapes,
          external_worker=(not FLAGS.on_gcp), is_eval_mode=False)
      train_computation = tpu.replicate(
          computation=model_train_fn,
          inputs=[[]] * mesh_context.num_cores,
          infeed_queue=simd_input_reader.infeed_queue,
          device_assignment=mesh_context.device_assignment)

    else:
      placement_input_reader = input_reader.PlacementMeshImplInputReader(
          mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False)
      train_computation = placement_input_reader.gpu_placement(model_train_fn)

    ###########################################################
    # Training.
    master_to_slice_hook, slice_to_master_hook = train_hooks.get()
    ckpt_loader_hook = _CkptLoaderHook()
    step_counter_hook = tf.train.StepCounterHook(every_n_steps=10)
    all_hooks = [ckpt_loader_hook, master_to_slice_hook,
                 slice_to_master_hook, step_counter_hook]

    if FLAGS.write_summary:
      flush_summary = contrib_summary.flush()

    with tf.train.MonitoredTrainingSession(
        master=master,
        scaffold=_get_scaffold(additional_initializers=[]),
        hooks=all_hooks,
        config=config) as sess:

      if FLAGS.write_summary:
        contrib_summary.initialize(session=sess)

      if FLAGS.use_tpu:
        simd_input_reader.start_infeed_thread(
            sess, FLAGS.num_train_iterations_per_loop)
      else:
        placement_input_reader.initialize(sess)

      for step in range(FLAGS.num_train_iterations_per_loop):
        sess.run(train_computation)
        if FLAGS.write_summary:
          sess.run(flush_summary)
        tf.logging.info('train steps: {}'.format(step))
コード例 #2
0
    def _run_train_phase():
        """The real function that runs the training phase."""
        # Setup input pipeline.
        ds_creator = unet.get_dataset_creator('train')
        mtf_shapes = unet.get_input_mtf_shapes('train')
        simd_input_reader = input_reader.SimdMeshImplInputReader(
            mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False)

        model_train_fn, train_hooks, _ = _get_model_fn('train', cpu_devices,
                                                       d_assignment, num_hosts)
        tpu_train_computation = tpu.replicate(
            computation=model_train_fn,
            inputs=[[]] * num_cores,
            infeed_queue=simd_input_reader.infeed_queue,
            device_assignment=d_assignment)

        ###########################################################
        # Training.
        master_to_slice_hook, slice_to_master_hook = train_hooks.get()
        ckpt_loader_hook = _CkptLoaderHook()
        if FLAGS.write_summary:
            flush_summary = tf.contrib.summary.flush()

        with tf.train.MonitoredTrainingSession(
                master=FLAGS.master,
                scaffold=_get_scaffold(additional_initializers=[]),
                hooks=[
                    ckpt_loader_hook, master_to_slice_hook,
                    slice_to_master_hook
                ],
                config=tf.ConfigProto(allow_soft_placement=True)) as sess:

            if FLAGS.write_summary:
                tf.contrib.summary.initialize(session=sess)
            simd_input_reader.start_infeed_thread(
                sess, FLAGS.num_train_iterations_per_loop)

            for step in range(FLAGS.num_train_iterations_per_loop):
                sess.run(tpu_train_computation)
                if FLAGS.write_summary:
                    sess.run(flush_summary)
                tf.logging.info('train steps: {}'.format(step))
コード例 #3
0
ファイル: model_executor.py プロジェクト: bruinxiong/mesh-1
  def _run_eval_phase():
    """The real function that runs the evaluation phase."""
    # Setup input pipeline.
    ds_creator = unet.get_dataset_creator('eval')
    mtf_shapes = unet.get_input_mtf_shapes('eval')

    model_eval_fn, eval_hooks, output_dtypes_shapes = _get_model_fn(
        'eval', mesh_context)

    if FLAGS.use_tpu:
      assert mesh_context.device_assignment
      assert mesh_context.num_cores
      simd_input_reader = input_reader.SimdMeshImplInputReader(
          mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=True)
      eval_computation = tpu.replicate(
          computation=model_eval_fn,
          inputs=[[]] * mesh_context.num_cores,
          infeed_queue=simd_input_reader.infeed_queue,
          device_assignment=mesh_context.device_assignment)

      output_dtypes, output_shapes = output_dtypes_shapes.get()
      outfeed_dequeue_ops = []

      # Create outfeed_dequeue_ops.
      for host_id in range(mesh_context.num_hosts):
        # pylint: disable=protected-access
        with ops.device(input_reader._host_id_to_tf_device(
            host_id, external_worker=True)):
          for device_ordinal in range(mesh_context.num_cores_per_host):
            outfeed_dequeue_op = tpu_ops.outfeed_dequeue_tuple(
                dtypes=output_dtypes,
                shapes=output_shapes,
                device_ordinal=device_ordinal)

            # We don't need output other than from core 0.
            if outfeed_dequeue_ops:
              outfeed_dequeue_ops.append(
                  [tf.reduce_mean(x) for x in outfeed_dequeue_op])
            else:
              outfeed_dequeue_ops.append(outfeed_dequeue_op)

    else:
      placement_input_reader = input_reader.PlacementMeshImplInputReader(
          mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False)
      eval_computation = placement_input_reader.gpu_placement(model_eval_fn)

    ###########################################################
    # Evaluation.
    master_to_slice_hook, _ = eval_hooks.get()
    ckpt_loader_hook = _CkptLoaderHook()
    all_hooks = [ckpt_loader_hook, master_to_slice_hook]

    if FLAGS.write_summary:
      flush_summary = contrib_summary.flush()

    with tf.train.MonitoredSession(
        session_creator=tf.train.ChiefSessionCreator(
            master=FLAGS.master,
            config=tf.ConfigProto(allow_soft_placement=True)),
        hooks=all_hooks) as sess:

      if FLAGS.write_summary:
        contrib_summary.initialize(session=sess)

      if FLAGS.use_tpu:
        simd_input_reader.start_infeed_thread(
            sess, FLAGS.num_eval_iterations_per_loop)
      else:
        placement_input_reader.initialize(sess)

      pprocessor = unet.PostProcessor()
      for step in range(FLAGS.num_eval_iterations_per_loop):
        # Only get results from the 0-th core.
        if FLAGS.use_tpu:
          sess.run(eval_computation)
          results = sess.run(outfeed_dequeue_ops)[0]
        else:
          results = sess.run(eval_computation)
        pprocessor.record(results, FLAGS.pred_output_dir)

        if FLAGS.write_summary:
          sess.run(flush_summary)
        tf.logging.info('eval steps: {}'.format(step))
      pprocessor.finish()
コード例 #4
0
    def test_get_laidout_tensors(self, is_eval_mode):
        mesh_shape = "mesh_x:2, mesh_y:1"
        layout = "batch:mesh_x, io:mesh_y"
        batch_io_dim = 4

        with tf.Session() as sess:
            topology, num_cores = self.initialize_system(sess)

            # Get a device_assignment object for mtf.
            d_assignment = device_assignment.device_assignment(
                topology,
                computation_shape=[
                    1,
                ] * mtf.utils.topology_rank(topology),
                num_replicas=num_cores)

            # Hacked dataset creator: creates different datasets for the first and
            # second call, in order to test SimdMeshImplInputReader.
            self.sub_batch_created_times = 0

            def stateful_ds_creator():
                whole_batch = tf.eye(batch_io_dim, dtype=tf.float32)
                sub_batch = tf.slice(whole_batch,
                                     [self.sub_batch_created_times * 2, 0],
                                     [2, 4])
                self.sub_batch_created_times += 1
                return tf.data.Dataset.from_tensors(
                    sub_batch).repeat().unbatch()

            batch_dim = mtf.Dimension("batch", batch_io_dim)
            io_dim = mtf.Dimension("io", batch_io_dim)
            mtf_input_shapes = [mtf.Shape([batch_dim, io_dim])]

            # Get mesh_impl.
            mesh_shape = mtf.convert_to_shape(mesh_shape)
            layout_rules = mtf.convert_to_layout_rules(layout)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, None, d_assignment)

            simd_input_reader = input_reader.SimdMeshImplInputReader(
                mesh_impl,
                stateful_ds_creator,
                mtf_input_shapes,
                external_worker=False,
                is_eval_mode=is_eval_mode)

            def model_fn(features):
                return features

            replicated_computation = tpu.replicate(
                computation=model_fn,
                inputs=[[]] * num_cores,
                infeed_queue=simd_input_reader.infeed_queue,
                device_assignment=d_assignment)

            simd_input_reader.start_infeed_thread(sess, 1)
            results = sess.run(replicated_computation)
            print("results: {}".format(results))

            core_0_data = results[0][0]
            core_1_data = results[1][0]
            print("core_0_data: {}".format(core_0_data))
            print("core_1_data: {}".format(core_1_data))

            if is_eval_mode:
                # If there is only one dataset object, then the stateful_ds_creator()
                # should be called only once.
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_0_data)
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_1_data)
            else:
                # If there are two dataset objects, then the stateful_ds_creator()
                # should be called twice.
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_0_data)
                self.assertAllClose(
                    np.array([[0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32),
                    core_1_data)

            sess.run(tf.tpu.shutdown_system())