def _build_input_pipelines(self): self._train_input_pipeline = InputPipeline( records=self._train_records, records_type=RecordsParser.RECORDS_LABELLED, shuffle_buffer_size=self._shuffle_buffer_size, batch_size=self._batch_size, num_preprocessing_threads=self._num_preprocessing_threads, num_repeat=1, preprocessing_fn=self._classifier.preprocess_image, preprocessing_kwargs={'is_training': True}, drop_remainder=False, iterator_type=InputPipeline.INITIALIZABLE_ITERATOR) self._val_input_pipeline = InputPipeline( records=self._val_records, records_type=RecordsParser.RECORDS_LABELLED, shuffle_buffer_size=self._shuffle_buffer_size, batch_size=self._batch_size, num_preprocessing_threads=self._num_preprocessing_threads, num_repeat=-1, preprocessing_fn=self._classifier.preprocess_image, preprocessing_kwargs={'is_training': False}, drop_remainder=False, iterator_type=InputPipeline.ONESHOT_ITERATOR)
def run0(self, ae_config, pc_config, cpm_checkpoint): tf.reset_default_graph() print('') self.logger.info('===== building graph start =====') with tf.Graph().as_default() as graph: # datafeed self.logger.info('* datafeed') input_pipeline = InputPipeline( records=self.records_file, records_type=RecordsParser.RECORDS_UNLABELLED, shuffle_buffer_size=self._batch_size * self.NUM_PREPROCESSING_THREADS, batch_size=self._batch_size, num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS, num_repeat=1, preprocessing_fn=self._get_resize_function( self._image_height, self._image_width), preprocessing_kwargs={}, drop_remainder=True, compute_bpp=False, shuffle=False, dtype_out=tf.float32) images = input_pipeline.next_batch()[0] # compression + inference op self.logger.info('* compression') with tf.name_scope('compression'): print(images.get_shape().as_list()) images = nhwc_to_nchw(images) # create networks ae_cls = autoencoder.get_network_cls(ae_config) pc_cls = probclass.get_network_cls(pc_config) # instantiate models ae = ae_cls(ae_config) pc = pc_cls(pc_config, num_centers=ae_config.num_centers) enc_out_val = ae.encode(images, is_training=False) images_compressed = ae.decode(enc_out_val.qhard, is_training=False) bitcost_val = pc.bitcost(enc_out_val.qbar, enc_out_val.symbols, is_training=False, pad_value=pc.auto_pad_value(ae)) avg_bits_per_pixel = bitcost_to_bpp(bitcost_val, images) images = nchw_to_nhwc(images) images_compressed = nchw_to_nhwc(images_compressed) # compute distortions self.logger.info('* distortions') with tf.name_scope('distortions'): distortions_obj = Distortions( reconstructed_images=images_compressed, original_images=tf.cast(images, tf.float32), lambda_ms_ssim=1.0, lambda_psnr=1.0, lambda_feature_loss=1.0, data_format=self.DATA_FORMAT, loss_net_kwargs=None) distortions_ops = { 'ms_ssim': distortions_obj.compute_ms_ssim(), 'mse': distortions_obj.compute_mse(), 'psnr': distortions_obj.compute_psnr() } # cpm saver cpm_saver = Saver( cpm_checkpoint, var_list=Saver.get_var_list_of_ckpt_dir(cpm_checkpoint)) ckpt_itr, cpm_ckpt_path = Saver.all_ckpts_with_iterations( cpm_checkpoint)[-1] self.logger.info('ckpt_itr={}'.format(ckpt_itr)) self.logger.info('ckpt_path={}'.format(cpm_ckpt_path)) graph.finalize() with tf.Session(config=get_sess_config(allow_growth=False), graph=graph) as sess: cpm_saver.restore_ckpt(sess, cpm_ckpt_path) distortions_values = {key: list() for key in self.DISTORTION_KEYS} bpp_values = [] n_images_processed = 0 n_images_processed_per_second = deque(10 * [0.0], 10) progress( n_images_processed, self._dataset.NUM_VAL, '{}/{} images processed'.format(n_images_processed, self._dataset.NUM_VAL)) try: while True: batch_start_time = time.time() # compute distortions and bpp batch_bpp_mean_values, batch_distortions_values = sess.run( [avg_bits_per_pixel, distortions_ops]) # collect values bpp_values.append(batch_bpp_mean_values) for key in self.DISTORTION_KEYS: distortions_values[key].append( batch_distortions_values[key]) n_images_processed += self._batch_size n_images_processed_per_second.append( self._batch_size / (time.time() - batch_start_time)) progress(n_images_processed, self._dataset.NUM_VAL, status='{}/{} images processed ({} img/s)'.format( n_images_processed, self._dataset.NUM_VAL, np.mean([ t for t in n_images_processed_per_second ]))) except tf.errors.OutOfRangeError: self.logger.info( 'reached end of dataset; processed {} images'.format( n_images_processed)) except KeyboardInterrupt: self.logger.info( 'manual interrupt; processed {}/{} images'.format( n_images_processed, self._dataset.NUM_VAL)) mean_bpp_values = np.mean(bpp_values) mean_dist_values = { key: np.mean(arr) for key, arr in distortions_values.items() } print('*** intermediate results:') print('bits per pixel: {}'.format(mean_bpp_values)) for key in self.DISTORTION_KEYS: print('{}: {}'.format(key, mean_dist_values[key])) return {key: np.nan for key in self.DISTORTION_KEYS}, np.nan mean_bpp_values = np.mean(bpp_values) mean_dist_values = { key: np.mean(arr) for key, arr in distortions_values.items() } return mean_dist_values, mean_bpp_values
def eval_classifier_model(self, cnn_model: any( [ImagenetClassifier, FGVCClassifier]), ckpt_path=None): tf.reset_default_graph() print('') self.logger.info('===== building graph start: {} ====='.format( cnn_model.NAME)) # assertions assert self._dataset.NUM_CLASSES == cnn_model.num_classes, 'incostent number of classes ({} != {})'.format( self._dataset.NUM_CLASSES, cnn_model.num_classes) # image shapes image_shape_classification = cnn_model.INPUT_SHAPE image_shape_compression = CompressionPreprocessing.pad_image_shape( image_shape=image_shape_classification, size_multiple_of=self.SIZE_MULTIPLE_OF, extra_padding_multiples=2) # log image sizes self.logger.info( 'image_shape_classification={}'.format(image_shape_classification)) self.logger.info( 'image_shape_compression={}'.format(image_shape_compression)) with tf.Graph().as_default() as graph: # datafeed self.logger.info('* datafeed') rnn_model = self._get_rnn_model(image_shape_compression[0], image_shape_compression[1]) input_pipeline = InputPipeline( records=self.records_file, records_type=RecordsParser.RECORDS_LABELLED, shuffle_buffer_size=self.BATCH_SIZE * self.NUM_PREPROCESSING_THREADS, batch_size=self.BATCH_SIZE, num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS, num_repeat=1, preprocessing_fn=CompressionPreprocessing.preprocess_image, preprocessing_kwargs={ 'height': image_shape_compression[0], 'width': image_shape_compression[1], 'resize_side_min': min(image_shape_compression[:2]), 'is_training': False, 'dtype_out': tf.uint8 }, drop_remainder=False, compute_bpp=False, shuffle=False) images, labels = input_pipeline.next_batch() # compression + inference op self.logger.info('* compression') with tf.name_scope('rnn_compression'): image_batch_compressed = rnn_model.build_model( images=images, is_training=tf.cast(False, tf.bool), reuse=tf.get_variable_scope().reuse) # inference kwargs self.logger.info('* inference') if self._dataset_name == Imagenet.NAME: def inference_kwargs(**kwargs): return dict(graph=kwargs['graph']) else: def inference_kwargs(**kwargs): return dict( arg_scope=cnn_model.arg_scope(weight_decay=float(0)), is_training=False, return_predictions=True, reuse=True if kwargs['j'] > 0 else False) predictions_per_compression = [] with tf.name_scope('inference_rnn'): for rnn_iteration in range(self._num_iterations): with tf.name_scope('iteration_{}'.format(rnn_iteration)): image_batch_compressed_iteration = tf.cast( image_batch_compressed[rnn_iteration], tf.float32) # take central crop of images in batch image_batch_compressed_iteration = tf.image.resize_image_with_crop_or_pad( image=image_batch_compressed_iteration, target_height=image_shape_classification[0], target_width=image_shape_classification[1]) # standardize appropriately image_batch_compressed_iteration = cnn_model.standardize_tensor( image_batch_compressed_iteration) # predict preds = cnn_model.inference( image_batch_compressed_iteration, **inference_kwargs(graph=graph, j=rnn_iteration)) predictions_per_compression.append(preds) # aggregate predictions_per_compression_op = tf.stack( predictions_per_compression, axis=0) self.logger.info('predictions_shape: {}'.format( predictions_per_compression_op.get_shape().as_list())) # restorers if self._dataset_name == Imagenet.NAME: classifier_saver = None else: classifier_saver = tf.train.Saver( var_list=cnn_model.model_variables()) # rnn saver saver = tf.train.Saver(var_list=rnn_model.model_variables) graph.finalize() with tf.Session(config=get_sess_config(allow_growth=False), graph=graph) as sess: if classifier_saver is not None: classifier_saver.restore(sess, ckpt_path) saver.restore(sess, self._rnn_checkpoint) labels_values = [] predictions_all_iters_values = [ list() for _ in range(self._num_iterations) ] n_images_processed = 0 n_images_processed_per_second = deque(10 * [0.0], 10) progress( n_images_processed, self._dataset.NUM_VAL, '{}/{} images processed'.format(n_images_processed, self._dataset.NUM_VAL)) try: while True: batch_start_time = time.time() # run inference batch_predictions_all_iters_values, batch_label_values = sess.run( [predictions_per_compression_op, labels]) # collect predictions for rnn_itr, preds_itr in enumerate( batch_predictions_all_iters_values): predictions_all_iters_values[rnn_itr].append(preds_itr) # collect labels and bpp labels_values.append( self.to_categorical(batch_label_values, Imagenet.NUM_CLASSES)) n_images_processed += len(batch_label_values) n_images_processed_per_second.append( len(batch_label_values) / (time.time() - batch_start_time)) progress(n_images_processed, self._dataset.NUM_VAL, status='{}/{} images processed ({} img/s)'.format( n_images_processed, self._dataset.NUM_VAL, np.mean([ t for t in n_images_processed_per_second ]))) except tf.errors.OutOfRangeError: self.logger.info( 'reached end of dataset; processed {} images'.format( n_images_processed)) except KeyboardInterrupt: self.logger.info( 'manual interrupt; processed {}/{} images'.format( n_images_processed, Imagenet.NUM_VAL)) return [(np.nan, np.nan) for _ in range(self._num_iterations) ], [np.nan for _ in range(self._num_iterations)] labels_values = np.concatenate(labels_values, axis=0) predictions_all_iters_values = [ np.concatenate(preds_iter_values, axis=0) for preds_iter_values in predictions_all_iters_values ] accuracies = [(self.top_k_accuracy(labels_values, preds_iter_values, 1), self.top_k_accuracy(labels_values, preds_iter_values, 5)) for preds_iter_values in predictions_all_iters_values] return accuracies
def run(self): tf.reset_default_graph() print('') self.logger.info('===== building graph start =====') # datafeed self.logger.info('* datafeed') with tf.Graph().as_default() as graph: image_height_compression, image_width_compression, _ = RNNCompressionModel.pad_image_shape( image_shape=[self._image_height, self._image_width, 3]) rnn_model = self._get_rnn_model(image_height_compression, image_width_compression) input_pipeline = InputPipeline( records=self._records_file, records_type=RecordsParser.RECORDS_UNLABELLED, shuffle_buffer_size=0, batch_size=self._batch_size, num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS, num_repeat=1, preprocessing_fn=self._get_resize_function( self._image_height, self._image_width), preprocessing_kwargs={}, drop_remainder=True, compute_bpp=False, shuffle=False) images = input_pipeline.next_batch()[0] if image_height_compression != self._image_height or image_width_compression != self._image_width: images = tf.image.resize_image_with_crop_or_pad( images, image_height_compression, image_width_compression) num_images_in_batch_op = tf.shape(images)[0] self.logger.info('images shape for compression: {}'.format( images.get_shape().as_list())) # compress images self.logger.info('* compression') images_compressed = rnn_model.build_model( images=images, is_training=tf.cast(False, tf.bool), reuse=tf.get_variable_scope().reuse) images_compressed.set_shape([ self._num_iterations, self._batch_size, image_height_compression, image_width_compression, 3 ]) self.logger.info('compressed images shape: {}'.format( images_compressed.get_shape().as_list())) # compute distortions self.logger.info('* distortions') distortions_obj_per_compression = [ Distortions(reconstructed_images=tf.image. resize_image_with_crop_or_pad( image=images_compressed[ii], target_width=self._image_width, target_height=self._image_height), original_images=tf.cast(images, tf.float32), lambda_ms_ssim=1.0, lambda_psnr=1.0, lambda_feature_loss=1.0, data_format=self.DATA_FORMAT, loss_net_kwargs=None) for ii in range(self._num_iterations) ] distortions_ops_per_compression = [{ 'ms_ssim': d.compute_ms_ssim() } for d in distortions_obj_per_compression] # savers rnn_saver = tf.train.Saver(var_list=rnn_model.model_variables) graph.finalize() with tf.Session(config=get_sess_config(allow_growth=True), graph=graph) as sess: rnn_saver.restore(sess, self._rnn_checkpoint) distortions_values_per_compression = [{ key: list() for key in self.DISTORTION_KEYS } for _ in range(self._num_iterations)] bpp_values_per_compression = [ list() for _ in range(self._num_iterations) ] n_images_processed = 0 n_images_processed_per_second = deque(10 * [0.0], 10) progress( n_images_processed, Cub200.NUM_VAL, '{}/{} images processed'.format(n_images_processed, self._dataset.NUM_VAL)) try: while True: batch_start_time = time.time() # compute distortions and bpp batch_distortions_values_per_compression, num_images_in_batch = sess.run( [ distortions_ops_per_compression, num_images_in_batch_op ]) # collect values for comp_level, dist_comp in enumerate( batch_distortions_values_per_compression): bpp_values_per_compression[comp_level].extend( [0.125 * (comp_level + 1)]) for key in self.DISTORTION_KEYS: distortions_values_per_compression[comp_level][ key].append(dist_comp[key]) n_images_processed += num_images_in_batch n_images_processed_per_second.append( num_images_in_batch / (time.time() - batch_start_time)) progress(n_images_processed, self._dataset.NUM_VAL, status='{}/{} images processed ({} img/s)'.format( n_images_processed, self._dataset.NUM_VAL, np.mean([ t for t in n_images_processed_per_second ]))) except tf.errors.OutOfRangeError: self.logger.info( 'reached end of dataset; processed {} images'.format( n_images_processed)) except KeyboardInterrupt: self.logger.info( 'manual interrupt; processed {}/{} images'.format( n_images_processed, self._dataset.NUM_VAL)) return mean_bpp_values_per_compression = [ np.mean(bpp_vals) for bpp_vals in bpp_values_per_compression ] mean_dist_values_per_compression = [{ key: np.mean(arr) for key, arr in dist_dict.items() } for dist_dict in distortions_values_per_compression] self._save_results(mean_bpp_values_per_compression, mean_dist_values_per_compression, self._rnn_unit + '_' + self._loss_name, [q + 1 for q in range(self._num_iterations)])
def eval_classifier_model(self, ae_config, pc_config, cpm_checkpoint, cnn_model: any([ImagenetClassifier, FGVCClassifier]), fgvc_ckpt_path=None): tf.reset_default_graph() print('') self.logger.info('===== building graph start: {} ====='.format(cnn_model.NAME)) # assertions assert self._dataset.NUM_CLASSES == cnn_model.num_classes, 'incostent number of classes ({} != {})'.format( self._dataset.NUM_CLASSES, cnn_model.num_classes) # image shapes image_shape_classification = cnn_model.INPUT_SHAPE image_shape_compression = CompressionPreprocessing.pad_image_shape(image_shape=image_shape_classification, size_multiple_of=self.SIZE_MULTIPLE_OF, extra_padding_multiples=2) # log image sizes self.logger.info('image_shape_classification={}'.format(image_shape_classification)) self.logger.info('image_shape_compression={}'.format(image_shape_compression)) with tf.Graph().as_default() as graph: # datafeed self.logger.info('* datafeed') input_pipeline = InputPipeline(records=self.records_file, records_type=RecordsParser.RECORDS_LABELLED, shuffle_buffer_size=self.BATCH_SIZE * self.NUM_PREPROCESSING_THREADS, batch_size=self.BATCH_SIZE, num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS, num_repeat=1, preprocessing_fn=CompressionPreprocessing.preprocess_image, preprocessing_kwargs={'height': image_shape_compression[0], 'width': image_shape_compression[1], 'resize_side_min': min(image_shape_compression[:2]), 'is_training': False, 'dtype_out': tf.uint8}, drop_remainder=False, compute_bpp=False, shuffle=False, dtype_out=tf.float32) images, labels = input_pipeline.next_batch() # compression + inference op self.logger.info('* compression') with tf.name_scope('compression'): images = nhwc_to_nchw(images) # create networks ae_cls = autoencoder.get_network_cls(ae_config) pc_cls = probclass.get_network_cls(pc_config) # instantiate models ae = ae_cls(ae_config) pc = pc_cls(pc_config, num_centers=ae_config.num_centers) enc_out_val = ae.encode(images, is_training=False) images_compressed = ae.decode(enc_out_val.qhard, is_training=False) bitcost_val = pc.bitcost(enc_out_val.qbar, enc_out_val.symbols, is_training=False, pad_value=pc.auto_pad_value(ae)) avg_bits_per_pixel = bitcost_to_bpp(bitcost_val, images) images_compressed = nchw_to_nhwc(images_compressed) # inference kwargs self.logger.info('* inference') if self._dataset_name == Imagenet.NAME: def inference_kwargs(**kwargs): return dict(graph=kwargs['graph']) else: def inference_kwargs(**kwargs): return dict(arg_scope=cnn_model.arg_scope(weight_decay=float(0)), is_training=False, return_predictions=True, reuse=None) with tf.name_scope('inference_rnn'): # take central crop of images in batch images_compressed = tf.image.resize_image_with_crop_or_pad( image=images_compressed, target_height=image_shape_classification[0], target_width=image_shape_classification[1]) # standardize appropriately images_compressed = cnn_model.standardize_tensor( images_compressed) # predict predictions = cnn_model.inference(images_compressed, **inference_kwargs(graph=graph)) # aggregate self.logger.info('predictions_shape: {}'.format(predictions.get_shape().as_list())) # restorers if self._dataset_name == Imagenet.NAME: classifier_saver = None else: classifier_saver = tf.train.Saver(var_list=cnn_model.model_variables()) # cpm saver cpm_saver = Saver(cpm_checkpoint, var_list=Saver.get_var_list_of_ckpt_dir(cpm_checkpoint)) ckpt_itr, cpm_ckpt_path = Saver.all_ckpts_with_iterations(cpm_checkpoint)[-1] self.logger.info('ckpt_itr={}'.format(ckpt_itr)) self.logger.info('ckpt_path={}'.format(cpm_ckpt_path)) graph.finalize() with tf.Session(config=get_sess_config(allow_growth=False), graph=graph) as sess: cpm_saver.restore_ckpt(sess, cpm_ckpt_path) if classifier_saver is not None: classifier_saver.restore(sess, fgvc_ckpt_path) labels_values = [] predictions_values = [] bpp_values = [] n_images_processed = 0 n_images_processed_per_second = deque(10 * [0.0], 10) progress(n_images_processed, self._dataset.NUM_VAL, '{}/{} images processed'.format(n_images_processed, self._dataset.NUM_VAL)) try: while True: batch_start_time = time.time() # run inference batch_predictions_values, batch_label_values, batch_avg_bpp_values = sess.run( [predictions, labels, avg_bits_per_pixel]) # collect predictions predictions_values.append(batch_predictions_values) # collect labels and bpp labels_values.append(self.to_categorical(batch_label_values, Imagenet.NUM_CLASSES)) bpp_values.append(batch_avg_bpp_values) n_images_processed += len(batch_label_values) n_images_processed_per_second.append(len(batch_label_values) / (time.time() - batch_start_time)) progress(n_images_processed, self._dataset.NUM_VAL, status='{}/{} images processed ({} img/s)'.format( n_images_processed, self._dataset.NUM_VAL, np.mean([t for t in n_images_processed_per_second]))) except tf.errors.OutOfRangeError: self.logger.info('reached end of dataset; processed {} images'.format(n_images_processed)) except KeyboardInterrupt: self.logger.info( 'manual interrupt; processed {}/{} images'.format(n_images_processed, self._dataset.NUM_VAL)) labels_values = np.concatenate(labels_values, axis=0) predictions_values = np.concatenate(predictions_values, axis=0) bpp_values_mean = np.mean(bpp_values) accuracies = (self.top_k_accuracy(labels_values, predictions_values, 1), self.top_k_accuracy(labels_values, predictions_values, 5)) print('*** intermediate results:') print('bits per pixel: {}'.format(bpp_values_mean)) print('Top-1 Accuracy: {}'.format(accuracies[0])) print('Top-5 Accuracy: {}'.format(accuracies[1])) return (np.nan, np.nan), np.nan labels_values = np.concatenate(labels_values, axis=0) predictions_values = np.concatenate(predictions_values, axis=0) bpp_values_mean = np.mean(bpp_values) accuracies = (self.top_k_accuracy(labels_values, predictions_values, 1), self.top_k_accuracy(labels_values, predictions_values, 5)) return accuracies, bpp_values_mean
class TrainModel: ALLOWED_MODELS = [ 'inception_v3', 'mobilenet_v1', 'resnet_v1_50', 'vgg_16', 'inception_resnet_v2' ] ALLOWED_DATASETS = [Cub200.NAME, StanfordDogs.NAME] def __init__(self, model_name, dataset_name, train_records, val_records, config_path, config_optimizer, config_learning_rate, config_data, config_transfer_learning, init_checkpoint_path, job_id, eval_epochs, checkpoint_epochs): # assertions assert os.path.isfile(train_records), 'train_records not found' assert os.path.isfile(val_records), 'val_records not found' assert model_name in self.ALLOWED_MODELS, 'unknown model' assert dataset_name in self.ALLOWED_DATASETS, 'unknown dataset' # general args self._model_name = model_name self._dataset = get_dataset(dataset_name) self._config_path = config_path self._init_checkpoint_path = init_checkpoint_path self._eval_epochs = eval_epochs self._checkpoint_epochs = checkpoint_epochs # optimizer self._optimizer_name = config_optimizer['name'] self._opt_epsilon = config_optimizer['opt_epsilon'] self._weight_decay = config_optimizer['weight_decay'] # learning rate self._initial_learning_rate = config_learning_rate[ 'initial_learning_rate'] self._learning_rate_decay_type = config_learning_rate[ 'learning_rate_decay_type'] self._learning_rate_decay_factor = config_learning_rate[ 'learning_rate_decay_factor'] self._num_epochs_per_decay = config_learning_rate[ 'num_epochs_per_decay'] self._end_learning_rate = config_learning_rate['end_learning_rate'] # data self._batch_size = config_data['batch_size'] self._num_epochs = config_data['num_epochs'] self._shuffle_buffer_size = config_data['shuffle_buffer_size'] self._num_preprocessing_threads = config_data[ 'num_preprocessing_threads'] self._train_records = train_records self._val_records = val_records # transfer_learning self._trainable_scopes = config_transfer_learning['trainable_scopes'] self._checkpoint_exclude_scopes = config_transfer_learning[ 'checkpoint_exclude_scopes'] # timing self._epoch_times = deque(10 * [0.0], 10) # setup logging, directories self._job_id = job_id self._setup_dirs_and_logging() # log params log_configs(self._logger, [ config_optimizer, config_transfer_learning, config_data, config_learning_rate ]) self._logger.info('model_name: {}'.format(model_name)) self._logger.info('train_records: {}'.format(train_records)) self._logger.info('val_records: {}'.format(val_records)) self._logger.info('job_id: {}'.format(job_id)) # build graph self._build_graph() def _init_model(self): self._classifier = classifier_factory.get_fgvc_classifier( self._dataset.NAME, self._model_name) def _build_input_pipelines(self): self._train_input_pipeline = InputPipeline( records=self._train_records, records_type=RecordsParser.RECORDS_LABELLED, shuffle_buffer_size=self._shuffle_buffer_size, batch_size=self._batch_size, num_preprocessing_threads=self._num_preprocessing_threads, num_repeat=1, preprocessing_fn=self._classifier.preprocess_image, preprocessing_kwargs={'is_training': True}, drop_remainder=False, iterator_type=InputPipeline.INITIALIZABLE_ITERATOR) self._val_input_pipeline = InputPipeline( records=self._val_records, records_type=RecordsParser.RECORDS_LABELLED, shuffle_buffer_size=self._shuffle_buffer_size, batch_size=self._batch_size, num_preprocessing_threads=self._num_preprocessing_threads, num_repeat=-1, preprocessing_fn=self._classifier.preprocess_image, preprocessing_kwargs={'is_training': False}, drop_remainder=False, iterator_type=InputPipeline.ONESHOT_ITERATOR) def _compute_loss(self, logits, labels, end_points): # cross entropy loss cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) cross_entropy_loss = tf.reduce_mean(cross_entropy_loss, name='cross_entropy_loss') # aux loss if 'AuxLogits' in end_points: aux_loss = 0.4 * tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=end_points['AuxLogits']) aux_loss = tf.reduce_mean(aux_loss, name='aux_loss') else: aux_loss = tf.cast(0.0, tf.float32) # regularization loss if self._trainable_scopes is None: regularization_terms = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) else: regularization_terms = [] for sc in self._trainable_scopes: regularization_terms.extend( tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES, scope=tf.get_default_graph().get_name_scope() + '/' + sc)) regularization_loss = tf.reduce_sum(regularization_terms, name='regularization_loss') return cross_entropy_loss, aux_loss, regularization_loss def _get_trainable_variables(self, verbose=False): if self._trainable_scopes is None: return tf.trainable_variables() assert isinstance(self._trainable_scopes, list) variables_to_train = [] for scope in self._trainable_scopes: scope_variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope) variables_to_train.extend(scope_variables) if verbose: self._logger.info('====== trainable variables:') self._logger.info('trainable_scopes: {}'.format( self._trainable_scopes)) self._logger.info('num_trainable_variables: {}'.format( len(variables_to_train))) for v in variables_to_train: self._logger.info('name: {}, shape: {}, dtype: {}'.format( v.name, v.get_shape().as_list(), v.dtype)) self._logger.info('===========================') return variables_to_train def _build_graph(self): with tf.Graph().as_default() as graph: self._init_model() self._build_input_pipelines() self._build_graph0() graph.finalize() self._graph = graph def _build_graph0(self): self._global_step = tf.train.get_or_create_global_step( tf.get_default_graph()) # ========= train with tf.name_scope('train'): train_images, train_labels = self._train_input_pipeline.next_batch( ) train_images.set_shape([None, *self._classifier.INPUT_SHAPE]) # ========= inference arg_scope = self._classifier.arg_scope( weight_decay=self._weight_decay, is_training=True) with tf.variable_scope(tf.get_variable_scope(), reuse=None): train_logits, train_end_points = self._classifier.inference( input_tensor=train_images, is_training=True, reuse=None, arg_scope=arg_scope) # log shapes self._logger.info('images_shape: {}'.format( train_images.get_shape().as_list())) self._logger.info('logits_shape: {}'.format( train_logits.get_shape().as_list())) trainable_variables = self._get_trainable_variables(verbose=True) # ========= compute losses self._cross_entropy_loss, self._aux_loss, self._regularization_loss = self._compute_loss( train_logits, train_labels, train_end_points) self._train_loss = self._regularization_loss + self._cross_entropy_loss + self._aux_loss # ========= accuracy train_predictions = tf.argmax(tf.nn.softmax(train_logits), axis=1, name='train_predictions') self._train_accuracy, self._train_accuracy_update = tf.metrics.accuracy( train_labels, train_predictions, name='train_accuracy') # ========= configure optimizer learning_rate = configure_learning_rate( self._global_step, self._batch_size, self._initial_learning_rate, 1, self._dataset.NUM_TRAIN, self._num_epochs_per_decay, self._learning_rate_decay_type, self._learning_rate_decay_factor, self._end_learning_rate) self._optimizer = configure_optimizer(learning_rate, self._optimizer_name, self._opt_epsilon) # ========= optimization if self._trainable_scopes is None: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) else: update_ops = [] for sc in self._trainable_scopes: update_ops.extend( tf.get_collection( key=tf.GraphKeys.UPDATE_OPS, scope=tf.get_default_graph().get_name_scope() + '/' + sc)) self._logger.info('update_ops: {}'.format(update_ops)) with tf.control_dependencies(update_ops): self._train_op = self._optimizer.minimize( self._train_loss, self._global_step, trainable_variables) # ========= summaries self._train_summaries = self._collect_summaries( 'train', self._train_loss, self._cross_entropy_loss, self._aux_loss, self._regularization_loss, self._train_accuracy, learning_rate, trainable_variables, train_end_points) # ========= validation with tf.name_scope('val'): val_images, val_labels = self._val_input_pipeline.next_batch() # ========= inference arg_scope = self._classifier.arg_scope(self._weight_decay) with tf.variable_scope(tf.get_variable_scope(), reuse=True): val_logits, val_end_points = self._classifier.inference( input_tensor=val_images, is_training=False, reuse=True, arg_scope=arg_scope) # ========= compute losses self._val_cross_entropy_loss, self._val_aux_loss, self._val_regularization_loss = self._compute_loss( val_logits, val_labels, val_end_points) self._val_loss = self._val_regularization_loss + self._val_cross_entropy_loss + self._val_aux_loss # ========= accuracy val_predictions = tf.argmax(tf.nn.softmax(val_logits), axis=1) self._val_accuracy, self._val_accuracy_update = tf.metrics.accuracy( val_labels, val_predictions, name='validation_accuracy') # ========= summaries self._val_summaries = self._collect_summaries( 'val', self._val_loss, self._val_cross_entropy_loss, self._val_aux_loss, self._val_regularization_loss, self._val_accuracy) # ========= saving and restoring self._optimizer.get_name() vars_to_save = [ v for v in tf.global_variables() if 'Adam' not in v.op.name ] vars_to_save.extend( [v for v in tf.local_variables() if 'Adam' not in v.op.name]) self._logger.info(('vars_to_save: {}'.format(vars_to_save))) self._model_saver = tf.train.Saver(var_list=vars_to_save, save_relative_paths=True, max_to_keep=2) self._init_fn = self._get_init_fn(self._classifier.model_variables(), self._init_checkpoint_path, self._checkpoint_exclude_scopes, verbose=True) def train(self): train_writer = tf.summary.FileWriter( os.path.join(self._tensorboard_dir, 'train/'), self._graph) val_writer = tf.summary.FileWriter( os.path.join(self._tensorboard_dir, 'val/'), self._graph) with tf.Session(graph=self._graph, config=get_sess_config(allow_growth=True)) as sess: # initialize model and train input pipeline self._init_fn(sess) self._train_input_pipeline.initialize(sess) step = sess.run(self._global_step) for epoch in range(1, self._num_epochs + 1): epoch_start_time = time.time() try: while True: _, step = sess.run([self._train_op, self._global_step]) except tf.errors.OutOfRangeError: pass except KeyboardInterrupt: self._logger.info('manual interrupt') self._clean_up(sess, step, [train_writer, val_writer]) return self._epoch_times.append(time.time() - epoch_start_time) self._train_input_pipeline.initialize(sess) if epoch % self._eval_epochs == 0 or epoch == 1: self._eval_procedure(sess, step, epoch, train_writer, val_writer) if epoch % self._checkpoint_epochs == 0: self._model_saver.save(sess, os.path.join( self._checkpoint_dir, 'model.ckpt'), global_step=step, write_meta_graph=False) self._logger.info('[{} steps]saved model.'.format(step)) self._clean_up(sess, step, [train_writer, val_writer]) def _get_init_fn(self, model_variables, checkpoint_path, checkpoint_exclude_scopes, ignore_missing_vars=False, verbose=False): """ returns fetches, feed_dict for restoring model """ # ========== continue with training latest_checkpoint = tf.train.latest_checkpoint(self._checkpoint_dir) if latest_checkpoint is not None: self._logger.info( 'continue training from {}'.format(latest_checkpoint)) def _init_fn(sess): self._model_saver.restore(sess, latest_checkpoint) return _init_fn # ========== start training from scratch if no checkpoint is provided if checkpoint_path is None: self._logger.info( 'no init_checkpoint_path provided; training network from scratch.' ) def _init_fn(sess): sess.run( tf.group([ tf.global_variables_initializer(), tf.local_variables_initializer() ])) return _init_fn # ========== fine tune trainable variables exclusions = [v.op.name for v in self._optimizer.variables()] if checkpoint_exclude_scopes: assert isinstance(checkpoint_exclude_scopes, list) exclusions.extend( [scope.strip() for scope in checkpoint_exclude_scopes]) vars_to_restore_from_checkpoint = [] for var in model_variables: for exclusion in exclusions: if var.op.name.startswith(exclusion): break else: vars_to_restore_from_checkpoint.append(var) if tf.gfile.IsDirectory(checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) vars_to_init_from_scratch = [ v for v in [*tf.global_variables(), *tf.local_variables()] if v not in vars_to_restore_from_checkpoint ] if verbose: self._logger.info( '====== variables to be initialized from checkpoint:') self._logger.info('num: {}'.format( len(vars_to_restore_from_checkpoint))) for v in vars_to_restore_from_checkpoint: self._logger.info('name: {}, shape: {}, dtype: {}'.format( v.op.name, v.get_shape().as_list(), v.dtype)) self._logger.info('===========================\n') self._logger.info( '====== variables to be initialized from scratch:') self._logger.info('num: {}'.format(len(vars_to_init_from_scratch))) for v in vars_to_init_from_scratch: self._logger.info('name: {}, shape: {}, dtype: {}'.format( v.op.name, v.get_shape().as_list(), v.dtype)) self._logger.info('===========================\n') # randomly initialize vars to train init_fetches = [ tf.variables_initializer(var_list=vars_to_init_from_scratch) ] # restore rest of vars from pretrained model restore_op, restore_dict = slim.assign_from_checkpoint( checkpoint_path, var_list=vars_to_restore_from_checkpoint, ignore_missing_vars=ignore_missing_vars) init_fetches.append(restore_op) def _init_fn(sess): sess.run(init_fetches, feed_dict=restore_dict) return _init_fn def _clean_up(self, sess, step, summary_writers=None): print('cleanup...') # stop writers summary_writers = [ summary_writers ] if not isinstance(summary_writers, list) else summary_writers for writer in summary_writers: if writer is None: continue else: writer.flush() writer.close() # save model self._model_saver.save(sess, os.path.join(self._checkpoint_dir, 'model.ckpt'), step, write_meta_graph=False) self._logger.info('[{} steps]saved model.'.format(step)) @staticmethod def _collect_summaries(split, total_loss, cross_entropy_loss, aux_loss, regularization_loss, accuracy, learning_rate=None, variables=None, end_points=None): summaries = list() summaries.append(tf.summary.scalar(split + '/total_loss', total_loss)) summaries.append(tf.summary.scalar(split + '/accuracy', accuracy)) summaries.append( tf.summary.scalar(split + '/cross_entropy_loss', cross_entropy_loss)) summaries.append(tf.summary.scalar(split + '/aux_loss', aux_loss)) summaries.append( tf.summary.scalar(split + '/regularization_loss', regularization_loss)) if split == 'val': return tf.summary.merge(summaries, name=split + '_summaries') summaries.append( tf.summary.scalar(split + '/learning_rate', learning_rate)) for v in variables: summaries.append( tf.summary.histogram(split + '/variables/{}'.format(v.op.name), v)) for ep in end_points: a = end_points[ep] summaries.append( tf.summary.histogram(split + '/activations/{}'.format(ep), a)) return tf.summary.merge(summaries, name=split + '_summaries') def _eval_procedure(self, sess, step, epoch, train_summary_writer, val_summary_writer): # train stats (train_summaries, train_total_loss, train_regularization_loss, train_cross_entropy_loss, train_aux_loss, train_accuracy, _) = sess.run([ self._train_summaries, self._train_loss, self._regularization_loss, self._cross_entropy_loss, self._aux_loss, self._train_accuracy, self._train_accuracy_update ]) train_summary_writer.add_summary(train_summaries, epoch) train_summary_writer.flush() # val stats (val_summaries, val_total_loss, val_cross_entropy_loss, val_aux_loss, val_accuracy, _) = sess.run([ self._val_summaries, self._val_loss, self._val_cross_entropy_loss, self._val_aux_loss, self._val_accuracy, self._val_accuracy_update ]) val_summary_writer.add_summary(val_summaries, epoch) val_summary_writer.flush() # compute average epoch time avg_epoch_time = np.mean([t for t in self._epoch_times if t > 0]) evg_epoch_time_hms = seconds_to_minutes_seconds(avg_epoch_time) # print stats self._logger.info( self._eval_str(epoch=epoch, num_epochs=self._num_epochs, step=step, avg_epoch_time=evg_epoch_time_hms, total_loss=train_total_loss, val_total_loss=val_total_loss, regularization_loss=train_regularization_loss, cross_entropy_loss=train_cross_entropy_loss, val_cross_entropy_loss=val_cross_entropy_loss, aux_loss=train_aux_loss, val_aux_loss=val_aux_loss, accuracy=100.0 * train_accuracy, val_accuracy=100.0 * val_accuracy)) @staticmethod def _eval_str(epoch, num_epochs, step, avg_epoch_time, total_loss, val_total_loss, regularization_loss, cross_entropy_loss, val_cross_entropy_loss, aux_loss, val_aux_loss, accuracy, val_accuracy): eval_str = "[{epoch}/{num_epochs} epochs ({avg_epoch_time} / epoch)] | " eval_str += "total_loss: {total_loss:.4f} ({val_total_loss:.4f})| " eval_str += "regularization_loss: {regularization_loss:.6f} (-) | " eval_str += "cross_entropy_loss: {cross_entropy_loss:.4f} ({val_cross_entropy_loss:.4f}) | " eval_str += "aux_loss: {aux_loss:.4f} ({val_aux_loss:.4f}) | " eval_str += "accuracy: {accuracy:.2f}% ({val_accuracy:.2f}%)" return eval_str.format(epoch=epoch, num_epochs=num_epochs, step=step, avg_epoch_time=avg_epoch_time, total_loss=total_loss, val_total_loss=val_total_loss, regularization_loss=regularization_loss, cross_entropy_loss=cross_entropy_loss, val_cross_entropy_loss=val_cross_entropy_loss, aux_loss=aux_loss, val_aux_loss=val_aux_loss, accuracy=accuracy, val_accuracy=val_accuracy) @property def graph(self): return self._graph def _setup_dirs_and_logging(self): self._module_dir = os.path.dirname( os.path.abspath(inspect.getfile(inspect.currentframe()))) base_dir = os.path.join( self._module_dir, 'checkpoints/{}_id{}/'.format(self._model_name, self._job_id)) self._checkpoint_dir = os.path.join(base_dir, 'checkpoints') self._tensorboard_dir = os.path.join(base_dir, 'tensorboard') self._log_dir = os.path.join(base_dir, 'logs') log_strings = [] if not os.path.exists(base_dir): os.makedirs(base_dir) log_strings.append('created dir {}'.format(base_dir)) if not os.path.exists(self._checkpoint_dir): os.makedirs(self._checkpoint_dir) log_strings.append('created dir {}'.format(self._checkpoint_dir)) if not os.path.exists(self._log_dir): os.makedirs(self._log_dir) log_strings.append('created dir {}'.format(self._log_dir)) if not os.path.exists(self._tensorboard_dir): os.makedirs(self._tensorboard_dir) log_strings.append('created dir {}'.format(self._tensorboard_dir)) # copy config file to base dir config_dest_file = os.path.join(base_dir, 'config.json') copyfile(self._config_path, config_dest_file) # logging logfile = os.path.join( self._log_dir, 'train_{}_id{}.log'.format(self._model_name, self._job_id)) self._logger = get_logger(logfile) for s in log_strings: self._logger.info(s) self._logger.info('logfile={}'.format(logfile))
def run(self): tf.reset_default_graph() print('') self.logger.info('===== building graph start =====') with tf.Graph().as_default() as graph: # datafeed self.logger.info('* datafeed') input_pipeline = InputPipeline( records=self._records_file, records_type=RecordsParser.RECORDS_UNLABELLED, shuffle_buffer_size=0, batch_size=self._batch_size, num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS, num_repeat=1, preprocessing_fn=self._get_resize_function( self._image_height, self._image_width), preprocessing_kwargs={}, drop_remainder=True, compute_bpp=False, shuffle=False) images = input_pipeline.next_batch()[0] image_shape = images.get_shape().as_list() self.logger.info('image_shape: {}'.format(image_shape)) # compression op self.logger.info('* compression') images_per_compression = [] bpp_op_per_compression = [] for j, compression_level in enumerate(self.COMPRESSION_LEVELS): # compress batch with tf.name_scope( 'compression_webp_{}'.format(compression_level)): with tf.device(CPU_DEVICE): # -> webp compression on cpu img_batch_compressed, _bpp = TFWebp.tf_encode_decode_image_batch( image_batch=tf.cast(images, tf.uint8), quality=compression_level) img_batch_compressed.set_shape( images.get_shape().as_list()) images_per_compression.append( tf.cast(img_batch_compressed, tf.float32)) bpp_op_per_compression.append(_bpp) # compute distortions self.logger.info('* distortions') distortions_obj_per_compression = [ Distortions(reconstructed_images=c_img_batch, original_images=tf.cast(images, tf.float32), lambda_ms_ssim=1.0, lambda_psnr=1.0, lambda_feature_loss=1.0, data_format=self.DATA_FORMAT, loss_net_kwargs=None) for c_img_batch in images_per_compression ] distortions_ops_per_compression = [{ 'ms_ssim': d.compute_ms_ssim() } for d in distortions_obj_per_compression] graph.finalize() with tf.Session(config=get_sess_config(allow_growth=True), graph=graph) as sess: distortions_values_per_compression = [{ key: list() for key in self.DISTORTION_KEYS } for _ in self.COMPRESSION_LEVELS] bpp_values_per_compression = [ list() for _ in self.COMPRESSION_LEVELS ] n_images_processed = 0 n_images_processed_per_second = deque(10 * [0.0], 10) progress( n_images_processed, Cub200.NUM_VAL, '{}/{} images processed'.format(n_images_processed, self._dataset.NUM_VAL)) try: while True: batch_start_time = time.time() # compute distortions and bpp batch_bpp_values_per_compression, batch_distortions_values_per_compression = sess.run( [ bpp_op_per_compression, distortions_ops_per_compression ]) # collect values for comp_level, (dist_comp, bpp_comp) in enumerate( zip(batch_distortions_values_per_compression, batch_bpp_values_per_compression)): bpp_values_per_compression[comp_level].extend(bpp_comp) for key in self.DISTORTION_KEYS: distortions_values_per_compression[comp_level][ key].append(dist_comp[key]) n_images_processed += len( batch_bpp_values_per_compression[0]) n_images_processed_per_second.append( len(batch_bpp_values_per_compression[0]) / (time.time() - batch_start_time)) progress(n_images_processed, self._dataset.NUM_VAL, status='{}/{} images processed ({} img/s)'.format( n_images_processed, self._dataset.NUM_VAL, np.mean([ t for t in n_images_processed_per_second ]))) except tf.errors.OutOfRangeError: self.logger.info( 'reached end of dataset; processed {} images'.format( n_images_processed)) except KeyboardInterrupt: self.logger.info( 'manual interrupt; processed {}/{} images'.format( n_images_processed, self._dataset.NUM_VAL)) return [(np.nan, np.nan) for _ in self.COMPRESSION_LEVELS] mean_bpp_values_per_compression = [ np.mean(bpp_vals) for bpp_vals in bpp_values_per_compression ] mean_dist_values_per_compression = [{ key: np.mean(arr) for key, arr in dist_dict.items() } for dist_dict in distortions_values_per_compression] self._save_results(mean_bpp_values_per_compression, mean_dist_values_per_compression, 'webp', self.COMPRESSION_LEVELS)
def run(self): tf.reset_default_graph() print('') self.logger.info('===== building graph start =====') with tf.Graph().as_default() as graph: # datafeed self.logger.info('* datafeed') ip0 = InputPipeline( records=self._records_file, records_type=RecordsParser.RECORDS_UNLABELLED, shuffle_buffer_size=0, batch_size=self._batch_size, num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS, num_repeat=1, preprocessing_fn=CompressionPreprocessing. preprocess_image_for_eval, preprocessing_kwargs={ 'height': self._image_height, 'width': self._image_width, 'resize_side_min': min(self._image_height, self._image_width) }, drop_remainder=True, compute_bpp=False, shuffle=False) original_images = ip0.next_batch()[0] image_batches, bpp_op_per_compression = [], [] for records in self._bpg_records_files: ip = InputPipeline( records=records, records_type=RecordsParser.RECORDS_BPP, shuffle_buffer_size=0, batch_size=self._batch_size, num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS, num_repeat=1, preprocessing_fn=CompressionPreprocessing. preprocess_image_with_identity, preprocessing_kwargs={ 'height': self._image_height, 'width': self._image_width, 'dtype_out': tf.uint8 }, drop_remainder=True, compute_bpp=False, shuffle=False) images, bpp = ip.next_batch() image_batches.append(images) bpp_op_per_compression.append(bpp) # compute distortions self.logger.info('* distortions') distortions_obj_per_compression = [ Distortions(reconstructed_images=c_img_batch, original_images=original_images, lambda_ms_ssim=1.0, lambda_psnr=1.0, lambda_feature_loss=1.0, data_format=self.DATA_FORMAT, loss_net_kwargs=None) for c_img_batch in image_batches ] distortions_ops_per_compression = [{ 'ms_ssim': d.compute_ms_ssim() } for d in distortions_obj_per_compression] graph.finalize() with tf.Session(config=get_sess_config(allow_growth=True), graph=graph) as sess: distortions_values_per_compression = [{ key: list() for key in self.DISTORTION_KEYS } for _ in range(self._num_compression_levels)] bpp_values_per_compression = [ list() for _ in range(self._num_compression_levels) ] n_images_processed = 0 n_images_processed_per_second = deque(10 * [0.0], 10) progress( n_images_processed, Cub200.NUM_VAL, '{}/{} images processed'.format(n_images_processed, self._dataset.NUM_VAL)) try: while True: batch_start_time = time.time() # compute distortions and bpp batch_bpp_values_per_compression, batch_distortions_values_per_compression = sess.run( [ bpp_op_per_compression, distortions_ops_per_compression ]) # collect values for comp_level, (dist_comp, bpp_comp) in enumerate( zip(batch_distortions_values_per_compression, batch_bpp_values_per_compression)): bpp_values_per_compression[comp_level].extend(bpp_comp) for key in self.DISTORTION_KEYS: distortions_values_per_compression[comp_level][ key].append(dist_comp[key]) n_images_processed += len( batch_bpp_values_per_compression[0]) n_images_processed_per_second.append( len(batch_bpp_values_per_compression[0]) / (time.time() - batch_start_time)) progress(n_images_processed, self._dataset.NUM_VAL, status='{}/{} images processed ({} img/s)'.format( n_images_processed, self._dataset.NUM_VAL, np.mean([ t for t in n_images_processed_per_second ]))) except tf.errors.OutOfRangeError: self.logger.info( 'reached end of dataset; processed {} images'.format( n_images_processed)) except KeyboardInterrupt: self.logger.info( 'manual interrupt; processed {}/{} images'.format( n_images_processed, self._dataset.NUM_VAL)) return [(np.nan, np.nan) for _ in range(self._num_compression_levels)] mean_bpp_values_per_compression = [ np.mean(bpp_vals) for bpp_vals in bpp_values_per_compression ] mean_dist_values_per_compression = [{ key: np.mean(arr) for key, arr in dist_dict.items() } for dist_dict in distortions_values_per_compression] self._save_results(mean_bpp_values_per_compression, mean_dist_values_per_compression, 'bpg', None)
def eval_classifier_model(self, cnn_model: any( [ImagenetClassifier, FGVCClassifier]), ckpt_path=None): tf.reset_default_graph() print('') self.logger.info('===== building graph start: {} ====='.format( cnn_model.NAME)) # assertions assert self._dataset.NUM_CLASSES == cnn_model.num_classes, 'incostent number of classes ({} != {})'.format( self._dataset.NUM_CLASSES, cnn_model.num_classes) # image shapes image_shape_classification = cnn_model.INPUT_SHAPE image_shape_compression = CompressionPreprocessing.pad_image_shape( image_shape=image_shape_classification, size_multiple_of=self.SIZE_MULTIPLE_OF, extra_padding_multiples=2) # log image sizes self.logger.info( 'image_shape_classification={}'.format(image_shape_classification)) self.logger.info( 'image_shape_compression={}'.format(image_shape_compression)) # records files depending on inference resolution if image_shape_classification[0] < 256: bpg_records_files = list(self._bpg_records_files256) else: bpg_records_files = list(self._bpg_records_files336) self.logger.info('bpg_records_files: {}'.format(bpg_records_files)) with tf.Graph().as_default() as graph: # datafeed self.logger.info('* datafeed') image_batches, labels_batches, bpp_batches = [], [], [] for records in bpg_records_files: ip = InputPipeline( records=records, records_type=RecordsParser.RECORDS_LABELLED_BPP, shuffle_buffer_size=1, batch_size=self.BATCH_SIZE, num_preprocessing_threads=self.NUM_PREPROCESSING_THREADS, num_repeat=1, preprocessing_fn=CompressionPreprocessing. preprocess_image_with_identity, preprocessing_kwargs={ 'height': image_shape_compression[0], 'width': image_shape_compression[1], 'dtype_out': tf.uint8 }, drop_remainder=True, compute_bpp=False, shuffle=False, dtype_out=tf.uint8) images, labels, bpp = ip.next_batch() image_batches.append(images) labels_batches.append(labels) bpp_batches.append(bpp) # compression + inference op self.logger.info('* inference') predictions_per_compression = [] # inference kwargs if self._dataset_name == Imagenet.NAME: def inference_kwargs(**kwargs): return dict(graph=kwargs['graph']) else: def inference_kwargs(**kwargs): return dict( arg_scope=cnn_model.arg_scope(weight_decay=float(0)), is_training=False, return_predictions=True, reuse=True if kwargs['j'] > 0 else False) for j, image_batch_compressed in enumerate(image_batches): with tf.name_scope('inference_bpg{}'.format(j)): # crop center image_batch_compressed = tf.image.resize_image_with_crop_or_pad( image_batch_compressed, image_shape_classification[0], image_shape_classification[1]) # standardize appropriately image_batch_for_inference = cnn_model.standardize_tensor( image_batch_compressed) # predict preds = cnn_model.inference( image_batch_for_inference, **inference_kwargs(graph=graph, j=j)) predictions_per_compression.append(preds) # aggregate predictions_per_compression_op = tf.stack( predictions_per_compression, axis=0) self.logger.info('predictions_shape: {}'.format( predictions_per_compression_op.get_shape().as_list())) # restore if self._dataset_name == Imagenet.NAME: classifier_saver = None else: classifier_saver = tf.train.Saver( var_list=cnn_model.model_variables()) graph.finalize() with tf.Session(config=get_sess_config(allow_growth=False), graph=graph) as sess: if classifier_saver is not None: classifier_saver.restore(sess, ckpt_path) labels_all_comp_values = [list() for _ in range(self._num_records)] predictions_all_comp_values = [ list() for _ in range(self._num_records) ] bpp_all_comp_values = [list() for _ in range(self._num_records)] n_images_processed = 0 n_images_processed_per_second = deque(10 * [0.0], 10) progress( n_images_processed, self._dataset.NUM_VAL, '{}/{} images processed'.format(n_images_processed, self._dataset.NUM_VAL)) try: while True: batch_start_time = time.time() # run inference (batch_predictions_all_comp_values, batch_label_all_comp_values, batch_bpp_all_comp_values) = sess.run([ predictions_per_compression_op, labels_batches, bpp_batches ]) # collect predictions for comp_level, (preds_comp, bpp_comp, labels_comp) in enumerate( zip(batch_predictions_all_comp_values, batch_bpp_all_comp_values, batch_label_all_comp_values)): predictions_all_comp_values[comp_level].append( preds_comp) bpp_all_comp_values[comp_level].append(bpp_comp) labels_all_comp_values[comp_level].append( self.to_categorical(labels_comp, cnn_model.num_classes)) n_images_processed += len(batch_label_all_comp_values[0]) n_images_processed_per_second.append( len(batch_label_all_comp_values[0]) / (time.time() - batch_start_time)) progress(n_images_processed, self._dataset.NUM_VAL, status='{}/{} images processed ({} img/s)'.format( n_images_processed, self._dataset.NUM_VAL, np.mean([ t for t in n_images_processed_per_second ]))) except tf.errors.OutOfRangeError: self.logger.info( 'reached end of dataset; processed {} images'.format( n_images_processed)) except KeyboardInterrupt: self.logger.info( 'manual interrupt; processed {}/{} images'.format( n_images_processed, Imagenet.NUM_VAL)) return [(np.nan, np.nan) for _ in range(self._num_records) ], [np.nan for _ in range(self._num_records)] labels_all_comp_values = [ np.concatenate(labels_comp_values, axis=0) for labels_comp_values in labels_all_comp_values ] bpp_all_comp_values = [ np.mean(np.concatenate(bpp_values, 0)) for bpp_values in bpp_all_comp_values ] predictions_all_comp_values = [ np.concatenate(preds_comp_values, axis=0) for preds_comp_values in predictions_all_comp_values ] accuracies = [(self.top_k_accuracy(labels_values, preds_comp_values, 1), self.top_k_accuracy(labels_values, preds_comp_values, 5)) for preds_comp_values, labels_values in zip( predictions_all_comp_values, labels_all_comp_values)] return accuracies, bpp_all_comp_values