def inputs(self, eval_data): """Construct input for CIFAR evaluation using the Reader ops. Args: eval_data: bool, indicating if one should use the train or eval data set. Returns: images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. labels: Labels. 1D tensor of [batch_size] size. Raises: ValueError: If no data_dir """ if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') if FLAGS.dataset == 'cifar10': data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') images, labels = cifar10_input.inputs(eval_data=eval_data, data_dir=data_dir, batch_size=FLAGS.batch_size) elif FLAGS.dataset == 'imagenet': data_dir = FLAGS.data_dir if FLAGS.dataset_split_name == "test": FLAGS.dataset_split_name = "validation" images, labels = imagenet_input.inputs() if FLAGS.use_fp16: images = tf.cast(images, tf.float16) labels = tf.cast(labels, tf.float16) return images, labels
def inputs(data_class, shuffle=True): """Construct input for evaluation using the Reader ops. Args: data_class: string, indicating if one should use the 'train' or 'eval' or 'test' data set. shuffle: bool, to shuffle dataset list to read Returns: images: Images. 4D tensor of [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3] size. labels: Labels. 1D tensor of [BATCH_SIZE] size. Raises: ValueError: If no data_dir """ return data_input.inputs(data_class=data_class, batch_size=BATCH_SIZE, shuffle=shuffle)
def inputs(eval_data): """Construct input for CIFAR evaluation using the Reader ops. Args: eval_data: bool, indicating if one should use the train or eval data set. Returns: images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. labels: Labels. 1D tensor of [batch_size] size. Raises: ValueError: If no data_dir """ if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') images, labels = imagenet_input.inputs(eval_data=eval_data, data_dir=FLAGS.data_dir, batch_size=FLAGS.batch_size) if FLAGS.use_fp16: images = tf.cast(images, tf.float16) labels = tf.cast(labels, tf.float16) return images, labels
def distorted_inputs(): """Construct distorted input for CIFAR training using the Reader ops. Returns: images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. labels: Labels. 1D tensor of [batch_size] size. Raises: ValueError: If no data_dir """ if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') images, labels = imagenet_input.inputs(eval_data=False,data_dir=data_dir, batch_size=FLAGS.batch_size) if FLAGS.use_fp16: images = tf.cast(images, tf.float16) labels = tf.cast(labels, tf.float16) labels = tf.cast(labels,tf.int32,name="xxsad") return images, labels
def train(): print('[Dataset Configuration]') print('\tImageNet training root: %s' % FLAGS.train_image_root) print('\tImageNet training list: %s' % FLAGS.train_dataset) print('\tImageNet val root: %s' % FLAGS.val_image_root) print('\tImageNet val list: %s' % FLAGS.val_dataset) print('\tNumber of classes: %d' % FLAGS.num_classes) print('\tNumber of training images: %d' % FLAGS.num_train_instance) print('\tNumber of val images: %d' % FLAGS.num_val_instance) print('[Network Configuration]') print('\tBatch size: %d' % FLAGS.batch_size) print('\tNumber of GPUs: %d' % FLAGS.num_gpus) print('\tBasemodel file: %s' % FLAGS.basemodel) print('[Optimization Configuration]') print('\tL2 loss weight: %f' % FLAGS.l2_weight) print('\tThe momentum optimizer: %f' % FLAGS.momentum) print('\tInitial learning rate: %f' % FLAGS.initial_lr) print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch) print('\tLearning rate decay: %f' % FLAGS.lr_decay) print('[Training Configuration]') print('\tTrain dir: %s' % FLAGS.train_dir) print('\tTraining max steps: %d' % FLAGS.max_steps) print('\tSteps per displaying info: %d' % FLAGS.display) print('\tSteps per validation: %d' % FLAGS.val_interval) print('\tSteps during validation: %d' % FLAGS.val_iter) print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval) print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction) print('\tLog device placement: %d' % FLAGS.log_device_placement) with tf.Graph().as_default(): init_step = 0 global_step = tf.Variable(0, trainable=False, name='global_step') # Get images and labels of ImageNet import multiprocessing num_threads = multiprocessing.cpu_count() / FLAGS.num_gpus print('Load ImageNet dataset(%d threads)' % num_threads) with tf.device('/cpu:0'): print('\tLoading training data from %s' % FLAGS.train_dataset) with tf.variable_scope('train_image'): train_images, train_labels = data_input.distorted_inputs( FLAGS.train_image_root, FLAGS.train_dataset, FLAGS.batch_size, True, num_threads=num_threads, num_sets=FLAGS.num_gpus) print('\tLoading validation data from %s' % FLAGS.val_dataset) with tf.variable_scope('test_image'): val_images, val_labels = data_input.inputs( FLAGS.val_image_root, FLAGS.val_dataset, FLAGS.batch_size, False, num_threads=num_threads, num_sets=FLAGS.num_gpus) tf.summary.image('images', train_images[0][:2]) # Build model lr_decay_steps = map(float, FLAGS.lr_step_epoch.split(',')) lr_decay_steps = list( map(int, [ s * FLAGS.num_train_instance / FLAGS.batch_size / FLAGS.num_gpus for s in lr_decay_steps ])) hp = resnet.HParams(batch_size=FLAGS.batch_size, num_gpus=FLAGS.num_gpus, num_classes=FLAGS.num_classes, weight_decay=FLAGS.l2_weight, momentum=FLAGS.momentum, finetune=FLAGS.finetune) network_train = resnet.ResNet(hp, train_images, train_labels, global_step, name="train") network_train.build_model() network_train.build_train_op() train_summary_op = tf.summary.merge_all() # Summaries(training) network_val = resnet.ResNet(hp, val_images, val_labels, global_step, name="val", reuse_weights=True) network_val.build_model() print('Number of Weights: %d' % network_train._weights) print('FLOPs: %d' % network_train._flops) # Build an initialization operation to run below. init = tf.global_variables_initializer() # Start running operations on the Graph. sess = tf.Session(config=tf.ConfigProto( gpu_options=tf.GPUOptions( per_process_gpu_memory_fraction=FLAGS.gpu_fraction), allow_soft_placement=False, # allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement)) sess.run(init) # Create a saver. saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000) if FLAGS.checkpoint is not None: print('Load checkpoint %s' % FLAGS.checkpoint) saver.restore(sess, FLAGS.checkpoint) init_step = global_step.eval(session=sess) elif FLAGS.basemodel: # Define a different saver to save model checkpoints print('Load parameters from basemodel %s' % FLAGS.basemodel) variables = tf.global_variables() vars_restore = [ var for var in variables if not "Momentum" in var.name and not "global_step" in var.name ] saver_restore = tf.train.Saver(vars_restore, max_to_keep=10000) saver_restore.restore(sess, FLAGS.basemodel) else: print( 'No checkpoint file of basemodel found. Start from the scratch.' ) # Start queue runners & summary_writer tf.train.start_queue_runners(sess=sess) if not os.path.exists(FLAGS.train_dir): os.mkdir(FLAGS.train_dir) summary_writer = tf.summary.FileWriter( os.path.join(FLAGS.train_dir, str(global_step.eval(session=sess))), sess.graph) # Training! val_best_acc = 0.0 for step in range(init_step, FLAGS.max_steps): # val if step % FLAGS.val_interval == 0: val_loss, val_acc = 0.0, 0.0 for i in range(FLAGS.val_iter): loss_value, acc_value = sess.run( [network_val.loss, network_val.acc], feed_dict={network_val.is_train: False}) val_loss += loss_value val_acc += acc_value val_loss /= FLAGS.val_iter val_acc /= FLAGS.val_iter val_best_acc = max(val_best_acc, val_acc) format_str = ('%s: (val) step %d, loss=%.4f, acc=%.4f') print(format_str % (datetime.now(), step, val_loss, val_acc)) val_summary = tf.Summary() val_summary.value.add(tag='val/loss', simple_value=val_loss) val_summary.value.add(tag='val/acc', simple_value=val_acc) val_summary.value.add(tag='val/best_acc', simple_value=val_best_acc) summary_writer.add_summary(val_summary, step) summary_writer.flush() # Train lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps, step) start_time = time.time() # For timeline profiling # if step == 153: # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) # run_metadata = tf.RunMetadata() # _, loss_value, acc_value, train_summary_str = \ # sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op], # feed_dict={network_train.is_train:True, network_train.lr:lr_value} # , options=run_options, run_metadata=run_metadata) # # Create the Timeline object, and write it to a json # tl = timeline.Timeline(run_metadata.step_stats) # ctf = tl.generate_chrome_trace_format() # with open('timeline.json', 'w') as f: # f.write(ctf) # print('Wrote the timeline profile of %d iter training on %s' %(step, 'timeline.json')) # else: # _, loss_value, acc_value, train_summary_str = \ # sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op], # feed_dict={network_train.is_train:True, network_train.lr:lr_value}) _, loss_value, acc_value, train_summary_str = \ sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op], feed_dict={network_train.is_train:True, network_train.lr:lr_value}) duration = time.time() - start_time assert not np.isnan(loss_value) # Display & Summary(training) if step % FLAGS.display == 0 or step < 10: num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ( '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), step, loss_value, acc_value, lr_value, examples_per_sec, sec_per_batch)) summary_writer.add_summary(train_summary_str, step) # Save the model checkpoint periodically. if (step > init_step and step % FLAGS.checkpoint_interval == 0) or (step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) if sys.stdin in select.select([sys.stdin], [], [], 0)[0]: char = sys.stdin.read(1) if char == 'b': embed()
def train(): print('[Dataset Configuration]') print('\tImageNet test root: %s' % FLAGS.test_image_root) print('\tImageNet test list: %s' % FLAGS.test_dataset) print('\tNumber of classes: %d' % FLAGS.num_classes) print('\tNumber of test images: %d' % FLAGS.num_test_instance) print('[Network Configuration]') print('\tBatch size: %d' % FLAGS.batch_size) print('\tCheckpoint file: %s' % FLAGS.checkpoint) print('[Optimization Configuration]') print('\tL2 loss weight: %f' % FLAGS.l2_weight) print('\tThe momentum optimizer: %f' % FLAGS.momentum) print('\tInitial learning rate: %f' % FLAGS.initial_lr) print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch) print('\tLearning rate decay: %f' % FLAGS.lr_decay) print('[Evaluation Configuration]') print('\tOutput file path: %s' % FLAGS.output_file) print('\tTest iterations: %d' % FLAGS.test_iter) print('\tSteps per displaying info: %d' % FLAGS.display) print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction) print('\tLog device placement: %d' % FLAGS.log_device_placement) with tf.Graph().as_default(): global_step = tf.Variable(0, trainable=False, name='global_step') # Get images and labels of ImageNet print('Load ImageNet dataset') with tf.device('/cpu:0'): print('\tLoading test data from %s' % FLAGS.test_dataset) with tf.variable_scope('test_image'): test_images, test_labels = data_input.inputs( FLAGS.test_image_root, FLAGS.test_dataset, FLAGS.batch_size, False, num_threads=1, center_crop=True) # Build a Graph that computes the predictions from the inference model. images = tf.placeholder(tf.float32, [ FLAGS.batch_size, data_input.IMAGE_HEIGHT, data_input.IMAGE_WIDTH, 3 ]) labels = tf.placeholder(tf.int32, [FLAGS.batch_size]) # Build model with tf.device('/GPU:0'): hp = resnet.HParams(batch_size=FLAGS.batch_size, num_classes=FLAGS.num_classes, weight_decay=FLAGS.l2_weight, momentum=FLAGS.momentum, finetune=FLAGS.finetune) network = resnet.ResNet(hp, images, labels, global_step) network.build_model() print('\tNumber of Weights: %d' % network._weights) print('\tFLOPs: %d' % network._flops) # Build an initialization operation to run below. init = tf.initialize_all_variables() # Start running operations on the Graph. sess = tf.Session(config=tf.ConfigProto( gpu_options=tf.GPUOptions( per_process_gpu_memory_fraction=FLAGS.gpu_fraction), allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement)) '''debugging attempt from tensorflow.python import debug as tf_debug sess = tf_debug.LocalCLIDebugWrapperSession(sess) def _get_data(datum, tensor): return tensor == train_images sess.add_tensor_filter("get_data", _get_data) ''' sess.run(init) # Create a saver. saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000) if FLAGS.checkpoint is not None: saver.restore(sess, FLAGS.checkpoint) print('Load checkpoint %s' % FLAGS.checkpoint) else: print( 'No checkpoint file of basemodel found. Start from the scratch.' ) # Start queue runners & summary_writer tf.train.start_queue_runners(sess=sess) # Test! test_loss = 0.0 test_acc = 0.0 test_time = 0.0 confusion_matrix = np.zeros((FLAGS.num_classes, FLAGS.num_classes), dtype=np.int32) for i in range(FLAGS.test_iter): test_images_val, test_labels_val = sess.run( [test_images, test_labels]) start_time = time.time() loss_value, acc_value, pred_value = sess.run( [network.loss, network.acc, network.preds], feed_dict={ network.is_train: False, images: test_images_val, labels: test_labels_val }) duration = time.time() - start_time test_loss += loss_value test_acc += acc_value test_time += duration for l, p in zip(test_labels_val, pred_value): confusion_matrix[l, p] += 1 if i % FLAGS.display == 0: num_examples_per_step = FLAGS.batch_size examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ( '%s: iter %d, loss=%.4f, acc=%.4f (%.1f examples/sec; %.3f sec/batch)' ) print(format_str % (datetime.now(), i, loss_value, acc_value, examples_per_sec, sec_per_batch)) test_loss /= FLAGS.test_iter test_acc /= FLAGS.test_iter # Print and save results sec_per_image = test_time / FLAGS.test_iter / FLAGS.batch_size print('Done! Acc: %.6f, Test time: %.3f sec, %.7f sec/example' % (test_acc, test_time, sec_per_image)) print('Saving result... ') result = { 'accuracy': test_acc, 'confusion_matrix': confusion_matrix, 'test_time': test_time, 'sec_per_image': sec_per_image } with open(FLAGS.output_file, 'wb') as fd: pickle.dump(result, fd) print('done!')
def train(): print('[Dataset Configuration]') print('\tImageNet test root: %s' % FLAGS.test_image_root) print('\tImageNet test list: %s' % FLAGS.test_dataset) print('\tNumber of classes: %d' % FLAGS.num_classes) print('\tNumber of test images: %d' % FLAGS.num_test_instance) print('[Network Configuration]') print('\tBatch size: %d' % FLAGS.batch_size) print('\tCheckpoint file: %s' % FLAGS.checkpoint) print('[Optimization Configuration]') print('\tL2 loss weight: %f' % FLAGS.l2_weight) print('\tThe momentum optimizer: %f' % FLAGS.momentum) print('\tInitial learning rate: %f' % FLAGS.initial_lr) print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch) print('\tLearning rate decay: %f' % FLAGS.lr_decay) print('[Evaluation Configuration]') print('\tOutput file path: %s' % FLAGS.output_file) print('\tTest iterations: %d' % FLAGS.test_iter) print('\tSteps per displaying info: %d' % FLAGS.display) print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction) print('\tLog device placement: %d' % FLAGS.log_device_placement) graph = tf.Graph() with graph.as_default() as g: global_step = tf.Variable(0, trainable=False, name='global_step', dtype=tf.int64) # Get images and labels of ImageNet print('Load ImageNet dataset') with tf.device('/cpu:0'): print('\tLoading test data from %s' % FLAGS.test_dataset) with tf.variable_scope('test_image'): test_images, test_labels = data_input.inputs( FLAGS.test_image_root, FLAGS.test_dataset, FLAGS.batch_size, False, num_threads=1, center_crop=True) # Build a Graph that computes the predictions from the inference model. imagenet_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) imagenet_std = np.array([0.229, 0.224, 0.225], dtype=np.float32) images = tf.placeholder(tf.float32, [ FLAGS.batch_size, data_input.IMAGE_HEIGHT, data_input.IMAGE_WIDTH, 3 ]) labels = tf.placeholder(tf.int64, [FLAGS.batch_size]) def build_network(): network = resnet_model.resnet_v1(resnet_depth=50, num_classes=1000, dropblock_size=None, dropblock_keep_probs=[None] * 4, data_format='channels_last') return network(inputs=images, is_training=False) logits = build_network() sess = tf.Session(graph=g, config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) # Build an initialization operation to run below. init = tf.initialize_all_variables() sess.run(init) # Create a saver. saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000) if FLAGS.checkpoint is not None: saver.restore(sess, FLAGS.checkpoint) print('Load checkpoint %s' % FLAGS.checkpoint) else: print( 'No checkpoint file of basemodel found. Start from the scratch.' ) # Start queue runners & summary_writer tf.train.start_queue_runners(sess=sess) # Test! test_loss = 0.0 test_acc = 0.0 test_time = 0.0 confusion_matrix = np.zeros((FLAGS.num_classes, FLAGS.num_classes), dtype=np.int32) path = int(FLAGS.checkpoint.split('-')[1]) neuron_selector = tf.placeholder(tf.int32) y = logits[0][neuron_selector] gradient_saliency = saliency.GradientSaliency(g, sess, y, images) os.system('mkdir -p ./{}/weight_{:05d}'.format(FLAGS.save_path, path)) classified_flag = [] ground_truth = [] pred_label = [] count = 0 for i in range(FLAGS.test_iter): test_images_val, test_labels_val = sess.run( [test_images[0], test_labels[0]]) start_time = time.time() # Evaluate metrics # Replace True with "test_labels_val[0] == FLAGS.class_ind" for analyzing the specified # class_ind if True: predictions = np.argmax(logits.eval( session=sess, feed_dict={images: test_images_val}), axis=1) ones = np.ones([FLAGS.batch_size]) zeros = np.zeros([FLAGS.batch_size]) correct = np.where(np.equal(predictions, test_labels_val), ones, zeros) acc = np.mean(correct) duration = time.time() - start_time test_acc += acc test_time += duration classified_flag.append([i, acc]) ground_truth.append(test_labels_val[0]) pred_label.append(predictions[0]) # Get gradients grad = gradient_saliency.GetMask( test_images_val[0, :], feed_dict={neuron_selector: test_labels_val[0]}) imsave( './{}/weight_{:05d}/img_{:05d}.jpg'.format( FLAGS.save_path, path, i), (test_images_val[0, :] * imagenet_std + imagenet_mean)) np.save( './{}/weight_{:05d}/grad_{:05d}.npy'.format( FLAGS.save_path, path, i), np.mean(grad, axis=-1)) count += 1 test_acc /= FLAGS.test_iter np.save( './{}/weight_{:05d}/classified_flag.npy'.format( FLAGS.save_path, int(path)), np.array(classified_flag)) np.save( './{}/weight_{:05d}/ground_truth.npy'.format( FLAGS.save_path, int(path)), np.array(ground_truth)) np.save( './{}/weight_{:05d}/pred_label.npy'.format(FLAGS.save_path, int(path)), np.array(pred_label)) # Print and save results sec_per_image = test_time / FLAGS.test_iter / FLAGS.batch_size print('Done! Acc: %.6f, Test time: %.3f sec, %.7f sec/example' % (test_acc, test_time, sec_per_image)) print('done!')
def train(): print('[Dataset Configuration]') print('\tImageNet training root: %s' % FLAGS.train_image_root) print('\tImageNet training list: %s' % FLAGS.train_dataset) print('\tImageNet val root: %s' % FLAGS.val_image_root) print('\tImageNet val list: %s' % FLAGS.val_dataset) print('\tNumber of classes: %d' % FLAGS.num_classes) print('\tNumber of training images: %d' % FLAGS.num_train_instance) print('\tNumber of val images: %d' % FLAGS.num_val_instance) print('[Network Configuration]') print('\tBatch size: %d' % FLAGS.batch_size) print('\tNumber of GPUs: %d' % FLAGS.num_gpus) print('\tNumber of Groups: %d-%d-%d' % (FLAGS.ngroups3, FLAGS.ngroups2, FLAGS.ngroups1)) print('\tBasemodel file: %s' % FLAGS.basemodel) print('[Optimization Configuration]') print('\tL2 loss weight: %f' % FLAGS.l2_weight) print('\tOverlap loss weight: %f' % FLAGS.gamma1) print('\tWeight split loss weight: %f' % FLAGS.gamma2) print('\tUniform loss weight: %f' % FLAGS.gamma3) print('\tThe momentum optimizer: %f' % FLAGS.momentum) print('\tNo update on BN scale parameter: %d' % FLAGS.bn_no_scale) print('\tWeighted split loss: %d' % FLAGS.weighted_group_loss) print('\tInitial learning rate: %f' % FLAGS.initial_lr) print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch) print('\tLearning rate decay: %f' % FLAGS.lr_decay) print('[Training Configuration]') print('\tTrain dir: %s' % FLAGS.train_dir) print('\tTraining max steps: %d' % FLAGS.max_steps) print('\tSteps per displaying info: %d' % FLAGS.display) print('\tSteps per validation: %d' % FLAGS.val_interval) print('\tSteps during validation: %d' % FLAGS.val_iter) print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval) print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction) print('\tLog device placement: %d' % FLAGS.log_device_placement) with tf.Graph().as_default(): init_step = 0 global_step = tf.Variable(0, trainable=False, name='global_step') # Get images and labels of ImageNet import multiprocessing num_threads = multiprocessing.cpu_count() / FLAGS.num_gpus print('Load ImageNet dataset(%d threads)' % num_threads) with tf.device('/cpu:0'): print('\tLoading training data from %s' % FLAGS.train_dataset) with tf.variable_scope('train_image'): train_images, train_labels = data_input.distorted_inputs( FLAGS.train_image_root, FLAGS.train_dataset, FLAGS.batch_size, True, num_threads=num_threads, num_sets=FLAGS.num_gpus) # tf.summary.image('images', train_images[0]) print('\tLoading validation data from %s' % FLAGS.val_dataset) with tf.variable_scope('test_image'): val_images, val_labels = data_input.inputs( FLAGS.val_image_root, FLAGS.val_dataset, FLAGS.batch_size, False, num_threads=num_threads, num_sets=FLAGS.num_gpus) # Build model lr_decay_steps = map(float, FLAGS.lr_step_epoch.split(',')) lr_decay_steps = map(int, [ s * FLAGS.num_train_instance / FLAGS.batch_size / FLAGS.num_gpus for s in lr_decay_steps ]) hp = resnet.HParams(batch_size=FLAGS.batch_size, num_gpus=FLAGS.num_gpus, num_classes=FLAGS.num_classes, weight_decay=FLAGS.l2_weight, ngroups1=FLAGS.ngroups1, ngroups2=FLAGS.ngroups2, ngroups3=FLAGS.ngroups3, gamma1=FLAGS.gamma1, gamma2=FLAGS.gamma2, gamma3=FLAGS.gamma3, momentum=FLAGS.momentum, bn_no_scale=FLAGS.bn_no_scale, weighted_group_loss=FLAGS.weighted_group_loss, finetune=FLAGS.finetune) network_train = resnet.ResNet(hp, train_images, train_labels, global_step, name="train") network_train.build_model() network_train.build_train_op() train_summary_op = tf.summary.merge_all() # Summaries(training) network_val = resnet.ResNet(hp, val_images, val_labels, global_step, name="val", reuse_weights=True) network_val.build_model() print('Number of Weights: %d' % network_train._weights) print('FLOPs: %d' % network_train._flops) # Build an initialization operation to run below. init = tf.global_variables_initializer() # Start running operations on the Graph. sess = tf.Session(config=tf.ConfigProto( gpu_options=tf.GPUOptions( per_process_gpu_memory_fraction=FLAGS.gpu_fraction), # allow_soft_placement=False, allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement)) sess.run(init) # Create a saver. saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000) if FLAGS.checkpoint is not None: print('Load checkpoint %s' % FLAGS.checkpoint) saver.restore(sess, FLAGS.checkpoint) init_step = global_step.eval(session=sess) elif FLAGS.basemodel: # Define a different saver to save model checkpoints print('Load parameters from basemodel %s' % FLAGS.basemodel) variables = tf.global_variables() vars_restore = [ var for var in variables if not "Momentum" in var.name and not "group" in var.name and not "global_step" in var.name ] saver_restore = tf.train.Saver(vars_restore, max_to_keep=10000) saver_restore.restore(sess, FLAGS.basemodel) else: print( 'No checkpoint file of basemodel found. Start from the scratch.' ) # Start queue runners & summary_writer tf.train.start_queue_runners(sess=sess) if not os.path.exists(FLAGS.train_dir): os.mkdir(FLAGS.train_dir) summary_writer = tf.summary.FileWriter( os.path.join(FLAGS.train_dir, str(global_step.eval(session=sess))), sess.graph) # Training! val_best_acc = 0.0 for step in xrange(init_step, FLAGS.max_steps): # val if step % FLAGS.val_interval == 0: val_loss, val_acc = 0.0, 0.0 for i in range(FLAGS.val_iter): loss_value, acc_value = sess.run( [network_val.loss, network_val.acc], feed_dict={network_val.is_train: False}) val_loss += loss_value val_acc += acc_value val_loss /= FLAGS.val_iter val_acc /= FLAGS.val_iter val_best_acc = max(val_best_acc, val_acc) format_str = ('%s: (val) step %d, loss=%.4f, acc=%.4f') print(format_str % (datetime.now(), step, val_loss, val_acc)) val_summary = tf.Summary() val_summary.value.add(tag='val/loss', simple_value=val_loss) val_summary.value.add(tag='val/acc', simple_value=val_acc) val_summary.value.add(tag='val/best_acc', simple_value=val_best_acc) summary_writer.add_summary(val_summary, step) summary_writer.flush() # Train lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps, step) start_time = time.time() if step == 153: run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() _, loss_value, acc_value, train_summary_str = \ sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op], feed_dict={network_train.is_train:True, network_train.lr:lr_value} , options=run_options, run_metadata=run_metadata) _ = sess.run(network_train.validity_op) # Create the Timeline object, and write it to a json tl = timeline.Timeline(run_metadata.step_stats) ctf = tl.generate_chrome_trace_format() with open('timeline.json', 'w') as f: f.write(ctf) print('Wrote the timeline profile of %d iter training on %s' % (step, 'timeline.json')) else: _, loss_value, acc_value, train_summary_str = \ sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op], feed_dict={network_train.is_train:True, network_train.lr:lr_value}) _ = sess.run(network_train.validity_op) duration = time.time() - start_time assert not np.isnan(loss_value) # Display & Summary(training) if step % FLAGS.display == 0 or step < 10: num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ( '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), step, loss_value, acc_value, lr_value, examples_per_sec, sec_per_batch)) summary_writer.add_summary(train_summary_str, step) # Save the model checkpoint periodically. if (step > init_step and step % FLAGS.checkpoint_interval == 0) or (step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) # Does it work correctly? # if sys.stdin in select.select([sys.stdin], [], [], 0)[0]: # char = sys.stdin.read(1) # if char == 'b': # embed() # Add weights and groupings visualization filters = [64, [64, 256], [128, 512], [256, 1024], [512, 2048]] if FLAGS.group_summary_interval is not None: if step % FLAGS.group_summary_interval == 0: img_summaries = [] if FLAGS.ngroups1 > 1: logits_weights = get_var_value('logits/fc/weights', sess) split_p1 = get_var_value('group/split_p1/q', sess) split_q1 = get_var_value('group/split_q1/q', sess) feature_indices = np.argsort( np.argmax(split_p1, axis=0)) class_indices = np.argsort(np.argmax(split_q1, axis=0)) img_summaries.append( img_to_summary( np.repeat(split_p1[:, feature_indices], 20, axis=0), 'split_p1')) img_summaries.append( img_to_summary( np.repeat(split_q1[:, class_indices], 200, axis=0), 'split_q1')) img_summaries.append( img_to_summary( np.abs(logits_weights[feature_indices, :] [:, class_indices]), 'logits')) if FLAGS.ngroups2 > 1: conv5_1_shortcut = get_var_value( 'conv5_1/conv_shortcut/kernel', sess) conv5_1_conv_1 = get_var_value('conv5_1/conv_1/kernel', sess) conv5_1_conv_2 = get_var_value('conv5_1/conv_2/kernel', sess) conv5_1_conv_3 = get_var_value('conv5_1/conv_3/kernel', sess) conv5_2_conv_1 = get_var_value('conv5_2/conv_1/kernel', sess) conv5_2_conv_2 = get_var_value('conv5_2/conv_2/kernel', sess) conv5_2_conv_3 = get_var_value('conv5_2/conv_3/kernel', sess) conv5_3_conv_1 = get_var_value('conv5_3/conv_1/kernel', sess) conv5_3_conv_2 = get_var_value('conv5_3/conv_2/kernel', sess) conv5_3_conv_3 = get_var_value('conv5_3/conv_3/kernel', sess) split_p2 = get_var_value('group/split_p2/q', sess) split_q2 = _merge_split_q( split_p1, _get_even_merge_idxs(FLAGS.ngroups1, FLAGS.ngroups2)) split_r211 = get_var_value('group/split_r211/q', sess) split_r212 = get_var_value('group/split_r212/q', sess) split_r221 = get_var_value('group/split_r221/q', sess) split_r222 = get_var_value('group/split_r222/q', sess) split_r231 = get_var_value('group/split_r231/q', sess) split_r232 = get_var_value('group/split_r232/q', sess) feature_indices1 = np.argsort( np.argmax(split_p2, axis=0)) feature_indices2 = np.argsort( np.argmax(split_q2, axis=0)) feature_indices3 = np.argsort( np.argmax(split_r211, axis=0)) feature_indices4 = np.argsort( np.argmax(split_r212, axis=0)) feature_indices5 = np.argsort( np.argmax(split_r221, axis=0)) feature_indices6 = np.argsort( np.argmax(split_r222, axis=0)) feature_indices7 = np.argsort( np.argmax(split_r231, axis=0)) feature_indices8 = np.argsort( np.argmax(split_r232, axis=0)) conv5_1_shortcut_img = np.abs( conv5_1_shortcut[:, :, feature_indices1, :] [:, :, :, feature_indices2].transpose([2, 0, 3, 1]).reshape( filters[3][1], filters[4][1])) conv5_1_conv_1_img = np.abs( conv5_1_conv_1[:, :, feature_indices1, :] [:, :, :, feature_indices3].transpose([2, 0, 3, 1]).reshape( filters[3][1], filters[4][0])) conv5_1_conv_2_img = np.abs( conv5_1_conv_2[:, :, feature_indices3, :] [:, :, :, feature_indices4].transpose([2, 0, 3, 1]).reshape( filters[4][0] * 3, filters[4][0] * 3)) conv5_1_conv_3_img = np.abs( conv5_1_conv_3[:, :, feature_indices4, :] [:, :, :, feature_indices2].transpose([2, 0, 3, 1]).reshape( filters[4][0], filters[4][1])) conv5_2_conv_1_img = np.abs( conv5_2_conv_1[:, :, feature_indices2, :] [:, :, :, feature_indices5].transpose([2, 0, 3, 1]).reshape( filters[4][1], filters[4][0])) conv5_2_conv_2_img = np.abs( conv5_2_conv_2[:, :, feature_indices5, :] [:, :, :, feature_indices6].transpose([2, 0, 3, 1]).reshape( filters[4][0] * 3, filters[4][0] * 3)) conv5_2_conv_3_img = np.abs( conv5_2_conv_3[:, :, feature_indices6, :] [:, :, :, feature_indices2].transpose([2, 0, 3, 1]).reshape( filters[4][0], filters[4][1])) conv5_3_conv_1_img = np.abs( conv5_3_conv_1[:, :, feature_indices2, :] [:, :, :, feature_indices7].transpose([2, 0, 3, 1]).reshape( filters[4][1], filters[4][0])) conv5_3_conv_2_img = np.abs( conv5_3_conv_2[:, :, feature_indices7, :] [:, :, :, feature_indices8].transpose([2, 0, 3, 1]).reshape( filters[4][0] * 3, filters[4][0] * 3)) conv5_3_conv_3_img = np.abs( conv5_3_conv_3[:, :, feature_indices8, :] [:, :, :, feature_indices2].transpose([2, 0, 3, 1]).reshape( filters[4][0], filters[4][1])) img_summaries.append( img_to_summary( np.repeat(split_p2[:, feature_indices1], 20, axis=0), 'split_p2')) img_summaries.append( img_to_summary( np.repeat(split_r211[:, feature_indices3], 20, axis=0), 'split_r211')) img_summaries.append( img_to_summary( np.repeat(split_r212[:, feature_indices4], 20, axis=0), 'split_r212')) img_summaries.append( img_to_summary( np.repeat(split_r221[:, feature_indices5], 20, axis=0), 'split_r221')) img_summaries.append( img_to_summary( np.repeat(split_r222[:, feature_indices6], 20, axis=0), 'split_r222')) img_summaries.append( img_to_summary( np.repeat(split_r231[:, feature_indices7], 20, axis=0), 'split_r231')) img_summaries.append( img_to_summary( np.repeat(split_r232[:, feature_indices8], 20, axis=0), 'split_r232')) img_summaries.append( img_to_summary(conv5_1_shortcut_img, 'conv5_1/shortcut')) img_summaries.append( img_to_summary(conv5_1_conv_1_img, 'conv5_1/conv_1')) img_summaries.append( img_to_summary(conv5_1_conv_2_img, 'conv5_1/conv_2')) img_summaries.append( img_to_summary(conv5_1_conv_3_img, 'conv5_1/conv_3')) img_summaries.append( img_to_summary(conv5_2_conv_1_img, 'conv5_2/conv_1')) img_summaries.append( img_to_summary(conv5_2_conv_2_img, 'conv5_2/conv_2')) img_summaries.append( img_to_summary(conv5_2_conv_3_img, 'conv5_2/conv_3')) img_summaries.append( img_to_summary(conv5_3_conv_1_img, 'conv5_3/conv_1')) img_summaries.append( img_to_summary(conv5_3_conv_2_img, 'conv5_3/conv_2')) img_summaries.append( img_to_summary(conv5_3_conv_3_img, 'conv5_3/conv_3')) # if FLAGS.ngroups3 > 1: # conv4_1_shortcut = get_var_value('conv4_1/conv_shortcut/kernel', sess) # conv4_1_conv_1 = get_var_value('conv4_1/conv_1/kernel', sess) # conv4_1_conv_2 = get_var_value('conv4_1/conv_2/kernel', sess) # conv4_2_conv_1 = get_var_value('conv4_2/conv_1/kernel', sess) # conv4_2_conv_2 = get_var_value('conv4_2/conv_2/kernel', sess) # split_p3 = get_var_value('group/split_p3/q', sess) # split_q3 = _merge_split_q(split_p2, _get_even_merge_idxs(FLAGS.ngroups2, FLAGS.ngroups3)) # split_r31 = get_var_value('group/split_r31/q', sess) # split_r32 = get_var_value('group/split_r32/q', sess) # feature_indices1 = np.argsort(np.argmax(split_p3, axis=0)) # feature_indices2 = np.argsort(np.argmax(split_q3, axis=0)) # feature_indices3 = np.argsort(np.argmax(split_r31, axis=0)) # feature_indices4 = np.argsort(np.argmax(split_r32, axis=0)) # conv4_1_shortcut_img = np.abs(conv4_1_shortcut[:,:,feature_indices1,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[2], filters[3])) # conv4_1_conv_1_img = np.abs(conv4_1_conv_1[:,:,feature_indices1,:][:,:,:,feature_indices3].transpose([2,0,3,1]).reshape(filters[2] * 3, filters[3] * 3)) # conv4_1_conv_2_img = np.abs(conv4_1_conv_2[:,:,feature_indices3,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[3] * 3, filters[3] * 3)) # conv4_2_conv_1_img = np.abs(conv4_2_conv_1[:,:,feature_indices2,:][:,:,:,feature_indices4].transpose([2,0,3,1]).reshape(filters[3] * 3, filters[3] * 3)) # conv4_2_conv_2_img = np.abs(conv4_2_conv_2[:,:,feature_indices4,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[3] * 3, filters[3] * 3)) # img_summaries.append(img_to_summary(np.repeat(split_p3[:, feature_indices1], 20, axis=0), 'split_p3')) # img_summaries.append(img_to_summary(np.repeat(split_r31[:, feature_indices3], 20, axis=0), 'split_r31')) # img_summaries.append(img_to_summary(np.repeat(split_r32[:, feature_indices4], 20, axis=0), 'split_r32')) # img_summaries.append(img_to_summary(conv4_1_shortcut_img, 'conv4_1/shortcut')) # img_summaries.append(img_to_summary(conv4_1_conv_1_img, 'conv4_1/conv_1')) # img_summaries.append(img_to_summary(conv4_1_conv_2_img, 'conv4_1/conv_2')) # img_summaries.append(img_to_summary(conv4_2_conv_1_img, 'conv4_2/conv_1')) # img_summaries.append(img_to_summary(conv4_2_conv_2_img, 'conv4_2/conv_2')) if img_summaries: img_summary = tf.Summary(value=img_summaries) summary_writer.add_summary(img_summary, step) summary_writer.flush()
def train(): print('[Dataset Configuration]') print('\tImageNet test root: %s' % FLAGS.test_image_root) print('\tImageNet test list: %s' % FLAGS.test_dataset) print('\tNumber of classes: %d' % FLAGS.num_classes) print('\tNumber of test images: %d' % FLAGS.num_test_instance) print('[Network Configuration]') print('\tBatch size: %d' % FLAGS.batch_size) print('\tCheckpoint file: %s' % FLAGS.checkpoint) print('[Optimization Configuration]') print('\tL2 loss weight: %f' % FLAGS.l2_weight) print('\tThe momentum optimizer: %f' % FLAGS.momentum) print('\tInitial learning rate: %f' % FLAGS.initial_lr) print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch) print('\tLearning rate decay: %f' % FLAGS.lr_decay) print('[Evaluation Configuration]') print('\tOutput file path: %s' % FLAGS.output_file) print('\tTest iterations: %d' % FLAGS.test_iter) print('\tSteps per displaying info: %d' % FLAGS.display) print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction) print('\tLog device placement: %d' % FLAGS.log_device_placement) with tf.Graph().as_default(): global_step = tf.Variable(0, trainable=False, name='global_step') # Get images and labels of ImageNet print('Load ImageNet dataset') with tf.device('/cpu:0'): print('\tLoading test data from %s' % FLAGS.test_dataset) with tf.variable_scope('test_image'): test_images, test_labels = data_input.inputs( FLAGS.test_image_root, FLAGS.test_dataset, FLAGS.batch_size, False, num_threads=1, center_crop=True) # Build a Graph that computes the predictions from the inference model. images = tf.placeholder(tf.float32, [ FLAGS.batch_size, data_input.IMAGE_HEIGHT, data_input.IMAGE_WIDTH, 3 ]) labels = tf.placeholder(tf.int32, [FLAGS.batch_size]) # Build model with tf.device('/GPU:0'): hp = resnet.HParams(batch_size=FLAGS.batch_size, num_gpus=1, num_classes=FLAGS.num_classes, weight_decay=FLAGS.l2_weight, momentum=FLAGS.momentum, finetune=FLAGS.finetune) network = resnet.ResNet(hp, [images], [labels], global_step) network.build_model() print('\tNumber of Weights: %d' % network._weights) print('\tFLOPs: %d' % network._flops) # Build an initialization operation to run below. init = tf.initialize_all_variables() # Start running operations on the Graph. sess = tf.Session(config=tf.ConfigProto( gpu_options=tf.GPUOptions( per_process_gpu_memory_fraction=FLAGS.gpu_fraction), allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement)) '''debugging attempt from tensorflow.python import debug as tf_debug sess = tf_debug.LocalCLIDebugWrapperSession(sess) def _get_data(datum, tensor): return tensor == train_images sess.add_tensor_filter("get_data", _get_data) ''' sess.run(init) # Create a saver. saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000) if FLAGS.checkpoint is not None: saver.restore(sess, FLAGS.checkpoint) print('Load checkpoint %s' % FLAGS.checkpoint) else: print( 'No checkpoint file of basemodel found. Start from the scratch.' ) # Start queue runners & summary_writer tf.train.start_queue_runners(sess=sess) fi = ti.TensorFI(sess, logLevel=50, name="convolutional", disableInjections=False) # save the results t1 = open("fi-org-resnet-top1.csv", "a") t5 = open("fi-org-resnet-top5.csv", "a") fiTime = 1000 for i in range(FLAGS.test_iter): fi.turnOffInjections() test_images_val, test_labels_val = sess.run( [test_images[0], test_labels[0]]) fi.turnOnInjections() for j in range(fiTime): probs = sess.run( [network.probs], feed_dict={ network.is_train: False, images: test_images_val, labels: test_labels_val }) probs = np.asarray(probs) probs = probs[0] counter = 0 for each_prob in probs: pred = (np.argsort(each_prob)[::-1])[0:5] label = test_labels_val[counter] counter += 1 print(pred, 'label:', label) if (label == pred[0]): t1.write( ` 1 ` + ",") t5.write( ` 1 ` + ",") elif (label in pred[1:]): t1.write( ` 0 ` + ",") t5.write( ` 1 ` + ",") else: t1.write( ` 0 ` + ",") t5.write( ` 0 ` + ",") print('--------fi on resnet, %d img, %d FI run' % (i + 1, j + 1)) t1.write("\n") t5.write("\n")
def train(): print('[Dataset Configuration]') print('\tImageNet training root: %s' % FLAGS.train_image_root) print('\tImageNet training list: %s' % FLAGS.train_dataset) print('\tImageNet test root: %s' % FLAGS.test_image_root) print('\tImageNet test list: %s' % FLAGS.test_dataset) print('\tNumber of classes: %d' % FLAGS.num_classes) print('\tNumber of training images: %d' % FLAGS.num_train_instance) print('\tNumber of test images: %d' % FLAGS.num_test_instance) print('[Network Configuration]') print('\tNumber of GPUs: %d' % FLAGS.num_gpu) print('\tBatch size: %d' % FLAGS.batch_size) print('\tSplitted Network: %s' % FLAGS.split) if FLAGS.split: print('\tClustering path: %s' % FLAGS.cluster_path) print('\tNo logit map: %s' % FLAGS.no_logit_map) print('[Optimization Configuration]') print('\tL2 loss weight: %f' % FLAGS.l2_weight) print('\tThe momentum optimizer: %f' % FLAGS.momentum) print('\tInitial learning rate: %f' % FLAGS.initial_lr) print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch) print('\tLearning rate decay: %f' % FLAGS.lr_decay) print('[Training Configuration]') print('\tTrain dir: %s' % FLAGS.train_dir) print('\tTraining max steps: %d' % FLAGS.max_steps) print('\tSteps per displaying info: %d' % FLAGS.display) print('\tSteps per testing: %d' % FLAGS.test_interval) print('\tSteps during testing: %d' % FLAGS.test_iter) print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval) print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction) print('\tLog device placement: %d' % FLAGS.log_device_placement) with open(FLAGS.cluster_path) as fd: clustering = pickle.load(fd) with tf.Graph().as_default(): init_step = 0 global_step = tf.Variable(0, trainable=False, name='global_step') # Get images and labels of CIFAR-100 with tf.variable_scope('train_image'): train_images, train_labels = data_input.distorted_inputs( FLAGS.train_image_root, FLAGS.train_dataset, FLAGS.batch_size * FLAGS.num_gpu, True) with tf.variable_scope('test_image'): test_images, test_labels = data_input.inputs( FLAGS.test_image_root, FLAGS.test_dataset, FLAGS.batch_size * FLAGS.num_gpu, False) # Build a Graph that computes the predictions from the inference model. images = tf.placeholder(tf.float32, [ FLAGS.batch_size * FLAGS.num_gpu, data_input.IMAGE_HEIGHT, data_input.IMAGE_WIDTH, 3 ]) labels = tf.placeholder(tf.int32, [FLAGS.batch_size * FLAGS.num_gpu]) # Build model lr_decay_steps = map(float, FLAGS.lr_step_epoch.split(',')) lr_decay_steps = map(int, [ s * FLAGS.num_train_instance / FLAGS.batch_size / FLAGS.num_gpu for s in lr_decay_steps ]) print('Learning rate decays at iter: %s' % str(lr_decay_steps)) hp = resnet.HParams(num_gpu=FLAGS.num_gpu, batch_size=FLAGS.batch_size, split=FLAGS.split, num_classes=FLAGS.num_classes, weight_decay=FLAGS.l2_weight, momentum=FLAGS.momentum, no_logit_map=FLAGS.no_logit_map) network = resnet.ResNet(hp, images, labels, global_step) if FLAGS.split: network.set_clustering(clustering) network.build_model() print('%d flops' % network._flops) print('%d params' % network._weights) network.build_train_op() # Summaries(training) train_summary_op = tf.merge_all_summaries() # Build an initialization operation to run below. init = tf.initialize_all_variables() # Start running operations on the Graph. sess = tf.Session(config=tf.ConfigProto( gpu_options=tf.GPUOptions( per_process_gpu_memory_fraction=FLAGS.gpu_fraction), log_device_placement=FLAGS.log_device_placement)) sess.run(init) # Create a saver. # saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000) saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000, write_version=tf.train.SaverDef.V2) ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir) if ckpt and ckpt.model_checkpoint_path: print('\tRestore from %s' % ckpt.model_checkpoint_path) # Restores from checkpoint saver.restore(sess, ckpt.model_checkpoint_path) init_step = int( ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) else: ckpt_base = tf.train.get_checkpoint_state(FLAGS.baseline_dir) if ckpt_base and ckpt_base.model_checkpoint_path: # Check loadable variables(variable with same name and same shape) and load them only print('No checkpoint file found. Start from the baseline.') loadable_vars = utils._get_loadable_vars( ckpt_base.model_checkpoint_path, verbose=True) # saver_base = tf.train.Saver(loadable_vars) saver_base = tf.train.Saver(loadable_vars, write_version=tf.train.SaverDef.V2) saver_base.restore(sess, ckpt_base.model_checkpoint_path) else: print('No checkpoint file found. Start from the scratch.') # Start queue runners & summary_writer tf.train.start_queue_runners(sess=sess) if not os.path.exists(FLAGS.train_dir): os.mkdir(FLAGS.train_dir) summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) # Training! test_best_acc = 0.0 for step in xrange(init_step, FLAGS.max_steps): # Test if step % FLAGS.test_interval == 0: test_loss, test_acc = 0.0, 0.0 for i in range(FLAGS.test_iter): test_images_val, test_labels_val = sess.run( [test_images, test_labels]) loss_value, acc_value = sess.run( [network.loss, network.acc], feed_dict={ network.is_train: False, images: test_images_val, labels: test_labels_val }) test_loss += loss_value test_acc += acc_value test_loss /= FLAGS.test_iter test_acc /= FLAGS.test_iter test_best_acc = max(test_best_acc, test_acc) format_str = ('%s: (Test) step %d, loss=%.4f, acc=%.4f') print(format_str % (datetime.now(), step, test_loss, test_acc)) test_summary = tf.Summary() test_summary.value.add(tag='test/loss', simple_value=test_loss) test_summary.value.add(tag='test/acc', simple_value=test_acc) test_summary.value.add(tag='test/best_acc', simple_value=test_best_acc) summary_writer.add_summary(test_summary, step) summary_writer.flush() # Train lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps, step) start_time = time.time() train_images_val, train_labels_val = sess.run( [train_images, train_labels]) _, loss_value, acc_value, train_summary_str = \ sess.run([network.train_op, network.loss, network.acc, train_summary_op], feed_dict={network.is_train:True, network.lr:lr_value, images:train_images_val, labels:train_labels_val}) duration = time.time() - start_time assert not np.isnan(loss_value) # Display & Summary(training) if step % FLAGS.display == 0: num_examples_per_step = FLAGS.batch_size examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ( '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), step, loss_value, acc_value, lr_value, examples_per_sec, sec_per_batch)) summary_writer.add_summary(train_summary_str, step) # Save the model checkpoint periodically. if (step > init_step and step % FLAGS.checkpoint_interval == 0) or (step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)
def train(): print('[Dataset Configuration]') print('\tImageNet test root: %s' % FLAGS.test_image_root) print('\tImageNet test list: %s' % FLAGS.test_dataset) print('\tNumber of classes: %d' % FLAGS.num_classes) print('\tNumber of test images: %d' % FLAGS.num_test_instance) print('[Network Configuration]') print('\tBatch size: %d' % FLAGS.batch_size) print('\tCheckpoint file: %s' % FLAGS.checkpoint) print('[Optimization Configuration]') print('\tL2 loss weight: %f' % FLAGS.l2_weight) print('\tThe momentum optimizer: %f' % FLAGS.momentum) print('\tInitial learning rate: %f' % FLAGS.initial_lr) print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch) print('\tLearning rate decay: %f' % FLAGS.lr_decay) print('[Evaluation Configuration]') print('\tOutput file path: %s' % FLAGS.output_file) print('\tTest iterations: %d' % FLAGS.test_iter) print('\tSteps per displaying info: %d' % FLAGS.display) print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction) print('\tLog device placement: %d' % FLAGS.log_device_placement) with tf.Graph().as_default(): global_step = tf.Variable(0, trainable=False, name='global_step') # Get images and labels of ImageNet print('Load ImageNet dataset') with tf.device('/cpu:0'): print('\tLoading test data from %s' % FLAGS.test_dataset) with tf.variable_scope('test_image'): test_images, test_labels = data_input.inputs( FLAGS.test_image_root, FLAGS.test_dataset, FLAGS.batch_size, False, num_threads=1, center_crop=True) # Build a Graph that computes the predictions from the inference model. images = tf.placeholder(tf.float32, [ FLAGS.batch_size, data_input.IMAGE_HEIGHT, data_input.IMAGE_WIDTH, 3 ]) labels = tf.placeholder(tf.int32, [FLAGS.batch_size]) # images = tf.placeholder(tf.float32, [1, data_input.IMAGE_HEIGHT, data_input.IMAGE_WIDTH, 3]) # labels = tf.placeholder(tf.int32, [1]) # Build model with tf.device('/GPU:0'): hp = resnet.HParams(batch_size=FLAGS.batch_size, num_gpus=1, num_classes=FLAGS.num_classes, weight_decay=FLAGS.l2_weight, momentum=FLAGS.momentum, finetune=FLAGS.finetune) network = resnet.ResNet(hp, [images], [labels], global_step) network.build_model() print('\tNumber of Weights: %d' % network._weights) print('\tFLOPs: %d' % network._flops) # Build an initialization operation to run below. init = tf.initialize_all_variables() # Start running operations on the Graph. sess = tf.Session(config=tf.ConfigProto( gpu_options=tf.GPUOptions( per_process_gpu_memory_fraction=FLAGS.gpu_fraction), allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement)) ''' debugging attempt from tensorflow.python import debug as tf_debug sess = tf_debug.LocalCLIDebugWrapperSession(sess) def _get_data(datum, tensor): return tensor == train_images sess.add_tensor_filter("get_data", _get_data) ''' sess.run(init) # Create a saver. saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000) if FLAGS.checkpoint is not None: saver.restore(sess, FLAGS.checkpoint) print('Load checkpoint %s' % FLAGS.checkpoint) else: print( 'No checkpoint file of basemodel found. Start from the scratch.' ) # Start queue runners & summary_writer tf.train.start_queue_runners(sess=sess) ''' total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 sha = [] for dim in shape: variable_parameters *= dim.value sha.append(dim.value) total_parameters += variable_parameters print(variable.name, sha) print ' ' print(total_parameters) wri = open("op-name.csv", "a") for op in tf.get_default_graph().get_operations(): wri.write(str(op.name) + "\n") fi = ti.TensorFI(sess, logLevel = 50, name = "convolutional", disableInjections=False) start_time = time.time() for i in range(FLAGS.test_iter): fi.turnOffInjections() test_images_val, test_labels_val = sess.run([test_images[0], test_labels[0]]) # img = test_images_val[i, :, :, :] # label = test_labels_val[i] # img = img.reshape((1,224,224,3)) fi.turnOnInjections() probs = sess.run([ network.probs ], feed_dict={network.is_train:False, images:test_images_val, labels:test_labels_val}) probs = np.asarray(probs) try: probs = probs[0][0] preds = (np.argsort(probs)[::-1])[0:5] print preds, 'label: ', test_labels_val except: pass ''' ''' probs = sess.run([network.probs], feed_dict={network.is_train:False, images:test_images_val, labels:test_labels_val}) print( len(probs[0]) ) probs = probs[0] for i in range(len(probs)): preds = probs[i, :] pred = (np.argsort(preds)[::-1])[0:5] print pred, 'label: ', test_labels_val[i] ''' wri = open("acyOnValSet.csv", "a") wri.write("top1" + "," + "top5" + "," + "numOfImg" + "\n") # Test! test_loss = 0.0 test_acc = 0.0 test_time = 0.0 confusion_matrix = np.zeros((FLAGS.num_classes, FLAGS.num_classes), dtype=np.int32) numOfImg = 0 top1 = 0. top5 = 0. for i in range(FLAGS.test_iter): test_images_val, test_labels_val = sess.run( [test_images[0], test_labels[0]]) print(len(test_labels_val), test_images_val.shape, test_labels_val.shape) break probs = sess.run( [network.probs], feed_dict={ network.is_train: False, images: test_images_val, labels: test_labels_val }) probs = np.asarray(probs) probs = probs[0] counter = 0 for each_prob in probs: pred = (np.argsort(each_prob)[::-1])[0:5] label = test_labels_val[counter] counter += 1 if (label == pred[0]): top1 += 1 top5 += 1 elif (label in pred[1:]): top5 += 1 numOfImg += 1 print('------------ evaluating on validation set', i, ' batch') print("top1: %f, top5: %f, numImg: %d" % (top1, top5, numOfImg)) # wri.write(`acc_value` + "\n") wri.write( ` top1 ` + "," + ` top5 ` + "," + ` numOfImg ` + "\n") print top1 / numOfImg, top5 / numOfImg '''
def train(): print('[Dataset Configuration]') print('\tImageNet training root: %s' % FLAGS.train_image_root) print('\tImageNet training list: %s' % FLAGS.train_dataset) print('\tImageNet test root: %s' % FLAGS.test_image_root) print('\tImageNet test list: %s' % FLAGS.test_dataset) print('\tImageNet class name list: %s' % FLAGS.class_list) print('\tNumber of classes: %d' % FLAGS.num_classes) print('\tNumber of training images: %d' % FLAGS.num_train_instance) print('\tNumber of test images: %d' % FLAGS.num_test_instance) print('[Network Configuration]') print('\tNumber of GPUs: %d' % FLAGS.num_gpu) print('\tBatch size: %d' % FLAGS.batch_size) print('\tSplitted Network: %s' % FLAGS.split) if FLAGS.split: print('\tClustering path: %s' % FLAGS.cluster_path) print('\tNo logit map: %s' % FLAGS.no_logit_map) print('[Testing Configuration]') print('\tCheckpoint path: %s' % FLAGS.ckpt_path) print('\tDataset: %s' % ('Training' if FLAGS.train_data else 'Test')) print('\tNumber of testing iterations: %d' % FLAGS.test_iter) print('\tOutput path: %s' % FLAGS.output) print('\tSteps per displaying info: %d' % FLAGS.display) print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction) print('\tLog device placement: %d' % FLAGS.log_device_placement) with tf.Graph().as_default(): # The CIFAR-100 dataset with tf.variable_scope('test_image'): if FLAGS.train_data: test_images, test_labels = data_input.inputs( FLAGS.train_image_root, FLAGS.train_dataset, FLAGS.batch_size * FLAGS.num_gpu, False) else: test_images, test_labels = data_input.inputs( FLAGS.test_image_root, FLAGS.test_dataset, FLAGS.batch_size * FLAGS.num_gpu, False) # The class labels with open(FLAGS.class_list) as fd: classes = [temp.strip()[:30] for temp in fd.readlines()] # Build a Graph that computes the predictions from the inference model. images = tf.placeholder(tf.float32, [ FLAGS.batch_size * FLAGS.num_gpu, data_input.IMAGE_HEIGHT, data_input.IMAGE_WIDTH, 3 ]) labels = tf.placeholder(tf.int32, [FLAGS.batch_size * FLAGS.num_gpu]) # Build model hp = resnet.HParams(num_gpu=FLAGS.num_gpu, batch_size=FLAGS.batch_size, split=FLAGS.split, num_classes=FLAGS.num_classes, weight_decay=FLAGS.l2_weight, momentum=FLAGS.momentum, no_logit_map=FLAGS.no_logit_map) network = resnet.ResNet(hp, images, labels, None) if FLAGS.split: network.set_clustering(clustering) network.build_model() print('%d flops' % network._flops) print('%d params' % network._weights) # network.build_train_op() # NO training op # Build an initialization operation to run below. init = tf.initialize_all_variables() # Start running operations on the Graph. sess = tf.Session(config=tf.ConfigProto( gpu_options=tf.GPUOptions( per_process_gpu_memory_fraction=FLAGS.gpu_fraction), log_device_placement=FLAGS.log_device_placement)) sess.run(init) # Create a saver. saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000) if os.path.isdir(FLAGS.ckpt_path): ckpt = tf.train.get_checkpoint_state(FLAGS.ckpt_path) # Restores from checkpoint if ckpt and ckpt.model_checkpoint_path: print('\tRestore from %s' % ckpt.model_checkpoint_path) saver.restore(sess, ckpt.model_checkpoint_path) else: print('No checkpoint file found in the dir [%s]' % FLAGS.ckpt_path) sys.exit(1) elif os.path.isfile(FLAGS.ckpt_path): print('\tRestore from %s' % FLAGS.ckpt_path) saver.restore(sess, FLAGS.ckpt_path) else: print('No checkpoint file found in the path [%s]' % FLAGS.ckpt_path) sys.exit(1) # Start queue runners tf.train.start_queue_runners(sess=sess) # Testing! result_ll = [[0, 0] for _ in range(FLAGS.num_classes) ] # Correct/wrong counts for each class test_loss = 0.0, 0.0 for i in range(FLAGS.test_iter): test_images_val, test_labels_val = sess.run( [test_images, test_labels]) preds_val, loss_value, acc_value = sess.run( [network.preds, network.loss, network.acc], feed_dict={ network.is_train: False, images: test_images_val, labels: test_labels_val }) test_loss += loss_value for j in range(FLAGS.batch_size * FLAGS.num_gpu): correct = 0 if test_labels_val[j] == preds_val[j] else 1 result_ll[test_labels_val[j] % FLAGS.num_classes][correct] += 1 if i % FLAGS.display == 0: format_str = ('%s: (Test) step %d, loss=%.4f, acc=%.4f') print(format_str % (datetime.now(), i, loss_value, acc_value)) test_loss /= FLAGS.test_iter # Summary display & output acc_list = [float(r[0]) / float(r[0] + r[1]) for r in result_ll] result_total = np.sum(np.array(result_ll), axis=0) acc_total = float(result_total[0]) / np.sum(result_total) print 'Class \t\t\tT\tF\tAcc.' format_str = '%-31s %7d %7d %.5f' for i in range(FLAGS.num_classes): print format_str % (classes[i], result_ll[i][0], result_ll[i][1], acc_list[i]) print(format_str % ('(Total)', result_total[0], result_total[1], acc_total)) # Output to file(if specified) if FLAGS.output.strip(): with open(FLAGS.output, 'w') as fd: fd.write('Class \t\t\tT\tF\tAcc.\n') format_str = '%-31s %7d %7d %.5f' for i in range(FLAGS.num_classes): t, f = result_ll[i] format_str = '%-31s %7d %7d %.5f\n' fd.write(format_str % (classes[i].replace(' ', '-'), t, f, acc_list[i])) fd.write( format_str % ('(Total)', result_total[0], result_total[1], acc_total))
def train(): print('[Dataset Configuration]') print('\tImageNet training root: %s' % FLAGS.train_image_root) print('\tImageNet training list: %s' % FLAGS.train_dataset) print('\tImageNet val root: %s' % FLAGS.val_image_root) print('\tImageNet val list: %s' % FLAGS.val_dataset) print('\tNumber of classes: %d' % FLAGS.num_classes) print('\tNumber of training images: %d' % FLAGS.num_train_instance) print('\tNumber of val images: %d' % FLAGS.num_val_instance) print('[Network Configuration]') print('\tBatch size: %d' % FLAGS.batch_size) print('\tNumber of GPUs: %d' % FLAGS.num_gpus) print('\tNumber of Groups: %d-%d-%d' % (FLAGS.ngroups3, FLAGS.ngroups2, FLAGS.ngroups1)) print('\tBasemodel file: %s' % FLAGS.basemodel) print('[Optimization Configuration]') print('\tL2 loss weight: %f' % FLAGS.l2_weight) print('\tThe momentum optimizer: %f' % FLAGS.momentum) print('\tInitial learning rate: %f' % FLAGS.initial_lr) print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch) print('\tLearning rate decay: %f' % FLAGS.lr_decay) print('[Training Configuration]') print('\tTrain dir: %s' % FLAGS.train_dir) print('\tTraining max steps: %d' % FLAGS.max_steps) print('\tSteps per displaying info: %d' % FLAGS.display) print('\tSteps per validation: %d' % FLAGS.val_interval) print('\tSteps during validation: %d' % FLAGS.val_iter) print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval) print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction) print('\tLog device placement: %d' % FLAGS.log_device_placement) with tf.Graph().as_default(): init_step = 0 global_step = tf.Variable(0, trainable=False, name='global_step') # Get images and labels of ImageNet import multiprocessing num_threads = multiprocessing.cpu_count() / FLAGS.num_gpus print('Load ImageNet dataset(%d threads)' % num_threads) with tf.device('/cpu:0'): print('\tLoading training data from %s' % FLAGS.train_dataset) with tf.variable_scope('train_image'): train_images, train_labels = data_input.distorted_inputs(FLAGS.train_image_root, FLAGS.train_dataset , FLAGS.batch_size, True, num_threads=num_threads, num_sets=FLAGS.num_gpus) # tf.summary.image('images', train_images[0]) print('\tLoading validation data from %s' % FLAGS.val_dataset) with tf.variable_scope('test_image'): val_images, val_labels = data_input.inputs(FLAGS.val_image_root, FLAGS.val_dataset , FLAGS.batch_size, False, num_threads=num_threads, num_sets=FLAGS.num_gpus) # Get splitted params if not FLAGS.basemodel: print('No basemodel found to load split params') sys.exit(-1) else: print('Load split params from %s' % FLAGS.basemodel) def get_perms(q_name, ngroups): split_q = reader.get_tensor(q_name) q_amax = np.argmax(split_q, axis=0) return [np.where(q_amax == i)[0] for i in range(ngroups)] reader = tf.train.NewCheckpointReader(FLAGS.basemodel) split_params = {} print('\tlogits...') base_logits_w = reader.get_tensor('logits/fc/weights') base_logits_b = reader.get_tensor('logits/fc/biases') split_p1_idxs = get_perms('group/split_p1/q', FLAGS.ngroups1) split_q1_idxs = get_perms('group/split_q1/q', FLAGS.ngroups1) logits_params = {'weights':[], 'biases':[], 'input_perms':[], 'output_perms':[]} for i in range(FLAGS.ngroups1): logits_params['weights'].append(base_logits_w[split_p1_idxs[i], :][:, split_q1_idxs[i]]) logits_params['biases'].append(base_logits_b[split_q1_idxs[i]]) logits_params['input_perms'] = split_p1_idxs logits_params['output_perms'] = split_q1_idxs split_params['logits'] = logits_params if FLAGS.ngroups2 > 1: print('\tconv5_x...') base_conv5_1_shortcut_k = reader.get_tensor('conv5_1/shortcut/kernel') base_conv5_1_conv1_k = reader.get_tensor('conv5_1/conv_1/kernel') base_conv5_1_conv2_k = reader.get_tensor('conv5_1/conv_2/kernel') base_conv5_2_conv1_k = reader.get_tensor('conv5_2/conv_1/kernel') base_conv5_2_conv2_k = reader.get_tensor('conv5_2/conv_2/kernel') split_p2_idxs = get_perms('group/split_p2/q', FLAGS.ngroups2) split_q2_idxs = _merge_split_idxs(split_p1_idxs, _get_even_merge_idxs(FLAGS.ngroups1, FLAGS.ngroups2)) split_r21_idxs = get_perms('group/split_r21/q', FLAGS.ngroups2) split_r22_idxs = get_perms('group/split_r22/q', FLAGS.ngroups2) conv5_1_params = {'shortcut':[], 'conv1':[], 'conv2':[], 'p_perms':[], 'q_perms':[], 'r_perms':[]} for i in range(FLAGS.ngroups2): conv5_1_params['shortcut'].append(base_conv5_1_shortcut_k[:,:,split_p2_idxs[i],:][:,:,:,split_q2_idxs[i]]) conv5_1_params['conv1'].append(base_conv5_1_conv1_k[:,:,split_p2_idxs[i],:][:,:,:,split_r21_idxs[i]]) conv5_1_params['conv2'].append(base_conv5_1_conv2_k[:,:,split_r21_idxs[i],:][:,:,:,split_q2_idxs[i]]) conv5_1_params['p_perms'] = split_p2_idxs conv5_1_params['q_perms'] = split_q2_idxs conv5_1_params['r_perms'] = split_r21_idxs split_params['conv5_1'] = conv5_1_params conv5_2_params = {'conv1':[], 'conv2':[], 'p_perms':[], 'r_perms':[]} for i in range(FLAGS.ngroups2): conv5_2_params['conv1'].append(base_conv5_2_conv1_k[:,:,split_q2_idxs[i],:][:,:,:,split_r22_idxs[i]]) conv5_2_params['conv2'].append(base_conv5_2_conv2_k[:,:,split_r22_idxs[i],:][:,:,:,split_q2_idxs[i]]) conv5_2_params['p_perms'] = split_q2_idxs conv5_2_params['r_perms'] = split_r22_idxs split_params['conv5_2'] = conv5_2_params for i, unit_name in enumerate(['conv5_1', 'conv5_2', 'conv5_3', 'conv5_4', 'conv5_5', 'conv5_6']): print('\t' + unit_name) sp = {} split_params[unit_name] = sp if FLAGS.ngroups3 > 1: print('\tconv4_x...') base_conv4_1_shortcut_k = reader.get_tensor('conv4_1/shortcut/kernel') base_conv4_1_conv1_k = reader.get_tensor('conv4_1/conv_1/kernel') base_conv4_1_conv2_k = reader.get_tensor('conv4_1/conv_2/kernel') base_conv4_2_conv1_k = reader.get_tensor('conv4_2/conv_1/kernel') base_conv4_2_conv2_k = reader.get_tensor('conv4_2/conv_2/kernel') split_p3_idxs = get_perms('group/split_p3/q', FLAGS.ngroups3) split_q3_idxs = _merge_split_idxs(split_p2_idxs, _get_even_merge_idxs(FLAGS.ngroups2, FLAGS.ngroups3)) split_r31_idxs = get_perms('group/split_r31/q', FLAGS.ngroups3) split_r32_idxs = get_perms('group/split_r32/q', FLAGS.ngroups3) conv4_1_params = {'shortcut':[], 'conv1':[], 'conv2':[], 'p_perms':[], 'q_perms':[], 'r_perms':[]} for i in range(FLAGS.ngroups3): conv4_1_params['shortcut'].append(base_conv4_1_shortcut_k[:,:,split_p3_idxs[i],:][:,:,:,split_q3_idxs[i]]) conv4_1_params['conv1'].append(base_conv4_1_conv1_k[:,:,split_p3_idxs[i],:][:,:,:,split_r31_idxs[i]]) conv4_1_params['conv2'].append(base_conv4_1_conv2_k[:,:,split_r31_idxs[i],:][:,:,:,split_q3_idxs[i]]) conv4_1_params['p_perms'] = split_p3_idxs conv4_1_params['q_perms'] = split_q3_idxs conv4_1_params['r_perms'] = split_r31_idxs split_params['conv4_1'] = conv4_1_params conv4_2_params = {'conv1':[], 'conv2':[], 'p_perms':[], 'r_perms':[]} for i in range(FLAGS.ngroups3): conv4_2_params['conv1'].append(base_conv4_2_conv1_k[:,:,split_q3_idxs[i],:][:,:,:,split_r32_idxs[i]]) conv4_2_params['conv2'].append(base_conv4_2_conv2_k[:,:,split_r32_idxs[i],:][:,:,:,split_q3_idxs[i]]) conv4_2_params['p_perms'] = split_q3_idxs conv4_2_params['r_perms'] = split_r32_idxs split_params['conv4_2'] = conv4_2_params # Build model lr_decay_steps = map(float,FLAGS.lr_step_epoch.split(',')) lr_decay_steps = map(int,[s*FLAGS.num_train_instance/FLAGS.batch_size/FLAGS.num_gpus for s in lr_decay_steps]) hp = resnet.HParams(batch_size=FLAGS.batch_size, num_gpus=FLAGS.num_gpus, num_classes=FLAGS.num_classes, weight_decay=FLAGS.l2_weight, ngroups1=FLAGS.ngroups1, ngroups2=FLAGS.ngroups2, ngroups3=FLAGS.ngroups3, split_params=split_params, momentum=FLAGS.momentum, finetune=FLAGS.finetune) network_train = resnet.ResNet(hp, train_images, train_labels, global_step, name="train") network_train.build_model() network_train.build_train_op() train_summary_op = tf.summary.merge_all() # Summaries(training) network_val = resnet.ResNet(hp, val_images, val_labels, global_step, name="val", reuse_weights=True) network_val.build_model() print('Number of Weights: %d' % network_train._weights) print('FLOPs: %d' % network_train._flops) # Build an initialization operation to run below. init = tf.global_variables_initializer() # Start running operations on the Graph. sess = tf.Session(config=tf.ConfigProto( gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_fraction), # allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement)) '''debugging attempt from tensorflow.python import debug as tf_debug sess = tf_debug.LocalCLIDebugWrapperSession(sess) def _get_data(datum, tensor): return tensor == train_images sess.add_tensor_filter("get_data", _get_data) ''' sess.run(init) # Create a saver. saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000) if FLAGS.checkpoint is not None: saver.restore(sess, FLAGS.checkpoint) init_step = global_step.eval(session=sess) print('Load checkpoint %s' % FLAGS.checkpoint) elif FLAGS.basemodel: # Define a different saver to save model checkpoints # Select only base variables (exclude split layers) print('Load parameters from basemodel %s' % FLAGS.basemodel) variables = tf.global_variables() vars_restore = [var for var in variables if not "Momentum" in var.name and not "logits" in var.name and not "global_step" in var.name] if FLAGS.ngroups2 > 1: vars_restore = [var for var in vars_restore if not "conv5_" in var.name] if FLAGS.ngroups3 > 1: vars_restore = [var for var in vars_restore if not "conv4_" in var.name] saver_restore = tf.train.Saver(vars_restore, max_to_keep=10000) saver_restore.restore(sess, FLAGS.basemodel) else: print('No checkpoint file of basemodel found. Start from the scratch.') # Start queue runners & summary_writer tf.train.start_queue_runners(sess=sess) if not os.path.exists(FLAGS.train_dir): os.mkdir(FLAGS.train_dir) summary_writer = tf.summary.FileWriter(os.path.join(FLAGS.train_dir, str(global_step.eval(session=sess))), sess.graph) # Training! val_best_acc = 0.0 for step in xrange(init_step, FLAGS.max_steps): # val if step % FLAGS.val_interval == 0: val_loss, val_acc = 0.0, 0.0 for i in range(FLAGS.val_iter): loss_value, acc_value = sess.run([network_val.loss, network_val.acc], feed_dict={network_val.is_train:False}) val_loss += loss_value val_acc += acc_value val_loss /= FLAGS.val_iter val_acc /= FLAGS.val_iter val_best_acc = max(val_best_acc, val_acc) format_str = ('%s: (val) step %d, loss=%.4f, acc=%.4f') print (format_str % (datetime.now(), step, val_loss, val_acc)) val_summary = tf.Summary() val_summary.value.add(tag='val/loss', simple_value=val_loss) val_summary.value.add(tag='val/acc', simple_value=val_acc) val_summary.value.add(tag='val/best_acc', simple_value=val_best_acc) summary_writer.add_summary(val_summary, step) summary_writer.flush() # Train lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps, step) start_time = time.time() _, loss_value, acc_value, train_summary_str = \ sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op], feed_dict={network_train.is_train:True, network_train.lr:lr_value}) duration = time.time() - start_time assert not np.isnan(loss_value) # Display & Summary(training) if step % FLAGS.display == 0: num_examples_per_step = FLAGS.batch_size examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ('%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f ' 'sec/batch)') print (format_str % (datetime.now(), step, loss_value, acc_value, lr_value, examples_per_sec, sec_per_batch)) summary_writer.add_summary(train_summary_str, step) # Save the model checkpoint periodically. if (step > init_step and step % FLAGS.checkpoint_interval == 0) or (step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) if sys.stdin in select.select([sys.stdin], [], [], 0)[0]: char = sys.stdin.read(1) if char == 'b': embed()
def train(): print('[Dataset Configuration]') print('\tImageNet test root: %s' % FLAGS.test_image_root) print('\tImageNet test list: %s' % FLAGS.test_dataset) print('\tNumber of classes: %d' % FLAGS.num_classes) print('\tNumber of test images: %d' % FLAGS.num_test_instance) print('[Network Configuration]') print('\tBatch size: %d' % FLAGS.batch_size) print('\tCheckpoint file: %s' % FLAGS.checkpoint) print('[Optimization Configuration]') print('\tL2 loss weight: %f' % FLAGS.l2_weight) print('\tThe momentum optimizer: %f' % FLAGS.momentum) print('\tInitial learning rate: %f' % FLAGS.initial_lr) print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch) print('\tLearning rate decay: %f' % FLAGS.lr_decay) print('[Evaluation Configuration]') print('\tOutput file path: %s' % FLAGS.output_file) print('\tTest iterations: %d' % FLAGS.test_iter) print('\tSteps per displaying info: %d' % FLAGS.display) print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction) print('\tLog device placement: %d' % FLAGS.log_device_placement) sess = tf.Session() global_step = tf.Variable(0, trainable=False, name='global_step') FLAGS.test_dataset = "./val.txt" # Get images and labels of ImageNet print('Load ImageNet dataset') with tf.device('/cpu:0'): print('\tLoading test data from %s' % FLAGS.test_dataset) with tf.variable_scope('test_image'): test_images, test_labels = data_input.inputs(FLAGS.test_image_root, FLAGS.test_dataset, FLAGS.batch_size, False, num_threads=1, center_crop=True) # Build a Graph that computes the predictions from the inference model. images = tf.placeholder( tf.float32, [FLAGS.batch_size, data_input.IMAGE_HEIGHT, data_input.IMAGE_WIDTH, 3]) labels = tf.placeholder(tf.int32, [FLAGS.batch_size]) # images = tf.placeholder(tf.float32, [1, data_input.IMAGE_HEIGHT, data_input.IMAGE_WIDTH, 3]) # labels = tf.placeholder(tf.int32, [1]) # Build model with tf.device('/cpu:0'): hp = resnet.HParams(batch_size=FLAGS.batch_size, num_gpus=1, num_classes=FLAGS.num_classes, weight_decay=FLAGS.l2_weight, momentum=FLAGS.momentum, finetune=FLAGS.finetune) network = resnet.ResNet(hp, [images], [labels], global_step) network.build_model() print('\tNumber of Weights: %d' % network._weights) print('\tFLOPs: %d' % network._flops) # Build an initialization operation to run below. init = tf.initialize_all_variables() # Start running operations on the Graph. sess = tf.Session( config=tf.ConfigProto(gpu_options=tf.GPUOptions( per_process_gpu_memory_fraction=FLAGS.gpu_fraction), allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement)) sess.run(init) # Create a saver. saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000) if FLAGS.checkpoint is not None: saver.restore(sess, FLAGS.checkpoint) print('Load checkpoint %s' % FLAGS.checkpoint) else: print('No checkpoint file of basemodel found. Start from the scratch.') # Start queue runners & summary_writer tf.train.start_queue_runners(sess=sess) "============================================================================================================" "================= Begin to insert restriction on selective layers ========================================" "============================================================================================================" "NOTE: this model requires GPU (otherwise it'll report error while restoring the variables from the old graph to the new graph)" # get all the operators in the graph ops = [ tensor for op in sess.graph.get_operations() for tensor in op.values() ] graph_def = sess.graph.as_graph_def() def get_op_dependency(op): "get all the node that precedes the target op" cur_op = [] #op = sess.graph.get_tensor_by_name("ranger_11/ranger_10/ranger_9/ranger_8/ranger_7/ranger_6/ranger_5/ranger_4/ranger_3/ranger_2/ranger_1/ranger/Relu_5:0").op cur_op.append(op) next_op = [] a = open("resnet-op.txt", "a") # save all the ops depend on the output op into file while (not (next_op == [] and cur_op == [])): next_op = [] for each in cur_op: printline = False for inp in each.inputs: printline = True #print(inp) a.write(str(inp) + "\n") next_op.append(inp.op) if (printline): #print('') a.write("\n\n") cur_op = next_op def get_target_scope_prefix(scope_name, dup_cnt, dummy_scope_name, dummy_graph_dup_cnt): "get the scope prefix of the target path (the latest duplicated path)" target_graph_prefix = "" # the scope prefix of the latest path if (dup_cnt == 0): target_graph_prefix = "" # elif (dup_cnt == 1): target_graph_prefix = str(scope_name + "/") # e.g., ranger/relu:0 if (dummy_graph_dup_cnt == 1): target_graph_prefix = dummy_scope_name + "/" + target_graph_prefix # e.g., dummy/ranger/relu:0 else: target_graph_prefix = str(scope_name + "/") if (dummy_graph_dup_cnt > 0): # e.g., dummy/ranger/relu:0 target_graph_prefix = dummy_scope_name + "/" + target_graph_prefix dummy_graph_dup_cnt -= 1 for i in range(1, dup_cnt): target_graph_prefix = scope_name + "/" + target_graph_prefix # e.g., ranger/dummy/ranger/ relu if (dummy_graph_dup_cnt > 0): target_graph_prefix = dummy_scope_name + "/" + target_graph_prefix # e.g., dummy/ranger/dummy/ranger/relu:0 dummy_graph_dup_cnt -= 1 return target_graph_prefix def restore_all_var(sess, scope_name, dup_cnt, all_var, dummy_scope_name, dummy_graph_dup_cnt, OLD_SESS): "need to map back the variable values to the ones under the new scope" target_graph_prefix = get_target_scope_prefix(scope_name, dup_cnt, dummy_scope_name, dummy_graph_dup_cnt) tmp = [] for each in all_var: #print( target_graph_prefix , each.name ) sess.run( tf.assign( sess.graph.get_tensor_by_name(target_graph_prefix + each.name), OLD_SESS.run(OLD_SESS.graph.get_tensor_by_name( each.name)))) def get_op_with_prefix(op_name, dup_cnt, scope_name, dummy_graph_dup_cnt, dummy_scope_name): "Need to call this function to return the name of the ops under the NEW graph (with scope prefix)" "return the name of the duplicated op with prefix, a new scope prefix upon each duplication" op_name = get_target_scope_prefix(scope_name, dup_cnt, dummy_scope_name, dummy_graph_dup_cnt) + op_name return op_name import re def modify_graph(sess, dup_cnt, scope_name, prefix_of_bound_op, dummy_graph_dup_cnt, dummy_scope_name): "Modify the graph def to: 1) remove the nodes from older paths (we only need to keep the latest path)" " and 2) modify the input dependency to only associate with the latest path" graph_def = sess.graph.as_graph_def() target_graph_prefix = get_target_scope_prefix(scope_name, dup_cnt, dummy_scope_name, dummy_graph_dup_cnt) #print('target prefix ==> ', target_graph_prefix, dup_cnt) # Delete nodes from the redundant paths, we only want the most recent path, otherwise the size of graph will explode nodes = [] for node in graph_def.node: if target_graph_prefix in node.name and prefix_of_bound_op not in node.name: # ops to be kept, otherwise removed from graph nodes.append(node) elif (prefix_of_bound_op in node.name): if (dup_cnt != graph_dup_cnt): "this part should keep the new op from the most recent duplication (with lesser prefix)" if (target_graph_prefix not in node.name ): # remove dummy nodes like dummy/op nodes.append(node) else: nodes.append(node) # remove dummy nodes like dummy/dummy/relu if (dummy_scope_name + "/" + dummy_scope_name + "/" in node.name): nodes.remove(node) #print(' ', dup_cnt, dummy_graph_dup_cnt) mod_graph_def = tf.GraphDef() mod_graph_def.node.extend(nodes) "For the newly created op, we need to rewire the input dependency so that it only relies on the latest graph" "because we've only kpet the latest graph in the modified graphdef. " "This is for the restriction op, e.g., tf.maximum(relu_1, 100), where relu_1 is from the PREVIOUS graph" # Delete references to deleted nodes, for node in mod_graph_def.node: inp_names = [] if (prefix_of_bound_op in node.name): # only for the restriction op for inp in node.input: if prefix_of_bound_op in inp or target_graph_prefix in inp: inp_names.append(inp) else: #print(node.name, inp, ' ---> ', (scope_name + "_" + str(dup_cnt-1) + "/" + inp) ) "here because we copy the graghdef from the PREVIOUS graph, it has dependency to the PREVIOUS graph" "so we need to remove this redepency by using input from only the latest path, e.g., test/x3, test_1/test/x3, the" "former will be removed in the above pruning, so we need to replace x3 input as test_1/test/x3 from the current graph" # change the scope prefix to be the one from the latest path bfname = inp if (scope_name in inp): regexp = re.escape(scope_name) + "_\d+/|" + re.escape(scope_name) + "/|" + \ re.escape(dummy_scope_name) + "_\d+/|" + re.escape(dummy_scope_name) + "/" # pattern for "ranger_1/" or "ranger" inp_names.append(target_graph_prefix + re.sub(regexp, "", inp)) afname = target_graph_prefix + re.sub( regexp, "", inp) else: inp_names.append(target_graph_prefix + inp) afname = target_graph_prefix + inp del node.input[:] # delete all the inputs node.input.extend( inp_names) # keep the modified input dependency return mod_graph_def def printgraphdef(graphdef): for each in graphdef.node: print(each.name) def printgraph(sess): ops = [ tensor for op in sess.graph.get_operations() for tensor in op.values() ] a = open("op.txt", "w") for n in ops: a.write(n.name + "\n") # in resenet-18, Relu is renamed as relu act = "relu" op_follow_act = ["MaxPool", "Reshape", "AvgPool"] special_op_follow_act = "concat" up_bound = map(float, [7, 8, 7, 5, 11, 5, 12, 6, 11, 5, 12, 5, 14, 5, 12, 5, 66 ]) # upper bound for restriction low_bound = map(float, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]) # low bound for restriction PREFIX = 'ranger' # scope name in the graph DUMMY_PREFIX = 'dummy' # graph_dup_cnt = 0 # count the number of iteration for duplication, used to track the scope prefix of the new op dummy_graph_dup_cnt = 0 # count the num of dummy graph duplication (for resetting the default graph to contain only the latest path) op_cnt = 0 # count num of op act_cnt = 0 # count num of act check_follow = False # flag for checking the following op (when the current op is ACT) op_to_keep = [ ] # ops to keep while duplicating the graph (we remove the irrelevant ops before duplication, otherwise the graph size will explode) new_op_prefix = "bound_op_prefix" # prefix of the newly created ops for range restriction OLD_SESS = sess # keep the old session all_var = tf.global_variables() # all vars before duplication # get all the operators in the graph ops = [ tensor for op in sess.graph.get_operations() for tensor in op.values() ] graph_def = sess.graph.as_graph_def() "iterate each op in the graph and insert bounding ops" for cur_op in ops: if (act in cur_op.name and ("gradients" not in cur_op.name)): # bounding with tf.name_scope( new_op_prefix ) as scope: # the restricion ops will have the special scope prefix name bound_tensor = sess.graph.get_tensor_by_name( get_op_with_prefix(cur_op.name, graph_dup_cnt, PREFIX, dummy_graph_dup_cnt, DUMMY_PREFIX)) print("bounding: ", bound_tensor, up_bound[act_cnt]) rest = tf.maximum(bound_tensor, low_bound[act_cnt]) rest = tf.minimum(rest, up_bound[act_cnt]) op_to_be_replaced = get_op_with_prefix(cur_op.name, graph_dup_cnt, PREFIX, dummy_graph_dup_cnt, DUMMY_PREFIX) # delete redundant paths in graphdef and modify the input dependency to be depending on the latest path only truncated_graphdef = modify_graph(sess, graph_dup_cnt, PREFIX, new_op_prefix, dummy_graph_dup_cnt, DUMMY_PREFIX) # import the modified graghdef (inserted with bouding ops) into the current graph tf.import_graph_def(truncated_graphdef, name=PREFIX, input_map={op_to_be_replaced: rest}) graph_dup_cnt += 1 "reset the graph to contain only the duplicated path" truncated_graphdef = modify_graph(sess, graph_dup_cnt, PREFIX, new_op_prefix, dummy_graph_dup_cnt, DUMMY_PREFIX) tf.reset_default_graph() sess = tf.Session() sess.as_default() tf.import_graph_def(truncated_graphdef, name=DUMMY_PREFIX) dummy_graph_dup_cnt += 1 check_follow = True # this is a ACT, so we need to check the following op act_cnt = (act_cnt + 1) % len( up_bound ) # count the number of visited ACT (used for the case where there are two copies of ops, one for training and one testing) # this will check the next operator that follows the ACT op elif (check_follow): keep_rest = False # check whether the following op needs to be bounded # this is the case for Maxpool, Avgpool and Reshape for each in op_follow_act: if ( each in cur_op.name and "/shape" not in cur_op.name ): #the latter condition is for checking case like "Reshape_1/shape:0", this shouldn't be bounded keep_rest = True low = low_bound[act_cnt - 1] up = up_bound[act_cnt - 1] break # this is the case for ConCatV2, "axis" is the parameter to the actual op concat if (special_op_follow_act in cur_op.name and ("axis" not in cur_op.name) and ("values" not in cur_op.name)): keep_rest = True low = np.minimum(low_bound[act_cnt - 1], low_bound[act_cnt - 2]) up = np.maximum(up_bound[act_cnt - 1], up_bound[act_cnt - 2]) "bound the values, using either float (default) or int" if (keep_rest): try: with tf.name_scope( new_op_prefix ) as scope: # the restricion ops will have the special scope prefix name bound_tensor = sess.graph.get_tensor_by_name( get_op_with_prefix(cur_op.name, graph_dup_cnt, PREFIX, dummy_graph_dup_cnt, DUMMY_PREFIX)) print("bounding: ", bound_tensor) rest = tf.maximum(bound_tensor, low) rest = tf.minimum(rest, up) except: with tf.name_scope( new_op_prefix ) as scope: # the restricion ops will have the special scope prefix name bound_tensor = sess.graph.get_tensor_by_name( get_op_with_prefix(cur_op.name, graph_dup_cnt, PREFIX, dummy_graph_dup_cnt, DUMMY_PREFIX)) print("bounding: ", bound_tensor) rest = tf.maximum(bound_tensor, int(low)) rest = tf.minimum(rest, int(up)) #print(cur_op, act_cnt) #print(rest.op.node_def,' -----') "replace the input to the tensor, at the palce where we place Ranger, e.g., Ranger(ReLu), then we replace Relu" op_to_be_replaced = get_op_with_prefix(cur_op.name, graph_dup_cnt, PREFIX, dummy_graph_dup_cnt, DUMMY_PREFIX) truncated_graphdef = modify_graph(sess, graph_dup_cnt, PREFIX, new_op_prefix, dummy_graph_dup_cnt, DUMMY_PREFIX) tf.import_graph_def(truncated_graphdef, name=PREFIX, input_map={op_to_be_replaced: rest}) graph_dup_cnt += 1 "reset the graph to contain only the duplicated path" truncated_graphdef = modify_graph(sess, graph_dup_cnt, PREFIX, new_op_prefix, dummy_graph_dup_cnt, DUMMY_PREFIX) tf.reset_default_graph() sess = tf.Session() sess.as_default() tf.import_graph_def(truncated_graphdef, name=DUMMY_PREFIX) dummy_graph_dup_cnt += 1 # check the ops, but not to bound the ops else: check_follow = False # the default setting is not to check the next op # the following ops of the listed operaions will be kept tracking, # becuase the listed ops do not perform actual computation, so the restriction bound still applies oblivious_ops = [ "Const", "truncated_normal", "Variable", "weights", "biases", "dropout" ] if( ("Reshape" in cur_op.name and "/shape" in cur_op.name) or \ ("concat" in cur_op.name and ("axis" in cur_op.name or "values" in cur_op.name) ) ): check_follow = True # we need to check the following op of Reshape/shape:0, concat/axis (these are not the actual reshape/concat ops) else: for ea in oblivious_ops: # we need to check the op follows the listed ops if (ea in cur_op.name): check_follow = True op_cnt += 1 # we need to call modify_graph to modify the input dependency for finalization truncated_graphdef = modify_graph(sess, graph_dup_cnt, PREFIX, new_op_prefix, dummy_graph_dup_cnt, DUMMY_PREFIX) tf.import_graph_def(truncated_graphdef, name=PREFIX) graph_dup_cnt += 1 # restore the variables to the latest path truncated_graphdef = modify_graph(sess, graph_dup_cnt, PREFIX, new_op_prefix, dummy_graph_dup_cnt, DUMMY_PREFIX) tf.reset_default_graph() sess = tf.Session() sess.as_default() #printgraphdef(truncated_graphdef) tf.import_graph_def(truncated_graphdef, name=DUMMY_PREFIX) dummy_graph_dup_cnt += 1 "restore all the variables from the orignial garph to the new graph" restore_all_var(sess, PREFIX, graph_dup_cnt, all_var, DUMMY_PREFIX, dummy_graph_dup_cnt, OLD_SESS) # printgraph(sess) print("Finish graph modification!") print('') "============================================================================================================" "============================================================================================================" "This is the name of the operator to be evaluated, we will find the corresponding one under the Ranger's scope" OP_FOR_EVAL = network.probs new_op_for_eval_name = get_op_with_prefix(OP_FOR_EVAL.op.name, graph_dup_cnt, PREFIX, dummy_graph_dup_cnt, DUMMY_PREFIX) print(new_op_for_eval_name, 'op to be eval') new_op_for_eval = sess.graph.get_tensor_by_name(new_op_for_eval_name + ":0") # you can call this function to check the depenency of the final operator # you should see the bouding ops are inserted into the dependency # NOTE: the printing might contain duplicated output #get_op_dependency(new_op_for_eval.op) # input to eval the results for i in range(2): test_images_val, test_labels_val = OLD_SESS.run( [test_images[0], test_labels[0]]) # evaluation on the old path preds = OLD_SESS.run(OP_FOR_EVAL, feed_dict={ network.is_train: False, images: test_images_val, labels: test_labels_val }) print((np.argsort(np.asarray(preds)[0])[::-1])[0:10]) print('') # evaluation on the new path new_x = sess.graph.get_tensor_by_name( get_op_with_prefix(images.op.name, graph_dup_cnt, PREFIX, dummy_graph_dup_cnt, DUMMY_PREFIX) + ":0") new_y = sess.graph.get_tensor_by_name( get_op_with_prefix(labels.op.name, graph_dup_cnt, PREFIX, dummy_graph_dup_cnt, DUMMY_PREFIX) + ":0") new_is_train = sess.graph.get_tensor_by_name( get_op_with_prefix(network.is_train.op.name, graph_dup_cnt, PREFIX, dummy_graph_dup_cnt, DUMMY_PREFIX) + ":0") #new_prob2 = sess.graph.get_tensor_by_name( get_op_with_prefix(model.prob2.op.name, graph_dup_cnt, PREFIX, dummy_graph_dup_cnt, DUMMY_PREFIX)+":0") preds = sess.run(new_op_for_eval, feed_dict={ new_is_train: False, new_x: test_images_val, new_y: test_labels_val }) print((np.argsort(np.asarray(preds)[0])[::-1])[0:10])