def run(dataset_name, model_dir, batch_size, predictions_per_example, max_examples, output_dir, fake_data=False): """Runs predictions on the given dataset using the specified model.""" gfile.makedirs(output_dir) data_config = image_data_utils.get_data_config(dataset_name) dataset = data_lib.build_dataset(data_config, batch_size, fake_data=fake_data) if max_examples: dataset = dataset.take(max_examples) model_opts = experiment_utils.load_config(model_dir + '/model_options.json') model_opts = models_lib.ModelOptions(**model_opts) logging.info('Loaded model options: %s', model_opts) model = models_lib.build_model(model_opts) logging.info('Loading model weights...') model.load_weights(model_dir + '/model.ckpt') logging.info('done loading model weights.') writer = array_utils.StatsWriter( os.path.join(output_dir, 'predictions_%s' % dataset_name)) writer_small = array_utils.StatsWriter( os.path.join(output_dir, 'predictions_small_%s' % dataset_name)) writers = {'full': writer, 'small': writer_small} max_batches = 50000 // batch_size experiment_utils.make_predictions(model, dataset, predictions_per_example, writers, max_batches)
def load_model(model_dir): model_opts = experiment_utils.load_config(model_dir + '/model_options.json') model_opts = ModelOptions(**model_opts) logging.info('Loaded model options: %s', model_opts) model = build_model(model_opts) logging.info('Loading model weights...') model.load_weights(model_dir + '/model.ckpt') logging.info('done loading model weights.') return model
def load_trained_model(model_dir, load_weights=True, as_components=False): """Load a trained model using recorded options and weights.""" model_opts = experiment_utils.load_config(model_dir + '/model_options.json') model_opts = ModelOptions(**model_opts) logging.info('Loaded model options: %s', model_opts) model = build_model(model_opts, as_components=as_components) if load_weights: logging.info('Loading model weights...') if as_components: _ = [m.load_weights(model_dir + '/model.ckpt') for m in model] else: model.load_weights(model_dir + '/model.ckpt') logging.info('done loading model weights.') return model