def testBuildModelMmd(self): images, labels, params = self._testBuildDefaultModel() with self.test_session(): dsn.create_model(images, labels, tf.cast(tf.ones([ 32, ]), tf.bool), images, labels, 'mmd_loss', params, 'dann_mnist') loss_tensors = tf.contrib.losses.get_losses() self.assertEqual(len(loss_tensors), 6)
def testBuildModelDannMultiPSTasks(self): images, labels, params = self._testBuildDefaultModel() params['ps_tasks'] = 10 with self.test_session(): dsn.create_model(images, labels, tf.cast(tf.ones([32,]), tf.bool), images, labels, 'dann_loss', params, 'dann_mnist') loss_tensors = tf.contrib.losses.get_losses() self.assertEqual(len(loss_tensors), 6)
def testBuildModelNoSeparation(self): images, labels, params = self._testBuildDefaultModel() params['use_separation'] = False with self.test_session(): dsn.create_model(images, labels, tf.cast(tf.ones([32,]), tf.bool), images, labels, 'dann_loss', params, 'dann_mnist') loss_tensors = tf.contrib.losses.get_losses() self.assertEqual(len(loss_tensors), 2)
def main(_): model_params = { 'layers_to_regularize': FLAGS.layers_to_regularize, 'alpha_weight': FLAGS.alpha_weight, 'beta_weight': FLAGS.beta_weight, 'gamma_weight': FLAGS.gamma_weight, 'recon_loss_name': FLAGS.recon_loss_name, 'decoder_name': FLAGS.decoder_name, 'encoder_name': FLAGS.encoder_name, 'weight_decay': FLAGS.weight_decay, 'batch_size': FLAGS.batch_size, } g = tf.Graph() with g.as_default(): Xs, Ys = get_source(FLAGS.batch_size) Xt, Yt = get_target(FLAGS.batch_size) slim.get_or_create_global_step() dsn.create_model(Xs, Ys, Xt, Yt, FLAGS.similarity_loss, model_params, basic_tower_name=FLAGS.basic_tower) learning_rate = tf.train.exponential_decay( FLAGS.learning_rate, slim.get_or_create_global_step(), FLAGS.decay_steps, FLAGS.decay_rate, staircase=True, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) tf.summary.scalar('total_loss', tf.losses.get_total_loss()) #opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) opt = tf.train.AdamOptimizer(learning_rate) tf.logging.set_verbosity(tf.logging.INFO) #run training loss_tensor = slim.learning.create_train_op( slim.losses.get_total_loss(), #tf.losses.get_total_loss(), opt, summarize_gradients=True, colocate_gradients_with_ops=True) slim.learning.train(train_op=loss_tensor, logdir=FLAGS.train_log_dir, master="", is_chief=1, number_of_steps=FLAGS.max_number_of_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def testBuildModelNoDomainAdaptation(self): images, labels, params = self._testBuildDefaultModel() params['use_separation'] = False with self.test_session(): dsn.create_model(images, labels, tf.cast(tf.ones([ 32, ]), tf.bool), images, labels, 'none', params, 'dann_mnist') loss_tensors = tf.contrib.losses.get_losses() self.assertEqual(len(loss_tensors), 1) self.assertEqual( len(tf.contrib.losses.get_regularization_losses()), 0)
def testBuildModelNoAdaptationWeightDecay(self): images, labels, params = self._testBuildDefaultModel() params['use_separation'] = False params['weight_decay'] = 1e-5 with self.test_session(): dsn.create_model(images, labels, tf.cast(tf.ones([32,]), tf.bool), images, labels, 'none', params, 'dann_mnist') loss_tensors = tf.contrib.losses.get_losses() self.assertEqual(len(loss_tensors), 1) self.assertTrue(len(tf.contrib.losses.get_regularization_losses()) >= 1)
def main(_): model_params = { 'use_separation': FLAGS.use_separation, 'domain_separation_startpoint': FLAGS.domain_separation_startpoint, 'layers_to_regularize': FLAGS.layers_to_regularize, 'alpha_weight': FLAGS.alpha_weight, 'beta_weight': FLAGS.beta_weight, 'gamma_weight': FLAGS.gamma_weight, 'pose_weight': FLAGS.pose_weight, 'recon_loss_name': FLAGS.recon_loss_name, 'decoder_name': FLAGS.decoder_name, 'encoder_name': FLAGS.encoder_name, 'weight_decay': FLAGS.weight_decay, 'batch_size': FLAGS.batch_size, 'use_logging': FLAGS.use_logging, 'ps_tasks': FLAGS.ps_tasks, 'task': FLAGS.task, } g = tf.Graph() with g.as_default(): with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): # Load the data. source_images, source_labels = provide_batch_fn()( FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers, FLAGS.batch_size, FLAGS.num_preprocessing_threads) target_images, target_labels = provide_batch_fn()( FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers, FLAGS.batch_size, FLAGS.num_preprocessing_threads) # In the unsupervised case all the samples in the labeled # domain are from the source domain. domain_selection_mask = tf.fill((source_images.get_shape().as_list()[0],), True) # When using the semisupervised model we include labeled target data in # the source labelled data. if FLAGS.target_labeled_dataset != 'none': # 1000 is the maximum number of labelled target samples that exists in # the datasets. target_semi_images, target_semi_labels = provide_batch_fn()( FLAGS.target_labeled_dataset, 'train', FLAGS.batch_size) # Calculate the proportion of source domain samples in the semi- # supervised setting, so that the proportion is set accordingly in the # batches. proportion = float(source_labels['num_train_samples']) / ( source_labels['num_train_samples'] + target_semi_labels['num_train_samples']) rnd_tensor = tf.random_uniform( (target_semi_images.get_shape().as_list()[0],)) domain_selection_mask = rnd_tensor < proportion source_images = tf.where(domain_selection_mask, source_images, target_semi_images) source_class_labels = tf.where(domain_selection_mask, source_labels['classes'], target_semi_labels['classes']) if 'quaternions' in source_labels: source_pose_labels = tf.where(domain_selection_mask, source_labels['quaternions'], target_semi_labels['quaternions']) (source_images, source_class_labels, source_pose_labels, domain_selection_mask) = tf.train.shuffle_batch( [ source_images, source_class_labels, source_pose_labels, domain_selection_mask ], FLAGS.batch_size, 50000, 5000, num_threads=1, enqueue_many=True) else: (source_images, source_class_labels, domain_selection_mask) = tf.train.shuffle_batch( [source_images, source_class_labels, domain_selection_mask], FLAGS.batch_size, 50000, 5000, num_threads=1, enqueue_many=True) source_labels = {} source_labels['classes'] = source_class_labels if 'quaternions' in source_labels: source_labels['quaternions'] = source_pose_labels slim.get_or_create_global_step() tf.summary.image('source_images', source_images, max_outputs=3) tf.summary.image('target_images', target_images, max_outputs=3) dsn.create_model( source_images, source_labels, domain_selection_mask, target_images, target_labels, FLAGS.similarity_loss, model_params, basic_tower_name=FLAGS.basic_tower) # Configure the optimization scheme: learning_rate = tf.train.exponential_decay( FLAGS.learning_rate, slim.get_or_create_global_step(), FLAGS.decay_steps, FLAGS.decay_rate, staircase=True, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) tf.summary.scalar('total_loss', tf.losses.get_total_loss()) opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) tf.logging.set_verbosity(tf.logging.INFO) # Run training. loss_tensor = slim.learning.create_train_op( slim.losses.get_total_loss(), opt, summarize_gradients=True, colocate_gradients_with_ops=True) slim.learning.train( train_op=loss_tensor, logdir=FLAGS.train_log_dir, master=FLAGS.master, is_chief=FLAGS.task == 0, number_of_steps=FLAGS.max_number_of_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def main(_): model_params = { 'use_separation': FLAGS.use_separation, 'domain_separation_startpoint': FLAGS.domain_separation_startpoint, 'layers_to_regularize': FLAGS.layers_to_regularize, 'alpha_weight': FLAGS.alpha_weight, 'beta_weight': FLAGS.beta_weight, 'gamma_weight': FLAGS.gamma_weight, 'pose_weight': FLAGS.pose_weight, 'recon_loss_name': FLAGS.recon_loss_name, 'decoder_name': FLAGS.decoder_name, 'encoder_name': FLAGS.encoder_name, 'weight_decay': FLAGS.weight_decay, 'batch_size': FLAGS.batch_size, 'use_logging': FLAGS.use_logging, 'ps_tasks': FLAGS.ps_tasks, 'task': FLAGS.task, } g = tf.Graph() with g.as_default(): with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): # Load the data. source_images, source_labels = provide_batch_fn()( FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers, FLAGS.batch_size, FLAGS.num_preprocessing_threads) target_images, target_labels = provide_batch_fn()( FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers, FLAGS.batch_size, FLAGS.num_preprocessing_threads) # In the unsupervised case all the samples in the labeled # domain are from the source domain. domain_selection_mask = tf.fill( (source_images.get_shape().as_list()[0], ), True) # When using the semisupervised model we include labeled target data in # the source labelled data. if FLAGS.target_labeled_dataset != 'none': # 1000 is the maximum number of labelled target samples that exists in # the datasets. target_semi_images, target_semi_labels = provide_batch_fn()( FLAGS.target_labeled_dataset, 'train', FLAGS.batch_size) # Calculate the proportion of source domain samples in the semi- # supervised setting, so that the proportion is set accordingly in the # batches. proportion = float(source_labels['num_train_samples']) / ( source_labels['num_train_samples'] + target_semi_labels['num_train_samples']) rnd_tensor = tf.random_uniform( (target_semi_images.get_shape().as_list()[0], )) domain_selection_mask = rnd_tensor < proportion source_images = tf.where(domain_selection_mask, source_images, target_semi_images) source_class_labels = tf.where(domain_selection_mask, source_labels['classes'], target_semi_labels['classes']) if 'quaternions' in source_labels: source_pose_labels = tf.where( domain_selection_mask, source_labels['quaternions'], target_semi_labels['quaternions']) (source_images, source_class_labels, source_pose_labels, domain_selection_mask) = tf.train.shuffle_batch( [ source_images, source_class_labels, source_pose_labels, domain_selection_mask ], FLAGS.batch_size, 50000, 5000, num_threads=1, enqueue_many=True) else: (source_images, source_class_labels, domain_selection_mask) = tf.train.shuffle_batch( [ source_images, source_class_labels, domain_selection_mask ], FLAGS.batch_size, 50000, 5000, num_threads=1, enqueue_many=True) source_labels = {} source_labels['classes'] = source_class_labels if 'quaternions' in source_labels: source_labels['quaternions'] = source_pose_labels slim.get_or_create_global_step() tf.summary.image('source_images', source_images, max_outputs=3) tf.summary.image('target_images', target_images, max_outputs=3) dsn.create_model(source_images, source_labels, domain_selection_mask, target_images, target_labels, FLAGS.similarity_loss, model_params, basic_tower_name=FLAGS.basic_tower) # Configure the optimization scheme: learning_rate = tf.train.exponential_decay( FLAGS.learning_rate, slim.get_or_create_global_step(), FLAGS.decay_steps, FLAGS.decay_rate, staircase=True, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) tf.summary.scalar('total_loss', tf.losses.get_total_loss()) opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) tf.logging.set_verbosity(tf.logging.INFO) # Run training. loss_tensor = slim.learning.create_train_op( slim.losses.get_total_loss(), opt, summarize_gradients=True, colocate_gradients_with_ops=True) slim.learning.train(train_op=loss_tensor, logdir=FLAGS.train_log_dir, master=FLAGS.master, is_chief=FLAGS.task == 0, number_of_steps=FLAGS.max_number_of_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)