예제 #1
0
def main(_):
    # 设备选择
    if FLAGS.num_gpus == 0:
        dev = '/cpu:0'
    elif FLAGS.num_gpus == 1:
        dev = '/gpu:0'
    else:
        raise ValueError('Only support 0 or 1 gpu.')

    # 执行模式
    if FLAGS.mode == 'train':
        batch_size = 1
    elif FLAGS.mode == 'eval':
        batch_size = 2

    # 数据集类别数量
    if FLAGS.dataset == 'cifar10':
        num_classes = 1
    elif FLAGS.dataset == 'cifar100':
        num_classes = 100

    # 残差网络模型参数
    hps = resnet.HParams(batch_size=batch_size,
                               num_classes=num_classes,
                               min_lrn_rate=0.0001,
                               lrn_rate=0.1,
                               num_residual_units=5,
                               use_bottleneck=False,
                               weight_decay_rate=0.0002,
                               relu_leakiness=0.1,
                               optimizer='mom')
    # 执行训练或测试
    with tf.device(dev):
        if FLAGS.mode == 'train':
            train(hps)
        elif FLAGS.mode == 'eval':
            evaluate(hps)
예제 #2
0
def run_training(image_path, batch_size, epoch, model_path, log_dir, start_lr):
    # Tell TensorFlow that the model will be built into the default Graph.
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False, name='global_step')
        # Create a session for running operations in the Graph.
        sess = tf.Session()

        # Input images and labels.
        #images, labels = inputs(path=get_files_name(image_path), batch_size=batch_size, num_epochs=epoch)
        #images, labels = inputs(path=image_path, train=True, batch_size=batch_size, num_epochs=epoch)
        #train_images, train_labels = inputs(path=image_path, train=True, batch_size=batch_size, num_epochs=epoch)
        #test_images, test_labels = inputs(path=image_path, train=False, batch_size=batch_size, num_epochs=epoch)

        #record_file_names = ['./record_save/train-resave.tfrecords', './record_save/train-flip.tfrecords']
        record_file_names = ['./record_save/train.tfrecords']
        train_images, train_labels = inputs_multifile_data(
            record_file_names=record_file_names,
            train=True,
            batch_size=batch_size,
            num_epochs=epoch)
        test_images, test_labels = inputs_data(record_file_path=image_path,
                                               train=False,
                                               batch_size=batch_size,
                                               num_epochs=None)

        images = tf.placeholder(tf.float32,
                                [batch_size, HEIGHT, WIDTH, CHANNEL])
        labels = tf.placeholder(tf.int32, [batch_size, 88])
        labels_test = tf.placeholder(tf.int32, [batch_size])

        # train_mode = tf.placeholder(tf.bool)
        # load network

        decay_step = 10000  #10 * 190000 / 128
        hp = resnet.HParams(
            batch_size=batch_size,
            num_classes=89,
            num_residual_units=2,  #2
            k=4,  #caffe output k = 4
            weight_decay=0.0005,
            initial_lr=0.001,
            decay_step=decay_step,
            decay_rate=0.9,
            momentum=0.9,
            drop_prob=0.5)

        net = resnet.WResNet(hp, images, labels, labels_test,
                             global_step)  #, is_train=True
        net.build_model()
        net.build_train_op()

        #net = resnet.WResNet(hp, images, labels, global_step, train_mode)
        #net.network()
        #logits = net._logits

        #ans = tf.argmax(tf.nn.softmax(logits),1)
        # Build a Graph that computes predictions from the inference model.

        # Add to the Graph the loss calculation.
        #age_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
        #age_loss = tf.reduce_mean(age_cross_entropy) #age_cross_entropy_mean

        #        age_ = tf.cast(tf.constant([i for i in range(0, 89)]), tf.float32)
        #        age = tf.reduce_sum(tf.multiply(tf.nn.softmax(logits), age_), axis=1)
        #        abs_age_error = tf.losses.absolute_difference(labels, age)
        #
        #
        #        tf.summary.scalar("age_cross_entropy", age_loss)
        #        tf.summary.scalar("train_abs_age_error", abs_age_error)

        # Add to the Graph operations that train the model.

        #        lr = tf.train.exponential_decay(start_lr, global_step=global_step, decay_steps=10000, decay_rate=0.9, staircase=True)
        #        optimizer = tf.train.AdamOptimizer(lr)
        #        tf.summary.scalar("lr", lr)
        #        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  # update batch normalization layer
        #        with tf.control_dependencies(update_ops):
        #            train_op = optimizer.minimize(net.loss, global_step)

        # if you want to transfer weight from another model,please comment below codes
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            per_process_gpu_memory_fraction=0.45)))
        sess.run(init_op)
        #tf.global_variables_initializer().run()
        #tf.local_variables_initializer().run()

        #merged = tf.summary.merge_all()
        if not os.path.exists(log_dir):
            os.mkdir(log_dir)
        train_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # if you want to transfer weight from another model,please comment below codes
        init_step = 0
        saver = tf.train.Saver(max_to_keep=10000)
        ckpt = tf.train.get_checkpoint_state(model_path)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            init_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print('init_step: ', init_step)
            print("restore and continue training!")
        else:
            print(
                'No checkpoint file found. No old saved network, start from the scratch.'
            )
        # if you want to transfer weight from another model, please comment above codes

        # Start input enqueue threads.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        #Training
        test_best_acc = 0.0
        train_best_acc = 0.0
        test_interval = 1000
        max_steps = 1000000
        test_good_step = 0
        train_good_step = 0

        for step in range(init_step, max_steps):
            # Test
            if step % test_interval == 0:
                test_loss, test_acc = 0.0, 0.0
                for i in range(test_interval):
                    test_images_val, test_labels_val = sess.run(
                        [test_images, test_labels])
                    acc_value = sess.run(net.acc,
                                         feed_dict={
                                             images: test_images_val,
                                             labels_test: test_labels_val,
                                             net.is_train: False
                                         })
                    test_acc += acc_value
                test_acc /= test_interval
                #test_best_acc = max(test_best_acc, test_acc)

                if test_best_acc < test_acc:
                    test_best_acc = test_acc
                    test_good_step = step
                print('!!!!!! test_good_step: ', test_good_step)
                format_str = ('%s: (Test)     step %d, acc=%.4f ')
                print(format_str % (datetime.now(), step, test_acc))

                test_summary = tf.Summary()
                test_summary.value.add(tag='test/acc', simple_value=test_acc)
                test_summary.value.add(tag='test/best_acc',
                                       simple_value=test_best_acc)
                train_writer.add_summary(test_summary, step)

                train_writer.flush()

            # Train
            start_time = time.time()
            train_images_val, train_labels_val = sess.run(
                [train_images, train_labels])
            batch_labels = []
            for age in list(train_labels_val):
                plabel = np.zeros(shape=(88), dtype=np.float32)
                if age == 0:
                    age = 1
                if age > 88:
                    age = 88
                plabel[:age] = 1
                batch_labels.append(plabel)
            batch_labels = np.array(batch_labels)

            _, lr_value, train_loss = sess.run(
                [net.train_op, net.lr, net.loss],
                feed_dict={
                    images: train_images_val,
                    labels: batch_labels,
                    net.is_train: True
                })  #net.train_op
            duration = time.time() - start_time

            assert not np.isnan(train_loss)

            # Display & Summary(training)
            display = 100
            batch_size = 128
            if step % display == 0:
                num_examples_per_step = batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, loss=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, train_loss, lr_value,
                                    examples_per_sec, sec_per_batch))
                print(
                    sess.run(
                        [net.train_op, net.preds],
                        feed_dict={
                            images: train_images_val,
                            labels: batch_labels,
                            net.is_train: True
                        }))
                #train_writer.add_summary(train_summary_str, step)

                #if (train_best_acc <= train_acc)  and (train_best_acc > 0.95):
                #    train_best_acc = train_acc
                #    train_good_step = step
                #    print('!!!!!! train_good_step: ', train_good_step)

            # Save the model checkpoint periodically.
            #if (step > init_step and step % display == 0) or (step + 1) == max_steps:
            if (step % display == 0) or (step + 1) == max_steps:
                if (step == test_good_step):
                    checkpoint_path_test = os.path.join(
                        model_path + '/test_good', 'model_age.ckpt')
                    saver.save(sess, checkpoint_path_test, global_step=step)
                else:
                    checkpoint_path = os.path.join(model_path,
                                                   'model_age.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)

            #if(step == train_good_step) and (train_best_acc > 0.95):
            #    checkpoint_path = os.path.join(model_path, 'model_age.ckpt')
            #    saver.save(sess, checkpoint_path, global_step=step)
        # Wait for threads to finish.
        coord.join(threads)
        sess.close()
예제 #3
0
def train():
    print('[Dataset Configuration]')
    print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
    print('\tNetwork width multiplier: %d' % FLAGS.k)
    print('\tClustering file: %s' % FLAGS.cluster_path)
    print('\tWrong logit map: %d' % FLAGS.no_logit_map)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs to step down lr: %s' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)
    print('\tLearning rate multiplier for split net: %f' % FLAGS.split_lr_mult)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per testing: %d' % FLAGS.test_interval)
    print('\tSteps during testing: %d' % FLAGS.test_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with open(FLAGS.cluster_path) as fd:
        clustering = pickle.load(fd)

    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of CIFAR-100
        print('Load CIFAR-100 dataset')
        train_dataset_path = os.path.join(FLAGS.data_dir, 'train')
        test_dataset_path = os.path.join(FLAGS.data_dir, 'val')
        # train_dataset_path = os.path.join(FLAGS.data_dir, 'train_val')
        # test_dataset_path = os.path.join(FLAGS.data_dir, 'test')
        print('\tLoading training data from %s' % train_dataset_path)
        with tf.variable_scope('train_image'):
            cifar100_train = cifar100.CIFAR100Runner(train_dataset_path,
                                                     shuffle=True,
                                                     distort=True,
                                                     capacity=10000)
            train_images, train_labels = cifar100_train.get_inputs(
                FLAGS.batch_size)
        print('\tLoading test data from %s' % test_dataset_path)
        with tf.variable_scope('test_image'):
            cifar100_test = cifar100.CIFAR100Runner(test_dataset_path,
                                                    shuffle=False,
                                                    distort=False,
                                                    capacity=5000)
            # shuffle=False, distort=False, capacity=10000)
            test_images, test_labels = cifar100_test.get_inputs(
                FLAGS.batch_size)

        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, cifar100.IMAGE_SIZE, cifar100.IMAGE_SIZE, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # Build model
        print('Build model')
        lr_decay_steps = map(float, FLAGS.lr_step_epoch.split(','))
        lr_decay_steps = map(int, [
            s * FLAGS.num_train_instance / FLAGS.batch_size
            for s in lr_decay_steps
        ])
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_classes=FLAGS.num_classes,
                            num_residual_units=FLAGS.num_residual_units,
                            k=FLAGS.k,
                            weight_decay=FLAGS.l2_weight,
                            momentum=FLAGS.momentum,
                            no_logit_map=FLAGS.no_logit_map,
                            split_lr_mult=FLAGS.split_lr_mult)
        network = resnet.ResNet(hp, images, labels, global_step)
        network.set_clustering(clustering)
        network.build_model()
        print('%d flops' % network._flops)
        print('%d weights' % sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]))
        network.build_train_op()

        # Summaries(training)
        train_summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print('\tRestore from %s' % ckpt.model_checkpoint_path)
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            init_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        else:
            print('No checkpoint file found. Start from the scratch.')

        # Start queue runners & summary_writer
        cifar100_train.start_threads(sess, n_threads=10)
        cifar100_test.start_threads(sess, n_threads=1)

        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

        # Training!
        test_best_acc = 0.0
        for step in xrange(init_step, FLAGS.max_steps):
            # Test
            if step % FLAGS.test_interval == 0:
                test_loss, test_acc = 0.0, 0.0
                for i in range(FLAGS.test_iter):
                    test_images_val, test_labels_val = sess.run(
                        [test_images, test_labels])
                    loss_value, acc_value = sess.run(
                        [network.loss, network.acc],
                        feed_dict={
                            network.is_train: False,
                            images: test_images_val,
                            labels: test_labels_val
                        })
                    test_loss += loss_value
                    test_acc += acc_value
                test_loss /= FLAGS.test_iter
                test_acc /= FLAGS.test_iter
                test_best_acc = max(test_best_acc, test_acc)
                format_str = ('%s: (Test)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), step, test_loss, test_acc))

                test_summary = tf.Summary()
                test_summary.value.add(tag='test/loss', simple_value=test_loss)
                test_summary.value.add(tag='test/acc', simple_value=test_acc)
                test_summary.value.add(tag='test/best_acc',
                                       simple_value=test_best_acc)
                summary_writer.add_summary(test_summary, step)
                summary_writer.flush()

            # Train
            lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps,
                              step)
            start_time = time.time()
            train_images_val, train_labels_val = sess.run(
                [train_images, train_labels])
            _, loss_value, acc_value, train_summary_str = \
                    sess.run([network.train_op, network.loss, network.acc, train_summary_op],
                             feed_dict={network.is_train:True, network.lr:lr_value, images:train_images_val, labels:train_labels_val})
            duration = time.time() - start_time

            assert not np.isnan(loss_value)

            # Display & Summary(training)
            if step % FLAGS.display == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, acc_value, lr_value,
                       examples_per_sec, sec_per_batch))
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
예제 #4
0
def train():
    print('[Dataset Configuration]')
    print('\tImageNet test root: %s' % FLAGS.test_image_root)
    print('\tImageNet test list: %s' % FLAGS.test_dataset)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tCheckpoint file: %s' % FLAGS.checkpoint)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Evaluation Configuration]')
    print('\tOutput file path: %s' % FLAGS.output_file)
    print('\tTest iterations: %d' % FLAGS.test_iter)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    sess = tf.Session()

    global_step = tf.Variable(0, trainable=False, name='global_step')

    FLAGS.test_dataset = "./val.txt"
    # Get images and labels of ImageNet
    print('Load ImageNet dataset')
    with tf.device('/cpu:0'):
        print('\tLoading test data from %s' % FLAGS.test_dataset)
        with tf.variable_scope('test_image'):
            test_images, test_labels = data_input.inputs(FLAGS.test_image_root,
                                                         FLAGS.test_dataset,
                                                         FLAGS.batch_size,
                                                         False,
                                                         num_threads=1,
                                                         center_crop=True)

    # Build a Graph that computes the predictions from the inference model.
    images = tf.placeholder(
        tf.float32,
        [FLAGS.batch_size, data_input.IMAGE_HEIGHT, data_input.IMAGE_WIDTH, 3])
    labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

    #        images = tf.placeholder(tf.float32, [1, data_input.IMAGE_HEIGHT, data_input.IMAGE_WIDTH, 3])
    #        labels = tf.placeholder(tf.int32, [1])

    # Build model
    with tf.device('/cpu:0'):
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_gpus=1,
                            num_classes=FLAGS.num_classes,
                            weight_decay=FLAGS.l2_weight,
                            momentum=FLAGS.momentum,
                            finetune=FLAGS.finetune)
    network = resnet.ResNet(hp, [images], [labels], global_step)
    network.build_model()
    print('\tNumber of Weights: %d' % network._weights)
    print('\tFLOPs: %d' % network._flops)

    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph.
    sess = tf.Session(
        config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
                              allow_soft_placement=True,
                              log_device_placement=FLAGS.log_device_placement))

    sess.run(init)

    # Create a saver.
    saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
    if FLAGS.checkpoint is not None:
        saver.restore(sess, FLAGS.checkpoint)
        print('Load checkpoint %s' % FLAGS.checkpoint)
    else:
        print('No checkpoint file of basemodel found. Start from the scratch.')

    # Start queue runners & summary_writer
    tf.train.start_queue_runners(sess=sess)

    "============================================================================================================"
    "=================  Begin to insert restriction on selective layers  ========================================"
    "============================================================================================================"
    "NOTE: this model requires GPU (otherwise it'll report error while restoring the variables from the old graph to the new graph)"

    # get all the operators in the graph
    ops = [
        tensor for op in sess.graph.get_operations() for tensor in op.values()
    ]
    graph_def = sess.graph.as_graph_def()

    def get_op_dependency(op):
        "get all the node that precedes the target op"
        cur_op = []
        #op = sess.graph.get_tensor_by_name("ranger_11/ranger_10/ranger_9/ranger_8/ranger_7/ranger_6/ranger_5/ranger_4/ranger_3/ranger_2/ranger_1/ranger/Relu_5:0").op
        cur_op.append(op)
        next_op = []

        a = open("resnet-op.txt",
                 "a")  # save all the ops depend on the output op into file
        while (not (next_op == [] and cur_op == [])):
            next_op = []
            for each in cur_op:
                printline = False
                for inp in each.inputs:
                    printline = True
                    #print(inp)
                    a.write(str(inp) + "\n")
                    next_op.append(inp.op)
                if (printline):
                    #print('')
                    a.write("\n\n")
            cur_op = next_op

    def get_target_scope_prefix(scope_name, dup_cnt, dummy_scope_name,
                                dummy_graph_dup_cnt):
        "get the scope prefix of the target path (the latest duplicated path)"
        target_graph_prefix = ""  # the scope prefix of the latest path
        if (dup_cnt == 0):
            target_graph_prefix = ""  #
        elif (dup_cnt == 1):
            target_graph_prefix = str(scope_name + "/")  # e.g., ranger/relu:0
            if (dummy_graph_dup_cnt == 1):
                target_graph_prefix = dummy_scope_name + "/" + target_graph_prefix  # e.g., dummy/ranger/relu:0
        else:
            target_graph_prefix = str(scope_name + "/")

            if (dummy_graph_dup_cnt > 0):  # e.g., dummy/ranger/relu:0
                target_graph_prefix = dummy_scope_name + "/" + target_graph_prefix
                dummy_graph_dup_cnt -= 1

            for i in range(1, dup_cnt):
                target_graph_prefix = scope_name + "/" + target_graph_prefix  # e.g., ranger/dummy/ranger/ relu
                if (dummy_graph_dup_cnt > 0):
                    target_graph_prefix = dummy_scope_name + "/" + target_graph_prefix  # e.g., dummy/ranger/dummy/ranger/relu:0
                    dummy_graph_dup_cnt -= 1

        return target_graph_prefix

    def restore_all_var(sess, scope_name, dup_cnt, all_var, dummy_scope_name,
                        dummy_graph_dup_cnt, OLD_SESS):
        "need to map back the variable values to the ones under the new scope"

        target_graph_prefix = get_target_scope_prefix(scope_name, dup_cnt,
                                                      dummy_scope_name,
                                                      dummy_graph_dup_cnt)

        tmp = []
        for each in all_var:
            #print( target_graph_prefix ,  each.name )
            sess.run(
                tf.assign(
                    sess.graph.get_tensor_by_name(target_graph_prefix +
                                                  each.name),
                    OLD_SESS.run(OLD_SESS.graph.get_tensor_by_name(
                        each.name))))

    def get_op_with_prefix(op_name, dup_cnt, scope_name, dummy_graph_dup_cnt,
                           dummy_scope_name):
        "Need to call this function to return the name of the ops under the NEW graph (with scope prefix)"
        "return the name of the duplicated op with prefix, a new scope prefix upon each duplication"
        op_name = get_target_scope_prefix(scope_name, dup_cnt,
                                          dummy_scope_name,
                                          dummy_graph_dup_cnt) + op_name

        return op_name

    import re

    def modify_graph(sess, dup_cnt, scope_name, prefix_of_bound_op,
                     dummy_graph_dup_cnt, dummy_scope_name):
        "Modify the graph def to: 1) remove the nodes from older paths (we only need to keep the latest path)"
        " and 2) modify the input dependency to only associate with the latest path"
        graph_def = sess.graph.as_graph_def()

        target_graph_prefix = get_target_scope_prefix(scope_name, dup_cnt,
                                                      dummy_scope_name,
                                                      dummy_graph_dup_cnt)

        #print('target prefix ==> ', target_graph_prefix, dup_cnt)

        # Delete nodes from the redundant paths, we only want the most recent path, otherwise the size of graph will explode
        nodes = []
        for node in graph_def.node:
            if target_graph_prefix in node.name and prefix_of_bound_op not in node.name:  # ops to be kept, otherwise removed from graph
                nodes.append(node)

            elif (prefix_of_bound_op in node.name):

                if (dup_cnt != graph_dup_cnt):
                    "this part should keep the new op from the most recent duplication (with lesser prefix)"
                    if (target_graph_prefix not in node.name
                        ):  # remove dummy nodes like dummy/op
                        nodes.append(node)

                else:
                    nodes.append(node)

                # remove dummy nodes like dummy/dummy/relu
                if (dummy_scope_name + "/" + dummy_scope_name + "/"
                        in node.name):
                    nodes.remove(node)

        #print(' ', dup_cnt, dummy_graph_dup_cnt)

        mod_graph_def = tf.GraphDef()
        mod_graph_def.node.extend(nodes)

        "For the newly created op, we need to rewire the input dependency so that it only relies on the latest graph"
        "because we've only kpet the latest graph in the modified graphdef. "
        "This is for the restriction op, e.g., tf.maximum(relu_1, 100), where relu_1 is from the PREVIOUS graph"
        # Delete references to deleted nodes,
        for node in mod_graph_def.node:
            inp_names = []
            if (prefix_of_bound_op
                    in node.name):  # only for the restriction op
                for inp in node.input:
                    if prefix_of_bound_op in inp or target_graph_prefix in inp:
                        inp_names.append(inp)
                    else:
                        #print(node.name, inp, ' ---> ', (scope_name + "_" + str(dup_cnt-1) + "/" + inp) )
                        "here because we copy the graghdef from the PREVIOUS graph, it has dependency to the PREVIOUS graph"
                        "so we need to remove this redepency by using input from only the latest path, e.g., test/x3, test_1/test/x3, the"
                        "former will be removed in the above pruning, so we need to replace x3 input as test_1/test/x3 from the current graph"
                        # change the scope prefix to be the one from the latest path
                        bfname = inp
                        if (scope_name in inp):
                            regexp = re.escape(scope_name) + "_\d+/|" + re.escape(scope_name) + "/|" + \
                                      re.escape(dummy_scope_name) + "_\d+/|" + re.escape(dummy_scope_name) + "/" # pattern for "ranger_1/" or "ranger"
                            inp_names.append(target_graph_prefix +
                                             re.sub(regexp, "", inp))
                            afname = target_graph_prefix + re.sub(
                                regexp, "", inp)
                        else:
                            inp_names.append(target_graph_prefix + inp)
                            afname = target_graph_prefix + inp

                del node.input[:]  # delete all the inputs
                node.input.extend(
                    inp_names)  # keep the modified input dependency

        return mod_graph_def

    def printgraphdef(graphdef):
        for each in graphdef.node:
            print(each.name)

    def printgraph(sess):
        ops = [
            tensor for op in sess.graph.get_operations()
            for tensor in op.values()
        ]
        a = open("op.txt", "w")
        for n in ops:
            a.write(n.name + "\n")

    # in resenet-18, Relu is renamed as relu
    act = "relu"
    op_follow_act = ["MaxPool", "Reshape", "AvgPool"]
    special_op_follow_act = "concat"
    up_bound = map(float,
                   [7, 8, 7, 5, 11, 5, 12, 6, 11, 5, 12, 5, 14, 5, 12, 5, 66
                    ])  # upper bound for restriction
    low_bound = map(float, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
                            ])  # low bound for restriction

    PREFIX = 'ranger'  # scope name in the graph
    DUMMY_PREFIX = 'dummy'  #
    graph_dup_cnt = 0  # count the number of iteration for duplication, used to track the scope prefix of the new op
    dummy_graph_dup_cnt = 0  # count the num of dummy graph duplication (for resetting the default graph to contain only the latest path)

    op_cnt = 0  # count num of op
    act_cnt = 0  # count num of act
    check_follow = False  # flag for checking the following op (when the current op is ACT)
    op_to_keep = [
    ]  # ops to keep while duplicating the graph (we remove the irrelevant ops before duplication, otherwise the graph size will explode)
    new_op_prefix = "bound_op_prefix"  # prefix of the newly created ops for range restriction
    OLD_SESS = sess  # keep the old session
    all_var = tf.global_variables()  # all vars before duplication

    # get all the operators in the graph
    ops = [
        tensor for op in sess.graph.get_operations() for tensor in op.values()
    ]
    graph_def = sess.graph.as_graph_def()

    "iterate each op in the graph and insert bounding ops"
    for cur_op in ops:

        if (act in cur_op.name and ("gradients" not in cur_op.name)):
            # bounding
            with tf.name_scope(
                    new_op_prefix
            ) as scope:  # the restricion ops will have the special scope prefix name
                bound_tensor = sess.graph.get_tensor_by_name(
                    get_op_with_prefix(cur_op.name, graph_dup_cnt, PREFIX,
                                       dummy_graph_dup_cnt, DUMMY_PREFIX))
                print("bounding: ", bound_tensor, up_bound[act_cnt])
                rest = tf.maximum(bound_tensor, low_bound[act_cnt])
                rest = tf.minimum(rest, up_bound[act_cnt])

            op_to_be_replaced = get_op_with_prefix(cur_op.name, graph_dup_cnt,
                                                   PREFIX, dummy_graph_dup_cnt,
                                                   DUMMY_PREFIX)

            # delete redundant paths in graphdef and modify the input dependency to be depending on the latest path only
            truncated_graphdef = modify_graph(sess, graph_dup_cnt, PREFIX,
                                              new_op_prefix,
                                              dummy_graph_dup_cnt,
                                              DUMMY_PREFIX)
            # import the modified graghdef (inserted with bouding ops) into the current graph
            tf.import_graph_def(truncated_graphdef,
                                name=PREFIX,
                                input_map={op_to_be_replaced: rest})
            graph_dup_cnt += 1

            "reset the graph to contain only the duplicated path"
            truncated_graphdef = modify_graph(sess, graph_dup_cnt, PREFIX,
                                              new_op_prefix,
                                              dummy_graph_dup_cnt,
                                              DUMMY_PREFIX)
            tf.reset_default_graph()
            sess = tf.Session()
            sess.as_default()
            tf.import_graph_def(truncated_graphdef, name=DUMMY_PREFIX)
            dummy_graph_dup_cnt += 1

            check_follow = True  # this is a ACT, so we need to check the following op
            act_cnt = (act_cnt + 1) % len(
                up_bound
            )  # count the number of visited ACT (used for the case where there are two copies of ops, one for training and one testing)

        # this will check the next operator that follows the ACT op
        elif (check_follow):
            keep_rest = False  # check whether the following op needs to be bounded

            # this is the case for Maxpool, Avgpool and Reshape
            for each in op_follow_act:
                if (
                        each in cur_op.name and "/shape" not in cur_op.name
                ):  #the latter condition is for checking case like "Reshape_1/shape:0", this shouldn't be bounded
                    keep_rest = True
                    low = low_bound[act_cnt - 1]
                    up = up_bound[act_cnt - 1]
                    break
            # this is the case for ConCatV2, "axis" is the parameter to the actual op concat
            if (special_op_follow_act in cur_op.name
                    and ("axis" not in cur_op.name)
                    and ("values" not in cur_op.name)):
                keep_rest = True
                low = np.minimum(low_bound[act_cnt - 1],
                                 low_bound[act_cnt - 2])
                up = np.maximum(up_bound[act_cnt - 1], up_bound[act_cnt - 2])

            "bound the values, using either float (default) or int"
            if (keep_rest):
                try:
                    with tf.name_scope(
                            new_op_prefix
                    ) as scope:  # the restricion ops will have the special scope prefix name
                        bound_tensor = sess.graph.get_tensor_by_name(
                            get_op_with_prefix(cur_op.name, graph_dup_cnt,
                                               PREFIX, dummy_graph_dup_cnt,
                                               DUMMY_PREFIX))
                        print("bounding: ", bound_tensor)
                        rest = tf.maximum(bound_tensor, low)
                        rest = tf.minimum(rest, up)
                except:
                    with tf.name_scope(
                            new_op_prefix
                    ) as scope:  # the restricion ops will have the special scope prefix name
                        bound_tensor = sess.graph.get_tensor_by_name(
                            get_op_with_prefix(cur_op.name, graph_dup_cnt,
                                               PREFIX, dummy_graph_dup_cnt,
                                               DUMMY_PREFIX))
                        print("bounding: ", bound_tensor)
                        rest = tf.maximum(bound_tensor, int(low))
                        rest = tf.minimum(rest, int(up))
                #print(cur_op, act_cnt)
                #print(rest.op.node_def,' -----')
                "replace the input to the tensor, at the palce where we place Ranger, e.g., Ranger(ReLu), then we replace Relu"
                op_to_be_replaced = get_op_with_prefix(cur_op.name,
                                                       graph_dup_cnt, PREFIX,
                                                       dummy_graph_dup_cnt,
                                                       DUMMY_PREFIX)

                truncated_graphdef = modify_graph(sess, graph_dup_cnt, PREFIX,
                                                  new_op_prefix,
                                                  dummy_graph_dup_cnt,
                                                  DUMMY_PREFIX)
                tf.import_graph_def(truncated_graphdef,
                                    name=PREFIX,
                                    input_map={op_to_be_replaced: rest})
                graph_dup_cnt += 1

                "reset the graph to contain only the duplicated path"
                truncated_graphdef = modify_graph(sess, graph_dup_cnt, PREFIX,
                                                  new_op_prefix,
                                                  dummy_graph_dup_cnt,
                                                  DUMMY_PREFIX)
                tf.reset_default_graph()
                sess = tf.Session()
                sess.as_default()
                tf.import_graph_def(truncated_graphdef, name=DUMMY_PREFIX)
                dummy_graph_dup_cnt += 1

            # check the ops, but not to bound the ops
            else:
                check_follow = False  # the default setting is not to check the next op

                # the following ops of the listed operaions will be kept tracking,
                # becuase the listed ops do not perform actual computation, so the restriction bound still applies
                oblivious_ops = [
                    "Const", "truncated_normal", "Variable", "weights",
                    "biases", "dropout"
                ]
                if( ("Reshape" in cur_op.name and "/shape" in cur_op.name) or \
                    ("concat" in cur_op.name and ("axis" in cur_op.name or "values" in cur_op.name) )
                   ):
                    check_follow = True  # we need to check the following op of Reshape/shape:0, concat/axis (these are not the actual reshape/concat ops)
                else:
                    for ea in oblivious_ops:  # we need to check the op follows the listed ops
                        if (ea in cur_op.name):
                            check_follow = True

        op_cnt += 1

    # we need to call modify_graph to modify the input dependency for finalization
    truncated_graphdef = modify_graph(sess, graph_dup_cnt, PREFIX,
                                      new_op_prefix, dummy_graph_dup_cnt,
                                      DUMMY_PREFIX)
    tf.import_graph_def(truncated_graphdef, name=PREFIX)
    graph_dup_cnt += 1
    # restore the variables to the latest path

    truncated_graphdef = modify_graph(sess, graph_dup_cnt, PREFIX,
                                      new_op_prefix, dummy_graph_dup_cnt,
                                      DUMMY_PREFIX)
    tf.reset_default_graph()
    sess = tf.Session()
    sess.as_default()
    #printgraphdef(truncated_graphdef)
    tf.import_graph_def(truncated_graphdef, name=DUMMY_PREFIX)
    dummy_graph_dup_cnt += 1

    "restore all the variables from the orignial garph to the new graph"
    restore_all_var(sess, PREFIX, graph_dup_cnt, all_var, DUMMY_PREFIX,
                    dummy_graph_dup_cnt, OLD_SESS)
    #    printgraph(sess)

    print("Finish graph modification!")
    print('')
    "============================================================================================================"
    "============================================================================================================"

    "This is the name of the operator to be evaluated, we will find the corresponding one under the Ranger's scope"
    OP_FOR_EVAL = network.probs
    new_op_for_eval_name = get_op_with_prefix(OP_FOR_EVAL.op.name,
                                              graph_dup_cnt, PREFIX,
                                              dummy_graph_dup_cnt,
                                              DUMMY_PREFIX)
    print(new_op_for_eval_name, 'op to be eval')
    new_op_for_eval = sess.graph.get_tensor_by_name(new_op_for_eval_name +
                                                    ":0")

    # you can call this function to check the depenency of the final operator
    # you should see the bouding ops are inserted into the dependency
    # NOTE: the printing might contain duplicated output
    #get_op_dependency(new_op_for_eval.op)

    # input to eval the results
    for i in range(2):
        test_images_val, test_labels_val = OLD_SESS.run(
            [test_images[0], test_labels[0]])

    # evaluation on the old path
    preds = OLD_SESS.run(OP_FOR_EVAL,
                         feed_dict={
                             network.is_train: False,
                             images: test_images_val,
                             labels: test_labels_val
                         })
    print((np.argsort(np.asarray(preds)[0])[::-1])[0:10])
    print('')

    # evaluation on the new path
    new_x = sess.graph.get_tensor_by_name(
        get_op_with_prefix(images.op.name, graph_dup_cnt, PREFIX,
                           dummy_graph_dup_cnt, DUMMY_PREFIX) + ":0")
    new_y = sess.graph.get_tensor_by_name(
        get_op_with_prefix(labels.op.name, graph_dup_cnt, PREFIX,
                           dummy_graph_dup_cnt, DUMMY_PREFIX) + ":0")
    new_is_train = sess.graph.get_tensor_by_name(
        get_op_with_prefix(network.is_train.op.name, graph_dup_cnt, PREFIX,
                           dummy_graph_dup_cnt, DUMMY_PREFIX) + ":0")
    #new_prob2 = sess.graph.get_tensor_by_name(  get_op_with_prefix(model.prob2.op.name, graph_dup_cnt, PREFIX, dummy_graph_dup_cnt, DUMMY_PREFIX)+":0")

    preds = sess.run(new_op_for_eval,
                     feed_dict={
                         new_is_train: False,
                         new_x: test_images_val,
                         new_y: test_labels_val
                     })
    print((np.argsort(np.asarray(preds)[0])[::-1])[0:10])
예제 #5
0
def train():
    print('[Dataset Configuration]')
    print('\tImageNet training root: %s' % FLAGS.train_image_root)
    print('\tImageNet training list: %s' % FLAGS.train_dataset)
    print('\tImageNet val root: %s' % FLAGS.val_image_root)
    print('\tImageNet val list: %s' % FLAGS.val_dataset)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of val images: %d' % FLAGS.num_val_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tNumber of GPUs: %d' % FLAGS.num_gpus)
    print('\tBasemodel file: %s' % FLAGS.basemodel)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per validation: %d' % FLAGS.val_interval)
    print('\tSteps during validation: %d' % FLAGS.val_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of ImageNet
        import multiprocessing
        num_threads = multiprocessing.cpu_count() / FLAGS.num_gpus
        print('Load ImageNet dataset(%d threads)' % num_threads)
        with tf.device('/cpu:0'):
            print('\tLoading training data from %s' % FLAGS.train_dataset)
            with tf.variable_scope('train_image'):
                train_images, train_labels = data_input.distorted_inputs(
                    FLAGS.train_image_root,
                    FLAGS.train_dataset,
                    FLAGS.batch_size,
                    True,
                    num_threads=num_threads,
                    num_sets=FLAGS.num_gpus)
            print('\tLoading validation data from %s' % FLAGS.val_dataset)
            with tf.variable_scope('test_image'):
                val_images, val_labels = data_input.inputs(
                    FLAGS.val_image_root,
                    FLAGS.val_dataset,
                    FLAGS.batch_size,
                    False,
                    num_threads=num_threads,
                    num_sets=FLAGS.num_gpus)
            tf.summary.image('images', train_images[0][:2])

        # Build model
        lr_decay_steps = map(float, FLAGS.lr_step_epoch.split(','))
        lr_decay_steps = list(
            map(int, [
                s * FLAGS.num_train_instance / FLAGS.batch_size /
                FLAGS.num_gpus for s in lr_decay_steps
            ]))
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_gpus=FLAGS.num_gpus,
                            num_classes=FLAGS.num_classes,
                            weight_decay=FLAGS.l2_weight,
                            momentum=FLAGS.momentum,
                            finetune=FLAGS.finetune)
        network_train = resnet.ResNet(hp,
                                      train_images,
                                      train_labels,
                                      global_step,
                                      name="train")
        network_train.build_model()
        network_train.build_train_op()
        train_summary_op = tf.summary.merge_all()  # Summaries(training)
        network_val = resnet.ResNet(hp,
                                    val_images,
                                    val_labels,
                                    global_step,
                                    name="val",
                                    reuse_weights=True)
        network_val.build_model()
        print('Number of Weights: %d' % network_train._weights)
        print('FLOPs: %d' % network_train._flops)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            allow_soft_placement=False,
            # allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000)
        if FLAGS.checkpoint is not None:
            print('Load checkpoint %s' % FLAGS.checkpoint)
            saver.restore(sess, FLAGS.checkpoint)
            init_step = global_step.eval(session=sess)
        elif FLAGS.basemodel:
            # Define a different saver to save model checkpoints
            print('Load parameters from basemodel %s' % FLAGS.basemodel)
            variables = tf.global_variables()
            vars_restore = [
                var for var in variables
                if not "Momentum" in var.name and not "global_step" in var.name
            ]
            saver_restore = tf.train.Saver(vars_restore, max_to_keep=10000)
            saver_restore.restore(sess, FLAGS.basemodel)
        else:
            print(
                'No checkpoint file of basemodel found. Start from the scratch.'
            )

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)

        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.summary.FileWriter(
            os.path.join(FLAGS.train_dir, str(global_step.eval(session=sess))),
            sess.graph)

        # Training!
        val_best_acc = 0.0
        for step in range(init_step, FLAGS.max_steps):
            # val
            if step % FLAGS.val_interval == 0:
                val_loss, val_acc = 0.0, 0.0
                for i in range(FLAGS.val_iter):
                    loss_value, acc_value = sess.run(
                        [network_val.loss, network_val.acc],
                        feed_dict={network_val.is_train: False})
                    val_loss += loss_value
                    val_acc += acc_value
                val_loss /= FLAGS.val_iter
                val_acc /= FLAGS.val_iter
                val_best_acc = max(val_best_acc, val_acc)
                format_str = ('%s: (val)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), step, val_loss, val_acc))

                val_summary = tf.Summary()
                val_summary.value.add(tag='val/loss', simple_value=val_loss)
                val_summary.value.add(tag='val/acc', simple_value=val_acc)
                val_summary.value.add(tag='val/best_acc',
                                      simple_value=val_best_acc)
                summary_writer.add_summary(val_summary, step)
                summary_writer.flush()

            # Train
            lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps,
                              step)
            start_time = time.time()
            # For timeline profiling
            # if step == 153:
            # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            # run_metadata = tf.RunMetadata()
            # _, loss_value, acc_value, train_summary_str = \
            # sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op],
            # feed_dict={network_train.is_train:True, network_train.lr:lr_value}
            # , options=run_options, run_metadata=run_metadata)
            # # Create the Timeline object, and write it to a json
            # tl = timeline.Timeline(run_metadata.step_stats)
            # ctf = tl.generate_chrome_trace_format()
            # with open('timeline.json', 'w') as f:
            # f.write(ctf)
            # print('Wrote the timeline profile of %d iter training on %s' %(step, 'timeline.json'))
            # else:
            # _, loss_value, acc_value, train_summary_str = \
            # sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op],
            # feed_dict={network_train.is_train:True, network_train.lr:lr_value})
            _, loss_value, acc_value, train_summary_str = \
                    sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op],
                            feed_dict={network_train.is_train:True, network_train.lr:lr_value})
            duration = time.time() - start_time

            assert not np.isnan(loss_value)

            # Display & Summary(training)
            if step % FLAGS.display == 0 or step < 10:
                num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, acc_value, lr_value,
                       examples_per_sec, sec_per_batch))
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

            if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
                char = sys.stdin.read(1)
                if char == 'b':
                    embed()
예제 #6
0
def train():
    def Load():

        if (FLAGS.is_Simple or FLAGS.is_Train == False):
            train_images = np.load("Input/test_data.npy")
            train_labels = np.load("Input/test_label.npy")
        else:
            train_data = np.load("Input/train_data.npy")
            train_label = np.load("Input/train_label.npy")
            permutation = np.random.permutation(train_data.shape[0])
            train_images = train_data[permutation, :, :, :]
            train_labels = train_label[permutation, :]
        test_data = np.load("Input/test_data.npy")
        test_label = np.load("Input/test_label.npy")
        mean_data = np.array(np.load("Input/mean_data.npy"), dtype=np.float16)
        mean_label = np.load("Input/mean_label.npy")
        std_label = np.load("Input/std_label.npy")

        return train_images, train_labels, test_data, test_label, mean_data, mean_label, std_label

    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        X = tf.placeholder(tf.float32, [None, 224, 224, 3],
                           name='Input_Images')
        SHAPE = tf.placeholder(tf.float32, [None, 100], name='SHAPE')
        EXP = tf.placeholder(tf.float32, [None, 79], name='EXP')
        EULAR = tf.placeholder(tf.float32, [None, 3], name='EULAR')
        T = tf.placeholder(tf.float32, [None, 2], name='T')
        S = tf.placeholder(tf.float32, [None], name='S')

        # Build model
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_gpus=FLAGS.num_gpus,
                            num_output=FLAGS.dim_output,
                            weight_decay=FLAGS.l2_weight,
                            momentum=FLAGS.momentum,
                            finetune=FLAGS.finetune)

        network_train = resnet.ResNet(hp,
                                      X,
                                      SHAPE,
                                      EXP,
                                      EULAR,
                                      T,
                                      S,
                                      global_step,
                                      name="train")
        network_train.build_model()
        network_train.build_train_op()
        train_summary_op = tf.summary.merge_all()  # Summaries(training)
        print('Number of Weights: %d' % network_train._weights)
        print('FLOPs: %d' % network_train._flops)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()
        print("sess 0")
        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            allow_soft_placement=False,
            # allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        print("sess 1")
        sess.run(init)
        print("sess done")
        # Create a saver.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=2)

        if (FLAGS.is_Train == False):

            checkpoint_dir = FLAGS.train_dir  # os.path.join(export_dir, 'checkpoint')
            checkpoints = tf.train.get_checkpoint_state(checkpoint_dir)
            if checkpoints and checkpoints.model_checkpoint_path:
                checkpoints_name = os.path.basename(
                    checkpoints.model_checkpoint_path)
                saver.restore(sess,
                              os.path.join(checkpoint_dir, checkpoints_name))
            print('Load checkpoint %s' % checkpoints_name)

            init_step = global_step.eval(session=sess)
        else:
            checkpoint_dir = FLAGS.train_dir  # os.path.join(export_dir, 'checkpoint')
            checkpoints = tf.train.get_checkpoint_state(checkpoint_dir)
            if checkpoints and checkpoints.model_checkpoint_path:
                checkpoints_name = os.path.basename(
                    checkpoints.model_checkpoint_path)
                saver.restore(sess,
                              os.path.join(checkpoint_dir, checkpoints_name))
            print('Load checkpoint %s' % checkpoints_name)
            init_step = global_step.eval(session=sess)
            #print('Start from the scratch.')

        # if not os.path.exists(FLAGS.train_dir):
        #     os.mkdir(FLAGS.train_dir)
        # summary_writer = tf.summary.FileWriter(os.path.join(FLAGS.train_dir, str(global_step.eval(session=sess))),
        #                                         sess.graph)

        # Training!
        train_images, train_labels, test_data, test_label, mean_data, mean_label, std_label = Load(
        )
        one_epoch_step = int(len(train_labels) / FLAGS.batch_size)
        train_images = (train_images - mean_data) / 255.0
        #train_labels = (train_labels-mean_label)/std_label
        test_data = (test_data - mean_data) / 255.0
        #test_label= (test_label - mean_label)/std_label
        print("data done")

        if (FLAGS.is_CustomTest == True):
            batch_data = (np.load('Input/gx.npy') - mean_data) / 255.0
            batch_labels = np.zeros((len(batch_data), 185))
            tmp = np.zeros((len(batch_data), 185))
            shape_logits, exp_logits, eular_logits, t_logits, s_logits = sess.run(
                [network_train.shape_logits, network_train.exp_logits,
                 network_train.eular_logits, network_train.t_logits,
                 network_train.s_logits, ], \
                feed_dict={network_train.is_train: False, X: batch_data,
                           SHAPE: batch_labels[:, :100], EXP: batch_labels[:, 100:179],
                           EULAR: batch_labels[:, 179:182], T: batch_labels[:, 182:184],
                           S: batch_labels[:, 184]})
            tmp[:, 0:100] = np.array(exp_logits)
            tmp[:, 100:179] = np.array(shape_logits)
            tmp[:, 179:182] = np.array(eular_logits)
            tmp[:, 182:184] = np.array(t_logits)
            tmp[:, 184][:, None] = np.array(s_logits)
            np.savetxt("tmp/gx.txt", tmp)

        elif (FLAGS.is_Train == False):

            max_iteration = int(len(test_label) / FLAGS.batch_size)
            print("max iteration is " + str(max_iteration))
            loss_ = 0
            tmp = np.zeros([185])
            for i in range(10):
                print(i)
                offset = (i * FLAGS.batch_size) % (test_data.shape[0] -
                                                   FLAGS.batch_size)
                batch_data = test_data[offset:(offset +
                                               FLAGS.batch_size), :, :, :]
                batch_labels = test_label[offset:(offset +
                                                  FLAGS.batch_size), :]

                shape_logits,exp_logits,eular_logits,t_logits,s_logits = sess.run([network_train.shape_logits,network_train.exp_logits,
                                                                                   network_train.eular_logits,network_train.t_logits,
                                                                                   network_train.s_logits,],\
                                             feed_dict={network_train.is_train: False,  X: batch_data,
                                                       SHAPE: batch_labels[:,:100],EXP:batch_labels[:,100:179],
                                                       EULAR:batch_labels[:,179:182],T:batch_labels[:,182:184],
                                                       S:batch_labels[:,184]})
                tmp[0:100] = np.array(exp_logits[0, :])
                tmp[100:179] = np.array(shape_logits[0, :])
                tmp[179:182] = np.array(eular_logits[0, :])
                tmp[182:184] = np.array(t_logits[0, :])
                tmp[184] = np.array(s_logits[0, :])

                #loss_+=loss_value[0]
                #print("test loss = " +str(loss_/ max_iteration))
                np.savetxt("tmp/" + str(i) + ".txt", tmp)
                fig = np.array((batch_data[0, :, :, :] * 255 + mean_data),
                               dtype=np.uint8)
                cv2.imwrite("tmp/" + str(i) + ".jpg", fig)

        else:

            for step in range(init_step, FLAGS.max_steps):

                offset = (step * FLAGS.batch_size) % (train_labels.shape[0] -
                                                      FLAGS.batch_size)
                batch_data = train_images[offset:(offset +
                                                  FLAGS.batch_size), :, :, :]
                batch_labels = train_labels[offset:(offset +
                                                    FLAGS.batch_size), :]
                # Train
                lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay,
                                  one_epoch_step, step)
                start_time = time.time()
                _, loss_value, shape_loss, exp_loss, eular_loss, t_loss, s_loss, points_loss, geo_loss, pose_loss = sess.run(
                    [
                        network_train.train_op, network_train.loss,
                        network_train.shape_loss, network_train.exp_loss,
                        network_train.eular_loss, network_train.t_loss,
                        network_train.s_loss, network_train.points_loss,
                        network_train.geo_loss, network_train.pose_loss
                    ],
                    feed_dict={
                        network_train.is_train: True,
                        network_train.lr: lr_value,
                        X: batch_data,
                        SHAPE: batch_labels[:, :100],
                        EXP: batch_labels[:, 100:179],
                        EULAR: batch_labels[:, 179:182],
                        T: batch_labels[:, 182:184],
                        S: batch_labels[:, 184]
                    })

                duration = time.time() - start_time

                assert not np.isnan(loss_value)

                # Display & Summary(training)
                if step % FLAGS.display == 0 or step < 10:
                    num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)
                    format_str = (
                        '%s: (Training) step %d, loss=%.4f, lr=%f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str %
                          (datetime.now(), step, loss_value, lr_value,
                           examples_per_sec, sec_per_batch))

                    format_str = (
                        'shape_loss=%.4f, exp_loss=%.4f,eular_loss=%.4f,t_loss=%.4f,s_loss=%.4f,points_loss=%.4f,geo_loss=%.4f,pose_loss=%.4f'
                    )
                    print(format_str %
                          (shape_loss, exp_loss, eular_loss, t_loss, s_loss,
                           points_loss, geo_loss, pose_loss))
                    elapse = time.time() - start_time
                    time_left = (FLAGS.max_steps - step) * elapse
                    print("\tTime left: %02d:%02d:%02d" %
                          (int(time_left / 3600), int(
                              time_left % 3600 / 60), time_left % 60))

                    # summary_writer.add_summary(train_summary_str, step)

                # Save the model checkpoint periodically.
                if (step > init_step and step % FLAGS.checkpoint_interval
                        == 0) or (step + 1) == FLAGS.max_steps:
                    checkpoint_path = os.path.join(FLAGS.train_dir,
                                                   'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)
예제 #7
0
def train():
    print('[Dataset Configuration]')
    print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
    print('\tNetwork width multiplier: %d' % FLAGS.k)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %f' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per testing: %d' % FLAGS.test_interval)
    print('\tSteps during testing: %d' % FLAGS.test_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, data_input.HEIGHT, data_input.WIDTH, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # Build model
        decay_step = FLAGS.lr_step_epoch * FLAGS.num_train_instance / FLAGS.batch_size
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_classes=FLAGS.num_classes,
                            num_residual_units=FLAGS.num_residual_units,
                            k=FLAGS.k,
                            weight_decay=FLAGS.l2_weight,
                            initial_lr=FLAGS.initial_lr,
                            decay_step=decay_step,
                            lr_decay=FLAGS.lr_decay,
                            momentum=FLAGS.momentum)
        network = resnet.ResNet(hp, images, labels, None)
        network.build_model()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        if os.path.isdir(FLAGS.ckpt_path):
            ckpt = tf.train.get_checkpoint_state(FLAGS.ckpt_path)
            # Restores from checkpoint
            if ckpt and ckpt.model_checkpoint_path:
                print('\tRestore from %s' % ckpt.model_checkpoint_path)
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print('No checkpoint file found in the dir [%s]' %
                      FLAGS.ckpt_path)
                sys.exit(1)
        elif os.path.isfile(FLAGS.ckpt_path):
            print('\tRestore from %s' % FLAGS.ckpt_path)
            saver.restore(sess, FLAGS.ckpt_path)
        else:
            print('No checkpoint file found in the path [%s]' %
                  FLAGS.ckpt_path)
            sys.exit(1)

        graph = tf.get_default_graph()
        block_num = 3
        conv_num = 2
        old_kernels_to_cluster = []
        old_kernels_to_add = []
        old_batch_norm = []
        for i in range(1, block_num + 1):
            for j in range(FLAGS.num_residual_units):
                old_kernels_to_cluster.append(get_kernel(i, j, 1, graph, sess))
                old_kernels_to_add.append(get_kernel(i, j, 2, graph, sess))
                old_batch_norm.append(get_batch_norm(i, j, 2, graph, sess))
        #old_batch_norm = old_batch_norm[1:]
        #old_batch_norm.append(get_last_batch_norm(graph, sess))

        new_params = []
        new_width = [
            16,
            int(16 * FLAGS.new_k),
            int(32 * FLAGS.new_k),
            int(64 * FLAGS.new_k)
        ]
        for i in range(len(old_batch_norm)):
            cluster_num = new_width[int(i / 4) + 1]
            cluster_kernels, cluster_indices = cluster_kernel(
                old_kernels_to_cluster[i], cluster_num)
            add_kernels = add_kernel(old_kernels_to_add[i], cluster_indices,
                                     cluster_num)
            cluster_batchs_norm = cluster_batch_norm(old_batch_norm[i],
                                                     cluster_indices,
                                                     cluster_num)
            new_params.append(cluster_kernels)
            for p in range(BATCH_NORM_PARAM_NUM):
                new_params.append(cluster_batchs_norm[p])
            new_params.append(add_kernels)

        # save variables
        init_params = []
        new_param_index = 0
        for var in tf.global_variables():
            update_match = UPDATE_PARAM_REGEX.match(var.name)
            skip_match = SKIP_PARAM_REGEX.match(var.name)
            if update_match and not skip_match:
                print("update {}".format(var.name))
                init_params.append((new_params[new_param_index], var.name))
                new_param_index += 1
            else:
                print("not update {}".format(var.name))
                var_vector = sess.run(var)
                init_params.append((var_vector, var.name))

        #close old graph
        sess.close()
    tf.reset_default_graph()

    # build new graph and eval
    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of CIFAR-100
        with tf.variable_scope('train_image'):
            train_images, train_labels = data_input.input_fn(FLAGS.data_dir,
                                                             FLAGS.batch_size,
                                                             train_mode=True)
        with tf.variable_scope('test_image'):
            test_images, test_labels = data_input.input_fn(FLAGS.data_dir,
                                                           FLAGS.batch_size,
                                                           train_mode=False)

        # The class labels
        with open(os.path.join(FLAGS.data_dir, 'fine_label_names.txt')) as fd:
            classes = [temp.strip() for temp in fd.readlines()]

        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, data_input.HEIGHT, data_input.WIDTH, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        new_network = resnet.ResNet(hp, images, labels, global_step,
                                    init_params, FLAGS.new_k)
        new_network.build_model()
        new_network.build_train_op()

        train_summary_op = tf.summary.merge_all()

        init = tf.initialize_all_variables()
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print('\tRestore from %s' % ckpt.model_checkpoint_path)
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            init_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        else:
            print('No checkpoint file found. Start from the scratch.')
        sys.stdout.flush()
        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)
        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
        # Training!
        test_best_acc = 0.0
        for step in range(init_step, FLAGS.max_steps):
            # Test
            if step % FLAGS.test_interval == 0:
                test_loss, test_acc = 0.0, 0.0
                for i in range(FLAGS.test_iter):
                    test_images_val, test_labels_val = sess.run(
                        [test_images, test_labels])
                    loss_value, acc_value = sess.run(
                        [new_network.loss, new_network.acc],
                        feed_dict={
                            new_network.is_train: False,
                            images: test_images_val,
                            labels: test_labels_val
                        })
                    test_loss += loss_value
                    test_acc += acc_value
                test_loss /= FLAGS.test_iter
                test_acc /= FLAGS.test_iter
                test_best_acc = max(test_best_acc, test_acc)
                format_str = ('%s: (Test)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), step, test_loss, test_acc))
                sys.stdout.flush()
                test_summary = tf.Summary()
                test_summary.value.add(tag='test/loss', simple_value=test_loss)
                test_summary.value.add(tag='test/acc', simple_value=test_acc)
                test_summary.value.add(tag='test/best_acc',
                                       simple_value=test_best_acc)
                summary_writer.add_summary(test_summary, step)
                summary_writer.flush()
            # Train
            start_time = time.time()
            train_images_val, train_labels_val = sess.run(
                [train_images, train_labels])
            _, lr_value, loss_value, acc_value, train_summary_str = \
                    sess.run([new_network.train_op, new_network.lr, new_network.loss, new_network.acc, train_summary_op],
                        feed_dict={new_network.is_train:True, images:train_images_val, labels:train_labels_val})
            duration = time.time() - start_time
            assert not np.isnan(loss_value)
            # Display & Summary(training)
            if step % FLAGS.display == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, acc_value, lr_value,
                       examples_per_sec, sec_per_batch))
                sys.stdout.flush()
                summary_writer.add_summary(train_summary_str, step)
            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
예제 #8
0
파일: eval.py 프로젝트: idobronstein/my_WRN
def train():
    print('[Dataset Configuration]')
    print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
    print('\tNetwork width multiplier: %d' % FLAGS.k)

    print('[Testing Configuration]')
    print('\tCheckpoint path: %s' % FLAGS.ckpt_path)
    print('\tDataset: %s' % ('Training' if FLAGS.train_data else 'Test'))
    print('\tNumber of testing iterations: %d' % FLAGS.test_iter)
    print('\tOutput path: %s' % FLAGS.output)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():
        # The CIFAR-100 dataset
        with tf.variable_scope('test_image'):
            test_images, test_labels = data_input.input_fn(
                FLAGS.data_dir,
                FLAGS.batch_size,
                train_mode=FLAGS.train_data,
                num_threads=1)

        # The class labels
        with open(os.path.join(FLAGS.data_dir, 'fine_label_names.txt')) as fd:
            classes = [temp.strip() for temp in fd.readlines()]

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, data_input.HEIGHT, data_input.WIDTH, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # Build model
        decay_step = FLAGS.lr_step_epoch * FLAGS.num_train_instance / FLAGS.batch_size
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_classes=FLAGS.num_classes,
                            num_residual_units=FLAGS.num_residual_units,
                            k=FLAGS.k,
                            weight_decay=FLAGS.l2_weight,
                            initial_lr=FLAGS.initial_lr,
                            decay_step=decay_step,
                            lr_decay=FLAGS.lr_decay,
                            momentum=FLAGS.momentum)
        network = resnet.ResNet(hp, images, labels, None, new_k=FLAGS.new_k)
        network.build_model()
        # network.build_train_op()  # NO training op

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        if os.path.isdir(FLAGS.ckpt_path):
            ckpt = tf.train.get_checkpoint_state(FLAGS.ckpt_path)
            # Restores from checkpoint
            if ckpt and ckpt.model_checkpoint_path:
                print('\tRestore from %s' % ckpt.model_checkpoint_path)
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print('No checkpoint file found in the dir [%s]' %
                      FLAGS.ckpt_path)
                sys.exit(1)
        elif os.path.isfile(FLAGS.ckpt_path):
            print('\tRestore from %s' % FLAGS.ckpt_path)
            saver.restore(sess, FLAGS.ckpt_path)
        else:
            print('No checkpoint file found in the path [%s]' %
                  FLAGS.ckpt_path)
            sys.exit(1)

        # Start queue runners
        tf.train.start_queue_runners(sess=sess)

        # Testing!
        result_ll = [[0, 0] for _ in range(FLAGS.num_classes)
                     ]  # Correct/wrong counts for each class
        test_loss = 0.0, 0.0
        for i in range(FLAGS.test_iter):
            test_images_val, test_labels_val = sess.run(
                [test_images, test_labels])
            preds_val, loss_value, acc_value = sess.run(
                [network.preds, network.loss, network.acc],
                feed_dict={
                    network.is_train: False,
                    images: test_images_val,
                    labels: test_labels_val
                })
            test_loss += loss_value
            for j in range(FLAGS.batch_size):
                correct = 0 if test_labels_val[j] == preds_val[j] else 1
                result_ll[test_labels_val[j] % FLAGS.num_classes][correct] += 1
        test_loss /= FLAGS.test_iter

        # Summary display & output
        acc_list = [float(r[0]) / float(r[0] + r[1]) for r in result_ll]
        result_total = np.sum(np.array(result_ll), axis=0)
        acc_total = float(result_total[0]) / np.sum(result_total)

        print('Class    \t\t\tT\tF\tAcc.')
        format_str = '%-31s %7d %7d %.5f'
        for i in range(FLAGS.num_classes):
            print(format_str %
                  (classes[i], result_ll[i][0], result_ll[i][1], acc_list[i]))
        print(format_str %
              ('(Total)', result_total[0], result_total[1], acc_total))

        # Output to file(if specified)
        if FLAGS.output.strip():
            with open(FLAGS.output, 'w') as fd:
                fd.write('Class    \t\t\tT\tF\tAcc.\n')
                format_str = '%-31s %7d %7d %.5f'
                for i in range(FLAGS.num_classes):
                    t, f = result_ll[i]
                    format_str = '%-31s %7d %7d %.5f\n'
                    fd.write(format_str %
                             (classes[i].replace(' ', '-'), t, f, acc_list[i]))
                fd.write(
                    format_str %
                    ('(Total)', result_total[0], result_total[1], acc_total))
예제 #9
0
def train():
    print('[Dataset Configuration]')
    print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
    print('\tNetwork width multiplier: %d' % FLAGS.k)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %f' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per testing: %d' % FLAGS.test_interval)
    print('\tSteps during testing: %d' % FLAGS.test_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of CIFAR-100
        with tf.variable_scope('train_image'):
            train_images, train_labels = data_input.input_fn(FLAGS.data_dir,
                                                             FLAGS.batch_size,
                                                             train_mode=True)
        with tf.variable_scope('test_image'):
            test_images, test_labels = data_input.input_fn(FLAGS.data_dir,
                                                           FLAGS.batch_size,
                                                           train_mode=False)

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, data_input.HEIGHT, data_input.WIDTH, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # Build model
        decay_step = FLAGS.lr_step_epoch * FLAGS.num_train_instance / FLAGS.batch_size
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_classes=FLAGS.num_classes,
                            num_residual_units=FLAGS.num_residual_units,
                            k=FLAGS.k,
                            weight_decay=FLAGS.l2_weight,
                            initial_lr=FLAGS.initial_lr,
                            decay_step=decay_step,
                            lr_decay=FLAGS.lr_decay,
                            momentum=FLAGS.momentum)
        network = resnet.ResNet(hp, images, labels, global_step)
        network.build_model()
        network.build_train_op()

        # Summaries(training)
        train_summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=False))
        sess.run(init)

        # fareed
        dot_rep = graph_to_dot(tf.get_default_graph())
        with open('profs/wrn.dot', 'w') as fwr:
            fwr.write(str(dot_rep))
        # trace_level=tf.RunOptions.FULL_TRACE,
        options = tf.RunOptions(report_tensor_allocations_upon_oom=True)
        run_metadata = tf.RunMetadata()

        operations_tensors = {}
        operations_attributes = {}
        operations_names = tf.get_default_graph().get_operations()
        count1 = 0
        count2 = 0

        for operation in operations_names:
            operation_name = operation.name
            operations_info = tf.get_default_graph().get_operation_by_name(
                operation_name).values()

            try:
                operations_attributes[operation_name] = []
                operations_attributes[operation_name].append(operation.type)
                operations_attributes[operation_name].append(
                    tf.get_default_graph().get_tensor_by_name(
                        operation_name + ':0').dtype._is_ref_dtype)
            except:
                pass
            if len(operations_info) > 0:
                if not (operations_info[0].shape.ndims is None):
                    operation_shape = operations_info[0].shape.as_list()
                    operation_dtype_size = operations_info[0].dtype.size
                    if not (operation_dtype_size is None):
                        operation_no_of_elements = 1
                        for dim in operation_shape:
                            if not (dim is None):
                                operation_no_of_elements = operation_no_of_elements * dim
                        total_size = operation_no_of_elements * operation_dtype_size
                        operations_tensors[operation_name] = total_size
                    else:
                        count1 = count1 + 1
                else:
                    count1 = count1 + 1
                    operations_tensors[operation_name] = -1

                #   print('no shape_1: ' + operation_name)
                #  print('no shape_2: ' + str(operations_info))
                #  operation_namee = operation_name + ':0'
                # tensor = tf.get_default_graph().get_tensor_by_name(operation_namee)
                # print('no shape_3:' + str(tf.shape(tensor)))
                # print('no shape:' + str(tensor.get_shape()))

            else:
                # print('no info :' + operation_name)
                # operation_namee = operation.name + ':0'
                count2 = count2 + 1
                operations_tensors[operation_name] = -1

                # try:
                #   tensor = tf.get_default_graph().get_tensor_by_name(operation_namee)
                # print(tensor)
                # print(tf.shape(tensor))
                # except:
                # print('no tensor: ' + operation_namee)
        print(count1)
        print(count2)
        with open('./profs/tensors_sz_32.txt', 'w') as f:
            for tensor, size in operations_tensors.items():
                f.write('"' + tensor + '"::' + str(size) + '\n')

        with open('./profs/operations_attributes.txt', 'w') as f:
            for op, attrs in operations_attributes.items():
                strr = op
                for attr in attrs:
                    strr += '::' + str(attr)
                strr += '\n'
                f.write(strr)

        # end fareed
        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print('\tRestore from %s' % ckpt.model_checkpoint_path)
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            init_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        else:
            print('No checkpoint file found. Start from the scratch.')

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)
        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        # Training!
        test_best_acc = 0.0
        for step in range(init_step, FLAGS.max_steps):
            # Test
            if step % FLAGS.test_interval == 777:
                test_loss, test_acc = 0.0, 0.0
                for i in range(FLAGS.test_iter):
                    test_images_val, test_labels_val = sess.run(
                        [test_images, test_labels])
                    loss_value, acc_value = sess.run(
                        [network.loss, network.acc],
                        feed_dict={
                            network.is_train: False,
                            images: test_images_val,
                            labels: test_labels_val
                        })
                    test_loss += loss_value
                    test_acc += acc_value
                test_loss /= FLAGS.test_iter
                test_acc /= FLAGS.test_iter
                test_best_acc = max(test_best_acc, test_acc)
                format_str = ('%s: (Test)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), step, test_loss, test_acc))

                test_summary = tf.Summary()
                test_summary.value.add(tag='test/loss', simple_value=test_loss)
                test_summary.value.add(tag='test/acc', simple_value=test_acc)
                test_summary.value.add(tag='test/best_acc',
                                       simple_value=test_best_acc)
                summary_writer.add_summary(test_summary, step)
                # test_loss_summary = tf.Summary()
                # test_loss_summary.value.add(tag='test/loss', simple_value=test_loss)
                # summary_writer.add_summary(test_loss_summary, step)
                # test_acc_summary = tf.Summary()
                # test_acc_summary.value.add(tag='test/acc', simple_value=test_acc)
                # summary_writer.add_summary(test_acc_summary, step)
                # test_best_acc_summary = tf.Summary()
                # test_best_acc_summary.value.add(tag='test/best_acc', simple_value=test_best_acc)
                # summary_writer.add_summary(test_best_acc_summary, step)
                summary_writer.flush()

            # Train
            # fareed
            if step % 10 == 1:
                train_images_val, train_labels_val = sess.run(
                    [train_images, train_labels],
                    run_metadata=run_metadata,
                    options=options)
                _, lr_value, loss_value, acc_value, train_summary_str = sess.run(
                    [
                        network.train_op, network.lr, network.loss,
                        network.acc, train_summary_op
                    ],
                    feed_dict={
                        network.is_train: True,
                        images: train_images_val,
                        labels: train_labels_val
                    },
                    run_metadata=run_metadata,
                    options=options)
                profile(run_metadata, step)

                if step == 1:
                    options_mem = tf.profiler.ProfileOptionBuilder.time_and_memory(
                    )
                    options_mem["min_bytes"] = 0
                    options_mem["min_micros"] = 0
                    options_mem["output"] = 'file:outfile=./profs/mem.txt'
                    options_mem["select"] = ("bytes", "peak_bytes",
                                             "output_bytes", "residual_bytes")
                    mem = tf.profiler.profile(tf.get_default_graph(),
                                              run_meta=run_metadata,
                                              cmd="scope",
                                              options=options_mem)
                    with open('profs/mem_2.txt', 'w') as f:
                        f.write(str(mem))
                # end fareed
            else:
                start_time = time.time()
                train_images_val, train_labels_val = sess.run(
                    [train_images, train_labels])
                _, lr_value, loss_value, acc_value, train_summary_str = sess.run(
                    [
                        network.train_op, network.lr, network.loss,
                        network.acc, train_summary_op
                    ],
                    feed_dict={
                        network.is_train: True,
                        images: train_images_val,
                        labels: train_labels_val
                    },
                    options=options)
                duration = time.time() - start_time

            assert not np.isnan(loss_value)

            # Display & Summary(training)
            if step % FLAGS.display == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, acc_value, lr_value,
                       examples_per_sec, sec_per_batch))
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
예제 #10
0
def train():
    print('[Dataset Configuration]')
    print('\tImageNet test root: %s' % FLAGS.test_image_root)
    print('\tImageNet test list: %s' % FLAGS.test_dataset)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tCheckpoint file: %s' % FLAGS.checkpoint)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Evaluation Configuration]')
    print('\tOutput file path: %s' % FLAGS.output_file)
    print('\tTest iterations: %d' % FLAGS.test_iter)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of ImageNet
        print('Load ImageNet dataset')
        with tf.device('/cpu:0'):
            print('\tLoading test data from %s' % FLAGS.test_dataset)
            with tf.variable_scope('test_image'):
                test_images, test_labels = data_input.inputs(
                    FLAGS.test_image_root,
                    FLAGS.test_dataset,
                    FLAGS.batch_size,
                    False,
                    num_threads=1,
                    center_crop=True)

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(tf.float32, [
            FLAGS.batch_size, data_input.IMAGE_HEIGHT, data_input.IMAGE_WIDTH,
            3
        ])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        #        images = tf.placeholder(tf.float32, [1, data_input.IMAGE_HEIGHT, data_input.IMAGE_WIDTH, 3])
        #        labels = tf.placeholder(tf.int32, [1])

        # Build model
        with tf.device('/GPU:0'):
            hp = resnet.HParams(batch_size=FLAGS.batch_size,
                                num_gpus=1,
                                num_classes=FLAGS.num_classes,
                                weight_decay=FLAGS.l2_weight,
                                momentum=FLAGS.momentum,
                                finetune=FLAGS.finetune)
        network = resnet.ResNet(hp, [images], [labels], global_step)
        network.build_model()
        print('\tNumber of Weights: %d' % network._weights)
        print('\tFLOPs: %d' % network._flops)

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        '''
	debugging attempt
        from tensorflow.python import debug as tf_debug
        sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        def _get_data(datum, tensor):
            return tensor == train_images
        sess.add_tensor_filter("get_data", _get_data)
        '''

        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        if FLAGS.checkpoint is not None:
            saver.restore(sess, FLAGS.checkpoint)
            print('Load checkpoint %s' % FLAGS.checkpoint)
        else:
            print(
                'No checkpoint file of basemodel found. Start from the scratch.'
            )

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)
        '''
        total_parameters = 0
        for variable in tf.trainable_variables():
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            
            variable_parameters = 1
            sha = []
            for dim in shape:
                variable_parameters *= dim.value
                sha.append(dim.value)
            total_parameters += variable_parameters
            print(variable.name, sha)
            print ' '

        print(total_parameters)
        

        wri = open("op-name.csv", "a")
        for op in tf.get_default_graph().get_operations():
            wri.write(str(op.name) + "\n")

        fi = ti.TensorFI(sess, logLevel = 50, name = "convolutional", disableInjections=False)

        start_time = time.time()        
        for i in range(FLAGS.test_iter):
            fi.turnOffInjections()

            test_images_val, test_labels_val = sess.run([test_images[0], test_labels[0]])

#            img = test_images_val[i, :, :, :]
#            label = test_labels_val[i]
#            img = img.reshape((1,224,224,3))


            fi.turnOnInjections()
            probs = sess.run([ network.probs ],
                        feed_dict={network.is_train:False, images:test_images_val, labels:test_labels_val})
 
            probs = np.asarray(probs)

            try:    
                probs = probs[0][0]  
                preds = (np.argsort(probs)[::-1])[0:5]

                print preds, 'label: ', test_labels_val
            except:        
                pass
        '''
        '''
        probs = sess.run([network.probs],
                    feed_dict={network.is_train:False, images:test_images_val, labels:test_labels_val})      
        print( len(probs[0]) )

        probs = probs[0]
        for i in range(len(probs)):
            preds = probs[i, :] 
            pred = (np.argsort(preds)[::-1])[0:5]
            print pred, 'label: ', test_labels_val[i]
        '''

        wri = open("acyOnValSet.csv", "a")
        wri.write("top1" + "," + "top5" + "," + "numOfImg" + "\n")
        # Test!
        test_loss = 0.0
        test_acc = 0.0
        test_time = 0.0
        confusion_matrix = np.zeros((FLAGS.num_classes, FLAGS.num_classes),
                                    dtype=np.int32)

        numOfImg = 0
        top1 = 0.
        top5 = 0.
        for i in range(FLAGS.test_iter):
            test_images_val, test_labels_val = sess.run(
                [test_images[0], test_labels[0]])

            print(len(test_labels_val), test_images_val.shape,
                  test_labels_val.shape)

            break

            probs = sess.run(
                [network.probs],
                feed_dict={
                    network.is_train: False,
                    images: test_images_val,
                    labels: test_labels_val
                })

            probs = np.asarray(probs)
            probs = probs[0]

            counter = 0
            for each_prob in probs:
                pred = (np.argsort(each_prob)[::-1])[0:5]
                label = test_labels_val[counter]
                counter += 1

                if (label == pred[0]):
                    top1 += 1
                    top5 += 1
                elif (label in pred[1:]):
                    top5 += 1

                numOfImg += 1

            print('------------ evaluating on validation set', i, ' batch')
            print("top1: %f, top5: %f, numImg: %d" % (top1, top5, numOfImg))


#            wri.write(`acc_value` + "\n")
        wri.write( ` top1 ` + "," + ` top5 ` + "," + ` numOfImg ` + "\n")
        print top1 / numOfImg, top5 / numOfImg
        '''
예제 #11
0
def train():
    print('[Dataset Configuration]')
    print('\tImageNet training root: %s' % FLAGS.train_image_root)
    print('\tImageNet training list: %s' % FLAGS.train_dataset)
    print('\tImageNet test root: %s' % FLAGS.test_image_root)
    print('\tImageNet test list: %s' % FLAGS.test_dataset)
    print('\tImageNet class name list: %s' % FLAGS.class_list)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tNumber of GPUs: %d' % FLAGS.num_gpu)
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tSplitted Network: %s' % FLAGS.split)
    if FLAGS.split:
        print('\tClustering path: %s' % FLAGS.cluster_path)
        print('\tNo logit map: %s' % FLAGS.no_logit_map)

    print('[Testing Configuration]')
    print('\tCheckpoint path: %s' % FLAGS.ckpt_path)
    print('\tDataset: %s' % ('Training' if FLAGS.train_data else 'Test'))
    print('\tNumber of testing iterations: %d' % FLAGS.test_iter)
    print('\tOutput path: %s' % FLAGS.output)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():
        # The CIFAR-100 dataset
        with tf.variable_scope('test_image'):
            if FLAGS.train_data:
                test_images, test_labels = data_input.inputs(
                    FLAGS.train_image_root, FLAGS.train_dataset,
                    FLAGS.batch_size * FLAGS.num_gpu, False)
            else:
                test_images, test_labels = data_input.inputs(
                    FLAGS.test_image_root, FLAGS.test_dataset,
                    FLAGS.batch_size * FLAGS.num_gpu, False)

        # The class labels
        with open(FLAGS.class_list) as fd:
            classes = [temp.strip()[:30] for temp in fd.readlines()]

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(tf.float32, [
            FLAGS.batch_size * FLAGS.num_gpu, data_input.IMAGE_HEIGHT,
            data_input.IMAGE_WIDTH, 3
        ])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size * FLAGS.num_gpu])

        # Build model
        hp = resnet.HParams(num_gpu=FLAGS.num_gpu,
                            batch_size=FLAGS.batch_size,
                            split=FLAGS.split,
                            num_classes=FLAGS.num_classes,
                            weight_decay=FLAGS.l2_weight,
                            momentum=FLAGS.momentum,
                            no_logit_map=FLAGS.no_logit_map)
        network = resnet.ResNet(hp, images, labels, None)
        if FLAGS.split:
            network.set_clustering(clustering)
        network.build_model()
        print('%d flops' % network._flops)
        print('%d params' % network._weights)
        # network.build_train_op()  # NO training op

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        if os.path.isdir(FLAGS.ckpt_path):
            ckpt = tf.train.get_checkpoint_state(FLAGS.ckpt_path)
            # Restores from checkpoint
            if ckpt and ckpt.model_checkpoint_path:
                print('\tRestore from %s' % ckpt.model_checkpoint_path)
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print('No checkpoint file found in the dir [%s]' %
                      FLAGS.ckpt_path)
                sys.exit(1)
        elif os.path.isfile(FLAGS.ckpt_path):
            print('\tRestore from %s' % FLAGS.ckpt_path)
            saver.restore(sess, FLAGS.ckpt_path)
        else:
            print('No checkpoint file found in the path [%s]' %
                  FLAGS.ckpt_path)
            sys.exit(1)

        # Start queue runners
        tf.train.start_queue_runners(sess=sess)

        # Testing!
        result_ll = [[0, 0] for _ in range(FLAGS.num_classes)
                     ]  # Correct/wrong counts for each class
        test_loss = 0.0, 0.0
        for i in range(FLAGS.test_iter):
            test_images_val, test_labels_val = sess.run(
                [test_images, test_labels])
            preds_val, loss_value, acc_value = sess.run(
                [network.preds, network.loss, network.acc],
                feed_dict={
                    network.is_train: False,
                    images: test_images_val,
                    labels: test_labels_val
                })
            test_loss += loss_value
            for j in range(FLAGS.batch_size * FLAGS.num_gpu):
                correct = 0 if test_labels_val[j] == preds_val[j] else 1
                result_ll[test_labels_val[j] % FLAGS.num_classes][correct] += 1
            if i % FLAGS.display == 0:
                format_str = ('%s: (Test)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), i, loss_value, acc_value))
        test_loss /= FLAGS.test_iter

        # Summary display & output
        acc_list = [float(r[0]) / float(r[0] + r[1]) for r in result_ll]
        result_total = np.sum(np.array(result_ll), axis=0)
        acc_total = float(result_total[0]) / np.sum(result_total)

        print 'Class    \t\t\tT\tF\tAcc.'
        format_str = '%-31s %7d %7d %.5f'
        for i in range(FLAGS.num_classes):
            print format_str % (classes[i], result_ll[i][0], result_ll[i][1],
                                acc_list[i])
        print(format_str %
              ('(Total)', result_total[0], result_total[1], acc_total))

        # Output to file(if specified)
        if FLAGS.output.strip():
            with open(FLAGS.output, 'w') as fd:
                fd.write('Class    \t\t\tT\tF\tAcc.\n')
                format_str = '%-31s %7d %7d %.5f'
                for i in range(FLAGS.num_classes):
                    t, f = result_ll[i]
                    format_str = '%-31s %7d %7d %.5f\n'
                    fd.write(format_str %
                             (classes[i].replace(' ', '-'), t, f, acc_list[i]))
                fd.write(
                    format_str %
                    ('(Total)', result_total[0], result_total[1], acc_total))
예제 #12
0
def train():
    print('[Dataset Configuration]')
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tNumber of GPUs: %d' % FLAGS.num_gpus)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tMax steps to run: %d' % FLAGS.max_steps)
    print('\tTraining total epochs: %d' % FLAGS.epoch)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)
    print('\tCheckpoint to load: %d' %
          FLAGS.checkpoint if FLAGS.checkpoint is not None else -1)

    with tf.Graph().as_default():
        init_step = 0
        epoch_init = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of ImageNet
        import multiprocessing
        num_threads = multiprocessing.cpu_count() / FLAGS.num_gpus
        print('Load ImageNet dataset(%d threads)' % num_threads)
        with tf.device('/cpu:0'):
            with tf.variable_scope('train_image'):
                # train_images, train_labels = data_input.distorted_inputs(FLAGS.train_image_root, FLAGS.train_dataset
                #                                , FLAGS.batch_size, True, num_threads=num_threads, num_sets=FLAGS.num_gpus)
                train_images, train_labels = get_dataflow(
                    FLAGS.batch_size, FLAGS.num_gpus)
            # print('\tLoading validation data from %s' % FLAGS.val_dataset)
            # with tf.variable_scope('test_image'):
            #     val_images, val_labels = data_input.inputs(FLAGS.val_image_root, FLAGS.val_dataset
            #                                    , FLAGS.batch_size, False, num_threads=num_threads, num_sets=FLAGS.num_gpus)
            # tf.summary.image('images', train_images[0][:2])

        # Build model
        lr_decay_steps = map(float, FLAGS.lr_step_epoch.split(','))
        lr_decay_steps = map(int, [
            s * FLAGS.num_train_instance / FLAGS.batch_size / FLAGS.num_gpus
            for s in lr_decay_steps
        ])
        hp = resnet.HParams(
            batch_size=FLAGS.batch_size,
            num_gpus=FLAGS.num_gpus,
            weight_decay=FLAGS.l2_weight,
        )
        network_train = resnet.ResNet(hp,
                                      train_images,
                                      train_labels,
                                      global_step,
                                      name="train")
        network_train.build_model()
        network_train.build_train_op()
        train_summary_op = tf.summary.merge_all()  # Summaries(training)
        # network_val = resnet.ResNet(hp, val_images, val_labels, global_step, name="val", reuse_weights=True)
        # network_val.build_model()
        print('Number of Weights: %d' % network_train._weights)
        print('FLOPs: %d' % network_train._flops)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph.
        config = tf.ConfigProto()
        # config.gpu_options.allow_growth = True
        # config.gpu_options.visible_device_list = "1"
        config.gpu_options.per_process_gpu_memory_fraction = 0.9
        config.allow_soft_placement = False
        config.log_device_placement = FLAGS.log_device_placement
        # tf.ConfigProto(
        #     gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.9),
        #     allow_soft_placement=False,
        #     # allow_soft_placement=True,
        #     log_device_placement=FLAGS.log_device_placement)
        sess = tf.Session(config=config)
        sess.run(init)
        sess.run(tf.local_variables_initializer())

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000)
        if FLAGS.checkpoint is not None:
            print('Load checkpoint %s' % FLAGS.checkpoint)
            saver.restore(sess, FLAGS.checkpoint)
            init_step = global_step.eval(session=sess)
        else:
            print(
                'No checkpoint file of basemodel found. Start from the scratch.'
            )

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)

        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.summary.FileWriter(
            os.path.join(FLAGS.train_dir, str(global_step.eval(session=sess))),
            sess.graph)

        # Training!
        val_best_acc = 0.0
        # for epoch in range(epoch_init, FLAGS.epoch):
        for step in range(init_step, FLAGS.max_steps):
            # # val
            # if step % FLAGS.val_interval == 0:
            #     val_loss, val_acc = 0.0, 0.0
            #     for i in range(FLAGS.val_iter):
            #         loss_value, acc_value = sess.run([network_val.loss, network_val.acc],
            #                     feed_dict={network_val.is_train:False})
            #         val_loss += loss_value
            #         val_acc += acc_value
            #     val_loss /= FLAGS.val_iter
            #     val_acc /= FLAGS.val_iter
            #     val_best_acc = max(val_best_acc, val_acc)
            #     format_str = ('%s: (val)     step %d, loss=%.4f, acc=%.4f')
            #     print (format_str % (datetime.now(), step, val_loss, val_acc))
            #
            #     val_summary = tf.Summary()
            #     val_summary.value.add(tag='val/loss', simple_value=val_loss)
            #     val_summary.value.add(tag='val/acc', simple_value=val_acc)
            #     val_summary.value.add(tag='val/best_acc', simple_value=val_best_acc)
            #     summary_writer.add_summary(val_summary, step)
            #     summary_writer.flush()

            # Train
            lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps,
                              step)
            start_time = time.time()
            _, loss_value, acc_value, train_summary_str = \
                    sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op],
                             feed_dict={network_train.is_train: True, network_train.lr: lr_value})
            duration = time.time() - start_time

            assert not np.isnan(loss_value)

            # Display & Summary(training)
            # if step % FLAGS.display == 0 or step < 10:
            if True:
                num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, epoch %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, 0, loss_value, acc_value,
                       lr_value, examples_per_sec, sec_per_batch))
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
예제 #13
0
def cifar10_blackbox(nb_classes=10,
                     batch_size=128,
                     nb_samples=10,
                     l2_weight=0.0001,
                     momentum=0.9,
                     initial_lr=0.1,
                     lr_step_epoch=100.0,
                     lr_decay=0.1,
                     num_residual_units=2,
                     num_train_instance=50000,
                     num_test_instance=10000,
                     k=1,
                     eps=0.3,
                     learning_rate=0.001,
                     nb_epochs=10,
                     holdout=150,
                     data_aug=6,
                     nb_epochs_s=10,
                     lmbda=0.1,
                     binary=False,
                     scale=False,
                     model_path=None,
                     targeted=False,
                     data_dir=None,
                     adv=False,
                     delay=0):
    """
    MNIST tutorial for the black-box attack from arxiv.org/abs/1602.02697
    :param train_start: index of first training set example
    :param train_end: index of last training set example
    :param test_start: index of first test set example
    :param test_end: index of last test set example
    :return: a dictionary with:
             * black-box model accuracy on test set
             * substitute model accuracy on test set
             * black-box model accuracy on adversarial examples transferred
               from the substitute model
    """

    # Set logging level to see debug information
    set_log_level(logging.DEBUG)

    # Dictionary used to keep track and return key accuracies
    accuracies = {}

    # Perform tutorial setup
    assert setup_tutorial()

    if not hasattr(backend, "tf"):
        raise RuntimeError("This tutorial requires keras to be configured"
                           " to use the TensorFlow backend.")

    # Image dimensions ordering should follow the Theano convention
    if keras.backend.image_dim_ordering() != 'tf':
        keras.backend.set_image_dim_ordering('tf')
        print("INFO: '~/.keras/keras.json' sets 'image_dim_ordering' to "
              "'th', temporarily setting to 'tf'")

    # Create TF session and set as Keras backend session
    sess = tf.Session()
    keras.backend.set_session(sess)

    # Get CIFAR10 test data
    X_train, Y_train, X_test, Y_test = data_cifar10_std()

    # Y_train_onehot = np_utils.to_categorical(Y_train, nb_classes)
    Y_test_onehot = np_utils.to_categorical(Y_test, nb_classes)

    # Y_test is for evaluating oracle
    Y_test_bbox = np.argmax(Y_test, axis=1)
    Y_test_bbox = Y_test_bbox.reshape(Y_test_bbox.shape[0], )
    Y_test_bbox = Y_test_bbox.astype('int32')

    #Y_test = Y_test.reshape(Y_test.shape[0],)
    #Y_test = Y_test.astype('int32')
    #Y_train = Y_train.astype('int32')

    # Initialize substitute training set reserved for adversary
    X_sub = X_test[:holdout]
    Y_sub = np.argmax(Y_test_onehot[:holdout], axis=1)

    # Redefine test set as remaining samples unavailable to adversaries
    X_test = X_test[holdout:]
    Y_test = Y_test[holdout:]

    # CIFAR10-specific dimensions
    img_rows = 32
    img_cols = 32
    channels = 3

    rng = np.random.RandomState([2017, 8, 30])

    # with tf.Graph().as_default():

    # Define input and output TF placeholders
    x = tf.placeholder(tf.float32, shape=(None, img_rows, img_cols, channels))
    y = tf.placeholder(tf.int32, shape=(None))

    phase = tf.placeholder(tf.bool, name='phase')
    y_s = tf.placeholder(tf.float32, shape=(None, nb_classes))

    # Seed random number generator so tutorial is reproducible

    # Simulate the black-box model locally
    # You could replace this by a remote labeling API for instance
    print("Preparing the WideResNet black-box model.")
    '''
    prep_bbox_out = prep_bbox(sess, x, y, X_train, Y_train, X_test, Y_test,
                              img_rows, img_cols, channels, nb_epochs, batch_size, learning_rate,
                              rng=rng, phase=phase, binary=binary, scale=scale,
                              nb_filters=nb_filters, model_path=model_path,
                              adv=adv, delay=delay, eps=eps)

    model, bbox_preds, accuracies['bbox'], model_path = prep_bbox_out
    '''
    decay_step = lr_step_epoch * num_train_instance / batch_size
    hp = resnet.HParams(batch_size=batch_size,
                        num_classes=nb_classes,
                        num_residual_units=num_residual_units,
                        k=k,
                        weight_decay=l2_weight,
                        initial_lr=initial_lr,
                        decay_step=decay_step,
                        lr_decay=lr_decay,
                        momentum=momentum)

    print(binary)
    binary = True if binary else False
    print(binary)
    network = resnet.ResNet(binary, hp, x, y, None)
    network.build_model()

    # bbox_preds = network.preds
    bbox_preds = network.probs

    init = tf.global_variables_initializer()
    sess.run(init)

    # Create a saver.
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000)

    if 'model' in model_path.split('/')[-1]:
        saver.restore(sess, model_path)
        print('restored %s' % model_path)
    else:
        saver.restore(sess, tf.train.latest_checkpoint(model_path))
        print('restored %s' % model_path)
    '''
    if os.path.isdir(model_path):
        ckpt = tf.train.get_checkpoint_state(model_path)
        # Restores from checkpoint
        if ckpt and ckpt.model_checkpoint_path:
            print('\tRestore from %s' % ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print('No checkpoint file found in the dir [%s]' % model_path)
            sys.exit(1)
    elif os.path.isfile(model_path):
        print('\tRestore from %s' % model_path)
        saver.restore(sess, model_path)
    else:
        print('No checkpoint file found in the path [%s]' % model_path)
        sys.exit(1)
    '''

    eval_params = {'batch_size': batch_size}
    acc = model_eval(sess,
                     x,
                     y,
                     bbox_preds,
                     X_test,
                     Y_test,
                     phase=phase,
                     args=eval_params)
    print('Test accuracy of black-box on legitimate test examples: %.4f' % acc)
예제 #14
0
import cPickle as pickle

import tensorflow as tf
import numpy as np

import resnet

model_pkl_fname = 'baseline/ResNet-50.pkl'
model_ckpt_fname = 'baseline/ResNet50.ckpt'

# Build model to load weights
global_step = tf.Variable(0, trainable=False, name='global_step')
images = tf.placeholder(tf.float32, [100, 224, 224, 3])
labels = tf.placeholder(tf.int32, [100])
hp = resnet.HParams(batch_size=100,
                    num_classes=1000,
                    weight_decay=0.0005,
                    momentum=0.9)
network = resnet.ResNet(hp, images, labels, global_step)
network.build_model()

# Load pkl weight file
print('Load pkl weight file')
with open(model_pkl_fname) as fd:
    weights = pickle.load(fd)

# Build an initialization operation to run below.
init = tf.initialize_all_variables()
sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
    per_process_gpu_memory_fraction=0.96),
                                        log_device_placement=False))
sess.run(init)
예제 #15
0
def train():
    print('[Dataset Configuration]')
    print('\tImageNet test root: %s' % FLAGS.test_image_root)
    print('\tImageNet test list: %s' % FLAGS.test_dataset)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tCheckpoint file: %s' % FLAGS.checkpoint)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Evaluation Configuration]')
    print('\tOutput file path: %s' % FLAGS.output_file)
    print('\tTest iterations: %d' % FLAGS.test_iter)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of ImageNet
        print('Load ImageNet dataset')
        with tf.device('/cpu:0'):
            print('\tLoading test data from %s' % FLAGS.test_dataset)
            with tf.variable_scope('test_image'):
                test_images, test_labels = data_input.inputs(
                    FLAGS.test_image_root,
                    FLAGS.test_dataset,
                    FLAGS.batch_size,
                    False,
                    num_threads=1,
                    center_crop=True)

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(tf.float32, [
            FLAGS.batch_size, data_input.IMAGE_HEIGHT, data_input.IMAGE_WIDTH,
            3
        ])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # Build model
        with tf.device('/GPU:0'):
            hp = resnet.HParams(batch_size=FLAGS.batch_size,
                                num_gpus=1,
                                num_classes=FLAGS.num_classes,
                                weight_decay=FLAGS.l2_weight,
                                momentum=FLAGS.momentum,
                                finetune=FLAGS.finetune)
        network = resnet.ResNet(hp, [images], [labels], global_step)
        network.build_model()
        print('\tNumber of Weights: %d' % network._weights)
        print('\tFLOPs: %d' % network._flops)

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        '''debugging attempt
        from tensorflow.python import debug as tf_debug
        sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        def _get_data(datum, tensor):
            return tensor == train_images
        sess.add_tensor_filter("get_data", _get_data)
        '''

        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        if FLAGS.checkpoint is not None:
            saver.restore(sess, FLAGS.checkpoint)
            print('Load checkpoint %s' % FLAGS.checkpoint)
        else:
            print(
                'No checkpoint file of basemodel found. Start from the scratch.'
            )

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)

        fi = ti.TensorFI(sess,
                         logLevel=50,
                         name="convolutional",
                         disableInjections=False)

        # save the results
        t1 = open("fi-org-resnet-top1.csv", "a")
        t5 = open("fi-org-resnet-top5.csv", "a")

        fiTime = 1000

        for i in range(FLAGS.test_iter):
            fi.turnOffInjections()

            test_images_val, test_labels_val = sess.run(
                [test_images[0], test_labels[0]])

            fi.turnOnInjections()
            for j in range(fiTime):
                probs = sess.run(
                    [network.probs],
                    feed_dict={
                        network.is_train: False,
                        images: test_images_val,
                        labels: test_labels_val
                    })

                probs = np.asarray(probs)
                probs = probs[0]

                counter = 0
                for each_prob in probs:
                    pred = (np.argsort(each_prob)[::-1])[0:5]
                    label = test_labels_val[counter]
                    counter += 1

                    print(pred, 'label:', label)

                    if (label == pred[0]):
                        t1.write( ` 1 ` + ",")
                        t5.write( ` 1 ` + ",")
                    elif (label in pred[1:]):
                        t1.write( ` 0 ` + ",")
                        t5.write( ` 1 ` + ",")
                    else:
                        t1.write( ` 0 ` + ",")
                        t5.write( ` 0 ` + ",")

                    print('--------fi on resnet, %d img, %d FI run' %
                          (i + 1, j + 1))

            t1.write("\n")
            t5.write("\n")
예제 #16
0
def train():
    print('[Dataset Configuration]')
    print('\tImageNet training root: %s' % FLAGS.train_image_root)
    print('\tImageNet training list: %s' % FLAGS.train_dataset)
    print('\tImageNet val root: %s' % FLAGS.val_image_root)
    print('\tImageNet val list: %s' % FLAGS.val_dataset)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of val images: %d' % FLAGS.num_val_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tNumber of GPUs: %d' % FLAGS.num_gpus)
    print('\tNumber of Groups: %d-%d-%d' %
          (FLAGS.ngroups3, FLAGS.ngroups2, FLAGS.ngroups1))
    print('\tBasemodel file: %s' % FLAGS.basemodel)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tOverlap loss weight: %f' % FLAGS.gamma1)
    print('\tWeight split loss weight: %f' % FLAGS.gamma2)
    print('\tUniform loss weight: %f' % FLAGS.gamma3)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tNo update on BN scale parameter: %d' % FLAGS.bn_no_scale)
    print('\tWeighted split loss: %d' % FLAGS.weighted_group_loss)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per validation: %d' % FLAGS.val_interval)
    print('\tSteps during validation: %d' % FLAGS.val_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of ImageNet
        import multiprocessing
        num_threads = multiprocessing.cpu_count() / FLAGS.num_gpus
        print('Load ImageNet dataset(%d threads)' % num_threads)
        with tf.device('/cpu:0'):
            print('\tLoading training data from %s' % FLAGS.train_dataset)
            with tf.variable_scope('train_image'):
                train_images, train_labels = data_input.distorted_inputs(
                    FLAGS.train_image_root,
                    FLAGS.train_dataset,
                    FLAGS.batch_size,
                    True,
                    num_threads=num_threads,
                    num_sets=FLAGS.num_gpus)
            # tf.summary.image('images', train_images[0])
            print('\tLoading validation data from %s' % FLAGS.val_dataset)
            with tf.variable_scope('test_image'):
                val_images, val_labels = data_input.inputs(
                    FLAGS.val_image_root,
                    FLAGS.val_dataset,
                    FLAGS.batch_size,
                    False,
                    num_threads=num_threads,
                    num_sets=FLAGS.num_gpus)

        # Build model
        lr_decay_steps = map(float, FLAGS.lr_step_epoch.split(','))
        lr_decay_steps = map(int, [
            s * FLAGS.num_train_instance / FLAGS.batch_size / FLAGS.num_gpus
            for s in lr_decay_steps
        ])
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_gpus=FLAGS.num_gpus,
                            num_classes=FLAGS.num_classes,
                            weight_decay=FLAGS.l2_weight,
                            ngroups1=FLAGS.ngroups1,
                            ngroups2=FLAGS.ngroups2,
                            ngroups3=FLAGS.ngroups3,
                            gamma1=FLAGS.gamma1,
                            gamma2=FLAGS.gamma2,
                            gamma3=FLAGS.gamma3,
                            momentum=FLAGS.momentum,
                            bn_no_scale=FLAGS.bn_no_scale,
                            weighted_group_loss=FLAGS.weighted_group_loss,
                            finetune=FLAGS.finetune)
        network_train = resnet.ResNet(hp,
                                      train_images,
                                      train_labels,
                                      global_step,
                                      name="train")
        network_train.build_model()
        network_train.build_train_op()
        train_summary_op = tf.summary.merge_all()  # Summaries(training)
        network_val = resnet.ResNet(hp,
                                    val_images,
                                    val_labels,
                                    global_step,
                                    name="val",
                                    reuse_weights=True)
        network_val.build_model()
        print('Number of Weights: %d' % network_train._weights)
        print('FLOPs: %d' % network_train._flops)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            # allow_soft_placement=False,
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000)
        if FLAGS.checkpoint is not None:
            print('Load checkpoint %s' % FLAGS.checkpoint)
            saver.restore(sess, FLAGS.checkpoint)
            init_step = global_step.eval(session=sess)
        elif FLAGS.basemodel:
            # Define a different saver to save model checkpoints
            print('Load parameters from basemodel %s' % FLAGS.basemodel)
            variables = tf.global_variables()
            vars_restore = [
                var for var in variables if not "Momentum" in var.name
                and not "group" in var.name and not "global_step" in var.name
            ]
            saver_restore = tf.train.Saver(vars_restore, max_to_keep=10000)
            saver_restore.restore(sess, FLAGS.basemodel)
        else:
            print(
                'No checkpoint file of basemodel found. Start from the scratch.'
            )

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)

        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.summary.FileWriter(
            os.path.join(FLAGS.train_dir, str(global_step.eval(session=sess))),
            sess.graph)

        # Training!
        val_best_acc = 0.0
        for step in xrange(init_step, FLAGS.max_steps):
            # val
            if step % FLAGS.val_interval == 0:
                val_loss, val_acc = 0.0, 0.0
                for i in range(FLAGS.val_iter):
                    loss_value, acc_value = sess.run(
                        [network_val.loss, network_val.acc],
                        feed_dict={network_val.is_train: False})
                    val_loss += loss_value
                    val_acc += acc_value
                val_loss /= FLAGS.val_iter
                val_acc /= FLAGS.val_iter
                val_best_acc = max(val_best_acc, val_acc)
                format_str = ('%s: (val)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), step, val_loss, val_acc))

                val_summary = tf.Summary()
                val_summary.value.add(tag='val/loss', simple_value=val_loss)
                val_summary.value.add(tag='val/acc', simple_value=val_acc)
                val_summary.value.add(tag='val/best_acc',
                                      simple_value=val_best_acc)
                summary_writer.add_summary(val_summary, step)
                summary_writer.flush()

            # Train
            lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps,
                              step)
            start_time = time.time()
            if step == 153:
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()
                _, loss_value, acc_value, train_summary_str = \
                        sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op],
                                 feed_dict={network_train.is_train:True, network_train.lr:lr_value}
                                 , options=run_options, run_metadata=run_metadata)
                _ = sess.run(network_train.validity_op)
                # Create the Timeline object, and write it to a json
                tl = timeline.Timeline(run_metadata.step_stats)
                ctf = tl.generate_chrome_trace_format()
                with open('timeline.json', 'w') as f:
                    f.write(ctf)
                print('Wrote the timeline profile of %d iter training on %s' %
                      (step, 'timeline.json'))
            else:
                _, loss_value, acc_value, train_summary_str = \
                        sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op],
                                feed_dict={network_train.is_train:True, network_train.lr:lr_value})
                _ = sess.run(network_train.validity_op)
            duration = time.time() - start_time

            assert not np.isnan(loss_value)

            # Display & Summary(training)
            if step % FLAGS.display == 0 or step < 10:
                num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, acc_value, lr_value,
                       examples_per_sec, sec_per_batch))
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

            # Does it work correctly?
            # if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
            # char = sys.stdin.read(1)
            # if char == 'b':
            # embed()

            # Add weights and groupings visualization
            filters = [64, [64, 256], [128, 512], [256, 1024], [512, 2048]]
            if FLAGS.group_summary_interval is not None:
                if step % FLAGS.group_summary_interval == 0:
                    img_summaries = []

                    if FLAGS.ngroups1 > 1:
                        logits_weights = get_var_value('logits/fc/weights',
                                                       sess)
                        split_p1 = get_var_value('group/split_p1/q', sess)
                        split_q1 = get_var_value('group/split_q1/q', sess)
                        feature_indices = np.argsort(
                            np.argmax(split_p1, axis=0))
                        class_indices = np.argsort(np.argmax(split_q1, axis=0))

                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_p1[:, feature_indices],
                                          20,
                                          axis=0), 'split_p1'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_q1[:, class_indices],
                                          200,
                                          axis=0), 'split_q1'))
                        img_summaries.append(
                            img_to_summary(
                                np.abs(logits_weights[feature_indices, :]
                                       [:, class_indices]), 'logits'))

                    if FLAGS.ngroups2 > 1:
                        conv5_1_shortcut = get_var_value(
                            'conv5_1/conv_shortcut/kernel', sess)
                        conv5_1_conv_1 = get_var_value('conv5_1/conv_1/kernel',
                                                       sess)
                        conv5_1_conv_2 = get_var_value('conv5_1/conv_2/kernel',
                                                       sess)
                        conv5_1_conv_3 = get_var_value('conv5_1/conv_3/kernel',
                                                       sess)
                        conv5_2_conv_1 = get_var_value('conv5_2/conv_1/kernel',
                                                       sess)
                        conv5_2_conv_2 = get_var_value('conv5_2/conv_2/kernel',
                                                       sess)
                        conv5_2_conv_3 = get_var_value('conv5_2/conv_3/kernel',
                                                       sess)
                        conv5_3_conv_1 = get_var_value('conv5_3/conv_1/kernel',
                                                       sess)
                        conv5_3_conv_2 = get_var_value('conv5_3/conv_2/kernel',
                                                       sess)
                        conv5_3_conv_3 = get_var_value('conv5_3/conv_3/kernel',
                                                       sess)
                        split_p2 = get_var_value('group/split_p2/q', sess)
                        split_q2 = _merge_split_q(
                            split_p1,
                            _get_even_merge_idxs(FLAGS.ngroups1,
                                                 FLAGS.ngroups2))
                        split_r211 = get_var_value('group/split_r211/q', sess)
                        split_r212 = get_var_value('group/split_r212/q', sess)
                        split_r221 = get_var_value('group/split_r221/q', sess)
                        split_r222 = get_var_value('group/split_r222/q', sess)
                        split_r231 = get_var_value('group/split_r231/q', sess)
                        split_r232 = get_var_value('group/split_r232/q', sess)
                        feature_indices1 = np.argsort(
                            np.argmax(split_p2, axis=0))
                        feature_indices2 = np.argsort(
                            np.argmax(split_q2, axis=0))
                        feature_indices3 = np.argsort(
                            np.argmax(split_r211, axis=0))
                        feature_indices4 = np.argsort(
                            np.argmax(split_r212, axis=0))
                        feature_indices5 = np.argsort(
                            np.argmax(split_r221, axis=0))
                        feature_indices6 = np.argsort(
                            np.argmax(split_r222, axis=0))
                        feature_indices7 = np.argsort(
                            np.argmax(split_r231, axis=0))
                        feature_indices8 = np.argsort(
                            np.argmax(split_r232, axis=0))
                        conv5_1_shortcut_img = np.abs(
                            conv5_1_shortcut[:, :, feature_indices1, :]
                            [:, :, :,
                             feature_indices2].transpose([2, 0, 3, 1]).reshape(
                                 filters[3][1], filters[4][1]))
                        conv5_1_conv_1_img = np.abs(
                            conv5_1_conv_1[:, :, feature_indices1, :]
                            [:, :, :,
                             feature_indices3].transpose([2, 0, 3, 1]).reshape(
                                 filters[3][1], filters[4][0]))
                        conv5_1_conv_2_img = np.abs(
                            conv5_1_conv_2[:, :, feature_indices3, :]
                            [:, :, :,
                             feature_indices4].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][0] * 3, filters[4][0] * 3))
                        conv5_1_conv_3_img = np.abs(
                            conv5_1_conv_3[:, :, feature_indices4, :]
                            [:, :, :,
                             feature_indices2].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][0], filters[4][1]))
                        conv5_2_conv_1_img = np.abs(
                            conv5_2_conv_1[:, :, feature_indices2, :]
                            [:, :, :,
                             feature_indices5].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][1], filters[4][0]))
                        conv5_2_conv_2_img = np.abs(
                            conv5_2_conv_2[:, :, feature_indices5, :]
                            [:, :, :,
                             feature_indices6].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][0] * 3, filters[4][0] * 3))
                        conv5_2_conv_3_img = np.abs(
                            conv5_2_conv_3[:, :, feature_indices6, :]
                            [:, :, :,
                             feature_indices2].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][0], filters[4][1]))
                        conv5_3_conv_1_img = np.abs(
                            conv5_3_conv_1[:, :, feature_indices2, :]
                            [:, :, :,
                             feature_indices7].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][1], filters[4][0]))
                        conv5_3_conv_2_img = np.abs(
                            conv5_3_conv_2[:, :, feature_indices7, :]
                            [:, :, :,
                             feature_indices8].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][0] * 3, filters[4][0] * 3))
                        conv5_3_conv_3_img = np.abs(
                            conv5_3_conv_3[:, :, feature_indices8, :]
                            [:, :, :,
                             feature_indices2].transpose([2, 0, 3, 1]).reshape(
                                 filters[4][0], filters[4][1]))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_p2[:, feature_indices1],
                                          20,
                                          axis=0), 'split_p2'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r211[:, feature_indices3],
                                          20,
                                          axis=0), 'split_r211'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r212[:, feature_indices4],
                                          20,
                                          axis=0), 'split_r212'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r221[:, feature_indices5],
                                          20,
                                          axis=0), 'split_r221'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r222[:, feature_indices6],
                                          20,
                                          axis=0), 'split_r222'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r231[:, feature_indices7],
                                          20,
                                          axis=0), 'split_r231'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r232[:, feature_indices8],
                                          20,
                                          axis=0), 'split_r232'))
                        img_summaries.append(
                            img_to_summary(conv5_1_shortcut_img,
                                           'conv5_1/shortcut'))
                        img_summaries.append(
                            img_to_summary(conv5_1_conv_1_img,
                                           'conv5_1/conv_1'))
                        img_summaries.append(
                            img_to_summary(conv5_1_conv_2_img,
                                           'conv5_1/conv_2'))
                        img_summaries.append(
                            img_to_summary(conv5_1_conv_3_img,
                                           'conv5_1/conv_3'))
                        img_summaries.append(
                            img_to_summary(conv5_2_conv_1_img,
                                           'conv5_2/conv_1'))
                        img_summaries.append(
                            img_to_summary(conv5_2_conv_2_img,
                                           'conv5_2/conv_2'))
                        img_summaries.append(
                            img_to_summary(conv5_2_conv_3_img,
                                           'conv5_2/conv_3'))
                        img_summaries.append(
                            img_to_summary(conv5_3_conv_1_img,
                                           'conv5_3/conv_1'))
                        img_summaries.append(
                            img_to_summary(conv5_3_conv_2_img,
                                           'conv5_3/conv_2'))
                        img_summaries.append(
                            img_to_summary(conv5_3_conv_3_img,
                                           'conv5_3/conv_3'))

                    # if FLAGS.ngroups3 > 1:
                    # conv4_1_shortcut = get_var_value('conv4_1/conv_shortcut/kernel', sess)
                    # conv4_1_conv_1 = get_var_value('conv4_1/conv_1/kernel', sess)
                    # conv4_1_conv_2 = get_var_value('conv4_1/conv_2/kernel', sess)
                    # conv4_2_conv_1 = get_var_value('conv4_2/conv_1/kernel', sess)
                    # conv4_2_conv_2 = get_var_value('conv4_2/conv_2/kernel', sess)
                    # split_p3 = get_var_value('group/split_p3/q', sess)
                    # split_q3 = _merge_split_q(split_p2, _get_even_merge_idxs(FLAGS.ngroups2, FLAGS.ngroups3))
                    # split_r31 = get_var_value('group/split_r31/q', sess)
                    # split_r32 = get_var_value('group/split_r32/q', sess)
                    # feature_indices1 = np.argsort(np.argmax(split_p3, axis=0))
                    # feature_indices2 = np.argsort(np.argmax(split_q3, axis=0))
                    # feature_indices3 = np.argsort(np.argmax(split_r31, axis=0))
                    # feature_indices4 = np.argsort(np.argmax(split_r32, axis=0))
                    # conv4_1_shortcut_img = np.abs(conv4_1_shortcut[:,:,feature_indices1,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[2], filters[3]))
                    # conv4_1_conv_1_img = np.abs(conv4_1_conv_1[:,:,feature_indices1,:][:,:,:,feature_indices3].transpose([2,0,3,1]).reshape(filters[2] * 3, filters[3] * 3))
                    # conv4_1_conv_2_img = np.abs(conv4_1_conv_2[:,:,feature_indices3,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[3] * 3, filters[3] * 3))
                    # conv4_2_conv_1_img = np.abs(conv4_2_conv_1[:,:,feature_indices2,:][:,:,:,feature_indices4].transpose([2,0,3,1]).reshape(filters[3] * 3, filters[3] * 3))
                    # conv4_2_conv_2_img = np.abs(conv4_2_conv_2[:,:,feature_indices4,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[3] * 3, filters[3] * 3))
                    # img_summaries.append(img_to_summary(np.repeat(split_p3[:, feature_indices1], 20, axis=0), 'split_p3'))
                    # img_summaries.append(img_to_summary(np.repeat(split_r31[:, feature_indices3], 20, axis=0), 'split_r31'))
                    # img_summaries.append(img_to_summary(np.repeat(split_r32[:, feature_indices4], 20, axis=0), 'split_r32'))
                    # img_summaries.append(img_to_summary(conv4_1_shortcut_img, 'conv4_1/shortcut'))
                    # img_summaries.append(img_to_summary(conv4_1_conv_1_img, 'conv4_1/conv_1'))
                    # img_summaries.append(img_to_summary(conv4_1_conv_2_img, 'conv4_1/conv_2'))
                    # img_summaries.append(img_to_summary(conv4_2_conv_1_img, 'conv4_2/conv_1'))
                    # img_summaries.append(img_to_summary(conv4_2_conv_2_img, 'conv4_2/conv_2'))

                    if img_summaries:
                        img_summary = tf.Summary(value=img_summaries)
                        summary_writer.add_summary(img_summary, step)
                        summary_writer.flush()
예제 #17
0
def train():
    print('[Dataset Configuration]')
    # print('\tImageNet training root: %s' % FLAGS.train_image_root)
    print('\tImageNet training list: %s' % FLAGS.train_dir)
    # print('\tImageNet val root: %s' % FLAGS.val_image_root)
    print('\tImageNet val list: %s' % FLAGS.val_dir)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of val images: %d' % FLAGS.num_val_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tNumber of GPUs: %d' % FLAGS.num_gpus)
    print('\tBasemodel file: %s' % FLAGS.basemodel)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tlog dir: %s' % FLAGS.log_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per validation: %d' % FLAGS.val_interval)
    print('\tSteps during validation: %d' % FLAGS.val_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')
        # Build model
        lr_decay_steps = map(float, FLAGS.lr_step_epoch.split(','))
        lr_decay_steps = map(int, [
            s * FLAGS.num_train_instance / FLAGS.batch_size / FLAGS.num_gpus
            for s in lr_decay_steps
        ])
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_gpus=FLAGS.num_gpus,
                            num_classes=FLAGS.num_classes,
                            weight_decay=FLAGS.l2_weight,
                            momentum=FLAGS.momentum,
                            finetune=FLAGS.finetune)
        network_train = resnet.ResNet(hp, global_step, name="train")
        network_train.build_model()
        network_train.build_train_op()
        train_summary_op = tf.summary.merge_all()  # Summaries(training)
        network_val = resnet.ResNet(hp,
                                    global_step,
                                    name="val",
                                    reuse_weights=True)
        network_val.build_model()
        print('Number of Weights: %d' % network_train._weights)
        print('FLOPs: %d' % network_train._flops)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            # allow_soft_placement=False,
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000)
        if FLAGS.checkpoint is not None:
            print('Load checkpoint %s' % FLAGS.checkpoint)
            saver.restore(sess, FLAGS.checkpoint)
            init_step = global_step.eval(session=sess)
        elif FLAGS.basemodel:
            # Define a different saver to save model checkpoints
            print('Load parameters from basemodel %s' % FLAGS.basemodel)
            variables = tf.global_variables()
            vars_restore = [
                var for var in variables if not "Momentum" in var.name
                and not "global_step" in var.name and not "logits" in var.name
            ]
            saver_restore = tf.train.Saver(vars_restore, max_to_keep=10000)
            saver_restore.restore(sess, FLAGS.basemodel)
            # vars_fc = [var for var in variables
            #            if "logtis" in var.name and
            #            not "Momentum" in var.name and
            #            not "global_step" in var.name]
            # init_fc = tf.contrib.layers.xavier_initializer()
            # sess.run(init_fc)
        else:
            print(
                'No checkpoint file of basemodel found. Start from the scratch.'
            )

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)

        # if not os.path.exists(FLAGS.train_dir):
        #     os.mkdir(FLAGS.train_dir)
        summary_writer = tf.summary.FileWriter(
            os.path.join(FLAGS.log_dir, str(global_step.eval(session=sess))),
            sess.graph)

        # Training!
        train_data = stack_obj_eps(FLAGS.train_dir)
        val_data = stack_obj_eps(FLAGS.val_dir)
        train_data_splited = split_stack_infos(train_data)
        batch_size_list = [64, 8, 8]

        val_best_acc = 0.0
        for step in xrange(init_step, FLAGS.max_steps):
            # val
            if step % FLAGS.val_interval == 0:
                val_loss, val_acc, val_prec, val_rec, val_f1 = 0.0, 0.0, 0.0, 0.0, 0.0
                val_conf_mat = np.zeros((3, 3))
                for i in range(FLAGS.val_iter):
                    val_imgs, val_labels = rand_imgs_acts(
                        val_data, FLAGS.batch_size)
                    loss_value, acc_value, preds = sess.run(
                        [network_val.loss, network_val.acc, network_val.preds],
                        feed_dict={
                            network_val._images: val_imgs,
                            network_val._labels: val_labels,
                            network_val.is_train: False
                        })
                    # preds = np.zeros(FLAGS.batch_size)
                    # y_true = sess.run(val_labels)
                    # y_true = y_true[0]
                    y_true = val_labels
                    # print "y_true: ", y_true
                    # print "y_pred: ", preds
                    prec_value, rec_value, f1_value, conf_mat_value = evaluate(
                        y_true, preds, labels=[0, 1, 2])
                    val_loss += loss_value
                    val_acc += acc_value
                    val_prec += prec_value
                    val_rec += rec_value
                    val_f1 += f1_value
                    val_conf_mat += conf_mat_value
                val_loss /= FLAGS.val_iter
                val_acc /= FLAGS.val_iter
                val_prec /= FLAGS.val_iter
                val_rec /= FLAGS.val_iter
                val_f1 /= FLAGS.val_iter
                val_conf_mat /= FLAGS.val_iter
                val_best_acc = max(val_best_acc, val_acc)
                format_str = ('%s: (val)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), step, val_loss, val_acc))
                print "val_prec: ", val_prec
                print "val_rec: ", val_rec
                print "val_f1: ", val_f1
                print "val confusion matrix: "
                print val_conf_mat
                val_summary = tf.Summary()
                val_summary.value.add(tag='val/loss', simple_value=val_loss)
                val_summary.value.add(tag='val/acc', simple_value=val_acc)
                val_summary.value.add(tag='val/best_acc',
                                      simple_value=val_best_acc)
                val_summary.value.add(tag='val/prec',
                                      simple_value=np.mean(val_prec))
                val_summary.value.add(tag='val/rec',
                                      simple_value=np.mean(val_rec))
                val_summary.value.add(tag='val/f1',
                                      simple_value=np.mean(val_f1))
                summary_writer.add_summary(val_summary, step)
                summary_writer.flush()

            # Train
            lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps,
                              step)
            start_time = time.time()
            train_imgs, train_labels = rand_imgs_acts_specify_batch_size(
                train_data_splited, batch_size_list)
            _, loss_value, acc_value, train_summary_str, preds = \
                    sess.run([network_train.train_op, network_train.loss, network_train.acc, train_summary_op, network_train.preds],
                            feed_dict={network_train._images:train_imgs, network_train._labels:train_labels,
                                       network_train.is_train:True, network_train.lr:lr_value})
            # preds = np.zeros(FLAGS.batch_size)
            # y_true = sess.run(train_labels)
            y_true = train_labels
            # print "y_true: ", y_true
            # print "y_pred: ", preds
            train_prec, train_rec, train_f1, train_conf_mat = evaluate(
                y_true, preds, labels=[0, 1, 2])

            duration = time.time() - start_time
            # sys.stdout.flush()
            assert not np.isnan(loss_value)

            # Display & Summary(training)
            if step % FLAGS.display == 0 or step < 10:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = ('%s: (Training) step %d, loss=%.4f, '
                              'acc=%.4f, '
                              'lr=%f (%.1f examples/sec; %.3f '
                              'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, acc_value, lr_value,
                       examples_per_sec, sec_per_batch))
                print "train_prec: ", train_prec
                print "train_rec: ", train_rec
                print "train_f1: ", train_f1
                print "train confusion matrix: "
                print train_conf_mat
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step >= init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.log_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

            if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
                char = sys.stdin.read(1)
                if char == 'b':
                    embed()
예제 #18
0
def train():
    print('[Dataset Configuration]')
    print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of val images: %d' % FLAGS.num_val_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
    print('\tNetwork width multiplier: %d' % FLAGS.k)
    print('\tNumber of Groups: %d-%d-%d' %
          (FLAGS.ngroups3, FLAGS.ngroups2, FLAGS.ngroups1))
    print('\tBasemodel file: %s' % FLAGS.basemodel)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tOverlap loss weight: %f' % FLAGS.gamma1)
    print('\tWeight split loss weight: %f' % FLAGS.gamma2)
    print('\tUniform loss weight: %f' % FLAGS.gamma3)
    print('\tDropout keep probability: %f' % FLAGS.dropout_keep_prob)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tNo update on BN scale parameter: %d' % FLAGS.bn_no_scale)
    print('\tWeighted split loss: %d' % FLAGS.weighted_group_loss)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)
    print('\tFinetune: %d' % FLAGS.finetune)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per validation: %d' % FLAGS.val_interval)
    print('\tSteps during validation: %d' % FLAGS.val_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tSteps per plot splits: %d' % FLAGS.group_summary_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of CIFAR-100
        print('Load CIFAR-100 dataset')
        train_dataset_path = os.path.join(FLAGS.data_dir, 'train')
        val_dataset_path = os.path.join(FLAGS.data_dir, 'val')
        print('\tLoading training data from %s' % train_dataset_path)
        with tf.variable_scope('train_image'):
            cifar100_train = cifar100.CIFAR100Runner(train_dataset_path,
                                                     image_per_thread=32,
                                                     shuffle=True,
                                                     distort=True,
                                                     capacity=10000)
            train_images, train_labels = cifar100_train.get_inputs(
                FLAGS.batch_size)
        print('\tLoading validation data from %s' % val_dataset_path)
        with tf.variable_scope('val_image'):
            cifar100_val = cifar100.CIFAR100Runner(val_dataset_path,
                                                   image_per_thread=32,
                                                   shuffle=False,
                                                   distort=False,
                                                   capacity=5000)
            # shuffle=False, distort=False, capacity=10000)
            val_images, val_labels = cifar100_val.get_inputs(FLAGS.batch_size)

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, cifar100.IMAGE_SIZE, cifar100.IMAGE_SIZE, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # Build model
        lr_decay_steps = map(float, FLAGS.lr_step_epoch.split(','))
        lr_decay_steps = map(int, [
            s * FLAGS.num_train_instance / FLAGS.batch_size
            for s in lr_decay_steps
        ])
        with tf.device('/GPU:0'):
            hp = resnet.HParams(batch_size=FLAGS.batch_size,
                                num_classes=FLAGS.num_classes,
                                num_residual_units=FLAGS.num_residual_units,
                                k=FLAGS.k,
                                weight_decay=FLAGS.l2_weight,
                                ngroups1=FLAGS.ngroups1,
                                ngroups2=FLAGS.ngroups2,
                                ngroups3=FLAGS.ngroups3,
                                gamma1=FLAGS.gamma1,
                                gamma2=FLAGS.gamma2,
                                gamma3=FLAGS.gamma3,
                                dropout_keep_prob=FLAGS.dropout_keep_prob,
                                momentum=FLAGS.momentum,
                                bn_no_scale=FLAGS.bn_no_scale,
                                weighted_group_loss=FLAGS.weighted_group_loss,
                                finetune=FLAGS.finetune)
        network = resnet.ResNet(hp, images, labels, global_step)
        network.build_model()
        network.build_train_op()
        print('Number of Weights: %d' % network._weights)
        print('FLOPs: %d' % network._flops)

        train_summary_op = tf.summary.merge_all()  # Summaries(training)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000)
        if FLAGS.checkpoint is not None:
            saver.restore(sess, FLAGS.checkpoint)
            init_step = global_step.eval(session=sess)
            print('Load checkpoint %s' % FLAGS.checkpoint)
        elif FLAGS.basemodel:
            # Define a different saver to load model checkpoints
            # Select only base variables (exclude split layers)
            print('Load parameters from basemodel %s' % FLAGS.basemodel)
            variables = tf.global_variables()
            vars_restore = [
                var for var in variables if not "Momentum" in var.name
                and not "group" in var.name and not "global_step" in var.name
            ]
            # vars_restore = [var for var in variables
            # if not "alpha" in var.name and
            # not "fc_beta" in var.name and
            # not "unit_3" in var.name and
            # not "unit_last" in var.name and
            # not "logits" in var.name and
            # not "Momentum" in var.name and
            # not "global_step" in var.name]
            saver_restore = tf.train.Saver(vars_restore, max_to_keep=10000)
            saver_restore.restore(sess, FLAGS.basemodel)
        else:
            print(
                'No checkpoint file of basemodel found. Start from the scratch.'
            )

        # Start queue runners & summary_writer
        cifar100_train.start_threads(sess, n_threads=20)
        cifar100_val.start_threads(sess, n_threads=1)

        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.summary.FileWriter(
            os.path.join(FLAGS.train_dir, str(global_step.eval(session=sess))),
            sess.graph)

        # Training!
        val_best_acc = 0.0
        for step in xrange(init_step, FLAGS.max_steps):
            # val
            if step % FLAGS.val_interval == 0:
                val_loss, val_acc = 0.0, 0.0
                for i in range(FLAGS.val_iter):
                    val_images_val, val_labels_val = sess.run(
                        [val_images, val_labels])
                    loss_value, acc_value = sess.run(
                        [network.loss, network.acc],
                        feed_dict={
                            network.is_train: False,
                            images: val_images_val,
                            labels: val_labels_val
                        })
                    val_loss += loss_value
                    val_acc += acc_value
                val_loss /= FLAGS.val_iter
                val_acc /= FLAGS.val_iter
                val_best_acc = max(val_best_acc, val_acc)
                format_str = ('%s: (val)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), step, val_loss, val_acc))

                val_summary = tf.Summary()
                val_summary.value.add(tag='val/loss', simple_value=val_loss)
                val_summary.value.add(tag='val/acc', simple_value=val_acc)
                val_summary.value.add(tag='val/best_acc',
                                      simple_value=val_best_acc)
                summary_writer.add_summary(val_summary, step)
                summary_writer.flush()

            # Train
            lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps,
                              step)
            start_time = time.time()
            train_images_val, train_labels_val = sess.run(
                [train_images, train_labels])
            _, loss_value, acc_value, train_summary_str = \
                    sess.run([network.train_op, network.loss, network.acc, train_summary_op],
                             feed_dict={network.is_train:True, network.lr:lr_value, images:train_images_val, labels:train_labels_val})
            duration = time.time() - start_time

            assert not np.isnan(loss_value)

            # Display & Summary(training)
            if step % FLAGS.display == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, acc_value, lr_value,
                       examples_per_sec, sec_per_batch))
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

            if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
                char = sys.stdin.read(1)
                if char == 'b':
                    embed()

            # Plot grouped weight matrices as image summary
            filters = [16, 16 * FLAGS.k, 32 * FLAGS.k, 64 * FLAGS.k]
            if FLAGS.group_summary_interval is not None:
                if step % FLAGS.group_summary_interval == 0:
                    img_summaries = []

                    if FLAGS.ngroups1 > 1:
                        logits_weights = get_var_value('logits/fc/weights',
                                                       sess)
                        split_p1 = softmax(
                            get_var_value('group/split_p1/alpha', sess))
                        split_q1 = softmax(
                            get_var_value('group/split_q1/alpha', sess))
                        feature_indices = np.argsort(
                            np.argmax(split_p1, axis=0))
                        class_indices = np.argsort(np.argmax(split_q1, axis=0))

                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_p1[:, feature_indices],
                                          20,
                                          axis=0), 'split_p1'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_q1[:, class_indices],
                                          20,
                                          axis=0), 'split_q1'))
                        img_summaries.append(
                            img_to_summary(
                                np.abs(logits_weights[feature_indices, :]
                                       [:, class_indices]), 'logits'))

                    if FLAGS.ngroups2 > 1:
                        unit_3_0_shortcut = get_var_value(
                            'unit_3_0/shortcut/kernel', sess)
                        unit_3_0_conv_1 = get_var_value(
                            'unit_3_0/conv_1/kernel', sess)
                        unit_3_0_conv_2 = get_var_value(
                            'unit_3_0/conv_2/kernel', sess)
                        unit_3_1_conv_1 = get_var_value(
                            'unit_3_1/conv_1/kernel', sess)
                        unit_3_1_conv_2 = get_var_value(
                            'unit_3_1/conv_2/kernel', sess)
                        split_p2 = softmax(
                            get_var_value('group/split_p2/alpha', sess))
                        split_q2 = _merge_split_q(
                            split_p1,
                            _get_even_merge_idxs(FLAGS.ngroups1,
                                                 FLAGS.ngroups2))
                        split_r21 = softmax(
                            get_var_value('group/split_r21/alpha', sess))
                        split_r22 = softmax(
                            get_var_value('group/split_r22/alpha', sess))
                        feature_indices1 = np.argsort(
                            np.argmax(split_p2, axis=0))
                        feature_indices2 = np.argsort(
                            np.argmax(split_q2, axis=0))
                        feature_indices3 = np.argsort(
                            np.argmax(split_r21, axis=0))
                        feature_indices4 = np.argsort(
                            np.argmax(split_r22, axis=0))
                        unit_3_0_shortcut_img = np.abs(
                            unit_3_0_shortcut[:, :, feature_indices1, :]
                            [:, :, :,
                             feature_indices2].transpose([2, 0, 3, 1]).reshape(
                                 filters[2], filters[3]))
                        unit_3_0_conv_1_img = np.abs(
                            unit_3_0_conv_1[:, :, feature_indices1, :]
                            [:, :, :,
                             feature_indices3].transpose([2, 0, 3, 1]).reshape(
                                 filters[2] * 3, filters[3] * 3))
                        unit_3_0_conv_2_img = np.abs(
                            unit_3_0_conv_2[:, :, feature_indices3, :]
                            [:, :, :,
                             feature_indices2].transpose([2, 0, 3, 1]).reshape(
                                 filters[3] * 3, filters[3] * 3))
                        unit_3_1_conv_1_img = np.abs(
                            unit_3_1_conv_1[:, :, feature_indices2, :]
                            [:, :, :,
                             feature_indices4].transpose([2, 0, 3, 1]).reshape(
                                 filters[3] * 3, filters[3] * 3))
                        unit_3_1_conv_2_img = np.abs(
                            unit_3_1_conv_2[:, :, feature_indices4, :]
                            [:, :, :,
                             feature_indices2].transpose([2, 0, 3, 1]).reshape(
                                 filters[3] * 3, filters[3] * 3))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_p2[:, feature_indices1],
                                          20,
                                          axis=0), 'split_p2'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r21[:, feature_indices3],
                                          20,
                                          axis=0), 'split_r21'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r22[:, feature_indices4],
                                          20,
                                          axis=0), 'split_r22'))
                        img_summaries.append(
                            img_to_summary(unit_3_0_shortcut_img,
                                           'unit_3_0/shortcut_kernel'))
                        img_summaries.append(
                            img_to_summary(unit_3_0_conv_1_img,
                                           'unit_3_0/conv_1_kernel'))
                        img_summaries.append(
                            img_to_summary(unit_3_0_conv_2_img,
                                           'unit_3_0/conv_2_kernel'))
                        img_summaries.append(
                            img_to_summary(unit_3_1_conv_1_img,
                                           'unit_3_1/conv_1_kernel'))
                        img_summaries.append(
                            img_to_summary(unit_3_1_conv_2_img,
                                           'unit_3_1/conv_2_kernel'))

                    if FLAGS.ngroups3 > 1:
                        unit_2_0_shortcut = get_var_value(
                            'unit_2_0/shortcut/kernel', sess)
                        unit_2_0_conv_1 = get_var_value(
                            'unit_2_0/conv_1/kernel', sess)
                        unit_2_0_conv_2 = get_var_value(
                            'unit_2_0/conv_2/kernel', sess)
                        unit_2_1_conv_1 = get_var_value(
                            'unit_2_1/conv_1/kernel', sess)
                        unit_2_1_conv_2 = get_var_value(
                            'unit_2_1/conv_2/kernel', sess)
                        split_p3 = softmax(
                            get_var_value('group/split_p3/alpha', sess))
                        split_q3 = _merge_split_q(
                            split_p2,
                            _get_even_merge_idxs(FLAGS.ngroups2,
                                                 FLAGS.ngroups3))
                        split_r31 = softmax(
                            get_var_value('group/split_r31/alpha', sess))
                        split_r32 = softmax(
                            get_var_value('group/split_r32/alpha', sess))
                        feature_indices1 = np.argsort(
                            np.argmax(split_p3, axis=0))
                        feature_indices2 = np.argsort(
                            np.argmax(split_q3, axis=0))
                        feature_indices3 = np.argsort(
                            np.argmax(split_r31, axis=0))
                        feature_indices4 = np.argsort(
                            np.argmax(split_r32, axis=0))
                        unit_2_0_shortcut_img = np.abs(
                            unit_2_0_shortcut[:, :, feature_indices1, :]
                            [:, :, :,
                             feature_indices2].transpose([2, 0, 3, 1]).reshape(
                                 filters[1], filters[2]))
                        unit_2_0_conv_1_img = np.abs(
                            unit_2_0_conv_1[:, :, feature_indices1, :]
                            [:, :, :,
                             feature_indices3].transpose([2, 0, 3, 1]).reshape(
                                 filters[1] * 3, filters[2] * 3))
                        unit_2_0_conv_2_img = np.abs(
                            unit_2_0_conv_2[:, :, feature_indices3, :]
                            [:, :, :,
                             feature_indices2].transpose([2, 0, 3, 1]).reshape(
                                 filters[2] * 3, filters[2] * 3))
                        unit_2_1_conv_1_img = np.abs(
                            unit_2_1_conv_1[:, :, feature_indices2, :]
                            [:, :, :,
                             feature_indices4].transpose([2, 0, 3, 1]).reshape(
                                 filters[2] * 3, filters[2] * 3))
                        unit_2_1_conv_2_img = np.abs(
                            unit_2_1_conv_2[:, :, feature_indices4, :]
                            [:, :, :,
                             feature_indices2].transpose([2, 0, 3, 1]).reshape(
                                 filters[2] * 3, filters[2] * 3))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_p3[:, feature_indices1],
                                          20,
                                          axis=0), 'split_p3'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r31[:, feature_indices3],
                                          20,
                                          axis=0), 'split_r31'))
                        img_summaries.append(
                            img_to_summary(
                                np.repeat(split_r32[:, feature_indices4],
                                          20,
                                          axis=0), 'split_r32'))
                        img_summaries.append(
                            img_to_summary(unit_2_0_shortcut_img,
                                           'unit_2_0/shortcut_kernel'))
                        img_summaries.append(
                            img_to_summary(unit_2_0_conv_1_img,
                                           'unit_2_0/conv_1_kernel'))
                        img_summaries.append(
                            img_to_summary(unit_2_0_conv_2_img,
                                           'unit_2_0/conv_2_kernel'))
                        img_summaries.append(
                            img_to_summary(unit_2_1_conv_1_img,
                                           'unit_2_1/conv_1_kernel'))
                        img_summaries.append(
                            img_to_summary(unit_2_1_conv_2_img,
                                           'unit_2_1/conv_2_kernel'))

                    if img_summaries:  # If not empty
                        img_summary = tf.Summary(value=img_summaries)
                        summary_writer.add_summary(img_summary, step)
                        summary_writer.flush()
예제 #19
0
    'val_dir',
    '/home/pirate03/hobotrl_data/playground/initialD/exp/record_rule_scenes_obj80_vec_rewards_docker005_no_early_stopping_all_green/valid',
    """Path to initialD the test dataset""")
tf.app.flags.DEFINE_integer('num_classes', 3,
                            """Number of classes in the dataset.""")
# Training Configuration
tf.app.flags.DEFINE_string('log_dir', './resnet/val_ckp',
                           """Directory where to write log and checkpoint.""")

FLAGS = tf.app.flags.FLAGS
val_dir = FLAGS.val_dir
log_dir = FLAGS.log_dir

hp = resnet.HParams(batch_size=64,
                    num_gpus=1,
                    num_classes=3,
                    weight_decay=0.001,
                    momentum=0.9,
                    finetune=True)
global_step = tf.Variable(0, trainable=False, name='global_step')
network_val = resnet.ResNet(hp, global_step, name="val")
network_val.build_model()
stack_num = 3
state_shape = (256, 256, 3 * stack_num)
labels = [0, 1, 2]

graph = tf.get_default_graph()

init_op = tf.global_variables_initializer()

sv = tf.train.Supervisor(
    graph=graph,
예제 #20
0
def train():
    print('[Dataset Configuration]')
    print('\tImageNet test root: %s' % FLAGS.test_image_root)
    print('\tImageNet test list: %s' % FLAGS.test_dataset)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tCheckpoint file: %s' % FLAGS.checkpoint)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Evaluation Configuration]')
    print('\tOutput file path: %s' % FLAGS.output_file)
    print('\tTest iterations: %d' % FLAGS.test_iter)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of ImageNet
        print('Load ImageNet dataset')
        with tf.device('/cpu:0'):
            print('\tLoading test data from %s' % FLAGS.test_dataset)
            with tf.variable_scope('test_image'):
                test_images, test_labels = data_input.inputs(
                    FLAGS.test_image_root,
                    FLAGS.test_dataset,
                    FLAGS.batch_size,
                    False,
                    num_threads=1,
                    center_crop=True)

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(tf.float32, [
            FLAGS.batch_size, data_input.IMAGE_HEIGHT, data_input.IMAGE_WIDTH,
            3
        ])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # Build model
        with tf.device('/GPU:0'):
            hp = resnet.HParams(batch_size=FLAGS.batch_size,
                                num_classes=FLAGS.num_classes,
                                weight_decay=FLAGS.l2_weight,
                                momentum=FLAGS.momentum,
                                finetune=FLAGS.finetune)
        network = resnet.ResNet(hp, images, labels, global_step)
        network.build_model()
        print('\tNumber of Weights: %d' % network._weights)
        print('\tFLOPs: %d' % network._flops)

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        '''debugging attempt
        from tensorflow.python import debug as tf_debug
        sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        def _get_data(datum, tensor):
            return tensor == train_images
        sess.add_tensor_filter("get_data", _get_data)
        '''

        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        if FLAGS.checkpoint is not None:
            saver.restore(sess, FLAGS.checkpoint)
            print('Load checkpoint %s' % FLAGS.checkpoint)
        else:
            print(
                'No checkpoint file of basemodel found. Start from the scratch.'
            )

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)

        # Test!
        test_loss = 0.0
        test_acc = 0.0
        test_time = 0.0
        confusion_matrix = np.zeros((FLAGS.num_classes, FLAGS.num_classes),
                                    dtype=np.int32)
        for i in range(FLAGS.test_iter):
            test_images_val, test_labels_val = sess.run(
                [test_images, test_labels])
            start_time = time.time()
            loss_value, acc_value, pred_value = sess.run(
                [network.loss, network.acc, network.preds],
                feed_dict={
                    network.is_train: False,
                    images: test_images_val,
                    labels: test_labels_val
                })
            duration = time.time() - start_time
            test_loss += loss_value
            test_acc += acc_value
            test_time += duration
            for l, p in zip(test_labels_val, pred_value):
                confusion_matrix[l, p] += 1

            if i % FLAGS.display == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: iter %d, loss=%.4f, acc=%.4f (%.1f examples/sec; %.3f sec/batch)'
                )
                print(format_str % (datetime.now(), i, loss_value, acc_value,
                                    examples_per_sec, sec_per_batch))
        test_loss /= FLAGS.test_iter
        test_acc /= FLAGS.test_iter

        # Print and save results
        sec_per_image = test_time / FLAGS.test_iter / FLAGS.batch_size
        print('Done! Acc: %.6f, Test time: %.3f sec, %.7f sec/example' %
              (test_acc, test_time, sec_per_image))
        print('Saving result... ')
        result = {
            'accuracy': test_acc,
            'confusion_matrix': confusion_matrix,
            'test_time': test_time,
            'sec_per_image': sec_per_image
        }
        with open(FLAGS.output_file, 'wb') as fd:
            pickle.dump(result, fd)
        print('done!')
예제 #21
0
def train():
    print('[Dataset Configuration]')
    print('\tImageNet training root: %s' % FLAGS.train_image_root)
    print('\tImageNet training list: %s' % FLAGS.train_dataset)
    print('\tImageNet test root: %s' % FLAGS.test_image_root)
    print('\tImageNet test list: %s' % FLAGS.test_dataset)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tNumber of GPUs: %d' % FLAGS.num_gpu)
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tSplitted Network: %s' % FLAGS.split)
    if FLAGS.split:
        print('\tClustering path: %s' % FLAGS.cluster_path)
        print('\tNo logit map: %s' % FLAGS.no_logit_map)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per testing: %d' % FLAGS.test_interval)
    print('\tSteps during testing: %d' % FLAGS.test_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with open(FLAGS.cluster_path) as fd:
        clustering = pickle.load(fd)

    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of CIFAR-100
        with tf.variable_scope('train_image'):
            train_images, train_labels = data_input.distorted_inputs(
                FLAGS.train_image_root, FLAGS.train_dataset,
                FLAGS.batch_size * FLAGS.num_gpu, True)
        with tf.variable_scope('test_image'):
            test_images, test_labels = data_input.inputs(
                FLAGS.test_image_root, FLAGS.test_dataset,
                FLAGS.batch_size * FLAGS.num_gpu, False)

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(tf.float32, [
            FLAGS.batch_size * FLAGS.num_gpu, data_input.IMAGE_HEIGHT,
            data_input.IMAGE_WIDTH, 3
        ])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size * FLAGS.num_gpu])

        # Build model
        lr_decay_steps = map(float, FLAGS.lr_step_epoch.split(','))
        lr_decay_steps = map(int, [
            s * FLAGS.num_train_instance / FLAGS.batch_size / FLAGS.num_gpu
            for s in lr_decay_steps
        ])
        print('Learning rate decays at iter: %s' % str(lr_decay_steps))
        hp = resnet.HParams(num_gpu=FLAGS.num_gpu,
                            batch_size=FLAGS.batch_size,
                            split=FLAGS.split,
                            num_classes=FLAGS.num_classes,
                            weight_decay=FLAGS.l2_weight,
                            momentum=FLAGS.momentum,
                            no_logit_map=FLAGS.no_logit_map)
        network = resnet.ResNet(hp, images, labels, global_step)
        if FLAGS.split:
            network.set_clustering(clustering)
        network.build_model()
        print('%d flops' % network._flops)
        print('%d params' % network._weights)
        network.build_train_op()

        # Summaries(training)
        train_summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        # saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        saver = tf.train.Saver(tf.all_variables(),
                               max_to_keep=10000,
                               write_version=tf.train.SaverDef.V2)
        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print('\tRestore from %s' % ckpt.model_checkpoint_path)
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            init_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        else:
            ckpt_base = tf.train.get_checkpoint_state(FLAGS.baseline_dir)
            if ckpt_base and ckpt_base.model_checkpoint_path:
                # Check loadable variables(variable with same name and same shape) and load them only
                print('No checkpoint file found. Start from the baseline.')
                loadable_vars = utils._get_loadable_vars(
                    ckpt_base.model_checkpoint_path, verbose=True)
                # saver_base = tf.train.Saver(loadable_vars)
                saver_base = tf.train.Saver(loadable_vars,
                                            write_version=tf.train.SaverDef.V2)
                saver_base.restore(sess, ckpt_base.model_checkpoint_path)
            else:
                print('No checkpoint file found. Start from the scratch.')

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)
        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

        # Training!
        test_best_acc = 0.0
        for step in xrange(init_step, FLAGS.max_steps):
            # Test
            if step % FLAGS.test_interval == 0:
                test_loss, test_acc = 0.0, 0.0
                for i in range(FLAGS.test_iter):
                    test_images_val, test_labels_val = sess.run(
                        [test_images, test_labels])
                    loss_value, acc_value = sess.run(
                        [network.loss, network.acc],
                        feed_dict={
                            network.is_train: False,
                            images: test_images_val,
                            labels: test_labels_val
                        })
                    test_loss += loss_value
                    test_acc += acc_value
                test_loss /= FLAGS.test_iter
                test_acc /= FLAGS.test_iter
                test_best_acc = max(test_best_acc, test_acc)
                format_str = ('%s: (Test)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), step, test_loss, test_acc))

                test_summary = tf.Summary()
                test_summary.value.add(tag='test/loss', simple_value=test_loss)
                test_summary.value.add(tag='test/acc', simple_value=test_acc)
                test_summary.value.add(tag='test/best_acc',
                                       simple_value=test_best_acc)
                summary_writer.add_summary(test_summary, step)
                summary_writer.flush()

            # Train
            lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps,
                              step)
            start_time = time.time()
            train_images_val, train_labels_val = sess.run(
                [train_images, train_labels])
            _, loss_value, acc_value, train_summary_str = \
                    sess.run([network.train_op, network.loss, network.acc, train_summary_op],
                             feed_dict={network.is_train:True, network.lr:lr_value, images:train_images_val, labels:train_labels_val})
            duration = time.time() - start_time

            assert not np.isnan(loss_value)

            # Display & Summary(training)
            if step % FLAGS.display == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, acc_value, lr_value,
                       examples_per_sec, sec_per_batch))
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
예제 #22
0
def train():
    print('[Dataset Configuration]')
    print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
    print('\tNetwork width multiplier: %d' % FLAGS.k)

    print('[Testing Configuration]')
    print('\tCheckpoint path: %s' % FLAGS.ckpt_path)
    print('\tDataset: %s' % ('Training' if FLAGS.train_data else 'Test'))
    print('\tNumber of testing iterations: %d' % FLAGS.test_iter)
    print('\tOutput path: %s' % FLAGS.output)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, data_input.HEIGHT, data_input.WIDTH, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # Build model
        decay_step = FLAGS.lr_step_epoch * FLAGS.num_train_instance / FLAGS.batch_size
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_classes=FLAGS.num_classes,
                            num_residual_units=FLAGS.num_residual_units,
                            k=FLAGS.k,
                            weight_decay=FLAGS.l2_weight,
                            initial_lr=FLAGS.initial_lr,
                            decay_step=decay_step,
                            lr_decay=FLAGS.lr_decay,
                            momentum=FLAGS.momentum)
        network = resnet.ResNet(hp, images, labels, None)
        network.build_model()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        if os.path.isdir(FLAGS.ckpt_path):
            ckpt = tf.train.get_checkpoint_state(FLAGS.ckpt_path)
            # Restores from checkpoint
            if ckpt and ckpt.model_checkpoint_path:
                print('\tRestore from %s' % ckpt.model_checkpoint_path)
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print('No checkpoint file found in the dir [%s]' %
                      FLAGS.ckpt_path)
                sys.exit(1)
        elif os.path.isfile(FLAGS.ckpt_path):
            print('\tRestore from %s' % FLAGS.ckpt_path)
            saver.restore(sess, FLAGS.ckpt_path)
        else:
            print('No checkpoint file found in the path [%s]' %
                  FLAGS.ckpt_path)
            sys.exit(1)

        graph = tf.get_default_graph()
        block_num = 3
        conv_num = 2
        old_kernels_to_cluster = []
        old_kernels_to_add = []
        old_batch_norm = []
        for i in range(1, block_num + 1):
            for j in range(FLAGS.num_residual_units):
                old_kernels_to_cluster.append(get_kernel(i, j, 1, graph, sess))
                old_kernels_to_add.append(get_kernel(i, j, 2, graph, sess))
                old_batch_norm.append(get_batch_norm(i, j, 2, graph, sess))
        #old_batch_norm = old_batch_norm[1:]
        #old_batch_norm.append(get_last_batch_norm(graph, sess))

        new_params = []
        new_width = [
            16,
            int(16 * FLAGS.new_k),
            int(32 * FLAGS.new_k),
            int(64 * FLAGS.new_k)
        ]
        for i in range(len(old_batch_norm)):
            cluster_num = new_width[int(i / 4) + 1]
            cluster_kernels, cluster_indices = cluster_kernel(
                old_kernels_to_cluster[i], cluster_num)
            add_kernels = add_kernel(old_kernels_to_add[i], cluster_indices,
                                     cluster_num)
            cluster_batchs_norm = cluster_batch_norm(old_batch_norm[i],
                                                     cluster_indices,
                                                     cluster_num)
            new_params.append(cluster_kernels)
            for p in range(BATCH_NORM_PARAM_NUM):
                new_params.append(cluster_batchs_norm[p])
            new_params.append(add_kernels)

        # save variables
        init_params = []
        new_param_index = 0
        for var in tf.global_variables():
            update_match = UPDATE_PARAM_REGEX.match(var.name)
            skip_match = SKIP_PARAM_REGEX.match(var.name)
            if update_match and not skip_match:
                print("update {}".format(var.name))
                init_params.append((new_params[new_param_index], var.name))
                new_param_index += 1
            else:
                print("not update {}".format(var.name))
                var_vector = sess.run(var)
                init_params.append((var_vector, var.name))

        #close old graph
        sess.close()
    tf.reset_default_graph()

    # build new graph and eval
    with tf.Graph().as_default():
        # The CIFAR-100 dataset
        with tf.variable_scope('test_image'):
            test_images, test_labels = data_input.input_fn(
                FLAGS.data_dir,
                FLAGS.batch_size,
                train_mode=FLAGS.train_data,
                num_threads=1)

        # The class labels
        with open(os.path.join(FLAGS.data_dir, 'fine_label_names.txt')) as fd:
            classes = [temp.strip() for temp in fd.readlines()]

        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, data_input.HEIGHT, data_input.WIDTH, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        new_network = resnet.ResNet(hp, images, labels, None, init_params,
                                    FLAGS.new_k)
        new_network.build_model()

        init = tf.initialize_all_variables()
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Testing!
        result_ll = [[0, 0] for _ in range(FLAGS.num_classes)
                     ]  # Correct/wrong counts for each class
        test_loss = 0.0, 0.0
        for i in range(FLAGS.test_iter):
            test_images_val, test_labels_val = sess.run(
                [test_images, test_labels])
            preds_val, loss_value, acc_value = sess.run(
                [new_network.preds, new_network.loss, new_network.acc],
                feed_dict={
                    new_network.is_train: False,
                    images: test_images_val,
                    labels: test_labels_val
                })
            test_loss += loss_value
            for j in range(FLAGS.batch_size):
                correct = 0 if test_labels_val[j] == preds_val[j] else 1
                result_ll[test_labels_val[j] % FLAGS.num_classes][correct] += 1
        test_loss /= FLAGS.test_iter

        # Summary display & output
        acc_list = [float(r[0]) / float(r[0] + r[1]) for r in result_ll]
        result_total = np.sum(np.array(result_ll), axis=0)
        acc_total = float(result_total[0]) / np.sum(result_total)

        print('Class    \t\t\tT\tF\tAcc.')
        format_str = '%-31s %7d %7d %.5f'
        for i in range(FLAGS.num_classes):
            print(format_str %
                  (classes[i], result_ll[i][0], result_ll[i][1], acc_list[i]))
        print(format_str %
              ('(Total)', result_total[0], result_total[1], acc_total))

        # Output to file(if specified)
        if FLAGS.output.strip():
            with open(FLAGS.output, 'w') as fd:
                fd.write('Class    \t\t\tT\tF\tAcc.\n')
                format_str = '%-31s %7d %7d %.5f'
                for i in range(FLAGS.num_classes):
                    t, f = result_ll[i]
                    format_str = '%-31s %7d %7d %.5f\n'
                    fd.write(format_str %
                             (classes[i].replace(' ', '-'), t, f, acc_list[i]))
                fd.write(
                    format_str %
                    ('(Total)', result_total[0], result_total[1], acc_total))
예제 #23
0
def train():
    print('[Dataset Configuration]')
    #print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    #print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
    print('\tNetwork width multiplier: %d' % FLAGS.k)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %f' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per testing: %d' % FLAGS.test_interval)
    print('\tSteps during testing: %d' % FLAGS.test_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    sys.stdout.flush()

    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of ImageNet
        with tf.variable_scope('train_image'):
            train_images, train_labels = image_processing.distorted_inputs(
                dataset.Dataset('imagenet', 'train'), num_preprocess_threads=4)
        with tf.variable_scope('test_image'):
            test_images, test_labels = image_processing.distorted_inputs(
                dataset.Dataset('imagenet', 'validation'),
                num_preprocess_threads=4)

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # Build model
        decay_step = FLAGS.lr_step_epoch * FLAGS.num_train_instance / FLAGS.batch_size
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_classes=FLAGS.num_classes,
                            num_residual_units=FLAGS.num_residual_units,
                            k=FLAGS.k,
                            weight_decay=FLAGS.l2_weight,
                            initial_lr=FLAGS.initial_lr,
                            decay_step=decay_step,
                            lr_decay=FLAGS.lr_decay,
                            momentum=FLAGS.momentum)
        network = resnet.ResNet(hp, images, labels, global_step)
        network.build_model()
        network.build_train_op()
        network.count_trainable_params()

        # Summaries(training)
        train_summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print('\tRestore from %s' % ckpt.model_checkpoint_path)
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            init_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        else:
            print('No checkpoint file found. Start from the scratch.')
        sys.stdout.flush()

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)
        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        # Training!
        test_best_acc = 0.0
        for step in range(init_step, FLAGS.max_steps):
            # Test
            if step % FLAGS.test_interval == 0:
                test_loss, test_acc = 0.0, 0.0
                for i in range(FLAGS.test_iter):
                    test_images_val, test_labels_val = sess.run(
                        [test_images, test_labels])
                    test_labels_val -= 1
                    loss_value, acc_value = sess.run(
                        [network.loss, network.acc],
                        feed_dict={
                            network.is_train: False,
                            images: test_images_val,
                            labels: test_labels_val
                        })
                    test_loss += loss_value
                    test_acc += acc_value
                test_loss /= FLAGS.test_iter
                test_acc /= FLAGS.test_iter
                test_best_acc = max(test_best_acc, test_acc)
                format_str = ('%s: (Test)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), step, test_loss, test_acc))
                sys.stdout.flush()

                test_summary = tf.Summary()
                test_summary.value.add(tag='test/loss', simple_value=test_loss)
                test_summary.value.add(tag='test/acc', simple_value=test_acc)
                test_summary.value.add(tag='test/best_acc',
                                       simple_value=test_best_acc)
                summary_writer.add_summary(test_summary, step)
                # test_loss_summary = tf.Summary()
                # test_loss_summary.value.add(tag='test/loss', simple_value=test_loss)
                # summary_writer.add_summary(test_loss_summary, step)
                # test_acc_summary = tf.Summary()
                # test_acc_summary.value.add(tag='test/acc', simple_value=test_acc)
                # summary_writer.add_summary(test_acc_summary, step)
                # test_best_acc_summary = tf.Summary()
                # test_best_acc_summary.value.add(tag='test/best_acc', simple_value=test_best_acc)
                # summary_writer.add_summary(test_best_acc_summary, step)
                summary_writer.flush()

            # Train
            start_time = time.time()
            train_images_val, train_labels_val = sess.run(
                [train_images, train_labels])
            train_labels_val -= 1
            _, lr_value, loss_value, acc_value, train_summary_str = \
                    sess.run([network.train_op, network.lr, network.loss, network.acc, train_summary_op],
                        feed_dict={network.is_train:True, images:train_images_val, labels:train_labels_val})
            duration = time.time() - start_time

            assert not np.isnan(loss_value)

            # Display & Summary(training)
            if step % FLAGS.display == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, acc_value, lr_value,
                       examples_per_sec, sec_per_batch))
                sys.stdout.flush()
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
    elif len(v.shape) == 2:
        model_weights[k] = np.transpose(v)
    else:
        model_weights[k] = v

# Build ResNet-18 model and save parameters
with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False, name='global_step')
    images = [tf.placeholder(tf.float32, [2, 224, 224, 3])]
    labels = [tf.placeholder(tf.int32, [2])]

    # Build model
    print("Build ResNet-18 model")
    hp = resnet.HParams(batch_size=2,
                        num_gpus=1,
                        num_classes=1000,
                        weight_decay=0.001,
                        momentum=0.9,
                        finetune=False)
    network_train = resnet.ResNet(hp,
                                  images,
                                  labels,
                                  global_step,
                                  name="train")
    network_train.build_model()
    print('Number of Weights: %d' % network_train._weights)
    print('FLOPs: %d' % network_train._flops)

    # Build an initialization operation to run below.
    init = tf.global_variables_initializer()

    # Start running operations on the Graph.
def cifar10_blackbox(nb_classes=10,
                     batch_size=128,
                     nb_samples=10,
                     l2_weight=0.0001,
                     momentum=0.9,
                     initial_lr=0.1,
                     lr_step_epoch=100.0,
                     lr_decay=0.1,
                     num_residual_units=2,
                     num_train_instance=50000,
                     num_test_instance=10000,
                     k=1,
                     eps=0.3,
                     learning_rate=0.001,
                     nb_epochs=10,
                     holdout=150,
                     data_aug=6,
                     nb_epochs_s=10,
                     lmbda=0.1,
                     binary=False,
                     scale=False,
                     model_path=None,
                     targeted=False,
                     data_dir=None,
                     adv=False,
                     delay=0):
    """
    MNIST tutorial for the black-box attack from arxiv.org/abs/1602.02697
    :param train_start: index of first training set example
    :param train_end: index of last training set example
    :param test_start: index of first test set example
    :param test_end: index of last test set example
    :return: a dictionary with:
             * black-box model accuracy on test set
             * substitute model accuracy on test set
             * black-box model accuracy on adversarial examples transferred
               from the substitute model
    """

    # Set logging level to see debug information
    set_log_level(logging.DEBUG)

    # Dictionary used to keep track and return key accuracies
    accuracies = {}

    # Perform tutorial setup
    assert setup_tutorial()

    if not hasattr(backend, "tf"):
        raise RuntimeError("This tutorial requires keras to be configured"
                           " to use the TensorFlow backend.")

    # Image dimensions ordering should follow the Theano convention
    if keras.backend.image_dim_ordering() != 'tf':
        keras.backend.set_image_dim_ordering('tf')
        print("INFO: '~/.keras/keras.json' sets 'image_dim_ordering' to "
              "'th', temporarily setting to 'tf'")

    # Create TF session and set as Keras backend session
    sess = tf.Session()
    keras.backend.set_session(sess)

    # Get CIFAR10 test data
    X_train, Y_train, X_test, Y_test = data_cifar10_std()

    # Initialize substitute training set reserved for adversary
    X_sub = X_test[:holdout]
    Y_sub = np.argmax(Y_test[:holdout], axis=1)

    # Redefine test set as remaining samples unavailable to adversaries
    X_test = X_test[holdout:]
    Y_test = Y_test[holdout:]

    # Y_test is for evaluating oracle
    '''
    Y_test_bbox = np.argmax(Y_test, axis=1)
    Y_test_bbox = Y_test_bbox.reshape(Y_test_bbox.shape[0],)
    Y_test_bbox = Y_test_bbox.astype('int32')    
    '''

    #Y_sub = Y_test_onehot[:holdout]
    #Y_sub = np.argmax(Y_test_onehot[:holdout], axis=1)

    print("Y_sub shape=")
    print(Y_sub.shape)

    print("Y_test shape=")
    print(Y_test.shape)

    # CIFAR10-specific dimensions
    img_rows = 32
    img_cols = 32
    channels = 3

    I = 1
    '''
    plt.figure()
    plt.imshow(X_test[I])
    plt.title("clean 1")
    '''
    rng = np.random.RandomState([2017, 8, 30])

    # with tf.Graph().as_default():

    # Define input and output TF placeholders
    x = tf.placeholder(tf.float32, shape=(None, img_rows, img_cols, channels))

    y_oracle = tf.placeholder(tf.int32, shape=(None))
    y_eval = tf.placeholder(tf.int32, shape=(None))
    y_sub = tf.placeholder(tf.float32, shape=(None, nb_classes))

    #y_ = tf.placeholder(tf.int32, shape=(None, nb_classes))
    phase = tf.placeholder(tf.bool, name='phase')

    # Seed random number generator so tutorial is reproducible

    # Simulate the black-box model locally
    # You could replace this by a remote labeling API for instance
    print("Preparing the WideResNet black-box model.")
    '''
    prep_bbox_out = prep_bbox(sess, x, y, X_train, Y_train, X_test, Y_test,
                              img_rows, img_cols, channels, nb_epochs, batch_size, learning_rate,
                              rng=rng, phase=phase, binary=binary, scale=scale,
                              nb_filters=nb_filters, model_path=model_path,
                              adv=adv, delay=delay, eps=eps)

    model, bbox_preds, accuracies['bbox'], model_path = prep_bbox_out
    '''
    decay_step = lr_step_epoch * num_train_instance / batch_size
    hp = resnet.HParams(batch_size=batch_size,
                        num_classes=nb_classes,
                        num_residual_units=num_residual_units,
                        k=k,
                        weight_decay=l2_weight,
                        initial_lr=initial_lr,
                        decay_step=decay_step,
                        lr_decay=lr_decay,
                        momentum=momentum)

    print(binary)
    binary = True if binary else False
    print(binary)
    network = resnet.ResNet(binary, hp, x, y_oracle, None)
    network.build_model()

    # bbox_preds = network.preds
    bbox_preds = network.probs

    init = tf.global_variables_initializer()
    sess.run(init)

    # Create a saver.
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000)

    if 'model' in model_path.split('/')[-1]:
        saver.restore(sess, model_path)
        print('restored %s' % model_path)
    else:
        saver.restore(sess, tf.train.latest_checkpoint(model_path))
        print('restored %s' % model_path)
    '''
    if os.path.isdir(model_path):
        ckpt = tf.train.get_checkpoint_state(model_path)
        # Restores from checkpoint
        if ckpt and ckpt.model_checkpoint_path:
            print('\tRestore from %s' % ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print('No checkpoint file found in the dir [%s]' % model_path)
            sys.exit(1)
    elif os.path.isfile(model_path):
        print('\tRestore from %s' % model_path)
        saver.restore(sess, model_path)
    else:
        print('No checkpoint file found in the path [%s]' % model_path)
        sys.exit(1)
    '''

    eval_params = {'batch_size': batch_size}
    acc = model_eval(sess,
                     x,
                     y_eval,
                     bbox_preds,
                     X_test,
                     Y_test,
                     phase=phase,
                     args=eval_params)
    print('Test accuracy of black-box on legitimate test examples: %.4f' % acc)

    # Train substitute using method from https://arxiv.org/abs/1602.02697
    print("Training the substitute model.")
    train_sub_out = train_sub(sess,
                              x,
                              y_sub,
                              bbox_preds,
                              X_sub,
                              Y_sub,
                              nb_classes,
                              nb_epochs_s,
                              batch_size,
                              learning_rate,
                              data_aug,
                              lmbda,
                              rng=rng,
                              phase=phase)
    # learning_rate, data_aug, lmbda, rng=rng, phase=phase,
    # model_path=model_path)
    model_sub, preds_sub = train_sub_out

    # Evaluate the substitute model on clean test examples
    eval_params = {'batch_size': batch_size}
    acc = model_eval(sess,
                     x,
                     y_eval,
                     preds_sub,
                     X_test,
                     Y_test,
                     phase=phase,
                     args=eval_params)
    accuracies['sub'] = acc
    print('Test accuracy of substitute on clean examples: ' + str(acc))

    # Initialize the Fast Gradient Sign Method (FGSM) attack object.
    #fgsm_par = {'eps': eps, 'ord': np.inf, 'clip_min': 0., 'clip_max': 1.}
    fgsm_par = {'eps': eps, 'ord': np.inf}
    fgsm = FastGradientMethod(model_sub, sess=sess)

    if targeted:
        from cleverhans.utils import build_targeted_dataset
        adv_inputs, true_labels, adv_ys = build_targeted_dataset(
            X_test, Y_test, np.arange(nb_samples), nb_classes, img_rows,
            img_cols, channels)
        att_batch_size = np.clip(nb_samples * (nb_classes - 1),
                                 a_max=MAX_BATCH_SIZE,
                                 a_min=1)
        nb_adv_per_sample = nb_classes - 1
        yname = "y_target"

    else:
        att_batch_size = np.minimum(nb_samples, MAX_BATCH_SIZE)
        nb_adv_per_sample = 1
        adv_ys = None
        yname = "y"

    # Craft adversarial examples using the substitute
    #eval_params = {'batch_size': att_batch_size}

    if targeted:
        fgsm_par.update({yname: adv_ys})
        x_adv_sub = fgsm.generate_np(adv_inputs, phase, **fgsm_par)
        accuracy = model_eval(sess,
                              x,
                              y,
                              model(x, reuse=True),
                              x_adv_sub,
                              true_labels,
                              phase=phase,
                              args=eval_params)
    else:
        # Evaluate the accuracy of the "black-box" model on adversarial
        # examples
        '''
        x_adv_sub = fgsm.generate(x, phase, **fgsm_par)
        accuracy = model_eval(sess, x, y_eval, model(
            x_adv_sub), X_test, Y_test, phase=phase, args=eval_params)
        '''
    adv_np = fgsm.generate_np(X_test, phase, **fgsm_par)
    adv_acc = model_eval(sess,
                         x,
                         y_eval,
                         bbox_preds,
                         adv_np,
                         Y_test,
                         phase=phase,
                         args=eval_params)
    print('Test accuracy of oracle on adversarial examples generated '
          'using the substitute: %.4f' % adv_acc)

    cln_acc = model_eval(sess,
                         x,
                         y_eval,
                         bbox_preds,
                         X_test,
                         Y_test,
                         phase=phase,
                         args=eval_params)
    print('Test accuracy of oracle on legitimate test examples: %.4f' %
          cln_acc)

    accuracies['bbox_on_sub_adv_ex'] = adv_acc
    '''
    plt.figure()
    plt.imshow(X_test[I])
    plt.title("clean")
    plt.figure()
    plt.imshow(adv_np[I])
    plt.title("adv_np")
    plt.show()
    '''
    return accuracies