コード例 #1
0
    def testBuildModelMmd(self):
        images, labels, params = self._testBuildDefaultModel()

        with self.test_session():
            dsn.create_model(images, labels, tf.cast(tf.ones([
                32,
            ]), tf.bool), images, labels, 'mmd_loss', params, 'dann_mnist')
            loss_tensors = tf.contrib.losses.get_losses()
        self.assertEqual(len(loss_tensors), 6)
コード例 #2
0
ファイル: dsn_test.py プロジェクト: 812864539/models
 def testBuildModelDannMultiPSTasks(self):
   images, labels, params = self._testBuildDefaultModel()
   params['ps_tasks'] = 10
   with self.test_session():
     dsn.create_model(images, labels,
                      tf.cast(tf.ones([32,]), tf.bool), images, labels,
                      'dann_loss', params, 'dann_mnist')
     loss_tensors = tf.contrib.losses.get_losses()
   self.assertEqual(len(loss_tensors), 6)
コード例 #3
0
ファイル: dsn_test.py プロジェクト: 812864539/models
 def testBuildModelNoSeparation(self):
   images, labels, params = self._testBuildDefaultModel()
   params['use_separation'] = False
   with self.test_session():
     dsn.create_model(images, labels,
                      tf.cast(tf.ones([32,]), tf.bool), images, labels,
                      'dann_loss', params, 'dann_mnist')
     loss_tensors = tf.contrib.losses.get_losses()
   self.assertEqual(len(loss_tensors), 2)
コード例 #4
0
def main(_):
    model_params = {
        'layers_to_regularize': FLAGS.layers_to_regularize,
        'alpha_weight': FLAGS.alpha_weight,
        'beta_weight': FLAGS.beta_weight,
        'gamma_weight': FLAGS.gamma_weight,
        'recon_loss_name': FLAGS.recon_loss_name,
        'decoder_name': FLAGS.decoder_name,
        'encoder_name': FLAGS.encoder_name,
        'weight_decay': FLAGS.weight_decay,
        'batch_size': FLAGS.batch_size,
    }

    g = tf.Graph()
    with g.as_default():
        Xs, Ys = get_source(FLAGS.batch_size)
        Xt, Yt = get_target(FLAGS.batch_size)

        slim.get_or_create_global_step()
        dsn.create_model(Xs,
                         Ys,
                         Xt,
                         Yt,
                         FLAGS.similarity_loss,
                         model_params,
                         basic_tower_name=FLAGS.basic_tower)

        learning_rate = tf.train.exponential_decay(
            FLAGS.learning_rate,
            slim.get_or_create_global_step(),
            FLAGS.decay_steps,
            FLAGS.decay_rate,
            staircase=True,
            name='learning_rate')

        tf.summary.scalar('learning_rate', learning_rate)
        tf.summary.scalar('total_loss', tf.losses.get_total_loss())

        #opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
        opt = tf.train.AdamOptimizer(learning_rate)
        tf.logging.set_verbosity(tf.logging.INFO)
        #run training
        loss_tensor = slim.learning.create_train_op(
            slim.losses.get_total_loss(),
            #tf.losses.get_total_loss(),
            opt,
            summarize_gradients=True,
            colocate_gradients_with_ops=True)
        slim.learning.train(train_op=loss_tensor,
                            logdir=FLAGS.train_log_dir,
                            master="",
                            is_chief=1,
                            number_of_steps=FLAGS.max_number_of_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs)
コード例 #5
0
 def testBuildModelNoDomainAdaptation(self):
     images, labels, params = self._testBuildDefaultModel()
     params['use_separation'] = False
     with self.test_session():
         dsn.create_model(images, labels, tf.cast(tf.ones([
             32,
         ]), tf.bool), images, labels, 'none', params, 'dann_mnist')
         loss_tensors = tf.contrib.losses.get_losses()
         self.assertEqual(len(loss_tensors), 1)
         self.assertEqual(
             len(tf.contrib.losses.get_regularization_losses()), 0)
コード例 #6
0
ファイル: dsn_test.py プロジェクト: 812864539/models
 def testBuildModelNoAdaptationWeightDecay(self):
   images, labels, params = self._testBuildDefaultModel()
   params['use_separation'] = False
   params['weight_decay'] = 1e-5
   with self.test_session():
     dsn.create_model(images, labels,
                      tf.cast(tf.ones([32,]), tf.bool), images, labels, 'none',
                      params, 'dann_mnist')
     loss_tensors = tf.contrib.losses.get_losses()
     self.assertEqual(len(loss_tensors), 1)
     self.assertTrue(len(tf.contrib.losses.get_regularization_losses()) >= 1)
コード例 #7
0
ファイル: dsn_train.py プロジェクト: 812864539/models
def main(_):
  model_params = {
      'use_separation': FLAGS.use_separation,
      'domain_separation_startpoint': FLAGS.domain_separation_startpoint,
      'layers_to_regularize': FLAGS.layers_to_regularize,
      'alpha_weight': FLAGS.alpha_weight,
      'beta_weight': FLAGS.beta_weight,
      'gamma_weight': FLAGS.gamma_weight,
      'pose_weight': FLAGS.pose_weight,
      'recon_loss_name': FLAGS.recon_loss_name,
      'decoder_name': FLAGS.decoder_name,
      'encoder_name': FLAGS.encoder_name,
      'weight_decay': FLAGS.weight_decay,
      'batch_size': FLAGS.batch_size,
      'use_logging': FLAGS.use_logging,
      'ps_tasks': FLAGS.ps_tasks,
      'task': FLAGS.task,
  }
  g = tf.Graph()
  with g.as_default():
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      # Load the data.
      source_images, source_labels = provide_batch_fn()(
          FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
          FLAGS.batch_size, FLAGS.num_preprocessing_threads)
      target_images, target_labels = provide_batch_fn()(
          FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
          FLAGS.batch_size, FLAGS.num_preprocessing_threads)

      # In the unsupervised case all the samples in the labeled
      # domain are from the source domain.
      domain_selection_mask = tf.fill((source_images.get_shape().as_list()[0],),
                                      True)

      # When using the semisupervised model we include labeled target data in
      # the source labelled data.
      if FLAGS.target_labeled_dataset != 'none':
        # 1000 is the maximum number of labelled target samples that exists in
        # the datasets.
        target_semi_images, target_semi_labels = provide_batch_fn()(
            FLAGS.target_labeled_dataset, 'train', FLAGS.batch_size)

        # Calculate the proportion of source domain samples in the semi-
        # supervised setting, so that the proportion is set accordingly in the
        # batches.
        proportion = float(source_labels['num_train_samples']) / (
            source_labels['num_train_samples'] +
            target_semi_labels['num_train_samples'])

        rnd_tensor = tf.random_uniform(
            (target_semi_images.get_shape().as_list()[0],))

        domain_selection_mask = rnd_tensor < proportion
        source_images = tf.where(domain_selection_mask, source_images,
                                 target_semi_images)
        source_class_labels = tf.where(domain_selection_mask,
                                       source_labels['classes'],
                                       target_semi_labels['classes'])

        if 'quaternions' in source_labels:
          source_pose_labels = tf.where(domain_selection_mask,
                                        source_labels['quaternions'],
                                        target_semi_labels['quaternions'])
          (source_images, source_class_labels, source_pose_labels,
           domain_selection_mask) = tf.train.shuffle_batch(
               [
                   source_images, source_class_labels, source_pose_labels,
                   domain_selection_mask
               ],
               FLAGS.batch_size,
               50000,
               5000,
               num_threads=1,
               enqueue_many=True)

        else:
          (source_images, source_class_labels,
           domain_selection_mask) = tf.train.shuffle_batch(
               [source_images, source_class_labels, domain_selection_mask],
               FLAGS.batch_size,
               50000,
               5000,
               num_threads=1,
               enqueue_many=True)
        source_labels = {}
        source_labels['classes'] = source_class_labels
        if 'quaternions' in source_labels:
          source_labels['quaternions'] = source_pose_labels

      slim.get_or_create_global_step()
      tf.summary.image('source_images', source_images, max_outputs=3)
      tf.summary.image('target_images', target_images, max_outputs=3)

      dsn.create_model(
          source_images,
          source_labels,
          domain_selection_mask,
          target_images,
          target_labels,
          FLAGS.similarity_loss,
          model_params,
          basic_tower_name=FLAGS.basic_tower)

      # Configure the optimization scheme:
      learning_rate = tf.train.exponential_decay(
          FLAGS.learning_rate,
          slim.get_or_create_global_step(),
          FLAGS.decay_steps,
          FLAGS.decay_rate,
          staircase=True,
          name='learning_rate')

      tf.summary.scalar('learning_rate', learning_rate)
      tf.summary.scalar('total_loss', tf.losses.get_total_loss())

      opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      tf.logging.set_verbosity(tf.logging.INFO)
      # Run training.
      loss_tensor = slim.learning.create_train_op(
          slim.losses.get_total_loss(),
          opt,
          summarize_gradients=True,
          colocate_gradients_with_ops=True)
      slim.learning.train(
          train_op=loss_tensor,
          logdir=FLAGS.train_log_dir,
          master=FLAGS.master,
          is_chief=FLAGS.task == 0,
          number_of_steps=FLAGS.max_number_of_steps,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)
コード例 #8
0
def main(_):
    model_params = {
        'use_separation': FLAGS.use_separation,
        'domain_separation_startpoint': FLAGS.domain_separation_startpoint,
        'layers_to_regularize': FLAGS.layers_to_regularize,
        'alpha_weight': FLAGS.alpha_weight,
        'beta_weight': FLAGS.beta_weight,
        'gamma_weight': FLAGS.gamma_weight,
        'pose_weight': FLAGS.pose_weight,
        'recon_loss_name': FLAGS.recon_loss_name,
        'decoder_name': FLAGS.decoder_name,
        'encoder_name': FLAGS.encoder_name,
        'weight_decay': FLAGS.weight_decay,
        'batch_size': FLAGS.batch_size,
        'use_logging': FLAGS.use_logging,
        'ps_tasks': FLAGS.ps_tasks,
        'task': FLAGS.task,
    }
    g = tf.Graph()
    with g.as_default():
        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            # Load the data.
            source_images, source_labels = provide_batch_fn()(
                FLAGS.source_dataset, 'train', FLAGS.dataset_dir,
                FLAGS.num_readers, FLAGS.batch_size,
                FLAGS.num_preprocessing_threads)
            target_images, target_labels = provide_batch_fn()(
                FLAGS.target_dataset, 'train', FLAGS.dataset_dir,
                FLAGS.num_readers, FLAGS.batch_size,
                FLAGS.num_preprocessing_threads)

            # In the unsupervised case all the samples in the labeled
            # domain are from the source domain.
            domain_selection_mask = tf.fill(
                (source_images.get_shape().as_list()[0], ), True)

            # When using the semisupervised model we include labeled target data in
            # the source labelled data.
            if FLAGS.target_labeled_dataset != 'none':
                # 1000 is the maximum number of labelled target samples that exists in
                # the datasets.
                target_semi_images, target_semi_labels = provide_batch_fn()(
                    FLAGS.target_labeled_dataset, 'train', FLAGS.batch_size)

                # Calculate the proportion of source domain samples in the semi-
                # supervised setting, so that the proportion is set accordingly in the
                # batches.
                proportion = float(source_labels['num_train_samples']) / (
                    source_labels['num_train_samples'] +
                    target_semi_labels['num_train_samples'])

                rnd_tensor = tf.random_uniform(
                    (target_semi_images.get_shape().as_list()[0], ))

                domain_selection_mask = rnd_tensor < proportion
                source_images = tf.where(domain_selection_mask, source_images,
                                         target_semi_images)
                source_class_labels = tf.where(domain_selection_mask,
                                               source_labels['classes'],
                                               target_semi_labels['classes'])

                if 'quaternions' in source_labels:
                    source_pose_labels = tf.where(
                        domain_selection_mask, source_labels['quaternions'],
                        target_semi_labels['quaternions'])
                    (source_images, source_class_labels, source_pose_labels,
                     domain_selection_mask) = tf.train.shuffle_batch(
                         [
                             source_images, source_class_labels,
                             source_pose_labels, domain_selection_mask
                         ],
                         FLAGS.batch_size,
                         50000,
                         5000,
                         num_threads=1,
                         enqueue_many=True)

                else:
                    (source_images, source_class_labels,
                     domain_selection_mask) = tf.train.shuffle_batch(
                         [
                             source_images, source_class_labels,
                             domain_selection_mask
                         ],
                         FLAGS.batch_size,
                         50000,
                         5000,
                         num_threads=1,
                         enqueue_many=True)
                source_labels = {}
                source_labels['classes'] = source_class_labels
                if 'quaternions' in source_labels:
                    source_labels['quaternions'] = source_pose_labels

            slim.get_or_create_global_step()
            tf.summary.image('source_images', source_images, max_outputs=3)
            tf.summary.image('target_images', target_images, max_outputs=3)

            dsn.create_model(source_images,
                             source_labels,
                             domain_selection_mask,
                             target_images,
                             target_labels,
                             FLAGS.similarity_loss,
                             model_params,
                             basic_tower_name=FLAGS.basic_tower)

            # Configure the optimization scheme:
            learning_rate = tf.train.exponential_decay(
                FLAGS.learning_rate,
                slim.get_or_create_global_step(),
                FLAGS.decay_steps,
                FLAGS.decay_rate,
                staircase=True,
                name='learning_rate')

            tf.summary.scalar('learning_rate', learning_rate)
            tf.summary.scalar('total_loss', tf.losses.get_total_loss())

            opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
            tf.logging.set_verbosity(tf.logging.INFO)
            # Run training.
            loss_tensor = slim.learning.create_train_op(
                slim.losses.get_total_loss(),
                opt,
                summarize_gradients=True,
                colocate_gradients_with_ops=True)
            slim.learning.train(train_op=loss_tensor,
                                logdir=FLAGS.train_log_dir,
                                master=FLAGS.master,
                                is_chief=FLAGS.task == 0,
                                number_of_steps=FLAGS.max_number_of_steps,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)