Exemplo n.º 1
0
def main(_):
    train_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train')
    save_image_dir = os.path.join(train_dir, 'images')
    if not os.path.exists(train_dir):
        os.makedirs(train_dir)
    if not os.path.exists(save_image_dir):
        os.makedirs(save_image_dir)

    g = tf.Graph()
    with g.as_default():
        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            global_step = slim.get_or_create_global_step()
            ##########
            ## data ##
            ##########
            train_data = model.get_inputs(FLAGS.inp_dir,
                                          FLAGS.dataset_name,
                                          'train',
                                          FLAGS.batch_size,
                                          FLAGS.image_size,
                                          is_training=True)
            inputs = model.preprocess(train_data, FLAGS.step_size)
            ###########
            ## model ##
            ###########
            model_fn = model.get_model_fn(FLAGS, is_training=True)
            outputs = model_fn(inputs)
            ##########
            ## loss ##
            ##########
            task_loss = model.get_loss(inputs, outputs, FLAGS)
            regularization_loss = model.get_regularization_loss(
                ['encoder', 'rotator', 'decoder'], FLAGS)
            loss = task_loss + regularization_loss
            ###############
            ## optimizer ##
            ###############
            optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
            if FLAGS.sync_replicas:
                optimizer = tf.train.SyncReplicasOptimizer(
                    optimizer,
                    replicas_to_aggregate=FLAGS.workers_replicas -
                    FLAGS.backup_workers,
                    total_num_replicas=FLAGS.worker_replicas)

            ##############
            ## train_op ##
            ##############
            train_op = model.get_train_op_for_scope(
                loss, optimizer, ['encoder', 'rotator', 'decoder'], FLAGS)
            ###########
            ## saver ##
            ###########
            saver = tf.train.Saver(
                max_to_keep=np.minimum(5, FLAGS.worker_replicas + 1))

            if FLAGS.task == 0:
                val_data = model.get_inputs(FLAGS.inp_dir,
                                            FLAGS.dataset_name,
                                            'val',
                                            FLAGS.batch_size,
                                            FLAGS.image_size,
                                            is_training=False)
                val_inputs = model.preprocess(val_data, FLAGS.step_size)
                # Note: don't compute loss here
                reused_model_fn = model.get_model_fn(FLAGS,
                                                     is_training=False,
                                                     reuse=True)
                val_outputs = reused_model_fn(val_inputs)
                with tf.device(tf.DeviceSpec(device_type='CPU')):
                    if FLAGS.step_size == 1:
                        vis_input_images = val_inputs['images_0'] * 255.0
                        vis_output_images = val_inputs['images_1'] * 255.0
                        vis_pred_images = val_outputs['images_1'] * 255.0
                        vis_pred_masks = (val_outputs['masks_1'] *
                                          (-1) + 1) * 255.0
                    else:
                        rep_times = int(np.ceil(32.0 / float(FLAGS.step_size)))
                        vis_list_1 = []
                        vis_list_2 = []
                        vis_list_3 = []
                        vis_list_4 = []
                        for j in xrange(rep_times):
                            for k in xrange(FLAGS.step_size):
                                vis_input_image = val_inputs['images_0'][j],
                                vis_output_image = val_inputs['images_%d' %
                                                              (k + 1)][j]
                                vis_pred_image = val_outputs['images_%d' %
                                                             (k + 1)][j]
                                vis_pred_mask = val_outputs['masks_%d' %
                                                            (k + 1)][j]
                                vis_list_1.append(
                                    tf.expand_dims(vis_input_image, 0))
                                vis_list_2.append(
                                    tf.expand_dims(vis_output_image, 0))
                                vis_list_3.append(
                                    tf.expand_dims(vis_pred_image, 0))
                                vis_list_4.append(
                                    tf.expand_dims(vis_pred_mask, 0))

                        vis_list_1 = tf.reshape(tf.stack(vis_list_1), [
                            rep_times * FLAGS.step_size, FLAGS.image_size,
                            FLAGS.image_size, 3
                        ])
                        vis_list_2 = tf.reshape(tf.stack(vis_list_2), [
                            rep_times * FLAGS.step_size, FLAGS.image_size,
                            FLAGS.image_size, 3
                        ])
                        vis_list_3 = tf.reshape(tf.stack(vis_list_3), [
                            rep_times * FLAGS.step_size, FLAGS.image_size,
                            FLAGS.image_size, 3
                        ])
                        vis_list_4 = tf.reshape(tf.stack(vis_list_4), [
                            rep_times * FLAGS.step_size, FLAGS.image_size,
                            FLAGS.image_size, 1
                        ])

                        vis_input_images = vis_list_1 * 255.0
                        vis_output_images = vis_list_2 * 255.0
                        vis_pred_images = vis_list_3 * 255.0
                        vis_pred_masks = (vis_list_4 * (-1) + 1) * 255.0

                    write_disk_op = model.write_disk_grid(
                        global_step=global_step,
                        summary_freq=FLAGS.save_every,
                        log_dir=save_image_dir,
                        input_images=vis_input_images,
                        output_images=vis_output_images,
                        pred_images=vis_pred_images,
                        pred_masks=vis_pred_masks)
                with tf.control_dependencies([write_disk_op]):
                    train_op = tf.identity(train_op)

            #############
            ## init_fn ##
            #############
            init_fn = model.get_init_fn(['encoder, '
                                         'rotator', 'decoder'], FLAGS)

            ##############
            ## training ##
            ##############
            slim.learning.train(train_op=train_op,
                                logdir=train_dir,
                                init_fn=init_fn,
                                master=FLAGS.master,
                                is_chief=(FLAGS.task == 0),
                                number_of_steps=FLAGS.max_number_of_steps,
                                saver=saver,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
Exemplo n.º 2
0
def main(_):
  train_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train')
  save_image_dir = os.path.join(train_dir, 'images')
  if not os.path.exists(train_dir):
    os.makedirs(train_dir)
  if not os.path.exists(save_image_dir):
    os.makedirs(save_image_dir)

  g = tf.Graph()
  with g.as_default():
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      global_step = slim.get_or_create_global_step()
      ##########
      ## data ##
      ##########
      train_data = model.get_inputs(
          FLAGS.inp_dir,
          FLAGS.dataset_name,
          'train',
          FLAGS.batch_size,
          FLAGS.image_size,
          is_training=True)
      inputs = model.preprocess(train_data, FLAGS.step_size)
      ###########
      ## model ##
      ###########
      model_fn = model.get_model_fn(FLAGS, is_training=True)
      outputs = model_fn(inputs)
      ##########
      ## loss ##
      ##########
      task_loss = model.get_loss(inputs, outputs, FLAGS)
      regularization_loss = model.get_regularization_loss(
          ['encoder', 'rotator', 'decoder'], FLAGS)
      loss = task_loss + regularization_loss
      ###############
      ## optimizer ##
      ###############
      optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
      if FLAGS.sync_replicas:
        optimizer = tf.train.SyncReplicasOptimizer(
            optimizer,
            replicas_to_aggregate=FLAGS.workers_replicas - FLAGS.backup_workers,
            total_num_replicas=FLAGS.worker_replicas)

      ##############
      ## train_op ##
      ##############
      train_op = model.get_train_op_for_scope(
          loss, optimizer, ['encoder', 'rotator', 'decoder'], FLAGS)
      ###########
      ## saver ##
      ###########
      saver = tf.train.Saver(max_to_keep=np.minimum(5,
                                                    FLAGS.worker_replicas + 1))

      if FLAGS.task == 0:
        val_data = model.get_inputs(
            FLAGS.inp_dir,
            FLAGS.dataset_name,
            'val',
            FLAGS.batch_size,
            FLAGS.image_size,
            is_training=False)
        val_inputs = model.preprocess(val_data, FLAGS.step_size)
        # Note: don't compute loss here
        reused_model_fn = model.get_model_fn(
            FLAGS, is_training=False, reuse=True)
        val_outputs = reused_model_fn(val_inputs)
        with tf.device(tf.DeviceSpec(device_type='CPU')):
          if FLAGS.step_size == 1:
            vis_input_images = val_inputs['images_0'] * 255.0
            vis_output_images = val_inputs['images_1'] * 255.0
            vis_pred_images = val_outputs['images_1'] * 255.0
            vis_pred_masks = (val_outputs['masks_1'] * (-1) + 1) * 255.0
          else:
            rep_times = int(np.ceil(32.0 / float(FLAGS.step_size)))
            vis_list_1 = []
            vis_list_2 = []
            vis_list_3 = []
            vis_list_4 = []
            for j in xrange(rep_times):
              for k in xrange(FLAGS.step_size):
                vis_input_image = val_inputs['images_0'][j],
                vis_output_image = val_inputs['images_%d' % (k + 1)][j]
                vis_pred_image = val_outputs['images_%d' % (k + 1)][j]
                vis_pred_mask = val_outputs['masks_%d' % (k + 1)][j]
                vis_list_1.append(tf.expand_dims(vis_input_image, 0))
                vis_list_2.append(tf.expand_dims(vis_output_image, 0))
                vis_list_3.append(tf.expand_dims(vis_pred_image, 0))
                vis_list_4.append(tf.expand_dims(vis_pred_mask, 0))

            vis_list_1 = tf.reshape(
                tf.stack(vis_list_1), [
                    rep_times * FLAGS.step_size, FLAGS.image_size,
                    FLAGS.image_size, 3
                ])
            vis_list_2 = tf.reshape(
                tf.stack(vis_list_2), [
                    rep_times * FLAGS.step_size, FLAGS.image_size,
                    FLAGS.image_size, 3
                ])
            vis_list_3 = tf.reshape(
                tf.stack(vis_list_3), [
                    rep_times * FLAGS.step_size, FLAGS.image_size,
                    FLAGS.image_size, 3
                ])
            vis_list_4 = tf.reshape(
                tf.stack(vis_list_4), [
                    rep_times * FLAGS.step_size, FLAGS.image_size,
                    FLAGS.image_size, 1
                ])

            vis_input_images = vis_list_1 * 255.0
            vis_output_images = vis_list_2 * 255.0
            vis_pred_images = vis_list_3 * 255.0
            vis_pred_masks = (vis_list_4 * (-1) + 1) * 255.0

          write_disk_op = model.write_disk_grid(
              global_step=global_step,
              summary_freq=FLAGS.save_every,
              log_dir=save_image_dir,
              input_images=vis_input_images,
              output_images=vis_output_images,
              pred_images=vis_pred_images,
              pred_masks=vis_pred_masks)
        with tf.control_dependencies([write_disk_op]):
          train_op = tf.identity(train_op)

      #############
      ## init_fn ##
      #############
      init_fn = model.get_init_fn(['encoder, ' 'rotator', 'decoder'], FLAGS)

      ##############
      ## training ##
      ##############
      slim.learning.train(
          train_op=train_op,
          logdir=train_dir,
          init_fn=init_fn,
          master=FLAGS.master,
          is_chief=(FLAGS.task == 0),
          number_of_steps=FLAGS.max_number_of_steps,
          saver=saver,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)