def run_training(run_dir, checkpoint_dir, hparams):
    """Runs the training loop.

  Args:
    run_dir: The directory where training specific logs are placed
    checkpoint_dir: The directory where the checkpoints and log files are
      stored.
    hparams: The hyperparameters struct.

  Raises:
    ValueError: if hparams.arch is not recognized.
  """
    for path in [run_dir, checkpoint_dir]:
        if not tf.gfile.Exists(path):
            tf.gfile.MakeDirs(path)

    # Serialize hparams to log dir
    hparams_filename = os.path.join(checkpoint_dir, 'hparams.json')
    with tf.gfile.FastGFile(hparams_filename, 'w') as f:
        f.write(hparams.to_json())

    with tf.Graph().as_default():
        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            global_step = slim.get_or_create_global_step()

            #########################
            # Preprocess the inputs #
            #########################
            target_dataset = dataset_factory.get_dataset(
                FLAGS.target_dataset,
                split_name='train',
                dataset_dir=FLAGS.dataset_dir)
            target_images, _ = dataset_factory.provide_batch(
                FLAGS.target_dataset, 'train', FLAGS.dataset_dir,
                FLAGS.num_readers, hparams.batch_size,
                FLAGS.num_preprocessing_threads)
            num_target_classes = target_dataset.num_classes

            if hparams.arch not in ['dcgan']:
                source_dataset = dataset_factory.get_dataset(
                    FLAGS.source_dataset,
                    split_name='train',
                    dataset_dir=FLAGS.dataset_dir)
                num_source_classes = source_dataset.num_classes
                source_images, source_labels = dataset_factory.provide_batch(
                    FLAGS.source_dataset, 'train', FLAGS.dataset_dir,
                    FLAGS.num_readers, hparams.batch_size,
                    FLAGS.num_preprocessing_threads)
                # Data provider provides 1 hot labels, but we expect categorical.
                source_labels['class'] = tf.argmax(source_labels['classes'], 1)
                del source_labels['classes']
                if num_source_classes != num_target_classes:
                    raise ValueError(
                        'Source and Target datasets must have same number of classes. '
                        'Are %d and %d' %
                        (num_source_classes, num_target_classes))
            else:
                source_images = None
                source_labels = None

            ####################
            # Define the model #
            ####################
            end_points = pixelda_model.create_model(
                hparams,
                target_images,
                source_images=source_images,
                source_labels=source_labels,
                is_training=True,
                num_classes=num_target_classes)

            #################################
            # Get the variables to optimize #
            #################################
            generator_vars, generator_update_ops = _get_vars_and_update_ops(
                hparams, 'generator')
            discriminator_vars, discriminator_update_ops = _get_vars_and_update_ops(
                hparams, 'discriminator')

            ########################
            # Configure the losses #
            ########################
            generator_loss = pixelda_losses.g_step_loss(
                source_images,
                source_labels,
                end_points,
                hparams,
                num_classes=num_target_classes)
            discriminator_loss = pixelda_losses.d_step_loss(
                end_points, source_labels, num_target_classes, hparams)

            ###########################
            # Create the training ops #
            ###########################
            learning_rate = hparams.learning_rate
            if hparams.lr_decay_steps:
                learning_rate = tf.train.exponential_decay(
                    learning_rate,
                    slim.get_or_create_global_step(),
                    decay_steps=hparams.lr_decay_steps,
                    decay_rate=hparams.lr_decay_rate,
                    staircase=True)
            tf.summary.scalar('Learning_rate', learning_rate)

            if hparams.discriminator_steps == 0:
                discriminator_train_op = tf.no_op()
            else:
                discriminator_optimizer = tf.train.AdamOptimizer(
                    learning_rate, beta1=hparams.adam_beta1)

                discriminator_train_op = slim.learning.create_train_op(
                    discriminator_loss,
                    discriminator_optimizer,
                    update_ops=discriminator_update_ops,
                    variables_to_train=discriminator_vars,
                    clip_gradient_norm=hparams.clip_gradient_norm,
                    summarize_gradients=FLAGS.summarize_gradients)

            if hparams.generator_steps == 0:
                generator_train_op = tf.no_op()
            else:
                generator_optimizer = tf.train.AdamOptimizer(
                    learning_rate, beta1=hparams.adam_beta1)
                generator_train_op = slim.learning.create_train_op(
                    generator_loss,
                    generator_optimizer,
                    update_ops=generator_update_ops,
                    variables_to_train=generator_vars,
                    clip_gradient_norm=hparams.clip_gradient_norm,
                    summarize_gradients=FLAGS.summarize_gradients)

            #############
            # Summaries #
            #############
            pixelda_utils.summarize_model(end_points)
            pixelda_utils.summarize_transferred_grid(
                end_points['transferred_images'],
                source_images,
                name='Transferred')
            if 'source_images_recon' in end_points:
                pixelda_utils.summarize_transferred_grid(
                    end_points['source_images_recon'],
                    source_images,
                    name='Source Reconstruction')
            pixelda_utils.summaries_color_distributions(
                end_points['transferred_images'], 'Transferred')
            pixelda_utils.summaries_color_distributions(
                target_images, 'Target')

            if source_images is not None:
                pixelda_utils.summarize_transferred(
                    source_images, end_points['transferred_images'])
                pixelda_utils.summaries_color_distributions(
                    source_images, 'Source')
                pixelda_utils.summaries_color_distributions(
                    tf.abs(source_images - end_points['transferred_images']),
                    'Abs(Source_minus_Transferred)')

            number_of_steps = None
            if hparams.num_training_examples:
                # Want to control by amount of data seen, not # steps
                number_of_steps = hparams.num_training_examples / hparams.batch_size

            hooks = [
                tf.train.StepCounterHook(),
            ]

            chief_only_hooks = [
                tf.train.CheckpointSaverHook(
                    saver=tf.train.Saver(),
                    checkpoint_dir=run_dir,
                    save_secs=FLAGS.save_interval_secs)
            ]

            if number_of_steps:
                hooks.append(
                    tf.train.StopAtStepHook(last_step=number_of_steps))

            _train(discriminator_train_op,
                   generator_train_op,
                   logdir=run_dir,
                   master=FLAGS.master,
                   is_chief=FLAGS.task == 0,
                   hooks=hooks,
                   chief_only_hooks=chief_only_hooks,
                   save_checkpoint_secs=None,
                   save_summaries_steps=FLAGS.save_summaries_steps,
                   hparams=hparams)
示例#2
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    hparams = tf.contrib.training.HParams()
    hparams.weight_decay_task_classifier = FLAGS.weight_decay

    if FLAGS.dataset_name in ['mnist', 'mnist_m', 'usps']:
        hparams.task_tower = 'mnist'
    else:
        raise ValueError('Unknown dataset %s' % FLAGS.dataset_name)

    with tf.Graph().as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.num_ps_tasks,
                                               merge_devices=True)):
            dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                                  FLAGS.split_name,
                                                  FLAGS.dataset_dir)
            num_classes = dataset.num_classes

            preprocess_fn = partial(
                pixelda_preprocess.preprocess_classification, is_training=True)

            images, labels = dataset_factory.provide_batch(
                FLAGS.dataset_name,
                FLAGS.split_name,
                dataset_dir=FLAGS.dataset_dir,
                num_readers=FLAGS.num_readers,
                batch_size=FLAGS.batch_size,
                num_preprocessing_threads=FLAGS.num_readers)
            # preprocess_fn=preprocess_fn)

            # Define the model
            logits, _ = pixelda_task_towers.add_task_specific_model(
                images, hparams, num_classes=num_classes, is_training=True)

            # Define the losses
            if 'classes' in labels:
                one_hot_labels = labels['classes']
                loss = tf.losses.softmax_cross_entropy(
                    onehot_labels=one_hot_labels, logits=logits)
                tf.summary.scalar('losses/Classification_Loss', loss)
            else:
                raise ValueError('Only support classification for now.')

            total_loss = tf.losses.get_total_loss()
            tf.summary.scalar('losses/Total_Loss', total_loss)

            # Setup the moving averages
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, slim.get_or_create_global_step())
            tf.add_to_collection(
                tf.GraphKeys.UPDATE_OPS,
                variable_averages.apply(moving_average_variables))

            # Specify the optimization scheme:
            learning_rate = tf.train.exponential_decay(
                FLAGS.learning_rate,
                slim.get_or_create_global_step(),
                FLAGS.learning_rate_decay_steps,
                FLAGS.learning_rate_decay_factor,
                staircase=True)

            optimizer = tf.train.AdamOptimizer(learning_rate,
                                               beta1=FLAGS.adam_beta1)

            train_op = slim.learning.create_train_op(total_loss, optimizer)

            slim.learning.train(train_op,
                                FLAGS.logdir,
                                master=FLAGS.master,
                                is_chief=(FLAGS.task == 0),
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
示例#3
0
def run_eval(run_dir, checkpoint_dir, hparams):
    """Runs the eval loop.

  Args:
    run_dir: The directory where eval specific logs are placed
    checkpoint_dir: The directory where the checkpoints are stored
    hparams: The hyperparameters struct.

  Raises:
    ValueError: if hparams.arch is not recognized.
  """
    for checkpoint_path in slim.evaluation.checkpoints_iterator(
            checkpoint_dir, FLAGS.eval_interval_secs):
        with tf.Graph().as_default():
            #########################
            # Preprocess the inputs #
            #########################
            target_dataset = dataset_factory.get_dataset(
                FLAGS.target_dataset,
                split_name=FLAGS.target_split_name,
                dataset_dir=FLAGS.dataset_dir)
            target_images, target_labels = dataset_factory.provide_batch(
                FLAGS.target_dataset, FLAGS.target_split_name,
                FLAGS.dataset_dir, FLAGS.num_readers, hparams.batch_size,
                FLAGS.num_preprocessing_threads)
            num_target_classes = target_dataset.num_classes
            target_labels['class'] = tf.argmax(target_labels['classes'], 1)
            del target_labels['classes']

            if hparams.arch not in ['dcgan']:
                source_dataset = dataset_factory.get_dataset(
                    FLAGS.source_dataset,
                    split_name=FLAGS.source_split_name,
                    dataset_dir=FLAGS.dataset_dir)
                num_source_classes = source_dataset.num_classes
                source_images, source_labels = dataset_factory.provide_batch(
                    FLAGS.source_dataset, FLAGS.source_split_name,
                    FLAGS.dataset_dir, FLAGS.num_readers, hparams.batch_size,
                    FLAGS.num_preprocessing_threads)
                source_labels['class'] = tf.argmax(source_labels['classes'], 1)
                del source_labels['classes']
                if num_source_classes != num_target_classes:
                    raise ValueError(
                        'Input and output datasets must have same number of classes'
                    )
            else:
                source_images = None
                source_labels = None

            ####################
            # Define the model #
            ####################
            end_points = pixelda_model.create_model(
                hparams,
                target_images,
                source_images=source_images,
                source_labels=source_labels,
                is_training=False,
                num_classes=num_target_classes)

            #######################
            # Metrics & Summaries #
            #######################
            names_to_values, names_to_updates = create_metrics(
                end_points, source_labels, target_labels, hparams)
            pixelda_utils.summarize_model(end_points)
            pixelda_utils.summarize_transferred_grid(
                end_points['transferred_images'],
                source_images,
                name='Transferred')
            if 'source_images_recon' in end_points:
                pixelda_utils.summarize_transferred_grid(
                    end_points['source_images_recon'],
                    source_images,
                    name='Source Reconstruction')
            pixelda_utils.summarize_images(target_images, 'Target')

            for name, value in names_to_values.iteritems():
                tf.summary.scalar(name, value)

            # Use the entire split by default
            num_examples = target_dataset.num_samples

            num_batches = math.ceil(num_examples / float(hparams.batch_size))
            global_step = slim.get_or_create_global_step()

            result = slim.evaluation.evaluate_once(
                master=FLAGS.master,
                checkpoint_path=checkpoint_path,
                logdir=run_dir,
                num_evals=num_batches,
                eval_op=names_to_updates.values(),
                final_op=names_to_values)
示例#4
0
def run_eval(run_dir, checkpoint_dir, hparams):
  """Runs the eval loop.

  Args:
    run_dir: The directory where eval specific logs are placed
    checkpoint_dir: The directory where the checkpoints are stored
    hparams: The hyperparameters struct.

  Raises:
    ValueError: if hparams.arch is not recognized.
  """
  for checkpoint_path in slim.evaluation.checkpoints_iterator(
      checkpoint_dir, FLAGS.eval_interval_secs):
    with tf.Graph().as_default():
      #########################
      # Preprocess the inputs #
      #########################
      target_dataset = dataset_factory.get_dataset(
          FLAGS.target_dataset,
          split_name=FLAGS.target_split_name,
          dataset_dir=FLAGS.dataset_dir)
      target_images, target_labels = dataset_factory.provide_batch(
          FLAGS.target_dataset, FLAGS.target_split_name, FLAGS.dataset_dir,
          FLAGS.num_readers, hparams.batch_size,
          FLAGS.num_preprocessing_threads)
      num_target_classes = target_dataset.num_classes
      target_labels['class'] = tf.argmax(target_labels['classes'], 1)
      del target_labels['classes']

      if hparams.arch not in ['dcgan']:
        source_dataset = dataset_factory.get_dataset(
            FLAGS.source_dataset,
            split_name=FLAGS.source_split_name,
            dataset_dir=FLAGS.dataset_dir)
        num_source_classes = source_dataset.num_classes
        source_images, source_labels = dataset_factory.provide_batch(
            FLAGS.source_dataset, FLAGS.source_split_name, FLAGS.dataset_dir,
            FLAGS.num_readers, hparams.batch_size,
            FLAGS.num_preprocessing_threads)
        source_labels['class'] = tf.argmax(source_labels['classes'], 1)
        del source_labels['classes']
        if num_source_classes != num_target_classes:
          raise ValueError(
              'Input and output datasets must have same number of classes')
      else:
        source_images = None
        source_labels = None

      ####################
      # Define the model #
      ####################
      end_points = pixelda_model.create_model(
          hparams,
          target_images,
          source_images=source_images,
          source_labels=source_labels,
          is_training=False,
          num_classes=num_target_classes)

      #######################
      # Metrics & Summaries #
      #######################
      names_to_values, names_to_updates = create_metrics(end_points,
                                                         source_labels,
                                                         target_labels, hparams)
      pixelda_utils.summarize_model(end_points)
      pixelda_utils.summarize_transferred_grid(
          end_points['transferred_images'], source_images, name='Transferred')
      if 'source_images_recon' in end_points:
        pixelda_utils.summarize_transferred_grid(
            end_points['source_images_recon'],
            source_images,
            name='Source Reconstruction')
      pixelda_utils.summarize_images(target_images, 'Target')

      for name, value in names_to_values.iteritems():
        tf.summary.scalar(name, value)

      # Use the entire split by default
      num_examples = target_dataset.num_samples

      num_batches = math.ceil(num_examples / float(hparams.batch_size))
      global_step = slim.get_or_create_global_step()

      result = slim.evaluation.evaluate_once(
          master=FLAGS.master,
          checkpoint_path=checkpoint_path,
          logdir=run_dir,
          num_evals=num_batches,
          eval_op=names_to_updates.values(),
          final_op=names_to_values)
示例#5
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    hparams = tf.contrib.training.HParams()
    hparams.weight_decay_task_classifier = 0.0

    if FLAGS.dataset_name in ['mnist', 'mnist_m', 'usps']:
        hparams.task_tower = 'mnist'
    else:
        raise ValueError('Unknown dataset %s' % FLAGS.dataset_name)

    if not tf.gfile.Exists(FLAGS.eval_dir):
        tf.gfile.MakeDirs(FLAGS.eval_dir)

    with tf.Graph().as_default():
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.split_name,
                                              FLAGS.dataset_dir)
        num_classes = dataset.num_classes
        num_samples = dataset.num_samples

        preprocess_fn = partial(pixelda_preprocess.preprocess_classification,
                                is_training=False)

        images, labels = dataset_factory.provide_batch(
            FLAGS.dataset_name,
            FLAGS.split_name,
            dataset_dir=FLAGS.dataset_dir,
            num_readers=FLAGS.num_readers,
            batch_size=FLAGS.batch_size,
            num_preprocessing_threads=FLAGS.num_readers)

        # Define the model
        logits, _ = pixelda_task_towers.add_task_specific_model(
            images, hparams, num_classes=num_classes, is_training=True)

        #####################
        # Define the losses #
        #####################
        if 'classes' in labels:
            one_hot_labels = labels['classes']
            loss = tf.losses.softmax_cross_entropy(
                onehot_labels=one_hot_labels, logits=logits)
            tf.summary.scalar('losses/Classification_Loss', loss)
        else:
            raise ValueError('Only support classification for now.')

        total_loss = tf.losses.get_total_loss()

        predictions = tf.reshape(tf.argmax(logits, 1), shape=[-1])
        class_labels = tf.argmax(labels['classes'], 1)

        metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map(
            {
                'Mean_Loss':
                tf.contrib.metrics.streaming_mean(total_loss),
                'Accuracy':
                tf.contrib.metrics.streaming_accuracy(
                    predictions, tf.reshape(class_labels, shape=[-1])),
                'Recall_at_5':
                tf.contrib.metrics.streaming_recall_at_k(
                    logits, class_labels, 5),
            })

        tf.summary.histogram('outputs/Predictions', predictions)
        tf.summary.histogram('outputs/Ground_Truth', class_labels)

        for name, value in metrics_to_values.items():
            tf.summary.scalar(name, value)

        num_batches = int(math.ceil(num_samples / float(FLAGS.batch_size)))

        slim.evaluation.evaluation_loop(
            master=FLAGS.master,
            checkpoint_dir=FLAGS.checkpoint_dir,
            logdir=FLAGS.eval_dir,
            num_evals=num_batches,
            eval_op=list(metrics_to_updates.values()),
            eval_interval_secs=FLAGS.eval_interval_secs)
示例#6
0
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  hparams = tf.contrib.training.HParams()
  hparams.weight_decay_task_classifier = FLAGS.weight_decay

  if FLAGS.dataset_name in ['mnist', 'mnist_m', 'usps']:
    hparams.task_tower = 'mnist'
  else:
    raise ValueError('Unknown dataset %s' % FLAGS.dataset_name)

  with tf.Graph().as_default():
    with tf.device(
        tf.train.replica_device_setter(FLAGS.num_ps_tasks, merge_devices=True)):
      dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                            FLAGS.split_name, FLAGS.dataset_dir)
      num_classes = dataset.num_classes

      preprocess_fn = partial(pixelda_preprocess.preprocess_classification,
                              is_training=True)

      images, labels = dataset_factory.provide_batch(
          FLAGS.dataset_name,
          FLAGS.split_name,
          dataset_dir=FLAGS.dataset_dir,
          num_readers=FLAGS.num_readers,
          batch_size=FLAGS.batch_size,
          num_preprocessing_threads=FLAGS.num_readers)
      # preprocess_fn=preprocess_fn)

      # Define the model
      logits, _ = pixelda_task_towers.add_task_specific_model(
          images, hparams, num_classes=num_classes, is_training=True)

      # Define the losses
      if 'classes' in labels:
        one_hot_labels = labels['classes']
        loss = tf.losses.softmax_cross_entropy(
            onehot_labels=one_hot_labels, logits=logits)
        tf.summary.scalar('losses/Classification_Loss', loss)
      else:
        raise ValueError('Only support classification for now.')

      total_loss = tf.losses.get_total_loss()
      tf.summary.scalar('losses/Total_Loss', total_loss)

      # Setup the moving averages
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, slim.get_or_create_global_step())
      tf.add_to_collection(
          tf.GraphKeys.UPDATE_OPS,
          variable_averages.apply(moving_average_variables))

      # Specify the optimization scheme:
      learning_rate = tf.train.exponential_decay(
          FLAGS.learning_rate,
          slim.get_or_create_global_step(),
          FLAGS.learning_rate_decay_steps,
          FLAGS.learning_rate_decay_factor,
          staircase=True)

      optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.adam_beta1)

      train_op = slim.learning.create_train_op(total_loss, optimizer)

      slim.learning.train(
          train_op,
          FLAGS.logdir,
          master=FLAGS.master,
          is_chief=(FLAGS.task == 0),
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)