Example #1
0
def main(_):

    # Always keep the cpu as default
    with tf.Graph().as_default(), tf.device('/cpu:0'):

        if FLAGS.validation_interval == 0:
            FLAGS.validation_db = None

        # Set Tensorboard log directory
        if FLAGS.summaries_dir:
            # The following gives a nice but unrobust timestamp
            FLAGS.summaries_dir = os.path.join(
                FLAGS.summaries_dir,
                datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))

        if not FLAGS.train_db and not FLAGS.validation_db and not FLAGS.inference_db and not FLAGS.visualizeModelPath:
            logging.error(
                "At least one of the following file sources should be specified: "
                "train_db, validation_db or inference_db")
            exit(-1)

        if FLAGS.seed:
            tf.set_random_seed(FLAGS.seed)

        batch_size_train = FLAGS.batch_size
        batch_size_val = FLAGS.batch_size
        logging.info("Train batch size is %s and validation batch size is %s",
                     batch_size_train, batch_size_val)

        # This variable keeps track of next epoch, when to perform validation.
        next_validation = FLAGS.validation_interval
        logging.info(
            "Training epochs to be completed for each validation : %s",
            next_validation)

        # This variable keeps track of next epoch, when to save model weights.
        next_snapshot_save = FLAGS.snapshotInterval
        logging.info(
            "Training epochs to be completed before taking a snapshot : %s",
            next_snapshot_save)
        last_snapshot_save_epoch = 0

        snapshot_prefix = FLAGS.snapshotPrefix if FLAGS.snapshotPrefix else FLAGS.network.split(
            '.')[0]
        logging.info("Model weights will be saved as %s_<EPOCH>_Model.ckpt",
                     snapshot_prefix)

        if not os.path.exists(FLAGS.save):
            os.makedirs(FLAGS.save)
            logging.info("Created a directory %s to save all the snapshots",
                         FLAGS.save)

        # Load mean variable
        if FLAGS.subtractMean == 'none':
            mean_loader = None
        else:
            if not FLAGS.mean:
                logging.error(
                    "subtractMean parameter not set to 'none' yet mean image path is unset"
                )
                exit(-1)
            logging.info("Loading mean tensor from %s file", FLAGS.mean)
            mean_loader = tf_data.MeanLoader(FLAGS.mean, FLAGS.subtractMean,
                                             FLAGS.bitdepth)

        classes = 0
        nclasses = 0
        if FLAGS.labels_list:
            logging.info("Loading label definitions from %s file",
                         FLAGS.labels_list)
            classes = loadLabels(FLAGS.labels_list)
            nclasses = len(classes)
            if not classes:
                logging.error("Reading labels file %s failed.",
                              FLAGS.labels_list)
                exit(-1)
            logging.info("Found %s classes", nclasses)

        # Create a data-augmentation dict
        aug_dict = {
            'aug_flip': FLAGS.augFlip,
            'aug_noise': FLAGS.augNoise,
            'aug_contrast': FLAGS.augContrast,
            'aug_whitening': FLAGS.augWhitening,
            'aug_HSV': {
                'h': FLAGS.augHSVh,
                's': FLAGS.augHSVs,
                'v': FLAGS.augHSVv,
            },
        }

        # Import the network file
        path_network = os.path.join(
            os.path.dirname(os.path.realpath(__file__)),
            FLAGS.networkDirectory, FLAGS.network)
        exec(open(path_network).read(), globals())

        try:
            UserModel
        except NameError:
            logging.error("The user model class 'UserModel' is not defined.")
            exit(-1)
        if not inspect.isclass(UserModel):  # noqa
            logging.error("The user model class 'UserModel' is not a class.")
            exit(-1)
        # @TODO(tzaman) - add mode checks to UserModel

        if FLAGS.train_db:
            with tf.name_scope(digits.STAGE_TRAIN) as stage_scope:
                train_model = Model(digits.STAGE_TRAIN, FLAGS.croplen,
                                    nclasses, FLAGS.optimization,
                                    FLAGS.momentum)
                train_model.create_dataloader(FLAGS.train_db)
                train_model.dataloader.setup(FLAGS.train_labels, FLAGS.shuffle,
                                             FLAGS.bitdepth, batch_size_train,
                                             FLAGS.epoch, FLAGS.seed)
                train_model.dataloader.set_augmentation(mean_loader, aug_dict)
                train_model.create_model(UserModel, stage_scope)  # noqa

        if FLAGS.validation_db:
            with tf.name_scope(digits.STAGE_VAL) as stage_scope:
                val_model = Model(digits.STAGE_VAL, FLAGS.croplen, nclasses)
                val_model.create_dataloader(FLAGS.validation_db)
                val_model.dataloader.setup(
                    FLAGS.validation_labels, False, FLAGS.bitdepth,
                    batch_size_val, 1e9,
                    FLAGS.seed)  # @TODO(tzaman): set numepochs to 1
                val_model.dataloader.set_augmentation(mean_loader)
                val_model.create_model(UserModel, stage_scope)  # noqa

        if FLAGS.inference_db:
            with tf.name_scope(digits.STAGE_INF) as stage_scope:
                inf_model = Model(digits.STAGE_INF, FLAGS.croplen, nclasses)
                inf_model.create_dataloader(FLAGS.inference_db)
                inf_model.dataloader.setup(None, False, FLAGS.bitdepth,
                                           FLAGS.batch_size, 1, FLAGS.seed)
                inf_model.dataloader.set_augmentation(mean_loader)
                inf_model.create_model(UserModel, stage_scope)  # noqa

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=
            True,  # will automatically do non-gpu supported ops on cpu
            inter_op_parallelism_threads=TF_INTER_OP_THREADS,
            intra_op_parallelism_threads=TF_INTRA_OP_THREADS,
            log_device_placement=FLAGS.log_device_placement))

        if FLAGS.visualizeModelPath:
            visualize_graph(sess.graph_def, FLAGS.visualizeModelPath)
            exit(0)

        # Saver creation.
        if FLAGS.save_vars == 'all':
            vars_to_save = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        elif FLAGS.save_vars == 'trainable':
            vars_to_save = tf.all_variables()
        else:
            logging.error('Unknown save_var flag (%s)' % FLAGS.save_vars)
            exit(-1)
        saver = tf.train.Saver(vars_to_save,
                               max_to_keep=0,
                               sharded=FLAGS.serving_export)

        # Initialize variables
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)

        # If weights option is set, preload weights from existing models appropriately
        if FLAGS.weights:
            load_snapshot(sess, FLAGS.weights,
                          tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        # Tensorboard: Merge all the summaries and write them out
        writer = tf.train.SummaryWriter(
            os.path.join(FLAGS.summaries_dir, 'tb'), sess.graph)

        # If we are inferencing, only do that.
        if FLAGS.inference_db:
            inf_model.start_queue_runners(sess)
            Inference(sess, inf_model)

        queue_size_op = []
        for n in tf.get_default_graph().as_graph_def().node:
            if '_Size' in n.name:
                queue_size_op.append(n.name + ':0')

        start = time.time()  # @TODO(tzaman) - removeme

        # Initial Forward Validation Pass
        if FLAGS.validation_db:
            val_model.start_queue_runners(sess)
            Validation(sess, val_model, 0)

        if FLAGS.train_db:
            # During training, a log output should occur at least X times per epoch or every X images, whichever lower
            train_steps_per_epoch = train_model.dataloader.get_total(
            ) / batch_size_train
            if math.ceil(train_steps_per_epoch /
                         MIN_LOGS_PER_TRAIN_EPOCH) < math.ceil(
                             5000 / batch_size_train):
                logging_interval_step = int(
                    math.ceil(train_steps_per_epoch /
                              MIN_LOGS_PER_TRAIN_EPOCH))
            else:
                logging_interval_step = int(math.ceil(5000 / batch_size_train))
            logging.info(
                "During training. details will be logged after every %s steps (batches)",
                logging_interval_step)

            # epoch value will be calculated for every batch size. To maintain unique epoch value between batches,
            # it needs to be rounded to the required number of significant digits.
            epoch_round = 0  # holds the required number of significant digits for round function.
            tmp_batchsize = batch_size_train * logging_interval_step
            while tmp_batchsize <= train_model.dataloader.get_total():
                tmp_batchsize = tmp_batchsize * 10
                epoch_round += 1
            logging.info(
                "While logging, epoch value will be rounded to %s significant digits",
                epoch_round)

            # Create the learning rate policy
            total_training_steps = train_model.dataloader.num_epochs * train_model.dataloader.get_total() / \
                train_model.dataloader.batch_size
            lrpolicy = lr_policy.LRPolicy(FLAGS.lr_policy, FLAGS.lr_base_rate,
                                          FLAGS.lr_gamma, FLAGS.lr_power,
                                          total_training_steps,
                                          FLAGS.lr_stepvalues)
            train_model.start_queue_runners(sess)

            # Training
            logging.info('Started training the model')

            current_epoch = 0
            try:
                step = 0
                step_last_log = 0
                print_vals_sum = 0
                while not train_model.queue_coord.should_stop():
                    log_runtime = FLAGS.log_runtime_stats_per_step and (
                        step % FLAGS.log_runtime_stats_per_step == 0)

                    run_options = None
                    run_metadata = None
                    if log_runtime:
                        # For a HARDWARE_TRACE you need NVIDIA CUPTI, a 'CUDA-EXTRA'
                        # SOFTWARE_TRACE HARDWARE_TRACE FULL_TRACE
                        run_options = tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE)
                        run_metadata = tf.RunMetadata()

                    feed_dict = {
                        train_model.learning_rate:
                        lrpolicy.get_learning_rate(step)
                    }

                    if False:
                        for op in train_model.train:
                            _, summary_str, step = sess.run(
                                [
                                    op, train_model.summary,
                                    train_model.global_step
                                ],
                                feed_dict=feed_dict,
                                options=run_options,
                                run_metadata=run_metadata)
                    else:
                        _, summary_str, step = sess.run(
                            [
                                train_model.train, train_model.summary,
                                train_model.global_step
                            ],
                            feed_dict=feed_dict,
                            options=run_options,
                            run_metadata=run_metadata)

                    # HACK
                    step = step / len(train_model.train)

                    # logging.info(sess.run(queue_size_op)) # DEVELOPMENT: for checking the queue size

                    if log_runtime:
                        writer.add_run_metadata(run_metadata, str(step))
                        save_timeline_trace(run_metadata, FLAGS.save,
                                            int(step))

                    writer.add_summary(summary_str, step)

                    # Parse the summary
                    tags, print_vals = summary_to_lists(summary_str)

                    print_vals_sum = print_vals + print_vals_sum

                    # @TODO(tzaman): account for variable batch_size value on very last epoch
                    current_epoch = round((step * batch_size_train) /
                                          train_model.dataloader.get_total(),
                                          epoch_round)

                    # Start with a forward pass
                    if ((step % logging_interval_step) == 0):
                        steps_since_log = step - step_last_log
                        print_list = print_summarylist(
                            tags, print_vals_sum / steps_since_log)
                        logging.info("Training (epoch " + str(current_epoch) +
                                     "): " + print_list)
                        print_vals_sum = 0
                        step_last_log = step

                    # Potential Validation Pass
                    if FLAGS.validation_db and current_epoch >= next_validation:
                        Validation(sess, val_model, current_epoch)
                        # Find next nearest epoch value that exactly divisible by FLAGS.validation_interval:
                        next_validation = (round(float(current_epoch)/FLAGS.validation_interval) + 1) * \
                            FLAGS.validation_interval

                    # Saving Snapshot
                    if FLAGS.snapshotInterval > 0 and current_epoch >= next_snapshot_save:
                        save_snapshot(sess, saver, FLAGS.save, snapshot_prefix,
                                      current_epoch, FLAGS.serving_export)

                        # To find next nearest epoch value that exactly divisible by FLAGS.snapshotInterval
                        next_snapshot_save = (round(float(current_epoch)/FLAGS.snapshotInterval) + 1) * \
                            FLAGS.snapshotInterval
                        last_snapshot_save_epoch = current_epoch
                    writer.flush()
            except tf.errors.OutOfRangeError:
                logging.info(
                    'Done training for epochs: tf.errors.OutOfRangeError')
            except ValueError as err:
                logging.error(err.args[0])
                exit(-1)  # DIGITS wants a dirty error.
            except (KeyboardInterrupt):
                logging.info('Interrupt signal received.')

            # If required, perform final snapshot save
            if FLAGS.snapshotInterval > 0 and FLAGS.epoch > last_snapshot_save_epoch:
                save_snapshot(sess, saver, FLAGS.save, snapshot_prefix,
                              FLAGS.epoch, FLAGS.serving_export)

        print('Training wall-time:',
              time.time() - start)  # @TODO(tzaman) - removeme

        # If required, perform final Validation pass
        if FLAGS.validation_db and current_epoch >= next_validation:
            Validation(sess, val_model, current_epoch)

        if FLAGS.train_db:
            del train_model
        if FLAGS.validation_db:
            del val_model
        if FLAGS.inference_db:
            del inf_model

        # We need to call sess.close() because we've used a with block
        sess.close()

        writer.close()
        logging.info('END')
        exit(0)
Example #2
0
def main(_):
    # Always keep the cpu as default
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        if FLAGS.validation_interval == 0:
            FLAGS.validation_db = None

        # Set Tensorboard log directory
        if FLAGS.summaries_dir:
            # The following gives a nice but unrobust timestamp
            FLAGS.summaries_dir = os.path.join(
                FLAGS.summaries_dir,
                datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))

        if not FLAGS.train_db and not FLAGS.validation_db and not FLAGS.inference_db:
            logging.error(
                "At least one of the following file sources should be specified: "
                "train_db, validation_db or inference_db")
            exit(-1)

        if FLAGS.seed:
            tf.set_random_seed(FLAGS.seed)

        batch_size_train = FLAGS.batch_size
        batch_size_val = FLAGS.batch_size
        logging.info("Train batch size is %s and validation batch size is %s",
                     batch_size_train, batch_size_val)

        # This variable keeps track of next epoch, when to perform validation.
        next_validation = FLAGS.validation_interval
        logging.info(
            "Training epochs to be completed for each validation : %s",
            next_validation)

        # This variable keeps track of next epoch, when to save model weights.
        next_snapshot_save = FLAGS.snapshotInterval
        logging.info(
            "Training epochs to be completed before taking a snapshot : %s",
            next_snapshot_save)
        last_snapshot_save_epoch = 0

        snapshot_prefix = FLAGS.snapshotPrefix if FLAGS.snapshotPrefix else FLAGS.network.split(
            '.')[0]
        logging.info(
            'Model weights will be saves as {}_<EPOCH>_Model.ckpt'.format(
                snapshot_prefix))

        if not os.path.exists(FLAGS.save):
            os.makedirs(FLAGS.save)
            logging.info("Created a directory %s to save all the snapshots",
                         FLAGS.save)

        classes = 0
        nclasses = 0
        if FLAGS.labels_list:
            logging.info("Loading label definitions from %s file",
                         FLAGS.labels_list)
            classes = loadLabels(FLAGS.labels_list)
            nclasses = len(classes)
            if not classes:
                logging.error("Reading labels file %s failed.",
                              FLAGS.labels_list)
                exit(-1)
            logging.info("Found %s classes", nclasses)

        # Debugging NaNs and Inf
        check_op = tf.add_check_numerics_ops()

        # Import the network file
        path_network = os.path.join(
            os.path.dirname(os.path.realpath(__file__)),
            FLAGS.networkDirectory, FLAGS.network)
        exec(open(path_network).read(), globals())

        try:
            UserModel
        except NameError:
            logging.error("The user model class 'UserModel' is not defined.")
            exit(-1)
        if not inspect.isclass(UserModel):
            logging.error("The user model class 'UserModel' is not a class.")
            eixt(-1)

        if FLAGS.train_db:
            with tf.name_scope(utils.STAGE_TRAIN) as stage_scope:
                train_model = Model(utils.STAGE_TRAIN, FLAGS.croplen, nclasses,
                                    FLAGS.optimization, FLAGS.momentum)
                train_model.create_dataloader(FLAGS.train_db)
                train_model.dataloader.setup(FLAGS.train_labels, FLAGS.shuffle,
                                             FLAGS.bitdepth, batch_size_train,
                                             FLAGS.epoch, FLAGS.seed)

                train_model.create_model(UserModel, stage_scope)

        if FLAGS.validation_db:
            with tf.name_scope(utils.STAGE_VAL) as stage_scope:
                val_model = Model(utils.STAGE_VAL,
                                  FLAGS.croplen,
                                  nclasses,
                                  reuse_variable=True)
                val_model.create_dataloader(FLAGS.validation_db)
                val_model.dataloader.setup(FLAGS.validation_labels, False,
                                           FLAGS.bitdepth, batch_size_val, 1e9,
                                           FLAGS.seed)
                val_model.create_model(UserModel, stage_scope)

        if FLAGS.inference_db:
            with tf.name_scope(utils.STAGE_INF) as stage_scope:
                inf_model = Model(utils.STAGE_INF, FLAGS.croplen, nclasses)
                inf_model.create_dataloader(FLAGS.inference_db)
                inf_model.dataloader.setup(None, False, FLAGS.bitdepth,
                                           FLAGS.batch_size, 1, FLAGS.seed)
                inf_model.create_model(UserModel, stage_scope)

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=
            True,  # will automatically do non-gpu supported ops on cpu
            inter_op_parallelism_threads=TF_INTER_OP_THREADS,
            intra_op_parallelism_threads=TF_INTRA_OP_THREADS,
            log_device_placement=FLAGS.log_device_placement,
            gpu_options=tf.GPUOptions(allow_growth=True)))

        # Saver creation.
        if FLAGS.save_vars == 'all':
            vars_to_save = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
            # vars_to_save = tf.global_variables()
        elif FLAGS.save_vars == 'trainable':
            vars_to_save = tf.trainable_variables()
        else:
            logging.error('Unknown save_var flag ({})'.format(FLAGS.save_vars))
            exit(-1)

        saver = tf.train.Saver(vars_to_save,
                               max_to_keep=0,
                               sharded=FLAGS.serving_export)

        # Initialize variables
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)

        # If weights option is set, preload weights from existing models appropriatedly
        if FLAGS.weights:
            load_snapshot(sess, FLAGS.weights,
                          tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        # Tensorboard: Merge all the summaries and write them out
        writer = ops.SummaryWriter(os.path.join(FLAGS.summaries_dir, 'tb'),
                                   sess.graph)

        # If we are inferencing , only do that
        if FLAGS.inference_db:
            inf_model.start_queue_runners(sess)
            Inference(sess, inf_model)

        queue_size_op = []
        for n in tf.get_default_graph().as_graph_def().node:
            if '_Size' in n.name:
                logging.debug('graph node name: {}'.format(n.name))
                queue_size_op.append(n.name + ':0')

        start = time.time()

        # initail Forward VAlidation Pass
        if FLAGS.validation_db:
            val_model.start_queue_runners(sess)
            Validation(sess, val_model, 0, check_op)

        if FLAGS.train_db:
            # During training, a log output should occur at least X times per epoch or every X images, whichever lower
            train_steps_per_epoch = train_model.dataloader.get_total(
            ) / batch_size_train
            if math.ceil(train_steps_per_epoch /
                         MIN_LOGS_PER_TRAIN_EPOCH) < math.ceil(
                             5000 / batch_size_train):
                logging_interval_step = int(
                    math.ceil(train_steps_per_epoch /
                              MIN_LOGS_PER_TRAIN_EPOCH))
            else:
                logging_interval_step = int(math.ceil(5000 / batch_size_train))
            logging.info(
                "During training. details will be logged after every %s steps (batches)",
                logging_interval_step)

            # epoch value will be calculated for every batch size. To maintain unique epoch value between batches,
            # it needs to be rounded to the required number of significant digits.
            epoch_round = 0  # holds the required number of significant digits for round function.
            tmp_batchsize = batch_size_train * logging_interval_step
            while tmp_batchsize <= train_model.dataloader.get_total():
                tmp_batchsize = tmp_batchsize * 10
                epoch_round += 1
            logging.info(
                "While logging, epoch value will be rounded to %s significant digits",
                epoch_round)

            # Create the learning rate policy
            total_trainning_steps = train_model.dataloader.num_epochs * train_model.dataloader.get_total(
            ) / train_model.dataloader.batch_size
            lrpolicy = lr_policy.LRPolicy(FLAGS.lr_policy, FLAGS.lr_base_rate,
                                          FLAGS.lr_gamma, FLAGS.lr_power,
                                          total_trainning_steps,
                                          FLAGS.lr_stepvalues)

            train_model.start_queue_runners(sess)

            # Trainnig
            logging.info('Started training the model')

            current_epoch = 0
            try:
                step = 0
                step_last_log = 0
                print_vals_sum = 0
                while not train_model.queue_coord.should_stop():
                    log_runtime = FLAGS.log_runtime_stats_per_step and (
                        step % FLAGS.log_runtime_stats_per_step == 0)

                    # log runtime for benchmark
                    run_options = None
                    run_metadata = None
                    if log_runtime:
                        # For a HARDWARE_TRACE you need NVIDIA CUPTI, a 'CUDA-EXTRA'
                        # SOFTWARE_TRACE HARDWARE_TRACE FULL_TRACE
                        run_options = tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE)
                        run_metadata = tf.RunMetadata()

                    feed_dict = {
                        train_model.learning_rate:
                        lrpolicy.get_learning_rate(step)
                    }

                    # Session.run
                    _, summary_str, step, _ = sess.run(
                        [
                            train_model.train, train_model.summary,
                            train_model.global_step, check_op
                        ],
                        feed_dict=feed_dict,
                        options=run_options,
                        run_metadata=run_metadata)

                    step = step / len(train_model.train)

                    #1. display the eaxct total memory, compute time, and tensor output sizes to TensorBoard
                    if log_runtime:
                        writer.add_run_metadata(run_metadata, str(step))

                    #2, another method to trace the timeline of operations
                    # You can then open Google Chrome, go to the page chrome://tracing and load the timeline.json file
                    if log_runtime:
                        # Create the Timeline object, and write it to json
                        tl = timeline.TimeLine(run_metadata.step_stats)
                        ctf = tl.generate_chrome_trace_format()
                        prof_path = os.path.join(FLAGS.summaries_dir,
                                                 'benchmark')
                        with open(os.path.join(prof_path, 'timeline.json'),
                                  'w') as f:
                            f.write(ctf)

                    # sumamry for TensorBoard
                    writer.add_summary(summary_str, step)

                    current_epoch = round((step * batch_size_train) /
                                          train_model.dataloader.get_total(),
                                          epoch_round)

                    # Potential Validation Pass
                    if FLAGS.validation_db and current_epoch >= next_validation:
                        Validation(sess, val_model, current_epoch, check_op)
                        # Find next nearest epoch value that exactly divisible by FLAGS.validation_interval:
                        next_validation = (round(float(current_epoch)/FLAGS.validation_interval) + 1) * \
                                            FLAGS.validation_interval

                    # Saving Snapshot
                    if FLAGS.snapshotInterval > 0 and current_epoch >= next_snapshot_save:
                        save_snapshot(sess, saver, FLAGS.save, snapshot_prefix,
                                      current_epoch, FLAGS.serving_export)

                        # To find next nearest epoch value that exactly divisible by FLAGS.snapshotInterval
                        next_snapshot_save = (round(float(current_epoch)/FLAGS.snapshotInterval) + 1) * \
                                            FLAGS.snapshotInterval
                        last_snapshot_save_epoch = current_epoch

                    writer.flush()
            except tf.errors.OutOfRangeError:
                logging.info(
                    'Done training for epochs limit reached: tf.errors.OutOfRangeError'
                )
            except ValueError as err:
                logging.error(err.args[0])
                exit(-1)
            except (KeyboardInterrupt):
                logging.info('Interrupt signal received.')

            # If required, perform final snapshot save
            if FLAGS.snapshotInterval > 0 and FLAGS.epoch > last_snapshot_save_epoch:
                save_snapshot(sess, saver, FLAGS.save, snapshot_prefix,
                              FLAGS.epoch, FLAGS.serving_export)

        print('Training wall-time:', time.time() - start)

        # If required, perform final Validation pass
        if FLAGS.validation_db and current_epoch >= next_validation:
            Validation(sess, val_model, current_epoch, check_op)

        # Close and terminate the quques
        if FLAGS.train_db:
            del train_model
        if FLAGS.validation_db:
            del val_model
        if FLAGS.inference_db:
            del inf_model

        # We need to call sess.close() because we've used a with block
        sess.close()

        writer.close()
        logging.info('END')
        exit(0)