def single_step_fn():
        """Function for a single TPU step."""
        all_input_data = multi_device_iterator.get_next()
        for core in range(FLAGS.num_cores):
            features_shape, features, labels = all_input_data[core]
            flattened_inputs = (inputs_flattener.flatten_features_and_labels(
                features, labels))
            per_host_sharded_inputs.append(flattened_inputs)

            if params['transpose_input']:
                is_height_short_side = tf.less(features_shape[0],
                                               features_shape[1])
            else:
                is_height_short_side = tf.less(features_shape[1],
                                               features_shape[2])

        def height_short_side_model_fn(*args):
            """Mode function for input images with height on the short side."""
            features, labels = inputs_flattener.unflatten_features_and_labels(
                args)
            features, labels = _set_feature_and_label_shapes(
                features, labels, params)
            spec = mask_rcnn_model.mask_rcnn_model_fn(
                features, labels, tf.estimator.ModeKeys.TRAIN, params)
            captured_scaffold_fn.capture(spec.scaffold_fn)
            return spec.train_op

        def height_long_side_model_fn(*args):
            """Mode function for input images with height on the long side."""
            features, labels = inputs_flattener.unflatten_features_and_labels(
                args)
            # Create a new params which has the reversed dynamic image shape.
            new_params = copy.deepcopy(params)
            new_params['dynamic_image_size'] = new_params[
                'dynamic_image_size'][::-1]
            features, labels = _set_feature_and_label_shapes(
                features, labels, new_params)
            spec = mask_rcnn_model.mask_rcnn_model_fn(
                features, labels, tf.estimator.ModeKeys.TRAIN, new_params)
            captured_scaffold_fn.capture(spec.scaffold_fn)
            return spec.train_op

        rewrite_computation = tf.cond(
            is_height_short_side,
            lambda: tpu.replicate(height_short_side_model_fn,
                                  per_host_sharded_inputs),  # pylint: disable=line-too-long
            lambda: tpu.replicate(height_long_side_model_fn,
                                  per_host_sharded_inputs)  # pylint: disable=line-too-long
        )

        return rewrite_computation
Esempio n. 2
0
  def _run_train_phase():
    """The real function that runs the training phase."""
    # Setup input pipeline.
    ds_creator = unet.get_dataset_creator('train')
    mtf_shapes = unet.get_input_mtf_shapes('train')

    model_train_fn, train_hooks, _ = _get_model_fn('train', mesh_context)

    if FLAGS.use_tpu:
      assert mesh_context.device_assignment
      assert mesh_context.num_cores
      simd_input_reader = input_reader.SimdMeshImplInputReader(
          mesh_context.mesh_impl, ds_creator, mtf_shapes,
          external_worker=(not FLAGS.on_gcp), is_eval_mode=False)
      train_computation = tpu.replicate(
          computation=model_train_fn,
          inputs=[[]] * mesh_context.num_cores,
          infeed_queue=simd_input_reader.infeed_queue,
          device_assignment=mesh_context.device_assignment)

    else:
      placement_input_reader = input_reader.PlacementMeshImplInputReader(
          mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False)
      train_computation = placement_input_reader.gpu_placement(model_train_fn)

    ###########################################################
    # Training.
    master_to_slice_hook, slice_to_master_hook = train_hooks.get()
    ckpt_loader_hook = _CkptLoaderHook()
    step_counter_hook = tf.train.StepCounterHook(every_n_steps=10)
    all_hooks = [ckpt_loader_hook, master_to_slice_hook,
                 slice_to_master_hook, step_counter_hook]

    if FLAGS.write_summary:
      flush_summary = contrib_summary.flush()

    with tf.train.MonitoredTrainingSession(
        master=master,
        scaffold=_get_scaffold(additional_initializers=[]),
        hooks=all_hooks,
        config=config) as sess:

      if FLAGS.write_summary:
        contrib_summary.initialize(session=sess)

      if FLAGS.use_tpu:
        simd_input_reader.start_infeed_thread(
            sess, FLAGS.num_train_iterations_per_loop)
      else:
        placement_input_reader.initialize(sess)

      for step in range(FLAGS.num_train_iterations_per_loop):
        sess.run(train_computation)
        if FLAGS.write_summary:
          sess.run(flush_summary)
        tf.logging.info('train steps: {}'.format(step))
Esempio n. 3
0
def freeze_graph_tpu(model_path):
    """Custom freeze_graph implementation for Cloud TPU."""

    assert model_path
    assert FLAGS.tpu_name
    if FLAGS.tpu_name.startswith('grpc://'):
        tpu_grpc_url = FLAGS.tpu_name
    else:
        tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=None, project=None)
        tpu_grpc_url = tpu_cluster_resolver.get_master()
    sess = tf.Session(tpu_grpc_url)

    output_names = []
    with sess.graph.as_default():
        # Replicate the inference function for each TPU core.
        replicated_features = []
        feature_type = tf.bool if FLAGS.bool_features else tf.float32
        for i in range(FLAGS.num_tpu_cores):
            name = 'pos_tensor_%d' % i
            features = tf.placeholder(feature_type, [None], name=name)
            replicated_features.append((features, ))
        outputs = contrib_tpu.replicate(tpu_model_inference_fn,
                                        replicated_features)

        # The replicate op assigns names like output_0_shard_0 to the output
        # names. Give them human readable names.
        for i, (policy_output, value_output, _) in enumerate(outputs):
            policy_name = 'policy_output_%d' % i
            value_name = 'value_output_%d' % i
            output_names.extend([policy_name, value_name])
            tf.identity(policy_output, policy_name)
            tf.identity(value_output, value_name)

        tf.train.Saver().restore(sess, model_path)

    out_graph = tf.graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(), output_names)

    metadata = make_model_metadata({
        'engine': 'tpu',
        'num_replicas': FLAGS.num_tpu_cores,
    })

    minigo_model.write_graph_def(out_graph, metadata, model_path + '.minigo')
Esempio n. 4
0
    def _run_train_phase():
        """The real function that runs the training phase."""
        # Setup input pipeline.
        ds_creator = unet.get_dataset_creator('train')
        mtf_shapes = unet.get_input_mtf_shapes('train')
        simd_input_reader = input_reader.SimdMeshImplInputReader(
            mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False)

        model_train_fn, train_hooks, _ = _get_model_fn('train', cpu_devices,
                                                       d_assignment, num_hosts)
        tpu_train_computation = tpu.replicate(
            computation=model_train_fn,
            inputs=[[]] * num_cores,
            infeed_queue=simd_input_reader.infeed_queue,
            device_assignment=d_assignment)

        ###########################################################
        # Training.
        master_to_slice_hook, slice_to_master_hook = train_hooks.get()
        ckpt_loader_hook = _CkptLoaderHook()
        if FLAGS.write_summary:
            flush_summary = tf.contrib.summary.flush()

        with tf.train.MonitoredTrainingSession(
                master=FLAGS.master,
                scaffold=_get_scaffold(additional_initializers=[]),
                hooks=[
                    ckpt_loader_hook, master_to_slice_hook,
                    slice_to_master_hook
                ],
                config=tf.ConfigProto(allow_soft_placement=True)) as sess:

            if FLAGS.write_summary:
                tf.contrib.summary.initialize(session=sess)
            simd_input_reader.start_infeed_thread(
                sess, FLAGS.num_train_iterations_per_loop)

            for step in range(FLAGS.num_train_iterations_per_loop):
                sess.run(tpu_train_computation)
                if FLAGS.write_summary:
                    sess.run(flush_summary)
                tf.logging.info('train steps: {}'.format(step))
Esempio n. 5
0
  def _run_eval_phase():
    """The real function that runs the evaluation phase."""
    # Setup input pipeline.
    ds_creator = unet.get_dataset_creator('eval')
    mtf_shapes = unet.get_input_mtf_shapes('eval')

    model_eval_fn, eval_hooks, output_dtypes_shapes = _get_model_fn(
        'eval', mesh_context)

    if FLAGS.use_tpu:
      assert mesh_context.device_assignment
      assert mesh_context.num_cores
      simd_input_reader = input_reader.SimdMeshImplInputReader(
          mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=True)
      eval_computation = tpu.replicate(
          computation=model_eval_fn,
          inputs=[[]] * mesh_context.num_cores,
          infeed_queue=simd_input_reader.infeed_queue,
          device_assignment=mesh_context.device_assignment)

      output_dtypes, output_shapes = output_dtypes_shapes.get()
      outfeed_dequeue_ops = []

      # Create outfeed_dequeue_ops.
      for host_id in range(mesh_context.num_hosts):
        # pylint: disable=protected-access
        with ops.device(input_reader._host_id_to_tf_device(
            host_id, external_worker=True)):
          for device_ordinal in range(mesh_context.num_cores_per_host):
            outfeed_dequeue_op = tpu_ops.outfeed_dequeue_tuple(
                dtypes=output_dtypes,
                shapes=output_shapes,
                device_ordinal=device_ordinal)

            # We don't need output other than from core 0.
            if outfeed_dequeue_ops:
              outfeed_dequeue_ops.append(
                  [tf.reduce_mean(x) for x in outfeed_dequeue_op])
            else:
              outfeed_dequeue_ops.append(outfeed_dequeue_op)

    else:
      placement_input_reader = input_reader.PlacementMeshImplInputReader(
          mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False)
      eval_computation = placement_input_reader.gpu_placement(model_eval_fn)

    ###########################################################
    # Evaluation.
    master_to_slice_hook, _ = eval_hooks.get()
    ckpt_loader_hook = _CkptLoaderHook()
    all_hooks = [ckpt_loader_hook, master_to_slice_hook]

    if FLAGS.write_summary:
      flush_summary = contrib_summary.flush()

    with tf.train.MonitoredSession(
        session_creator=tf.train.ChiefSessionCreator(
            master=FLAGS.master,
            config=tf.ConfigProto(allow_soft_placement=True)),
        hooks=all_hooks) as sess:

      if FLAGS.write_summary:
        contrib_summary.initialize(session=sess)

      if FLAGS.use_tpu:
        simd_input_reader.start_infeed_thread(
            sess, FLAGS.num_eval_iterations_per_loop)
      else:
        placement_input_reader.initialize(sess)

      pprocessor = unet.PostProcessor()
      for step in range(FLAGS.num_eval_iterations_per_loop):
        # Only get results from the 0-th core.
        if FLAGS.use_tpu:
          sess.run(eval_computation)
          results = sess.run(outfeed_dequeue_ops)[0]
        else:
          results = sess.run(eval_computation)
        pprocessor.record(results, FLAGS.pred_output_dir)

        if FLAGS.write_summary:
          sess.run(flush_summary)
        tf.logging.info('eval steps: {}'.format(step))
      pprocessor.finish()