Ejemplo n.º 1
0
def destorted_inputs():
    """导入训练数据"""

    images, labels = cifar100_input.distorted_inputs(
        batch_size=FLAGS.batch_size)
    if FLAGS.use_fp16:
        images = tf.cast(images, tf.float16)
        labels = tf.cast(labels, tf.float16)
    return images, labels
Ejemplo n.º 2
0
def get_data(FLAGS, dataset):
    tr_data = None
    tr_label = None
    image_size = None
    channel_num = None
    output_num = None
    if dataset == 'cifar10':
        tr_data, tr_label = cifar10_input.distorted_inputs(
            FLAGS.cifar_data_dir, FLAGS.batch_size)
        image_size = cifar10_input.IMAGE_SIZE
        channel_num = 3
        output_num = 10
    elif dataset == 'svhn':
        tr_data, tr_label = svhn.distorted_inputs(FLAGS.svhn_data_dir,
                                                  FLAGS.batch_size)
        image_size = svhn.IMAGE_SIZE
        channel_num = 3
        output_num = 10
    elif dataset == 'cifar20':
        tr_data, tr_label = cifar100_input.distorted_inputs(
            20, FLAGS.cifar100_data_dir, FLAGS.batch_size)
        image_size = cifar100_input.IMAGE_SIZE
        channel_num = 3
        output_num = 20
    elif dataset == 'mnist1':
        tr_data, tr_label = binary_mnist_input.read_train_data(
            FLAGS, FLAGS.a1, FLAGS.a2)
        image_size = 28
        channel_num = 1
        output_num = 2
    elif dataset == 'mnist2':
        tr_data, tr_label = binary_mnist_input.read_train_data(
            FLAGS, FLAGS.b1, FLAGS.b2)
        image_size = 28
        channel_num = 1
        output_num = 2
    elif dataset == 'mnist3':
        tr_data, tr_label = binary_mnist_input.read_train_data(
            FLAGS, FLAGS.c1, FLAGS.c2)
        image_size = 28
        channel_num = 1
        output_num = 2
    elif dataset == 'mnist4':
        tr_data, tr_label = binary_mnist_input.read_train_data(
            FLAGS, FLAGS.d1, FLAGS.d2)
        image_size = 28
        channel_num = 1
        output_num = 2
    else:
        raise ValueError("No such dataset")

    return tr_data, tr_label, image_size, channel_num, output_num
Ejemplo n.º 3
0
def distorted_inputs():
    """Construct distorted input for CIFAR training using the Reader ops.
    Returns:
      images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
      labels: Labels. 1D tensor of [batch_size] size.
    Raises:
      ValueError: If no data_dir
    """
    if not FLAGS.data_dir:
        raise ValueError('Please supply a data_dir')
    data_dir = os.path.join(FLAGS.data_dir, 'cifar-100-binary')
    images, labels = cifar100_input.distorted_inputs(data_dir=data_dir,
                                                     batch_size=FLAGS.batch_size)
    if FLAGS.use_fp16:
        images = tf.cast(images, tf.float16)
        labels = tf.cast(labels, tf.float16)
    return images, labels
def train():

    images = tf.placeholder(
        tf.float32,
        [batchSize, data_input.IMAGE_SIZE, data_input.IMAGE_SIZE, 3])
    labels = tf.placeholder(tf.int32, [batchSize])

    isTrain = tf.placeholder(tf.bool)
    with tf.variable_scope('train_image'):
        trainImages, trainLabels = data_input.distorted_inputs(Data_DIR, 100)

    trainOperations, predictedClass, loss = getWRNGraph(
        isTrain, images, labels)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        saver = tf.train.Saver()

        for iteration in range(trainingIterations):
            trainImagesVal, trainLabelsVal = sess.run(
                [trainImages, trainLabels])

            _, lossValue = sess.run([trainOperations, loss],
                                    feed_dict={
                                        isTrain: True,
                                        images: trainImagesVal,
                                        labels: trainLabelsVal
                                    })

            if iteration % 1000 == 0:
                print "iteration %d with loss = %f" % (iteration, lossValue)

        coord.request_stop()
        coord.join(threads)
        saver.save(sess, MODEL_DIR + '/' + MODEL_File)
Ejemplo n.º 5
0
    def __init__(self, eval=False):
        # pretty standard/simple DNN loosely based on alexnet since its
        #   simple and old so it trains decently fast on a GTX860M
        # decrease learning rate over time
        # input -> 32x32x3
        # conv1 (f=5, s=1, k=64 relu) -> 32x32x64
        # pool1 (f=3, s=2) -> 16x16x64
        # conv2 (f=5, s=1, k=64 relu) -> 16x16x64
        # pool2 (f=3, s=2) 8x8x64
        # dropout (.5)
        # fc1 (384 relu) -> 1x384
        # fc2 (192 relu) -> 1x192
        # linear -> 1x10

        epochs = 100
        learning_rate = .00000001
        batch_size = 16
        early_stop = False
        num_train = 50000

        f = [3, 3, 3, 3]
        k = [32, 64, 128, 128, 512]

        # conv1
        w1 = self.weight('w1', [f[0], f[0], 3, k[0]])
        b1 = self.bias('b1', [k[0]])

        # conv2
        w2 = self.weight('w2', [f[1], f[1], k[0], k[1]])
        b2 = self.bias('b2', [k[1]])

        # conv3
        w3 = self.weight('w3', [f[2], f[2], k[1], k[2]])
        b3 = self.bias('b3', [k[2]])

        # conv4
        w4 = self.weight('w4', [f[3], f[3], k[2], k[3]])
        b4 = self.bias('b4', [k[3]])

        # fc1
        w5 = self.weight('w5', [8 * 8 * k[3], k[4]])
        b5 = self.bias('b5', [k[4]])

        # fc2
        w6 = self.weight('w6', [k[4], 100])
        b6 = self.bias('b6', [100])

        # linear
        #  w5 = self.weight('w5', [k[3], 10])
        #  b5 = self.bias('b5', [10])
        self.params = (w1, w2, w3, w4, w5, w6)

        if eval:
            images, labels = cifar100_input.inputs(True, 'cifar-100-binary',
                                                   1000)
        else:
            images, labels = cifar100_input.distorted_inputs(
                'cifar-100-binary', batch_size)
        s = [1, 1, 1, 1]

        global_step = tf.train.get_or_create_global_step()
        X_ = self.conv(images, w1, s[0], b1)
        X_ = self.conv(X_, w2, s[1], b2)
        X_ = self.pool(X_, 2, 2)
        #  X_ = tf.nn.dropout(X_, .25)
        #  X_ = self.batch_norm(X_)
        X_ = self.conv(X_, w3, s[2], b3)
        X_ = self.conv(X_, w4, s[3], b4)
        X_ = self.pool(X_, 2, 2)
        #  X_ = tf.nn.dropout(X_, .25)
        #  X_ = self.batch_norm(X_)
        X_ = tf.nn.relu(self.fc(X_, w5, b5))
        X_ = tf.nn.dropout(X_, .5)
        logits = self.fc(X_, w6, b6)

        pred = tf.nn.softmax(logits)
        loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=logits,
                                                    labels=labels))
        #  l2 = tf.reduce_sum([tf.reduce_sum(tf.pow(w,2)) for w in self.params])
        #  loss = loss + weight_decay * l2

        correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(labels, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

        optim = tf.train.AdamOptimizer(learning_rate).minimize(
            loss, global_step=global_step)

        steps_per_epoch = num_train // batch_size

        saver = tf.train.Saver()
        checkpoint = 'checkpoints/model.cifar100.v5.ckpt'
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            try:
                saver.restore(sess, checkpoint)
            except:
                pass
            if eval:
                total_accuracy = 0
                total_loss = 0
                for batch in range(10):
                    a, l, g = sess.run([accuracy, loss, global_step])
                    total_accuracy += a
                    total_loss += l
                print(
                    'global_step {} (epoch {}): test accuracy={}, test loss={}'
                    .format(g, g // steps_per_epoch, total_accuracy / 10,
                            total_loss / 10))
                coord.request_stop()
                coord.join(threads)
                return

            with tqdm(range(steps_per_epoch * epochs)) as t:
                best = 0
                for step in t:
                    epoch, step_in_epoch = divmod(step, steps_per_epoch)
                    if step_in_epoch == 0:
                        saver.save(sess, checkpoint)
                        total_accuracy = 0
                        total_loss = 0

                    a, l, o, g = sess.run([accuracy, loss, optim, global_step])
                    total_accuracy += a
                    total_loss += l

                    t.set_postfix(
                        epoch=g // steps_per_epoch,
                        step=step_in_epoch,
                        acc=total_accuracy / step_in_epoch,
                        loss=total_loss / step_in_epoch,
                    )

            if early_stop:
                saver.restore(sess, checkpoint)
            coord.request_stop()
            coord.join(threads)
Ejemplo n.º 6
0
def train():
    print('[Dataset Configuration]')
    print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
    print('\tNumber of classes: %d' % FLAGS.num_classes)
    print('\tNumber of training images: %d' % FLAGS.num_train_instance)
    print('\tNumber of test images: %d' % FLAGS.num_test_instance)

    print('[Network Configuration]')
    print('\tBatch size: %d' % FLAGS.batch_size)
    print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
    print('\tNetwork width multiplier: %d' % FLAGS.k)

    print('[Optimization Configuration]')
    print('\tL2 loss weight: %f' % FLAGS.l2_weight)
    print('\tThe momentum optimizer: %f' % FLAGS.momentum)
    print('\tInitial learning rate: %f' % FLAGS.initial_lr)
    print('\tEpochs per lr step: %f' % FLAGS.lr_step_epoch)
    print('\tLearning rate decay: %f' % FLAGS.lr_decay)

    print('[Training Configuration]')
    print('\tTrain dir: %s' % FLAGS.train_dir)
    print('\tTraining max steps: %d' % FLAGS.max_steps)
    print('\tSteps per displaying info: %d' % FLAGS.display)
    print('\tSteps per testing: %d' % FLAGS.test_interval)
    print('\tSteps during testing: %d' % FLAGS.test_iter)
    print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
    print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
    print('\tLog device placement: %d' % FLAGS.log_device_placement)


    with tf.Graph().as_default():
        init_step = 0
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get images and labels of CIFAR-100
        with tf.variable_scope('train_image'):
            train_images, train_labels = data_input.distorted_inputs(FLAGS.data_dir, FLAGS.batch_size)
        with tf.variable_scope('test_image'):
            test_images, test_labels = data_input.inputs(True, FLAGS.data_dir, FLAGS.batch_size)

        # Build a Graph that computes the predictions from the inference model.
        images = tf.placeholder(tf.float32, [FLAGS.batch_size, data_input.IMAGE_SIZE, data_input.IMAGE_SIZE, 3])
        labels = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # Build model
        decay_step = FLAGS.lr_step_epoch * FLAGS.num_train_instance / FLAGS.batch_size
        hp = resnet.HParams(batch_size=FLAGS.batch_size,
                            num_classes=FLAGS.num_classes,
                            num_residual_units=FLAGS.num_residual_units,
                            k=FLAGS.k,
                            weight_decay=FLAGS.l2_weight,
                            initial_lr=FLAGS.initial_lr,
                            decay_step=decay_step,
                            lr_decay=FLAGS.lr_decay,
                            momentum=FLAGS.momentum)
        network = resnet.ResNet(hp, images, labels, global_step)
        network.build_model()
        network.build_train_op()

        # Summaries(training)
        train_summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10000)
        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if ckpt and ckpt.model_checkpoint_path:
           print('\tRestore from %s' % ckpt.model_checkpoint_path)
           # Restores from checkpoint
           saver.restore(sess, ckpt.model_checkpoint_path)
           init_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        else:
           print('No checkpoint file found. Start from the scratch.')

        # Start queue runners & summary_writer
        tf.train.start_queue_runners(sess=sess)
        if not os.path.exists(FLAGS.train_dir):
            os.mkdir(FLAGS.train_dir)
        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

        # Training!
        test_best_acc = 0.0
        for step in xrange(init_step, FLAGS.max_steps):
            # Test
            if step % FLAGS.test_interval == 0:
                test_loss, test_acc = 0.0, 0.0
                for i in range(FLAGS.test_iter):
                    test_images_val, test_labels_val = sess.run([test_images, test_labels])
                    loss_value, acc_value = sess.run([network.loss, network.acc],
                                feed_dict={network.is_train:False, images:test_images_val, labels:test_labels_val})
                    test_loss += loss_value
                    test_acc += acc_value
                test_loss /= FLAGS.test_iter
                test_acc /= FLAGS.test_iter
                test_best_acc = max(test_best_acc, test_acc)
                format_str = ('%s: (Test)     step %d, loss=%.4f, acc=%.4f')
                print (format_str % (datetime.now(), step, test_loss, test_acc))

                test_summary = tf.Summary()
                test_summary.value.add(tag='test/loss', simple_value=test_loss)
                test_summary.value.add(tag='test/acc', simple_value=test_acc)
                test_summary.value.add(tag='test/best_acc', simple_value=test_best_acc)
                summary_writer.add_summary(test_summary, step)
                # test_loss_summary = tf.Summary()
                # test_loss_summary.value.add(tag='test/loss', simple_value=test_loss)
                # summary_writer.add_summary(test_loss_summary, step)
                # test_acc_summary = tf.Summary()
                # test_acc_summary.value.add(tag='test/acc', simple_value=test_acc)
                # summary_writer.add_summary(test_acc_summary, step)
                # test_best_acc_summary = tf.Summary()
                # test_best_acc_summary.value.add(tag='test/best_acc', simple_value=test_best_acc)
                # summary_writer.add_summary(test_best_acc_summary, step)
                summary_writer.flush()

            # Train
            start_time = time.time()
            train_images_val, train_labels_val = sess.run([train_images, train_labels])
            _, lr_value, loss_value, acc_value, train_summary_str = \
                    sess.run([network.train_op, network.lr, network.loss, network.acc, train_summary_op],
                        feed_dict={network.is_train:True, images:train_images_val, labels:train_labels_val})
            duration = time.time() - start_time

            assert not np.isnan(loss_value)

            # Display & Summary(training)
            if step % FLAGS.display == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = ('%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
                              'sec/batch)')
                print (format_str % (datetime.now(), step, loss_value, acc_value, lr_value,
                                     examples_per_sec, sec_per_batch))
                summary_writer.add_summary(train_summary_str, step)

            # Save the model checkpoint periodically.
            if (step > init_step and step % FLAGS.checkpoint_interval == 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)