示例#1
0
def train():
    print('[Dataset Configuration]')
    print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
    print('\tNetwork width multiplier: %d' % FLAGS.k)

    print('[Testing Configuration]')
    print('\tCheckpoint path: %s' % FLAGS.ckpt_path)
    print('\tDataset: %s' % ('Training' if FLAGS.train_data else 'Test'))
    print('\tNumber of testing iterations: %d' % FLAGS.test_iter)
    print('\tOutput path: %s' % FLAGS.output)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, data_input.HEIGHT, data_input.WIDTH, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # Build model
        decay_step = FLAGS.lr_step_epoch * FLAGS.num_train_instance / FLAGS.batch_size
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_classes=FLAGS.num_classes,
                            num_residual_units=FLAGS.num_residual_units,
                            k=FLAGS.k,
                            weight_decay=FLAGS.l2_weight,
                            initial_lr=FLAGS.initial_lr,
                            decay_step=decay_step,
                            lr_decay=FLAGS.lr_decay,
                            momentum=FLAGS.momentum)
        network = resnet.ResNet(hp, images, labels, None)
        network.build_model()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        if os.path.isdir(FLAGS.ckpt_path):
            ckpt = tf.train.get_checkpoint_state(FLAGS.ckpt_path)
            # Restores from checkpoint
            if ckpt and ckpt.model_checkpoint_path:
                print('\tRestore from %s' % ckpt.model_checkpoint_path)
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print('No checkpoint file found in the dir [%s]' %
                      FLAGS.ckpt_path)
                sys.exit(1)
        elif os.path.isfile(FLAGS.ckpt_path):
            print('\tRestore from %s' % FLAGS.ckpt_path)
            saver.restore(sess, FLAGS.ckpt_path)
        else:
            print('No checkpoint file found in the path [%s]' %
                  FLAGS.ckpt_path)
            sys.exit(1)

        graph = tf.get_default_graph()
        block_num = 3
        conv_num = 2
        old_kernels_to_cluster = []
        old_kernels_to_add = []
        old_batch_norm = []
        for i in range(1, block_num + 1):
            for j in range(FLAGS.num_residual_units):
                old_kernels_to_cluster.append(get_kernel(i, j, 1, graph, sess))
                old_kernels_to_add.append(get_kernel(i, j, 2, graph, sess))
                old_batch_norm.append(get_batch_norm(i, j, 2, graph, sess))
        #old_batch_norm = old_batch_norm[1:]
        #old_batch_norm.append(get_last_batch_norm(graph, sess))

        new_params = []
        new_width = [
            16,
            int(16 * FLAGS.new_k),
            int(32 * FLAGS.new_k),
            int(64 * FLAGS.new_k)
        ]
        for i in range(len(old_batch_norm)):
            cluster_num = new_width[int(i / 4) + 1]
            cluster_kernels, cluster_indices = cluster_kernel(
                old_kernels_to_cluster[i], cluster_num)
            add_kernels = add_kernel(old_kernels_to_add[i], cluster_indices,
                                     cluster_num)
            cluster_batchs_norm = cluster_batch_norm(old_batch_norm[i],
                                                     cluster_indices,
                                                     cluster_num)
            new_params.append(cluster_kernels)
            for p in range(BATCH_NORM_PARAM_NUM):
                new_params.append(cluster_batchs_norm[p])
            new_params.append(add_kernels)

        # save variables
        init_params = []
        new_param_index = 0
        for var in tf.global_variables():
            update_match = UPDATE_PARAM_REGEX.match(var.name)
            skip_match = SKIP_PARAM_REGEX.match(var.name)
            if update_match and not skip_match:
                print("update {}".format(var.name))
                init_params.append((new_params[new_param_index], var.name))
                new_param_index += 1
            else:
                print("not update {}".format(var.name))
                var_vector = sess.run(var)
                init_params.append((var_vector, var.name))

        #close old graph
        sess.close()
    tf.reset_default_graph()

    # build new graph and eval
    with tf.Graph().as_default():
        # The CIFAR-100 dataset
        with tf.variable_scope('test_image'):
            test_images, test_labels = data_input.input_fn(
                FLAGS.data_dir,
                FLAGS.batch_size,
                train_mode=FLAGS.train_data,
                num_threads=1)

        # The class labels
        with open(os.path.join(FLAGS.data_dir, 'fine_label_names.txt')) as fd:
            classes = [temp.strip() for temp in fd.readlines()]

        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, data_input.HEIGHT, data_input.WIDTH, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        new_network = resnet.ResNet(hp, images, labels, None, init_params,
                                    FLAGS.new_k)
        new_network.build_model()

        init = tf.initialize_all_variables()
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Testing!
        result_ll = [[0, 0] for _ in range(FLAGS.num_classes)
                     ]  # Correct/wrong counts for each class
        test_loss = 0.0, 0.0
        for i in range(FLAGS.test_iter):
            test_images_val, test_labels_val = sess.run(
                [test_images, test_labels])
            preds_val, loss_value, acc_value = sess.run(
                [new_network.preds, new_network.loss, new_network.acc],
                feed_dict={
                    new_network.is_train: False,
                    images: test_images_val,
                    labels: test_labels_val
                })
            test_loss += loss_value
            for j in range(FLAGS.batch_size):
                correct = 0 if test_labels_val[j] == preds_val[j] else 1
                result_ll[test_labels_val[j] % FLAGS.num_classes][correct] += 1
        test_loss /= FLAGS.test_iter

        # Summary display & output
        acc_list = [float(r[0]) / float(r[0] + r[1]) for r in result_ll]
        result_total = np.sum(np.array(result_ll), axis=0)
        acc_total = float(result_total[0]) / np.sum(result_total)

        print('Class    \t\t\tT\tF\tAcc.')
        format_str = '%-31s %7d %7d %.5f'
        for i in range(FLAGS.num_classes):
            print(format_str %
                  (classes[i], result_ll[i][0], result_ll[i][1], acc_list[i]))
        print(format_str %
              ('(Total)', result_total[0], result_total[1], acc_total))

        # Output to file(if specified)
        if FLAGS.output.strip():
            with open(FLAGS.output, 'w') as fd:
                fd.write('Class    \t\t\tT\tF\tAcc.\n')
                format_str = '%-31s %7d %7d %.5f'
                for i in range(FLAGS.num_classes):
                    t, f = result_ll[i]
                    format_str = '%-31s %7d %7d %.5f\n'
                    fd.write(format_str %
                             (classes[i].replace(' ', '-'), t, f, acc_list[i]))
                fd.write(
                    format_str %
                    ('(Total)', result_total[0], result_total[1], acc_total))
示例#2
0
def train():
    print('[Dataset Configuration]')
    print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
    print('\tNetwork width multiplier: %d' % FLAGS.k)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %f' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per testing: %d' % FLAGS.test_interval)
    print('\tSteps during testing: %d' % FLAGS.test_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, data_input.HEIGHT, data_input.WIDTH, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # Build model
        decay_step = FLAGS.lr_step_epoch * FLAGS.num_train_instance / FLAGS.batch_size
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_classes=FLAGS.num_classes,
                            num_residual_units=FLAGS.num_residual_units,
                            k=FLAGS.k,
                            weight_decay=FLAGS.l2_weight,
                            initial_lr=FLAGS.initial_lr,
                            decay_step=decay_step,
                            lr_decay=FLAGS.lr_decay,
                            momentum=FLAGS.momentum)
        network = resnet.ResNet(hp, images, labels, None)
        network.build_model()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        if os.path.isdir(FLAGS.ckpt_path):
            ckpt = tf.train.get_checkpoint_state(FLAGS.ckpt_path)
            # Restores from checkpoint
            if ckpt and ckpt.model_checkpoint_path:
                print('\tRestore from %s' % ckpt.model_checkpoint_path)
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print('No checkpoint file found in the dir [%s]' %
                      FLAGS.ckpt_path)
                sys.exit(1)
        elif os.path.isfile(FLAGS.ckpt_path):
            print('\tRestore from %s' % FLAGS.ckpt_path)
            saver.restore(sess, FLAGS.ckpt_path)
        else:
            print('No checkpoint file found in the path [%s]' %
                  FLAGS.ckpt_path)
            sys.exit(1)

        graph = tf.get_default_graph()
        block_num = 3
        conv_num = 2
        old_kernels_to_cluster = []
        old_kernels_to_add = []
        old_batch_norm = []
        for i in range(1, block_num + 1):
            for j in range(FLAGS.num_residual_units):
                old_kernels_to_cluster.append(get_kernel(i, j, 1, graph, sess))
                old_kernels_to_add.append(get_kernel(i, j, 2, graph, sess))
                old_batch_norm.append(get_batch_norm(i, j, 2, graph, sess))
        #old_batch_norm = old_batch_norm[1:]
        #old_batch_norm.append(get_last_batch_norm(graph, sess))

        new_params = []
        new_width = [
            16,
            int(16 * FLAGS.new_k),
            int(32 * FLAGS.new_k),
            int(64 * FLAGS.new_k)
        ]
        for i in range(len(old_batch_norm)):
            cluster_num = new_width[int(i / 4) + 1]
            cluster_kernels, cluster_indices = cluster_kernel(
                old_kernels_to_cluster[i], cluster_num)
            add_kernels = add_kernel(old_kernels_to_add[i], cluster_indices,
                                     cluster_num)
            cluster_batchs_norm = cluster_batch_norm(old_batch_norm[i],
                                                     cluster_indices,
                                                     cluster_num)
            new_params.append(cluster_kernels)
            for p in range(BATCH_NORM_PARAM_NUM):
                new_params.append(cluster_batchs_norm[p])
            new_params.append(add_kernels)

        # save variables
        init_params = []
        new_param_index = 0
        for var in tf.global_variables():
            update_match = UPDATE_PARAM_REGEX.match(var.name)
            skip_match = SKIP_PARAM_REGEX.match(var.name)
            if update_match and not skip_match:
                print("update {}".format(var.name))
                init_params.append((new_params[new_param_index], var.name))
                new_param_index += 1
            else:
                print("not update {}".format(var.name))
                var_vector = sess.run(var)
                init_params.append((var_vector, var.name))

        #close old graph
        sess.close()
    tf.reset_default_graph()

    # build new graph and eval
    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of CIFAR-100
        with tf.variable_scope('train_image'):
            train_images, train_labels = data_input.input_fn(FLAGS.data_dir,
                                                             FLAGS.batch_size,
                                                             train_mode=True)
        with tf.variable_scope('test_image'):
            test_images, test_labels = data_input.input_fn(FLAGS.data_dir,
                                                           FLAGS.batch_size,
                                                           train_mode=False)

        # The class labels
        with open(os.path.join(FLAGS.data_dir, 'fine_label_names.txt')) as fd:
            classes = [temp.strip() for temp in fd.readlines()]

        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, data_input.HEIGHT, data_input.WIDTH, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        new_network = resnet.ResNet(hp, images, labels, global_step,
                                    init_params, FLAGS.new_k)
        new_network.build_model()
        new_network.build_train_op()

        train_summary_op = tf.summary.merge_all()

        init = tf.initialize_all_variables()
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print('\tRestore from %s' % ckpt.model_checkpoint_path)
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            init_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        else:
            print('No checkpoint file found. Start from the scratch.')
        sys.stdout.flush()
        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)
        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
        # Training!
        test_best_acc = 0.0
        for step in range(init_step, FLAGS.max_steps):
            # Test
            if step % FLAGS.test_interval == 0:
                test_loss, test_acc = 0.0, 0.0
                for i in range(FLAGS.test_iter):
                    test_images_val, test_labels_val = sess.run(
                        [test_images, test_labels])
                    loss_value, acc_value = sess.run(
                        [new_network.loss, new_network.acc],
                        feed_dict={
                            new_network.is_train: False,
                            images: test_images_val,
                            labels: test_labels_val
                        })
                    test_loss += loss_value
                    test_acc += acc_value
                test_loss /= FLAGS.test_iter
                test_acc /= FLAGS.test_iter
                test_best_acc = max(test_best_acc, test_acc)
                format_str = ('%s: (Test)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), step, test_loss, test_acc))
                sys.stdout.flush()
                test_summary = tf.Summary()
                test_summary.value.add(tag='test/loss', simple_value=test_loss)
                test_summary.value.add(tag='test/acc', simple_value=test_acc)
                test_summary.value.add(tag='test/best_acc',
                                       simple_value=test_best_acc)
                summary_writer.add_summary(test_summary, step)
                summary_writer.flush()
            # Train
            start_time = time.time()
            train_images_val, train_labels_val = sess.run(
                [train_images, train_labels])
            _, lr_value, loss_value, acc_value, train_summary_str = \
                    sess.run([new_network.train_op, new_network.lr, new_network.loss, new_network.acc, train_summary_op],
                        feed_dict={new_network.is_train:True, images:train_images_val, labels:train_labels_val})
            duration = time.time() - start_time
            assert not np.isnan(loss_value)
            # Display & Summary(training)
            if step % FLAGS.display == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, acc_value, lr_value,
                       examples_per_sec, sec_per_batch))
                sys.stdout.flush()
                summary_writer.add_summary(train_summary_str, step)
            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
示例#3
0
def train():
    print('[Dataset Configuration]')
    print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
    print('\tNetwork width multiplier: %d' % FLAGS.k)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %f' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per testing: %d' % FLAGS.test_interval)
    print('\tSteps during testing: %d' % FLAGS.test_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of CIFAR-100
        with tf.variable_scope('train_image'):
            train_images, train_labels = data_input.input_fn(FLAGS.data_dir,
                                                             FLAGS.batch_size,
                                                             train_mode=True)
        with tf.variable_scope('test_image'):
            test_images, test_labels = data_input.input_fn(FLAGS.data_dir,
                                                           FLAGS.batch_size,
                                                           train_mode=False)

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, data_input.HEIGHT, data_input.WIDTH, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # Build model
        decay_step = FLAGS.lr_step_epoch * FLAGS.num_train_instance / FLAGS.batch_size
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_classes=FLAGS.num_classes,
                            num_residual_units=FLAGS.num_residual_units,
                            k=FLAGS.k,
                            weight_decay=FLAGS.l2_weight,
                            initial_lr=FLAGS.initial_lr,
                            decay_step=decay_step,
                            lr_decay=FLAGS.lr_decay,
                            momentum=FLAGS.momentum)
        network = resnet.ResNet(hp, images, labels, global_step)
        network.build_model()
        network.build_train_op()

        # Summaries(training)
        train_summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print('\tRestore from %s' % ckpt.model_checkpoint_path)
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            init_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        else:
            print('No checkpoint file found. Start from the scratch.')

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)
        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        # Training!
        test_best_acc = 0.0
        for step in range(init_step, FLAGS.max_steps):
            # Test
            if step % FLAGS.test_interval == 0:
                test_loss, test_acc = 0.0, 0.0
                for i in range(FLAGS.test_iter):
                    test_images_val, test_labels_val = sess.run(
                        [test_images, test_labels])
                    loss_value, acc_value = sess.run(
                        [network.loss, network.acc],
                        feed_dict={
                            network.is_train: False,
                            images: test_images_val,
                            labels: test_labels_val
                        })
                    test_loss += loss_value
                    test_acc += acc_value
                test_loss /= FLAGS.test_iter
                test_acc /= FLAGS.test_iter
                test_best_acc = max(test_best_acc, test_acc)
                format_str = ('%s: (Test)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), step, test_loss, test_acc))

                test_summary = tf.Summary()
                test_summary.value.add(tag='test/loss', simple_value=test_loss)
                test_summary.value.add(tag='test/acc', simple_value=test_acc)
                test_summary.value.add(tag='test/best_acc',
                                       simple_value=test_best_acc)
                summary_writer.add_summary(test_summary, step)
                # test_loss_summary = tf.Summary()
                # test_loss_summary.value.add(tag='test/loss', simple_value=test_loss)
                # summary_writer.add_summary(test_loss_summary, step)
                # test_acc_summary = tf.Summary()
                # test_acc_summary.value.add(tag='test/acc', simple_value=test_acc)
                # summary_writer.add_summary(test_acc_summary, step)
                # test_best_acc_summary = tf.Summary()
                # test_best_acc_summary.value.add(tag='test/best_acc', simple_value=test_best_acc)
                # summary_writer.add_summary(test_best_acc_summary, step)
                summary_writer.flush()

            # Train
            start_time = time.time()
            train_images_val, train_labels_val = sess.run(
                [train_images, train_labels])
            _, lr_value, loss_value, acc_value, train_summary_str = \
                    sess.run([network.train_op, network.lr, network.loss, network.acc, train_summary_op],
                        feed_dict={network.is_train:True, images:train_images_val, labels:train_labels_val})
            duration = time.time() - start_time

            assert not np.isnan(loss_value)

            # Display & Summary(training)
            if step % FLAGS.display == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, acc_value, lr_value,
                       examples_per_sec, sec_per_batch))
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
示例#4
0
def train():
    print('[Dataset Configuration]')
    print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
    print('\tNetwork width multiplier: %d' % FLAGS.k)

    print('[Testing Configuration]')
    print('\tCheckpoint path: %s' % FLAGS.ckpt_path)
    print('\tDataset: %s' % ('Training' if FLAGS.train_data else 'Test'))
    print('\tNumber of testing iterations: %d' % FLAGS.test_iter)
    print('\tOutput path: %s' % FLAGS.output)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():
        # The CIFAR-100 dataset
        with tf.variable_scope('test_image'):
            test_images, test_labels = data_input.input_fn(
                FLAGS.data_dir,
                FLAGS.batch_size,
                train_mode=FLAGS.train_data,
                num_threads=1)

        # The class labels
        with open(os.path.join(FLAGS.data_dir, 'fine_label_names.txt')) as fd:
            classes = [temp.strip() for temp in fd.readlines()]

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, data_input.HEIGHT, data_input.WIDTH, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # Build model
        decay_step = FLAGS.lr_step_epoch * FLAGS.num_train_instance / FLAGS.batch_size
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_classes=FLAGS.num_classes,
                            num_residual_units=FLAGS.num_residual_units,
                            k=FLAGS.k,
                            weight_decay=FLAGS.l2_weight,
                            initial_lr=FLAGS.initial_lr,
                            decay_step=decay_step,
                            lr_decay=FLAGS.lr_decay,
                            momentum=FLAGS.momentum)
        network = resnet.ResNet(hp, images, labels, None, new_k=FLAGS.new_k)
        network.build_model()
        # network.build_train_op()  # NO training op

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        if os.path.isdir(FLAGS.ckpt_path):
            ckpt = tf.train.get_checkpoint_state(FLAGS.ckpt_path)
            # Restores from checkpoint
            if ckpt and ckpt.model_checkpoint_path:
                print('\tRestore from %s' % ckpt.model_checkpoint_path)
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print('No checkpoint file found in the dir [%s]' %
                      FLAGS.ckpt_path)
                sys.exit(1)
        elif os.path.isfile(FLAGS.ckpt_path):
            print('\tRestore from %s' % FLAGS.ckpt_path)
            saver.restore(sess, FLAGS.ckpt_path)
        else:
            print('No checkpoint file found in the path [%s]' %
                  FLAGS.ckpt_path)
            sys.exit(1)

        # Start queue runners
        tf.train.start_queue_runners(sess=sess)

        # Testing!
        result_ll = [[0, 0] for _ in range(FLAGS.num_classes)
                     ]  # Correct/wrong counts for each class
        test_loss = 0.0, 0.0
        for i in range(FLAGS.test_iter):
            test_images_val, test_labels_val = sess.run(
                [test_images, test_labels])
            preds_val, loss_value, acc_value = sess.run(
                [network.preds, network.loss, network.acc],
                feed_dict={
                    network.is_train: False,
                    images: test_images_val,
                    labels: test_labels_val
                })
            test_loss += loss_value
            for j in range(FLAGS.batch_size):
                correct = 0 if test_labels_val[j] == preds_val[j] else 1
                result_ll[test_labels_val[j] % FLAGS.num_classes][correct] += 1
        test_loss /= FLAGS.test_iter

        # Summary display & output
        acc_list = [float(r[0]) / float(r[0] + r[1]) for r in result_ll]
        result_total = np.sum(np.array(result_ll), axis=0)
        acc_total = float(result_total[0]) / np.sum(result_total)

        print('Class    \t\t\tT\tF\tAcc.')
        format_str = '%-31s %7d %7d %.5f'
        for i in range(FLAGS.num_classes):
            print(format_str %
                  (classes[i], result_ll[i][0], result_ll[i][1], acc_list[i]))
        print(format_str %
              ('(Total)', result_total[0], result_total[1], acc_total))

        # Output to file(if specified)
        if FLAGS.output.strip():
            with open(FLAGS.output, 'w') as fd:
                fd.write('Class    \t\t\tT\tF\tAcc.\n')
                format_str = '%-31s %7d %7d %.5f'
                for i in range(FLAGS.num_classes):
                    t, f = result_ll[i]
                    format_str = '%-31s %7d %7d %.5f\n'
                    fd.write(format_str %
                             (classes[i].replace(' ', '-'), t, f, acc_list[i]))
                fd.write(
                    format_str %
                    ('(Total)', result_total[0], result_total[1], acc_total))
示例#5
0
def train():
    print('[Dataset Configuration]')
    print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
    print('\tNetwork width multiplier: %d' % FLAGS.k)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %f' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per testing: %d' % FLAGS.test_interval)
    print('\tSteps during testing: %d' % FLAGS.test_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)

    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of CIFAR-100
        with tf.variable_scope('train_image'):
            train_images, train_labels = data_input.input_fn(FLAGS.data_dir,
                                                             FLAGS.batch_size,
                                                             train_mode=True)
        with tf.variable_scope('test_image'):
            test_images, test_labels = data_input.input_fn(FLAGS.data_dir,
                                                           FLAGS.batch_size,
                                                           train_mode=False)

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, data_input.HEIGHT, data_input.WIDTH, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # Build model
        decay_step = FLAGS.lr_step_epoch * FLAGS.num_train_instance / FLAGS.batch_size
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_classes=FLAGS.num_classes,
                            num_residual_units=FLAGS.num_residual_units,
                            k=FLAGS.k,
                            weight_decay=FLAGS.l2_weight,
                            initial_lr=FLAGS.initial_lr,
                            decay_step=decay_step,
                            lr_decay=FLAGS.lr_decay,
                            momentum=FLAGS.momentum)
        network = resnet.ResNet(hp, images, labels, global_step)
        network.build_model()
        network.build_train_op()

        # Summaries(training)
        train_summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=False))
        sess.run(init)

        # fareed
        dot_rep = graph_to_dot(tf.get_default_graph())
        with open('profs/wrn.dot', 'w') as fwr:
            fwr.write(str(dot_rep))
        # trace_level=tf.RunOptions.FULL_TRACE,
        options = tf.RunOptions(report_tensor_allocations_upon_oom=True)
        run_metadata = tf.RunMetadata()

        operations_tensors = {}
        operations_attributes = {}
        operations_names = tf.get_default_graph().get_operations()
        count1 = 0
        count2 = 0

        for operation in operations_names:
            operation_name = operation.name
            operations_info = tf.get_default_graph().get_operation_by_name(
                operation_name).values()

            try:
                operations_attributes[operation_name] = []
                operations_attributes[operation_name].append(operation.type)
                operations_attributes[operation_name].append(
                    tf.get_default_graph().get_tensor_by_name(
                        operation_name + ':0').dtype._is_ref_dtype)
            except:
                pass
            if len(operations_info) > 0:
                if not (operations_info[0].shape.ndims is None):
                    operation_shape = operations_info[0].shape.as_list()
                    operation_dtype_size = operations_info[0].dtype.size
                    if not (operation_dtype_size is None):
                        operation_no_of_elements = 1
                        for dim in operation_shape:
                            if not (dim is None):
                                operation_no_of_elements = operation_no_of_elements * dim
                        total_size = operation_no_of_elements * operation_dtype_size
                        operations_tensors[operation_name] = total_size
                    else:
                        count1 = count1 + 1
                else:
                    count1 = count1 + 1
                    operations_tensors[operation_name] = -1

                #   print('no shape_1: ' + operation_name)
                #  print('no shape_2: ' + str(operations_info))
                #  operation_namee = operation_name + ':0'
                # tensor = tf.get_default_graph().get_tensor_by_name(operation_namee)
                # print('no shape_3:' + str(tf.shape(tensor)))
                # print('no shape:' + str(tensor.get_shape()))

            else:
                # print('no info :' + operation_name)
                # operation_namee = operation.name + ':0'
                count2 = count2 + 1
                operations_tensors[operation_name] = -1

                # try:
                #   tensor = tf.get_default_graph().get_tensor_by_name(operation_namee)
                # print(tensor)
                # print(tf.shape(tensor))
                # except:
                # print('no tensor: ' + operation_namee)
        print(count1)
        print(count2)
        with open('./profs/tensors_sz_32.txt', 'w') as f:
            for tensor, size in operations_tensors.items():
                f.write('"' + tensor + '"::' + str(size) + '\n')

        with open('./profs/operations_attributes.txt', 'w') as f:
            for op, attrs in operations_attributes.items():
                strr = op
                for attr in attrs:
                    strr += '::' + str(attr)
                strr += '\n'
                f.write(strr)

        # end fareed
        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print('\tRestore from %s' % ckpt.model_checkpoint_path)
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            init_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        else:
            print('No checkpoint file found. Start from the scratch.')

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)
        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        # Training!
        test_best_acc = 0.0
        for step in range(init_step, FLAGS.max_steps):
            # Test
            if step % FLAGS.test_interval == 777:
                test_loss, test_acc = 0.0, 0.0
                for i in range(FLAGS.test_iter):
                    test_images_val, test_labels_val = sess.run(
                        [test_images, test_labels])
                    loss_value, acc_value = sess.run(
                        [network.loss, network.acc],
                        feed_dict={
                            network.is_train: False,
                            images: test_images_val,
                            labels: test_labels_val
                        })
                    test_loss += loss_value
                    test_acc += acc_value
                test_loss /= FLAGS.test_iter
                test_acc /= FLAGS.test_iter
                test_best_acc = max(test_best_acc, test_acc)
                format_str = ('%s: (Test)     step %d, loss=%.4f, acc=%.4f')
                print(format_str % (datetime.now(), step, test_loss, test_acc))

                test_summary = tf.Summary()
                test_summary.value.add(tag='test/loss', simple_value=test_loss)
                test_summary.value.add(tag='test/acc', simple_value=test_acc)
                test_summary.value.add(tag='test/best_acc',
                                       simple_value=test_best_acc)
                summary_writer.add_summary(test_summary, step)
                # test_loss_summary = tf.Summary()
                # test_loss_summary.value.add(tag='test/loss', simple_value=test_loss)
                # summary_writer.add_summary(test_loss_summary, step)
                # test_acc_summary = tf.Summary()
                # test_acc_summary.value.add(tag='test/acc', simple_value=test_acc)
                # summary_writer.add_summary(test_acc_summary, step)
                # test_best_acc_summary = tf.Summary()
                # test_best_acc_summary.value.add(tag='test/best_acc', simple_value=test_best_acc)
                # summary_writer.add_summary(test_best_acc_summary, step)
                summary_writer.flush()

            # Train
            # fareed
            if step % 10 == 1:
                train_images_val, train_labels_val = sess.run(
                    [train_images, train_labels],
                    run_metadata=run_metadata,
                    options=options)
                _, lr_value, loss_value, acc_value, train_summary_str = sess.run(
                    [
                        network.train_op, network.lr, network.loss,
                        network.acc, train_summary_op
                    ],
                    feed_dict={
                        network.is_train: True,
                        images: train_images_val,
                        labels: train_labels_val
                    },
                    run_metadata=run_metadata,
                    options=options)
                profile(run_metadata, step)

                if step == 1:
                    options_mem = tf.profiler.ProfileOptionBuilder.time_and_memory(
                    )
                    options_mem["min_bytes"] = 0
                    options_mem["min_micros"] = 0
                    options_mem["output"] = 'file:outfile=./profs/mem.txt'
                    options_mem["select"] = ("bytes", "peak_bytes",
                                             "output_bytes", "residual_bytes")
                    mem = tf.profiler.profile(tf.get_default_graph(),
                                              run_meta=run_metadata,
                                              cmd="scope",
                                              options=options_mem)
                    with open('profs/mem_2.txt', 'w') as f:
                        f.write(str(mem))
                # end fareed
            else:
                start_time = time.time()
                train_images_val, train_labels_val = sess.run(
                    [train_images, train_labels])
                _, lr_value, loss_value, acc_value, train_summary_str = sess.run(
                    [
                        network.train_op, network.lr, network.loss,
                        network.acc, train_summary_op
                    ],
                    feed_dict={
                        network.is_train: True,
                        images: train_images_val,
                        labels: train_labels_val
                    },
                    options=options)
                duration = time.time() - start_time

            assert not np.isnan(loss_value)

            # Display & Summary(training)
            if step % FLAGS.display == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, acc_value, lr_value,
                       examples_per_sec, sec_per_batch))
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval
                    == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)