Exemple #1
0
    def __init__(self):
        # Load the training set definition. It will be used to know the dataset
        # size and possibly to initialize a new training graph.
        training_set_def_path = paths.DatasetDefinitions.TRAINING
        with open(training_set_def_path, 'r') as f:
            training_set_def = json.load(f)

        self.dataset_size = len(training_set_def)
        self.batches_per_epoch = utils.batches_per_epoch(
            dataset_size=self.dataset_size,
            batch_size=config.ExperimentConfig.BATCH_SIZE_TRAINING,
            drop_last=True)

        # Create and save the MetaGraph for the training graph.
        metagraph_path = paths.MetaGraphs.TRAINING
        if os.path.exists(metagraph_path):
            logger.info('Importing existing training MetaGraph '
                        'from {}'.format(metagraph_path))
            training_graph = tf.Graph()
            with training_graph.as_default():
                tf.train.import_meta_graph(metagraph_path)
        else:
            logger.info('Creating new training MetaGraph')
            training_graph = graphs.build_training_graph(training_set_def)
            with training_graph.as_default():
                tf.train.export_meta_graph(metagraph_path)

        self._session = tf.Session(graph=training_graph)

        with training_graph.as_default():
            self._saver = tf.train.Saver(max_to_keep=1)
    def load(self):
        """Loads the model weights of the "best" model trained so far."""
        metagraph_path = paths.MetaGraphs.INFERENCE
        if not os.path.exists(metagraph_path):
            raise IOError('No MetaGraph at {}'.format(metagraph_path))

        # Load the MetaGraph from disk.
        inference_graph = tf.Graph()
        with inference_graph.as_default():
            tf.train.import_meta_graph(metagraph_path)
            self._saver = tf.train.Saver()

        logger.info('Imported existing inference MetaGraph '
                    'from {}'.format(metagraph_path))

        # Create a tf.Session object with the loaded graph.
        self._session = tf.Session(graph=inference_graph)

        # Load the trained weights.
        best_ckpt_prefix = paths.Checkpoints.BEST_MODEL
        best_ckpt_pattern = best_ckpt_prefix + '.*'
        if not glob.glob(best_ckpt_pattern):
            raise IOError('No checkpoint at {}'.format(best_ckpt_prefix))

        self._saver.restore(self._session, best_ckpt_prefix)

        logger.info('Loaded trained weights '
                    'from {}'.format(best_ckpt_prefix))
def _load_or_create_training_status(training_status_path):
    """Returns a dict containing the current status of the experiment.

    This function tries to load the current status from disk, in case an
    existing one has been saved; otherwise returns a new one.

    Args:
        training_status_path (str): Path to the training status file to load.

    Returns:
        A dict containing the current experiment status.
    """
    if os.path.exists(training_status_path):
        # Load an existing training status.
        with open(training_status_path, 'r') as f:
            training_status = json.load(f)
        logger.info('Loaded existing training status '
                    'from {}'.format(training_status_path))
    else:
        # Create a new training status.
        training_status = {
            _CNST.LATEST_TRAINED_KEY: _CNST.DEFAULT_EPOCH_IDX,
            _CNST.LATEST_EVALUATED_KEY: {
                _CNST.EPOCH_IDX_KEY: _CNST.DEFAULT_EPOCH_IDX,
                _CNST.METRIC_KEY: _CNST.DEFAULT_METRIC_VALUE,
            },
            _CNST.BEST_KEY: {
                _CNST.EPOCH_IDX_KEY: _CNST.DEFAULT_EPOCH_IDX,
                _CNST.METRIC_KEY: _CNST.DEFAULT_METRIC_VALUE,
            },
        }
        logger.info('Created new training status')

    return training_status
 def create_metagraph():
     """Creates and stores the MetaGraph for the inference graph."""
     metagraph_path = paths.MetaGraphs.INFERENCE
     if not os.path.exists(metagraph_path):
         logger.info('Creating new inference MetaGraph')
         inference_graph = graphs.build_inference_graph()
         with inference_graph.as_default():
             tf.train.export_meta_graph(metagraph_path)
Exemple #5
0
    def __init__(self):
        # Create and store the MetaGraph for the evaluation graph.
        metagraph_path = paths.MetaGraphs.EVALUATION
        if os.path.exists(metagraph_path):
            logger.info('Importing existing evaluation MetaGraph '
                        'from {}'.format(metagraph_path))
            evaluation_graph = tf.Graph()
            with evaluation_graph.as_default():
                tf.train.import_meta_graph(metagraph_path)
        else:
            logger.info('Creating new evaluation MetaGraph')
            evaluation_graph = graphs.build_evaluation_graph()
            with evaluation_graph.as_default():
                tf.train.export_meta_graph(metagraph_path)

        self._session = tf.Session(graph=evaluation_graph)

        with evaluation_graph.as_default():
            self._saver = tf.train.Saver()
    def _before_new_training(self):
        # Randomly initialize the training graph and evaluate the un-trained
        # model to set a worst-case baseline.
        self._training_engine.initialize()
        self._training_engine.save()
        epoch_idx = -1
        avg_loss, accuracy = self._run_validation(epoch_idx)
        logger.info('Untrained model: '
                    'loss={:.3f}, '
                    'accuracy={:.3f}'.format(avg_loss, accuracy))

        # Log the evaluation results.
        self._logging_engine.log_evaluation_results(epoch_idx=epoch_idx,
                                                    avg_loss=avg_loss,
                                                    accuracy=accuracy)

        # Update the training status.
        self._update_training_status(epoch_idx=epoch_idx,
                                     validation_loss=avg_loss)
    def _after_epoch(self, epoch_idx):
        # Evaluate on the validation set.
        avg_loss, accuracy = self._run_validation(epoch_idx)
        # Log the evaluation results.
        self._logging_engine.log_evaluation_results(epoch_idx=epoch_idx,
                                                    avg_loss=avg_loss,
                                                    accuracy=accuracy)

        log_msg = 'After {} training epochs: ' \
                  'loss={:.3f}, ' \
                  'accuracy={:.3f}'.format(epoch_idx + 1,
                                           avg_loss,
                                           accuracy)
        is_new_best_epoch = \
            self._update_training_status(epoch_idx=epoch_idx,
                                         validation_loss=avg_loss)
        if is_new_best_epoch:
            log_msg += ' [new best model]'
            self._validation_engine.save()
        logger.info(log_msg)
Exemple #8
0
def main():
    parser = argparse.ArgumentParser(description='Predict the digit'
                                     'represented in the provided'
                                     'images.')
    parser.add_argument('image_paths',
                        type=str,
                        nargs='+',
                        help='Path to the image to predict on')
    parser.add_argument('--show',
                        action='store_true',
                        help='If set, the input images are displayed,'
                        'with prediction as a title.')

    args = parser.parse_args()
    image_paths = args.image_paths
    show = args.show

    with inf_eng.InferenceEngine() as inference_engine:
        predictions, probabilities = \
            inference_engine.load_preprocess_and_predict(image_paths)

    logger.info('Predictions generated by model: '
                '{}'.format(paths.Checkpoints.BEST_MODEL))
    for image_path, pred, probs in zip(image_paths, predictions,
                                       probabilities):
        print('Image: {}\n'
              '  Prediction: {}\n'
              '  Probabilities: [{}]'.format(
                  image_path, pred,
                  ', '.join(['{:.3f}'.format(p) for p in probs])))

        if show:
            pil_image = Image.open(image_path)
            title = 'Prediction: {}\n' \
                    'with probability: {:.3f}'.format(pred, probs[pred])
            plt.figure()
            plt.imshow(pil_image, cmap='gray')
            plt.title(title)
            plt.axis('off')
            plt.show()
            plt.close()
    def run(self):
        num_epochs = config.ExperimentConfig.NUM_EPOCHS

        latest_trained_epoch_idx = \
            self._training_status[_CNST.LATEST_TRAINED_KEY]

        if latest_trained_epoch_idx == -1:
            # No epoch has been trained: start a new training.
            self._before_new_training()
        else:
            # Some epochs were trained. Was the last epoch evaluated?
            latest_evaluated = self._training_status[_CNST.LATEST_EVALUATED_KEY]
            latest_evaluated_epoch_idx = latest_evaluated[_CNST.EPOCH_IDX_KEY]
            if latest_evaluated_epoch_idx == latest_trained_epoch_idx - 1:
                # The latest trained epoch was not evaluated.
                self._after_epoch(latest_trained_epoch_idx)
            elif latest_evaluated_epoch_idx < latest_trained_epoch_idx - 1:
                raise ValueError('Missing evaluations: the latest trained '
                                 'epoch index is {}, but the latest evaluated '
                                 'epoch index is {}'
                                 ''.format(latest_trained_epoch_idx,
                                           latest_evaluated_epoch_idx))
            else:
                # The latest trained epoch was evaluated: do nothing.
                pass

            # Resume the training where it was left.
            self._training_engine.resume()

        # At this point every trained epoch (possibly none) has been evaluated.

        for epoch_idx in range(latest_trained_epoch_idx + 1, num_epochs):
            self._before_epoch(epoch_idx)
            self._run_epoch(epoch_idx)
            self._after_epoch(epoch_idx)

        logger.info('Training completed')

        self._after_training()
        self._shut_down_engines()
    def _after_training(self):
        for testset_name, testset_def_path \
                in paths.DatasetDefinitions.TEST.items():
            avg_loss, accuracy = self._run_evaluation(testset_def_path,
                                                      testset_name)
            logger.info('Evaluated on: {}\n'
                        '  loss: {:.3f}\n'
                        '  accuracy: {:.3f}'.format(testset_name,
                                                    avg_loss,
                                                    accuracy))

            # Append the new evaluation results.
            eval_results = {
                _CNST.EVAL_RESULTS_TIMESTAMP_KEY: time.strftime(
                    '%a %d %b %Y %H:%M:%S'),
                _CNST.EVAL_RESULTS_DATASET_KEY: testset_name,
                'avg_loss': avg_loss,
                'accuracy': accuracy,
            }
            eval_utils.save_evaluation_results(
                eval_results_path=paths.EvaluationResults.PATH,
                eval_results=eval_results)