Exemple #1
0
def do_train(sess, args):
    # set CPU as the default device for the graph. Some of the operations will be moved to GPU later.
    with tf.device('/cpu:0'):

        # Images and labels placeholders
        images_ph = tf.placeholder(tf.float32,
                                   shape=(None, ) + tuple(args.processed_size),
                                   name='input')
        labels_ph = tf.placeholder(tf.int32, shape=(None), name='label')

        # a placeholder for determining if we train or validate the network. This placeholder will be used to set dropout rates and batchnorm paramaters.
        is_training_ph = tf.placeholder(tf.bool, name='is_training')

        #epoch number
        epoch_number = tf.get_variable(
            'epoch_number', [],
            dtype=tf.int32,
            initializer=tf.constant_initializer(0),
            trainable=False,
            collections=[tf.GraphKeys.GLOBAL_VARIABLES, SAVE_VARIABLES])
        global_step = tf.get_variable(
            'global_step', [],
            dtype=tf.int32,
            initializer=tf.constant_initializer(0),
            trainable=False,
            collections=[tf.GraphKeys.GLOBAL_VARIABLES, SAVE_VARIABLES])

        # Weight Decay policy
        wd = utils.get_policy(args.WD_policy, args.WD_details)

        # Learning rate decay policy (if needed)
        lr = utils.get_policy(args.LR_policy, args.LR_details)

        # Create an optimizer that performs gradient descent.
        optimizer = utils.get_optimizer(args.optimizer, lr)

        # build the computational graph using the provided configuration.
        dnn_model = model(images_ph,
                          labels_ph,
                          utils.loss,
                          optimizer,
                          wd,
                          args.architecture,
                          args.num_classes,
                          is_training_ph,
                          args.transfer_mode,
                          num_gpus=args.num_gpus)

        # Create a pipeline to read data from disk
        # a placeholder for setting the input pipeline batch size. This is employed to ensure that we feed each validation example only once to the network.
        # Because we only use 1 GPU for validation, the validation batch size should not be more than 512.
        batch_size_tf = tf.placeholder_with_default(min(512, args.batch_size),
                                                    shape=())

        # A data loader pipeline to read training images and their labels
        train_loader = loader(args.train_info, args.delimiter, args.raw_size,
                              args.processed_size, True,
                              args.chunked_batch_size, args.num_prefetch,
                              args.num_threads, args.path_prefix, args.shuffle)
        # The loader returns images, their labels, and their paths
        images, labels, info = train_loader.load()

        # If validation data are provided, we create an input pipeline to load the validation data
        if args.run_validation:
            val_loader = loader(args.val_info, args.delimiter, args.raw_size,
                                args.processed_size, False, batch_size_tf,
                                args.num_prefetch, args.num_threads,
                                args.path_prefix)
            val_images, val_labels, val_info = val_loader.load()

        # Get training operations to run from the deep learning model
        train_ops = dnn_model.train_ops()

        # Build an initialization operation to run below.
        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())
        sess.run(init)

        if args.retrain_from is not None:
            dnn_model.load(sess, args.retrain_from)

        # Set the start epoch number
        start_epoch = sess.run(epoch_number + 1)

        # Start the queue runners.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # Setup a summary writer
        summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph)

        # The main training loop
        for epoch in range(start_epoch, start_epoch + args.num_epochs):

            # update epoch_number
            sess.run(epoch_number.assign(epoch))

            print("Epoch %d of %d started" %
                  (epoch, start_epoch + args.num_epochs - 1))
            # Trainig batches
            for step in range(args.num_batches):
                sess.run(global_step.assign(step + epoch * args.num_batches))
                # train the network on a batch of data (It also measures time)
                start_time = time.time()

                # load a batch from input pipeline
                img, lbl = sess.run([images, labels],
                                    options=args.run_options,
                                    run_metadata=args.run_metadata)

                # train on the loaded batch of data
                _, loss_value, top1_accuracy, topn_accuracy = sess.run(
                    train_ops,
                    feed_dict={
                        images_ph: img,
                        labels_ph: lbl,
                        is_training_ph: True
                    },
                    options=args.run_options,
                    run_metadata=args.run_metadata)
                duration = time.time() - start_time

                # Check for errors
                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                # Logging every ten batches and writing tensorboard summaries every hundred batches
                if step % 10 == 0:

                    num_examples_per_step = args.chunked_batch_size * args.num_gpus
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = duration / args.num_gpus

                    # Log
                    format_str = (
                        '%s: epoch %d of %d, step %d of %d, loss = %.2f, Top-1 = %.2f Top-'
                        + str(args.top_n) +
                        ' = %.2f (%.1f examples/sec; %.3f sec/batch)')
                    print(
                        format_str %
                        (datetime.now(), epoch, start_epoch + args.num_epochs -
                         1, step, args.num_batches, loss_value, top1_accuracy,
                         topn_accuracy, examples_per_sec, sec_per_batch))
                    sys.stdout.flush()

                if step % 100 == 0:
                    summary_str = sess.run(tf.summary.merge_all(),
                                           feed_dict={
                                               images_ph: img,
                                               labels_ph: lbl,
                                               is_training_ph: True
                                           })
                    summary_writer.add_summary(summary_str,
                                               args.num_batches * epoch + step)
                    if args.log_debug_info:
                        summary_writer.add_run_metadata(
                            run_metadata, 'epoch%d step%d' % (epoch, step))

            # Save the model checkpoint periodically after each training epoch
            checkpoint_path = os.path.join(args.log_dir, args.snapshot_prefix)
            dnn_model.save(sess, checkpoint_path, global_step=epoch)

            print("Epoch %d of %d ended. a checkpoint saved at %s" %
                  (epoch, start_epoch + args.num_epochs - 1, args.log_dir))
            sys.stdout.flush()
            # if validation data are provided, evaluate accuracy on the validation set after the end of each epoch
            if args.run_validation:

                print("Evaluating on validation set")
                total_loss = utils.AverageMeter(
                )  # Measures cross entropy loss
                top1 = utils.AverageMeter()  # Measures top-1 accuracy
                topn = utils.AverageMeter()  # Measures top-n accuracy

                # The validation loop
                for step in range(args.num_val_batches):
                    # Load a batch of data
                    val_img, val_lbl = sess.run(
                        [val_images, val_labels],
                        feed_dict={
                            batch_size_tf:
                            args.num_val_samples % min(512, args.batch_size)
                        } if step == args.num_val_batches - 1 else None,
                        options=args.run_options,
                        run_metadata=args.run_metadata)

                    # validate the network on the loaded batch
                    val_loss, top1_predictions, topn_predictions = sess.run(
                        [train_ops[1], train_ops[2], train_ops[3]],
                        feed_dict={
                            images_ph: val_img,
                            labels_ph: val_lbl,
                            is_training_ph: False
                        },
                        options=args.run_options,
                        run_metadata=args.run_metadata)

                    current_batch_size = val_lbl.shape[0]
                    total_loss.update(val_loss, current_batch_size)
                    top1.update(top1_predictions, current_batch_size)
                    topn.update(topn_predictions, current_batch_size)

                    if step % 10 == 0 or step == args.num_val_batches - 1:
                        print(
                            "Validation step %d of %d, Loss %.2f, Top-1 Accuracy %.2f, Top-%d Accuracy %.2f "
                            % (step, args.num_val_batches, total_loss.avg,
                               top1.avg, args.top_n, topn.avg))
                        sys.stdout.flush()

        coord.request_stop()
        coord.join(threads)
        sess.close()
Exemple #2
0
def do_evaluate(sess, args):
    with tf.device('/cpu:0'):
        # Images and labels placeholders
        images_ph = tf.placeholder(tf.float32,
                                   shape=(None, ) + tuple(args.processed_size),
                                   name='input')
        labels_ph = tf.placeholder(tf.int32, shape=(None), name='label')

        # a placeholder for determining if we train or validate the network. This placeholder will be used to set dropout rates and batchnorm paramaters.
        is_training_ph = tf.placeholder(tf.bool, name='is_training')

        # build a deep learning model using the provided configuration
        dnn_model = model(images_ph, labels_ph, utils.loss, None, 0.0,
                          args.architecture, args.num_classes, is_training_ph,
                          args.transfer_mode)

        # creating an input pipeline to read data from disk
        # a placeholder for setting the input pipeline batch size. This is employed to ensure that we feed each validation example only once to the network.
        batch_size_tf = tf.placeholder_with_default(args.batch_size, shape=())

        # a data loader pipeline to read test data
        val_loader = loader(args.val_info,
                            args.delimiter,
                            args.raw_size,
                            args.processed_size,
                            False,
                            batch_size_tf,
                            args.num_prefetch,
                            args.num_threads,
                            args.path_prefix,
                            inference_only=args.inference_only)

        # if we want to do inference only (i.e. no label is provided) we only load images and their paths
        if not args.inference_only:
            val_images, val_labels, val_info = val_loader.load()
        else:
            val_images, val_info = val_loader.load()

        # get evaluation operations from the dnn model
        eval_ops = dnn_model.evaluate_ops(args.inference_only)

        # Build an initialization operation to run below.
        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())
        sess.run(init)

        # Load pretrained parameters from disk
        dnn_model.load(sess, args.log_dir)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # evaluation
        if not args.inference_only:
            total_loss = utils.AverageMeter()  # Measures cross entropy loss
            top1 = utils.AverageMeter()  # Measures top-1 accuracy
            topn = utils.AverageMeter()  # Measures top-n accuracy

            # Open an output file to write predictions
            out_file = open(args.save_predictions, 'w')
            predictions_format_str = ('%d, %s, %s, %s, %s\n')

            for step in range(args.num_val_batches):
                # Load a batch of data
                val_img, val_lbl, val_inf = sess.run(
                    [val_images, val_labels, val_info],
                    feed_dict={
                        batch_size_tf: args.num_val_samples % args.batch_size
                    } if step == args.num_val_batches - 1 else None)

                # Evaluate the network on the loaded batch
                val_loss, top1_predictions, topn_predictions, topnguesses, topnconf = sess.run(
                    eval_ops,
                    feed_dict={
                        images_ph: val_img,
                        labels_ph: val_lbl,
                        is_training_ph: False
                    },
                    options=args.run_options,
                    run_metadata=args.run_metadata)

                current_batch_size = val_lbl.shape[0]
                total_loss.update(val_loss, current_batch_size)
                top1.update(top1_predictions, current_batch_size)
                topn.update(topn_predictions, current_batch_size)
                print(
                    'Batch Number: %d of %d, Top-1 Hit: %d, Top-%d Hit: %d, Loss %.2f, Top-1 Accuracy: %.2f, Top-%d Accuracy: %.2f'
                    %
                    (step, args.num_val_batches, top1.sum, args.top_n,
                     topn.sum, total_loss.avg, top1.avg, args.top_n, topn.avg))

                # log results into an output file
                for i in range(0, val_inf.shape[0]):
                    out_file.write(
                        predictions_format_str %
                        (step * args.batch_size + i + 1, val_inf[i],
                         val_loader.label_dict[val_lbl[i]], ', '.join(
                             '%d' % item
                             for item in topnguesses[i]), ', '.join(
                                 '%.4f' % item for item in topnconf[i])))
                    out_file.flush()

            out_file.close()
        #inference
        else:

            # Open an output file to write predictions
            out_file = open(args.save_predictions, 'w')
            predictions_format_str = ('%d, %s, %s, %s\n')

            for step in range(args.num_val_batches):
                # Load a batch of data
                val_img, val_inf = sess.run(
                    [val_images, val_info],
                    feed_dict={
                        batch_size_tf: args.num_val_samples % args.batch_size
                    } if step == args.num_val_batches - 1 else None)

                # Run the network on the loaded batch
                topnguesses, topnconf = sess.run(
                    eval_ops,
                    feed_dict={
                        images_ph: val_img,
                        is_training_ph: False
                    },
                    options=args.run_options,
                    run_metadata=args.run_metadata)
                print('Batch Number: %d of %d is done' %
                      (step, args.num_val_batches))

                # Log to an output file
                for i in range(0, val_inf.shape[0]):
                    out_file.write(
                        predictions_format_str %
                        (step * args.batch_size + i + 1, val_inf[i], ', '.join(
                            '%d' % item for item in topnguesses[i]), ', '.join(
                                '%.4f' % item for item in topnconf[i])))
                    out_file.flush()

            out_file.close()

        coord.request_stop()
        coord.join(threads)
        sess.close()