Example #1
0
def run(dataset_name,
        model_dir,
        batch_size,
        predictions_per_example,
        max_examples,
        output_dir,
        fake_data=False):
    """Runs predictions on the given dataset using the specified model."""
    gfile.makedirs(output_dir)
    data_config = image_data_utils.get_data_config(dataset_name)
    dataset = data_lib.build_dataset(data_config,
                                     batch_size,
                                     fake_data=fake_data)
    if max_examples:
        dataset = dataset.take(max_examples)

    model_opts = experiment_utils.load_config(model_dir +
                                              '/model_options.json')
    model_opts = models_lib.ModelOptions(**model_opts)
    logging.info('Loaded model options: %s', model_opts)

    model = models_lib.build_model(model_opts)
    logging.info('Loading model weights...')
    model.load_weights(model_dir + '/model.ckpt')
    logging.info('done loading model weights.')

    writer = array_utils.StatsWriter(
        os.path.join(output_dir, 'predictions_%s' % dataset_name))
    writer_small = array_utils.StatsWriter(
        os.path.join(output_dir, 'predictions_small_%s' % dataset_name))

    writers = {'full': writer, 'small': writer_small}
    max_batches = 50000 // batch_size
    experiment_utils.make_predictions(model, dataset, predictions_per_example,
                                      writers, max_batches)
Example #2
0
def load_model(model_dir):
  model_opts = experiment_utils.load_config(model_dir + '/model_options.json')
  model_opts = ModelOptions(**model_opts)
  logging.info('Loaded model options: %s', model_opts)

  model = build_model(model_opts)
  logging.info('Loading model weights...')
  model.load_weights(model_dir + '/model.ckpt')
  logging.info('done loading model weights.')
  return model
Example #3
0
def load_trained_model(model_dir, load_weights=True, as_components=False):
    """Load a trained model using recorded options and weights."""
    model_opts = experiment_utils.load_config(model_dir +
                                              '/model_options.json')
    model_opts = ModelOptions(**model_opts)
    logging.info('Loaded model options: %s', model_opts)

    model = build_model(model_opts, as_components=as_components)
    if load_weights:
        logging.info('Loading model weights...')
        if as_components:
            _ = [m.load_weights(model_dir + '/model.ckpt') for m in model]
        else:
            model.load_weights(model_dir + '/model.ckpt')
        logging.info('done loading model weights.')
    return model