def add_task_loss(source_images, source_labels, basic_tower, params): """Adds a classification and/or pose estimation loss to the model. Args: source_images: images from the source domain, a tensor of size [batch_size, height, width, channels] source_labels: labels from the source domain, a tensor of size [batch_size]. or a tuple of (quaternions, class_labels) basic_tower: a function that creates the single tower of the model. params: A dictionary of parameters. Expecting 'weight_decay', 'pose_weight'. Returns: The source endpoints. Raises: RuntimeError: if basic tower does not support pose estimation. """ with tf.variable_scope('towers'): source_logits, source_endpoints = basic_tower( source_images, weight_decay=params['weight_decay'], prefix='Source') if 'quaternions' in source_labels: # We have pose estimation as well if 'quaternion_pred' not in source_endpoints: raise RuntimeError( 'Please use a model for estimation e.g. pose_mini') loss = losses.log_quaternion_loss(source_labels['quaternions'], source_endpoints['quaternion_pred'], params) assert_op = tf.Assert(tf.is_finite(loss), [loss]) with tf.control_dependencies([assert_op]): quaternion_loss = loss tf.summary.histogram('log_quaternion_loss_hist', quaternion_loss) slim.losses.add_loss(quaternion_loss * params['pose_weight']) tf.summary.scalar('losses/quaternion_loss', quaternion_loss) classification_loss = tf.losses.softmax_cross_entropy( source_labels['classes'], source_logits) tf.summary.scalar('losses/classification_loss', classification_loss) return source_endpoints
def create_model(source_images, source_labels, domain_selection_mask, source_val_images, source_val_labels, target_images, target_labels, similarity_loss, params, basic_tower_name): """Creates a DSN model. Args: source_images: images from the source domain, a tensor of size [batch_size, height, width, channels] source_labels: a dictionary with the name, tensor pairs. 'classes' is one- hot for the number of classes. domain_selection_mask: a boolean tensor of size [batch_size, ] which denotes the labeled images that belong to the source domain. target_images: images from the target domain, a tensor of size [batch_size, height width, channels]. target_labels: a dictionary with the name, tensor pairs. similarity_loss: The type of method to use for encouraging the codes from the shared encoder to be similar. params: A dictionary of parameters. Expecting 'weight_decay', 'layers_to_regularize', 'use_separation', 'domain_separation_startpoint', 'alpha_weight', 'beta_weight', 'gamma_weight', 'recon_loss_name', 'decoder_name', 'encoder_name' basic_tower_name: the name of the tower to use for the shared encoder. Raises: ValueError: if the arch is not one of the available architectures. """ network = getattr(models, basic_tower_name) num_classes = source_labels['classes'].get_shape().as_list()[1] # Make sure we are using the appropriate number of classes. network = partial(network, num_classes=num_classes) # Add the classification/pose estimation loss to the source domain. source_endpoints = add_task_loss(source_images, source_labels, network, params) if similarity_loss == 'none': # No domain adaptation, we can stop here. return with tf.variable_scope('towers', reuse=True): source_val_logits, source_val_endpoints = network(source_val_images, weight_decay=params['weight_decay'], prefix='source_val') target_logits, target_endpoints = network( target_images, weight_decay=params['weight_decay'], prefix='target') # Plot target accuracy, auc of the train set. source_val_accuracy = utils.accuracy(tf.argmax(source_val_logits, 1), tf.argmax(source_val_labels['classes'], 1)) tf.summary.scalar('eval/Source validation accuracy', source_val_accuracy) target_accuracy = utils.accuracy(tf.argmax(target_logits, 1), tf.argmax(target_labels['classes'], 1)) tf.summary.scalar('eval/Target accuracy', target_accuracy) if num_classes == 2: score_val = tf.nn.softmax(source_val_logits)[:, 1] source_val_auc = tf.metrics.auc(tf.argmax(source_val_labels['classes'], 1), score_val) tf.summary.scalar('eval/Source validation AUC', source_val_auc[1]) score = tf.nn.softmax(target_logits)[:, 1] target_auc = tf.metrics.auc(tf.argmax(target_labels['classes'], 1), score) tf.summary.scalar('eval/Target AUC', target_auc[1]) if 'quaternions' in target_labels: target_quaternion_loss = losses.log_quaternion_loss( target_labels['quaternions'], target_endpoints['quaternion_pred'], params) tf.summary.scalar('eval/Target quaternions', target_quaternion_loss) source_shared = source_endpoints[params['layers_to_regularize']] target_shared = target_endpoints[params['layers_to_regularize']] # When using the semisupervised model we include labeled target data in the # source classifier. We do not want to include these target domain when # we use the similarity loss. indices = tf.range(0, source_shared.get_shape().as_list()[0]) indices = tf.boolean_mask(indices, domain_selection_mask) add_similarity_loss(similarity_loss, tf.gather(source_shared, indices), tf.gather(target_shared, indices), params) if params['use_separation']: add_autoencoders( source_images, source_shared, target_images, target_shared, params=params,)