Exemple #1
0
    def _build_input_pipelines(self):
        self._train_input_pipeline = InputPipeline(
            records=self._train_records,
            records_type=RecordsParser.RECORDS_LABELLED,
            shuffle_buffer_size=self._shuffle_buffer_size,
            batch_size=self._batch_size,
            num_preprocessing_threads=self._num_preprocessing_threads,
            num_repeat=1,
            preprocessing_fn=self._classifier.preprocess_image,
            preprocessing_kwargs={'is_training': True},
            drop_remainder=False,
            iterator_type=InputPipeline.INITIALIZABLE_ITERATOR)

        self._val_input_pipeline = InputPipeline(
            records=self._val_records,
            records_type=RecordsParser.RECORDS_LABELLED,
            shuffle_buffer_size=self._shuffle_buffer_size,
            batch_size=self._batch_size,
            num_preprocessing_threads=self._num_preprocessing_threads,
            num_repeat=-1,
            preprocessing_fn=self._classifier.preprocess_image,
            preprocessing_kwargs={'is_training': False},
            drop_remainder=False,
            iterator_type=InputPipeline.ONESHOT_ITERATOR)
Exemple #2
0
    def run0(self, ae_config, pc_config, cpm_checkpoint):

        tf.reset_default_graph()

        print('')
        self.logger.info('===== building graph start =====')

        with tf.Graph().as_default() as graph:
            # datafeed
            self.logger.info('* datafeed')

            input_pipeline = InputPipeline(
                records=self.records_file,
                records_type=RecordsParser.RECORDS_UNLABELLED,
                shuffle_buffer_size=self._batch_size *
                self.NUM_PREPROCESSING_THREADS,
                batch_size=self._batch_size,
                num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS,
                num_repeat=1,
                preprocessing_fn=self._get_resize_function(
                    self._image_height, self._image_width),
                preprocessing_kwargs={},
                drop_remainder=True,
                compute_bpp=False,
                shuffle=False,
                dtype_out=tf.float32)

            images = input_pipeline.next_batch()[0]

            # compression + inference op
            self.logger.info('* compression')
            with tf.name_scope('compression'):
                print(images.get_shape().as_list())
                images = nhwc_to_nchw(images)

                # create networks
                ae_cls = autoencoder.get_network_cls(ae_config)
                pc_cls = probclass.get_network_cls(pc_config)

                # instantiate models
                ae = ae_cls(ae_config)
                pc = pc_cls(pc_config, num_centers=ae_config.num_centers)

                enc_out_val = ae.encode(images, is_training=False)
                images_compressed = ae.decode(enc_out_val.qhard,
                                              is_training=False)

                bitcost_val = pc.bitcost(enc_out_val.qbar,
                                         enc_out_val.symbols,
                                         is_training=False,
                                         pad_value=pc.auto_pad_value(ae))
                avg_bits_per_pixel = bitcost_to_bpp(bitcost_val, images)
                images = nchw_to_nhwc(images)
                images_compressed = nchw_to_nhwc(images_compressed)

            # compute distortions
            self.logger.info('* distortions')
            with tf.name_scope('distortions'):
                distortions_obj = Distortions(
                    reconstructed_images=images_compressed,
                    original_images=tf.cast(images, tf.float32),
                    lambda_ms_ssim=1.0,
                    lambda_psnr=1.0,
                    lambda_feature_loss=1.0,
                    data_format=self.DATA_FORMAT,
                    loss_net_kwargs=None)

                distortions_ops = {
                    'ms_ssim': distortions_obj.compute_ms_ssim(),
                    'mse': distortions_obj.compute_mse(),
                    'psnr': distortions_obj.compute_psnr()
                }

            # cpm saver
            cpm_saver = Saver(
                cpm_checkpoint,
                var_list=Saver.get_var_list_of_ckpt_dir(cpm_checkpoint))
            ckpt_itr, cpm_ckpt_path = Saver.all_ckpts_with_iterations(
                cpm_checkpoint)[-1]
            self.logger.info('ckpt_itr={}'.format(ckpt_itr))
            self.logger.info('ckpt_path={}'.format(cpm_ckpt_path))

        graph.finalize()

        with tf.Session(config=get_sess_config(allow_growth=False),
                        graph=graph) as sess:

            cpm_saver.restore_ckpt(sess, cpm_ckpt_path)

            distortions_values = {key: list() for key in self.DISTORTION_KEYS}
            bpp_values = []
            n_images_processed = 0
            n_images_processed_per_second = deque(10 * [0.0], 10)
            progress(
                n_images_processed, self._dataset.NUM_VAL,
                '{}/{} images processed'.format(n_images_processed,
                                                self._dataset.NUM_VAL))

            try:
                while True:
                    batch_start_time = time.time()

                    # compute distortions and bpp
                    batch_bpp_mean_values, batch_distortions_values = sess.run(
                        [avg_bits_per_pixel, distortions_ops])

                    # collect values
                    bpp_values.append(batch_bpp_mean_values)
                    for key in self.DISTORTION_KEYS:
                        distortions_values[key].append(
                            batch_distortions_values[key])

                    n_images_processed += self._batch_size
                    n_images_processed_per_second.append(
                        self._batch_size / (time.time() - batch_start_time))

                    progress(n_images_processed,
                             self._dataset.NUM_VAL,
                             status='{}/{} images processed ({} img/s)'.format(
                                 n_images_processed, self._dataset.NUM_VAL,
                                 np.mean([
                                     t for t in n_images_processed_per_second
                                 ])))

            except tf.errors.OutOfRangeError:
                self.logger.info(
                    'reached end of dataset; processed {} images'.format(
                        n_images_processed))

            except KeyboardInterrupt:
                self.logger.info(
                    'manual interrupt; processed {}/{} images'.format(
                        n_images_processed, self._dataset.NUM_VAL))

                mean_bpp_values = np.mean(bpp_values)
                mean_dist_values = {
                    key: np.mean(arr)
                    for key, arr in distortions_values.items()
                }

                print('*** intermediate results:')
                print('bits per pixel: {}'.format(mean_bpp_values))
                for key in self.DISTORTION_KEYS:
                    print('{}: {}'.format(key, mean_dist_values[key]))

                return {key: np.nan for key in self.DISTORTION_KEYS}, np.nan

        mean_bpp_values = np.mean(bpp_values)
        mean_dist_values = {
            key: np.mean(arr)
            for key, arr in distortions_values.items()
        }

        return mean_dist_values, mean_bpp_values
Exemple #3
0
    def eval_classifier_model(self,
                              cnn_model: any(
                                  [ImagenetClassifier, FGVCClassifier]),
                              ckpt_path=None):

        tf.reset_default_graph()

        print('')
        self.logger.info('===== building graph start: {} ====='.format(
            cnn_model.NAME))

        # assertions
        assert self._dataset.NUM_CLASSES == cnn_model.num_classes, 'incostent number of classes ({} != {})'.format(
            self._dataset.NUM_CLASSES, cnn_model.num_classes)

        # image shapes
        image_shape_classification = cnn_model.INPUT_SHAPE
        image_shape_compression = CompressionPreprocessing.pad_image_shape(
            image_shape=image_shape_classification,
            size_multiple_of=self.SIZE_MULTIPLE_OF,
            extra_padding_multiples=2)

        # log image sizes
        self.logger.info(
            'image_shape_classification={}'.format(image_shape_classification))
        self.logger.info(
            'image_shape_compression={}'.format(image_shape_compression))

        with tf.Graph().as_default() as graph:
            # datafeed
            self.logger.info('* datafeed')

            rnn_model = self._get_rnn_model(image_shape_compression[0],
                                            image_shape_compression[1])

            input_pipeline = InputPipeline(
                records=self.records_file,
                records_type=RecordsParser.RECORDS_LABELLED,
                shuffle_buffer_size=self.BATCH_SIZE *
                self.NUM_PREPROCESSING_THREADS,
                batch_size=self.BATCH_SIZE,
                num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS,
                num_repeat=1,
                preprocessing_fn=CompressionPreprocessing.preprocess_image,
                preprocessing_kwargs={
                    'height': image_shape_compression[0],
                    'width': image_shape_compression[1],
                    'resize_side_min': min(image_shape_compression[:2]),
                    'is_training': False,
                    'dtype_out': tf.uint8
                },
                drop_remainder=False,
                compute_bpp=False,
                shuffle=False)

            images, labels = input_pipeline.next_batch()

            # compression + inference op
            self.logger.info('* compression')
            with tf.name_scope('rnn_compression'):
                image_batch_compressed = rnn_model.build_model(
                    images=images,
                    is_training=tf.cast(False, tf.bool),
                    reuse=tf.get_variable_scope().reuse)

            # inference kwargs
            self.logger.info('* inference')
            if self._dataset_name == Imagenet.NAME:

                def inference_kwargs(**kwargs):
                    return dict(graph=kwargs['graph'])
            else:

                def inference_kwargs(**kwargs):
                    return dict(
                        arg_scope=cnn_model.arg_scope(weight_decay=float(0)),
                        is_training=False,
                        return_predictions=True,
                        reuse=True if kwargs['j'] > 0 else False)

            predictions_per_compression = []

            with tf.name_scope('inference_rnn'):

                for rnn_iteration in range(self._num_iterations):
                    with tf.name_scope('iteration_{}'.format(rnn_iteration)):
                        image_batch_compressed_iteration = tf.cast(
                            image_batch_compressed[rnn_iteration], tf.float32)

                        # take central crop of images in batch
                        image_batch_compressed_iteration = tf.image.resize_image_with_crop_or_pad(
                            image=image_batch_compressed_iteration,
                            target_height=image_shape_classification[0],
                            target_width=image_shape_classification[1])

                        # standardize appropriately
                        image_batch_compressed_iteration = cnn_model.standardize_tensor(
                            image_batch_compressed_iteration)

                        # predict
                        preds = cnn_model.inference(
                            image_batch_compressed_iteration,
                            **inference_kwargs(graph=graph, j=rnn_iteration))
                        predictions_per_compression.append(preds)

                # aggregate
                predictions_per_compression_op = tf.stack(
                    predictions_per_compression, axis=0)
                self.logger.info('predictions_shape: {}'.format(
                    predictions_per_compression_op.get_shape().as_list()))

            # restorers
            if self._dataset_name == Imagenet.NAME:
                classifier_saver = None
            else:
                classifier_saver = tf.train.Saver(
                    var_list=cnn_model.model_variables())

            # rnn saver
            saver = tf.train.Saver(var_list=rnn_model.model_variables)

        graph.finalize()

        with tf.Session(config=get_sess_config(allow_growth=False),
                        graph=graph) as sess:

            if classifier_saver is not None:
                classifier_saver.restore(sess, ckpt_path)

            saver.restore(sess, self._rnn_checkpoint)

            labels_values = []
            predictions_all_iters_values = [
                list() for _ in range(self._num_iterations)
            ]
            n_images_processed = 0
            n_images_processed_per_second = deque(10 * [0.0], 10)
            progress(
                n_images_processed, self._dataset.NUM_VAL,
                '{}/{} images processed'.format(n_images_processed,
                                                self._dataset.NUM_VAL))

            try:
                while True:
                    batch_start_time = time.time()

                    # run inference
                    batch_predictions_all_iters_values, batch_label_values = sess.run(
                        [predictions_per_compression_op, labels])

                    # collect predictions
                    for rnn_itr, preds_itr in enumerate(
                            batch_predictions_all_iters_values):
                        predictions_all_iters_values[rnn_itr].append(preds_itr)

                    # collect labels and bpp
                    labels_values.append(
                        self.to_categorical(batch_label_values,
                                            Imagenet.NUM_CLASSES))

                    n_images_processed += len(batch_label_values)
                    n_images_processed_per_second.append(
                        len(batch_label_values) /
                        (time.time() - batch_start_time))

                    progress(n_images_processed,
                             self._dataset.NUM_VAL,
                             status='{}/{} images processed ({} img/s)'.format(
                                 n_images_processed, self._dataset.NUM_VAL,
                                 np.mean([
                                     t for t in n_images_processed_per_second
                                 ])))

            except tf.errors.OutOfRangeError:
                self.logger.info(
                    'reached end of dataset; processed {} images'.format(
                        n_images_processed))

            except KeyboardInterrupt:
                self.logger.info(
                    'manual interrupt; processed {}/{} images'.format(
                        n_images_processed, Imagenet.NUM_VAL))
                return [(np.nan, np.nan) for _ in range(self._num_iterations)
                        ], [np.nan for _ in range(self._num_iterations)]

        labels_values = np.concatenate(labels_values, axis=0)
        predictions_all_iters_values = [
            np.concatenate(preds_iter_values, axis=0)
            for preds_iter_values in predictions_all_iters_values
        ]

        accuracies = [(self.top_k_accuracy(labels_values, preds_iter_values,
                                           1),
                       self.top_k_accuracy(labels_values, preds_iter_values,
                                           5))
                      for preds_iter_values in predictions_all_iters_values]

        return accuracies
Exemple #4
0
    def run(self):

        tf.reset_default_graph()

        print('')
        self.logger.info('===== building graph start =====')

        # datafeed
        self.logger.info('* datafeed')

        with tf.Graph().as_default() as graph:

            image_height_compression, image_width_compression, _ = RNNCompressionModel.pad_image_shape(
                image_shape=[self._image_height, self._image_width, 3])

            rnn_model = self._get_rnn_model(image_height_compression,
                                            image_width_compression)

            input_pipeline = InputPipeline(
                records=self._records_file,
                records_type=RecordsParser.RECORDS_UNLABELLED,
                shuffle_buffer_size=0,
                batch_size=self._batch_size,
                num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS,
                num_repeat=1,
                preprocessing_fn=self._get_resize_function(
                    self._image_height, self._image_width),
                preprocessing_kwargs={},
                drop_remainder=True,
                compute_bpp=False,
                shuffle=False)

            images = input_pipeline.next_batch()[0]
            if image_height_compression != self._image_height or image_width_compression != self._image_width:
                images = tf.image.resize_image_with_crop_or_pad(
                    images, image_height_compression, image_width_compression)

            num_images_in_batch_op = tf.shape(images)[0]
            self.logger.info('images shape for compression: {}'.format(
                images.get_shape().as_list()))

            # compress images
            self.logger.info('* compression')
            images_compressed = rnn_model.build_model(
                images=images,
                is_training=tf.cast(False, tf.bool),
                reuse=tf.get_variable_scope().reuse)
            images_compressed.set_shape([
                self._num_iterations, self._batch_size,
                image_height_compression, image_width_compression, 3
            ])
            self.logger.info('compressed images shape: {}'.format(
                images_compressed.get_shape().as_list()))

            # compute distortions
            self.logger.info('* distortions')
            distortions_obj_per_compression = [
                Distortions(reconstructed_images=tf.image.
                            resize_image_with_crop_or_pad(
                                image=images_compressed[ii],
                                target_width=self._image_width,
                                target_height=self._image_height),
                            original_images=tf.cast(images, tf.float32),
                            lambda_ms_ssim=1.0,
                            lambda_psnr=1.0,
                            lambda_feature_loss=1.0,
                            data_format=self.DATA_FORMAT,
                            loss_net_kwargs=None)
                for ii in range(self._num_iterations)
            ]

            distortions_ops_per_compression = [{
                'ms_ssim': d.compute_ms_ssim()
            } for d in distortions_obj_per_compression]

            # savers
            rnn_saver = tf.train.Saver(var_list=rnn_model.model_variables)

        graph.finalize()

        with tf.Session(config=get_sess_config(allow_growth=True),
                        graph=graph) as sess:

            rnn_saver.restore(sess, self._rnn_checkpoint)

            distortions_values_per_compression = [{
                key: list()
                for key in self.DISTORTION_KEYS
            } for _ in range(self._num_iterations)]
            bpp_values_per_compression = [
                list() for _ in range(self._num_iterations)
            ]
            n_images_processed = 0
            n_images_processed_per_second = deque(10 * [0.0], 10)
            progress(
                n_images_processed, Cub200.NUM_VAL,
                '{}/{} images processed'.format(n_images_processed,
                                                self._dataset.NUM_VAL))

            try:
                while True:
                    batch_start_time = time.time()

                    # compute distortions and bpp
                    batch_distortions_values_per_compression, num_images_in_batch = sess.run(
                        [
                            distortions_ops_per_compression,
                            num_images_in_batch_op
                        ])

                    # collect values
                    for comp_level, dist_comp in enumerate(
                            batch_distortions_values_per_compression):
                        bpp_values_per_compression[comp_level].extend(
                            [0.125 * (comp_level + 1)])
                        for key in self.DISTORTION_KEYS:
                            distortions_values_per_compression[comp_level][
                                key].append(dist_comp[key])

                    n_images_processed += num_images_in_batch
                    n_images_processed_per_second.append(
                        num_images_in_batch / (time.time() - batch_start_time))

                    progress(n_images_processed,
                             self._dataset.NUM_VAL,
                             status='{}/{} images processed ({} img/s)'.format(
                                 n_images_processed, self._dataset.NUM_VAL,
                                 np.mean([
                                     t for t in n_images_processed_per_second
                                 ])))

            except tf.errors.OutOfRangeError:
                self.logger.info(
                    'reached end of dataset; processed {} images'.format(
                        n_images_processed))

            except KeyboardInterrupt:
                self.logger.info(
                    'manual interrupt; processed {}/{} images'.format(
                        n_images_processed, self._dataset.NUM_VAL))
                return

            mean_bpp_values_per_compression = [
                np.mean(bpp_vals) for bpp_vals in bpp_values_per_compression
            ]
            mean_dist_values_per_compression = [{
                key: np.mean(arr)
                for key, arr in dist_dict.items()
            } for dist_dict in distortions_values_per_compression]

            self._save_results(mean_bpp_values_per_compression,
                               mean_dist_values_per_compression,
                               self._rnn_unit + '_' + self._loss_name,
                               [q + 1 for q in range(self._num_iterations)])
Exemple #5
0
    def eval_classifier_model(self, ae_config, pc_config, cpm_checkpoint,
                              cnn_model: any([ImagenetClassifier, FGVCClassifier]), fgvc_ckpt_path=None):

        tf.reset_default_graph()

        print('')
        self.logger.info('===== building graph start: {} ====='.format(cnn_model.NAME))

        # assertions
        assert self._dataset.NUM_CLASSES == cnn_model.num_classes, 'incostent number of classes ({} != {})'.format(
            self._dataset.NUM_CLASSES, cnn_model.num_classes)

        # image shapes
        image_shape_classification = cnn_model.INPUT_SHAPE
        image_shape_compression = CompressionPreprocessing.pad_image_shape(image_shape=image_shape_classification,
                                                                           size_multiple_of=self.SIZE_MULTIPLE_OF,
                                                                           extra_padding_multiples=2)

        # log image sizes
        self.logger.info('image_shape_classification={}'.format(image_shape_classification))
        self.logger.info('image_shape_compression={}'.format(image_shape_compression))

        with tf.Graph().as_default() as graph:
            # datafeed
            self.logger.info('* datafeed')

            input_pipeline = InputPipeline(records=self.records_file,
                                           records_type=RecordsParser.RECORDS_LABELLED,
                                           shuffle_buffer_size=self.BATCH_SIZE * self.NUM_PREPROCESSING_THREADS,
                                           batch_size=self.BATCH_SIZE,
                                           num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS,
                                           num_repeat=1,
                                           preprocessing_fn=CompressionPreprocessing.preprocess_image,
                                           preprocessing_kwargs={'height': image_shape_compression[0],
                                                                 'width': image_shape_compression[1],
                                                                 'resize_side_min': min(image_shape_compression[:2]),
                                                                 'is_training': False,
                                                                 'dtype_out': tf.uint8},
                                           drop_remainder=False,
                                           compute_bpp=False,
                                           shuffle=False, dtype_out=tf.float32)

            images, labels = input_pipeline.next_batch()

            # compression + inference op
            self.logger.info('* compression')
            with tf.name_scope('compression'):

                images = nhwc_to_nchw(images)

                # create networks
                ae_cls = autoencoder.get_network_cls(ae_config)
                pc_cls = probclass.get_network_cls(pc_config)

                # instantiate models
                ae = ae_cls(ae_config)
                pc = pc_cls(pc_config, num_centers=ae_config.num_centers)

                enc_out_val = ae.encode(images, is_training=False)
                images_compressed = ae.decode(enc_out_val.qhard, is_training=False)

                bitcost_val = pc.bitcost(enc_out_val.qbar, enc_out_val.symbols, is_training=False,
                                         pad_value=pc.auto_pad_value(ae))
                avg_bits_per_pixel = bitcost_to_bpp(bitcost_val, images)
                images_compressed = nchw_to_nhwc(images_compressed)

                # inference kwargs
            self.logger.info('* inference')
            if self._dataset_name == Imagenet.NAME:
                def inference_kwargs(**kwargs):
                    return dict(graph=kwargs['graph'])
            else:
                def inference_kwargs(**kwargs):
                    return dict(arg_scope=cnn_model.arg_scope(weight_decay=float(0)),
                                is_training=False,
                                return_predictions=True,
                                reuse=None)

            with tf.name_scope('inference_rnn'):

                # take central crop of images in batch
                images_compressed = tf.image.resize_image_with_crop_or_pad(
                    image=images_compressed,
                    target_height=image_shape_classification[0],
                    target_width=image_shape_classification[1])

                # standardize appropriately
                images_compressed = cnn_model.standardize_tensor(
                    images_compressed)

                # predict
                predictions = cnn_model.inference(images_compressed, **inference_kwargs(graph=graph))

                # aggregate
                self.logger.info('predictions_shape: {}'.format(predictions.get_shape().as_list()))

            # restorers
            if self._dataset_name == Imagenet.NAME:
                classifier_saver = None
            else:
                classifier_saver = tf.train.Saver(var_list=cnn_model.model_variables())

            # cpm saver
            cpm_saver = Saver(cpm_checkpoint, var_list=Saver.get_var_list_of_ckpt_dir(cpm_checkpoint))
            ckpt_itr, cpm_ckpt_path = Saver.all_ckpts_with_iterations(cpm_checkpoint)[-1]
            self.logger.info('ckpt_itr={}'.format(ckpt_itr))
            self.logger.info('ckpt_path={}'.format(cpm_ckpt_path))

        graph.finalize()

        with tf.Session(config=get_sess_config(allow_growth=False), graph=graph) as sess:

            cpm_saver.restore_ckpt(sess, cpm_ckpt_path)

            if classifier_saver is not None:
                classifier_saver.restore(sess, fgvc_ckpt_path)

            labels_values = []
            predictions_values = []
            bpp_values = []
            n_images_processed = 0
            n_images_processed_per_second = deque(10 * [0.0], 10)
            progress(n_images_processed, self._dataset.NUM_VAL,
                     '{}/{} images processed'.format(n_images_processed, self._dataset.NUM_VAL))

            try:
                while True:
                    batch_start_time = time.time()

                    # run inference
                    batch_predictions_values, batch_label_values, batch_avg_bpp_values = sess.run(
                        [predictions, labels, avg_bits_per_pixel])

                    # collect predictions
                    predictions_values.append(batch_predictions_values)

                    # collect labels and bpp
                    labels_values.append(self.to_categorical(batch_label_values, Imagenet.NUM_CLASSES))
                    bpp_values.append(batch_avg_bpp_values)

                    n_images_processed += len(batch_label_values)
                    n_images_processed_per_second.append(len(batch_label_values) / (time.time() - batch_start_time))

                    progress(n_images_processed, self._dataset.NUM_VAL,
                             status='{}/{} images processed ({} img/s)'.format(
                                 n_images_processed, self._dataset.NUM_VAL,
                                 np.mean([t for t in n_images_processed_per_second])))

            except tf.errors.OutOfRangeError:
                self.logger.info('reached end of dataset; processed {} images'.format(n_images_processed))

            except KeyboardInterrupt:
                self.logger.info(
                    'manual interrupt; processed {}/{} images'.format(n_images_processed, self._dataset.NUM_VAL))

                labels_values = np.concatenate(labels_values, axis=0)
                predictions_values = np.concatenate(predictions_values, axis=0)
                bpp_values_mean = np.mean(bpp_values)

                accuracies = (self.top_k_accuracy(labels_values, predictions_values, 1),
                              self.top_k_accuracy(labels_values, predictions_values, 5))

                print('*** intermediate results:')
                print('bits per pixel: {}'.format(bpp_values_mean))
                print('Top-1 Accuracy: {}'.format(accuracies[0]))
                print('Top-5 Accuracy: {}'.format(accuracies[1]))

                return (np.nan, np.nan), np.nan

        labels_values = np.concatenate(labels_values, axis=0)
        predictions_values = np.concatenate(predictions_values, axis=0)
        bpp_values_mean = np.mean(bpp_values)

        accuracies = (self.top_k_accuracy(labels_values, predictions_values, 1),
                      self.top_k_accuracy(labels_values, predictions_values, 5))

        return accuracies, bpp_values_mean
Exemple #6
0
class TrainModel:
    ALLOWED_MODELS = [
        'inception_v3', 'mobilenet_v1', 'resnet_v1_50', 'vgg_16',
        'inception_resnet_v2'
    ]
    ALLOWED_DATASETS = [Cub200.NAME, StanfordDogs.NAME]

    def __init__(self, model_name, dataset_name, train_records, val_records,
                 config_path, config_optimizer, config_learning_rate,
                 config_data, config_transfer_learning, init_checkpoint_path,
                 job_id, eval_epochs, checkpoint_epochs):

        # assertions
        assert os.path.isfile(train_records), 'train_records not found'
        assert os.path.isfile(val_records), 'val_records not found'

        assert model_name in self.ALLOWED_MODELS, 'unknown model'
        assert dataset_name in self.ALLOWED_DATASETS, 'unknown dataset'

        # general args
        self._model_name = model_name
        self._dataset = get_dataset(dataset_name)
        self._config_path = config_path
        self._init_checkpoint_path = init_checkpoint_path
        self._eval_epochs = eval_epochs
        self._checkpoint_epochs = checkpoint_epochs

        # optimizer
        self._optimizer_name = config_optimizer['name']
        self._opt_epsilon = config_optimizer['opt_epsilon']
        self._weight_decay = config_optimizer['weight_decay']

        # learning rate
        self._initial_learning_rate = config_learning_rate[
            'initial_learning_rate']
        self._learning_rate_decay_type = config_learning_rate[
            'learning_rate_decay_type']
        self._learning_rate_decay_factor = config_learning_rate[
            'learning_rate_decay_factor']
        self._num_epochs_per_decay = config_learning_rate[
            'num_epochs_per_decay']
        self._end_learning_rate = config_learning_rate['end_learning_rate']

        # data
        self._batch_size = config_data['batch_size']
        self._num_epochs = config_data['num_epochs']
        self._shuffle_buffer_size = config_data['shuffle_buffer_size']
        self._num_preprocessing_threads = config_data[
            'num_preprocessing_threads']
        self._train_records = train_records
        self._val_records = val_records

        # transfer_learning
        self._trainable_scopes = config_transfer_learning['trainable_scopes']
        self._checkpoint_exclude_scopes = config_transfer_learning[
            'checkpoint_exclude_scopes']

        # timing
        self._epoch_times = deque(10 * [0.0], 10)

        # setup logging, directories
        self._job_id = job_id
        self._setup_dirs_and_logging()

        # log params
        log_configs(self._logger, [
            config_optimizer, config_transfer_learning, config_data,
            config_learning_rate
        ])
        self._logger.info('model_name: {}'.format(model_name))
        self._logger.info('train_records: {}'.format(train_records))
        self._logger.info('val_records: {}'.format(val_records))
        self._logger.info('job_id: {}'.format(job_id))

        # build graph
        self._build_graph()

    def _init_model(self):
        self._classifier = classifier_factory.get_fgvc_classifier(
            self._dataset.NAME, self._model_name)

    def _build_input_pipelines(self):
        self._train_input_pipeline = InputPipeline(
            records=self._train_records,
            records_type=RecordsParser.RECORDS_LABELLED,
            shuffle_buffer_size=self._shuffle_buffer_size,
            batch_size=self._batch_size,
            num_preprocessing_threads=self._num_preprocessing_threads,
            num_repeat=1,
            preprocessing_fn=self._classifier.preprocess_image,
            preprocessing_kwargs={'is_training': True},
            drop_remainder=False,
            iterator_type=InputPipeline.INITIALIZABLE_ITERATOR)

        self._val_input_pipeline = InputPipeline(
            records=self._val_records,
            records_type=RecordsParser.RECORDS_LABELLED,
            shuffle_buffer_size=self._shuffle_buffer_size,
            batch_size=self._batch_size,
            num_preprocessing_threads=self._num_preprocessing_threads,
            num_repeat=-1,
            preprocessing_fn=self._classifier.preprocess_image,
            preprocessing_kwargs={'is_training': False},
            drop_remainder=False,
            iterator_type=InputPipeline.ONESHOT_ITERATOR)

    def _compute_loss(self, logits, labels, end_points):
        # cross entropy loss
        cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=labels, logits=logits)
        cross_entropy_loss = tf.reduce_mean(cross_entropy_loss,
                                            name='cross_entropy_loss')

        # aux loss
        if 'AuxLogits' in end_points:
            aux_loss = 0.4 * tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=labels, logits=end_points['AuxLogits'])
            aux_loss = tf.reduce_mean(aux_loss, name='aux_loss')
        else:
            aux_loss = tf.cast(0.0, tf.float32)

        # regularization loss
        if self._trainable_scopes is None:
            regularization_terms = tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)
        else:
            regularization_terms = []
            for sc in self._trainable_scopes:
                regularization_terms.extend(
                    tf.get_collection(
                        tf.GraphKeys.REGULARIZATION_LOSSES,
                        scope=tf.get_default_graph().get_name_scope() + '/' +
                        sc))

        regularization_loss = tf.reduce_sum(regularization_terms,
                                            name='regularization_loss')

        return cross_entropy_loss, aux_loss, regularization_loss

    def _get_trainable_variables(self, verbose=False):
        if self._trainable_scopes is None:
            return tf.trainable_variables()

        assert isinstance(self._trainable_scopes, list)

        variables_to_train = []
        for scope in self._trainable_scopes:
            scope_variables = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope)
            variables_to_train.extend(scope_variables)

        if verbose:
            self._logger.info('====== trainable variables:')
            self._logger.info('trainable_scopes: {}'.format(
                self._trainable_scopes))
            self._logger.info('num_trainable_variables: {}'.format(
                len(variables_to_train)))
            for v in variables_to_train:
                self._logger.info('name: {}, shape: {}, dtype: {}'.format(
                    v.name,
                    v.get_shape().as_list(), v.dtype))
            self._logger.info('===========================')

        return variables_to_train

    def _build_graph(self):
        with tf.Graph().as_default() as graph:
            self._init_model()
            self._build_input_pipelines()
            self._build_graph0()

        graph.finalize()
        self._graph = graph

    def _build_graph0(self):

        self._global_step = tf.train.get_or_create_global_step(
            tf.get_default_graph())

        # ========= train
        with tf.name_scope('train'):
            train_images, train_labels = self._train_input_pipeline.next_batch(
            )
            train_images.set_shape([None, *self._classifier.INPUT_SHAPE])

            # ========= inference
            arg_scope = self._classifier.arg_scope(
                weight_decay=self._weight_decay, is_training=True)
            with tf.variable_scope(tf.get_variable_scope(), reuse=None):
                train_logits, train_end_points = self._classifier.inference(
                    input_tensor=train_images,
                    is_training=True,
                    reuse=None,
                    arg_scope=arg_scope)

            # log shapes
            self._logger.info('images_shape: {}'.format(
                train_images.get_shape().as_list()))
            self._logger.info('logits_shape: {}'.format(
                train_logits.get_shape().as_list()))

            trainable_variables = self._get_trainable_variables(verbose=True)

            # ========= compute losses
            self._cross_entropy_loss, self._aux_loss, self._regularization_loss = self._compute_loss(
                train_logits, train_labels, train_end_points)
            self._train_loss = self._regularization_loss + self._cross_entropy_loss + self._aux_loss

            # ========= accuracy
            train_predictions = tf.argmax(tf.nn.softmax(train_logits),
                                          axis=1,
                                          name='train_predictions')
            self._train_accuracy, self._train_accuracy_update = tf.metrics.accuracy(
                train_labels, train_predictions, name='train_accuracy')

            # ========= configure optimizer
            learning_rate = configure_learning_rate(
                self._global_step, self._batch_size,
                self._initial_learning_rate, 1, self._dataset.NUM_TRAIN,
                self._num_epochs_per_decay, self._learning_rate_decay_type,
                self._learning_rate_decay_factor, self._end_learning_rate)

            self._optimizer = configure_optimizer(learning_rate,
                                                  self._optimizer_name,
                                                  self._opt_epsilon)

            # ========= optimization
            if self._trainable_scopes is None:
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            else:
                update_ops = []
                for sc in self._trainable_scopes:
                    update_ops.extend(
                        tf.get_collection(
                            key=tf.GraphKeys.UPDATE_OPS,
                            scope=tf.get_default_graph().get_name_scope() +
                            '/' + sc))

            self._logger.info('update_ops: {}'.format(update_ops))

            with tf.control_dependencies(update_ops):
                self._train_op = self._optimizer.minimize(
                    self._train_loss, self._global_step, trainable_variables)

            # ========= summaries
            self._train_summaries = self._collect_summaries(
                'train', self._train_loss, self._cross_entropy_loss,
                self._aux_loss, self._regularization_loss,
                self._train_accuracy, learning_rate, trainable_variables,
                train_end_points)

        # ========= validation
        with tf.name_scope('val'):
            val_images, val_labels = self._val_input_pipeline.next_batch()

            # ========= inference
            arg_scope = self._classifier.arg_scope(self._weight_decay)
            with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                val_logits, val_end_points = self._classifier.inference(
                    input_tensor=val_images,
                    is_training=False,
                    reuse=True,
                    arg_scope=arg_scope)

            # ========= compute losses
            self._val_cross_entropy_loss, self._val_aux_loss, self._val_regularization_loss = self._compute_loss(
                val_logits, val_labels, val_end_points)
            self._val_loss = self._val_regularization_loss + self._val_cross_entropy_loss + self._val_aux_loss

            # ========= accuracy
            val_predictions = tf.argmax(tf.nn.softmax(val_logits), axis=1)
            self._val_accuracy, self._val_accuracy_update = tf.metrics.accuracy(
                val_labels, val_predictions, name='validation_accuracy')

            # ========= summaries
            self._val_summaries = self._collect_summaries(
                'val', self._val_loss, self._val_cross_entropy_loss,
                self._val_aux_loss, self._val_regularization_loss,
                self._val_accuracy)

        # ========= saving and restoring
        self._optimizer.get_name()
        vars_to_save = [
            v for v in tf.global_variables() if 'Adam' not in v.op.name
        ]
        vars_to_save.extend(
            [v for v in tf.local_variables() if 'Adam' not in v.op.name])
        self._logger.info(('vars_to_save: {}'.format(vars_to_save)))
        self._model_saver = tf.train.Saver(var_list=vars_to_save,
                                           save_relative_paths=True,
                                           max_to_keep=2)
        self._init_fn = self._get_init_fn(self._classifier.model_variables(),
                                          self._init_checkpoint_path,
                                          self._checkpoint_exclude_scopes,
                                          verbose=True)

    def train(self):
        train_writer = tf.summary.FileWriter(
            os.path.join(self._tensorboard_dir, 'train/'), self._graph)
        val_writer = tf.summary.FileWriter(
            os.path.join(self._tensorboard_dir, 'val/'), self._graph)

        with tf.Session(graph=self._graph,
                        config=get_sess_config(allow_growth=True)) as sess:

            # initialize model and train input pipeline
            self._init_fn(sess)
            self._train_input_pipeline.initialize(sess)

            step = sess.run(self._global_step)

            for epoch in range(1, self._num_epochs + 1):

                epoch_start_time = time.time()

                try:
                    while True:
                        _, step = sess.run([self._train_op, self._global_step])

                except tf.errors.OutOfRangeError:
                    pass

                except KeyboardInterrupt:
                    self._logger.info('manual interrupt')
                    self._clean_up(sess, step, [train_writer, val_writer])
                    return

                self._epoch_times.append(time.time() - epoch_start_time)
                self._train_input_pipeline.initialize(sess)

                if epoch % self._eval_epochs == 0 or epoch == 1:
                    self._eval_procedure(sess, step, epoch, train_writer,
                                         val_writer)

                if epoch % self._checkpoint_epochs == 0:
                    self._model_saver.save(sess,
                                           os.path.join(
                                               self._checkpoint_dir,
                                               'model.ckpt'),
                                           global_step=step,
                                           write_meta_graph=False)
                    self._logger.info('[{} steps]saved model.'.format(step))

            self._clean_up(sess, step, [train_writer, val_writer])

    def _get_init_fn(self,
                     model_variables,
                     checkpoint_path,
                     checkpoint_exclude_scopes,
                     ignore_missing_vars=False,
                     verbose=False):
        """ returns fetches, feed_dict for restoring model """

        # ========== continue with training
        latest_checkpoint = tf.train.latest_checkpoint(self._checkpoint_dir)
        if latest_checkpoint is not None:
            self._logger.info(
                'continue training from {}'.format(latest_checkpoint))

            def _init_fn(sess):
                self._model_saver.restore(sess, latest_checkpoint)

            return _init_fn

        # ========== start training from scratch if no checkpoint is provided
        if checkpoint_path is None:
            self._logger.info(
                'no init_checkpoint_path provided; training network from scratch.'
            )

            def _init_fn(sess):
                sess.run(
                    tf.group([
                        tf.global_variables_initializer(),
                        tf.local_variables_initializer()
                    ]))

            return _init_fn

        # ========== fine tune trainable variables
        exclusions = [v.op.name for v in self._optimizer.variables()]
        if checkpoint_exclude_scopes:
            assert isinstance(checkpoint_exclude_scopes, list)
            exclusions.extend(
                [scope.strip() for scope in checkpoint_exclude_scopes])

        vars_to_restore_from_checkpoint = []
        for var in model_variables:
            for exclusion in exclusions:
                if var.op.name.startswith(exclusion):
                    break
            else:
                vars_to_restore_from_checkpoint.append(var)

        if tf.gfile.IsDirectory(checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)

        vars_to_init_from_scratch = [
            v for v in [*tf.global_variables(), *tf.local_variables()]
            if v not in vars_to_restore_from_checkpoint
        ]

        if verbose:
            self._logger.info(
                '====== variables to be initialized from checkpoint:')
            self._logger.info('num: {}'.format(
                len(vars_to_restore_from_checkpoint)))
            for v in vars_to_restore_from_checkpoint:
                self._logger.info('name: {}, shape: {}, dtype: {}'.format(
                    v.op.name,
                    v.get_shape().as_list(), v.dtype))
            self._logger.info('===========================\n')

            self._logger.info(
                '====== variables to be initialized from scratch:')
            self._logger.info('num: {}'.format(len(vars_to_init_from_scratch)))
            for v in vars_to_init_from_scratch:
                self._logger.info('name: {}, shape: {}, dtype: {}'.format(
                    v.op.name,
                    v.get_shape().as_list(), v.dtype))
            self._logger.info('===========================\n')

        # randomly initialize vars to train
        init_fetches = [
            tf.variables_initializer(var_list=vars_to_init_from_scratch)
        ]

        # restore rest of vars from pretrained model

        restore_op, restore_dict = slim.assign_from_checkpoint(
            checkpoint_path,
            var_list=vars_to_restore_from_checkpoint,
            ignore_missing_vars=ignore_missing_vars)
        init_fetches.append(restore_op)

        def _init_fn(sess):
            sess.run(init_fetches, feed_dict=restore_dict)

        return _init_fn

    def _clean_up(self, sess, step, summary_writers=None):
        print('cleanup...')

        # stop writers
        summary_writers = [
            summary_writers
        ] if not isinstance(summary_writers, list) else summary_writers
        for writer in summary_writers:
            if writer is None:
                continue
            else:
                writer.flush()
                writer.close()

        # save model
        self._model_saver.save(sess,
                               os.path.join(self._checkpoint_dir,
                                            'model.ckpt'),
                               step,
                               write_meta_graph=False)
        self._logger.info('[{} steps]saved model.'.format(step))

    @staticmethod
    def _collect_summaries(split,
                           total_loss,
                           cross_entropy_loss,
                           aux_loss,
                           regularization_loss,
                           accuracy,
                           learning_rate=None,
                           variables=None,
                           end_points=None):
        summaries = list()
        summaries.append(tf.summary.scalar(split + '/total_loss', total_loss))
        summaries.append(tf.summary.scalar(split + '/accuracy', accuracy))
        summaries.append(
            tf.summary.scalar(split + '/cross_entropy_loss',
                              cross_entropy_loss))
        summaries.append(tf.summary.scalar(split + '/aux_loss', aux_loss))
        summaries.append(
            tf.summary.scalar(split + '/regularization_loss',
                              regularization_loss))

        if split == 'val':
            return tf.summary.merge(summaries, name=split + '_summaries')

        summaries.append(
            tf.summary.scalar(split + '/learning_rate', learning_rate))
        for v in variables:
            summaries.append(
                tf.summary.histogram(split + '/variables/{}'.format(v.op.name),
                                     v))
        for ep in end_points:
            a = end_points[ep]
            summaries.append(
                tf.summary.histogram(split + '/activations/{}'.format(ep), a))

        return tf.summary.merge(summaries, name=split + '_summaries')

    def _eval_procedure(self, sess, step, epoch, train_summary_writer,
                        val_summary_writer):

        # train stats
        (train_summaries, train_total_loss, train_regularization_loss,
         train_cross_entropy_loss, train_aux_loss, train_accuracy,
         _) = sess.run([
             self._train_summaries, self._train_loss,
             self._regularization_loss, self._cross_entropy_loss,
             self._aux_loss, self._train_accuracy, self._train_accuracy_update
         ])

        train_summary_writer.add_summary(train_summaries, epoch)
        train_summary_writer.flush()

        # val stats
        (val_summaries, val_total_loss, val_cross_entropy_loss, val_aux_loss,
         val_accuracy, _) = sess.run([
             self._val_summaries, self._val_loss, self._val_cross_entropy_loss,
             self._val_aux_loss, self._val_accuracy, self._val_accuracy_update
         ])

        val_summary_writer.add_summary(val_summaries, epoch)
        val_summary_writer.flush()

        # compute average epoch time
        avg_epoch_time = np.mean([t for t in self._epoch_times if t > 0])
        evg_epoch_time_hms = seconds_to_minutes_seconds(avg_epoch_time)

        # print stats
        self._logger.info(
            self._eval_str(epoch=epoch,
                           num_epochs=self._num_epochs,
                           step=step,
                           avg_epoch_time=evg_epoch_time_hms,
                           total_loss=train_total_loss,
                           val_total_loss=val_total_loss,
                           regularization_loss=train_regularization_loss,
                           cross_entropy_loss=train_cross_entropy_loss,
                           val_cross_entropy_loss=val_cross_entropy_loss,
                           aux_loss=train_aux_loss,
                           val_aux_loss=val_aux_loss,
                           accuracy=100.0 * train_accuracy,
                           val_accuracy=100.0 * val_accuracy))

    @staticmethod
    def _eval_str(epoch, num_epochs, step, avg_epoch_time, total_loss,
                  val_total_loss, regularization_loss, cross_entropy_loss,
                  val_cross_entropy_loss, aux_loss, val_aux_loss, accuracy,
                  val_accuracy):

        eval_str = "[{epoch}/{num_epochs} epochs ({avg_epoch_time} / epoch)] | "
        eval_str += "total_loss: {total_loss:.4f} ({val_total_loss:.4f})| "
        eval_str += "regularization_loss: {regularization_loss:.6f} (-) | "
        eval_str += "cross_entropy_loss: {cross_entropy_loss:.4f} ({val_cross_entropy_loss:.4f}) | "
        eval_str += "aux_loss: {aux_loss:.4f} ({val_aux_loss:.4f}) | "
        eval_str += "accuracy: {accuracy:.2f}% ({val_accuracy:.2f}%)"

        return eval_str.format(epoch=epoch,
                               num_epochs=num_epochs,
                               step=step,
                               avg_epoch_time=avg_epoch_time,
                               total_loss=total_loss,
                               val_total_loss=val_total_loss,
                               regularization_loss=regularization_loss,
                               cross_entropy_loss=cross_entropy_loss,
                               val_cross_entropy_loss=val_cross_entropy_loss,
                               aux_loss=aux_loss,
                               val_aux_loss=val_aux_loss,
                               accuracy=accuracy,
                               val_accuracy=val_accuracy)

    @property
    def graph(self):
        return self._graph

    def _setup_dirs_and_logging(self):
        self._module_dir = os.path.dirname(
            os.path.abspath(inspect.getfile(inspect.currentframe())))
        base_dir = os.path.join(
            self._module_dir,
            'checkpoints/{}_id{}/'.format(self._model_name, self._job_id))

        self._checkpoint_dir = os.path.join(base_dir, 'checkpoints')
        self._tensorboard_dir = os.path.join(base_dir, 'tensorboard')
        self._log_dir = os.path.join(base_dir, 'logs')

        log_strings = []
        if not os.path.exists(base_dir):
            os.makedirs(base_dir)
            log_strings.append('created dir {}'.format(base_dir))

        if not os.path.exists(self._checkpoint_dir):
            os.makedirs(self._checkpoint_dir)
            log_strings.append('created dir {}'.format(self._checkpoint_dir))

        if not os.path.exists(self._log_dir):
            os.makedirs(self._log_dir)
            log_strings.append('created dir {}'.format(self._log_dir))

        if not os.path.exists(self._tensorboard_dir):
            os.makedirs(self._tensorboard_dir)
            log_strings.append('created dir {}'.format(self._tensorboard_dir))

        # copy config file to base dir
        config_dest_file = os.path.join(base_dir, 'config.json')
        copyfile(self._config_path, config_dest_file)

        # logging
        logfile = os.path.join(
            self._log_dir, 'train_{}_id{}.log'.format(self._model_name,
                                                      self._job_id))
        self._logger = get_logger(logfile)
        for s in log_strings:
            self._logger.info(s)
        self._logger.info('logfile={}'.format(logfile))
Exemple #7
0
    def run(self):

        tf.reset_default_graph()

        print('')
        self.logger.info('===== building graph start =====')

        with tf.Graph().as_default() as graph:
            # datafeed
            self.logger.info('* datafeed')

            input_pipeline = InputPipeline(
                records=self._records_file,
                records_type=RecordsParser.RECORDS_UNLABELLED,
                shuffle_buffer_size=0,
                batch_size=self._batch_size,
                num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS,
                num_repeat=1,
                preprocessing_fn=self._get_resize_function(
                    self._image_height, self._image_width),
                preprocessing_kwargs={},
                drop_remainder=True,
                compute_bpp=False,
                shuffle=False)

            images = input_pipeline.next_batch()[0]

            image_shape = images.get_shape().as_list()
            self.logger.info('image_shape: {}'.format(image_shape))

            # compression op
            self.logger.info('* compression')

            images_per_compression = []
            bpp_op_per_compression = []
            for j, compression_level in enumerate(self.COMPRESSION_LEVELS):
                # compress batch
                with tf.name_scope(
                        'compression_webp_{}'.format(compression_level)):
                    with tf.device(CPU_DEVICE):  # -> webp compression on cpu
                        img_batch_compressed, _bpp = TFWebp.tf_encode_decode_image_batch(
                            image_batch=tf.cast(images, tf.uint8),
                            quality=compression_level)

                    img_batch_compressed.set_shape(
                        images.get_shape().as_list())
                    images_per_compression.append(
                        tf.cast(img_batch_compressed, tf.float32))
                    bpp_op_per_compression.append(_bpp)

            # compute distortions
            self.logger.info('* distortions')
            distortions_obj_per_compression = [
                Distortions(reconstructed_images=c_img_batch,
                            original_images=tf.cast(images, tf.float32),
                            lambda_ms_ssim=1.0,
                            lambda_psnr=1.0,
                            lambda_feature_loss=1.0,
                            data_format=self.DATA_FORMAT,
                            loss_net_kwargs=None)
                for c_img_batch in images_per_compression
            ]

            distortions_ops_per_compression = [{
                'ms_ssim': d.compute_ms_ssim()
            } for d in distortions_obj_per_compression]

        graph.finalize()

        with tf.Session(config=get_sess_config(allow_growth=True),
                        graph=graph) as sess:

            distortions_values_per_compression = [{
                key: list()
                for key in self.DISTORTION_KEYS
            } for _ in self.COMPRESSION_LEVELS]
            bpp_values_per_compression = [
                list() for _ in self.COMPRESSION_LEVELS
            ]
            n_images_processed = 0
            n_images_processed_per_second = deque(10 * [0.0], 10)
            progress(
                n_images_processed, Cub200.NUM_VAL,
                '{}/{} images processed'.format(n_images_processed,
                                                self._dataset.NUM_VAL))

            try:
                while True:
                    batch_start_time = time.time()

                    # compute distortions and bpp
                    batch_bpp_values_per_compression, batch_distortions_values_per_compression = sess.run(
                        [
                            bpp_op_per_compression,
                            distortions_ops_per_compression
                        ])

                    # collect values
                    for comp_level, (dist_comp, bpp_comp) in enumerate(
                            zip(batch_distortions_values_per_compression,
                                batch_bpp_values_per_compression)):
                        bpp_values_per_compression[comp_level].extend(bpp_comp)
                        for key in self.DISTORTION_KEYS:
                            distortions_values_per_compression[comp_level][
                                key].append(dist_comp[key])

                    n_images_processed += len(
                        batch_bpp_values_per_compression[0])
                    n_images_processed_per_second.append(
                        len(batch_bpp_values_per_compression[0]) /
                        (time.time() - batch_start_time))

                    progress(n_images_processed,
                             self._dataset.NUM_VAL,
                             status='{}/{} images processed ({} img/s)'.format(
                                 n_images_processed, self._dataset.NUM_VAL,
                                 np.mean([
                                     t for t in n_images_processed_per_second
                                 ])))

            except tf.errors.OutOfRangeError:
                self.logger.info(
                    'reached end of dataset; processed {} images'.format(
                        n_images_processed))

            except KeyboardInterrupt:
                self.logger.info(
                    'manual interrupt; processed {}/{} images'.format(
                        n_images_processed, self._dataset.NUM_VAL))
                return [(np.nan, np.nan) for _ in self.COMPRESSION_LEVELS]

            mean_bpp_values_per_compression = [
                np.mean(bpp_vals) for bpp_vals in bpp_values_per_compression
            ]
            mean_dist_values_per_compression = [{
                key: np.mean(arr)
                for key, arr in dist_dict.items()
            } for dist_dict in distortions_values_per_compression]

            self._save_results(mean_bpp_values_per_compression,
                               mean_dist_values_per_compression, 'webp',
                               self.COMPRESSION_LEVELS)
Exemple #8
0
    def run(self):

        tf.reset_default_graph()

        print('')
        self.logger.info('===== building graph start =====')

        with tf.Graph().as_default() as graph:

            # datafeed
            self.logger.info('* datafeed')

            ip0 = InputPipeline(
                records=self._records_file,
                records_type=RecordsParser.RECORDS_UNLABELLED,
                shuffle_buffer_size=0,
                batch_size=self._batch_size,
                num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS,
                num_repeat=1,
                preprocessing_fn=CompressionPreprocessing.
                preprocess_image_for_eval,
                preprocessing_kwargs={
                    'height': self._image_height,
                    'width': self._image_width,
                    'resize_side_min': min(self._image_height,
                                           self._image_width)
                },
                drop_remainder=True,
                compute_bpp=False,
                shuffle=False)

            original_images = ip0.next_batch()[0]

            image_batches, bpp_op_per_compression = [], []
            for records in self._bpg_records_files:
                ip = InputPipeline(
                    records=records,
                    records_type=RecordsParser.RECORDS_BPP,
                    shuffle_buffer_size=0,
                    batch_size=self._batch_size,
                    num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS,
                    num_repeat=1,
                    preprocessing_fn=CompressionPreprocessing.
                    preprocess_image_with_identity,
                    preprocessing_kwargs={
                        'height': self._image_height,
                        'width': self._image_width,
                        'dtype_out': tf.uint8
                    },
                    drop_remainder=True,
                    compute_bpp=False,
                    shuffle=False)

                images, bpp = ip.next_batch()

                image_batches.append(images)
                bpp_op_per_compression.append(bpp)

            # compute distortions
            self.logger.info('* distortions')
            distortions_obj_per_compression = [
                Distortions(reconstructed_images=c_img_batch,
                            original_images=original_images,
                            lambda_ms_ssim=1.0,
                            lambda_psnr=1.0,
                            lambda_feature_loss=1.0,
                            data_format=self.DATA_FORMAT,
                            loss_net_kwargs=None)
                for c_img_batch in image_batches
            ]

            distortions_ops_per_compression = [{
                'ms_ssim': d.compute_ms_ssim()
            } for d in distortions_obj_per_compression]

        graph.finalize()

        with tf.Session(config=get_sess_config(allow_growth=True),
                        graph=graph) as sess:

            distortions_values_per_compression = [{
                key: list()
                for key in self.DISTORTION_KEYS
            } for _ in range(self._num_compression_levels)]
            bpp_values_per_compression = [
                list() for _ in range(self._num_compression_levels)
            ]
            n_images_processed = 0
            n_images_processed_per_second = deque(10 * [0.0], 10)
            progress(
                n_images_processed, Cub200.NUM_VAL,
                '{}/{} images processed'.format(n_images_processed,
                                                self._dataset.NUM_VAL))

            try:
                while True:
                    batch_start_time = time.time()

                    # compute distortions and bpp
                    batch_bpp_values_per_compression, batch_distortions_values_per_compression = sess.run(
                        [
                            bpp_op_per_compression,
                            distortions_ops_per_compression
                        ])

                    # collect values
                    for comp_level, (dist_comp, bpp_comp) in enumerate(
                            zip(batch_distortions_values_per_compression,
                                batch_bpp_values_per_compression)):
                        bpp_values_per_compression[comp_level].extend(bpp_comp)
                        for key in self.DISTORTION_KEYS:
                            distortions_values_per_compression[comp_level][
                                key].append(dist_comp[key])

                    n_images_processed += len(
                        batch_bpp_values_per_compression[0])
                    n_images_processed_per_second.append(
                        len(batch_bpp_values_per_compression[0]) /
                        (time.time() - batch_start_time))

                    progress(n_images_processed,
                             self._dataset.NUM_VAL,
                             status='{}/{} images processed ({} img/s)'.format(
                                 n_images_processed, self._dataset.NUM_VAL,
                                 np.mean([
                                     t for t in n_images_processed_per_second
                                 ])))

            except tf.errors.OutOfRangeError:
                self.logger.info(
                    'reached end of dataset; processed {} images'.format(
                        n_images_processed))

            except KeyboardInterrupt:
                self.logger.info(
                    'manual interrupt; processed {}/{} images'.format(
                        n_images_processed, self._dataset.NUM_VAL))
                return [(np.nan, np.nan)
                        for _ in range(self._num_compression_levels)]

            mean_bpp_values_per_compression = [
                np.mean(bpp_vals) for bpp_vals in bpp_values_per_compression
            ]
            mean_dist_values_per_compression = [{
                key: np.mean(arr)
                for key, arr in dist_dict.items()
            } for dist_dict in distortions_values_per_compression]

            self._save_results(mean_bpp_values_per_compression,
                               mean_dist_values_per_compression, 'bpg', None)
Exemple #9
0
    def eval_classifier_model(self,
                              cnn_model: any(
                                  [ImagenetClassifier, FGVCClassifier]),
                              ckpt_path=None):

        tf.reset_default_graph()

        print('')
        self.logger.info('===== building graph start: {} ====='.format(
            cnn_model.NAME))

        # assertions
        assert self._dataset.NUM_CLASSES == cnn_model.num_classes, 'incostent number of classes ({} != {})'.format(
            self._dataset.NUM_CLASSES, cnn_model.num_classes)

        # image shapes
        image_shape_classification = cnn_model.INPUT_SHAPE
        image_shape_compression = CompressionPreprocessing.pad_image_shape(
            image_shape=image_shape_classification,
            size_multiple_of=self.SIZE_MULTIPLE_OF,
            extra_padding_multiples=2)

        # log image sizes
        self.logger.info(
            'image_shape_classification={}'.format(image_shape_classification))
        self.logger.info(
            'image_shape_compression={}'.format(image_shape_compression))

        # records files depending on inference resolution
        if image_shape_classification[0] < 256:
            bpg_records_files = list(self._bpg_records_files256)
        else:
            bpg_records_files = list(self._bpg_records_files336)

        self.logger.info('bpg_records_files: {}'.format(bpg_records_files))

        with tf.Graph().as_default() as graph:

            # datafeed
            self.logger.info('* datafeed')
            image_batches, labels_batches, bpp_batches = [], [], []
            for records in bpg_records_files:
                ip = InputPipeline(
                    records=records,
                    records_type=RecordsParser.RECORDS_LABELLED_BPP,
                    shuffle_buffer_size=1,
                    batch_size=self.BATCH_SIZE,
                    num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS,
                    num_repeat=1,
                    preprocessing_fn=CompressionPreprocessing.
                    preprocess_image_with_identity,
                    preprocessing_kwargs={
                        'height': image_shape_compression[0],
                        'width': image_shape_compression[1],
                        'dtype_out': tf.uint8
                    },
                    drop_remainder=True,
                    compute_bpp=False,
                    shuffle=False,
                    dtype_out=tf.uint8)

                images, labels, bpp = ip.next_batch()

                image_batches.append(images)
                labels_batches.append(labels)
                bpp_batches.append(bpp)

            # compression + inference op
            self.logger.info('* inference')

            predictions_per_compression = []

            # inference kwargs
            if self._dataset_name == Imagenet.NAME:

                def inference_kwargs(**kwargs):
                    return dict(graph=kwargs['graph'])
            else:

                def inference_kwargs(**kwargs):
                    return dict(
                        arg_scope=cnn_model.arg_scope(weight_decay=float(0)),
                        is_training=False,
                        return_predictions=True,
                        reuse=True if kwargs['j'] > 0 else False)

            for j, image_batch_compressed in enumerate(image_batches):
                with tf.name_scope('inference_bpg{}'.format(j)):
                    # crop center
                    image_batch_compressed = tf.image.resize_image_with_crop_or_pad(
                        image_batch_compressed, image_shape_classification[0],
                        image_shape_classification[1])

                    # standardize appropriately
                    image_batch_for_inference = cnn_model.standardize_tensor(
                        image_batch_compressed)

                    # predict
                    preds = cnn_model.inference(
                        image_batch_for_inference,
                        **inference_kwargs(graph=graph, j=j))
                    predictions_per_compression.append(preds)

            # aggregate
            predictions_per_compression_op = tf.stack(
                predictions_per_compression, axis=0)
            self.logger.info('predictions_shape: {}'.format(
                predictions_per_compression_op.get_shape().as_list()))

            # restore
            if self._dataset_name == Imagenet.NAME:
                classifier_saver = None
            else:
                classifier_saver = tf.train.Saver(
                    var_list=cnn_model.model_variables())

        graph.finalize()

        with tf.Session(config=get_sess_config(allow_growth=False),
                        graph=graph) as sess:

            if classifier_saver is not None:
                classifier_saver.restore(sess, ckpt_path)

            labels_all_comp_values = [list() for _ in range(self._num_records)]
            predictions_all_comp_values = [
                list() for _ in range(self._num_records)
            ]
            bpp_all_comp_values = [list() for _ in range(self._num_records)]
            n_images_processed = 0
            n_images_processed_per_second = deque(10 * [0.0], 10)
            progress(
                n_images_processed, self._dataset.NUM_VAL,
                '{}/{} images processed'.format(n_images_processed,
                                                self._dataset.NUM_VAL))

            try:
                while True:
                    batch_start_time = time.time()

                    # run inference
                    (batch_predictions_all_comp_values,
                     batch_label_all_comp_values,
                     batch_bpp_all_comp_values) = sess.run([
                         predictions_per_compression_op, labels_batches,
                         bpp_batches
                     ])

                    # collect predictions
                    for comp_level, (preds_comp, bpp_comp,
                                     labels_comp) in enumerate(
                                         zip(batch_predictions_all_comp_values,
                                             batch_bpp_all_comp_values,
                                             batch_label_all_comp_values)):
                        predictions_all_comp_values[comp_level].append(
                            preds_comp)
                        bpp_all_comp_values[comp_level].append(bpp_comp)
                        labels_all_comp_values[comp_level].append(
                            self.to_categorical(labels_comp,
                                                cnn_model.num_classes))

                    n_images_processed += len(batch_label_all_comp_values[0])
                    n_images_processed_per_second.append(
                        len(batch_label_all_comp_values[0]) /
                        (time.time() - batch_start_time))

                    progress(n_images_processed,
                             self._dataset.NUM_VAL,
                             status='{}/{} images processed ({} img/s)'.format(
                                 n_images_processed, self._dataset.NUM_VAL,
                                 np.mean([
                                     t for t in n_images_processed_per_second
                                 ])))

            except tf.errors.OutOfRangeError:
                self.logger.info(
                    'reached end of dataset; processed {} images'.format(
                        n_images_processed))

            except KeyboardInterrupt:
                self.logger.info(
                    'manual interrupt; processed {}/{} images'.format(
                        n_images_processed, Imagenet.NUM_VAL))
                return [(np.nan, np.nan) for _ in range(self._num_records)
                        ], [np.nan for _ in range(self._num_records)]

        labels_all_comp_values = [
            np.concatenate(labels_comp_values, axis=0)
            for labels_comp_values in labels_all_comp_values
        ]
        bpp_all_comp_values = [
            np.mean(np.concatenate(bpp_values, 0))
            for bpp_values in bpp_all_comp_values
        ]
        predictions_all_comp_values = [
            np.concatenate(preds_comp_values, axis=0)
            for preds_comp_values in predictions_all_comp_values
        ]

        accuracies = [(self.top_k_accuracy(labels_values, preds_comp_values,
                                           1),
                       self.top_k_accuracy(labels_values, preds_comp_values,
                                           5))
                      for preds_comp_values, labels_values in zip(
                          predictions_all_comp_values, labels_all_comp_values)]

        return accuracies, bpp_all_comp_values