示例#1
0
class BanditSegNet:
    ''' Network described by
            https://arxiv.org/pdf/1511.00561.pdf '''
    def load_vgg_weights(self):
        """ Use the VGG model trained on
            imagent dataset as a starting point for training """
        vgg_path = "models/imagenet-vgg-verydeep-19.mat"
        vgg_mat = scipy.io.loadmat(vgg_path)

        self.vgg_params = np.squeeze(vgg_mat['layers'])
        self.layers = ('conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
                       'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
                       'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
                       'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
                       'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3',
                       'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
                       'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4',
                       'relu5_4')

    def __init__(self, num_classes=11):
        self.num_classes = num_classes
        self.load_vgg_weights()
        self.build()

        # Begin a TensorFlow session
        config = tf.ConfigProto(allow_soft_placement=True)
        self.session = tf.Session(config=config)
        self.session.run(tf.global_variables_initializer())
        self.session.run(tf.local_variables_initializer())

        # Make saving trained weights and biases possible
        self.saver = tf.train.Saver(max_to_keep=5,
                                    keep_checkpoint_every_n_hours=1)
        self.checkpoint_directory = './checkpoints/'

        # Declare logging for logging capabilities
        self.logger = Logger()

    def vgg_weight_and_bias(self, name, W_shape, b_shape):
        """ 
            Initializes weights and biases to the pre-trained VGG model.
            
            Args:
                name: name of the layer for which you want to initialize weights
                W_shape: shape of weights tensor exkpected
                b_shape: shape of bias tensor expected
            returns:
                w_var: Initialized weight variable
                b_var: Initialized bias variable
        """
        if name not in self.layers:
            return self.weight_variable(W_shape), self.weight_variable(b_shape)
        else:
            w, b = self.vgg_params[self.layers.index(name)][0][0][0][0]
            init_w = tf.constant(value=np.transpose(w, (1, 0, 2, 3)),
                                 dtype=tf.float32,
                                 shape=W_shape)
            init_b = tf.constant(value=b.reshape(-1),
                                 dtype=tf.float32,
                                 shape=b_shape)
            w_var = tf.Variable(init_w)
            b_var = tf.Variable(init_b)
            return w_var, b_var

    def weight_variable(self, shape):
        initial = tf.truncated_normal(shape, stddev=0.1)
        return tf.Variable(initial)

    def bias_variable(self, shape):
        initial = tf.constant(0.1, shape=shape)
        return tf.Variable(initial)

    def pool_layer(self, x):
        return tf.nn.max_pool_with_argmax(x,
                                          ksize=[1, 2, 2, 1],
                                          strides=[1, 2, 2, 1],
                                          padding='SAME')

    def unpool(self, pool, ind, ksize=[1, 2, 2, 1], scope='unpool'):
        """ Unpooling layer after max_pool_with_argmax.
        Args:
            pool: max pooled output tensor
            ind: argmax indices
            ksize: ksize is the same as for the pool
        Return:
            unpool: unpooling tensor
        """
        with tf.variable_scope(scope):
            # Pool shape: BATCH_SIZE * ENCODED_WIDTH * ENCODED_HEIGHT * NUM_CLASSES
            input_shape = tf.shape(pool)
            output_shape = [
                input_shape[0], input_shape[1] * ksize[1],
                input_shape[2] * ksize[2], input_shape[3]
            ]

            flat_input_size = tf.cumprod(input_shape)[-1]
            flat_output_shape = tf.stack([
                output_shape[0],
                output_shape[1] * output_shape[2] * output_shape[3]
            ])

            pool_ = tf.reshape(pool, tf.stack([flat_input_size]))
            batch_range = tf.range(tf.cast(output_shape[0], tf.int64),
                                   dtype=ind.dtype)
            reshape_shape = tf.stack([input_shape[0], 1, 1, 1])
            reshaped_batch_range = tf.reshape(batch_range, shape=reshape_shape)

            b = tf.ones_like(ind) * reshaped_batch_range
            b = tf.reshape(b, tf.stack([flat_input_size, 1]))
            ind_ = tf.reshape(ind, tf.stack([flat_input_size, 1]))
            ind_ = tf.concat([b, ind_], 1)

            ret = tf.scatter_nd(ind_,
                                pool_,
                                shape=tf.cast(flat_output_shape, tf.int64))
            ret = tf.reshape(ret, tf.stack(output_shape))
            return ret

    def conv_layer_with_bn(self,
                           x,
                           W_shape,
                           train_phase,
                           name,
                           padding='SAME'):
        b_shape = W_shape[3]
        W, b = self.vgg_weight_and_bias(name, W_shape, [b_shape])
        convolved_output = tf.nn.conv2d(
            x, W, strides=[1, 1, 1, 1], padding=padding) + b
        batch_norm = tf.contrib.layers.batch_norm(convolved_output,
                                                  is_training=train_phase)
        return tf.nn.relu(batch_norm)

    def build(self):
        # Declare input placeholders
        self.x = tf.placeholder(tf.float32, shape=(None, None, None, 3))
        self.y = tf.placeholder(tf.int64, shape=(None, None, None))
        self.propensity = tf.placeholder(tf.float32, shape=(None, None, None))
        self.delta = tf.placeholder(tf.float32, shape=[])
        self.lagrange_mult = tf.placeholder(tf.float32, shape=[])
        self.train_phase = tf.placeholder(tf.bool, name='train_phase')
        self.rate = tf.placeholder(tf.float32, shape=[])

        # First encoder
        conv_1_1 = self.conv_layer_with_bn(self.x, [3, 3, 3, 64],
                                           self.train_phase, 'conv1_1')
        conv_1_2 = self.conv_layer_with_bn(conv_1_1, [3, 3, 64, 64],
                                           self.train_phase, 'conv1_2')
        pool_1, pool_1_argmax = self.pool_layer(conv_1_2)

        # Second encoder
        conv_2_1 = self.conv_layer_with_bn(pool_1, [3, 3, 64, 128],
                                           self.train_phase, 'conv2_1')
        conv_2_2 = self.conv_layer_with_bn(conv_2_1, [3, 3, 128, 128],
                                           self.train_phase, 'conv2_2')
        pool_2, pool_2_argmax = self.pool_layer(conv_2_2)

        # Third encoder
        conv_3_1 = self.conv_layer_with_bn(pool_2, [3, 3, 128, 256],
                                           self.train_phase, 'conv3_1')
        conv_3_2 = self.conv_layer_with_bn(conv_3_1, [3, 3, 256, 256],
                                           self.train_phase, 'conv3_2')
        conv_3_3 = self.conv_layer_with_bn(conv_3_2, [3, 3, 256, 256],
                                           self.train_phase, 'conv3_3')
        pool_3, pool_3_argmax = self.pool_layer(conv_3_3)

        # Fourth encoder
        conv_4_1 = self.conv_layer_with_bn(pool_3, [3, 3, 256, 512],
                                           self.train_phase, 'conv4_1')
        conv_4_2 = self.conv_layer_with_bn(conv_4_1, [3, 3, 512, 512],
                                           self.train_phase, 'conv4_2')
        conv_4_3 = self.conv_layer_with_bn(conv_4_2, [3, 3, 512, 512],
                                           self.train_phase, 'conv4_3')
        pool_4, pool_4_argmax = self.pool_layer(conv_4_3)

        # Fifth encoder
        conv_5_1 = self.conv_layer_with_bn(pool_4, [3, 3, 512, 512],
                                           self.train_phase, 'conv5_1')
        conv_5_2 = self.conv_layer_with_bn(conv_5_1, [3, 3, 512, 512],
                                           self.train_phase, 'conv5_2')
        conv_5_3 = self.conv_layer_with_bn(conv_5_2, [3, 3, 512, 512],
                                           self.train_phase, 'conv5_3')
        pool_5, pool_5_argmax = self.pool_layer(conv_5_3)

        # First decoder
        unpool_5 = self.unpool(pool_5, pool_5_argmax)
        deconv_5_3 = self.conv_layer_with_bn(unpool_5, [3, 3, 512, 512],
                                             self.train_phase, 'deconv5_3')
        deconv_5_2 = self.conv_layer_with_bn(deconv_5_3, [3, 3, 512, 512],
                                             self.train_phase, 'deconv5_2')
        deconv_5_1 = self.conv_layer_with_bn(deconv_5_2, [3, 3, 512, 512],
                                             self.train_phase, 'deconv5_1')

        # Second decoder
        unpool_4 = self.unpool(deconv_5_1, pool_4_argmax)
        deconv_4_3 = self.conv_layer_with_bn(unpool_4, [3, 3, 512, 512],
                                             self.train_phase, 'deconv4_3')
        deconv_4_2 = self.conv_layer_with_bn(deconv_4_3, [3, 3, 512, 512],
                                             self.train_phase, 'deconv4_2')
        deconv_4_1 = self.conv_layer_with_bn(deconv_4_2, [3, 3, 512, 256],
                                             self.train_phase, 'deconv4_1')

        # Third decoder
        unpool_3 = self.unpool(deconv_4_1, pool_3_argmax)
        deconv_3_3 = self.conv_layer_with_bn(unpool_3, [3, 3, 256, 256],
                                             self.train_phase, 'deconv3_3')
        deconv_3_2 = self.conv_layer_with_bn(deconv_3_3, [3, 3, 256, 256],
                                             self.train_phase, 'deconv3_2')
        deconv_3_1 = self.conv_layer_with_bn(deconv_3_2, [3, 3, 256, 128],
                                             self.train_phase, 'deconv3_1')

        # Fourth decoder
        unpool_2 = self.unpool(deconv_3_1, pool_2_argmax)
        deconv_2_2 = self.conv_layer_with_bn(unpool_2, [3, 3, 128, 128],
                                             self.train_phase, 'deconv2_2')
        deconv_2_1 = self.conv_layer_with_bn(deconv_2_2, [3, 3, 128, 64],
                                             self.train_phase, 'deconv2_1')

        # Fifth decoder
        unpool_1 = self.unpool(deconv_2_1, pool_1_argmax)
        deconv_1_2 = self.conv_layer_with_bn(unpool_1, [3, 3, 64, 64],
                                             self.train_phase, 'deconv1_2')
        deconv_1_1 = self.conv_layer_with_bn(deconv_1_2, [3, 3, 64, 32],
                                             self.train_phase, 'deconv1_1')

        # Produce class scores
        # score_1 dimensions: BATCH_SIZE * WIDTH * HEIGHT * NUM_CLASSES
        score_1 = self.conv_layer_with_bn(deconv_1_1,
                                          [1, 1, 32, self.num_classes],
                                          self.train_phase, 'score_1')

        # Compute Empirical Risk Minimization loss
        logits = tf.reshape(score_1, (-1, self.num_classes))
        softmaxed = tf.nn.softmax(logits)
        numerator = tf.multiply(softmaxed, (self.delta - self.lagrange_mult))

        # Have (5, 153600, 11)
        # Want (768000, 11)
        propensity_shape = tf.shape(self.propensity)
        flat_prop_size = tf.cumprod(propensity_shape)[-2]
        propensity_ = tf.reshape(self.propensity,
                                 tf.stack([flat_prop_size, self.num_classes]))

        # Loss of information possible here? Should be fine.
        self.loss = tf.reduce_mean(self.lagrange_mult +
                                   tf.divide(numerator, propensity_))

        # Declare optimizer
        optimizer = tf.train.AdamOptimizer(self.rate)
        self.train_step = optimizer.minimize(self.loss)

        # Metrics
        self.prediction = tf.argmax(score_1, axis=3, name="prediction")
        self.accuracy = tf.contrib.metrics.accuracy(self.prediction,
                                                    self.y,
                                                    name='accuracy')
        self.mean_IoU = tf.contrib.metrics.streaming_mean_iou(self.prediction,
                                                              self.y,
                                                              self.num_classes,
                                                              name='mean_IoU')

    def restore_session(self):
        global_step = 0

        if not os.path.exists(self.checkpoint_directory):
            raise IOError(self.checkpoint_directory + ' does not exist.')
        else:
            path = tf.train.get_checkpoint_state(self.checkpoint_directory)
            if path is None:
                pass
            else:
                self.saver.restore(self.session, path.model_checkpoint_path)
                global_step = int(path.model_checkpoint_path.split('-')[-1])

        return global_step

    def train(self,
              dataset_dir,
              feedback_dir,
              lagrange=0.8,
              num_iterations=10000,
              learning_rate=0.1,
              batch_size=5):

        current_step = self.restore_session()

        bdr = BatchDatasetReader(dataset_dir,
                                 480,
                                 320,
                                 current_step,
                                 batch_size,
                                 trainval_only=True)
        bfr = BanditFeedbackReader(feedback_dir, current_step)

        # Begin Training
        for i in range(current_step, num_iterations):

            # One training step
            images, ground_truths = bdr.next_training_batch()
            deltas, propensities = bfr.next_item_batch()

            feed_dict = {
                self.x: images,
                self.y: ground_truths,
                self.propensity: propensities,
                self.delta: np.mean(deltas),
                self.lagrange_mult: lagrange,
                self.train_phase: 1,
                self.rate: learning_rate
            }

            print('run train step: ' + str(i))
            self.train_step.run(session=self.session, feed_dict=feed_dict)

            # Print loss every 10 iterations
            if i % 10 == 0:
                train_loss = self.session.run(self.loss, feed_dict=feed_dict)
                print("Step: %d, Train_loss:%g" % (i, train_loss))

            # Run against validation dataset for 100 iterations
            if i % 100 == 0:
                images, ground_truths = bdr.next_training_batch()

                feed_dict = {
                    self.x: images,
                    self.y: ground_truths,
                    self.propensity: propensities,
                    self.delta: np.mean(deltas),
                    self.lagrange_mult: lagrange,
                    self.train_phase: 1,
                    self.rate: learning_rate
                }

                val_loss = self.session.run(self.loss, feed_dict=feed_dict)
                val_accuracy = self.session.run(self.accuracy,
                                                feed_dict=feed_dict)
                val_mean_IoU, update_op = self.session.run(self.mean_IoU,
                                                           feed_dict=feed_dict)
                print("%s ---> Validation_loss: %g" %
                      (datetime.datetime.now(), val_loss))
                print("%s ---> Validation_accuracy: %g" %
                      (datetime.datetime.now(), val_accuracy))

                self.logger.log("%s ---> Number of epochs: %g\n" %
                                (datetime.datetime.now(),
                                 math.floor((i * batch_size) / bdr.num_train)))
                self.logger.log("%s ---> Number of iterations: %g\n" %
                                (datetime.datetime.now(), i))
                self.logger.log("%s ---> Validation_loss: %g\n" %
                                (datetime.datetime.now(), val_loss))
                self.logger.log("%s ---> Validation_accuracy: %g\n" %
                                (datetime.datetime.now(), val_accuracy))
                self.logger.log_for_graphing(i, val_loss, val_accuracy,
                                             val_mean_IoU)

                # Save the model variables
                self.saver.save(self.session,
                                self.checkpoint_directory + 'segnet',
                                global_step=i)

            # Print outputs every 1000 iterations
            if i % 1000 == 0:
                self.logger.graph_training_stats()
示例#2
0
class SegNetLogger:
    ''' Network described by
            https://arxiv.org/pdf/1511.00561.pdf '''

    def load_vgg_weights(self):
        """ Use the VGG model trained on
            imagent dataset as a starting point for training """
        vgg_path = "models/imagenet-vgg-verydeep-19.mat"
        vgg_mat = scipy.io.loadmat(vgg_path)

        self.vgg_params = np.squeeze(vgg_mat['layers'])
        self.layers = ('conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
                        'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
                        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
                        'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
                        'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
                        'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
                        'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
                        'relu5_3', 'conv5_4', 'relu5_4')

    def __init__(self, num_classes=11):
        self.num_classes = num_classes
        self.load_vgg_weights()
        self.build()

        # Begin a TensorFlow session
        config = tf.ConfigProto(allow_soft_placement=True)
        self.session = tf.Session(config=config)
        self.session.run(tf.global_variables_initializer())
        self.session.run(tf.local_variables_initializer())

        # Make saving trained weights and biases possible
        self.saver = tf.train.Saver(max_to_keep = 5, 
                                    keep_checkpoint_every_n_hours = 1)
        self.checkpoint_directory = './checkpoints/'

        self.logger = Logger()
        # Add summary logging capability
        """
        self.train_writer = tf.summary.FileWriter('./summaries' + '/train', 
                                                  self.session.graph)
        self.val_writer = tf.summary.FileWriter('./summaries' + '/val', 
                                                 self.session.graph)
        """


    def vgg_weight_and_bias(self, name, W_shape, b_shape):
        """ 
            Initializes weights and biases to the pre-trained VGG model.
            
            Args:
                name: name of the layer for which you want to initialize 
                      weights
                W_shape: shape of weights tensor exkpected
                b_shape: shape of bias tensor expected
            returns:
                w_var: Initialized weight variable
                b_var: Initialized bias variable
        """
        if name not in self.layers:
            return self.weight_variable(W_shape), self.weight_variable(b_shape)
        else:
            w, b = self.vgg_params[self.layers.index(name)][0][0][0][0]
            init_w = tf.constant(value=np.transpose(w, (1, 0, 2, 3)), 
                                 dtype=tf.float32, shape=W_shape)
            init_b = tf.constant(value=b.reshape(-1), dtype=tf.float32, 
                                 shape=b_shape)
            w_var = tf.Variable(init_w)
            b_var = tf.Variable(init_b)
            return w_var, b_var 

    def weight_variable(self, shape):
        initial = tf.truncated_normal(shape, stddev=0.1)
        return tf.Variable(initial)

    def bias_variable(self, shape):
        initial = tf.constant(0.1, shape=shape)
        return tf.Variable(initial)

    def pool_layer(self, x):
        return tf.nn.max_pool_with_argmax(x, ksize=[1, 2, 2, 1], 
                                          strides=[1, 2, 2, 1], padding='SAME')

    def unpool(self, pool, ind, ksize=[1, 2, 2, 1], scope='unpool'):
        """ Unpooling layer after max_pool_with_argmax.
        Args:
            pool: max pooled output tensor
            ind: argmax indices
            ksize: ksize is the same as for the pool
        Return:
            unpool: unpooling tensor
        """
        with tf.variable_scope(scope):
            input_shape =  tf.shape(pool)
            output_shape = [input_shape[0], 
                            input_shape[1] * ksize[1], 
                            input_shape[2] * ksize[2], 
                            input_shape[3]]

            flat_input_size = tf.cumprod(input_shape)[-1]
            flat_output_shape = tf.stack([output_shape[0], output_shape[1] 
                                                           * output_shape[2] 
                                                           * output_shape[3]])

            pool_ = tf.reshape(pool, tf.stack([flat_input_size]))
            batch_range = tf.range(tf.cast(output_shape[0], tf.int64), 
                                   dtype=ind.dtype)
            reshape_shape = tf.stack([input_shape[0], 1, 1, 1])
            reshaped_batch_range = tf.reshape(batch_range, 
                                              shape=reshape_shape)

            b = tf.ones_like(ind) * reshaped_batch_range
            b = tf.reshape(b, tf.stack([flat_input_size, 1]))
            ind_ = tf.reshape(ind, tf.stack([flat_input_size, 1]))
            ind_ = tf.concat([b, ind_], 1)

            ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, 
                                                           tf.int64))
            ret = tf.reshape(ret, tf.stack(output_shape))
            return ret

    def conv_layer_with_bn(self, x, W_shape, train_phase, name, 
                           padding='SAME'):
        b_shape = W_shape[3]
        W, b = self.vgg_weight_and_bias(name, W_shape, [b_shape])
        convolved_output = tf.nn.conv2d(x, W, strides=[1,1,1,1], 
                                        padding=padding) + b
        batch_norm = tf.contrib.layers.batch_norm(convolved_output, 
                                                  is_training=train_phase)
        return tf.nn.relu(batch_norm)

    def build(self):
        # Declare placeholders
        self.x = tf.placeholder(tf.float32, shape=(None, None, None, 3))
        # self.y dimensions = BATCH_SIZE * WIDTH * HEIGHT
        self.y = tf.placeholder(tf.int64, shape=(None, None, None))
        expected = tf.expand_dims(self.y, -1)
        self.train_phase = tf.placeholder(tf.bool, name='train_phase')
        self.rate = tf.placeholder(tf.float32, shape=[])

        # First encoder
        conv_1_1 = self.conv_layer_with_bn(self.x, [3, 3, 3, 64], 
                                           self.train_phase, 'conv1_1')
        conv_1_2 = self.conv_layer_with_bn(conv_1_1, [3, 3, 64, 64], 
                                           self.train_phase, 'conv1_2')
        pool_1, pool_1_argmax = self.pool_layer(conv_1_2)

        # Second encoder
        conv_2_1 = self.conv_layer_with_bn(pool_1, [3, 3, 64, 128], 
                                           self.train_phase, 'conv2_1')
        conv_2_2 = self.conv_layer_with_bn(conv_2_1, [3, 3, 128, 128], 
                                           self.train_phase, 'conv2_2')
        pool_2, pool_2_argmax = self.pool_layer(conv_2_2)

        # Third encoder
        conv_3_1 = self.conv_layer_with_bn(pool_2, [3, 3, 128, 256], 
                                           self.train_phase, 'conv3_1')
        conv_3_2 = self.conv_layer_with_bn(conv_3_1, [3, 3, 256, 256], 
                                           self.train_phase, 'conv3_2')
        conv_3_3 = self.conv_layer_with_bn(conv_3_2, [3, 3, 256, 256], 
                                           self.train_phase, 'conv3_3')
        pool_3, pool_3_argmax = self.pool_layer(conv_3_3)

        # Fourth encoder
        conv_4_1 = self.conv_layer_with_bn(pool_3, [3, 3, 256, 512], 
                                           self.train_phase, 'conv4_1')
        conv_4_2 = self.conv_layer_with_bn(conv_4_1, [3, 3, 512, 512], 
                                           self.train_phase, 'conv4_2')
        conv_4_3 = self.conv_layer_with_bn(conv_4_2, [3, 3, 512, 512], 
                                           self.train_phase, 'conv4_3')
        pool_4, pool_4_argmax = self.pool_layer(conv_4_3)

        # Fifth encoder
        conv_5_1 = self.conv_layer_with_bn(pool_4, [3, 3, 512, 512], 
                                           self.train_phase, 'conv5_1')
        conv_5_2 = self.conv_layer_with_bn(conv_5_1, [3, 3, 512, 512], 
                                           self.train_phase, 'conv5_2')
        conv_5_3 = self.conv_layer_with_bn(conv_5_2, [3, 3, 512, 512], 
                                           self.train_phase, 'conv5_3')
        pool_5, pool_5_argmax = self.pool_layer(conv_5_3)

        # First decoder
        unpool_5 = self.unpool(pool_5, pool_5_argmax)
        deconv_5_3 = self.conv_layer_with_bn(unpool_5, [3, 3, 512, 512], 
                                             self.train_phase, 'deconv5_3')
        deconv_5_2 = self.conv_layer_with_bn(deconv_5_3, [3, 3, 512, 512], 
                                             self.train_phase, 'deconv5_2')
        deconv_5_1 = self.conv_layer_with_bn(deconv_5_2, [3, 3, 512, 512], 
                                             self.train_phase, 'deconv5_1')

        # Second decoder
        unpool_4 = self.unpool(deconv_5_1, pool_4_argmax)
        deconv_4_3 = self.conv_layer_with_bn(unpool_4, [3, 3, 512, 512], 
                                             self.train_phase, 'deconv4_3')
        deconv_4_2 = self.conv_layer_with_bn(deconv_4_3, [3, 3, 512, 512], 
                                             self.train_phase, 'deconv4_2')
        deconv_4_1 = self.conv_layer_with_bn(deconv_4_2, [3, 3, 512, 256], 
                                             self.train_phase, 'deconv4_1')

        # Third decoder
        unpool_3 = self.unpool(deconv_4_1, pool_3_argmax)
        deconv_3_3 = self.conv_layer_with_bn(unpool_3, [3, 3, 256, 256], 
                                             self.train_phase, 'deconv3_3')
        deconv_3_2 = self.conv_layer_with_bn(deconv_3_3, [3, 3, 256, 256], 
                                             self.train_phase, 'deconv3_2')
        deconv_3_1 = self.conv_layer_with_bn(deconv_3_2, [3, 3, 256, 128], 
                                             self.train_phase, 'deconv3_1')

        # Fourth decoder
        unpool_2 = self.unpool(deconv_3_1, pool_2_argmax)
        deconv_2_2 = self.conv_layer_with_bn(unpool_2, [3, 3, 128, 128], 
                                             self.train_phase, 'deconv2_2')
        deconv_2_1 = self.conv_layer_with_bn(deconv_2_2, [3, 3, 128, 64], 
                                             self.train_phase, 'deconv2_1')

        # Fifth decoder
        unpool_1 = self.unpool(deconv_2_1, pool_1_argmax)
        deconv_1_2 = self.conv_layer_with_bn(unpool_1, [3, 3, 64, 64], 
                                             self.train_phase, 'deconv1_2')
        deconv_1_1 = self.conv_layer_with_bn(deconv_1_2, [3, 3, 64, 32], 
                                             self.train_phase, 'deconv1_1')

        # Produce class scores
        # score_1 dimensions: BATCH_SIZE * WIDTH * HEIGHT * NUM_CLASSES
        score_1 = self.conv_layer_with_bn(deconv_1_1, 
                                          [1, 1, 32, self.num_classes], 
                                          self.train_phase, 
                                          'score_1')
        logits = tf.reshape(score_1, (-1, self.num_classes))

        # Prepare network outputs
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.reshape(expected, [-1]), 
            logits=logits, 
            name='x_entropy')
        self.loss = tf.reduce_mean(cross_entropy, name='x_entropy_mean')
        optimizer = tf.train.AdamOptimizer(self.rate)
        self.train_step = optimizer.minimize(self.loss)
        
        # Metrics
        self.prediction = tf.argmax(score_1, axis=3, name="prediction")
        self.accuracy = tf.contrib.metrics.accuracy(self.prediction, 
                                                    self.y, 
                                                    name='accuracy')
        self.mean_IoU = tf.contrib.metrics.streaming_mean_iou(self.prediction, 
                                                    self.y,
                                                    self.num_classes, 
                                                    name='mean_IoU')

        self.propensity = tf.nn.softmax(logits)


    def restore_session(self):
        global_step = 0

        if not os.path.exists(self.checkpoint_directory):
            raise IOError(self.checkpoint_directory + ' does not exist.')
        else:
            path = tf.train.get_checkpoint_state(self.checkpoint_directory)
            if path is None:
                pass
            else:
                self.saver.restore(self.session, path.model_checkpoint_path)
                global_step = int(path.model_checkpoint_path.split('-')[-1])

        return global_step


    def train(self, dataset_directory, num_iterations=10000, learning_rate=0.1, 
              batch_size=5):
        current_step = self.restore_session()

        bdr = BatchDatasetReader(dataset_directory, 480, 320, current_step, 
                                 batch_size, trainval_only=True)

        # Begin Training
        for i in range(current_step, num_iterations):

            # One training step
            images, ground_truths = bdr.next_training_batch()
            feed_dict = {self.x: images, self.y: ground_truths, 
                         self.train_phase: 1, self.rate: learning_rate}
            print('run train step: ' + str(i))
            self.train_step.run(session=self.session, feed_dict=feed_dict)

            # Print loss every 10 iterations
            if i % 10 == 0:
                train_loss = self.session.run(self.loss, feed_dict=feed_dict)
                print("Step: %d, Train_loss:%g" % (i, train_loss))

            # Run against validation dataset for 100 iterations
            if i % 100 == 0:
                # Feed into network
                images, ground_truths = bdr.next_val_batch()
                feed_dict = {self.x: images, self.y: ground_truths, 
                             self.train_phase: 1, self.rate: learning_rate}

                # Get metrics and print them
                val_loss = self.session.run(self.loss, feed_dict=feed_dict)
                val_accuracy = self.session.run(self.accuracy, 
                                                feed_dict=feed_dict)
                val_mean_IoU, update_op = self.session.run(self.mean_IoU, 
                                                feed_dict=feed_dict)
                print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), 
                                                       val_loss))
                print("%s ---> Validation_accuracy: %g" % 
                      (datetime.datetime.now(), val_accuracy))

                # Log stuff
                self.logger.log("%s ---> Number of epochs: %g\n" % 
                                (datetime.datetime.now(), 
                                 math.floor((i * batch_size)/bdr.num_train)))
                self.logger.log("%s ---> Number of iterations: %g\n" % 
                                 (datetime.datetime.now(), i))
                self.logger.log("%s ---> Validation_loss: %g\n" % 
                                 (datetime.datetime.now(), val_loss))
                self.logger.log("%s ---> Validation_accuracy: %g\n" % 
                                 (datetime.datetime.now(), val_accuracy))
                self.logger.log_for_graphing(i, val_loss, val_accuracy, 
                                             val_mean_IoU)

                # Save the model variables
                self.saver.save(self.session, 
                                self.checkpoint_directory + 'segnet', 
                                global_step = i)

            # Print outputs every 1000 iterations
            if i % 1000 == 0:
                self.logger.graph_training_stats()


    def build_contextual_feedback_log(self, dataset_directory, 
                                      learning_rate=0.1):

        # Get trained weights and biases
        current_step = self.restore_session()

        DESIRED_LOG_SIZE_MULTIPLIER = 1

        output_dir = "./logged_bandit_feedback/"
        if not os.path.exists(output_dir):
            os.makedirs(output_dir) 

        log = []
        dr = ValDatasetReader(480, 320, dataset_directory)
        for j in range(DESIRED_LOG_SIZE_MULTIPLIER):

            for i in range(dr.val_data_size):

                i = (dr.val_data_size * j) + i

                image, ground_truth, image_file = dr.next_val_pair()

                feed_dict = {self.x: [image], self.y: [ground_truth], 
                             self.train_phase: 1, self.rate: learning_rate}

                loss = self.session.run(self.loss, feed_dict=feed_dict)
                propensity = self.session.run(self.propensity, 
                                              feed_dict=feed_dict)

                log.append((i, loss, propensity))
                print(i)

                if i % 100 == 0 and i != 0:
                    with open(output_dir + 'log-' + str(int(i/100)), 'wb') as fp:
                        pickle.dump(log, fp)
                        log = []


        with open(output_dir + 'meta', 'a') as metafile:
            metafile.write("size, " + str(dr.val_data_size * 
                                          DESIRED_LOG_SIZE_MULTIPLIER))
示例#3
0
class BatchSegDeconvNet:
  ''' Network described by,
  https://arxiv.org/pdf/1505.04366v1.pdf
  and https://arxiv.org/pdf/1505.07293.pdf
  and https://arxiv.org/pdf/1511.00561.pdf 

  Hybrid network: SegNet + DeconvNet '''

  def load_vgg_weights(self):
    """ Use the VGG model trained on
      imagent dataset as a starting point for training """
    vgg_path = "models/imagenet-vgg-verydeep-19.mat"
    vgg_mat = scipy.io.loadmat(vgg_path)

    self.vgg_params = np.squeeze(vgg_mat['layers'])
    self.layers = ('conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
            'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
            'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
            'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
            'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
            'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
            'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
            'relu5_3', 'conv5_4', 'relu5_4')

  def __init__(self, dataset_directory, num_classes=11):
    self.dataset_directory = dataset_directory

    self.num_classes = num_classes

    self.load_vgg_weights()

    self.build()

    # Begin a TensorFlow session
    config = tf.ConfigProto(allow_soft_placement=True)
    self.session = tf.Session(config=config)
    self.session.run(tf.global_variables_initializer())

    # Make saving trained weights and biases possible
    self.saver = tf.train.Saver(max_to_keep = 5, keep_checkpoint_every_n_hours = 1)
    self.checkpoint_directory = './checkpoints/'
    self.logger = Logger()

  def vgg_weight_and_bias(self, name, W_shape, b_shape):
    """ 
      Initializes weights and biases to the pre-trained VGG model.
      
      Args:
        name: name of the layer for which you want to initialize weights
        W_shape: shape of weights tensor expected
        b_shape: shape of bias tensor expected
      returns:
        w_var: Initialized weight variable
        b_var: Initialized bias variable
    """
    if name not in self.layers:
      raise KeyError("Layer missing in VGG model or mispelled. ")
    else:
      w, b = self.vgg_params[self.layers.index(name)][0][0][0][0]
      init_w = tf.constant(value=np.transpose(w, (1, 0, 2, 3)), dtype=tf.float32, shape=W_shape)
      init_b = tf.constant(value=b.reshape(-1), dtype=tf.float32, shape=b_shape)
      w_var = tf.Variable(init_w)
      b_var = tf.Variable(init_b)
      return w_var, b_var 

  def weight_variable(self, shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

  def bias_variable(self, shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

  def pool_layer(self, x):
    return tf.nn.max_pool_with_argmax(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

  def unpool(self, pool, ind, ksize=[1, 2, 2, 1], scope='unpool'):
    """
       Unpooling layer after max_pool_with_argmax.
       Args:
           pool:   max pooled output tensor
           ind:      argmax indices
           ksize:     ksize is the same as for the pool
       Return:
           unpool:    unpooling tensor
    """
    with tf.variable_scope(scope):
      input_shape =  tf.shape(pool)
      output_shape = [input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]]

      flat_input_size = tf.cumprod(input_shape)[-1]
      flat_output_shape = tf.stack([output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]])

      pool_ = tf.reshape(pool, tf.stack([flat_input_size]))
      batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype), 
                                        shape=tf.stack([input_shape[0], 1, 1, 1]))
      b = tf.ones_like(ind) * batch_range
      b = tf.reshape(b, tf.stack([flat_input_size, 1]))
      ind_ = tf.reshape(ind, tf.stack([flat_input_size, 1]))
      ind_ = tf.concat([b, ind_], 1)

      ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
      ret = tf.reshape(ret, tf.stack(output_shape))
      return ret

  
  def unravel_argmax(self, argmax, shape):
    ''' Implementation idea from: 
        https://github.com/tensorflow/tensorflow/issues/2169 '''
    output_list = []
    output_list.append(argmax // (shape[2] * shape[3]))
    output_list.append(argmax % (shape[2] * shape[3]) // shape[3])
    return tf.stack(output_list)
  
  def unpool_layer2x2_batch(self, x, argmax):
    '''
    Args:
        x: 4D tensor of shape [batch_size x height x width x channels]
        argmax: A Tensor of type Targmax. 4-D. The flattened indices of the max
        values chosen for each output.
    Return:
        4D output tensor of shape [batch_size x 2*height x 2*width x channels]
    '''
    bottom_shape = tf.shape(x)
    top_shape = [bottom_shape[0], bottom_shape[1]*2, bottom_shape[2]*2, bottom_shape[3]]

    batch_size = top_shape[0]
    height = top_shape[1]
    width = top_shape[2]
    channels = top_shape[3]

    argmax_shape = tf.to_int64([batch_size, height, width, channels])
    argmax = self.unravel_argmax(argmax, argmax_shape)

    t1 = tf.to_int64(tf.range(channels))
    t1 = tf.tile(t1, [batch_size*(width//2)*(height//2)])
    t1 = tf.reshape(t1, [-1, channels])
    t1 = tf.transpose(t1, perm=[1, 0])
    t1 = tf.reshape(t1, [channels, batch_size, height//2, width//2, 1])
    t1 = tf.transpose(t1, perm=[1, 0, 2, 3, 4])

    t2 = tf.to_int64(tf.range(batch_size))
    t2 = tf.tile(t2, [channels*(width//2)*(height//2)])
    t2 = tf.reshape(t2, [-1, batch_size])
    t2 = tf.transpose(t2, perm=[1, 0])
    t2 = tf.reshape(t2, [batch_size, channels, height//2, width//2, 1])

    t3 = tf.transpose(argmax, perm=[1, 4, 2, 3, 0])

    t = tf.concat([t2, t3, t1], 4)
    indices = tf.reshape(t, [(height//2)*(width//2)*channels*batch_size, 4])

    x1 = tf.transpose(x, perm=[0, 3, 1, 2])
    values = tf.reshape(x1, [-1])

    delta = tf.SparseTensor(indices, values, tf.to_int64(top_shape))
    return tf.sparse_tensor_to_dense(tf.sparse_reorder(delta))
        
  def conv_layer(self, x, W_shape, b_shape, name, padding='SAME'):
    # Pass b_shape as list because need the object to be iterable for the constant initializer
    W, b = self.vgg_weight_and_bias(name, W_shape, [b_shape])

    output = tf.nn.conv2d(x, W, strides=[1,1,1,1], padding=padding) + b
    return tf.nn.relu(output)

  def deconv_layer(self, x, W_shape, b_shape, name, padding='SAME'):
    W = self.weight_variable(W_shape)
    b = self.bias_variable([b_shape])
    x_shape = tf.shape(x)
    out_shape = tf.stack([x_shape[0], x_shape[1], x_shape[2], W_shape[2]])
    return tf.nn.conv2d_transpose(x, W, out_shape, [1, 1, 1, 1], padding=padding) + b

  def build(self):
    with tf.device('/gpu:0'):
      # Declare placeholders
      self.x = tf.placeholder(tf.float32, shape=(None, None, None, 3))
      # self.y dimensions = BATCH_SIZE * WIDTH * HEIGHT
      self.y = tf.placeholder(tf.int64, shape=(None, None, None))
      expected = tf.expand_dims(self.y, -1)
      self.rate = tf.placeholder(tf.float32, shape=[])

      # First encoder
      conv_1_1 = self.conv_layer(self.x, [3, 3, 3, 64], 64, 'conv1_1')
      conv_1_2 = self.conv_layer(conv_1_1, [3, 3, 64, 64], 64, 'conv1_2')
      pool_1, pool_1_argmax = self.pool_layer(conv_1_2)

      # Second encoder
      conv_2_1 = self.conv_layer(pool_1, [3, 3, 64, 128], 128, 'conv2_1')
      conv_2_2 = self.conv_layer(conv_2_1, [3, 3, 128, 128], 128, 'conv2_2')
      pool_2, pool_2_argmax = self.pool_layer(conv_2_2)

      # Third encoder
      conv_3_1 = self.conv_layer(pool_2, [3, 3, 128, 256], 256, 'conv3_1')
      conv_3_2 = self.conv_layer(conv_3_1, [3, 3, 256, 256], 256, 'conv3_2')
      conv_3_3 = self.conv_layer(conv_3_2, [3, 3, 256, 256], 256, 'conv3_3')
      pool_3, pool_3_argmax = self.pool_layer(conv_3_3)

      # Fourth encoder
      conv_4_1 = self.conv_layer(pool_3, [3, 3, 256, 512], 512, 'conv4_1')
      conv_4_2 = self.conv_layer(conv_4_1, [3, 3, 512, 512], 512, 'conv4_2')
      conv_4_3 = self.conv_layer(conv_4_2, [3, 3, 512, 512], 512, 'conv4_3')
      pool_4, pool_4_argmax = self.pool_layer(conv_4_3)

      # Fifth encoder
      conv_5_1 = self.conv_layer(pool_4, [3, 3, 512, 512], 512, 'conv5_1')
      conv_5_2 = self.conv_layer(conv_5_1, [3, 3, 512, 512], 512, 'conv5_2')
      conv_5_3 = self.conv_layer(conv_5_2, [3, 3, 512, 512], 512, 'conv5_3')
      pool_5, pool_5_argmax = self.pool_layer(conv_5_3)

      # First decoder
      unpool_5 = self.unpool(pool_5, pool_5_argmax)
      deconv_5_3 = self.deconv_layer(unpool_5, [3, 3, 512, 512], 512, 'deconv5_3')
      deconv_5_2 = self.deconv_layer(deconv_5_3, [3, 3, 512, 512], 512, 'deconv5_2')
      deconv_5_1 = self.deconv_layer(deconv_5_2, [3, 3, 512, 512], 512, 'deconv5_1')

      # Second decoder
      unpool_4 = self.unpool(deconv_5_1, pool_4_argmax)
      deconv_4_3 = self.deconv_layer(unpool_4, [3, 3, 512, 512], 512, 'deconv4_3')
      deconv_4_2 = self.deconv_layer(deconv_4_3, [3, 3, 512, 512], 512, 'deconv4_2')
      deconv_4_1 = self.deconv_layer(deconv_4_2, [3, 3, 256, 512], 256, 'deconv4_1')

      # Third decoder
      unpool_3 = self.unpool(deconv_4_1, pool_3_argmax)
      deconv_3_3 = self.deconv_layer(unpool_3, [3, 3, 256, 256], 256, 'deconv3_3')
      deconv_3_2 = self.deconv_layer(deconv_3_3, [3, 3, 256, 256], 256, 'deconv3_2')
      deconv_3_1 = self.deconv_layer(deconv_3_2, [3, 3, 128, 256], 128, 'deconv3_1')

      # Fourth decoder
      unpool_2 = self.unpool(deconv_3_1, pool_2_argmax)
      deconv_2_2 = self.deconv_layer(unpool_2, [3, 3, 128, 128], 128, 'deconv2_2')
      deconv_2_1 = self.deconv_layer(deconv_2_2, [3, 3, 64, 128], 64, 'deconv2_1')

      # Fifth decoder
      unpool_1 = self.unpool(deconv_2_1, pool_1_argmax)
      deconv_1_2 = self.deconv_layer(unpool_1, [3, 3, 64, 64], 64, 'deconv1_2')
      deconv_1_1 = self.deconv_layer(deconv_1_2, [3, 3, 32, 64], 32, 'deconv1_1')

      # Produce class scores
      # score_1 dimensions: BATCH_SIZE * WIDTH * HEIGHT * NUMBER_OF_CLASSES
      score_1 = self.deconv_layer(deconv_1_1, [1, 1, self.num_classes, 32], self.num_classes, 'score_1')
      logits = tf.reshape(score_1, (-1, self.num_classes))

      # Prepare network for training
      cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=tf.reshape(expected, [-1]), logits=logits, name='x_entropy')
      self.loss = tf.reduce_mean(cross_entropy, name='x_entropy_mean')
      self.train_step = tf.train.AdamOptimizer(self.rate).minimize(self.loss)

      # Metrics
      self.prediction = tf.argmax(score_1, axis=3, name="prediction")
      self.accuracy = tf.contrib.metrics.accuracy(self.prediction, self.y, name='accuracy')

  def restore_session(self):
    global_step = 0

    if not os.path.exists(self.checkpoint_directory):
      raise IOError(self.checkpoint_directory + ' does not exist.')
    else:
      path = tf.train.get_checkpoint_state(self.checkpoint_directory)
      if path is None:
        pass
      else:
        self.saver.restore(self.session, path.model_checkpoint_path)
        global_step = int(path.model_checkpoint_path.split('-')[-1])

    return global_step

  
  def train(self, num_iterations=10000, learning_rate=1e-6, batch_size=5):
    current_step = self.restore_session()

    bdr = BatchDatasetReader(self.dataset_directory, 480, 320, current_step, batch_size)

    # Begin Training
    for i in range(current_step, num_iterations):

      # One training step
      images, ground_truths = bdr.next_training_batch()
      feed_dict = {self.x: images, self.y: ground_truths, self.rate: learning_rate}
      print('run train step: ' + str(i))
      self.train_step.run(session=self.session, feed_dict=feed_dict)

      # Print loss every 10 iterations
      if i % 10 == 0:
        train_loss = self.session.run(self.loss, feed_dict=feed_dict)
        print("Step: %d, Train_loss:%g" % (i, train_loss))

      # Run against validation dataset for 100 iterations
      if i % 100 == 0:
        images, ground_truths = bdr.next_val_batch()
        feed_dict = {self.x: images, self.y: ground_truths, self.rate: learning_rate}
        val_loss = self.session.run(self.loss, feed_dict=feed_dict)
        val_accuracy = self.session.run(self.accuracy, feed_dict=feed_dict)
        print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), val_loss))
        print("%s ---> Validation_accuracy: %g" % (datetime.datetime.now(), val_accuracy))

        self.logger.log("%s ---> Number of epochs: %g\n" % (datetime.datetime.now(), math.floor((i * batch_size)/bdr.num_train)))
        self.logger.log("%s ---> Number of iterations: %g\n" % (datetime.datetime.now(), i))
        self.logger.log("%s ---> Validation_loss: %g\n" % (datetime.datetime.now(), val_loss))
        self.logger.log("%s ---> Validation_accuracy: %g\n" % (datetime.datetime.now(), val_accuracy))


        # Save the model variables
        self.saver.save(self.session, self.checkpoint_directory + 'segnet', global_step = i)
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        if args.network == 'resnet18':
            model = resnet18(pretrained=True, classes=args.n_classes)
        elif args.network == 'resnet50':
            model = resnet50(pretrained=True, classes=args.n_classes)
        else:
            model = resnet18(pretrained=True, classes=args.n_classes)
        self.model = model.to(device)
        self.D_model = IntraClsInfoMax(alpha=args.alpha,
                                       beta=args.beta,
                                       gamma=args.gamma).to(device)
        # print(self.model)
        # print(self.D_model)

        self.source_loader, self.val_loader = data_helper.get_train_dataloader(
            args, patches=model.is_patch_based())
        self.target_loader = data_helper.get_val_dataloader(
            args, patches=model.is_patch_based())
        self.test_loaders = {
            "val": self.val_loader,
            "test": self.target_loader
        }
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" %
              (len(self.source_loader.dataset), len(
                  self.val_loader.dataset), len(self.target_loader.dataset)))

        self.optimizer, self.scheduler = get_optim_and_scheduler(
            [self.model, self.D_model.global_d, self.D_model.local_d],
            args.epochs,
            args.learning_rate,
            args.train_all,
            nesterov=args.nesterov)
        self.dis_optimizer, self.dis_scheduler = get_optim_and_scheduler(
            [self.D_model.prior_d],
            args.epochs,
            args.learning_rate,
            args.train_all,
            nesterov=args.nesterov)  #args.learning_ratee*1e-3
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None
        self.max_test_acc = 0.0
        self.logger = Logger(self.args, update_frequency=30)
        self.results = {
            "val": torch.zeros(self.args.epochs),
            "test": torch.zeros(self.args.epochs)
        }

    def _do_epoch(self, device='cuda'):

        criterion = nn.CrossEntropyLoss()
        self.model.train()
        self.D_model.train()
        for it, ((data, jig_l, class_l),
                 d_idx) in enumerate(self.source_loader):
            data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(
                self.device), class_l.to(self.device), d_idx.to(self.device)

            self.optimizer.zero_grad()

            data_flip = torch.flip(data, (3, )).detach().clone()
            data = torch.cat((data, data_flip))
            class_l = torch.cat((class_l, class_l))

            y, M = self.model(data, feature_flag=True)

            # Classification Loss
            class_logit = self.model.class_classifier(y)
            class_loss = criterion(class_logit, class_l)

            # G loss - DIM Loss - P_loss
            M_prime = torch.cat(
                (M[1:], M[0].unsqueeze(0)),
                dim=0)  # Move feature to front position one by one
            class_prime = torch.cat((class_l[1:], class_l[0].unsqueeze(0)),
                                    dim=0)
            class_ll = (class_l, class_prime)

            DIM_loss = self.D_model(y, M, M_prime, class_ll)
            P_loss = self.D_model.prior_loss(y)

            DIM_loss = DIM_loss - P_loss
            # DIM_loss=self.beta*(DIM_loss-P_loss)
            loss = class_loss + DIM_loss
            loss.backward()
            self.optimizer.step()

            self.dis_optimizer.zero_grad()
            P_loss = self.D_model.prior_loss(y.detach())
            P_loss.backward()
            self.dis_optimizer.step()

            # Prediction
            _, cls_pred = class_logit.max(dim=1)

            losses = {
                'class': class_loss.detach().item(),
                'DIM': DIM_loss.detach().item(),
                'P_loss': P_loss.detach().item()
            }
            self.logger.log(
                it, len(self.source_loader), losses, {
                    "class": torch.sum(cls_pred == class_l.data).item(),
                }, data.shape[0])
            del loss, class_loss, class_logit, DIM_loss

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)

                class_correct = self.do_test(loader)

                class_acc = float(class_correct) / total
                self.logger.log_test(phase, {"class": class_acc})
                self.results[phase][self.current_epoch] = class_acc
                if phase == 'test' and class_acc > self.max_test_acc:
                    torch.save(
                        self.model.state_dict(),
                        os.path.join(self.logger.log_path,
                                     'best_{}.pth'.format(phase)))

    def do_test(self, loader):
        class_correct = 0
        for it, ((data, nouse, class_l), _) in enumerate(loader):
            data, nouse, class_l = data.to(self.device), nouse.to(
                self.device), class_l.to(self.device)

            class_logit = self.model(data, feature_flag=False)
            _, cls_pred = class_logit.max(dim=1)

            class_correct += torch.sum(cls_pred == class_l.data)

        return class_correct

    def do_training(self):
        for self.current_epoch in range(self.args.epochs):
            self.scheduler.step()
            self.dis_scheduler.step()
            self.logger.new_epoch(
                [*self.scheduler.get_lr(), *self.dis_scheduler.get_lr()])
            self._do_epoch()  # use self.current_epoch
        val_res = self.results["val"]
        test_res = self.results["test"]
        idx_best = val_res.argmax()
        print(
            "Best val %g, corresponding test %g - best test: %g, best epoch: %g"
            % (val_res.max(), test_res[idx_best], test_res.max(), idx_best))
        self.logger.save_best(test_res[idx_best], test_res.max())
        return self.logger, self.model
示例#5
0
class DRAM(object):
    def __init__(self, config):
        self.config = config
        self.data_init()
        self.model_init()

    def data_init(self):
        print("\nData init")
        self.dataset = Dataset(self.config)
        self.generator = Generator(self.config, self.dataset)

    def model_init(self):

        self.rnn_cell = tf.contrib.rnn
        self.config = config
        self.regularizer = tf.contrib.layers.l2_regularizer(
            scale=self.config.regularizer)
        self.initializer = tf.contrib.layers.xavier_initializer()
        self.images_ph = tf.placeholder(
            tf.float32,
            [None, self.config.input_shape, self.config.input_shape, 3])
        self.labels_ph = tf.placeholder(tf.int64, [None])
        self.N = tf.shape(self.images_ph)[0]

        # ------- GlimpseNet / LocNet -------

        with tf.variable_scope('glimpse_net'):
            self.gl = ConvGlimpseNetwork(self.config, self.images_ph)

        with tf.variable_scope('loc_net'):
            self.loc_net = LocNet(self.config)

        self.init_loc = tf.zeros(shape=[self.N, 2], dtype=tf.float32)
        with tf.variable_scope("rnn_decoder/loop_function",
                               reuse=tf.AUTO_REUSE):
            self.init_glimpse = self.gl(self.init_loc)

        self.inputs = [self.init_glimpse]
        self.inputs.extend([0] * (self.config.num_glimpses - 1))

        # ------- Recurrent network -------

        def get_next_input(output, i):

            loc, loc_mean = self.loc_net(output)
            gl_next = self.gl(loc)

            self.loc_mean_arr.append(loc_mean)
            self.sampled_loc_arr.append(loc)
            self.glimpses.append(self.gl.glimpse)

            return gl_next

        def rnn_decoder(decoder_inputs,
                        initial_state,
                        cell,
                        loop_function=None):

            with tf.variable_scope("rnn_decoder"):
                state = initial_state
                outputs = []
                prev = None

                for i, inp in enumerate(decoder_inputs):
                    if loop_function is not None and prev is not None:
                        with tf.variable_scope("loop_function",
                                               reuse=tf.AUTO_REUSE):
                            inp = loop_function(prev, i)

                    if i > 0:
                        tf.get_variable_scope().reuse_variables()

                    output, state = cell(inp, state)
                    outputs.append(output)

                    if loop_function is not None:
                        prev = output

            return outputs, state

        self.loc_mean_arr = [self.init_loc]
        self.sampled_loc_arr = [self.init_loc]
        self.glimpses = [self.gl.glimpse]

        self.lstm_cell = self.rnn_cell.LSTMCell(self.config.cell_size,
                                                state_is_tuple=True,
                                                activation=tf.nn.tanh,
                                                forget_bias=1.)
        self.init_state = self.lstm_cell.zero_state(self.N, tf.float32)
        self.outputs, self.rnn_state = rnn_decoder(
            self.inputs,
            self.init_state,
            self.lstm_cell,
            loop_function=get_next_input)

        # ------- Classification -------

        baselines = []
        for t, output in enumerate(self.outputs):
            with tf.variable_scope('baseline', reuse=tf.AUTO_REUSE):
                baseline_t = tf.layers.dense(
                    inputs=output,
                    units=2,
                    kernel_initializer=self.initializer)
            baseline_t = tf.squeeze(baseline_t)
            baselines.append(baseline_t)

        baselines = tf.stack(baselines)
        self.baselines = tf.transpose(baselines)

        with tf.variable_scope('classification', reuse=tf.AUTO_REUSE):
            self.class_prob_arr = []
            for t, op in enumerate(self.outputs):
                self.glimpse_logit = tf.layers.dense(
                    inputs=op,
                    units=self.config.num_classes,
                    kernel_initializer=self.initializer,
                    name='FCCN',
                    reuse=tf.AUTO_REUSE)
                self.glimpse_logit = tf.stop_gradient(self.glimpse_logit)
                self.glimpse_logit = tf.nn.softmax(self.glimpse_logit)
                self.class_prob_arr.append(self.glimpse_logit)
            self.class_prob_arr = tf.stack(self.class_prob_arr, axis=1)

        self.output = self.outputs[-1]
        with tf.variable_scope('classification', reuse=tf.AUTO_REUSE):
            self.logits = tf.layers.dense(inputs=self.output,
                                          units=self.config.num_classes,
                                          kernel_initializer=self.initializer,
                                          name='FCCN',
                                          reuse=tf.AUTO_REUSE)

            self.softmax = tf.nn.softmax(self.logits)

        self.sampled_locations = tf.concat(self.sampled_loc_arr, axis=0)
        self.mean_locations = tf.concat(self.loc_mean_arr, axis=0)
        self.sampled_locations = tf.reshape(
            self.sampled_locations, (self.config.num_glimpses, self.N, 2))
        self.sampled_locations = tf.transpose(self.sampled_locations,
                                              [1, 0, 2])
        self.mean_locations = tf.reshape(self.mean_locations,
                                         (self.config.num_glimpses, self.N, 2))
        self.mean_locations = tf.transpose(self.mean_locations, [1, 0, 2])
        prefix = tf.expand_dims(self.init_loc, 1)
        self.sampled_locations = tf.concat([prefix, self.sampled_locations],
                                           axis=1)
        self.mean_locations = tf.concat([prefix, self.mean_locations], axis=1)
        self.glimpses = tf.stack(self.glimpses, axis=1)

        # Losses/reward

        def loglikelihood(mean_arr, sampled_arr, sigma):
            mu = tf.stack(mean_arr)
            sampled = tf.stack(sampled_arr)
            gaussian = tf.contrib.distributions.Normal(mu, sigma)
            logll = gaussian.log_prob(sampled)
            logll = tf.reduce_sum(logll, 2)
            logll = tf.transpose(logll)
            return logll

        self.xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=self.logits, labels=self.labels_ph)
        self.xent = tf.reduce_mean(self.xent)

        self.pred_labels = tf.argmax(self.logits, 1)
        self.reward = tf.cast(tf.equal(self.pred_labels, self.labels_ph),
                              tf.float32)
        self.rewards = tf.expand_dims(self.reward, 1)
        self.rewards = tf.tile(self.rewards, [1, self.config.num_glimpses])
        self.logll = loglikelihood(self.loc_mean_arr, self.sampled_loc_arr,
                                   self.config.loc_std)
        self.advs = self.rewards - tf.stop_gradient(self.baselines)
        self.logllratio = tf.reduce_mean(self.logll * self.advs)

        self.reward = tf.reduce_mean(self.reward)

        self.baselines_mse = tf.reduce_mean(
            tf.square((self.rewards - self.baselines)))
        self.var_list = tf.trainable_variables()

        self.loss = -self.logllratio + self.xent + self.baselines_mse
        self.grads = tf.gradients(self.loss, self.var_list)
        self.grads, _ = tf.clip_by_global_norm(self.grads,
                                               self.config.max_grad_norm)

        self.setup_optimization()

        # session
        self.session_config = tf.ConfigProto()
        self.session_config.gpu_options.visible_device_list = self.config.gpu
        self.session_config.gpu_options.allow_growth = True
        self.session = tf.Session(config=self.session_config)
        self.session.run(tf.global_variables_initializer())

    def setup_optimization(self):

        # learning rate
        self.global_step = tf.get_variable(
            'global_step', [],
            initializer=tf.constant_initializer(0),
            trainable=False)

        self.training_steps_per_epoch = int(
            len(self.generator.training_ids) // self.config.batch_size)
        print('Training Step Per Epoch:', self.training_steps_per_epoch)

        self.starter_learning_rate = self.config.lr_start
        self.learning_rate = tf.train.exponential_decay(
            self.starter_learning_rate,
            self.global_step,
            self.training_steps_per_epoch,
            0.70,
            staircase=False)
        self.learning_rate = tf.maximum(self.learning_rate, self.config.lr_min)
        self.optimizer = tf.train.MomentumOptimizer(self.learning_rate,
                                                    momentum=0.90,
                                                    use_nesterov=True)
        #self.optimizer = tf.train.AdamOptimizer(self.learning_rate)
        self.train_op = self.optimizer.apply_gradients(
            zip(self.grads, self.var_list), global_step=self.global_step)

    def setup_logger(self):
        """Creates log directory and initializes logger."""

        self.summary_ops = {
            'reward': tf.summary.scalar('reward', self.reward),
            'hybrid_loss': tf.summary.scalar('hybrid_loss', self.loss),
            'cross_entropy': tf.summary.scalar('cross_entropy', self.xent),
            'baseline_mse': tf.summary.scalar('baseline_mse',
                                              self.baselines_mse),
            'logllratio': tf.summary.scalar('logllratio', self.logllratio),
            'lr': tf.summary.scalar('lr', self.learning_rate)
        }
        # 'glimpses': tf.summary.image('glimpses',tf.reshape(self.glimpses,[-1,self.config.glimpse_size,
        #                                                                  self.config.glimpse_size,
        #                                                                 3]),max_outputs=8)}

        self.eval_ops = {
            'labels': self.labels_ph,
            'pred_labels': self.pred_labels,
            'reward': self.reward,
            'hybrid_loss': self.loss,
            'cross_entropy': self.xent,
            'baseline_mse': self.baselines_mse,
            'logllratio': self.logllratio,
            'lr': self.learning_rate
        }

        self.logger = Logger(self.config.logdir,
                             sess=self.session,
                             summary_ops=self.summary_ops,
                             global_step=self.global_step,
                             eval_ops=self.eval_ops,
                             n_verbose=self.config.n_verbose,
                             var_list=self.var_list)

    def train(self):

        print('\n\n\n------------ Starting training ------------  \nT -- %s x %s \n' \
              'Model:  %s glimpses, glimpse size %s x %s \n\n\n' % (
                  self.config.input_shape, self.config.input_shape, self.config.num_glimpses, self.config.glimpse_size,
                  self.config.glimpse_size))

        self.setup_logger()

        for i in range(self.config.steps + 1):

            loc_dir_name = self.config.logdir + '/image/locations'
            traj_dir_name = self.config.logdir + '/image/trajectories'
            ROCs_dir_name = self.config.logdir + '/metrics/ROCs_AUCs/'
            PRs_dir_name = self.config.logdir + '/metrics/PRs/'

            if i == 0:
                if os.path.exists(loc_dir_name):
                    shutil.rmtree(loc_dir_name)
                    os.makedirs(loc_dir_name)
                else:
                    os.makedirs(loc_dir_name)

                if os.path.exists(traj_dir_name):
                    shutil.rmtree(traj_dir_name)
                    os.makedirs(traj_dir_name)
                else:
                    os.makedirs(traj_dir_name)

                if os.path.exists(ROCs_dir_name):
                    shutil.rmtree(ROCs_dir_name)
                    os.makedirs(ROCs_dir_name)
                else:
                    os.makedirs(ROCs_dir_name)

                if os.path.exists(PRs_dir_name):
                    shutil.rmtree(PRs_dir_name)
                    os.makedirs(PRs_dir_name)
                else:
                    os.makedirs(PRs_dir_name)

            self.logger.step = i

            images, labels = self.generator.generate()
            images = images.reshape(
                (-1, self.config.input_shape, self.config.input_shape, 3))
            labels = labels[0]
            feed_dict = {self.images_ph: images, self.labels_ph: labels}

            fetches = [
                self.output, self.rewards, self.reward, self.labels_ph,
                self.pred_labels, self.logits, self.train_op, self.loss,
                self.xent, self.baselines_mse, self.logllratio,
                self.learning_rate, self.loc_mean_arr
            ]
            output, rewards, reward, real_labels, pred_labels, logits, _, hybrid_loss, cross_entropy, baselines_mse, logllratio, lr, locations = self.session.run(
                fetches, feed_dict)

            if i % 1 == 0:

                print('\n------ Step %s ------' % (i))
                print('reward', reward)
                print('labels', real_labels)
                print('pred_labels', pred_labels)
                print('hybrid_loss', hybrid_loss)
                print('cross_entropy', cross_entropy)
                print('baseline_mse', baselines_mse)
                print('logllratio', logllratio)
                print('lr', lr)
                print('locations', locations[-1])
                print('logits', logits)
                self.logger.log('train', feed_dict=feed_dict)

            #if  i > 0 and i % 100 == 0:

            #   self.eval(i)
            #  self.logger.log('val', feed_dict=feed_dict)

            if i == self.config.steps:

                self.test(i)

            #if i == self.config.steps:
        # if i > 0 and i % 100 == 0:

        #    glimpse_images = self.session.run(self.glimpses, feed_dict)
        #   mean_locations = self.session.run(self.mean_locations, feed_dict)
        #  probs = self.session.run(self.class_prob_arr, feed_dict)

        # plot_glimpses(config=self.config, glimpse_images=glimpse_images, pred_labels=pred_labels, probs=probs,
        #   sampled_loc=mean_locations, X=images, labels=real_labels, file_name=loc_dir_name, step=i)

        #plot_trajectories(config=self.config, locations=mean_locations, X=images, labels=real_labels,
        #   pred_labels=pred_labels, file_name=traj_dir_name, step=i)

        #self.logger.save()

    def eval(self, step):
        return self.evaluate(self.session, self.images_ph, self.labels_ph,
                             self.softmax, step)

    def evaluate(self, sess, images_ph, labels_ph, softmax, step):
        print('Evaluating (%s x %s) using %s glimpses' %
              (self.config.input_shape, self.config.input_shape,
               self.config.num_glimpses))
        self.X_val, self.y_val = self.dataset.convert_to_arrays(
            self.dataset._partition[0]['val'],
            size=self.config.sampling_size_val)
        print('Validation set has %s patients' % len(self.y_val))

        X_val, y_val = self.X_val, self.y_val

        _num_examples = X_val.shape[0]
        steps_per_epoch = _num_examples // self.config.eval_batch_size

        y_scores = []
        y_trues = []

        for i in tqdm(iter(range(steps_per_epoch))):

            images, labels_val = self.dataset.next_batch(
                X_val, y_val[0], self.config.eval_batch_size, i)
            #images = images.reshape((-1, self.config.input_shape, self.config.input_shape, 3))

            softmax_val = sess.run(softmax,
                                   feed_dict={
                                       images_ph: images,
                                       labels_ph: labels_val
                                   })
            y_trues.extend(labels_val)
            y_scores.extend(softmax_val)

        y_preds = np.argmax(y_scores, 1)
        y_scores = np.array(y_scores)

        self.metrics_ROCs(y_trues, y_preds, y_scores, step)
        self.metrics(y_trues, y_preds, step)
        return

    def count_params(self):
        return self.count_parameters(self.session)

    def count_parameters(self, sess):
        variables_names = [v.name for v in tf.trainable_variables()]
        values = sess.run(variables_names)
        n_params = 0

        for k, v in zip(variables_names, values):
            print('-'.center(140, '-'))
            print('%s \t Shape: %s \t %s parameters' % (k, v.shape, v.size))
            n_params += v.size

        print('-'.center(140, '-'))
        print('Total # parameters:\t\t %s \n\n' % (n_params))
        return n_params

    def metrics_ROCs(self, y_trues, y_preds, y_scores, step, stage=None):

        y_trues_binary = label_binarize(
            y_trues, classes=list(self.dataset.le_name_mapping.values()))
        y_preds_binary = label_binarize(
            y_preds, classes=list(self.dataset.le_name_mapping.values()))
        n_classes = y_preds_binary.shape[1]
        if stage == 'test':
            fpr, tpr, _ = roc_curve(y_trues, y_scores)
        else:
            fpr, tpr, _ = roc_curve(y_trues, y_scores[:, 1])

        roc_auc = auc(fpr, tpr)

        plt.figure()

        plt.plot(fpr,
                 tpr,
                 label='ROC curve (AUC = {0:0.2f})'
                 ''.format(roc_auc),
                 color='navy',
                 linestyle=':',
                 linewidth=4)

        plt.plot([0, 1], [0, 1], 'k--', lw=2)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiving Operating Characteristic Curves')
        plt.legend(loc="lower right")
        plt.savefig(self.config.logdir + '/metrics/ROCs_AUCs/%i' % step)
        return

    def metrics(self, y_trues, y_preds, step):
        #        y_trues_binary= label_binarize(y_trues, classes=list(self.dataset.le_name_mapping.values()))
        #       y_preds_binary= label_binarize(y_preds, classes=list(self.dataset.le_name_mapping.values()))

        accuracy = accuracy_score(y_trues, y_preds)
        f1score = f1_score(y_trues, y_preds)
        recall = recall_score(y_trues, y_preds)
        precision = precision_score(y_trues, y_preds)
        names = ['accuracy', 'f1_score', 'recall', 'precision']
        pd.DataFrame(data=np.array([accuracy, f1score, recall, precision]),
                     index=names).to_csv(self.config.logdir +
                                         '/metrics/metrics_%i.csv' % step)
        return

    def load(self, checkpoint_dir):
        folder = os.path.join(checkpoint_dir, 'checkpoints')
        print('\nLoading model from <<{}>>.\n'.format(folder))

        self.saver = tf.train.Saver(self.var_list)
        ckpt = tf.train.get_checkpoint_state(folder)

        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt)
            self.saver.restore(self.session, ckpt.model_checkpoint_path)

    def patch_to_image(self, y_patches, proba=True):

        if proba == True:
            y_image = np.array([
                np.mean(y_patches[i * self.config.sampling_size_test:(i + 1) *
                                  self.config.sampling_size_test],
                        axis=0)
                for i in range(
                    int(len(y_patches) / self.config.sampling_size_test))
            ])

        else:
            y_image = np.array([
                np.mean(y_patches[i * self.config.sampling_size_test:(i + 1) *
                                  self.config.sampling_size_test]) > 0.5
                for i in range(
                    int(len(y_patches) / self.config.sampling_size_test))
            ]).reshape((-1, 1)).astype(int)
            y_image = np.asarray(y_image.flatten())
        return y_image

    def test(self, step):
        return self.testing(self.session, self.images_ph, self.labels_ph,
                            self.softmax, step)

    def testing(self, sess, images_ph, labels_ph, softmax, step):
        print('Testing (%s x %s) using %s glimpses' %
              (self.config.input_shape, self.config.input_shape,
               self.config.num_glimpses))
        print(self.dataset._partition[0]['test'])
        self.X_test, self.y_test = self.dataset.convert_to_arrays(
            self.dataset._partition[0]['test'],
            size=self.config.sampling_size_test)
        X_test, y_test = self.X_test, self.y_test
        print('y_test', y_test)
        _num_examples = X_test.shape[0]
        steps_per_epoch = _num_examples // self.config.test_batch_size

        y_scores = []
        y_trues = []

        for i in tqdm(iter(range(steps_per_epoch))):

            images, labels_test = self.dataset.next_batch(
                X_test, y_test[0], self.config.test_batch_size, i)

            print(labels_test)
            #images = images.reshape((-1, self.config.input_shape, self.config.input_shape, 3))

            softmax_test = sess.run(softmax,
                                    feed_dict={
                                        images_ph: images,
                                        labels_ph: labels_test
                                    })
            y_trues.extend(labels_test)
            y_scores.extend(softmax_test)

        y_trues = self.patch_to_image(y_trues, proba=False)
        y_scores = self.patch_to_image(y_scores, proba=True)

        y_preds = np.argmax(y_scores, 1)

        print('Test Set', self.dataset._partition[0]['test'])
        print(y_trues)
        print(y_preds)

        self.metrics_ROCs(y_trues, y_preds, y_scores, step)
        self.metrics(y_trues, y_preds, step)
        return
示例#6
0
class InvariantNet:
    def __init__(self, dataset_directory, num_classes=11):

        self.num_classes = num_classes
        self.load_vgg_weights()
        self.build()

        # Begin a TensorFlow session
        config = tf.ConfigProto(allow_soft_placement=True)
        self.session = tf.Session(config=config)
        self.session.run(tf.global_variables_initializer())
        self.session.run(tf.local_variables_initializer())

        # Make saving trained weights and biases possible
        self.saver = tf.train.Saver(max_to_keep=5,
                                    keep_checkpoint_every_n_hours=1)
        self.checkpoint_directory = './checkpoints/'

        self.logger = Logger()

    def load_vgg_weights(self):
        """ Use the VGG model trained on
            imagent dataset as a starting point for training """
        vgg_path = "models/imagenet-vgg-verydeep-19.mat"
        vgg_mat = scipy.io.loadmat(vgg_path)

        self.vgg_params = np.squeeze(vgg_mat['layers'])
        self.layers = ('conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
                       'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
                       'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
                       'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
                       'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3',
                       'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
                       'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4',
                       'relu5_4')

    def vgg_weight_and_bias(self, name, W_shape, b_shape):
        """ 
            Initializes weights and biases to the pre-trained VGG model.
            
            Args:
                name: name of the layer for which you want to initialize weights
                W_shape: shape of weights tensor exkpected
                b_shape: shape of bias tensor expected
            returns:
                w_var: Initialized weight variable
                b_var: Initialized bias variable
        """
        if name not in self.layers:
            return self.weight_variable(W_shape), \
                   self.weight_variable(b_shape)
        else:
            w, b = self.vgg_params[self.layers.index(name)][0][0][0][0]
            init_w = tf.constant(value=np.transpose(w, (1, 0, 2, 3)),
                                 dtype=tf.float32,
                                 shape=W_shape)
            init_b = tf.constant(value=b.reshape(-1),
                                 dtype=tf.float32,
                                 shape=b_shape)
            w_var = tf.Variable(init_w)
            b_var = tf.Variable(init_b)
            return w_var, b_var

    def weight_variable(self, shape, is_trainable):
        initial = tf.truncated_normal(shape, stddev=0.1)
        return tf.Variable(initial)

    def bias_variable(self, shape, is_trainable):
        initial = tf.constant(0.1, shape=shape)
        return tf.Variable(initial, trainable=True)

    def pool_layer(self, x):
        return tf.nn.max_pool_with_argmax(x,
                                          ksize=[1, 2, 2, 1],
                                          strides=[1, 2, 2, 1],
                                          padding='SAME')

    def unpool(self, pool, ind, ksize=[1, 2, 2, 1], scope='unpool'):
        """ 
            Unpooling layer after max_pool_with_argmax.

            Args:
                pool: max pooled output tensor
                ind: argmax indices
                ksize: ksize is the same as for the pool
            Return:
                unpool: unpooling tensor
        """
        with tf.variable_scope(scope):
            input_shape = tf.shape(pool)
            output_shape = [
                input_shape[0], input_shape[1] * ksize[1],
                input_shape[2] * ksize[2], input_shape[3]
            ]

            flat_input_size = tf.cumprod(input_shape)[-1]
            flat_output_shape = tf.stack([
                output_shape[0],
                output_shape[1] * output_shape[2] * output_shape[3]
            ])

            pool_ = tf.reshape(pool, tf.stack([flat_input_size]))
            batch_range = tf.range(tf.cast(output_shape[0], tf.int64),
                                   dtype=ind.dtype)
            reshape_shape = tf.stack([input_shape[0], 1, 1, 1])
            reshaped_batch_range = tf.reshape(batch_range, shape=reshape_shape)

            b = tf.ones_like(ind) * reshaped_batch_range
            b = tf.reshape(b, tf.stack([flat_input_size, 1]))
            ind_ = tf.reshape(ind, tf.stack([flat_input_size, 1]))
            ind_ = tf.concat([b, ind_], 1)

            ret = tf.scatter_nd(ind_,
                                pool_,
                                shape=tf.cast(flat_output_shape, tf.int64))
            ret = tf.reshape(ret, tf.stack(output_shape))
            return ret

    def conv_layer_with_bn(self, x, W_shape, name, padding='SAME'):
        b_shape = W_shape[3]
        W, b = self.vgg_weight_and_bias(name, W_shape, [b_shape])
        convolved_output = tf.nn.conv2d(
            x, W, strides=[1, 1, 1, 1], padding=padding) + b
        batch_norm = tf.contrib.layers.batch_norm(convolved_output,
                                                  is_training=True)
        return tf.nn.relu(batch_norm)

    def dynamic_filtering(self, pool_5):
        df = self.gen_dynamic_filter(self.theta,
                                     pool_5,
                                     filter_shape=[3, 3, 512, 512])
        pool_5 = self.dynamic_conv_layer(pool_5,
                                         filter_shape=[3, 3, 512, 512],
                                         dynamic_filter=df,
                                         name="conv_d")
        return pool_5

    def gen_dynamic_filter(self, theta, pooled_layer, filter_shape):
        """ 
            filter_shape=[3, 3, 512, 512]
            pooled_layer shape = NUM_BATCHES * HEIGHT * WIDTH * 512
        """
        # print(pooled_layer.get_shape())
        # pooled_layer shape = NUM_BATCHES(?) * 10 * 15 * 512

        feature_map = tf.reduce_mean(pooled_layer, axis=3)
        # print(feature_map.get_shape())
        # feature_map shape = NUM_BATCHES(?) * HEIGHT * WIDTH

        length = feature_map.get_shape()[1] * feature_map.get_shape()[2]
        # print(length)
        # length = HEIGHT * WIDTH = 150

        features = tf.reshape(feature_map, [-1, int(length)])
        # print(features.get_shape())
        # features shape = NUM_BATCHES(?) * 150
        num_batches = tf.shape(features)[0]
        """
        Theta with Reduce Mean

        theta = theta/20
        theta = tf.expand_dims(theta, 0)
        theta = tf.expand_dims(theta, 1)
        theta = tf.expand_dims(theta, 2)
        features = tf.expand_dims(features, 2)
        theta = tf.tile(theta, tf.stack([tf.shape(features)[0], features.get_shape()[1], 1]))
        features = tf.concat([features, theta], 2)
        features = tf.reduce_mean(features, axis=2)
        """
        """
        Theta with append
        """
        theta = theta / 20
        theta = tf.expand_dims(theta, 0)
        theta = tf.expand_dims(theta, 1)
        theta = tf.tile(theta, tf.stack([tf.shape(features)[0], 1]))
        features = tf.concat([features, theta], 1)
        # print(features.get_shape())

        fc1 = tf.contrib.layers.fully_connected(features, 64)
        fc2 = tf.contrib.layers.fully_connected(fc1, 128)
        fc3 = tf.contrib.layers.fully_connected(fc2,
                                                filter_shape[0] *
                                                filter_shape[1],
                                                activation_fn=None)
        fc3 = tf.reduce_mean(fc3, axis=0)
        filt = tf.reshape(fc3, filter_shape[0:2])
        filt = tf.expand_dims(filt, 2)
        filt = tf.expand_dims(filt, 3)
        filt = tf.tile(filt, [1, 1, filter_shape[2], filter_shape[3]])
        return filt

    def dynamic_conv_layer(self,
                           bottom,
                           filter_shape,
                           dynamic_filter,
                           name,
                           strides=[1, 1, 1, 1],
                           padding="SAME"):
        # init_w = tf.truncated_normal(filter_shape, stddev=0.2)
        init_b = tf.constant_initializer(value=0.0, dtype=tf.float32)
        filt = tf.get_variable(name="%s_w" % name,
                               shape=filter_shape,
                               initializer=tf.truncated_normal_initializer,
                               dtype=tf.float32)
        filt = tf.add(filt, dynamic_filter)
        conv = tf.nn.conv2d(bottom,
                            filter=filt,
                            strides=strides,
                            padding=padding,
                            name=name)
        bias = tf.get_variable(name="%s_b" % name,
                               initializer=init_b,
                               shape=[filter_shape[-1]],
                               dtype=tf.float32)
        return tf.nn.bias_add(conv, bias)

    def build(self):
        with tf.device('/gpu:0'):
            # Declare placeholders
            # self.x dimensions = BATCH_SIZE * HEIGHT * WIDTH * NUM_CHANNELS
            # TODO: Make flexible for image sizes
            self.x = tf.placeholder(tf.float32, shape=[None, 320, 480, 3])
            # self.y dimensions = BATCH_SIZE * WIDTH * HEIGHT
            self.y = tf.placeholder(tf.int64, shape=[None, 320, 480])
            expected = tf.expand_dims(self.y, -1)
            self.is_trainable = tf.placeholder(tf.bool, name='is_trainable')
            self.rate = tf.placeholder(tf.float32, shape=[])
            self.theta = tf.placeholder(tf.float32, shape=[], name='theta')

            # First encoder
            # conv_1_1 shape = BATCH_SIZE * HEIGHT * WIDTH * 64
            conv_1_1 = self.conv_layer_with_bn(self.x, [3, 3, 3, 64],
                                               'conv1_1')
            conv_1_2 = self.conv_layer_with_bn(conv_1_1, [3, 3, 64, 64],
                                               'conv1_2')
            pool_1, pool_1_argmax = self.pool_layer(conv_1_2)

            # Second encoder
            conv_2_1 = self.conv_layer_with_bn(pool_1, [3, 3, 64, 128],
                                               'conv2_1')
            conv_2_2 = self.conv_layer_with_bn(conv_2_1, [3, 3, 128, 128],
                                               'conv2_2')
            pool_2, pool_2_argmax = self.pool_layer(conv_2_2)

            # Third encoder
            conv_3_1 = self.conv_layer_with_bn(pool_2, [3, 3, 128, 256],
                                               'conv3_1')
            conv_3_2 = self.conv_layer_with_bn(conv_3_1, [3, 3, 256, 256],
                                               'conv3_2')
            conv_3_3 = self.conv_layer_with_bn(conv_3_2, [3, 3, 256, 256],
                                               'conv3_3')
            pool_3, pool_3_argmax = self.pool_layer(conv_3_3)

            # Fourth encoder
            conv_4_1 = self.conv_layer_with_bn(pool_3, [3, 3, 256, 512],
                                               'conv4_1')
            conv_4_2 = self.conv_layer_with_bn(conv_4_1, [3, 3, 512, 512],
                                               'conv4_2')
            conv_4_3 = self.conv_layer_with_bn(conv_4_2, [3, 3, 512, 512],
                                               'conv4_3')
            pool_4, pool_4_argmax = self.pool_layer(conv_4_3)

            # Fifth encoder
            conv_5_1 = self.conv_layer_with_bn(pool_4, [3, 3, 512, 512],
                                               'conv5_1')
            conv_5_2 = self.conv_layer_with_bn(conv_5_1, [3, 3, 512, 512],
                                               'conv5_2')
            conv_5_3 = self.conv_layer_with_bn(conv_5_2, [3, 3, 512, 512],
                                               'conv5_3')
            # pool_5 shape = BATCH_SIZE * HEIGHT * WIDTH * 512
            pool_5, pool_5_argmax = self.pool_layer(conv_5_3)

            # Dynamic Filtering when on non-street view
            y = lambda x: x
            pool_5 = tf.cond(self.is_trainable, lambda: y(pool_5),
                             lambda: self.dynamic_filtering(pool_5))

            # First decoder
            unpool_5 = self.unpool(pool_5, pool_5_argmax)
            deconv_5_3 = self.conv_layer_with_bn(unpool_5, [3, 3, 512, 512],
                                                 'deconv5_3')
            deconv_5_2 = self.conv_layer_with_bn(deconv_5_3, [3, 3, 512, 512],
                                                 'deconv5_2')
            deconv_5_1 = self.conv_layer_with_bn(deconv_5_2, [3, 3, 512, 512],
                                                 'deconv5_1')

            # Second decoder
            unpool_4 = self.unpool(deconv_5_1, pool_4_argmax)
            deconv_4_3 = self.conv_layer_with_bn(unpool_4, [3, 3, 512, 512],
                                                 'deconv4_3')
            deconv_4_2 = self.conv_layer_with_bn(deconv_4_3, [3, 3, 512, 512],
                                                 'deconv4_2')
            deconv_4_1 = self.conv_layer_with_bn(deconv_4_2, [3, 3, 512, 256],
                                                 'deconv4_1')

            # Third decoder
            unpool_3 = self.unpool(deconv_4_1, pool_3_argmax)
            deconv_3_3 = self.conv_layer_with_bn(unpool_3, [3, 3, 256, 256],
                                                 'deconv3_3')
            deconv_3_2 = self.conv_layer_with_bn(deconv_3_3, [3, 3, 256, 256],
                                                 'deconv3_2')
            deconv_3_1 = self.conv_layer_with_bn(deconv_3_2, [3, 3, 256, 128],
                                                 'deconv3_1')

            # Fourth decoder
            unpool_2 = self.unpool(deconv_3_1, pool_2_argmax)
            deconv_2_2 = self.conv_layer_with_bn(unpool_2, [3, 3, 128, 128],
                                                 'deconv2_2')
            deconv_2_1 = self.conv_layer_with_bn(deconv_2_2, [3, 3, 128, 64],
                                                 'deconv2_1')

            # Fifth decoder
            unpool_1 = self.unpool(deconv_2_1, pool_1_argmax)
            deconv_1_2 = self.conv_layer_with_bn(unpool_1, [3, 3, 64, 64],
                                                 'deconv1_2')
            deconv_1_1 = self.conv_layer_with_bn(deconv_1_2, [3, 3, 64, 32],
                                                 'deconv1_1')

            # Produce class scores
            # score_1 dimensions: BATCH_SIZE * HEIGHT * WIDTH * NUM_CLASSES
            score_1 = self.conv_layer_with_bn(deconv_1_1,
                                              [1, 1, 32, self.num_classes],
                                              'score_1')
            logits = tf.reshape(score_1, (-1, self.num_classes))

            # Prepare network outputs
            cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=tf.reshape(expected, [-1]),
                logits=logits,
                name='x_entropy')
            self.loss = tf.reduce_mean(cross_entropy, name='x_entropy_mean')
            optimizer = tf.train.AdamOptimizer(self.rate)
            self.train_step = optimizer.minimize(self.loss)

            # Metrics
            self.prediction = tf.argmax(score_1, axis=3, name="prediction")
            self.accuracy = tf.contrib.metrics.accuracy(self.prediction,
                                                        self.y,
                                                        name='accuracy')
            self.mean_IoU = tf.contrib.metrics.streaming_mean_iou(
                self.prediction, self.y, self.num_classes, name='mean_IoU')

    def restore_session(self):
        global_step = 0

        if not os.path.exists(self.checkpoint_directory):
            raise IOError(self.checkpoint_directory + ' does not exist.')
        else:
            path = tf.train.get_checkpoint_state(self.checkpoint_directory)
            if path is None:
                pass
            else:
                self.saver.restore(self.session, path.model_checkpoint_path)
                global_step = int(path.model_checkpoint_path.split('-')[-1])

        return global_step

    def train(self,
              num_iterations,
              theta,
              is_trainable,
              dataset_directory,
              learning_rate=0.1,
              batch_size=5):
        """
            Args:
                num_iterations
                theta: View perspective number

        """
        current_step = self.restore_session()

        bdr = BatchDatasetReader(dataset_directory, 480, 320, current_step,
                                 batch_size)

        # Begin Training
        for i in range(current_step, num_iterations):

            # One training step
            images, ground_truths = bdr.next_training_batch()

            # Train Phase Guide
            # is_trainable = True : street view training
            # is_trainable = False : non-street view training, freeze weights
            feed_dict = {
                self.x: images,
                self.y: ground_truths,
                self.is_trainable: is_trainable,
                self.theta: theta,
                self.rate: learning_rate
            }
            print('run train step: ' + str(i))
            self.train_step.run(session=self.session, feed_dict=feed_dict)

            # Print loss every 10 iterations
            if i % 10 == 0:
                train_loss = self.session.run(self.loss, feed_dict=feed_dict)
                print("Step: %d, Train_loss:%g" % (i, train_loss))

            # Run against validation dataset for 100 iterations
            if i % 100 == 0:
                images, ground_truths = bdr.next_val_batch()
                num_training_images = bdr.num_train

                # Make a validation prediction
                feed_dict = {
                    self.x: images,
                    self.y: ground_truths,
                    self.is_trainable: is_trainable,
                    self.theta: theta,
                    self.rate: learning_rate
                }
                val_loss = self.session.run(self.loss, feed_dict=feed_dict)
                val_accuracy = self.session.run(self.accuracy,
                                                feed_dict=feed_dict)
                val_mean_IoU, update_op = self.session.run(self.mean_IoU,
                                                           feed_dict=feed_dict)

                print("%s ---> Validation_loss: %g" %
                      (datetime.datetime.now(), val_loss))
                print("%s ---> Validation_accuracy: %g" %
                      (datetime.datetime.now(), val_accuracy))

                self.logger.log("%s ---> Number of epochs: %g\n" %
                                (datetime.datetime.now(),
                                 math.floor(
                                     (i * batch_size) / num_training_images)))
                self.logger.log("%s ---> Number of iterations: %g\n" %
                                (datetime.datetime.now(), i))
                self.logger.log("%s ---> Validation_loss: %g\n" %
                                (datetime.datetime.now(), val_loss))
                self.logger.log("%s ---> Validation_accuracy: %g\n" %
                                (datetime.datetime.now(), val_accuracy))
                self.logger.log_for_graphing(i, val_loss, val_accuracy,
                                             val_mean_IoU)

                # Save the model variables
                self.saver.save(self.session,
                                self.checkpoint_directory + 'DFSegNet',
                                global_step=i)

            # Print outputs every 1000 iterations
            if i % 1000 == 0:
                self.test(theta, is_trainable, dataset_directory, 1e-2)
                self.logger.graph_training_stats()

    def test(self, theta, is_trainable, dataset_directory, learning_rate):

        current_step = self.restore_session()

        dr = DatasetReader(480, 320, dataset_directory)

        for i in range(min(dr.test_data_size, 10)):
            image, ground_truth = dr.next_test_pair()

            feed_dict = {
                self.x: [image],
                self.y: [ground_truth],
                self.is_trainable: is_trainable,
                self.theta: theta,
                self.rate: learning_rate
            }
            segmentation = np.squeeze(
                self.session.run(self.prediction, feed_dict=feed_dict))

            dp = DataPostprocessor()
            dp.write_out(i, image, segmentation, ground_truth, current_step)
示例#7
0
class SegNet:
  ''' Network described by,
  https://arxiv.org/pdf/1511.00561.pdf '''

  def load_vgg_weights(self):
    """ Use the VGG model trained on
      imagent dataset as a starting point for training """

    # REMINDER: Download model if not existing
    vgg_path = "models/imagenet-vgg-verydeep-19.mat"
    vgg_mat = scipy.io.loadmat(vgg_path)

    self.vgg_params = np.squeeze(vgg_mat['layers'])
    self.layers = ('conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
            'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
            'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
            'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
            'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
            'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
            'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
            'relu5_3', 'conv5_4', 'relu5_4')

  def __init__(self, dataset_directory, num_classes=11):
    self.dataset_directory = dataset_directory

    self.num_classes = num_classes

    self.load_vgg_weights()

    self.build()

    # Begin a TensorFlow session
    config = tf.ConfigProto(allow_soft_placement=True)
    self.session = tf.Session()
    self.session.run(tf.global_variables_initializer())

    # Make saving trained weights and biases possible
    self.saver = tf.train.Saver(max_to_keep = 5, keep_checkpoint_every_n_hours = 1)
    self.checkpoint_directory = './checkpoints/'
    self.logger = Logger()

  def vgg_weight_and_bias(self, name, W_shape, b_shape):
    """ 
      Initializes weights and biases to the pre-trained VGG model.
      
      Args:
        name: name of the layer for which you want to initialize weights
        W_shape: shape of weights tensor expected
        b_shape: shape of bias tensor expected
      returns:
        w_var: Initialized weight variable
        b_var: Initialized bias variable
    """
    if name not in self.layers:
      return self.weight_variable(W_shape), self.bias_variable(b_shape)
    else:
      w, b = self.vgg_params[self.layers.index(name)][0][0][0][0]
      init_w = tf.constant(value=np.transpose(w, (1, 0, 2, 3)), dtype=tf.float32, shape=W_shape)
      init_b = tf.constant(value=b.reshape(-1), dtype=tf.float32, shape=b_shape)
      w_var = tf.Variable(init_w)
      b_var = tf.Variable(init_b)
      return w_var, b_var 

  def weight_variable(self, shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

  def bias_variable(self, shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

  def batch_norm_layer(self, inputT, is_training, scope):
    return tf.cond(is_training,
          lambda: tf.contrib.layers.batch_norm(inputT, is_training=True,
                           center=False, updates_collections=None, scope=scope+"_bn"),
          lambda: tf.contrib.layers.batch_norm(inputT, is_training=False,
                           updates_collections=None, center=False, scope=scope+"_bn", reuse = True))

  def pool_layer(self, x):
    return tf.nn.max_pool_with_argmax(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

  def unpool(self, pool, ind, ksize=(1, 2, 2, 1), scope='unpool'):
    """ Unpooling layer after max_pool_with_argmax.
      Args:
        pool: max pooled output tensor
        ind: argmax indices (produced by tf.nn.max_pool_with_argmax)
        ksize: ksize is the same as for the pool
      Return:
        unpooled: unpooling tensor
      Footnote:
        Implementation idea from: https://github.com/tensorflow/tensorflow/issues/2169
    """
    with tf.variable_scope(scope):
      pooled_shape = tf.shape(pool) 
      flatten_ind = tf.reshape(ind, (pooled_shape[0], pooled_shape[1] * pooled_shape[2] * pooled_shape[3]))
      # sparse indices to dense ones_like matrics
      one_hot_ind = tf.one_hot(flatten_ind,  pooled_shape[1] * ksize[1] * pooled_shape[2] * ksize[2] * pooled_shape[3], on_value=1., off_value=0., axis=-1)
      one_hot_ind = tf.reduce_sum(one_hot_ind, axis=1)
      one_like_mask = tf.reshape(one_hot_ind, (pooled_shape[0], pooled_shape[1] * ksize[1], pooled_shape[2] * ksize[2], pooled_shape[3]))
      # resize input array to the output size by nearest neighbor
      img = tf.image.resize_nearest_neighbor(pool, [pooled_shape[1] * ksize[1], pooled_shape[2] * ksize[2]])
      unpooled = tf.multiply(img, tf.cast(one_like_mask, img.dtype))
      return unpooled 

  def unravel_argmax(self, argmax, shape):
    output_list = []
    output_list.append(argmax // (shape[2] * shape[3]))
    output_list.append(argmax % (shape[2] * shape[3]) // shape[3])
    return tf.stack(output_list)
  
  def unpool_layer2x2(self, x, raveled_argmax, out_shape):
    ''' Implementation idea from: 
        https://github.com/tensorflow/tensorflow/issues/2169 '''

    argmax = self.unravel_argmax(raveled_argmax, tf.to_int64(out_shape))
    output = tf.zeros([out_shape[1], out_shape[2], out_shape[3]])
    height = tf.shape(output)[0]
    width = tf.shape(output)[1]
    channels = tf.shape(output)[2]
    t1 = tf.to_int64(tf.range(channels))
    t1 = tf.tile(t1, [((width + 1) // 2) * ((height + 1) // 2)])
    t1 = tf.reshape(t1, [-1, channels])
    t1 = tf.transpose(t1, perm=[1, 0])
    t1 = tf.reshape(t1, [channels, (height + 1) // 2, (width + 1) // 2, 1])
    t2 = tf.squeeze(argmax)
    t2 = tf.stack((t2[0], t2[1]), axis=0)
    t2 = tf.transpose(t2, perm=[3, 1, 2, 0])
    t = tf.concat([t2, t1], 3)
    indices = tf.reshape(t, [((height + 1) // 2) * ((width + 1) // 2) * channels, 3])
    x1 = tf.squeeze(x)
    x1 = tf.reshape(x1, [-1, channels])
    x1 = tf.transpose(x1, perm=[1, 0])
    values = tf.reshape(x1, [-1])
    delta = tf.SparseTensor(indices, values, tf.to_int64(tf.shape(output)))
    return tf.expand_dims(tf.sparse_tensor_to_dense(tf.sparse_reorder(delta)), 0)
        
  def conv_layer(self, x, W_shape, name, padding='SAME'):
    # Pass b_shape as list because need the object to be iterable for the constant initializer
    out_channel = W_shape[3]
    W, b = self.vgg_weight_and_bias(name, W_shape, [out_channel])

    output = tf.nn.conv2d(x, W, strides=[1,1,1,1], padding=padding) + b
    return tf.nn.relu(output)

  def conv_layer_with_bn(self, x, W_shape, train_phase, name, padding='SAME'):
    out_channel = W_shape[3]
    with tf.variable_scope(name) as scope:
      W, b = self.vgg_weight_and_bias(name, W_shape, [out_channel])
      output = tf.nn.conv2d(x, W, strides=[1,1,1,1], padding=padding) + b
      return tf.nn.relu(self.batch_norm_layer(output, train_phase, scope.name))

  def deconv_layer(self, x, W_shape, b_shape, name, padding='SAME'):
    W = self.weight_variable(W_shape)
    b = self.bias_variable([b_shape])
    x_shape = tf.shape(x)
    out_shape = tf.stack([x_shape[0], x_shape[1], x_shape[2], W_shape[2]])
    return tf.nn.conv2d_transpose(x, W, out_shape, [1, 1, 1, 1], padding=padding) + b

  def build(self):
    # Declare placeholders
    self.x = tf.placeholder(tf.float32, shape=(1, None, None, 3))
    self.y = tf.placeholder(tf.int64, shape=(1, None, None))
    expected = tf.expand_dims(self.y, -1)
    self.train_phase = tf.placeholder(tf.bool, name='train_phase')

    # First encoder
    conv_1_1 = self.conv_layer_with_bn(self.x, [3, 3, 3, 64], self.train_phase, 'conv1_1')
    conv_1_2 = self.conv_layer_with_bn(conv_1_1, [3, 3, 64, 64], self.train_phase, 'conv1_2')
    pool_1, pool_1_argmax = self.pool_layer(conv_1_2)

    # Second encoder
    conv_2_1 = self.conv_layer_with_bn(pool_1, [3, 3, 64, 128], self.train_phase, 'conv2_1')
    conv_2_2 = self.conv_layer_with_bn(conv_2_1, [3, 3, 128, 128], self.train_phase, 'conv2_2')
    pool_2, pool_2_argmax = self.pool_layer(conv_2_2)

    # Third encoder
    conv_3_1 = self.conv_layer_with_bn(pool_2, [3, 3, 128, 256], self.train_phase, 'conv3_1')
    conv_3_2 = self.conv_layer_with_bn(conv_3_1, [3, 3, 256, 256], self.train_phase, 'conv3_2')
    conv_3_3 = self.conv_layer_with_bn(conv_3_2, [3, 3, 256, 256], self.train_phase, 'conv3_3')
    pool_3, pool_3_argmax = self.pool_layer(conv_3_3)

    # Fourth encoder
    conv_4_1 = self.conv_layer_with_bn(pool_3, [3, 3, 256, 512], self.train_phase, 'conv4_1')
    conv_4_2 = self.conv_layer_with_bn(conv_4_1, [3, 3, 512, 512], self.train_phase, 'conv4_2')
    conv_4_3 = self.conv_layer_with_bn(conv_4_2, [3, 3, 512, 512], self.train_phase, 'conv4_3')
    pool_4, pool_4_argmax = self.pool_layer(conv_4_3)

    # Fifth encoder
    conv_5_1 = self.conv_layer_with_bn(pool_4, [3, 3, 512, 512], self.train_phase, 'conv5_1')
    conv_5_2 = self.conv_layer_with_bn(conv_5_1, [3, 3, 512, 512], self.train_phase, 'conv5_2')
    conv_5_3 = self.conv_layer_with_bn(conv_5_2, [3, 3, 512, 512], self.train_phase, 'conv5_3')
    pool_5, pool_5_argmax = self.pool_layer(conv_5_3)

    # First decoder
    unpool_5 = self.unpool_layer2x2(pool_5, pool_5_argmax, tf.shape(conv_5_3))
    deconv_5_3 = self.conv_layer_with_bn(unpool_5, [3, 3, 512, 512], self.train_phase, 'deconv5_3')
    deconv_5_2 = self.conv_layer_with_bn(deconv_5_3, [3, 3, 512, 512], self.train_phase, 'deconv5_2')
    deconv_5_1 = self.conv_layer_with_bn(deconv_5_2, [3, 3, 512, 512], self.train_phase, 'deconv5_1')

    # Second decoder
    unpool_4 = self.unpool_layer2x2(deconv_5_1, pool_4_argmax, tf.shape(conv_4_3))
    deconv_4_3 = self.conv_layer_with_bn(unpool_4, [3, 3, 512, 512], self.train_phase, 'deconv4_3')
    deconv_4_2 = self.conv_layer_with_bn(deconv_4_3, [3, 3, 512, 512], self.train_phase, 'deconv4_2')
    deconv_4_1 = self.conv_layer_with_bn(deconv_4_2, [3, 3, 512, 256], self.train_phase, 'deconv4_1')

    # Third decoder
    unpool_3 = self.unpool_layer2x2(deconv_4_1, pool_3_argmax, tf.shape(conv_3_3))
    deconv_3_3 = self.conv_layer_with_bn(unpool_3, [3, 3, 256, 256], self.train_phase, 'deconv3_3')
    deconv_3_2 = self.conv_layer_with_bn(deconv_3_3, [3, 3, 256, 256], self.train_phase, 'deconv3_2')
    deconv_3_1 = self.conv_layer_with_bn(deconv_3_2, [3, 3, 256, 128], self.train_phase, 'deconv3_1')

    # Fourth decoder
    unpool_2 = self.unpool_layer2x2(deconv_3_1, pool_2_argmax, tf.shape(conv_2_2))
    deconv_2_2 = self.conv_layer_with_bn(unpool_2, [3, 3, 128, 128], self.train_phase, 'deconv2_2')
    deconv_2_1 = self.conv_layer_with_bn(deconv_2_2, [3, 3, 128, 64], self.train_phase, 'deconv2_1')

    # Fifth decoder
    unpool_1 = self.unpool_layer2x2(deconv_2_1, pool_1_argmax, tf.shape(conv_1_2))
    deconv_1_2 = self.conv_layer_with_bn(unpool_1, [3, 3, 64, 64], self.train_phase, 'deconv1_2')
    deconv_1_1 = self.conv_layer_with_bn(deconv_1_2, [3, 3, 64, 32], self.train_phase, 'deconv1_1')

    # Produce class scores
    preds = self.conv_layer(deconv_1_1, [1, 1, 32, self.num_classes], 'preds')
    self.logits = tf.reshape(preds, (-1, self.num_classes))

    # Prepare network for training
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=tf.reshape(expected, [-1]), logits=self.logits, name='x_entropy')
    self.loss = tf.reduce_mean(cross_entropy, name='x_entropy_mean')

    # Metrics
    predicted_image = tf.argmax(preds, axis=3)
    self.accuracy = tf.contrib.metrics.accuracy(tf.cast(predicted_image, tf.int64), self.y, name='accuracy')

  def restore_session(self):
    global_step = 0

    if not os.path.exists(self.checkpoint_directory):
      raise IOError(self.checkpoint_directory + ' does not exist.')
    else:
      path = tf.train.get_checkpoint_state(self.checkpoint_directory)
      if path is None:
        pass
      else:
        self.saver.restore(self.session, path.model_checkpoint_path)
        global_step = int(path.model_checkpoint_path.split('-')[-1])

    return global_step

  
  def train(self, num_iterations=10000, learning_rate=1e-6):
    self.rate = learning_rate

    # Restore previous session if exists
    current_step = self.restore_session()

    dataset = DatasetReader()
    
    # Begin Training
    self.train_step = tf.train.AdamOptimizer(self.rate).minimize(self.loss)
    for i in range(current_step, num_iterations):

      # One training step
      image, ground_truth = dataset.next_train_pair()
      feed_dict = {self.x: [image], self.y: [ground_truth], self.train_phase: True}
      print('run train step: '+str(i))
      self.train_step.run(session=self.session, feed_dict=feed_dict)

      # Print loss every 10 iterations
      if i % 10 == 0:
        train_loss = self.session.run(self.loss, feed_dict=feed_dict)
        print("Step: %d, Train_loss:%g" % (i, train_loss))

      # Run against validation dataset for 100 iterations
      if i % 100 == 0:
        image, ground_truth = dataset.next_val_pair()
        feed_dict = {self.x: [image], self.y: [ground_truth], self.train_phase: True}
        val_loss = self.session.run(self.loss, feed_dict=feed_dict)
        val_accuracy = self.session.run(self.accuracy, feed_dict=feed_dict)
        print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), val_loss))
        print("%s ---> Validation_accuracy: %g" % (datetime.datetime.now(), val_accuracy))

        self.logger.log("%s ---> Number of epochs: %g\n" % (datetime.datetime.now(), math.floor((i * batch_size)/bdr.num_train)))
        self.logger.log("%s ---> Number of iterations: %g\n" % (datetime.datetime.now(), i))
        self.logger.log("%s ---> Validation_loss: %g\n" % (datetime.datetime.now(), val_loss))
        self.logger.log("%s ---> Validation_accuracy: %g\n" % (datetime.datetime.now(), val_accuracy))

        # Save the model variables
        self.saver.save(self.session, self.checkpoint_directory + 'segnet', global_step = i)

  def test(self, learning_rate=1e-6):
    dataset = DatasetReader()
    image, ground_truth = dataset.next_test_pair() 
    feed_dict = {self.x: [image], self.y: [ground_truth], self.train_phase: False}
    prediction = self.session(self.logits, feed_dict=feed_dict)
    img = Image.fromarray(prediction, 'L')
    img.save('prediction.png')