Exemplo n.º 1
0
  def _run_train_phase():
    """The real function that runs the training phase."""
    # Setup input pipeline.
    ds_creator = unet.get_dataset_creator('train')
    mtf_shapes = unet.get_input_mtf_shapes('train')

    model_train_fn, train_hooks, _ = _get_model_fn('train', mesh_context)

    if FLAGS.use_tpu:
      assert mesh_context.device_assignment
      assert mesh_context.num_cores
      simd_input_reader = input_reader.SimdMeshImplInputReader(
          mesh_context.mesh_impl, ds_creator, mtf_shapes,
          external_worker=(not FLAGS.on_gcp), is_eval_mode=False)
      train_computation = tpu.replicate(
          computation=model_train_fn,
          inputs=[[]] * mesh_context.num_cores,
          infeed_queue=simd_input_reader.infeed_queue,
          device_assignment=mesh_context.device_assignment)

    else:
      placement_input_reader = input_reader.PlacementMeshImplInputReader(
          mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False)
      train_computation = placement_input_reader.gpu_placement(model_train_fn)

    ###########################################################
    # Training.
    master_to_slice_hook, slice_to_master_hook = train_hooks.get()
    ckpt_loader_hook = _CkptLoaderHook()
    step_counter_hook = tf.train.StepCounterHook(every_n_steps=10)
    all_hooks = [ckpt_loader_hook, master_to_slice_hook,
                 slice_to_master_hook, step_counter_hook]

    if FLAGS.write_summary:
      flush_summary = contrib_summary.flush()

    with tf.train.MonitoredTrainingSession(
        master=master,
        scaffold=_get_scaffold(additional_initializers=[]),
        hooks=all_hooks,
        config=config) as sess:

      if FLAGS.write_summary:
        contrib_summary.initialize(session=sess)

      if FLAGS.use_tpu:
        simd_input_reader.start_infeed_thread(
            sess, FLAGS.num_train_iterations_per_loop)
      else:
        placement_input_reader.initialize(sess)

      for step in range(FLAGS.num_train_iterations_per_loop):
        sess.run(train_computation)
        if FLAGS.write_summary:
          sess.run(flush_summary)
        tf.logging.info('train steps: {}'.format(step))
Exemplo n.º 2
0
  def _run_eval_phase():
    """The real function that runs the evaluation phase."""
    # Setup input pipeline.
    ds_creator = unet.get_dataset_creator('eval')
    mtf_shapes = unet.get_input_mtf_shapes('eval')

    model_eval_fn, eval_hooks, output_dtypes_shapes = _get_model_fn(
        'eval', mesh_context)

    if FLAGS.use_tpu:
      assert mesh_context.device_assignment
      assert mesh_context.num_cores
      simd_input_reader = input_reader.SimdMeshImplInputReader(
          mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=True)
      eval_computation = tpu.replicate(
          computation=model_eval_fn,
          inputs=[[]] * mesh_context.num_cores,
          infeed_queue=simd_input_reader.infeed_queue,
          device_assignment=mesh_context.device_assignment)

      output_dtypes, output_shapes = output_dtypes_shapes.get()
      outfeed_dequeue_ops = []

      # Create outfeed_dequeue_ops.
      for host_id in range(mesh_context.num_hosts):
        # pylint: disable=protected-access
        with ops.device(input_reader._host_id_to_tf_device(
            host_id, external_worker=True)):
          for device_ordinal in range(mesh_context.num_cores_per_host):
            outfeed_dequeue_op = tpu_ops.outfeed_dequeue_tuple(
                dtypes=output_dtypes,
                shapes=output_shapes,
                device_ordinal=device_ordinal)

            # We don't need output other than from core 0.
            if outfeed_dequeue_ops:
              outfeed_dequeue_ops.append(
                  [tf.reduce_mean(x) for x in outfeed_dequeue_op])
            else:
              outfeed_dequeue_ops.append(outfeed_dequeue_op)

    else:
      placement_input_reader = input_reader.PlacementMeshImplInputReader(
          mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False)
      eval_computation = placement_input_reader.gpu_placement(model_eval_fn)

    ###########################################################
    # Evaluation.
    master_to_slice_hook, _ = eval_hooks.get()
    ckpt_loader_hook = _CkptLoaderHook()
    all_hooks = [ckpt_loader_hook, master_to_slice_hook]

    if FLAGS.write_summary:
      flush_summary = contrib_summary.flush()

    with tf.train.MonitoredSession(
        session_creator=tf.train.ChiefSessionCreator(
            master=FLAGS.master,
            config=tf.ConfigProto(allow_soft_placement=True)),
        hooks=all_hooks) as sess:

      if FLAGS.write_summary:
        contrib_summary.initialize(session=sess)

      if FLAGS.use_tpu:
        simd_input_reader.start_infeed_thread(
            sess, FLAGS.num_eval_iterations_per_loop)
      else:
        placement_input_reader.initialize(sess)

      pprocessor = unet.PostProcessor()
      for step in range(FLAGS.num_eval_iterations_per_loop):
        # Only get results from the 0-th core.
        if FLAGS.use_tpu:
          sess.run(eval_computation)
          results = sess.run(outfeed_dequeue_ops)[0]
        else:
          results = sess.run(eval_computation)
        pprocessor.record(results, FLAGS.pred_output_dir)

        if FLAGS.write_summary:
          sess.run(flush_summary)
        tf.logging.info('eval steps: {}'.format(step))
      pprocessor.finish()
Exemplo n.º 3
0
    def __init__(self):
        threads = 8
        graph = tf.Graph()
        self.session = tf.Session(graph=graph,
                                  config=tf.ConfigProto(
                                      inter_op_parallelism_threads=threads,
                                      intra_op_parallelism_threads=threads))

        with graph.as_default():
            self.images = tf.placeholder(tf.float32,
                                         shape=[None, 16, 16, 3],
                                         name="images")

            z_dim = 10

            def encoder(img):
                out = tf.layers.flatten(img)
                out = tf.layers.dense(out, 500, activation=tf.nn.relu)
                out = tf.layers.dense(out, 500, activation=tf.nn.relu)
                out = tf.layers.dense(out, z_dim, activation=tf.nn.relu)

                return out

            def decoder(z):
                out = tf.layers.dense(z, 500, activation=tf.nn.relu)
                out = tf.layers.dense(out, 500, activation=tf.nn.relu)
                out = tf.layers.dense(out, 16 * 16 * 3, activation=None)

                return tf.reshape(out, [-1, 16, 16, 3])

            self.z = encoder(self.images)
            self.generated_logits = decoder(self.z)
            self.generated_images = tf.nn.sigmoid(self.generated_logits,
                                                  name="generated_images")

            self.loss = tf.losses.mean_squared_error(self.images,
                                                     self.generated_images)

            global_step = tf.train.create_global_step()

            self.training = tf.train.AdamOptimizer().minimize(
                self.loss, global_step=global_step)

            logdir = "logs/autoencoder-{}-{}".format(
                z_dim,
                datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S"))

            summary_writer = tfsum.create_file_writer(logdir,
                                                      flush_millis=10 * 1000)
            with summary_writer.as_default(
            ), tfsum.record_summaries_every_n_global_steps(100):
                self.summaries = [
                    tfsum.scalar("loss", self.loss),
                    tfsum.histogram("latent", self.z)
                ]

            self.generated_images_summary_data = tf.placeholder(
                tf.float32, [None, None, 3])
            with summary_writer.as_default(), tfsum.always_record_summaries():
                self.generated_images_summary = tfsum.image(
                    "generated_image",
                    tf.expand_dims(self.generated_images_summary_data, axis=0))

            init = tf.global_variables_initializer()
            self.session.run(init)

            with summary_writer.as_default():
                tfsum.initialize(session=self.session,
                                 graph=self.session.graph)
Exemplo n.º 4
0
        summary.histogram("obj masks", obj_p[0])
        summary.histogram("flow_x_hist", flow[:, :, :, 0], family="flow")
        summary.histogram("flow_y_hist", flow[:, :, :, 1], family="flow")

        summary.image("frame0", cast_im(f0), max_images=3)
        summary.image("frame1", cast_im(f1), max_images=3)
        summary.image("frame1_t", cast_im(f1_t), max_images=3)
        summary.image("depth", cast_depth(depth), max_images=3)
        summary.image("optical_flow", cast_flow(flow), max_images=3)
        summary.image("object masks", cast_im(obj_p[0]), max_images=3)

        obj_summary(obj_p)
        cam_summary(cam_p)

    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        summary.initialize(graph=tf.get_default_graph())

        model.load_weights(os.path.join(models_path, "sfm.h5"))

        for s in range(S_max):
            l, *_ = sess.run(
                [loss, optimize, summary.all_summary_ops()])
            # beholder.update(session=sess)

            if s % 50 == 0:
                print("Iteration: {}  Loss: {}".format(s, l))

            if s % 5000 == 0 and not s == 0:
                model.save_weights(os.path.join(models_path, "sfm.h5"))