Exemple #1
0
def train():

    print(FLAGS.train_dir)

    with tf.Session() as sess:

        global_step = tf.contrib.framework.get_or_create_global_step()

        images = tf.placeholder(tf.float32,
                                shape=(FLAGS.batch_size, IMAGE_SIZE,
                                       IMAGE_SIZE, 3))
        labels = tf.placeholder(tf.int32, shape=(FLAGS.batch_size))
        indexes = tf.placeholder(tf.int32, shape=(FLAGS.batch_size))
        #mode_eval = tf.placeholder(tf.bool, shape=())
        keep_prob = tf.placeholder(tf.float32)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = vgg.inference(images, keep_prob)

        # Calculate loss.
        loss = vgg.loss(logits, labels)

        top_k_op = tf.nn.in_top_k(logits, labels, 1)

        prediction = tf.argmax(logits, 1)

        #tf.summary.scalar('prediction', loss)

        cmatix = tf.contrib.metrics.confusion_matrix(prediction, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = vgg.train(loss, global_step)

        #train = tf.train.GradientDescentOptimizer(0.00001).minimize(loss)

        tf.summary.scalar('dropout_keep_probability', keep_prob)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
        writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
        # summary_writer_validation = tf.summary.FileWriter(FLAGS.validate_dir)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        sess.run(init)

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables())

        loss_train = np.array([])
        loss_valid = np.array([])
        precision_test = np.array([])

        steps_train = np.array([])
        steps_valid = np.array([])
        steps_precision = np.array([])

        confusion_matrix_predictions = np.array([])
        confusion_matrix_labels = np.array([])

        EPOCH = 0
        start_time_global = time.time()

        for step in xrange(FLAGS.max_steps):

            #if step > 100: FLAGS.__setattr__("INITIAL_LEARNING_RATE", 0.001)

            if (step % EPOCHS_NUM == 0) and step > 300:
                print("validating")

                #assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

                if step != 0: EPOCH = EPOCH + 1

                # feeding data for validation

                images_batch, labels_batch, index_batch = sess.run(
                    [images_v, labels_v, indexs_v])

                # Run model
                _, loss_value = sess.run(
                    [train_op, loss],
                    feed_dict={
                        images: images_batch,
                        labels: labels_batch,
                        indexes: index_batch,
                        keep_prob: 1.0
                    })

                print('%s: loss = %.5f' % (datetime.now(), loss_value))

                loss_valid = np.concatenate((loss_valid, [loss_value]))
                steps_valid = np.concatenate((steps_valid, [EPOCH]))

            else:

                #print("here")
                #print (step)

                start_time = time.time()

                # feed data for training

                images_batch, labels_batch, index_batch = sess.run(
                    [images_t, labels_t, indexs_t])

                # Run model
                _, loss_value, summary_str = sess.run(
                    [train_op, loss, summary_op],
                    feed_dict={
                        images: images_batch,
                        labels: labels_batch,
                        indexes: index_batch,
                        keep_prob: 0.5
                    })

                duration = time.time() - start_time

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                if step % 10 == 0:
                    #summary_str = sess.run([summary_op],
                    #                     feed_dict={images: images_batch, labels: labels_batch, indexes: index_batch, keep_prob: 0.5})
                    writer.add_summary(summary_str, step)

                if step % 200 == 0:
                    num_examples_per_step = FLAGS.batch_size
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), step, loss_value,
                                        examples_per_sec, sec_per_batch))

                if (step - 2) % EPOCHS_NUM == 0:
                    loss_train = np.concatenate((loss_train, [loss_value]))
                    steps_train = np.concatenate((steps_train, [EPOCH]))

                # Save the model checkpoint periodically.
                if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                    checkpoint_path = os.path.join(FLAGS.train_dir,
                                                   'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)

                    np.savez(FLAGS.train_dir + '_losses.npz',
                             steps_train=steps_train,
                             loss_train=loss_train,
                             steps_valid=steps_valid,
                             loss_valid=loss_valid,
                             precision=precision_test,
                             steps_precision=steps_precision,
                             confusion_matrix_predictions=
                             confusion_matrix_predictions,
                             confusion_matrix_labels=confusion_matrix_labels)

            if EPOCH == 400:
                break

        final_time_global = time.time()

        print("Finish")

        print(final_time_global - start_time_global)

        sess.close()
Exemple #2
0
def train():
    with tf.Graph().as_default():

        data_set = cifar10.CIFAR10()
        images = data_set.load(FLAGS.data_path)

        global_step = tf.Variable(0, trainable=False)

        random_z = vgg.inputs()

        D_logits_real, D_logits_fake, D_logits_fake_for_G, \
        D_sigmoid_real, D_sigmoid_fake, D_sigmoid_fake_for_G = \
          vgg.inference(images, random_z)

        G_loss, D_loss = vgg.loss_l2(D_logits_real, D_logits_fake,
                                     D_logits_fake_for_G)

        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'g_' in var.name]
        D_vars = [var for var in t_vars if 'd_' in var.name]

        G_train_op, D_train_op = vgg.train(G_loss, D_loss, G_vars, D_vars,
                                           global_step)

        sampler = vgg.sampler(random_z)

        #summary_op = tf.merge_all_summaries()

        sess = sess_init()

        tf.train.start_queue_runners(sess=sess)

        #summary_writer = tf.train.SummaryWriter(FLAGS.log_dir, sess.graph)

        saver = tf.train.Saver()

        for step in xrange(1, FLAGS.max_steps + 1):
            batch_z = np.random.uniform(
                -1, 1, [FLAGS.batch_size, FLAGS.z_dim]).astype(np.float32)

            _, errD = sess.run([D_train_op, D_loss],
                               feed_dict={random_z: batch_z})

            _, errG = sess.run([G_train_op, G_loss],
                               feed_dict={random_z: batch_z})

            if step % 100 == 0:
                print "step = %d, errD = %f, errG = %f" % (step, errD, errG)

            if np.mod(step, 1000) == 0:
                samples = sess.run(sampler, feed_dict={random_z: batch_z})
                save_images(samples, [8, 8],
                            './samples/train_{:d}.bmp'.format(step))

                #      if step % 1000 == 0:
                #        summary_str = sess.run(summary_op,
                #            feed_dict={random_z: batch_z})
                #        summary_writer.add_summary(summary_str, step)

            if step % 10000 == 0:
                saver.save(
                    sess, '{0}/vgg-{1}.model'.format(FLAGS.checkpoint_dir,
                                                     step), global_step)
Exemple #3
0
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)

    num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))

    images_train, labels_train = vgg.distorted_inputs()

    images_val, labels_val = vgg.inputs(eval_data='test')

    is_training = tf.placeholder('bool', [], name='is_training')

    images, labels = tf.cond(is_training,
        lambda: (images_train, labels_train),
        lambda: (images_val, labels_val))


    # Build a Graph that computes the logits predictions from the
    # inference model.
    graph =  vgg.inference(images,gpu,is_training)
    logits = graph['s']
    params = graph['p']
    logits = tf.transpose(logits)
    # Calculate loss.
    loss = vgg.loss(logits, labels)


    logits_r = tf.reshape(logits,[vgg.batch_size])

    diff = vgg.ang_diff(logits_r,labels)

    true_count = tf.reduce_sum(tf.cast(tf.less(diff,25),tf.uint8))

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = vgg.train(loss, global_step)

    # Create a saver.
    saver = tf.train.Saver(tf.global_variables())
    
    summary = tf.Summary()
    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.summary.merge_all()

    # Build an initialization operation to run below.
    init = tf.global_variables_initializer()

    # Start running operations on the Graph.

    
    config = tf.ConfigProto(allow_soft_placement = True)
    #config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')

    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)
    summary_writer = tf.compat.v1.summary.FileWriter(FLAGS.eval_dir, sess.graph)

    if (FLAGS.Resume):
      ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
      if ckpt and ckpt.model_checkpoint_path:
        # Restores from checkpoint
        saver.restore(sess, ckpt.model_checkpoint_path)
        global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        
        sub_saver =  tf.train.Saver(graph['v'])
        sub_saver.save(sess, FLAGS.train_dir)
      else:
        print('No checkpoint file found')
        return
    else:
      sess.run(init,{ is_training: False })
      load_weights(params,'vgg19.npy', sess)

    test_iters = 11
    total_sample_count = test_iters * FLAGS.batch_size 
    
    for step in range(FLAGS.max_steps):
      start_time = time.time()
      _, loss_value = sess.run([train_op, loss],{ is_training: True })
      duration = time.time() - start_time
      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step > 1 and step % 250 == 0:
        num_examples_per_step = FLAGS.batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch))
        summary_str = sess.run(summary_op,{ is_training: False })
        summary_writer.add_summary(summary_str, step)
        
        summary.ParseFromString(sess.run(summary_op,{ is_training: False }))
        summary_writer.add_summary(summary, step)


      
      if step > 1 and step % 1000== 0 or (step + 1) == FLAGS.max_steps:     

        true_count_sum = 0  # Counts the number of correct predictions.
        diffs = np.array([])
        for i in range(test_iters): 
          true_count_ ,diff_= sess.run([true_count,tf.unstack(diff)],{ is_training: False })
          true_count_sum += true_count_
          diffs = np.append(diffs,diff_)
          
        diffs_var  = np.var(diffs)
        diffs_mean = np.mean(diffs)
        
        # Compute precision @ 1.
        precision = true_count_sum / total_sample_count
        print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
        summary.ParseFromString(sess.run(summary_op,{ is_training: False }))
        summary.value.add(tag='Precision @ 1', simple_value=precision)
        summary.value.add(tag='diffs_var', simple_value=diffs_var)
        summary.value.add(tag='diffs_mean', simple_value=diffs_mean)  
        summary_writer.add_summary(summary, step)
        # Save the model checkpoint periodically.
        saver.save(sess, checkpoint_path, global_step=step)
Exemple #4
0
def train():
    sys.stdout.write("\033[93m")  # yellow message

    print("Load and test model")

    sys.stdout.write("\033[0;0m")

    with tf.Session() as sess:

        global_step = tf.contrib.framework.get_or_create_global_step()

        images = tf.placeholder(tf.float32,
                                shape=(FLAGS.batch_size, IMAGE_SIZE,
                                       IMAGE_SIZE, 3))
        labels = tf.placeholder(tf.int32, shape=(FLAGS.batch_size))
        indexes = tf.placeholder(tf.int32, shape=(FLAGS.batch_size))
        # mode_eval = tf.placeholder(tf.bool, shape=())
        keep_prob = tf.placeholder(tf.float32)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = vgg.inference(images, keep_prob)

        # Calculate loss.
        loss = vgg.loss(logits, labels)

        top_k_op = tf.nn.in_top_k(logits, labels, 1)

        prediction = tf.argmax(logits, 1)

        #cmatix = tf.contrib.metrics.confusion_matrix(prediction, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = vgg.train(loss, global_step)

        # train = tf.train.GradientDescentOptimizer(0.00001).minimize(loss)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables())

        # Restore the moving average version of the learned variables for eval.

        #variable_averages = tf.train.ExponentialMovingAverage(
        #    vgg.MOVING_AVERAGE_DECAY)
        #variables_to_restore = variable_averages.variables_to_restore()
        #saver = tf.train.Saver(variables_to_restore)

        #saver = tf.train.import_meta_graph('/home/mikelf/Datasets/T-lessV2/restore_models/model.ckpt-61000.meta')

        saver.restore(
            sess,
            "/home/mikelf/experiments/full_test/vgg_scratch/100p/checkpoint/vgg_train_rgb_16bs01lr_SGD_100p/model.ckpt-85000"
        )

        #sess.run(tf.global_variables_initializer())

        print("Model restored.")

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        ##init = tf.global_variables_initializer()

        # Start running operations on the Graph.
        # sess = tf.Session(config=tf.ConfigProto(
        #    log_device_placement=FLAGS.log_device_placement))

        ##sess.run(init)

        coord = tf.train.Coordinator()

        # Start the queue runners.
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
        summary_writer_train = tf.summary.FileWriter(FLAGS.train_dir,
                                                     sess.graph)
        # summary_writer_validation = tf.summary.FileWriter(FLAGS.validate_dir)

        loss_train = np.array([])
        loss_valid = np.array([])
        precision_test = np.array([])

        steps_train = np.array([])
        steps_valid = np.array([])
        steps_precision = np.array([])

        EPOCH = 0
        start_time_global = time.time()

        print("getting precision on test dataset")

        # assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

        # feeding data for evaluation

        num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
        true_count = 0  # Counts the number of correct predictions.
        total_sample_count = num_iter * FLAGS.batch_size
        step = 0
        x = []
        cf_matrix_array = []
        labels_array = []

        while step < num_iter:
            images_batch, labels_batch, index_batch = sess.run(
                [images_p, labels_p, indexs_p])

            predictions, cf_matrix = sess.run(
                [top_k_op, prediction],
                feed_dict={
                    images: images_batch,
                    labels: labels_batch,
                    indexes: index_batch,
                    keep_prob: 1.0
                })

            true_count += np.sum(predictions)
            step += 1
            x.extend(index_batch)
            cf_matrix_array = np.append(cf_matrix_array, cf_matrix, axis=0)
            labels_array = np.append(labels_array, labels_batch, axis=0)

        print(cf_matrix_array.shape)

        print(len(x))
        dupes = [xa for n, xa in enumerate(x) if xa in x[:n]]
        # print(sorted(dupes))
        print(len(dupes))

        precision = true_count / total_sample_count

        print('%s: precision @ 1 = %.5f' % (datetime.now(), precision))

        precision_test = np.concatenate((precision_test, [precision]))
        steps_precision = np.concatenate((steps_precision, [EPOCH]))

        final_time_global = time.time()

        print("Finish")

        print(final_time_global - start_time_global)

        cnf_matrix = confusion_matrix(labels_array, cf_matrix_array)
        np.set_printoptions(precision=2)

        class_names = [
            '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',
            '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23',
            '24', '25', '26', '27', '28', '29', '30'
        ]

        # Plot non-normalized confusion matrix
        plt.figure()
        plot_confusion_matrix(cnf_matrix,
                              classes=class_names,
                              title='Confusion matrix, without normalization')

        # Plot normalized confusion matrix
        #plt.figure()
        #plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
        #                      title='Normalized confusion matrix')

        plt.show()

        coord.request_stop()
        coord.join(threads)

        sess.close()