Example #1
0
def run(dataset_name,
        model_dir,
        batch_size,
        predictions_per_example,
        max_examples,
        output_dir,
        fake_data=False):
    """Runs predictions on the given dataset using the specified model."""
    gfile.makedirs(output_dir)
    data_config = image_data_utils.get_data_config(dataset_name)
    dataset = data_lib.build_dataset(data_config,
                                     batch_size,
                                     fake_data=fake_data)
    if max_examples:
        dataset = dataset.take(max_examples)

    model_opts = experiment_utils.load_config(model_dir +
                                              '/model_options.json')
    model_opts = models_lib.ModelOptions(**model_opts)
    logging.info('Loaded model options: %s', model_opts)

    model = models_lib.build_model(model_opts)
    logging.info('Loading model weights...')
    model.load_weights(model_dir + '/model.ckpt')
    logging.info('done loading model weights.')

    writer = array_utils.StatsWriter(
        os.path.join(output_dir, 'predictions_%s' % dataset_name))
    writer_small = array_utils.StatsWriter(
        os.path.join(output_dir, 'predictions_small_%s' % dataset_name))

    writers = {'full': writer, 'small': writer_small}
    max_batches = 50000 // batch_size
    experiment_utils.make_predictions(model, dataset, predictions_per_example,
                                      writers, max_batches)
Example #2
0
def model_opts_from_hparams(hps, method, use_tpu, tpu, fake_training=False):
    """Returns a ModelOptions instance using given hyperparameters."""
    dropout_rate = hps.dropout_rate if hasattr(hps, 'dropout_rate') else 0
    variational = method in ('svi', 'll_svi')

    model_opts = models_lib.ModelOptions(
        # Modeling params
        method=method,
        # Data params.
        image_shape=data_lib.IMAGENET_SHAPE,
        num_classes=data_lib.IMAGENET_NUM_CLASSES,
        examples_per_epoch=data_lib.APPROX_IMAGENET_TRAINING_IMAGES,
        validation_size=data_lib.IMAGENET_VALIDATION_IMAGES,
        use_bfloat16=True,
        # SGD params
        train_epochs=90,
        batch_size=hps.batch_size,
        dropout_rate=dropout_rate,
        init_learning_rate=hps.init_learning_rate,
        # Variational params
        std_prior_scale=hps.std_prior_scale if variational else None,
        init_prior_scale_mean=hps.init_prior_scale_mean
        if variational else None,
        init_prior_scale_std=hps.init_prior_scale_std if variational else None,
        num_updates=data_lib.APPROX_IMAGENET_TRAINING_IMAGES,

        # TPU params
        use_tpu=use_tpu,
        tpu=tpu,
        num_cores=8,
        # GPU params
        num_gpus=8,
        num_replicas=1,
    )

    if fake_training:
        model_opts.batch_size = 32
        model_opts.examples_per_epoch = 256
        model_opts.train_epochs = 1
    return model_opts