Ejemplo n.º 1
0
def soft_ncut(image, image_segment, image_weights):
    """
    Args:
        image: [B, H, W, C]
        image_segment: [B, H, W, K]
        image_weights: [B, H*W, H*W]
    Returns:
        Soft_Ncut: scalar
    """

    batch_size = tf.shape(image)[0]
    num_class = tf.shape(image_segment)[-1]
    image_shape = image.get_shape()
    weight_size = image_shape[1].value * image_shape[2].value
    image_segment = tf.transpose(image_segment, [0, 3, 1, 2])  # [B, K, H, W]
    image_segment = tf.reshape(image_segment,
                               tf.stack([batch_size, num_class,
                                         weight_size]))  # [B, K, H*W]

    # Dis-association
    # [B0, H*W, H*W] @ [B1, K1, H*W] contract on [[2],[2]] = [B0, H*W, B1, K1]
    W_Ak = sparse_tensor_dense_tensordot(image_weights,
                                         image_segment,
                                         axes=[[2], [2]])
    W_Ak = tf.transpose(W_Ak, [0, 2, 3, 1])  # [B0, B1, K1, H*W]
    W_Ak = sycronize_axes(W_Ak, [0, 1], tensor_dims=4)  # [B0=B1, K1, H*W]
    # [B1, K1, H*W] @ [B2, K2, H*W] contract on [[2],[2]] = [B1, K1, B2, K2]
    dis_assoc = tf.tensordot(W_Ak, image_segment, axes=[[2], [2]])
    dis_assoc = sycronize_axes(dis_assoc, [0, 2],
                               tensor_dims=4)  # [B1=B2, K1, K2]
    dis_assoc = sycronize_axes(dis_assoc, [1, 2],
                               tensor_dims=3)  # [K1=K2, B1=B2]
    dis_assoc = tf.transpose(dis_assoc, [1, 0])  # [B1=B2, K1=K2]
    dis_assoc = tf.identity(dis_assoc, name="dis_assoc")

    # Association
    # image_segment: [B0, K0, H*W]
    sum_W = tf.sparse_reduce_sum(image_weights, axis=2)  # [B1, W*H]
    assoc = tf.tensordot(image_segment, sum_W, axes=[2, 1])  # [B0, K0, B1]
    assoc = sycronize_axes(assoc, [0, 2], tensor_dims=3)  # [B0=B1, K0]
    assoc = tf.identity(assoc, name="assoc")

    utils.add_activation_summary(dis_assoc)
    utils.add_activation_summary(assoc)

    # Soft NCut
    eps = 1e-6
    soft_ncut = tf.cast(num_class, tf.float32) - \
                tf.reduce_sum((dis_assoc + eps) / (assoc + eps), axis=1)

    return soft_ncut
Ejemplo n.º 2
0
    def train(cls, loss_val, var_list, flags):
        """
        Create train_op and learning_rate.
        """

        learning_rate = tf.Variable(flags.learning_rate, trainable=False)
        optimizer = tf.train.AdamOptimizer(learning_rate)
        # optimizer = tf.train.RMSPropOptimizer(learning_rate)
        grads = optimizer.compute_gradients(loss_val, var_list=var_list)
        if flags.debug:
            # print(len(var_list))
            for grad, var in grads:
                utils.add_gradient_summary(grad, var)
        train_op = optimizer.apply_gradients(grads)
        return learning_rate, train_op
Ejemplo n.º 3
0
    def plot_segmentation_under_test_dir(self):

        image_pattern = os.path.join(self.flags.test_dir, '*')
        image_lst = glob(image_pattern)
        data = []
        if not image_lst:
            print('No files found')
        else:
            test_images = np.stack([
                misc.imresize(imageio.imread(file),
                              [self.flags.image_size, self.flags.image_size],
                              interp='bilinear') for file in image_lst
            ])
            test_preds = self.predict_segmentation(test_images)
            colorized_test_preds = utils.batch_colorize_ndarray(
                test_preds, 0, self.flags.num_class,
                self.flags.cmap)[:, :, :, :3]
            for i, (imag,
                    pred) in enumerate(zip(test_images, colorized_test_preds)):
                fig, axes = plt.subplots(1, 2)
                axes[0].imshow(imag)
                axes[1].imshow(pred)
                axes[0].axis('off')
                axes[1].axis('off')
                filename = os.path.join(self.flags.logs_dir,
                                        'Figure_%d.png' % i)
                plt.savefig(filename, dpi=300, format="png", transparent=False)
            # plt.show()
        return test_images, test_preds
Ejemplo n.º 4
0
    def visualize_pred(self, dataset_reader):
        """
        Predict segmentation of images random selected from dataset_reader.
        """

        valid_images, _ = dataset_reader.get_random_batch(
            self.flags.batch_size)

        feed_dict = {
            self.image: valid_images,
            self.keep_probability: 1.0,
            self.phase_train: False
        }
        reconst_image, pred = self.sess.run(
            [self.reconstruct_image, self.pred_annotation],
            feed_dict=feed_dict)
        pred = np.squeeze(pred, axis=3)
        pred = utils.batch_colorize_ndarray(pred, 0, self.flags.num_class,
                                            self.flags.cmap)[:, :, :, :3]

        for itr in range(self.flags.batch_size):
            utils.save_image(valid_images[itr].astype(np.uint8),
                             self.flags.logs_dir,
                             name="inp_" + str(5 + itr))
            utils.save_image(reconst_image[itr].astype(np.uint8),
                             self.flags.logs_dir,
                             name="gt_" + str(5 + itr))
            utils.save_image(pred[itr].astype(np.uint8),
                             self.flags.logs_dir,
                             name="pred_" + str(5 + itr))
            print("Saved image: %d" % itr)

        return valid_images, pred
Ejemplo n.º 5
0
    def __init__(self, flags):
        """
        Initialize:
            placeholder,
            train_op,
            summary,
            session,
            saver and file_writer
        """

        self.flags = flags
        image_size = int(self.flags.image_size)
        num_class = int(self.flags.num_class)

        # Place holder
        self.image = tf.placeholder(tf.float32,
                                    shape=[None, image_size, image_size, 3],
                                    name="input_image")
        self.annotation = tf.placeholder(
            tf.int32,
            shape=[None, image_size, image_size, num_class],
            name="annotation")
        self.phase_train = tf.placeholder(tf.bool, name='phase_train')
        self.keep_probability = tf.placeholder(tf.float32,
                                               name="keep_probabilty")

        # Prediction and loss
        self.pred_annotation, self.image_segment_logits = \
            self.inference(self.image, self.keep_probability, self.phase_train, self.flags)
        image_segment = tf.nn.softmax(self.image_segment_logits)
        colorized_annotation = tf.argmax(self.annotation, axis=3)
        colorized_annotation = tf.expand_dims(colorized_annotation, dim=3)
        self.colorized_annotation = utils.batch_colorize(
            colorized_annotation, 0, num_class, self.flags.cmap)
        self.colorized_pred_annotation = utils.batch_colorize(
            self.pred_annotation, 0, num_class, self.flags.cmap)
        self.loss = tf.reduce_mean((tf.nn.softmax_cross_entropy_with_logits(
            logits=self.image_segment_logits,
            labels=self.annotation,
            name="entropy")))

        # Train var and op
        trainable_var = tf.trainable_variables()
        if self.flags.debug:
            for var in trainable_var:
                utils.add_to_regularization_and_summary(var)
        self.learning_rate, self.train_op = self.train(self.loss,
                                                       trainable_var,
                                                       self.flags)
        self.learning_rate_summary = tf.summary.scalar("learning_rate",
                                                       self.learning_rate)

        # Summary
        print("Setting up summary op...")
        tf.summary.image("input_image", self.image, max_outputs=2)
        tf.summary.image("annotation",
                         self.colorized_annotation,
                         max_outputs=2)
        tf.summary.image("pred_annotation",
                         self.colorized_pred_annotation,
                         max_outputs=2)
        self.loss_summary = tf.summary.scalar("total_loss", self.loss)
        self.summary_op = tf.summary.merge_all()

        # Session ,saver, and writer
        print("Setting up Session and Saver...")
        self.sess = tf.Session()
        self.saver = tf.train.Saver(max_to_keep=2)
        # create two summary writers to show training loss and validation loss in the same graph
        # need to create two folders 'train' and 'validation' inside FLAGS.logs_dir
        self.train_writer = tf.summary.FileWriter(
            os.path.join(self.flags.logs_dir, 'train'), self.sess.graph)
        self.validation_writer = tf.summary.FileWriter(
            os.path.join(self.flags.logs_dir, 'validation'))

        print("Initialize tf variables")
        self.sess.run(tf.global_variables_initializer())
        ckpt = tf.train.get_checkpoint_state(self.flags.logs_dir)
        if ckpt and ckpt.model_checkpoint_path:
            self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            print("Model restored...")
        return
Ejemplo n.º 6
0
    def unet(cls,
             image,
             keep_prob,
             phase_train,
             output_channel,
             num_layers,
             is_debug=False):

        net = {}
        batch_size = tf.shape(image)[0]
        current = image
        net['image'] = current
        for index_module in range(num_layers):
            # Check type of module
            is_encoder = index_module < num_layers // 2
            is_decoder = index_module > num_layers // 2
            is_classifier = index_module == num_layers // 2

            # Set number of input and output channels
            in_ch = current.get_shape()[-1]
            mod_output = 'mod%d_out'
            if is_encoder:
                current = cls.unet_encode(current, keep_prob, phase_train,
                                          index_module)
                name = mod_output % index_module
                net[name] = current
                current = slim.max_pool2d(current, [2, 2],
                                          stride=2,
                                          padding='SAME')

            if is_classifier:
                current = cls.unet_encode(current, keep_prob, phase_train,
                                          index_module)
                name = mod_output % index_module
                net[name] = current
                current = cls.upconv(current, index_module)

            if is_decoder:
                fuse_pool = mod_output % (num_layers - 1 - index_module)
                # print(index_module, num_layers-1-index_module)
                # print(net[fuse_pool].get_shape())
                # print(current.get_shape())
                current = tf.concat([current, net[fuse_pool]],
                                    axis=3,
                                    name="fuse_%d" % index_module)
                current = cls.unet_decode(current, keep_prob, phase_train,
                                          index_module)
                name = mod_output % index_module
                net[name] = current
                if index_module != num_layers - 1:
                    current = cls.upconv(current, index_module)
            if is_debug:
                print(name)
                print(net[name].get_shape())
                utils.add_activation_summary(current)

        # conv1x1
        current = slim.conv2d(current, output_channel, 1)
        name = 'segment'
        net[name] = current
        if is_debug:
            print(name)
            print(net[name].get_shape())
            print('unet complete')
        return net
Ejemplo n.º 7
0
    def __init__(self, flags):
        """
        Initialize:
            placeholder,
            train_op,
            summary,
            session,
            saver and file_writer
        """

        self.flags = flags
        image_size = int(self.flags.image_size)
        num_class = int(self.flags.num_class)

        # Place holder
        self.keep_probability = tf.placeholder(tf.float32,
                                               name="keep_probabilty")
        self.image = tf.placeholder(tf.float32,
                                    shape=[None, image_size, image_size, 3],
                                    name="input_image")
        self.phase_train = tf.placeholder(tf.bool, name='phase_train')
        self.neighbor_indeces = tf.placeholder(tf.int64,
                                               name="neighbor_indeces")
        self.neighbor_vals = tf.placeholder(tf.float32, name="neighbor_vals")
        self.neighbor_shape = tf.placeholder(tf.int64, name="neighbor_shape")
        neighbor_filter = (self.neighbor_indeces, self.neighbor_vals,
                           self.neighbor_shape)
        _image_weights = brightness_weight(self.image,
                                           neighbor_filter,
                                           sigma_I=0.05)
        image_weights = convert_to_batchTensor(*_image_weights)

        # Prediction and loss
        self.pred_annotation, self.image_segment_logits, self.reconstruct_image = \
            self.inference(self.image, self.keep_probability, self.phase_train, self.flags)
        image_segment = tf.nn.softmax(self.image_segment_logits)
        self.colorized_pred_annotation = utils.batch_colorize(
            self.pred_annotation, 0, num_class, self.flags.cmap)
        self.reconstruct_loss = tf.reduce_mean(
            tf.reshape((self.image - self.reconstruct_image)**2, shape=[-1]))
        batch_soft_ncut = soft_ncut(self.image, image_segment, image_weights)
        self.soft_ncut = tf.reduce_mean(batch_soft_ncut)
        self.loss = self.reconstruct_loss + self.soft_ncut

        # Train var and op
        trainable_var = tf.trainable_variables()
        encode_trainable_var = tf.trainable_variables("infer_encode")
        if self.flags.debug:
            for var in trainable_var:
                utils.add_to_regularization_and_summary(var)
        self.reconst_learning_rate, self.train_reconst_op = \
            self.train(self.reconstruct_loss, trainable_var, self.flags)
        self.softNcut_learning_rate, self.train_softNcut_op = \
            self.train(self.soft_ncut, encode_trainable_var, self.flags)
        self.reconst_learning_rate_summary = tf.summary.scalar(
            "reconst_learning_rate", self.reconst_learning_rate)
        self.softNcut_learning_rate_summary = tf.summary.scalar(
            "softNcut_learning_rate", self.softNcut_learning_rate)

        # Summary
        tf.summary.image("input_image", self.image, max_outputs=2)
        tf.summary.image("reconstruct_image",
                         self.reconstruct_image,
                         max_outputs=2)
        tf.summary.image("pred_annotation",
                         self.colorized_pred_annotation,
                         max_outputs=2)
        reconstLoss_summary = tf.summary.scalar("reconstruct_loss",
                                                self.reconstruct_loss)
        softNcutLoss_summary = tf.summary.scalar("soft_ncut_loss",
                                                 self.soft_ncut)
        totLoss_summary = tf.summary.scalar("total_loss", self.loss)
        self.loss_summary = tf.summary.merge(
            [reconstLoss_summary, softNcutLoss_summary, totLoss_summary])
        self.summary_op = tf.summary.merge_all()

        # Session ,saver, and writer
        print("Setting up Session and Saver...")
        self.sess = tf.Session()
        self.saver = tf.train.Saver(max_to_keep=2)
        # create two summary writers to show training loss and validation loss in the same graph
        # need to create two folders 'train' and 'validation' inside FLAGS.logs_dir
        self.train_writer = tf.summary.FileWriter(
            os.path.join(self.flags.logs_dir, 'train'), self.sess.graph)
        self.validation_writer = tf.summary.FileWriter(
            os.path.join(self.flags.logs_dir, 'validation'))

        print("Initialize tf variables")
        self.sess.run(tf.global_variables_initializer())
        ckpt = tf.train.get_checkpoint_state(self.flags.logs_dir)
        if ckpt and ckpt.model_checkpoint_path:
            self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            print("Model restored...")
        return
Ejemplo n.º 8
0
    class TestBrightWeight(unittest.TestCase):

        # Global setting
        NUM_OF_CLASSESS = 4
        IMAGE_SIZE = 10

        # Tf placeholder
        image = tf.placeholder(tf.float32,
                               shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3],
                               name="input_image")
        kernels = tf.cast(
            utils.weight_variable([3, 3, 3, NUM_OF_CLASSESS], name="weight"),
            tf.float32)
        bias = tf.cast(utils.bias_variable([NUM_OF_CLASSESS], name="bias"),
                       tf.float32)
        image_segment = utils.conv2d_basic(image, kernels, bias)

        neighbor_indeces = tf.placeholder(tf.int64, name="neighbor_indeces")
        neighbor_vals = tf.placeholder(tf.float32, name="neighbor_vals")
        neighbor_shape = tf.placeholder(tf.int64, name="neighbor_shape")
        neighbor_filter = (neighbor_indeces, neighbor_vals, neighbor_shape)
        _image_weights = brightness_weight(image, neighbor_filter, sigma_I=10)
        image_weights = convert_to_batchTensor(*_image_weights)
        dense_image_weights = tf.sparse_to_dense(image_weights.indices,
                                                 image_weights.dense_shape,
                                                 image_weights.values)

        soft_ncuts = soft_ncut(image, image_segment, image_weights)
        loss = tf.reduce_sum(soft_ncuts)

        # Optimizer
        trainable_var = tf.trainable_variables()
        optimizer = tf.train.AdamOptimizer(1e-4)
        grads = optimizer.compute_gradients(loss, var_list=trainable_var)
        trainer = optimizer.apply_gradients(grads)

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())

        # Data
        x = np.arange(IMAGE_SIZE * IMAGE_SIZE).reshape(IMAGE_SIZE, IMAGE_SIZE)
        x = np.moveaxis(np.tile(x, [3, 1, 1]), [0, 1, 2], [2, 0, 1])
        x = x[np.newaxis, :, :, :]

        def test_image_weight(self):
            """
            """
            image_shape = self.image.get_shape().as_list()[1:3]
            gauss_indeces, gauss_vals = gaussian_neighbor(image_shape,
                                                          sigma_X=4,
                                                          r=5)
            weight_shapes = np.prod(image_shape)
            sparse_bright_weight = self.sess.run(
                self.dense_image_weights,
                feed_dict={
                    self.image: self.x,
                    self.neighbor_indeces: gauss_indeces,
                    self.neighbor_vals: gauss_vals,
                    self.neighbor_shape: [weight_shapes, weight_shapes]
                })

            # Compare with dense version
            dense_bright_weight = dense_brightness_weight(self.x)
            max_err = np.max(np.abs(sparse_bright_weight -
                                    dense_bright_weight))
            print('max error of image_weights = %.4e' % max_err)
            np.testing.assert_allclose(sparse_bright_weight,
                                       dense_bright_weight,
                                       rtol=1e-6,
                                       atol=1e-6)

        def test_soft_ncut(self):
            """
            """
            image_shape = self.image.get_shape().as_list()[1:3]
            gauss_indeces, gauss_vals = gaussian_neighbor(image_shape,
                                                          sigma_X=4,
                                                          r=5)
            weight_shapes = np.prod(image_shape)
            sp_soft_ncut, image_segment = self.sess.run(
                [self.soft_ncuts, self.image_segment],
                feed_dict={
                    self.image: self.x,
                    self.neighbor_indeces: gauss_indeces,
                    self.neighbor_vals: gauss_vals,
                    self.neighbor_shape: [weight_shapes, weight_shapes]
                })

            # Compare with dense version
            dn_soft_ncut = dense_soft_ncut(self.x, image_segment)
            max_err = np.max(np.abs(sp_soft_ncut - dn_soft_ncut))
            print('max error of %s = %.4e' % ('soft_ncut', max_err))
            np.testing.assert_allclose(sp_soft_ncut,
                                       dn_soft_ncut,
                                       rtol=1e-4,
                                       atol=1e-6)

        def test_train(self):
            image_shape = self.image.get_shape().as_list()[1:3]
            gauss_indeces, gauss_vals = gaussian_neighbor(image_shape,
                                                          sigma_X=4,
                                                          r=5)
            weight_shapes = np.prod(image_shape)
            result, _ = self.sess.run(
                [self.soft_ncuts, self.trainer],
                feed_dict={
                    self.image: self.x,
                    self.neighbor_indeces: gauss_indeces,
                    self.neighbor_vals: gauss_vals,
                    self.neighbor_shape: [weight_shapes, weight_shapes]
                })