def main(argv): parser = argparse.ArgumentParser( description="Plot the loss distribution of a model and dataset pair") parser.add_argument("model", choices=[ "small_cnn", "cnn", "lstm_lm", "lstm_lm2", "lstm_lm3", "small_cnn_sq" ], help="Choose the type of the model") parser.add_argument("weights", help="The file containing the model weights") parser.add_argument("dataset", choices=[ "mnist", "cifar10", "cifar100", "cifar10-augmented", "cifar100-augmented", "ptb" ], help="Choose the dataset to compute the loss") parser.add_argument("--score", choices=["gnorm", "loss"], default="loss", help="Choose a score to plot") parser.add_argument("--batch_size", type=int, default=128, help="The batch size for computing the loss") parser.add_argument( "--random_seed", type=int, default=0, help="A seed for the PRNG (mainly used for dataset generation)") args = parser.parse_args(argv) np.random.seed(args.random_seed) dataset = load_dataset(args.dataset) network = models.get(args.model)(dataset.shape, dataset.output_size) model = OracleWrapper(network, score=args.score) model.model.load_weights(args.weights) for i in range(0, dataset.train_size, args.batch_size): idxs = slice(i, i + args.batch_size) for s in model.score_batch(*dataset.train_data(idxs)): print s
def build_model(model, wrapper, dataset, hyperparams, reweighting): def build_optimizer(opt, hyperparams): return { "sgd": SGD(lr=hyperparams.get("lr", 0.001), momentum=hyperparams.get("momentum", 0.0), clipnorm=hyperparams.get("clipnorm", None)), "adam": Adam(lr=hyperparams.get("lr", 0.001), clipnorm=hyperparams.get("clipnorm", None)), "rmsprop": RMSprop(lr=hyperparams.get("lr", 0.001), decay=hyperparams.get("lr_decay", 0.0), clipnorm=hyperparams.get("clipnorm", None)) }[opt] model = models.get(model)(dataset.shape, dataset.output_size) model.compile(optimizer=build_optimizer(hyperparams.get("opt", "adam"), hyperparams), loss=model.loss, metrics=model.metrics) return get_models_dictionary(hyperparams, reweighting)[wrapper](model)
def build_model(model, wrapper, dataset, hyperparams, reweighting): return get_models_dictionary(hyperparams, reweighting)[wrapper]( models.get(model)(dataset.shape, dataset.output_size) )
def main(argv): parser = argparse.ArgumentParser( description=("Compute the variance reduction achieved by different " "importance sampling methods")) parser.add_argument( "model", choices=["small_cnn", "cnn", "wide_resnet_28_2", "lstm_lm"], help="Choose the type of the model") parser.add_argument("weights", help="The file containing the model weights") parser.add_argument("dataset", choices=[ "mnist", "cifar10", "cifar100", "cifar10-augmented", "cifar100-augmented", "imagenet-32x32", "ptb", "cifar10-whitened-augmented", "cifar100-whitened-augmented" ], help="Choose the dataset to compute the loss") parser.add_argument("--samples", type=int, default=10, help="How many samples to choose") parser.add_argument("--score", choices=["gnorm", "full_gnorm", "loss", "ones"], nargs="+", default="loss", help="Choose a score to perform sampling with") parser.add_argument("--batch_size", type=int, default=128, help="The batch size for computing the loss") parser.add_argument( "--inner_batch_size", type=int, default=32, help=("The batch size to use for gradient computations " "(to decrease memory usage)")) parser.add_argument( "--sample_size", type=int, default=1024, help="The sample size to compute the variance reduction") parser.add_argument( "--random_seed", type=int, default=0, help="A seed for the PRNG (mainly used for dataset generation)") parser.add_argument("--save_scores", help="Directory to save the scores in") args = parser.parse_args(argv) np.random.seed(args.random_seed) dataset = load_dataset(args.dataset) network = models.get(args.model)(dataset.shape, dataset.output_size) network.load_weights(args.weights) grad = build_grad_batched(network, args.inner_batch_size) reweighting = BiasedReweightingPolicy() # Compute the full gradient idxs = np.random.choice(len(dataset.train_data), args.sample_size) x, y = dataset.train_data[idxs] full_grad = grad([x, y, np.ones(len(x))])[0] # Sample and approximate for score_metric in args.score: if score_metric != "ones": model = OracleWrapper(network, reweighting, score=score_metric) score = model.score else: score = uniform_score gs = np.zeros(shape=(10, ) + full_grad.shape, dtype=np.float32) print "Calculating %s..." % (score_metric, ) scores = score(x, y, batch_size=1) p = scores / scores.sum() pb = Progbar(args.samples) for i in range(args.samples): pb.update(i) idxs = np.random.choice(args.sample_size, args.batch_size, p=p) w = reweighting.sample_weights(idxs, scores).ravel() gs[i] = grad([x[idxs], y[idxs], w])[0] pb.update(args.samples) norms = np.sqrt(((full_grad - gs)**2).sum(axis=1)) alignment = gs.dot(full_grad[:, np.newaxis]) / np.sqrt( np.sum(full_grad**2)) alignment /= np.sqrt((gs**2).sum(axis=1, keepdims=True)) print "Mean of norms of diff", np.mean(norms) print "Variance of norms of diff", np.var(norms) print "Mean of alignment", np.mean(alignment) print "Variance of alignment", np.var(alignment) if args.save_scores: np.savetxt(path.join(args.save_scores, score_metric + ".txt"), scores)