Beispiel #1
0
def add_task_loss(source_images, source_labels, basic_tower, params):
    """Adds a classification and/or pose estimation loss to the model.

  Args:
    source_images: images from the source domain, a tensor of size
      [batch_size, height, width, channels]
    source_labels: labels from the source domain, a tensor of size [batch_size].
      or a tuple of (quaternions, class_labels)
    basic_tower: a function that creates the single tower of the model.
    params: A dictionary of parameters. Expecting 'weight_decay', 'pose_weight'.
  Returns:
    The source endpoints.

  Raises:
    RuntimeError: if basic tower does not support pose estimation.
  """
    with tf.variable_scope('towers'):
        source_logits, source_endpoints = basic_tower(
            source_images,
            weight_decay=params['weight_decay'],
            prefix='Source')

    if 'quaternions' in source_labels:  # We have pose estimation as well
        if 'quaternion_pred' not in source_endpoints:
            raise RuntimeError(
                'Please use a model for estimation e.g. pose_mini')

        loss = losses.log_quaternion_loss(source_labels['quaternions'],
                                          source_endpoints['quaternion_pred'],
                                          params)

        assert_op = tf.Assert(tf.is_finite(loss), [loss])
        with tf.control_dependencies([assert_op]):
            quaternion_loss = loss
            tf.summary.histogram('log_quaternion_loss_hist', quaternion_loss)
        slim.losses.add_loss(quaternion_loss * params['pose_weight'])
        tf.summary.scalar('losses/quaternion_loss', quaternion_loss)

    classification_loss = tf.losses.softmax_cross_entropy(
        source_labels['classes'], source_logits)

    tf.summary.scalar('losses/classification_loss', classification_loss)
    return source_endpoints
Beispiel #2
0
def create_model(source_images, source_labels, domain_selection_mask,
                 source_val_images, source_val_labels,
                 target_images, target_labels, similarity_loss, params,
                 basic_tower_name):
  """Creates a DSN model.

  Args:
    source_images: images from the source domain, a tensor of size
      [batch_size, height, width, channels]
    source_labels: a dictionary with the name, tensor pairs. 'classes' is one-
      hot for the number of classes.
    domain_selection_mask: a boolean tensor of size [batch_size, ] which denotes
      the labeled images that belong to the source domain.
    target_images: images from the target domain, a tensor of size
      [batch_size, height width, channels].
    target_labels: a dictionary with the name, tensor pairs.
    similarity_loss: The type of method to use for encouraging
      the codes from the shared encoder to be similar.
    params: A dictionary of parameters. Expecting 'weight_decay',
      'layers_to_regularize', 'use_separation', 'domain_separation_startpoint',
      'alpha_weight', 'beta_weight', 'gamma_weight', 'recon_loss_name',
      'decoder_name', 'encoder_name'
    basic_tower_name: the name of the tower to use for the shared encoder.

  Raises:
    ValueError: if the arch is not one of the available architectures.
  """
  network = getattr(models, basic_tower_name)
  num_classes = source_labels['classes'].get_shape().as_list()[1]

  # Make sure we are using the appropriate number of classes.
  network = partial(network, num_classes=num_classes)

  # Add the classification/pose estimation loss to the source domain.
  source_endpoints = add_task_loss(source_images, source_labels, network,
                                   params)

  if similarity_loss == 'none':
    # No domain adaptation, we can stop here.
    return

  with tf.variable_scope('towers', reuse=True):
    source_val_logits, source_val_endpoints = network(source_val_images, weight_decay=params['weight_decay'],
                                                      prefix='source_val')
    target_logits, target_endpoints = network(
        target_images, weight_decay=params['weight_decay'], prefix='target')


  # Plot target accuracy, auc of the train set.
  source_val_accuracy = utils.accuracy(tf.argmax(source_val_logits, 1), tf.argmax(source_val_labels['classes'], 1))
  tf.summary.scalar('eval/Source validation accuracy', source_val_accuracy)
  target_accuracy = utils.accuracy(tf.argmax(target_logits, 1), tf.argmax(target_labels['classes'], 1))
  tf.summary.scalar('eval/Target accuracy', target_accuracy)

  if num_classes == 2:
    score_val = tf.nn.softmax(source_val_logits)[:, 1]
    source_val_auc = tf.metrics.auc(tf.argmax(source_val_labels['classes'], 1), score_val)
    tf.summary.scalar('eval/Source validation AUC', source_val_auc[1])
    score = tf.nn.softmax(target_logits)[:, 1]
    target_auc = tf.metrics.auc(tf.argmax(target_labels['classes'], 1), score)
    tf.summary.scalar('eval/Target AUC', target_auc[1])


  if 'quaternions' in target_labels:
    target_quaternion_loss = losses.log_quaternion_loss(
        target_labels['quaternions'], target_endpoints['quaternion_pred'],
        params)
    tf.summary.scalar('eval/Target quaternions', target_quaternion_loss)


  source_shared = source_endpoints[params['layers_to_regularize']]
  target_shared = target_endpoints[params['layers_to_regularize']]

  # When using the semisupervised model we include labeled target data in the
  # source classifier. We do not want to include these target domain when
  # we use the similarity loss.
  indices = tf.range(0, source_shared.get_shape().as_list()[0])
  indices = tf.boolean_mask(indices, domain_selection_mask)
  add_similarity_loss(similarity_loss,
                      tf.gather(source_shared, indices),
                      tf.gather(target_shared, indices), params)

  if params['use_separation']:
    add_autoencoders(
        source_images,
        source_shared,
        target_images,
        target_shared,
        params=params,)