def main(argv):
    parser = argparse.ArgumentParser(
        description="Plot the loss distribution of a model and dataset pair")

    parser.add_argument("model",
                        choices=[
                            "small_cnn", "cnn", "lstm_lm", "lstm_lm2",
                            "lstm_lm3", "small_cnn_sq"
                        ],
                        help="Choose the type of the model")
    parser.add_argument("weights",
                        help="The file containing the model weights")
    parser.add_argument("dataset",
                        choices=[
                            "mnist", "cifar10", "cifar100",
                            "cifar10-augmented", "cifar100-augmented", "ptb"
                        ],
                        help="Choose the dataset to compute the loss")
    parser.add_argument("--score",
                        choices=["gnorm", "loss"],
                        default="loss",
                        help="Choose a score to plot")
    parser.add_argument("--batch_size",
                        type=int,
                        default=128,
                        help="The batch size for computing the loss")
    parser.add_argument(
        "--random_seed",
        type=int,
        default=0,
        help="A seed for the PRNG (mainly used for dataset generation)")

    args = parser.parse_args(argv)

    np.random.seed(args.random_seed)

    dataset = load_dataset(args.dataset)
    network = models.get(args.model)(dataset.shape, dataset.output_size)
    model = OracleWrapper(network, score=args.score)
    model.model.load_weights(args.weights)

    for i in range(0, dataset.train_size, args.batch_size):
        idxs = slice(i, i + args.batch_size)
        for s in model.score_batch(*dataset.train_data(idxs)):
            print s
def build_model(model, wrapper, dataset, hyperparams, reweighting):
    def build_optimizer(opt, hyperparams):
        return {
            "sgd":
            SGD(lr=hyperparams.get("lr", 0.001),
                momentum=hyperparams.get("momentum", 0.0),
                clipnorm=hyperparams.get("clipnorm", None)),
            "adam":
            Adam(lr=hyperparams.get("lr", 0.001),
                 clipnorm=hyperparams.get("clipnorm", None)),
            "rmsprop":
            RMSprop(lr=hyperparams.get("lr", 0.001),
                    decay=hyperparams.get("lr_decay", 0.0),
                    clipnorm=hyperparams.get("clipnorm", None))
        }[opt]

    model = models.get(model)(dataset.shape, dataset.output_size)
    model.compile(optimizer=build_optimizer(hyperparams.get("opt", "adam"),
                                            hyperparams),
                  loss=model.loss,
                  metrics=model.metrics)

    return get_models_dictionary(hyperparams, reweighting)[wrapper](model)
def build_model(model, wrapper, dataset, hyperparams, reweighting):
    return get_models_dictionary(hyperparams, reweighting)[wrapper](
        models.get(model)(dataset.shape, dataset.output_size)
    )
def main(argv):
    parser = argparse.ArgumentParser(
        description=("Compute the variance reduction achieved by different "
                     "importance sampling methods"))

    parser.add_argument(
        "model",
        choices=["small_cnn", "cnn", "wide_resnet_28_2", "lstm_lm"],
        help="Choose the type of the model")
    parser.add_argument("weights",
                        help="The file containing the model weights")
    parser.add_argument("dataset",
                        choices=[
                            "mnist", "cifar10", "cifar100",
                            "cifar10-augmented", "cifar100-augmented",
                            "imagenet-32x32", "ptb",
                            "cifar10-whitened-augmented",
                            "cifar100-whitened-augmented"
                        ],
                        help="Choose the dataset to compute the loss")
    parser.add_argument("--samples",
                        type=int,
                        default=10,
                        help="How many samples to choose")
    parser.add_argument("--score",
                        choices=["gnorm", "full_gnorm", "loss", "ones"],
                        nargs="+",
                        default="loss",
                        help="Choose a score to perform sampling with")
    parser.add_argument("--batch_size",
                        type=int,
                        default=128,
                        help="The batch size for computing the loss")
    parser.add_argument(
        "--inner_batch_size",
        type=int,
        default=32,
        help=("The batch size to use for gradient computations "
              "(to decrease memory usage)"))
    parser.add_argument(
        "--sample_size",
        type=int,
        default=1024,
        help="The sample size to compute the variance reduction")
    parser.add_argument(
        "--random_seed",
        type=int,
        default=0,
        help="A seed for the PRNG (mainly used for dataset generation)")
    parser.add_argument("--save_scores",
                        help="Directory to save the scores in")

    args = parser.parse_args(argv)

    np.random.seed(args.random_seed)

    dataset = load_dataset(args.dataset)
    network = models.get(args.model)(dataset.shape, dataset.output_size)
    network.load_weights(args.weights)
    grad = build_grad_batched(network, args.inner_batch_size)
    reweighting = BiasedReweightingPolicy()

    # Compute the full gradient
    idxs = np.random.choice(len(dataset.train_data), args.sample_size)
    x, y = dataset.train_data[idxs]
    full_grad = grad([x, y, np.ones(len(x))])[0]

    # Sample and approximate
    for score_metric in args.score:
        if score_metric != "ones":
            model = OracleWrapper(network, reweighting, score=score_metric)
            score = model.score
        else:
            score = uniform_score
        gs = np.zeros(shape=(10, ) + full_grad.shape, dtype=np.float32)
        print "Calculating %s..." % (score_metric, )
        scores = score(x, y, batch_size=1)
        p = scores / scores.sum()
        pb = Progbar(args.samples)
        for i in range(args.samples):
            pb.update(i)
            idxs = np.random.choice(args.sample_size, args.batch_size, p=p)
            w = reweighting.sample_weights(idxs, scores).ravel()
            gs[i] = grad([x[idxs], y[idxs], w])[0]
        pb.update(args.samples)
        norms = np.sqrt(((full_grad - gs)**2).sum(axis=1))
        alignment = gs.dot(full_grad[:, np.newaxis]) / np.sqrt(
            np.sum(full_grad**2))
        alignment /= np.sqrt((gs**2).sum(axis=1, keepdims=True))
        print "Mean of norms of diff", np.mean(norms)
        print "Variance of norms of diff", np.var(norms)
        print "Mean of alignment", np.mean(alignment)
        print "Variance of alignment", np.var(alignment)
        if args.save_scores:
            np.savetxt(path.join(args.save_scores, score_metric + ".txt"),
                       scores)