def dataset_from_files(self,
                           train_imgs,
                           train_lbls,
                           is_training,
                           repeat=True,
                           batch_size=None):
        def _parse_function(filename, label):
            image_string = tf.read_file(filename)
            image_decoded = tf.image.decode_jpeg(image_string,
                                                 channels=3)  ## uint8 image
            return image_decoded, tf.one_hot(label,
                                             config.num_classes,
                                             dtype=tf.int64)

        filenames = tf.constant(train_imgs)
        labels = tf.constant(train_lbls, dtype=tf.int32)

        dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))

        if repeat:
            dataset = dataset.shuffle(len(train_imgs))
            dataset = dataset.repeat(None)
        else:
            #dataset = dataset.shuffle(len(train_imgs)) ## Avoid shuffle to test batch normalization
            dataset = dataset.repeat(1)

        dataset = dataset.map(_parse_function, num_parallel_calls=3)

        ## Image level augmentation. It is possible to use it but I find batch level augmentation better
        # preprocess_mod = locate(config.preprocessing_module)
        # func_name = preprocess_mod.preprocess_for_train_simple
        # if not is_training:
        #     func_name = preprocess_mod.preprocess_for_eval_simple
        # dataset = dataset.map(lambda im, lbl,weight: (func_name (im,const.frame_height,const.frame_width), lbl,weight))

        if batch_size is None:
            batch_size = config.batch_size
        else:
            print('Eval Batch used')

        dataset = dataset.batch(batch_size)

        ## Batch Level Augmentation
        if is_training:
            dataset = dataset.map(lambda im_batch, lbl_batch: (
                nn_utils.augment(im_batch,
                                 resize=
                                 (const.frame_height, const.frame_width),
                                 horizontal_flip=True,
                                 vertical_flip=False,
                                 rotate=0,
                                 crop_probability=0,
                                 crop_min_percent=0), lbl_batch))
        else:
            dataset = dataset.map(lambda im_batch, lbl_batch:
                                  (nn_utils.center_crop(im_batch), lbl_batch))

        dataset = dataset.prefetch(1)
        return dataset
예제 #2
0
    def __init__(self,
                 cfg,
                 weight_decay=0.0001,
                 data_format='NHWC',
                 is_training=False,
                 reuse=None,
                 images_ph=None,
                 lbls_ph=None):
        self.cfg = cfg
        batch_size = None
        num_classes = cfg.num_classes
        if lbls_ph is not None:
            self.gt_lbls = tf.reshape(lbls_ph, [-1, num_classes])
        else:
            self.gt_lbls = tf.placeholder(tf.int32,
                                          shape=(batch_size, num_classes),
                                          name='class_lbls')

        self.augment_input = tf.placeholder(tf.bool, name='augment_input')

        ## Check whether to use placeholder for training (images_ph == None),
        # or the caller training procedure already provide images dataset pipeline
        if images_ph is not None:
            ## If training using images TF dataset pipeline, no need to do augmentation,
            #  just make sure the input is in the correct shape

            ## This alternative is more efficient because it avoid the discouraged TF placeholder usage
            self.input = images_ph
            _, w, h, c = self.input.shape
            aug_imgs = tf.reshape(self.input, [-1, w, h, 3])
        else:

            # If the input provide no images TF dataset pipeline
            # Revert to the traditional placeholder usage
            self.input = tf.placeholder(
                tf.float32,
                shape=(batch_size, const.max_frame_size, const.max_frame_size,
                       const.frame_channels),
                name='context_input')

            ## Training procedure controls whether to augment placeholder images
            #  using self.augment_input bool tensor
            aug_imgs = tf.cond(
                self.augment_input,
                lambda: nn_utils.augment(self.input,
                                         horizontal_flip=True,
                                         vertical_flip=False,
                                         rotate=0,
                                         crop_probability=0,
                                         color_aug_probability=0),
                lambda: nn_utils.center_crop(self.input))

        with tf.contrib.slim.arg_scope(
                densenet_arg_scope(weight_decay=weight_decay,
                                   data_format=data_format)):
            nets, train_end_points = densenet(aug_imgs,
                                              num_classes=num_classes,
                                              reduction=0.5,
                                              growth_rate=48,
                                              num_filters=96,
                                              num_layers=[6, 12, 36, 24],
                                              data_format=data_format,
                                              is_training=True,
                                              reuse=None,
                                              scope='densenet161')

            val_nets, val_end_points = densenet(
                aug_imgs,
                num_classes=num_classes,
                reduction=0.5,
                growth_rate=48,
                num_filters=96,
                num_layers=[6, 12, 36, 24],
                data_format=data_format,
                is_training=False,  ## Set is always to false
                reuse=True,
                scope='densenet161')

        def cal_metrics(end_points):
            gt = tf.argmax(self.gt_lbls, 1)
            logits = tf.reshape(end_points['densenet161/logits'],
                                [-1, num_classes])
            pre_logits = end_points['densenet161/dense_block4']

            center_supervised_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=self.gt_lbls, logits=logits, name='xentropy_center')
            loss = tf.reduce_mean(center_supervised_cross_entropy,
                                  name='xentropy_mean')
            predictions = tf.reshape(end_points['predictions'],
                                     [-1, num_classes])
            class_prediction = tf.argmax(predictions, 1)
            supervised_correct_prediction = tf.equal(gt, class_prediction)
            supervised_correct_prediction_cast = tf.cast(
                supervised_correct_prediction, tf.float32)
            accuracy = tf.reduce_mean(supervised_correct_prediction_cast)
            confusion_mat = tf.confusion_matrix(gt,
                                                class_prediction,
                                                num_classes=num_classes)
            _, accumulated_accuracy = tf.metrics.accuracy(gt, class_prediction)

            return loss, pre_logits, accuracy, confusion_mat, accumulated_accuracy

        self.train_loss, self.train_pre_logits, self.train_accuracy, self.train_confusion_mat, self.train_accumulated_accuracy = cal_metrics(
            train_end_points)
        self.val_loss, self.val_pre_logits, self.val_accuracy, self.val_confusion_mat, self.val_accumulated_accuracy = cal_metrics(
            val_end_points)
    def __init__(self,
                 cfg,
                 weight_decay=0.0001,
                 data_format='NHWC',
                 reuse=None,
                 images_ph=None,
                 lbls_ph=None,
                 weights_ph=None):

        self.cfg = cfg
        num_classes = cfg.num_classes
        filter_type = cfg.filter_type
        verbose = cfg.print_filter_name
        batch_size = None
        if lbls_ph is not None:
            self.gt_lbls = tf.reshape(lbls_ph, [-1, num_classes])
        else:
            self.gt_lbls = tf.placeholder(tf.int32,
                                          shape=(batch_size, num_classes),
                                          name='class_lbls')

        self.do_augmentation = tf.placeholder(tf.bool, name='do_augmentation')
        self.loss_class_weight = tf.placeholder(tf.float32,
                                                shape=(num_classes,
                                                       num_classes),
                                                name='weights')
        self.input = tf.placeholder(tf.float32,
                                    shape=(batch_size, const.max_frame_size,
                                           const.max_frame_size,
                                           const.num_channels),
                                    name='context_input')

        # if is_training:
        if images_ph is not None:
            self.input = images_ph
            _, w, h, c = self.input.shape
            aug_imgs = tf.reshape(self.input, [-1, w, h, c])
            print('No nnutils Augmentation')
        else:
            if cfg.db_name == 'honda':
                aug_imgs = self.input
            else:
                aug_imgs = tf.cond(
                    self.do_augmentation,
                    lambda: nn_utils.augment(self.input,
                                             cfg.preprocess_func,
                                             horizontal_flip=True,
                                             vertical_flip=False,
                                             rotate=0,
                                             crop_probability=0,
                                             color_aug_probability=0), lambda:
                    nn_utils.center_crop(self.input, cfg.preprocess_func))
            # aug_imgs = self.input ## Already augmented

            # else:

        with tf.contrib.slim.arg_scope(
                densenet_arg_scope(weight_decay=weight_decay,
                                   data_format=data_format)):

            val_nets, val_end_points = densenet(
                aug_imgs,
                num_classes=num_classes,
                reduction=0.5,
                growth_rate=48,
                num_filters=96,
                num_layers=[6, 12, 36, 24],
                data_format=data_format,
                is_training=False,  ## Set is always to false
                reuse=True,
                filter_type=filter_type,
                verbose=verbose,
                scope='densenet161')

        def cal_metrics(end_points):
            gt = tf.argmax(self.gt_lbls, 1)
            logits = tf.reshape(end_points['densenet161/logits'],
                                [-1, num_classes])
            pre_logits = end_points['densenet161/dense_block4']

            center_supervised_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=self.gt_lbls, logits=logits, name='xentropy_center')
            loss = tf.reduce_mean(center_supervised_cross_entropy,
                                  name='xentropy_mean')
            predictions = tf.reshape(end_points['predictions'],
                                     [-1, num_classes])
            class_prediction = tf.argmax(predictions, 1)
            supervised_correct_prediction = tf.equal(gt, class_prediction)
            supervised_correct_prediction_cast = tf.cast(
                supervised_correct_prediction, tf.float32)
            accuracy = tf.reduce_mean(supervised_correct_prediction_cast)
            confusion_mat = tf.math.confusion_matrix(gt,
                                                     class_prediction,
                                                     num_classes=num_classes)
            _, accumulated_accuracy = tf.compat.v1.metrics.accuracy(
                gt, class_prediction)
            _, per_class_acc_acc = tf.compat.v1.metrics.mean_per_class_accuracy(
                gt, class_prediction, num_classes=num_classes)
            per_class_acc_acc = tf.reduce_mean(per_class_acc_acc)

            class_prediction = tf.nn.softmax(logits)
            return loss, pre_logits, accuracy, confusion_mat, accumulated_accuracy, per_class_acc_acc, class_prediction, logits

        # self.train_loss,self.train_pre_logits,self.train_accuracy,self.train_confusion_mat,\
        # self.train_accumulated_accuracy,self.train_per_class_acc_acc ,self.train_class_prediction,self.train_logits = cal_metrics(train_end_points);

        self.val_loss,self.val_pre_logits,self.val_accuracy, self.val_confusion_mat, self.val_accumulated_accuracy \
            , self.val_per_class_acc_acc ,self.val_class_prediction,self.val_logits = cal_metrics(val_end_points)
예제 #4
0
    def __init__(self,
                 cfg=None,
                 is_training=True,
                 global_pool=True,
                 output_stride=None,
                 spatial_squeeze=True,
                 reuse=None,
                 scope='resnet_v2_50',
                 images_ph=None,
                 lbls_ph=None):
        self.cfg = cfg
        filter_type = cfg.filter_type
        verbose = cfg.print_filter_name
        batch_size = None

        if lbls_ph is not None:
            self.gt_lbls = tf.reshape(lbls_ph, [-1, cfg.num_classes])
        else:
            self.gt_lbls = tf.placeholder(tf.int32,
                                          shape=(batch_size, cfg.num_classes),
                                          name='class_lbls')

        self.do_augmentation = tf.placeholder(tf.bool, name='do_augmentation')
        self.loss_class_weight = tf.placeholder(tf.float32,
                                                shape=(cfg.num_classes,
                                                       cfg.num_classes),
                                                name='weights')
        self.input = tf.placeholder(tf.float32,
                                    shape=(batch_size, const.max_frame_size,
                                           const.max_frame_size,
                                           const.num_channels),
                                    name='context_input')

        # if is_training:
        if images_ph is not None:
            self.input = images_ph
            _, w, h, c = self.input.shape
            aug_imgs = tf.reshape(self.input, [-1, w, h, c])
            print('No nnutils Augmentation')
        else:
            aug_imgs = tf.cond(
                self.do_augmentation,
                lambda: nn_utils.augment(self.input,
                                         cfg.preprocess_func,
                                         horizontal_flip=True,
                                         vertical_flip=False,
                                         rotate=0,
                                         crop_probability=0,
                                         color_aug_probability=0),
                lambda: nn_utils.center_crop(self.input, cfg.preprocess_func))

        with slim.arg_scope(resnet_arg_scope()):

            _, val_end_points = resnet_v2_50(aug_imgs,
                                             cfg.num_classes,
                                             is_training=False,
                                             global_pool=global_pool,
                                             output_stride=output_stride,
                                             spatial_squeeze=spatial_squeeze,
                                             reuse=True,
                                             scope=scope,
                                             filter_type=filter_type,
                                             verbose=verbose)

        def cal_metrics(end_points):
            gt = tf.argmax(self.gt_lbls, 1)
            logits = tf.reshape(end_points['resnet_v2_50/logits'],
                                [-1, cfg.num_classes])
            pre_logits = None  #end_points['resnet_v2_50/block4/unit_3/bottleneck_v2']

            center_supervised_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=self.gt_lbls, logits=logits, name='xentropy_center')
            loss = tf.reduce_mean(center_supervised_cross_entropy,
                                  name='xentropy_mean')
            predictions = tf.reshape(end_points['predictions'],
                                     [-1, cfg.num_classes])
            class_prediction = tf.argmax(predictions, 1)
            supervised_correct_prediction = tf.equal(gt, class_prediction)
            supervised_correct_prediction_cast = tf.cast(
                supervised_correct_prediction, tf.float32)
            accuracy = tf.reduce_mean(supervised_correct_prediction_cast)
            confusion_mat = tf.confusion_matrix(gt,
                                                class_prediction,
                                                num_classes=cfg.num_classes)
            _, accumulated_accuracy = tf.compat.v1.metrics.accuracy(
                gt, class_prediction)
            _, per_class_acc_acc = tf.compat.v1.metrics.mean_per_class_accuracy(
                gt, class_prediction, num_classes=cfg.num_classes)
            per_class_acc_acc = tf.reduce_mean(per_class_acc_acc)
            return loss, pre_logits, accuracy, confusion_mat, accumulated_accuracy, per_class_acc_acc, class_prediction, logits

        self.val_loss, self.val_pre_logits, self.val_accuracy, self.val_confusion_mat, \
        self.val_accumulated_accuracy, self.val_per_class_acc_acc, self.val_class_prediction,self.val_logits = cal_metrics(
            val_end_points)
    def __init__(self,
                 cfg=None,
                 is_training=True,
                 global_pool=True,
                 output_stride=None,
                 spatial_squeeze=True,
                 reuse=None,
                 scope='resnet_v2_50',
                 images_ph=None,
                 lbls_ph=None):
        self.cfg = cfg
        batch_size = None
        num_classes = cfg.num_classes

        if lbls_ph is not None:
            self.gt_lbls = tf.reshape(lbls_ph, [-1, num_classes])
        else:
            self.gt_lbls = tf.placeholder(tf.int32,
                                          shape=(batch_size, num_classes),
                                          name='class_lbls')

        self.augment_input = tf.placeholder(tf.bool, name='augment_input')

        ## Check whether to use placeholder for training (images_ph == None),
        # or the caller training procedure already provide images dataset pipeline
        if images_ph is not None:
            ## If training using images TF dataset pipeline, no need to do augmentation,
            #  just make sure the input is in the correct shape

            ## This alternative is more efficient because it avoid the discouraged TF placeholder usage
            self.input = images_ph
            _, w, h, c = self.input.shape
            aug_imgs = tf.reshape(self.input, [-1, w, h, 3])
        else:

            # If the input provide no images TF dataset pipeline
            # Revert to the traditional placeholder usage
            self.input = tf.placeholder(
                tf.float32,
                shape=(batch_size, const.max_frame_size, const.max_frame_size,
                       const.frame_channels),
                name='context_input')

            ## Training procedure controls whether to augment placeholder images
            #  using self.augment_input bool tensor
            aug_imgs = tf.cond(
                self.augment_input,
                lambda: nn_utils.augment(self.input,
                                         horizontal_flip=True,
                                         vertical_flip=False,
                                         rotate=0,
                                         crop_probability=0,
                                         color_aug_probability=0),
                lambda: nn_utils.center_crop(self.input))

        with slim.arg_scope(resnet_arg_scope()):

            ## Why there are two endpoints? Short answer to do batch-normalization correctly.
            #
            ## Long answer:
            # first check this out https://github.com/tensorflow/tensorflow/issues/5987

            ### During training, a network *learns* how to normalize input,
            ### i.e. tf.layers.batch_normalization params like beta and gamma are *Learned*
            ### During evaluation, a network *uses* the learned batch-normalization params to normalize input

            ## A lot of iterations are required to learncorrect  tf.layers.batch_normalization params(beta and gamma)
            ## Really a lot of iterations. Remember this, I will revisit it later

            ## During training, a normalized image is dependent on other images within the mini-batch.
            ## During evaluation, a normalized image is *not* dependent on other images within the mini-batch.

            ## While training a network, a periodic evalution on validation is typical like after every 1000 iterations
            ## If evaluation on validation images is performed while learning batch-normalization params,
            # the quantitative performance is not acurrate.
            # Proof? using a set of images, run evaluation twice by feeding images in different order: alphabetically vs random
            # The quantitative performance will change

            ## The right way is to evaluate using the already *learned batch_normalization params(beta and gamma)*
            # Thus, it is important to make sure batch_normalization params are updated during evaluation
            ## and each image is normalized independently without any conditioning on the mini-batch

            ## To this end, I replicate my network twice similar to siamese networks.
            # The weights of both networks are identical and change together during training, i.e. backpropagation
            # Yet, one network learns batch_normalization params (is_training=True)
            # while the other uses the already learned params (is_training=False)

            _, train_end_points = resnet_v2_50(aug_imgs,
                                               num_classes,
                                               is_training=True,
                                               global_pool=global_pool,
                                               output_stride=output_stride,
                                               spatial_squeeze=spatial_squeeze,
                                               reuse=reuse,
                                               scope=scope)

            _, val_end_points = resnet_v2_50(aug_imgs,
                                             num_classes,
                                             is_training=False,
                                             global_pool=global_pool,
                                             output_stride=output_stride,
                                             spatial_squeeze=spatial_squeeze,
                                             reuse=True,
                                             scope=scope)

        def cal_metrics(end_points):
            gt = tf.argmax(self.gt_lbls, 1)
            logits = tf.reshape(end_points['resnet_v2_50/logits'],
                                [-1, num_classes])
            pre_logits = end_points['resnet_v2_50/block4/unit_3/bottleneck_v2']

            center_supervised_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=self.gt_lbls, logits=logits, name='xentropy_center')
            loss = tf.reduce_mean(center_supervised_cross_entropy,
                                  name='xentropy_mean')
            predictions = tf.reshape(end_points['predictions'],
                                     [-1, num_classes])
            class_prediction = tf.argmax(predictions, 1)
            supervised_correct_prediction = tf.equal(gt, class_prediction)
            supervised_correct_prediction_cast = tf.cast(
                supervised_correct_prediction, tf.float32)
            accuracy = tf.reduce_mean(supervised_correct_prediction_cast)
            confusion_mat = tf.confusion_matrix(gt,
                                                class_prediction,
                                                num_classes=num_classes)
            _, accumulated_accuracy = tf.metrics.accuracy(gt, class_prediction)
            _, per_class_acc_acc = tf.metrics.mean_per_class_accuracy(
                gt, class_prediction, num_classes=num_classes)
            per_class_acc_acc = tf.reduce_mean(per_class_acc_acc)
            return loss, pre_logits, accuracy, confusion_mat, accumulated_accuracy, per_class_acc_acc

        self.train_loss,self.train_pre_logits,self.train_accuracy,self.train_confusion_mat,\
                        self.train_accumulated_accuracy,self.train_per_class_acc_acc  = cal_metrics(train_end_points)


        self.val_loss,self.val_pre_logits,self.val_accuracy, self.val_confusion_mat,\
                        self.val_accumulated_accuracy,self.val_per_class_acc_acc  = cal_metrics(val_end_points)
    def __init__(self,
                 cfg=None,
                 is_training=True,
                 dropout_keep_prob=0.8,
                 scope='InceptionV1',
                 images_ph=None,
                 lbls_ph=None):
        self.cfg = cfg
        batch_size = None
        filter_type = cfg.filter_type
        verbose = cfg.print_filter_name
        num_classes = cfg.num_classes
        if lbls_ph is not None:
            self.gt_lbls = tf.reshape(lbls_ph, [-1, num_classes])
        else:
            self.gt_lbls = tf.placeholder(tf.int32,
                                          shape=(batch_size, num_classes),
                                          name='class_lbls')

        self.do_augmentation = tf.placeholder(tf.bool, name='do_augmentation')
        self.loss_class_weight = tf.placeholder(tf.float32,
                                                shape=(num_classes,
                                                       num_classes),
                                                name='weights')
        self.input = tf.placeholder(tf.float32,
                                    shape=(batch_size, const.max_frame_size,
                                           const.max_frame_size,
                                           const.num_channels),
                                    name='context_input')

        # if is_training:
        if images_ph is not None:
            self.input = images_ph
            _, w, h, c = self.input.shape
            aug_imgs = tf.reshape(self.input, [-1, w, h, c])
            print('No nnutils Augmentation')
        else:
            if cfg.db_name == 'honda':
                aug_imgs = self.input
            else:
                aug_imgs = tf.cond(
                    self.do_augmentation,
                    lambda: nn_utils.augment(self.input,
                                             cfg.preprocess_func,
                                             horizontal_flip=True,
                                             vertical_flip=False,
                                             rotate=0,
                                             crop_probability=0,
                                             color_aug_probability=0), lambda:
                    nn_utils.center_crop(self.input, cfg.preprocess_func))

        with slim.arg_scope(inception_v1_arg_scope()):
            # _, train_end_points = inception_v1(aug_imgs, num_classes,
            #                                    dropout_keep_prob=dropout_keep_prob,
            #                                    is_training=True,reuse=reuse, scope=scope)

            _, self.val_end_points = inception_v1(
                aug_imgs,
                num_classes,
                dropout_keep_prob=dropout_keep_prob,
                is_training=False,
                reuse=True,
                scope=scope,
                filter_type=filter_type,
                verbose=verbose)

        def cal_metrics(end_points):
            gt = tf.argmax(self.gt_lbls, 1)
            logits = tf.reshape(end_points['Logits'], [-1, num_classes])
            pre_logits = end_points['Mixed_5c']

            center_supervised_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=self.gt_lbls, logits=logits, name='xentropy_center')
            loss = tf.reduce_mean(center_supervised_cross_entropy,
                                  name='xentropy_mean')
            predictions = tf.reshape(end_points['Predictions'],
                                     [-1, num_classes])
            class_prediction = tf.argmax(predictions, 1)
            supervised_correct_prediction = tf.equal(gt, class_prediction)
            supervised_correct_prediction_cast = tf.cast(
                supervised_correct_prediction, tf.float32)
            accuracy = tf.reduce_mean(supervised_correct_prediction_cast)
            confusion_mat = tf.confusion_matrix(gt,
                                                class_prediction,
                                                num_classes=num_classes)
            _, accumulated_accuracy = tf.compat.v1.metrics.accuracy(
                gt, class_prediction)
            _, per_class_acc_acc = tf.compat.v1.metrics.mean_per_class_accuracy(
                gt, class_prediction, num_classes=num_classes)
            per_class_acc_acc = tf.reduce_mean(per_class_acc_acc)
            class_prediction = tf.nn.softmax(logits)
            return loss, pre_logits, accuracy, confusion_mat, accumulated_accuracy, per_class_acc_acc, class_prediction, logits

        # self.train_loss,self.train_pre_logits,self.train_accuracy,self.train_confusion_mat,\
        #                 self.train_accumulated_accuracy,self.train_per_class_acc_acc ,self.train_class_prediction = cal_metrics(train_end_points);


        self.val_loss,self.val_pre_logits,self.val_accuracy, self.val_confusion_mat,\
                        self.val_accumulated_accuracy,self.val_per_class_acc_acc ,self.val_class_prediction,self.val_logits = cal_metrics(self.val_end_points)
예제 #7
0
    def __init__(self,
                 num_classes,
                 weight_decay=0.0001,
                 data_format='NHWC',
                 is_training=False,
                 reuse=None,
                 images_ph=None,
                 lbls_ph=None):

        batch_size = None
        if lbls_ph is not None:
            self.gt_lbls = tf.reshape(lbls_ph, [-1, config.num_classes])
        else:
            self.gt_lbls = tf.placeholder(tf.int32,
                                          shape=(batch_size,
                                                 config.num_classes),
                                          name='class_lbls')

        self.is_training = tf.placeholder(tf.bool, name='training')
        self.input = tf.placeholder(tf.float32,
                                    shape=(batch_size, const.max_frame_size,
                                           const.max_frame_size,
                                           const.frame_channels),
                                    name='context_input')

        if is_training:
            if images_ph is not None:
                self.input = images_ph
                _, w, h, c = self.input.shape
                aug_imgs = tf.reshape(self.input, [-1, w, h, 3])
                # print('No nnutils Augmentation')
            else:

                aug_imgs = tf.cond(
                    self.is_training,
                    lambda: nn_utils.augment(self.input,
                                             horizontal_flip=True,
                                             vertical_flip=False,
                                             rotate=0,
                                             crop_probability=0,
                                             color_aug_probability=0),
                    lambda: nn_utils.center_crop(self.input))
        else:
            if images_ph is not None:
                self.input = images_ph
                _, w, h, c = self.input.shape
                aug_imgs = tf.reshape(self.input, [-1, w, h, 3])
                # print('No nnutils Augmentation')
            else:
                # self.input = tf.placeholder(tf.float32, shape=(batch_size, const.frame_height, const.frame_width,
                #                                                const.frame_channels), name='context_input')
                aug_imgs = tf.cond(
                    self.is_training,
                    lambda: nn_utils.augment(self.input,
                                             horizontal_flip=True,
                                             vertical_flip=False,
                                             rotate=0,
                                             crop_probability=0,
                                             color_aug_probability=0),
                    lambda: nn_utils.center_crop(self.input))

        self.batch_norm_enabled = tf.Variable(True,
                                              name='is_training',
                                              dtype=tf.bool,
                                              trainable=False)
        with tf.contrib.slim.arg_scope(
                densenet_arg_scope(weight_decay=weight_decay,
                                   data_format=data_format)):
            nets, train_end_points = densenet(aug_imgs,
                                              num_classes=num_classes,
                                              reduction=0.5,
                                              growth_rate=48,
                                              num_filters=96,
                                              num_layers=[6, 12, 36, 24],
                                              data_format=data_format,
                                              is_training=True,
                                              reuse=None,
                                              scope='densenet161')

            val_nets, val_end_points = densenet(
                aug_imgs,
                num_classes=num_classes,
                reduction=0.5,
                growth_rate=48,
                num_filters=96,
                num_layers=[6, 12, 36, 24],
                data_format=data_format,
                is_training=False,  ## Set is always to false
                reuse=True,
                scope='densenet161')

        def cal_metrics(end_points):
            gt = tf.argmax(self.gt_lbls, 1)
            logits = tf.reshape(end_points['densenet161/logits'],
                                [-1, num_classes])
            pre_logits = end_points['densenet161/dense_block4']

            center_supervised_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=self.gt_lbls, logits=logits, name='xentropy_center')
            loss = tf.reduce_mean(center_supervised_cross_entropy,
                                  name='xentropy_mean')
            predictions = tf.reshape(end_points['predictions'],
                                     [-1, num_classes])
            class_prediction = tf.argmax(predictions, 1)
            supervised_correct_prediction = tf.equal(gt, class_prediction)
            supervised_correct_prediction_cast = tf.cast(
                supervised_correct_prediction, tf.float32)
            accuracy = tf.reduce_mean(supervised_correct_prediction_cast)
            confusion_mat = tf.confusion_matrix(gt,
                                                class_prediction,
                                                num_classes=num_classes)
            _, accumulated_accuracy = tf.metrics.accuracy(gt, class_prediction)

            return loss, pre_logits, accuracy, confusion_mat, accumulated_accuracy

        self.train_loss, self.train_pre_logits, self.train_accuracy, self.train_confusion_mat, self.train_accumulated_accuracy = cal_metrics(
            train_end_points)
        self.val_loss, self.val_pre_logits, self.val_accuracy, self.val_confusion_mat, self.val_accumulated_accuracy = cal_metrics(
            val_end_points)