コード例 #1
0
def distorted_inputs():
    """Construct distorted input for CIFAR training using the Reader ops.
  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.
  Raises:
    ValueError: If no data_dir
  """
    if not FLAGS.data_dir:
        raise ValueError('Please supply a data_dir')
    images, labels = imagenet_input.distorted_inputs(
        data_dir=FLAGS.data_dir, num_batches=FLAGS.batch_size)
    if FLAGS.use_fp16:
        images = tf.cast(images, tf.float16)
        labels = tf.cast(labels, tf.float16)
    return images, labels
コード例 #2
0
def distorted_inputs(data_class, shuffle=True):
    """Construct input for training using the Reader ops.

    Args:
      data_class: string, indicating if one should use the 'train' or 'eval' or 'test' data set.
      shuffle: bool, to shuffle dataset list to read

    Returns:
      images: Images. 4D tensor of [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3] size.
      labels: Labels. 1D tensor of [BATCH_SIZE] size.

    Raises:
      ValueError: If no data_dir
    """
    return data_input.distorted_inputs(data_class=data_class,
                                       batch_size=BATCH_SIZE,
                                       shuffle=shuffle)
コード例 #3
0
ファイル: nac_net.py プロジェクト: pkamath/envelopenets
    def distorted_inputs(self):
        """Construct distorted input for a given dataset using the Reader ops.

        Returns:
            images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
            labels: Labels. 1D tensor of [batch_size] size.

        Raises:
            ValueError: If no data_dir
        """
        if not FLAGS.data_dir:
            raise ValueError('Please supply a data_dir')
        if FLAGS.dataset == 'cifar10':
            data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
            images, labels = cifar10_input.distorted_inputs(
                data_dir=data_dir, batch_size=FLAGS.batch_size)
        elif FLAGS.dataset == 'imagenet':
            images, labels = imagenet_input.distorted_inputs()
        if FLAGS.use_fp16:
            images = tf.cast(images, tf.float16)
            labels = tf.cast(labels, tf.float16)
        return images, labels
コード例 #4
0
ファイル: train.py プロジェクト: nlpng/resnet-18-tensorflow
def train():
    print('[Dataset Configuration]')
    print('\tImageNet training root: %s' % FLAGS.train_image_root)
    print('\tImageNet training list: %s' % FLAGS.train_dataset)
    print('\tImageNet val root: %s' % FLAGS.val_image_root)
    print('\tImageNet val list: %s' % FLAGS.val_dataset)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of val images: %d' % FLAGS.num_val_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tNumber of GPUs: %d' % FLAGS.num_gpus)
    print('\tBasemodel file: %s' % FLAGS.basemodel)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per validation: %d' % FLAGS.val_interval)
    print('\tSteps during validation: %d' % FLAGS.val_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of ImageNet
        import multiprocessing
        num_threads = multiprocessing.cpu_count() / FLAGS.num_gpus
        print('Load ImageNet dataset(%d threads)' % num_threads)
        with tf.device('/cpu:0'):
            print('\tLoading training data from %s' % FLAGS.train_dataset)
            with tf.variable_scope('train_image'):
                train_images, train_labels = data_input.distorted_inputs(
                    FLAGS.train_image_root,
                    FLAGS.train_dataset,
                    FLAGS.batch_size,
                    True,
                    num_threads=num_threads,
                    num_sets=FLAGS.num_gpus)
            print('\tLoading validation data from %s' % FLAGS.val_dataset)
            with tf.variable_scope('test_image'):
                val_images, val_labels = data_input.inputs(
                    FLAGS.val_image_root,
                    FLAGS.val_dataset,
                    FLAGS.batch_size,
                    False,
                    num_threads=num_threads,
                    num_sets=FLAGS.num_gpus)
            tf.summary.image('images', train_images[0][:2])

        # Build model
        lr_decay_steps = map(float, FLAGS.lr_step_epoch.split(','))
        lr_decay_steps = list(
            map(int, [
                s * FLAGS.num_train_instance / FLAGS.batch_size /
                FLAGS.num_gpus for s in lr_decay_steps
            ]))
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_gpus=FLAGS.num_gpus,
                            num_classes=FLAGS.num_classes,
                            weight_decay=FLAGS.l2_weight,
                            momentum=FLAGS.momentum,
                            finetune=FLAGS.finetune)
        network_train = resnet.ResNet(hp,
                                      train_images,
                                      train_labels,
                                      global_step,
                                      name="train")
        network_train.build_model()
        network_train.build_train_op()
        train_summary_op = tf.summary.merge_all()  # Summaries(training)
        network_val = resnet.ResNet(hp,
                                    val_images,
                                    val_labels,
                                    global_step,
                                    name="val",
                                    reuse_weights=True)
        network_val.build_model()
        print('Number of Weights: %d' % network_train._weights)
        print('FLOPs: %d' % network_train._flops)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            allow_soft_placement=False,
            # allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000)
        if FLAGS.checkpoint is not None:
            print('Load checkpoint %s' % FLAGS.checkpoint)
            saver.restore(sess, FLAGS.checkpoint)
            init_step = global_step.eval(session=sess)
        elif FLAGS.basemodel:
            # Define a different saver to save model checkpoints
            print('Load parameters from basemodel %s' % FLAGS.basemodel)
            variables = tf.global_variables()
            vars_restore = [
                var for var in variables
                if not "Momentum" in var.name and not "global_step" in var.name
            ]
            saver_restore = tf.train.Saver(vars_restore, max_to_keep=10000)
            saver_restore.restore(sess, FLAGS.basemodel)
        else:
            print(
                'No checkpoint file of basemodel found. Start from the scratch.'
            )

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)

        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.summary.FileWriter(
            os.path.join(FLAGS.train_dir, str(global_step.eval(session=sess))),
            sess.graph)

        # Training!
        val_best_acc = 0.0
        for step in range(init_step, FLAGS.max_steps):
            # val
            if step % FLAGS.val_interval == 0:
                val_loss, val_acc = 0.0, 0.0
                for i in range(FLAGS.val_iter):
                    loss_value, acc_value = sess.run(
                        [network_val.loss, network_val.acc],
                        feed_dict={network_val.is_train: False})
                    val_loss += loss_value
                    val_acc += acc_value
                val_loss /= FLAGS.val_iter
                val_acc /= FLAGS.val_iter
                val_best_acc = max(val_best_acc, val_acc)
                format_str = ('%s: (val)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), step, val_loss, val_acc))

                val_summary = tf.Summary()
                val_summary.value.add(tag='val/loss', simple_value=val_loss)
                val_summary.value.add(tag='val/acc', simple_value=val_acc)
                val_summary.value.add(tag='val/best_acc',
                                      simple_value=val_best_acc)
                summary_writer.add_summary(val_summary, step)
                summary_writer.flush()

            # Train
            lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps,
                              step)
            start_time = time.time()
            # For timeline profiling
            # if step == 153:
            # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            # run_metadata = tf.RunMetadata()
            # _, loss_value, acc_value, train_summary_str = \
            # sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op],
            # feed_dict={network_train.is_train:True, network_train.lr:lr_value}
            # , options=run_options, run_metadata=run_metadata)
            # # Create the Timeline object, and write it to a json
            # tl = timeline.Timeline(run_metadata.step_stats)
            # ctf = tl.generate_chrome_trace_format()
            # with open('timeline.json', 'w') as f:
            # f.write(ctf)
            # print('Wrote the timeline profile of %d iter training on %s' %(step, 'timeline.json'))
            # else:
            # _, loss_value, acc_value, train_summary_str = \
            # sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op],
            # feed_dict={network_train.is_train:True, network_train.lr:lr_value})
            _, loss_value, acc_value, train_summary_str = \
                    sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op],
                            feed_dict={network_train.is_train:True, network_train.lr:lr_value})
            duration = time.time() - start_time

            assert not np.isnan(loss_value)

            # Display & Summary(training)
            if step % FLAGS.display == 0 or step < 10:
                num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, acc_value, lr_value,
                       examples_per_sec, sec_per_batch))
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

            if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
                char = sys.stdin.read(1)
                if char == 'b':
                    embed()
コード例 #5
0
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.platform import gfile
from tensorflow.python.util import nest
import datasets
import imagenet_input

import numpy as np

data_dir = '/spartan/tf/train'
num_batches = 16

cpu_device = '/cpu:0'

with tf.Graph().as_default() as g:
    global_step = tf.compat.v1.train.get_or_create_global_step()

    with tf.device('/cpu:0'):
        images, labels = imagenet_input.distorted_inputs(data_dir, num_batches)

    summary_op = tf.compat.v1.summary.merge_all()

    # Sets up a timestamped log directory.
    #logdir = "logs/train_data/" + datetime.now().strftime("%Y%m%d")
    #summary_writer = tf.compat.v1.summary.FileWriter(logdir, g)

    with tf.compat.v1.Session() as sess:
        final_batches = sess.run(images)
        final_labels = sess.run(labels)
        print(final_batches)
        print(final_labels)
コード例 #6
0
ファイル: train.py プロジェクト: dalgu90/splitnet-imagenet22k
def train():
    print('[Dataset Configuration]')
    print('\tImageNet training root: %s' % FLAGS.train_image_root)
    print('\tImageNet training list: %s' % FLAGS.train_dataset)
    print('\tImageNet val root: %s' % FLAGS.val_image_root)
    print('\tImageNet val list: %s' % FLAGS.val_dataset)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of val images: %d' % FLAGS.num_val_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tNumber of GPUs: %d' % FLAGS.num_gpus)
    print('\tNumber of Groups: %d-%d-%d' %
          (FLAGS.ngroups3, FLAGS.ngroups2, FLAGS.ngroups1))
    print('\tBasemodel file: %s' % FLAGS.basemodel)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tOverlap loss weight: %f' % FLAGS.gamma1)
    print('\tWeight split loss weight: %f' % FLAGS.gamma2)
    print('\tUniform loss weight: %f' % FLAGS.gamma3)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tNo update on BN scale parameter: %d' % FLAGS.bn_no_scale)
    print('\tWeighted split loss: %d' % FLAGS.weighted_group_loss)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per validation: %d' % FLAGS.val_interval)
    print('\tSteps during validation: %d' % FLAGS.val_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of ImageNet
        import multiprocessing
        num_threads = multiprocessing.cpu_count() / FLAGS.num_gpus
        print('Load ImageNet dataset(%d threads)' % num_threads)
        with tf.device('/cpu:0'):
            print('\tLoading training data from %s' % FLAGS.train_dataset)
            with tf.variable_scope('train_image'):
                train_images, train_labels = data_input.distorted_inputs(
                    FLAGS.train_image_root,
                    FLAGS.train_dataset,
                    FLAGS.batch_size,
                    True,
                    num_threads=num_threads,
                    num_sets=FLAGS.num_gpus)
            # tf.summary.image('images', train_images[0])
            print('\tLoading validation data from %s' % FLAGS.val_dataset)
            with tf.variable_scope('test_image'):
                val_images, val_labels = data_input.inputs(
                    FLAGS.val_image_root,
                    FLAGS.val_dataset,
                    FLAGS.batch_size,
                    False,
                    num_threads=num_threads,
                    num_sets=FLAGS.num_gpus)

        # Build model
        lr_decay_steps = map(float, FLAGS.lr_step_epoch.split(','))
        lr_decay_steps = map(int, [
            s * FLAGS.num_train_instance / FLAGS.batch_size / FLAGS.num_gpus
            for s in lr_decay_steps
        ])
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_gpus=FLAGS.num_gpus,
                            num_classes=FLAGS.num_classes,
                            weight_decay=FLAGS.l2_weight,
                            ngroups1=FLAGS.ngroups1,
                            ngroups2=FLAGS.ngroups2,
                            ngroups3=FLAGS.ngroups3,
                            gamma1=FLAGS.gamma1,
                            gamma2=FLAGS.gamma2,
                            gamma3=FLAGS.gamma3,
                            momentum=FLAGS.momentum,
                            bn_no_scale=FLAGS.bn_no_scale,
                            weighted_group_loss=FLAGS.weighted_group_loss,
                            finetune=FLAGS.finetune)
        network_train = resnet.ResNet(hp,
                                      train_images,
                                      train_labels,
                                      global_step,
                                      name="train")
        network_train.build_model()
        network_train.build_train_op()
        train_summary_op = tf.summary.merge_all()  # Summaries(training)
        network_val = resnet.ResNet(hp,
                                    val_images,
                                    val_labels,
                                    global_step,
                                    name="val",
                                    reuse_weights=True)
        network_val.build_model()
        print('Number of Weights: %d' % network_train._weights)
        print('FLOPs: %d' % network_train._flops)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            # allow_soft_placement=False,
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000)
        if FLAGS.checkpoint is not None:
            print('Load checkpoint %s' % FLAGS.checkpoint)
            saver.restore(sess, FLAGS.checkpoint)
            init_step = global_step.eval(session=sess)
        elif FLAGS.basemodel:
            # Define a different saver to save model checkpoints
            print('Load parameters from basemodel %s' % FLAGS.basemodel)
            variables = tf.global_variables()
            vars_restore = [
                var for var in variables if not "Momentum" in var.name
                and not "group" in var.name and not "global_step" in var.name
            ]
            saver_restore = tf.train.Saver(vars_restore, max_to_keep=10000)
            saver_restore.restore(sess, FLAGS.basemodel)
        else:
            print(
                'No checkpoint file of basemodel found. Start from the scratch.'
            )

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)

        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.summary.FileWriter(
            os.path.join(FLAGS.train_dir, str(global_step.eval(session=sess))),
            sess.graph)

        # Training!
        val_best_acc = 0.0
        for step in xrange(init_step, FLAGS.max_steps):
            # val
            if step % FLAGS.val_interval == 0:
                val_loss, val_acc = 0.0, 0.0
                for i in range(FLAGS.val_iter):
                    loss_value, acc_value = sess.run(
                        [network_val.loss, network_val.acc],
                        feed_dict={network_val.is_train: False})
                    val_loss += loss_value
                    val_acc += acc_value
                val_loss /= FLAGS.val_iter
                val_acc /= FLAGS.val_iter
                val_best_acc = max(val_best_acc, val_acc)
                format_str = ('%s: (val)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), step, val_loss, val_acc))

                val_summary = tf.Summary()
                val_summary.value.add(tag='val/loss', simple_value=val_loss)
                val_summary.value.add(tag='val/acc', simple_value=val_acc)
                val_summary.value.add(tag='val/best_acc',
                                      simple_value=val_best_acc)
                summary_writer.add_summary(val_summary, step)
                summary_writer.flush()

            # Train
            lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps,
                              step)
            start_time = time.time()
            if step == 153:
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()
                _, loss_value, acc_value, train_summary_str = \
                        sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op],
                                 feed_dict={network_train.is_train:True, network_train.lr:lr_value}
                                 , options=run_options, run_metadata=run_metadata)
                _ = sess.run(network_train.validity_op)
                # Create the Timeline object, and write it to a json
                tl = timeline.Timeline(run_metadata.step_stats)
                ctf = tl.generate_chrome_trace_format()
                with open('timeline.json', 'w') as f:
                    f.write(ctf)
                print('Wrote the timeline profile of %d iter training on %s' %
                      (step, 'timeline.json'))
            else:
                _, loss_value, acc_value, train_summary_str = \
                        sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op],
                                feed_dict={network_train.is_train:True, network_train.lr:lr_value})
                _ = sess.run(network_train.validity_op)
            duration = time.time() - start_time

            assert not np.isnan(loss_value)

            # Display & Summary(training)
            if step % FLAGS.display == 0 or step < 10:
                num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, acc_value, lr_value,
                       examples_per_sec, sec_per_batch))
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

            # Does it work correctly?
            # if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
            # char = sys.stdin.read(1)
            # if char == 'b':
            # embed()

            # Add weights and groupings visualization
            filters = [64, [64, 256], [128, 512], [256, 1024], [512, 2048]]
            if FLAGS.group_summary_interval is not None:
                if step % FLAGS.group_summary_interval == 0:
                    img_summaries = []

                    if FLAGS.ngroups1 > 1:
                        logits_weights = get_var_value('logits/fc/weights',
                                                       sess)
                        split_p1 = get_var_value('group/split_p1/q', sess)
                        split_q1 = get_var_value('group/split_q1/q', sess)
                        feature_indices = np.argsort(
                            np.argmax(split_p1, axis=0))
                        class_indices = np.argsort(np.argmax(split_q1, axis=0))

                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_p1[:, feature_indices],
                                          20,
                                          axis=0), 'split_p1'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_q1[:, class_indices],
                                          200,
                                          axis=0), 'split_q1'))
                        img_summaries.append(
                            img_to_summary(
                                np.abs(logits_weights[feature_indices, :]
                                       [:, class_indices]), 'logits'))

                    if FLAGS.ngroups2 > 1:
                        conv5_1_shortcut = get_var_value(
                            'conv5_1/conv_shortcut/kernel', sess)
                        conv5_1_conv_1 = get_var_value('conv5_1/conv_1/kernel',
                                                       sess)
                        conv5_1_conv_2 = get_var_value('conv5_1/conv_2/kernel',
                                                       sess)
                        conv5_1_conv_3 = get_var_value('conv5_1/conv_3/kernel',
                                                       sess)
                        conv5_2_conv_1 = get_var_value('conv5_2/conv_1/kernel',
                                                       sess)
                        conv5_2_conv_2 = get_var_value('conv5_2/conv_2/kernel',
                                                       sess)
                        conv5_2_conv_3 = get_var_value('conv5_2/conv_3/kernel',
                                                       sess)
                        conv5_3_conv_1 = get_var_value('conv5_3/conv_1/kernel',
                                                       sess)
                        conv5_3_conv_2 = get_var_value('conv5_3/conv_2/kernel',
                                                       sess)
                        conv5_3_conv_3 = get_var_value('conv5_3/conv_3/kernel',
                                                       sess)
                        split_p2 = get_var_value('group/split_p2/q', sess)
                        split_q2 = _merge_split_q(
                            split_p1,
                            _get_even_merge_idxs(FLAGS.ngroups1,
                                                 FLAGS.ngroups2))
                        split_r211 = get_var_value('group/split_r211/q', sess)
                        split_r212 = get_var_value('group/split_r212/q', sess)
                        split_r221 = get_var_value('group/split_r221/q', sess)
                        split_r222 = get_var_value('group/split_r222/q', sess)
                        split_r231 = get_var_value('group/split_r231/q', sess)
                        split_r232 = get_var_value('group/split_r232/q', sess)
                        feature_indices1 = np.argsort(
                            np.argmax(split_p2, axis=0))
                        feature_indices2 = np.argsort(
                            np.argmax(split_q2, axis=0))
                        feature_indices3 = np.argsort(
                            np.argmax(split_r211, axis=0))
                        feature_indices4 = np.argsort(
                            np.argmax(split_r212, axis=0))
                        feature_indices5 = np.argsort(
                            np.argmax(split_r221, axis=0))
                        feature_indices6 = np.argsort(
                            np.argmax(split_r222, axis=0))
                        feature_indices7 = np.argsort(
                            np.argmax(split_r231, axis=0))
                        feature_indices8 = np.argsort(
                            np.argmax(split_r232, axis=0))
                        conv5_1_shortcut_img = np.abs(
                            conv5_1_shortcut[:, :, feature_indices1, :]
                            [:, :, :,
                             feature_indices2].transpose([2, 0, 3, 1]).reshape(
                                 filters[3][1], filters[4][1]))
                        conv5_1_conv_1_img = np.abs(
                            conv5_1_conv_1[:, :, feature_indices1, :]
                            [:, :, :,
                             feature_indices3].transpose([2, 0, 3, 1]).reshape(
                                 filters[3][1], filters[4][0]))
                        conv5_1_conv_2_img = np.abs(
                            conv5_1_conv_2[:, :, feature_indices3, :]
                            [:, :, :,
                             feature_indices4].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][0] * 3, filters[4][0] * 3))
                        conv5_1_conv_3_img = np.abs(
                            conv5_1_conv_3[:, :, feature_indices4, :]
                            [:, :, :,
                             feature_indices2].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][0], filters[4][1]))
                        conv5_2_conv_1_img = np.abs(
                            conv5_2_conv_1[:, :, feature_indices2, :]
                            [:, :, :,
                             feature_indices5].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][1], filters[4][0]))
                        conv5_2_conv_2_img = np.abs(
                            conv5_2_conv_2[:, :, feature_indices5, :]
                            [:, :, :,
                             feature_indices6].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][0] * 3, filters[4][0] * 3))
                        conv5_2_conv_3_img = np.abs(
                            conv5_2_conv_3[:, :, feature_indices6, :]
                            [:, :, :,
                             feature_indices2].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][0], filters[4][1]))
                        conv5_3_conv_1_img = np.abs(
                            conv5_3_conv_1[:, :, feature_indices2, :]
                            [:, :, :,
                             feature_indices7].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][1], filters[4][0]))
                        conv5_3_conv_2_img = np.abs(
                            conv5_3_conv_2[:, :, feature_indices7, :]
                            [:, :, :,
                             feature_indices8].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][0] * 3, filters[4][0] * 3))
                        conv5_3_conv_3_img = np.abs(
                            conv5_3_conv_3[:, :, feature_indices8, :]
                            [:, :, :,
                             feature_indices2].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][0], filters[4][1]))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_p2[:, feature_indices1],
                                          20,
                                          axis=0), 'split_p2'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r211[:, feature_indices3],
                                          20,
                                          axis=0), 'split_r211'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r212[:, feature_indices4],
                                          20,
                                          axis=0), 'split_r212'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r221[:, feature_indices5],
                                          20,
                                          axis=0), 'split_r221'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r222[:, feature_indices6],
                                          20,
                                          axis=0), 'split_r222'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r231[:, feature_indices7],
                                          20,
                                          axis=0), 'split_r231'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r232[:, feature_indices8],
                                          20,
                                          axis=0), 'split_r232'))
                        img_summaries.append(
                            img_to_summary(conv5_1_shortcut_img,
                                           'conv5_1/shortcut'))
                        img_summaries.append(
                            img_to_summary(conv5_1_conv_1_img,
                                           'conv5_1/conv_1'))
                        img_summaries.append(
                            img_to_summary(conv5_1_conv_2_img,
                                           'conv5_1/conv_2'))
                        img_summaries.append(
                            img_to_summary(conv5_1_conv_3_img,
                                           'conv5_1/conv_3'))
                        img_summaries.append(
                            img_to_summary(conv5_2_conv_1_img,
                                           'conv5_2/conv_1'))
                        img_summaries.append(
                            img_to_summary(conv5_2_conv_2_img,
                                           'conv5_2/conv_2'))
                        img_summaries.append(
                            img_to_summary(conv5_2_conv_3_img,
                                           'conv5_2/conv_3'))
                        img_summaries.append(
                            img_to_summary(conv5_3_conv_1_img,
                                           'conv5_3/conv_1'))
                        img_summaries.append(
                            img_to_summary(conv5_3_conv_2_img,
                                           'conv5_3/conv_2'))
                        img_summaries.append(
                            img_to_summary(conv5_3_conv_3_img,
                                           'conv5_3/conv_3'))

                    # if FLAGS.ngroups3 > 1:
                    # conv4_1_shortcut = get_var_value('conv4_1/conv_shortcut/kernel', sess)
                    # conv4_1_conv_1 = get_var_value('conv4_1/conv_1/kernel', sess)
                    # conv4_1_conv_2 = get_var_value('conv4_1/conv_2/kernel', sess)
                    # conv4_2_conv_1 = get_var_value('conv4_2/conv_1/kernel', sess)
                    # conv4_2_conv_2 = get_var_value('conv4_2/conv_2/kernel', sess)
                    # split_p3 = get_var_value('group/split_p3/q', sess)
                    # split_q3 = _merge_split_q(split_p2, _get_even_merge_idxs(FLAGS.ngroups2, FLAGS.ngroups3))
                    # split_r31 = get_var_value('group/split_r31/q', sess)
                    # split_r32 = get_var_value('group/split_r32/q', sess)
                    # feature_indices1 = np.argsort(np.argmax(split_p3, axis=0))
                    # feature_indices2 = np.argsort(np.argmax(split_q3, axis=0))
                    # feature_indices3 = np.argsort(np.argmax(split_r31, axis=0))
                    # feature_indices4 = np.argsort(np.argmax(split_r32, axis=0))
                    # conv4_1_shortcut_img = np.abs(conv4_1_shortcut[:,:,feature_indices1,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[2], filters[3]))
                    # conv4_1_conv_1_img = np.abs(conv4_1_conv_1[:,:,feature_indices1,:][:,:,:,feature_indices3].transpose([2,0,3,1]).reshape(filters[2] * 3, filters[3] * 3))
                    # conv4_1_conv_2_img = np.abs(conv4_1_conv_2[:,:,feature_indices3,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[3] * 3, filters[3] * 3))
                    # conv4_2_conv_1_img = np.abs(conv4_2_conv_1[:,:,feature_indices2,:][:,:,:,feature_indices4].transpose([2,0,3,1]).reshape(filters[3] * 3, filters[3] * 3))
                    # conv4_2_conv_2_img = np.abs(conv4_2_conv_2[:,:,feature_indices4,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[3] * 3, filters[3] * 3))
                    # img_summaries.append(img_to_summary(np.repeat(split_p3[:, feature_indices1], 20, axis=0), 'split_p3'))
                    # img_summaries.append(img_to_summary(np.repeat(split_r31[:, feature_indices3], 20, axis=0), 'split_r31'))
                    # img_summaries.append(img_to_summary(np.repeat(split_r32[:, feature_indices4], 20, axis=0), 'split_r32'))
                    # img_summaries.append(img_to_summary(conv4_1_shortcut_img, 'conv4_1/shortcut'))
                    # img_summaries.append(img_to_summary(conv4_1_conv_1_img, 'conv4_1/conv_1'))
                    # img_summaries.append(img_to_summary(conv4_1_conv_2_img, 'conv4_1/conv_2'))
                    # img_summaries.append(img_to_summary(conv4_2_conv_1_img, 'conv4_2/conv_1'))
                    # img_summaries.append(img_to_summary(conv4_2_conv_2_img, 'conv4_2/conv_2'))

                    if img_summaries:
                        img_summary = tf.Summary(value=img_summaries)
                        summary_writer.add_summary(img_summary, step)
                        summary_writer.flush()
コード例 #7
0
ファイル: train.py プロジェクト: dalgu90/splitnet-imagenet
def train():
    print('[Dataset Configuration]')
    print('\tImageNet training root: %s' % FLAGS.train_image_root)
    print('\tImageNet training list: %s' % FLAGS.train_dataset)
    print('\tImageNet test root: %s' % FLAGS.test_image_root)
    print('\tImageNet test list: %s' % FLAGS.test_dataset)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tNumber of GPUs: %d' % FLAGS.num_gpu)
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tSplitted Network: %s' % FLAGS.split)
    if FLAGS.split:
        print('\tClustering path: %s' % FLAGS.cluster_path)
        print('\tNo logit map: %s' % FLAGS.no_logit_map)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per testing: %d' % FLAGS.test_interval)
    print('\tSteps during testing: %d' % FLAGS.test_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with open(FLAGS.cluster_path) as fd:
        clustering = pickle.load(fd)

    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of CIFAR-100
        with tf.variable_scope('train_image'):
            train_images, train_labels = data_input.distorted_inputs(
                FLAGS.train_image_root, FLAGS.train_dataset,
                FLAGS.batch_size * FLAGS.num_gpu, True)
        with tf.variable_scope('test_image'):
            test_images, test_labels = data_input.inputs(
                FLAGS.test_image_root, FLAGS.test_dataset,
                FLAGS.batch_size * FLAGS.num_gpu, False)

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(tf.float32, [
            FLAGS.batch_size * FLAGS.num_gpu, data_input.IMAGE_HEIGHT,
            data_input.IMAGE_WIDTH, 3
        ])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size * FLAGS.num_gpu])

        # Build model
        lr_decay_steps = map(float, FLAGS.lr_step_epoch.split(','))
        lr_decay_steps = map(int, [
            s * FLAGS.num_train_instance / FLAGS.batch_size / FLAGS.num_gpu
            for s in lr_decay_steps
        ])
        print('Learning rate decays at iter: %s' % str(lr_decay_steps))
        hp = resnet.HParams(num_gpu=FLAGS.num_gpu,
                            batch_size=FLAGS.batch_size,
                            split=FLAGS.split,
                            num_classes=FLAGS.num_classes,
                            weight_decay=FLAGS.l2_weight,
                            momentum=FLAGS.momentum,
                            no_logit_map=FLAGS.no_logit_map)
        network = resnet.ResNet(hp, images, labels, global_step)
        if FLAGS.split:
            network.set_clustering(clustering)
        network.build_model()
        print('%d flops' % network._flops)
        print('%d params' % network._weights)
        network.build_train_op()

        # Summaries(training)
        train_summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        # saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        saver = tf.train.Saver(tf.all_variables(),
                               max_to_keep=10000,
                               write_version=tf.train.SaverDef.V2)
        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print('\tRestore from %s' % ckpt.model_checkpoint_path)
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            init_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        else:
            ckpt_base = tf.train.get_checkpoint_state(FLAGS.baseline_dir)
            if ckpt_base and ckpt_base.model_checkpoint_path:
                # Check loadable variables(variable with same name and same shape) and load them only
                print('No checkpoint file found. Start from the baseline.')
                loadable_vars = utils._get_loadable_vars(
                    ckpt_base.model_checkpoint_path, verbose=True)
                # saver_base = tf.train.Saver(loadable_vars)
                saver_base = tf.train.Saver(loadable_vars,
                                            write_version=tf.train.SaverDef.V2)
                saver_base.restore(sess, ckpt_base.model_checkpoint_path)
            else:
                print('No checkpoint file found. Start from the scratch.')

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)
        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

        # Training!
        test_best_acc = 0.0
        for step in xrange(init_step, FLAGS.max_steps):
            # Test
            if step % FLAGS.test_interval == 0:
                test_loss, test_acc = 0.0, 0.0
                for i in range(FLAGS.test_iter):
                    test_images_val, test_labels_val = sess.run(
                        [test_images, test_labels])
                    loss_value, acc_value = sess.run(
                        [network.loss, network.acc],
                        feed_dict={
                            network.is_train: False,
                            images: test_images_val,
                            labels: test_labels_val
                        })
                    test_loss += loss_value
                    test_acc += acc_value
                test_loss /= FLAGS.test_iter
                test_acc /= FLAGS.test_iter
                test_best_acc = max(test_best_acc, test_acc)
                format_str = ('%s: (Test)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), step, test_loss, test_acc))

                test_summary = tf.Summary()
                test_summary.value.add(tag='test/loss', simple_value=test_loss)
                test_summary.value.add(tag='test/acc', simple_value=test_acc)
                test_summary.value.add(tag='test/best_acc',
                                       simple_value=test_best_acc)
                summary_writer.add_summary(test_summary, step)
                summary_writer.flush()

            # Train
            lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps,
                              step)
            start_time = time.time()
            train_images_val, train_labels_val = sess.run(
                [train_images, train_labels])
            _, loss_value, acc_value, train_summary_str = \
                    sess.run([network.train_op, network.loss, network.acc, train_summary_op],
                             feed_dict={network.is_train:True, network.lr:lr_value, images:train_images_val, labels:train_labels_val})
            duration = time.time() - start_time

            assert not np.isnan(loss_value)

            # Display & Summary(training)
            if step % FLAGS.display == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, acc_value, lr_value,
                       examples_per_sec, sec_per_batch))
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
コード例 #8
0
def train():
    print('[Dataset Configuration]')
    print('\tImageNet training root: %s' % FLAGS.train_image_root)
    print('\tImageNet training list: %s' % FLAGS.train_dataset)
    print('\tImageNet val root: %s' % FLAGS.val_image_root)
    print('\tImageNet val list: %s' % FLAGS.val_dataset)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of val images: %d' % FLAGS.num_val_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tNumber of GPUs: %d' % FLAGS.num_gpus)
    print('\tNumber of Groups: %d-%d-%d' % (FLAGS.ngroups3, FLAGS.ngroups2, FLAGS.ngroups1))
    print('\tBasemodel file: %s' % FLAGS.basemodel)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per validation: %d' % FLAGS.val_interval)
    print('\tSteps during validation: %d' % FLAGS.val_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)


    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of ImageNet
        import multiprocessing
        num_threads = multiprocessing.cpu_count() / FLAGS.num_gpus
        print('Load ImageNet dataset(%d threads)' % num_threads)
        with tf.device('/cpu:0'):
            print('\tLoading training data from %s' % FLAGS.train_dataset)
            with tf.variable_scope('train_image'):
                train_images, train_labels = data_input.distorted_inputs(FLAGS.train_image_root, FLAGS.train_dataset
                                               , FLAGS.batch_size, True, num_threads=num_threads, num_sets=FLAGS.num_gpus)
            # tf.summary.image('images', train_images[0])
            print('\tLoading validation data from %s' % FLAGS.val_dataset)
            with tf.variable_scope('test_image'):
                val_images, val_labels = data_input.inputs(FLAGS.val_image_root, FLAGS.val_dataset
                                               , FLAGS.batch_size, False, num_threads=num_threads, num_sets=FLAGS.num_gpus)

        # Get splitted params
        if not FLAGS.basemodel:
            print('No basemodel found to load split params')
            sys.exit(-1)
        else:
            print('Load split params from %s' % FLAGS.basemodel)

            def get_perms(q_name, ngroups):
                split_q = reader.get_tensor(q_name)
                q_amax = np.argmax(split_q, axis=0)
                return [np.where(q_amax == i)[0] for i in range(ngroups)]

            reader = tf.train.NewCheckpointReader(FLAGS.basemodel)
            split_params = {}

            print('\tlogits...')
            base_logits_w = reader.get_tensor('logits/fc/weights')
            base_logits_b = reader.get_tensor('logits/fc/biases')
            split_p1_idxs = get_perms('group/split_p1/q', FLAGS.ngroups1)
            split_q1_idxs = get_perms('group/split_q1/q', FLAGS.ngroups1)

            logits_params = {'weights':[], 'biases':[], 'input_perms':[], 'output_perms':[]}
            for i in range(FLAGS.ngroups1):
                logits_params['weights'].append(base_logits_w[split_p1_idxs[i], :][:, split_q1_idxs[i]])
                logits_params['biases'].append(base_logits_b[split_q1_idxs[i]])
            logits_params['input_perms'] = split_p1_idxs
            logits_params['output_perms'] = split_q1_idxs
            split_params['logits'] = logits_params

            if FLAGS.ngroups2 > 1:
                print('\tconv5_x...')
                base_conv5_1_shortcut_k = reader.get_tensor('conv5_1/shortcut/kernel')
                base_conv5_1_conv1_k = reader.get_tensor('conv5_1/conv_1/kernel')
                base_conv5_1_conv2_k = reader.get_tensor('conv5_1/conv_2/kernel')
                base_conv5_2_conv1_k = reader.get_tensor('conv5_2/conv_1/kernel')
                base_conv5_2_conv2_k = reader.get_tensor('conv5_2/conv_2/kernel')
                split_p2_idxs = get_perms('group/split_p2/q', FLAGS.ngroups2)
                split_q2_idxs = _merge_split_idxs(split_p1_idxs, _get_even_merge_idxs(FLAGS.ngroups1, FLAGS.ngroups2))
                split_r21_idxs = get_perms('group/split_r21/q', FLAGS.ngroups2)
                split_r22_idxs = get_perms('group/split_r22/q', FLAGS.ngroups2)

                conv5_1_params = {'shortcut':[], 'conv1':[], 'conv2':[], 'p_perms':[], 'q_perms':[], 'r_perms':[]}
                for i in range(FLAGS.ngroups2):
                    conv5_1_params['shortcut'].append(base_conv5_1_shortcut_k[:,:,split_p2_idxs[i],:][:,:,:,split_q2_idxs[i]])
                    conv5_1_params['conv1'].append(base_conv5_1_conv1_k[:,:,split_p2_idxs[i],:][:,:,:,split_r21_idxs[i]])
                    conv5_1_params['conv2'].append(base_conv5_1_conv2_k[:,:,split_r21_idxs[i],:][:,:,:,split_q2_idxs[i]])
                conv5_1_params['p_perms'] = split_p2_idxs
                conv5_1_params['q_perms'] = split_q2_idxs
                conv5_1_params['r_perms'] = split_r21_idxs
                split_params['conv5_1'] = conv5_1_params

                conv5_2_params = {'conv1':[], 'conv2':[], 'p_perms':[], 'r_perms':[]}
                for i in range(FLAGS.ngroups2):
                    conv5_2_params['conv1'].append(base_conv5_2_conv1_k[:,:,split_q2_idxs[i],:][:,:,:,split_r22_idxs[i]])
                    conv5_2_params['conv2'].append(base_conv5_2_conv2_k[:,:,split_r22_idxs[i],:][:,:,:,split_q2_idxs[i]])
                conv5_2_params['p_perms'] = split_q2_idxs
                conv5_2_params['r_perms'] = split_r22_idxs
                split_params['conv5_2'] = conv5_2_params


                for i, unit_name in enumerate(['conv5_1', 'conv5_2', 'conv5_3', 'conv5_4', 'conv5_5', 'conv5_6']):
                    print('\t' + unit_name)
                    sp = {}
                    split_params[unit_name] = sp

            if FLAGS.ngroups3 > 1:
                print('\tconv4_x...')
                base_conv4_1_shortcut_k = reader.get_tensor('conv4_1/shortcut/kernel')
                base_conv4_1_conv1_k = reader.get_tensor('conv4_1/conv_1/kernel')
                base_conv4_1_conv2_k = reader.get_tensor('conv4_1/conv_2/kernel')
                base_conv4_2_conv1_k = reader.get_tensor('conv4_2/conv_1/kernel')
                base_conv4_2_conv2_k = reader.get_tensor('conv4_2/conv_2/kernel')
                split_p3_idxs = get_perms('group/split_p3/q', FLAGS.ngroups3)
                split_q3_idxs = _merge_split_idxs(split_p2_idxs, _get_even_merge_idxs(FLAGS.ngroups2, FLAGS.ngroups3))
                split_r31_idxs = get_perms('group/split_r31/q', FLAGS.ngroups3)
                split_r32_idxs = get_perms('group/split_r32/q', FLAGS.ngroups3)

                conv4_1_params = {'shortcut':[], 'conv1':[], 'conv2':[], 'p_perms':[], 'q_perms':[], 'r_perms':[]}
                for i in range(FLAGS.ngroups3):
                    conv4_1_params['shortcut'].append(base_conv4_1_shortcut_k[:,:,split_p3_idxs[i],:][:,:,:,split_q3_idxs[i]])
                    conv4_1_params['conv1'].append(base_conv4_1_conv1_k[:,:,split_p3_idxs[i],:][:,:,:,split_r31_idxs[i]])
                    conv4_1_params['conv2'].append(base_conv4_1_conv2_k[:,:,split_r31_idxs[i],:][:,:,:,split_q3_idxs[i]])
                conv4_1_params['p_perms'] = split_p3_idxs
                conv4_1_params['q_perms'] = split_q3_idxs
                conv4_1_params['r_perms'] = split_r31_idxs
                split_params['conv4_1'] = conv4_1_params

                conv4_2_params = {'conv1':[], 'conv2':[], 'p_perms':[], 'r_perms':[]}
                for i in range(FLAGS.ngroups3):
                    conv4_2_params['conv1'].append(base_conv4_2_conv1_k[:,:,split_q3_idxs[i],:][:,:,:,split_r32_idxs[i]])
                    conv4_2_params['conv2'].append(base_conv4_2_conv2_k[:,:,split_r32_idxs[i],:][:,:,:,split_q3_idxs[i]])
                conv4_2_params['p_perms'] = split_q3_idxs
                conv4_2_params['r_perms'] = split_r32_idxs
                split_params['conv4_2'] = conv4_2_params


        # Build model
        lr_decay_steps = map(float,FLAGS.lr_step_epoch.split(','))
        lr_decay_steps = map(int,[s*FLAGS.num_train_instance/FLAGS.batch_size/FLAGS.num_gpus for s in lr_decay_steps])
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_gpus=FLAGS.num_gpus,
                            num_classes=FLAGS.num_classes,
                            weight_decay=FLAGS.l2_weight,
                            ngroups1=FLAGS.ngroups1,
                            ngroups2=FLAGS.ngroups2,
                            ngroups3=FLAGS.ngroups3,
                            split_params=split_params,
                            momentum=FLAGS.momentum,
                            finetune=FLAGS.finetune)
        network_train = resnet.ResNet(hp, train_images, train_labels, global_step, name="train")
        network_train.build_model()
        network_train.build_train_op()
        train_summary_op = tf.summary.merge_all()  # Summaries(training)
        network_val = resnet.ResNet(hp, val_images, val_labels, global_step, name="val", reuse_weights=True)
        network_val.build_model()
        print('Number of Weights: %d' % network_train._weights)
        print('FLOPs: %d' % network_train._flops)


        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            # allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))

        '''debugging attempt
        from tensorflow.python import debug as tf_debug
        sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        def _get_data(datum, tensor):
            return tensor == train_images
        sess.add_tensor_filter("get_data", _get_data)
        '''

        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000)
        if FLAGS.checkpoint is not None:
           saver.restore(sess, FLAGS.checkpoint)
           init_step = global_step.eval(session=sess)
           print('Load checkpoint %s' % FLAGS.checkpoint)
        elif FLAGS.basemodel:
            # Define a different saver to save model checkpoints
            # Select only base variables (exclude split layers)
            print('Load parameters from basemodel %s' % FLAGS.basemodel)
            variables = tf.global_variables()
            vars_restore = [var for var in variables
                            if not "Momentum" in var.name and
                               not "logits" in var.name and
                               not "global_step" in var.name]
            if FLAGS.ngroups2 > 1:
                vars_restore = [var for var in vars_restore
                                if not "conv5_" in var.name]
            if FLAGS.ngroups3 > 1:
                vars_restore = [var for var in vars_restore
                                if not "conv4_" in var.name]
            saver_restore = tf.train.Saver(vars_restore, max_to_keep=10000)
            saver_restore.restore(sess, FLAGS.basemodel)
        else:
            print('No checkpoint file of basemodel found. Start from the scratch.')

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)

        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.summary.FileWriter(os.path.join(FLAGS.train_dir, str(global_step.eval(session=sess))),
                                                sess.graph)

        # Training!
        val_best_acc = 0.0
        for step in xrange(init_step, FLAGS.max_steps):
            # val
            if step % FLAGS.val_interval == 0:
                val_loss, val_acc = 0.0, 0.0
                for i in range(FLAGS.val_iter):
                    loss_value, acc_value = sess.run([network_val.loss, network_val.acc],
                                feed_dict={network_val.is_train:False})
                    val_loss += loss_value
                    val_acc += acc_value
                val_loss /= FLAGS.val_iter
                val_acc /= FLAGS.val_iter
                val_best_acc = max(val_best_acc, val_acc)
                format_str = ('%s: (val)     step %d, loss=%.4f, acc=%.4f')
                print (format_str % (datetime.now(), step, val_loss, val_acc))

                val_summary = tf.Summary()
                val_summary.value.add(tag='val/loss', simple_value=val_loss)
                val_summary.value.add(tag='val/acc', simple_value=val_acc)
                val_summary.value.add(tag='val/best_acc', simple_value=val_best_acc)
                summary_writer.add_summary(val_summary, step)
                summary_writer.flush()

            # Train
            lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps, step)
            start_time = time.time()
            _, loss_value, acc_value, train_summary_str = \
                    sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op],
                            feed_dict={network_train.is_train:True, network_train.lr:lr_value})
            duration = time.time() - start_time

            assert not np.isnan(loss_value)

            # Display & Summary(training)
            if step % FLAGS.display == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = ('%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                              'sec/batch)')
                print (format_str % (datetime.now(), step, loss_value, acc_value, lr_value,
                                     examples_per_sec, sec_per_batch))
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

            if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
              char = sys.stdin.read(1)
              if char == 'b':
                embed()