示例#1
0
def _calculate_vmf_loss(embedding,
                        semantic_labels,
                        unique_instance_labels,
                        prototype_labels,
                        concentration,
                        memory=None,
                        memory_labels=None,
                        split=1):
    """Calculates the von-Mises Fisher loss given semantic and cluster labels.

  Args:
    embedding: A 2-D float tensor with shape [num_pixels, embedding_dim].
    semantic_labels: A 1-D integer tensor with length [num_pixels]. It contains
      the semantic label for each pixel.
    unique_instance_labels: A 1-D integer tensor with length [num_pixels]. It
      contains the unique instance label for each pixel.
    prototype_labels: A 1-D integer tensor with length [num_prototypes]. It
      contains the semantic label for each prototype.
    concentration: A float that controls the sharpness of cosine similarities.
    memory: A 2-D float tensor for memory prototypes with shape
      `[num_prototypes, embedding_dim]`.
    memory_labels: A 1-D integer tensor for labels of memory prototypes with
      length `[num_prototypes]`.
    split: An integer for number of splits of matrix multiplication.

  Returns:
    loss: A float for the von-Mises Fisher loss.
    new_memory: A 2-D float tensor for the memory prototypes to update with
      shape `[num_prototypes, embedding_dim]`.
    new_memory_labels: A 1-D integer tensor for labels of memory prototypes to
      update with length `[num_prototypes]`.
  """
    prototypes = common_utils.calculate_prototypes_from_labels(
        embedding, unique_instance_labels, tf.size(prototype_labels))

    if memory is not None:
        memory = common_utils.normalize_embedding(memory)
        rand_index = tf.random_shuffle(tf.range(tf.shape(prototype_labels)[0]))
        new_memory = tf.squeeze(tf.gather(prototypes, rand_index))
        new_memory_labels = tf.squeeze(tf.gather(prototype_labels, rand_index))
        prototypes = tf.concat([prototypes, memory], 0)
        prototype_labels = tf.concat([prototype_labels, memory_labels], 0)
    else:
        new_memory = new_memory_labels = None

    similarities = _calculate_similarities(embedding, prototypes,
                                           concentration, split)

    log_likelihood = _calculate_log_likelihood(similarities,
                                               unique_instance_labels,
                                               semantic_labels,
                                               prototype_labels)

    loss = tf.reduce_mean(log_likelihood)
    return loss, new_memory, new_memory_labels
示例#2
0
def add_unsupervised_segsort_loss(embedding,
                                  concentration,
                                  cluster_labels,
                                  num_banks=0,
                                  loss_scope=None):
    with tf.name_scope(loss_scope, 'unsupervised_segsort_loss',
                       (embedding, concentration, cluster_labels, num_banks)):

        # Normalize embedding.
        embedding = common_utils.normalize_embedding(embedding)
        shape = embedding.get_shape().as_list()
        batch_size = shape[0]
        embedding_dim = shape[3]

        # Add offset to cluster labels.
        max_clusters = 256
        offset = tf.range(0, max_clusters * batch_size, max_clusters)
        cluster_labels += tf.reshape(offset, [-1, 1, 1, 1])
        _, cluster_labels = tf.unique(tf.reshape(cluster_labels, [-1]))

        # Calculate prototypes.
        embedding = tf.reshape(embedding, [-1, embedding_dim])
        prototypes = common_utils.calculate_prototypes_from_labels(
            embedding, cluster_labels)

        similarities = _calculate_similarities(embedding, prototypes,
                                               concentration, batch_size)

        # Calculate the unsupervised loss.
        self_indices = tf.concat([
            tf.expand_dims(tf.range(tf.shape(similarities)[0]), 1),
            tf.expand_dims(cluster_labels, 1)
        ],
                                 axis=1)
        numerator = tf.reshape(tf.gather_nd(similarities, self_indices), [-1])
        denominator = tf.reduce_sum(similarities, axis=1)

        probabilities = tf.divide(numerator, denominator)
        return tf.reduce_mean(-tf.log(probabilities))
示例#3
0
def main():
    """Creates the model and start the inference process."""
    args = get_arguments()

    # Parse image processing arguments.
    input_size = parse_commastr(args.input_size)
    strides = parse_commastr(args.strides)
    assert (input_size is not None and strides is not None)
    h, w = input_size
    innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8)))

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # Load the data reader.
    with tf.name_scope('create_inputs'):
        reader = SegSortImageReader(
            args.data_dir,
            args.data_list,
            None,
            False,  # No random scale
            False,  # No random mirror
            False,  # No random crop, center crop instead
            args.ignore_label,
            IMG_MEAN)
        image = reader.image[:reader.height, :reader.width]
        image_list = reader.image_list
    image_batch = tf.expand_dims(image, dim=0)

    # Create multi-scale augmented datas.
    rescale_image_batches = []
    is_flipped = []
    scales = [0.5, 0.75, 1, 1.25, 1.5, 1.75] if args.scale_aug else [1]
    for scale in scales:
        h_new = tf.to_int32(
            tf.multiply(tf.to_float(tf.shape(image_batch)[1]), scale))
        w_new = tf.to_int32(
            tf.multiply(tf.to_float(tf.shape(image_batch)[2]), scale))
        new_shape = tf.stack([h_new, w_new])
        new_image_batch = tf.image.resize_images(image_batch, new_shape)
        rescale_image_batches.append(new_image_batch)
        is_flipped.append(False)

    # Create horizontally flipped augmented datas.
    if args.flip_aug:
        for i in range(len(scales)):
            img = rescale_image_batches[i]
            is_flip = is_flipped[i]
            img = tf.squeeze(img, axis=0)
            flip_img = tf.image.flip_left_right(img)
            flip_img = tf.expand_dims(flip_img, axis=0)
            rescale_image_batches.append(flip_img)
            is_flipped.append(True)

    # Create input tensor to the Network.
    crop_image_batch = tf.placeholder(
        name='crop_image_batch',
        shape=[1, input_size[0], input_size[1], 3],
        dtype=tf.float32)

    # Create network.
    outputs = model(crop_image_batch, args.embedding_dim, False, True)

    # Grab variable names which should be restored from checkpoints.
    restore_var = [
        v for v in tf.global_variables() if 'crop_image_batch' not in v.name
    ]

    # Output predictions.
    output = outputs[0]
    output = tf.image.resize_bilinear(output, [input_size[0], input_size[1]])
    embedding = common_utils.normalize_embedding(output)

    # Prototype placeholders.
    prototype_features = tf.placeholder(tf.float32,
                                        shape=[None, args.embedding_dim])
    prototype_labels = tf.placeholder(tf.int32)

    # Combine embedding with location features and kmeans
    shape = embedding.get_shape().as_list()
    loc_feature = tf.expand_dims(
        common_utils.generate_location_features([shape[1], shape[2]], 'float'),
        0)
    embedding_with_location = tf.concat([embedding, loc_feature], 3)
    embedding_with_location = common_utils.normalize_embedding(
        embedding_with_location)

    # Perform Kmeans clustering and extract prototypes.
    cluster_labels = common_utils.kmeans(
        embedding_with_location, [args.num_clusters, args.num_clusters],
        args.kmeans_iterations)
    _, cluster_labels = tf.unique(tf.reshape(cluster_labels, [-1]))
    test_prototypes = common_utils.calculate_prototypes_from_labels(
        embedding, cluster_labels)

    # Predict semantic labels.
    similarities = tf.matmul(test_prototypes,
                             prototype_features,
                             transpose_b=True)
    _, k_predictions = tf.nn.top_k(similarities,
                                   k=args.k_in_nearest_neighbors,
                                   sorted=True)
    k_predictions = tf.gather(prototype_labels, k_predictions)
    k_predictions = tf.gather(k_predictions, cluster_labels)
    k_predictions = tf.reshape(
        k_predictions,
        [shape[0], shape[1], shape[2], args.k_in_nearest_neighbors])

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)
    sess.run(tf.local_variables_initializer())

    # Load weights.
    loader = tf.train.Saver(var_list=restore_var)
    if args.restore_from is not None:
        load(loader, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Get colormap.
    map_data = scipy.io.loadmat(args.colormap)
    key = os.path.basename(args.colormap).replace('.mat', '')
    colormap = map_data[key]
    colormap *= 255
    colormap = colormap.astype(np.uint8)

    # Create directory for saving predictions.
    pred_dir = os.path.join(args.save_dir, 'gray')
    color_dir = os.path.join(args.save_dir, 'color')
    if not os.path.isdir(pred_dir):
        os.makedirs(pred_dir)
    if not os.path.isdir(color_dir):
        os.makedirs(color_dir)

    # Iterate over testing steps.
    with open(args.data_list, 'r') as listf:
        num_steps = len(listf.read().split('\n')) - 1

    # Load prototype features and labels.
    prototype_features_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_features.npy'))
    prototype_labels_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_labels.npy'))
    feed_dict = {
        prototype_features: prototype_features_np,
        prototype_labels: prototype_labels_np
    }

    pbar = tqdm(range(num_steps))
    for step in pbar:
        rescale_img_batches = sess.run(rescale_image_batches)
        # Final segmentation results (average across multiple scales).
        scale_ind = 2 if args.scale_aug else 0
        final_lab_size = list(rescale_img_batches[scale_ind].shape[1:])
        final_lab_size[-1] = args.num_classes
        final_lab_batch = np.zeros(final_lab_size)

        # Iterate over multiple scales.
        for img_batch, is_flip in zip(rescale_img_batches, is_flipped):
            img_size = img_batch.shape
            padimg_size = list(img_size)  # deep copy of img_size

            padimg_h, padimg_w = padimg_size[1:3]
            input_h, input_w = input_size

            if input_h > padimg_h:
                padimg_h = input_h
            if input_w > padimg_w:
                padimg_w = input_w
            # Update padded image size.
            padimg_size[1] = padimg_h
            padimg_size[2] = padimg_w
            padimg_batch = np.zeros(padimg_size, dtype=np.float32)
            img_h, img_w = img_size[1:3]
            padimg_batch[:, :img_h, :img_w, :] = img_batch

            stride_h, stride_w = strides
            npatches_h = math.ceil(1.0 * (padimg_h - input_h) / stride_h) + 1
            npatches_w = math.ceil(1.0 * (padimg_w - input_w) / stride_w) + 1

            # Create padded prediction array.
            pred_size = list(padimg_size)
            pred_size[-1] = args.num_classes
            predictions_np = np.zeros(pred_size, dtype=np.int32)

            # Create the ending index of each patch.
            patch_indh = np.linspace(input_h,
                                     padimg_h,
                                     npatches_h,
                                     dtype=np.int32)
            patch_indw = np.linspace(input_w,
                                     padimg_w,
                                     npatches_w,
                                     dtype=np.int32)

            pred_size[-1] = args.embedding_dim
            for indh in patch_indh:
                for indw in patch_indw:
                    sh, eh = indh - input_h, indh  # start & end ind of H
                    sw, ew = indw - input_w, indw  # start & end ind of W
                    cropimg_batch = padimg_batch[:, sh:eh, sw:ew, :]
                    feed_dict[crop_image_batch] = cropimg_batch

                    k_predictions_np = sess.run(k_predictions,
                                                feed_dict=feed_dict)
                    # Sum up KNN votes.
                    # This is the speed bottleneck for multiscale inference.
                    # Use singlescale inference for fast results.
                    # TODO: Either compute on GPU or change a way of implementation.
                    for c in range(args.num_classes):
                        predictions_np[:, sh:eh, sw:ew, c] += np.sum(
                            (k_predictions_np == c).astype(np.int), axis=3)

            predictions_np = predictions_np[0, :img_h, :img_w, :]
            lab_batch = predictions_np.astype(np.float32)
            # Rescale prediction back to original resolution.
            lab_batch = cv2.resize(lab_batch,
                                   (final_lab_size[1], final_lab_size[0]),
                                   interpolation=cv2.INTER_LINEAR)
            if is_flip:
                # Flipped prediction back to original orientation.
                lab_batch = lab_batch[:, ::-1, :]
            final_lab_batch += lab_batch

        final_lab_ind = np.argmax(final_lab_batch, axis=-1)
        final_lab_ind = final_lab_ind.astype(np.uint8)

        basename = os.path.basename(image_list[step])
        basename = basename.replace('jpg', 'png')

        predname = os.path.join(pred_dir, basename)
        Image.fromarray(final_lab_ind, mode='L').save(predname)

        colorname = os.path.join(color_dir, basename)
        color = colormap[final_lab_ind]
        Image.fromarray(color, mode='RGB').save(colorname)

    coord.request_stop()
    coord.join(threads)
示例#4
0
def extract_trained_prototypes(embedding, location_features, cluster_labels,
                               num_clusters, kmeans_iterations,
                               panoptic_labels, panoptic_label_divisor,
                               ignore_label, evaluate_semantic_or_panoptic):
    """Extracts the trained prototypes in an image.

  Args:
    embedding: A 2-D float tensor with shape `[pixels, embedding_dim]`.
    location_features: A 2-D float tensor for location features with shape
      `[pixels, 2]`.
    cluster_labels: A 1-D integer tensor for cluster labels for all pixels.
    num_clusters: An integer scalar for total number of clusters.
    kmeans_iterations: Number of iterations for the k-means clustering.
    panoptic_labels: A 1-D integer tensor for panoptic labels for all pixels.
    panoptic_label_divisor: An integer constant to separate semantic and
      instance labels from panoptic labels.
    ignore_label: The semantic label to ignore.
    evaluate_semantic_or_panoptic: A boolean that specifies whether to evaluate
      semantic or panoptic segmentation.

  Returns:
    prototype_features: A 2-D float tensor for prototype features with shape
      `[num_prototypes, embedding_dim]`.
    prototype_labels: A 1-D integer tensor for prototype labels.
  """
    # Collect pixels of valid semantic classes.
    valid_pixels = tf.where(
        tf.not_equal(panoptic_labels // panoptic_label_divisor, ignore_label))
    panoptic_labels = tf.squeeze(tf.gather(panoptic_labels, valid_pixels),
                                 axis=1)
    cluster_labels = tf.squeeze(tf.gather(cluster_labels, valid_pixels),
                                axis=1)
    embedding = tf.squeeze(tf.gather(embedding, valid_pixels), axis=1)
    location_features = tf.squeeze(tf.gather(location_features, valid_pixels),
                                   axis=1)

    # Generate cluster labels via kmeans clustering.
    embedding_with_location = tf.concat([embedding, location_features], 1)
    embedding_with_location = common_utils.normalize_embedding(
        embedding_with_location)
    cluster_labels = common_utils.kmeans_with_initial_labels(
        embedding_with_location, cluster_labels, num_clusters,
        kmeans_iterations)
    _, cluster_labels = tf.unique(cluster_labels)

    if evaluate_semantic_or_panoptic == 'panoptic':
        # Calculate semantic and unique instance labels for all pixels.
        label_mapping, unique_panoptic_labels = tf.unique(panoptic_labels)

        # Find pixels of majority classes.
        select_pixels, majority_labels = find_majority_label_index(
            unique_panoptic_labels, cluster_labels)
    else:
        # Find pixels of majority semantic classes.
        semantic_labels = panoptic_labels // panoptic_label_divisor
        select_pixels, majority_labels = find_majority_label_index(
            semantic_labels, cluster_labels)

    cluster_labels = tf.squeeze(tf.gather(cluster_labels, select_pixels),
                                axis=1)
    embedding = tf.squeeze(tf.gather(embedding, select_pixels), axis=1)

    # Calculate the majority semantic and instance label for each prototype.
    if evaluate_semantic_or_panoptic == 'panoptic':
        prototype_panoptic_labels = tf.gather(label_mapping, majority_labels)
        prototype_semantic_labels = (prototype_panoptic_labels //
                                     panoptic_label_divisor)
        prototype_instance_labels = majority_labels
    else:
        prototype_semantic_labels = majority_labels
        prototype_instance_labels = tf.zeros_like(majority_labels)

    # Calculate the prototype features.
    prototype_features = common_utils.calculate_prototypes_from_labels(
        embedding, cluster_labels)

    return (prototype_features, prototype_semantic_labels,
            prototype_instance_labels)
示例#5
0
def predict_all_labels(embedding, num_clusters, kmeans_iterations,
                       prototype_features, prototype_semantic_labels,
                       prototype_instance_labels, k_in_nearest_neighbors,
                       panoptic_label_divisor, class_has_instances_list):
    """Predicts panoptic, semantic, and instance labels using the vMF embedding.

  Args:
    embedding: A 4-D float tensor with shape
      `[batch, height, width, embedding_dim]`.
    num_clusters: A list of 2 integers for number of clusters in y and x axes.
    kmeans_iterations: Number of iterations for the k-means clustering.
    prototype_features: A 2-D float tensor for trained prototype features with
      shape `[num_prototypes, embedding_dim]`.
    prototype_semantic_labels: A 1-D integer tensor for trained prototype
      semantic labels with length `[num_prototypes]`.
    prototype_instance_labels: A 1-D integer tensor for trained prototype
      instance labels with length `[num_prototypes]`.
    k_in_nearest_neighbors: The number of nearest neighbors to search,
      or k in k-nearest neighbors.
    panoptic_label_divisor: An integer constant to separate semantic and
      instance labels from panoptic labels.
    class_has_instances_list: A list of thing classes, which have instances.

  Returns:
    panoptic_predictions: A 1-D integer tensor for pixel panoptic predictions.
    semantic_predictions: A 1-D integer tensor for pixel semantic predictions.
    instance_predictions: A 1-D integer tensor for pixel instance predictions.
  """
    # Generate location features and combine them with embedding features.
    shape = embedding.get_shape().as_list()
    location_features = common_utils.generate_location_features(
        [shape[1], shape[2]], 'float')
    location_features = tf.expand_dims(location_features, 0)
    embedding_with_location = tf.concat([embedding, location_features], 3)
    embedding_with_location = common_utils.normalize_embedding(
        embedding_with_location)

    # Kmeans clustering.
    cluster_labels = common_utils.kmeans(embedding_with_location, num_clusters,
                                         kmeans_iterations)
    test_prototypes = common_utils.calculate_prototypes_from_labels(
        embedding, cluster_labels)

    # Predict semantic and instance labels.
    semantic_predictions, instance_predictions = predict_semantic_instance_labels(
        cluster_labels, test_prototypes, prototype_features,
        prototype_semantic_labels, prototype_instance_labels,
        k_in_nearest_neighbors)

    # Refine instance labels.
    class_has_instances_list = tf.reshape(class_has_instances_list,
                                          [1, 1, 1, -1])
    instance_predictions = tf.where(
        tf.reduce_all(tf.not_equal(tf.expand_dims(semantic_predictions, 3),
                                   class_has_instances_list),
                      axis=3), tf.zeros_like(instance_predictions),
        instance_predictions)

    # Combine semantic and panoptic predictions as panoptic predictions.
    panoptic_predictions = (semantic_predictions * panoptic_label_divisor +
                            instance_predictions)

    return (panoptic_predictions, semantic_predictions, instance_predictions,
            cluster_labels)
def main():
    """Create the model and start the Inference process."""
    args = get_arguments()

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # Load the data reader.
    with tf.name_scope('create_inputs'):
        reader = SegSortImageReader(
            args.data_dir,
            args.data_list,
            parse_commastr(args.input_size),
            False,  # No random scale
            False,  # No random mirror
            False,  # No random crop, center crop instead
            args.ignore_label,
            IMG_MEAN)

        image_list = reader.image_list
        image_batch = tf.expand_dims(reader.image, dim=0)
        label_batch = tf.expand_dims(reader.label, dim=0)
        cluster_label_batch = tf.expand_dims(reader.cluster_label, dim=0)
        loc_feature_batch = tf.expand_dims(reader.loc_feature, dim=0)
        height = reader.height
        width = reader.width

    # Create network and output prediction.
    outputs = model(image_batch, args.embedding_dim, False, True)

    # Grab variable names which should be restored from checkpoints.
    restore_var = [
        v for v in tf.global_variables() if 'crop_image_batch' not in v.name
    ]

    # Output predictions.
    output = outputs[0]
    output = tf.image.resize_bilinear(output, tf.shape(image_batch)[1:3, ])
    embedding = common_utils.normalize_embedding(output)

    # Prototype placeholders.
    prototype_features = tf.placeholder(tf.float32,
                                        shape=[None, args.embedding_dim])
    prototype_labels = tf.placeholder(tf.int32)

    # Combine embedding with location features.
    embedding_with_location = tf.concat([embedding, loc_feature_batch], 3)
    embedding_with_location = common_utils.normalize_embedding(
        embedding_with_location)

    # Kmeans clustering.
    cluster_labels = common_utils.kmeans(
        embedding_with_location, [args.num_clusters, args.num_clusters],
        args.kmeans_iterations)
    test_prototypes = common_utils.calculate_prototypes_from_labels(
        embedding, cluster_labels)

    # Predict semantic labels.
    semantic_predictions, _ = eval_utils.predict_semantic_instance_labels(
        cluster_labels, test_prototypes, prototype_features, prototype_labels,
        None, args.k_in_nearest_neighbors)
    semantic_predictions = tf.cast(semantic_predictions, tf.uint8)
    semantic_predictions = tf.squeeze(semantic_predictions)

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)
    sess.run(tf.local_variables_initializer())

    # Load weights.
    loader = tf.train.Saver(var_list=restore_var)
    if args.restore_from is not None:
        load(loader, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Get colormap.
    map_data = scipy.io.loadmat(args.colormap)
    key = os.path.basename(args.colormap).replace('.mat', '')
    colormap = map_data[key]
    colormap *= 255
    colormap = colormap.astype(np.uint8)

    # Create directory for saving predictions.
    pred_dir = os.path.join(args.save_dir, 'gray')
    color_dir = os.path.join(args.save_dir, 'color')
    if not os.path.isdir(pred_dir):
        os.makedirs(pred_dir)
    if not os.path.isdir(color_dir):
        os.makedirs(color_dir)

    # Iterate over testing steps.
    with open(args.data_list, 'r') as listf:
        num_steps = len(listf.read().split('\n')) - 1

    # Load prototype features and labels.
    prototype_features_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_features.npy'))
    prototype_labels_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_labels.npy'))

    feed_dict = {
        prototype_features: prototype_features_np,
        prototype_labels: prototype_labels_np
    }

    for step in tqdm(range(num_steps)):
        semantic_predictions_np, height_np, width_np = sess.run(
            [semantic_predictions, height, width], feed_dict=feed_dict)

        semantic_predictions_np = semantic_predictions_np[:height_np, :
                                                          width_np]

        basename = os.path.basename(image_list[step])
        basename = basename.replace('jpg', 'png')

        predname = os.path.join(pred_dir, basename)
        Image.fromarray(semantic_predictions_np, mode='L').save(predname)

        colorname = os.path.join(color_dir, basename)
        color = colormap[semantic_predictions_np]
        Image.fromarray(color, mode='RGB').save(colorname)

    coord.request_stop()
    coord.join(threads)