Ejemplo n.º 1
0
def make_model(dataset: str, model: str = "test"):
    if model == "test":
        from tensorflow.keras import Model
        from tensorflow.keras.layers import Dense, Flatten, Conv2D, BatchNormalization

        class MyModel(Model):
            def __init__(self):
                super(MyModel, self).__init__()
                self.bn = BatchNormalization()
                self.conv1 = Conv2D(32, 3, activation="relu")
                self.flatten = Flatten()
                self.d1 = Dense(128, activation="relu")
                self.d2 = Dense(10, activation="softmax")

            def call(self, x):
                x = self.bn(x)
                x = self.conv1(x)
                x = self.flatten(x)
                x = self.d1(x)
                return self.d2(x)

        return MyModel()
    else:
        shapes = {
            "mnist": ((28, 28, 1), 10),
            "cifar10": ((32, 32, 3), 10),
            "cifar100": ((32, 32, 3), 100),
            "imagenet": ((224, 224, 3), 1000),
        }
        return get_keras_model(model, input_shape=shapes[dataset][0], num_classes=shapes[dataset][1])
Ejemplo n.º 2
0
def test_mlpblock_extract():
    vgg = get_keras_model("VGG16")

    @tf.function
    def get_grads(inputs):
        y = vgg(inputs)
        y = tf.reduce_mean(y)
        return tf.gradients(y, vgg.trainable_variables)

    x = tf.ones(shape=(1, 224, 224, 3), name="input")
    grad_conc = get_grads.get_concrete_function(x)
    g = dfgraph_from_tf_function(grad_conc)
    # load costs, and plot optionally, if platform is not flops
    logger.info("Loading costs")
    if args.platform == "flops":
        cost_model = None
    else:
        cost_model = CostModel(model_name,
                               args.platform,
                               log_base,
                               quantization=5)
        cost_model.fit()
        if args.debug:
            cost_model.plot_costs()

    # load model from Keras
    logger.info("Loading model {}".format(model_name))
    model = get_keras_model(model_name, input_shape=args.input_shape)
    g = dfgraph_from_keras(model,
                           batch_size=args.batch_size,
                           cost_model=cost_model,
                           loss_cpu_cost=0,
                           loss_ram_cost=(4 * args.batch_size))
    if args.debug:
        tf.keras.utils.plot_model(model,
                                  to_file=log_base /
                                  "plot_{}_keras.png".format(model_name),
                                  show_shapes=True,
                                  show_layer_names=True)
        plot_dfgraph(g, log_base, name=model_name)

    # sweep constant baselines
    logger.info(
Ejemplo n.º 4
0
import tensorflow as tf
import logging

from checkmate.core.solvers.strategy_chen import solve_chen_sqrtn
from checkmate.core.utils.timer import Timer
from checkmate.tf2.extraction import dfgraph_from_tf_function
from checkmate.tf2.load_keras_model import get_keras_model

BS = 128

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    logging.info("building graph")
    with Timer("build_graph", print_results=True):
        model = get_keras_model("ResNet50")

        def grads(images, labels):
            with tf.GradientTape() as tape:
                pred = model(images)
                loss = tf.reduce_mean(pred - labels)
            gradient = tape.gradient(loss, model.trainable_variables)
            return loss, gradient

        grad_fn = tf.function(grads).get_concrete_function(
            tf.TensorSpec(shape=(BS, 224, 224, 3)),
            tf.TensorSpec(shape=(BS, 1000)))
    logging.info("tracing graph")
    with Timer("trace_graph", print_results=True):
        g = dfgraph_from_tf_function(grad_fn)
    # sched_result = solve_ilp_gurobi(g, budget=platform_memory("p2xlarge"), approx=False, eps_noise=0.0)
    # sched_result = solve_approx_lp_deterministic_05_threshold(g, budget=platform_memory("p2xlarge"))
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from checkmate.core.solvers.strategy_optimal_ilp import solve_ilp_gurobi
from checkmate.tf2_keras.extraction import dfgraph_from_keras
from experiments.common.definitions import checkmate_data_dir
from checkmate.tf2.load_keras_model import get_keras_model

if __name__ == "__main__":
    # get sample network and generate a graph on it
    model = get_keras_model("VGG16")
    g = dfgraph_from_keras(mod=model)
    budget = sum(g.cost_ram.values()) + g.cost_ram_parameters

    # solve for a schedule
    scheduler_result = solve_ilp_gurobi(g, budget)
    R = scheduler_result.schedule_aux_data.R

    # compute costs for 1000 runs
    r = R.sum(axis=0)
    C = [g.cost_cpu[key] for key in sorted(g.cost_cpu)]
    results = [np.random.normal(C, 1e7) @ r for i in range(1000)]
    x = pd.Series(results, name="Cost in flops")

    # plot costs
    plt.figure()
    sns.distplot(x)
    checkmate_data_dir().mkdir(parents=True, exist_ok=True)
    plt.savefig(checkmate_data_dir() /
        scratch_dir.mkdir(parents=True, exist_ok=True)

        # load costs, and plot optionally, if platform is not flops
        print("Loading costs")
        if args.platform == "flops":
            cost_model = None
        else:
            cost_model = CostModel(args.model_name,
                                   args.platform,
                                   scratch_dir,
                                   quantization=5)
            cost_model.fit()

        # load model from Keras
        print("Loading model {}".format(args.model_name))
        model = get_keras_model(args.model_name)
        g = dfgraph_from_keras(model,
                               batch_size=args.batch_size,
                               cost_model=cost_model,
                               loss_cpu_cost=0,
                               loss_ram_cost=(4 * args.batch_size))

    common_kwargs = dict(g=g,
                         budget=B,
                         print_to_console=False,
                         eps_noise=0,
                         approx=False)

    print("Common args:", common_kwargs)

    data = []