Esempio n. 1
0
def model_creator(batch_size, name="default", dtype=np.float32):
    """Create MNIST autoencoder model. Dataset is part of model."""

    model = Model(name)

    def get_batch_size(data):
        if isinstance(data, IndexedGrad):
            return int(data.live[0].shape[1])
        else:
            return int(data.shape[1])

    init_dict = {}
    global_vars = []
    local_vars = []

    # TODO: factor out to reuse between scripts
    # TODO: change feed_dict logic to reuse value provided to VarStruct
    # current situation makes reinitialization of global variable change
    # it's value, counterinituitive
    def init_var(val, name, is_global=False):
        """Helper to create variables with numpy or TF initial values."""
        if isinstance(val, tf.Tensor):
            var = u.get_variable(name=name, initializer=val, reuse=is_global)
        else:
            val = np.array(val)
            assert u.is_numeric(val), "Non-numeric type."

            var_struct = u.get_var(name=name, initializer=val, reuse=is_global)
            holder = var_struct.val_
            init_dict[holder] = val
            var = var_struct.var

        if is_global:
            global_vars.append(var)
        else:
            local_vars.append(var)

        return var

    # TODO: get rid of purely_relu
    def nonlin(x):
        if purely_relu:
            return tf.nn.relu(x)
        elif purely_linear:
            return tf.identity(x)
        else:
            return tf.sigmoid(x)

    # TODO: rename into "nonlin_d"
    def d_nonlin(y):
        if purely_relu:
            return u.relu_mask(y)
        elif purely_linear:
            return 1
        else:
            return y * (1 - y)

    patches = train_images[:, :args.batch_size]
    test_patches = test_images[:, :args.batch_size]

    if args.dataset == 'cifar':
        input_dim = 3 * 32 * 32
    elif args.dataset == 'mnist':
        input_dim = 28 * 28
    else:
        assert False
    if release_name == 'kfac_tiny':
        fs = [args.batch_size, input_dim, 196, input_dim]
    else:
        fs = [
            args.batch_size, input_dim, 1024, 1024, 1024, 196, 1024, 1024,
            1024, input_dim
        ]

    def f(i):
        return fs[i + 1]  # W[i] has shape f[i] x f[i-1]

    n = len(fs) - 2

    # Full dataset from which new batches are sampled
    X_full = init_var(train_images, "X_full", is_global=True)

    X = init_var(patches, "X", is_global=False)  # stores local batch per model
    W = [None] * n
    W.insert(0, X)
    A = [None] * (n + 2)
    A[1] = W[0]
    for i in range(1, n + 1):
        init_val = ng_init(f(i), f(i - 1)).astype(dtype)
        W[i] = init_var(init_val, "W_%d" % (i, ), is_global=True)
        A[i + 1] = nonlin(kfac_lib.matmul(W[i], A[i]))
    err = A[n + 1] - A[1]
    model.loss = u.L2(err) / (2 * get_batch_size(err))

    # create test error eval
    layer0 = init_var(test_patches, "X_test", is_global=True)
    layer = layer0
    for i in range(1, n + 1):
        layer = nonlin(W[i] @ layer)
    verr = (layer - layer0)
    model.vloss = u.L2(verr) / (2 * get_batch_size(verr))

    # manually compute backprop to use for sanity checking
    B = [None] * (n + 1)
    B2 = [None] * (n + 1)
    B[n] = err * d_nonlin(A[n + 1])
    _sampled_labels_live = tf.random_normal((f(n), f(-1)), dtype=dtype, seed=0)
    if args.fixed_labels:
        _sampled_labels_live = tf.ones(shape=(f(n), f(-1)), dtype=dtype)

    _sampled_labels = init_var(_sampled_labels_live,
                               "to_be_deleted",
                               is_global=False)

    B2[n] = _sampled_labels * d_nonlin(A[n + 1])
    for i in range(n - 1, -1, -1):
        backprop = t(W[i + 1]) @ B[i + 1]
        B[i] = backprop * d_nonlin(A[i + 1])
        backprop2 = t(W[i + 1]) @ B2[i + 1]
        B2[i] = backprop2 * d_nonlin(A[i + 1])

    # cov_A = [None]*(n+1)    # covariance of activations[i]
    # cov_B2 = [None]*(n+1)   # covariance of synthetic backprops[i]


#  vars_svd_A = [None]*(n+1)
#  vars_svd_B2 = [None]*(n+1)
#  dW = [None]*(n+1)
#  pre_dW = [None]*(n+1)   # preconditioned dW
# todo: decouple initial value from covariance update
# # maybe need start with identity and do running average
# for i in range(1,n+1):
#   if regularized_svd:
#     cov_A[i] = init_var(A[i]@t(A[i])/args.batch_size+args.Lambda*u.Identity(f(i-1)), "cov_A%d"%(i,))
#     cov_B2[i] = init_var(B2[i]@t(B2[i])/args.batch_size+args.Lambda*u.Identity(f(i)), "cov_B2%d"%(i,))
#   else:
#     cov_A[i] = init_var(A[i]@t(A[i])/args.batch_size, "cov_A%d"%(i,))
#     cov_B2[i] = init_var(B2[i]@t(B2[i])/args.batch_size, "cov_B2%d"%(i,))
#    vars_svd_A[i] = u.SvdWrapper(cov_A[i],"svd_A_%d"%(i,), do_inverses=False)
#    vars_svd_B2[i] = u.SvdWrapper(cov_B2[i],"svd_B2_%d"%(i,), do_inverses=False)

#    whitened_A = u.cached_inverse(vars_svd_A[i], args.Lambda) @ A[i]
#    whitened_B = u.cached_inverse(vars_svd_B2[i], args.Lambda) @ B[i]
#    dW[i] = (B[i] @ t(A[i]))/args.batch_size
#    pre_dW[i] = (whitened_B @ t(whitened_A))/args.batch_size

    sampled_labels_live = A[n + 1] + tf.random_normal(
        (f(n), f(-1)), dtype=dtype, seed=0)
    if args.fixed_labels:
        sampled_labels_live = A[n + 1] + tf.ones(shape=(f(n), f(-1)),
                                                 dtype=dtype)
    sampled_labels = init_var(sampled_labels_live,
                              "sampled_labels",
                              is_global=False)
    err2 = A[n + 1] - sampled_labels
    model.loss2 = u.L2(err2) / (2 * args.batch_size)
    model.global_vars = global_vars
    model.local_vars = local_vars
    model.trainable_vars = W[1:]

    # todo, we have 3 places where model step is tracked, reduce
    model.step = init_var(u.as_int32(0), "step", is_global=False)
    advance_step_op = model.step.assign_add(1)
    assert get_batch_size(X_full) % args.batch_size == 0
    batches_per_dataset = (get_batch_size(X_full) // args.batch_size)
    batch_idx = tf.mod(model.step, batches_per_dataset)
    start_idx = batch_idx * args.batch_size
    advance_batch_op = X.assign(X_full[:,
                                       start_idx:start_idx + args.batch_size])

    def advance_batch():
        #    print("Step for model(%s) is %s"%(model.name, u.eval(model.step)))
        sess = u.get_default_session()
        # TODO: get rid of _sampled_labels
        sessrun([sampled_labels.initializer, _sampled_labels.initializer])
        if args.advance_batch:
            sessrun(advance_batch_op)
        sessrun(advance_step_op)

    model.advance_batch = advance_batch

    # TODO: refactor this to take initial values out of Var struct
    #global_init_op = tf.group(*[v.initializer for v in global_vars])
    global_init_ops = [v.initializer for v in global_vars]
    global_init_op = tf.group(*[v.initializer for v in global_vars])
    global_init_query_ops = [
        tf.logical_not(tf.is_variable_initialized(v)) for v in global_vars
    ]

    def initialize_global_vars(verbose=False, reinitialize=False):
        """If reinitialize is false, will not reinitialize variables already
    initialized."""

        sess = u.get_default_session()
        if not reinitialize:
            uninited = sessrun(global_init_query_ops)
            # use numpy boolean indexing to select list of initializers to run
            to_initialize = list(np.asarray(global_init_ops)[uninited])
        else:
            to_initialize = global_init_ops

        if verbose:
            print("Initializing following:")
            for v in to_initialize:
                print("   " + v.name)

        sessrun(to_initialize, feed_dict=init_dict)

    model.initialize_global_vars = initialize_global_vars

    # didn't quite work (can't initialize var in same run call as deps likely)
    # enforce that batch is initialized before everything
    # except fake labels opa
    # for v in local_vars:
    #   if v != X and v != sampled_labels and v != _sampled_labels:
    #     print("Adding dep %s on %s"%(v.initializer.name, X.initializer.name))
    #     u.add_dep(v.initializer, on_op=X.initializer)

    local_init_op = tf.group(*[v.initializer for v in local_vars],
                             name="%s_localinit" % (model.name))
    print("Local vars:")
    for v in local_vars:
        print(v.name)

    def initialize_local_vars():
        sess = u.get_default_session()
        sessrun(_sampled_labels.initializer, feed_dict=init_dict)
        sessrun(X.initializer, feed_dict=init_dict)
        sessrun(local_init_op, feed_dict=init_dict)

    model.initialize_local_vars = initialize_local_vars

    return model
Esempio n. 2
0
def model_creator(batch_size, name='defaultmodel', dtype=np.float32):
    """Create MNIST autoencoder model. Dataset is part of model."""

    model = Model(name)

    init_dict = {}
    global_vars = []
    local_vars = []

    # TODO: factor out to reuse between scripts
    # TODO: change feed_dict logic to reuse value provided to VarStruct
    # current situation makes reinitialization of global variable change
    # it's value, counterinituitive
    def init_var(val, name, is_global=False):
        """Helper to create variables with numpy or TF initial values."""
        if isinstance(val, tf.Tensor):
            var = u.get_variable(name=name, initializer=val, reuse=is_global)
        else:
            val = np.array(val)
            assert u.is_numeric(val), "Non-numeric type."

            var_struct = u.get_var(name=name, initializer=val, reuse=is_global)
            holder = var_struct.val_
            init_dict[holder] = val
            var = var_struct.var

        if is_global:
            global_vars.append(var)
        else:
            local_vars.append(var)

        return var

    # TODO: get rid of purely_relu
    def nonlin(x):
        if purely_relu:
            return tf.nn.relu(x)
        elif purely_linear:
            return tf.identity(x)
        else:
            return tf.sigmoid(x)

    # TODO: rename into "nonlin_d"
    def d_nonlin(y):
        if purely_relu:
            return u.relu_mask(y)
        elif purely_linear:
            return 1
        else:
            return y * (1 - y)

    train_images = load_MNIST.load_MNIST_images(
        'data/train-images-idx3-ubyte').astype(dtype)
    patches = train_images[:, :batch_size]
    test_patches = train_images[:, -batch_size:]
    assert dsize < 25000

    fs = [
        batch_size, 28 * 28, 1024, 1024, 1024, 196, 1024, 1024, 1024, 28 * 28
    ]

    def f(i):
        return fs[i + 1]  # W[i] has shape f[i] x f[i-1]

    n = len(fs) - 2

    X = init_var(patches, "X", is_global=False)
    W = [None] * n
    W.insert(0, X)
    A = [None] * (n + 2)
    A[1] = W[0]
    for i in range(1, n + 1):
        init_val = ng_init(f(i), f(i - 1)).astype(dtype)
        W[i] = init_var(init_val, "W_%d" % (i, ), is_global=True)
        A[i + 1] = nonlin(kfac_lib.matmul(W[i], A[i]))

    err = A[n + 1] - A[1]

    # create test error eval
    layer = init_var(test_patches, "X_test", is_global=False)
    for i in range(1, n + 1):
        layer = nonlin(W[i] @ layer)
    verr = (layer - test_patches)
    model.vloss = u.L2(verr) / (2 * batch_size)

    # manually compute backprop to use for sanity checking
    B = [None] * (n + 1)
    B2 = [None] * (n + 1)
    B[n] = err * d_nonlin(A[n + 1])
    _sampled_labels_live = tf.random_normal((f(n), f(-1)), dtype=dtype, seed=0)
    if use_fixed_labels:
        _sampled_labels_live = tf.ones(shape=(f(n), f(-1)), dtype=dtype)

    _sampled_labels = init_var(_sampled_labels_live,
                               "to_be_deleted",
                               is_global=False)

    B2[n] = _sampled_labels * d_nonlin(A[n + 1])
    for i in range(n - 1, -1, -1):
        backprop = t(W[i + 1]) @ B[i + 1]
        B[i] = backprop * d_nonlin(A[i + 1])
        backprop2 = t(W[i + 1]) @ B2[i + 1]
        B2[i] = backprop2 * d_nonlin(A[i + 1])

    cov_A = [None] * (n + 1)  # covariance of activations[i]
    cov_B2 = [None] * (n + 1)  # covariance of synthetic backprops[i]
    vars_svd_A = [None] * (n + 1)
    vars_svd_B2 = [None] * (n + 1)
    dW = [None] * (n + 1)
    dW2 = [None] * (n + 1)
    pre_dW = [None] * (n + 1)  # preconditioned dW
    for i in range(1, n + 1):
        if regularized_svd:
            cov_A[i] = init_var(
                A[i] @ t(A[i]) / batch_size + LAMBDA * u.Identity(f(i - 1)),
                "cov_A%d" % (i, ))
            cov_B2[i] = init_var(
                B2[i] @ t(B2[i]) / batch_size + LAMBDA * u.Identity(f(i)),
                "cov_B2%d" % (i, ))
        else:
            cov_A[i] = init_var(A[i] @ t(A[i]) / batch_size, "cov_A%d" % (i, ))
            cov_B2[i] = init_var(B2[i] @ t(B2[i]) / batch_size,
                                 "cov_B2%d" % (i, ))
        vars_svd_A[i] = u.SvdWrapper(cov_A[i], "svd_A_%d" % (i, ))
        vars_svd_B2[i] = u.SvdWrapper(cov_B2[i], "svd_B2_%d" % (i, ))
        if use_tikhonov:
            whitened_A = u.regularized_inverse3(vars_svd_A[i], L=LAMBDA) @ A[i]
            whitened_B2 = u.regularized_inverse3(vars_svd_B2[i],
                                                 L=LAMBDA) @ B[i]
        else:
            whitened_A = u.pseudo_inverse2(vars_svd_A[i]) @ A[i]
            whitened_B2 = u.pseudo_inverse2(vars_svd_B2[i]) @ B[i]

        dW[i] = (B[i] @ t(A[i])) / batch_size
        dW2[i] = B[i] @ t(A[i])
        pre_dW[i] = (whitened_B2 @ t(whitened_A)) / batch_size

        #  model.extra['A'] = A
        #  model.extra['B'] = B
        #  model.extra['B2'] = B2
        #  model.extra['cov_A'] = cov_A
        #  model.extra['cov_B2'] = cov_B2
        #  model.extra['vars_svd_A'] = vars_svd_A
        #  model.extra['vars_svd_B2'] = vars_svd_B2
        #  model.extra['W'] = W
        #  model.extra['dW'] = dW
        #  model.extra['dW2'] = dW2
        #  model.extra['pre_dW'] = pre_dW

    model.loss = u.L2(err) / (2 * batch_size)
    sampled_labels_live = A[n + 1] + tf.random_normal(
        (f(n), f(-1)), dtype=dtype, seed=0)
    if use_fixed_labels:
        sampled_labels_live = A[n + 1] + tf.ones(shape=(f(n), f(-1)),
                                                 dtype=dtype)
    sampled_labels = init_var(sampled_labels_live,
                              "sampled_labels",
                              is_global=False)
    err2 = A[n + 1] - sampled_labels
    model.loss2 = u.L2(err2) / (2 * batch_size)
    model.global_vars = global_vars
    model.local_vars = local_vars
    model.trainable_vars = W[1:]

    def advance_batch():
        sess = tf.get_default_session()
        # TODO: get rid of _sampled_labels
        sess.run([sampled_labels.initializer, _sampled_labels.initializer])

    model.advance_batch = advance_batch

    # TODO: refactor this to take initial values out of Var struct
    #global_init_op = tf.group(*[v.initializer for v in global_vars])
    global_init_ops = [v.initializer for v in global_vars]
    global_init_op = tf.group(*[v.initializer for v in global_vars])

    global_init_query_op = [
        tf.logical_not(tf.is_variable_initialized(v)) for v in global_vars
    ]

    def initialize_global_vars(verbose=False, reinitialize=False):
        """If reinitialize is false, will not reinitialize variables already
    initialized."""

        sess = tf.get_default_session()
        if not reinitialize:
            uninited = sess.run(global_init_query_op)
            # use numpy boolean indexing to select list of initializers to run
            to_initialize = list(np.asarray(global_init_ops)[uninited])
        else:
            to_initialize = global_init_ops

        if verbose:
            print("Initializing following:")
            for v in to_initialize:
                print("   " + v.name)

        sess.run(to_initialize, feed_dict=init_dict)

    model.initialize_global_vars = initialize_global_vars

    local_init_op = tf.group(*[v.initializer for v in local_vars])

    def initialize_local_vars():
        sess = tf.get_default_session()
        sess.run(X.initializer, feed_dict=init_dict)  # A's depend on X
        sess.run(_sampled_labels.initializer, feed_dict=init_dict)
        sess.run(local_init_op, feed_dict=init_dict)

    model.initialize_local_vars = initialize_local_vars

    return model
Esempio n. 3
0
def model_creator(batch_size, name='defaultmodel', dtype=np.float32):
    """Create MNIST autoencoder model. Dataset is part of model."""

    global hack_global_init_dict

    model = Model(name)

    # TODO: actually use batch_size
    init_dict = {}  # todo: rename to feed_dict?
    global_vars = []
    local_vars = []

    # TODO: rename to make_var
    def init_var(val, name, is_global=False):
        """Helper to create variables with numpy or TF initial values."""
        if isinstance(val, tf.Tensor):
            var = u.get_variable(name=name, initializer=val, reuse=is_global)
        else:
            val = np.array(val)
            assert u.is_numeric(val), "Non-numeric type."

            var_struct = u.get_var(name=name, initializer=val, reuse=is_global)
            holder = var_struct.val_
            init_dict[holder] = val
            var = var_struct.var

        if is_global:
            global_vars.append(var)
        else:
            local_vars.append(var)

        return var

    # TODO: get rid of purely_relu
    def nonlin(x):
        if purely_relu:
            return tf.nn.relu(x)
        elif purely_linear:
            return tf.identity(x)
        else:
            return tf.sigmoid(x)

    # TODO: rename into "nonlin_d"
    def d_nonlin(y):
        if purely_relu:
            return u.relu_mask(y)
        elif purely_linear:
            return 1
        else:
            return y * (1 - y)

    train_images = load_MNIST.load_MNIST_images(
        'data/train-images-idx3-ubyte').astype(dtype)
    patches = train_images[:, :batch_size]
    fs = [batch_size, 28 * 28, 196, 28 * 28]

    def f(i):
        return fs[i + 1]  # W[i] has shape f[i] x f[i-1]

    n = len(fs) - 2

    X = init_var(patches, "X", is_global=False)
    W = [None] * n
    W.insert(0, X)
    A = [None] * (n + 2)
    A[1] = W[0]
    W0f_old = W_uniform(fs[2],
                        fs[3]).astype(dtype)  # to match previous generation
    W0s_old = u.unflatten(W0f_old, fs[1:])  # perftodo: this creates transposes
    for i in range(1, n + 1):
        #    temp = init_var(ng_init(f(i), f(i-1)), "W_%d"%(i,), is_global=True)
        #    init_val1 = W0s_old[i-1]
        init_val = ng_init(f(i), f(i - 1)).astype(dtype)
        W[i] = init_var(init_val, "W_%d" % (i, ), is_global=True)
        A[i + 1] = nonlin(kfac_lib.matmul(W[i], A[i]))

    err = A[n + 1] - A[1]

    # manually compute backprop to use for sanity checking
    B = [None] * (n + 1)
    B2 = [None] * (n + 1)
    B[n] = err * d_nonlin(A[n + 1])
    _sampled_labels_live = tf.random_normal((f(n), f(-1)), dtype=dtype, seed=0)
    if use_fixed_labels:
        _sampled_labels_live = tf.ones(shape=(f(n), f(-1)), dtype=dtype)

    _sampled_labels = init_var(_sampled_labels_live,
                               "to_be_deleted",
                               is_global=False)

    B2[n] = _sampled_labels * d_nonlin(A[n + 1])
    for i in range(n - 1, -1, -1):
        backprop = t(W[i + 1]) @ B[i + 1]
        B[i] = backprop * d_nonlin(A[i + 1])
        backprop2 = t(W[i + 1]) @ B2[i + 1]
        B2[i] = backprop2 * d_nonlin(A[i + 1])

    cov_A = [None] * (n + 1)  # covariance of activations[i]
    cov_B2 = [None] * (n + 1)  # covariance of synthetic backprops[i]
    vars_svd_A = [None] * (n + 1)
    vars_svd_B2 = [None] * (n + 1)
    dW = [None] * (n + 1)
    dW2 = [None] * (n + 1)
    pre_dW = [None] * (n + 1)  # preconditioned dW
    for i in range(1, n + 1):
        if regularized_svd:
            cov_A[i] = init_var(
                A[i] @ t(A[i]) / batch_size + LAMBDA * u.Identity(f(i - 1)),
                "cov_A%d" % (i, ))
            cov_B2[i] = init_var(
                B2[i] @ t(B2[i]) / batch_size + LAMBDA * u.Identity(f(i)),
                "cov_B2%d" % (i, ))
        else:
            cov_A[i] = init_var(A[i] @ t(A[i]) / batch_size, "cov_A%d" % (i, ))
            cov_B2[i] = init_var(B2[i] @ t(B2[i]) / batch_size,
                                 "cov_B2%d" % (i, ))
        vars_svd_A[i] = u.SvdWrapper(cov_A[i], "svd_A_%d" % (i, ))
        vars_svd_B2[i] = u.SvdWrapper(cov_B2[i], "svd_B2_%d" % (i, ))
        if use_tikhonov:
            whitened_A = u.regularized_inverse3(vars_svd_A[i], L=LAMBDA) @ A[i]
            whitened_B2 = u.regularized_inverse3(vars_svd_B2[i],
                                                 L=LAMBDA) @ B[i]
        else:
            whitened_A = u.pseudo_inverse2(vars_svd_A[i]) @ A[i]
            whitened_B2 = u.pseudo_inverse2(vars_svd_B2[i]) @ B[i]

        dW[i] = (B[i] @ t(A[i])) / batch_size
        dW2[i] = B[i] @ t(A[i])
        pre_dW[i] = (whitened_B2 @ t(whitened_A)) / batch_size

        #  model.extra['A'] = A
        #  model.extra['B'] = B
        #  model.extra['B2'] = B2
        #  model.extra['cov_A'] = cov_A
        #  model.extra['cov_B2'] = cov_B2
        #  model.extra['vars_svd_A'] = vars_svd_A
        #  model.extra['vars_svd_B2'] = vars_svd_B2
        #  model.extra['W'] = W
        #  model.extra['dW'] = dW
        #  model.extra['dW2'] = dW2
        #  model.extra['pre_dW'] = pre_dW

    model.loss = u.L2(err) / (2 * batch_size)
    sampled_labels_live = A[n + 1] + tf.random_normal(
        (f(n), f(-1)), dtype=dtype, seed=0)
    if use_fixed_labels:
        sampled_labels_live = A[n + 1] + tf.ones(shape=(f(n), f(-1)),
                                                 dtype=dtype)
    sampled_labels = init_var(sampled_labels_live,
                              "sampled_labels",
                              is_global=False)
    err2 = A[n + 1] - sampled_labels
    model.loss2 = u.L2(err2) / (2 * batch_size)
    model.global_vars = global_vars
    model.local_vars = local_vars
    model.trainable_vars = W[1:]

    def advance_batch():
        sess = tf.get_default_session()
        # TODO: get rid of _sampled_labels
        sess.run([sampled_labels.initializer, _sampled_labels.initializer])

    model.advance_batch = advance_batch

    global_init_op = tf.group(*[v.initializer for v in global_vars])

    def initialize_global_vars():
        sess = tf.get_default_session()
        sess.run(global_init_op, feed_dict=init_dict)

    model.initialize_global_vars = initialize_global_vars

    local_init_op = tf.group(*[v.initializer for v in local_vars])

    def initialize_local_vars():
        sess = tf.get_default_session()
        sess.run(X.initializer, feed_dict=init_dict)  # A's depend on X
        sess.run(_sampled_labels.initializer, feed_dict=init_dict)
        sess.run(local_init_op, feed_dict=init_dict)

    model.initialize_local_vars = initialize_local_vars

    hack_global_init_dict = init_dict

    return model
Esempio n. 4
0
def model_creator(batch_size, name="default", dtype=np.float32):
  """Create MNIST autoencoder model. Dataset is part of model."""

  model = Model(name)

  def get_batch_size(data):
    if isinstance(data, IndexedGrad):
      return int(data.live[0].shape[1])
    else:
      return int(data.shape[1])

  init_dict = {}
  global_vars = []
  local_vars = []
  
  # TODO: factor out to reuse between scripts
  # TODO: change feed_dict logic to reuse value provided to VarStruct
  # current situation makes reinitialization of global variable change
  # it's value, counterinituitive
  def init_var(val, name, is_global=False):
    """Helper to create variables with numpy or TF initial values."""
    if isinstance(val, tf.Tensor):
      var = u.get_variable(name=name, initializer=val, reuse=is_global)
    else:
      val = np.array(val)
      assert u.is_numeric(val), "Non-numeric type."
      
      var_struct = u.get_var(name=name, initializer=val, reuse=is_global)
      holder = var_struct.val_
      init_dict[holder] = val
      var = var_struct.var

    if is_global:
      global_vars.append(var)
    else:
      local_vars.append(var)
      
    return var

  # TODO: get rid of purely_relu
  def nonlin(x):
    if purely_relu:
      return tf.nn.relu(x)
    elif purely_linear:
      return tf.identity(x)
    else:
      return tf.sigmoid(x)

  # TODO: rename into "nonlin_d"
  def d_nonlin(y):
    if purely_relu:
      return u.relu_mask(y)
    elif purely_linear:
      return 1
    else: 
      return y*(1-y)

  patches = train_images[:,:args.batch_size];
  test_patches = test_images[:,:args.batch_size];

  if args.dataset == 'cifar':
    input_dim = 3*32*32
  elif args.dataset == 'mnist':
    input_dim = 28*28
  else:
    assert False
  fs = [args.batch_size, input_dim, 1024, 1024, 1024, 196, 1024, 1024, 1024,
        input_dim]
    
  def f(i): return fs[i+1]  # W[i] has shape f[i] x f[i-1]
  n = len(fs) - 2

  # Full dataset from which new batches are sampled
  X_full = init_var(train_images, "X_full", is_global=True)

  X = init_var(patches, "X", is_global=False)  # stores local batch per model
  W = [None]*n
  W.insert(0, X)
  A = [None]*(n+2)
  A[1] = W[0]
  for i in range(1, n+1):
    init_val = ng_init(f(i), f(i-1)).astype(dtype)
    W[i] = init_var(init_val, "W_%d"%(i,), is_global=True)
    A[i+1] = nonlin(kfac_lib.matmul(W[i], A[i]))
  err = A[n+1] - A[1]
  model.loss = u.L2(err) / (2 * get_batch_size(err))

  # create test error eval
  layer0 = init_var(test_patches, "X_test", is_global=True)
  layer = layer0
  for i in range(1, n+1):
    layer = nonlin(W[i] @ layer)
  verr = (layer - layer0)
  model.vloss = u.L2(verr) / (2 * get_batch_size(verr))

  # manually compute backprop to use for sanity checking
  B = [None]*(n+1)
  B2 = [None]*(n+1)
  B[n] = err*d_nonlin(A[n+1])
  _sampled_labels_live = tf.random_normal((f(n), f(-1)), dtype=dtype, seed=0)
  if args.fixed_labels:
    _sampled_labels_live = tf.ones(shape=(f(n), f(-1)), dtype=dtype)
    
  _sampled_labels = init_var(_sampled_labels_live, "to_be_deleted",
                             is_global=False)

  B2[n] = _sampled_labels*d_nonlin(A[n+1])
  for i in range(n-1, -1, -1):
    backprop = t(W[i+1]) @ B[i+1]
    B[i] = backprop*d_nonlin(A[i+1])
    backprop2 = t(W[i+1]) @ B2[i+1]
    B2[i] = backprop2*d_nonlin(A[i+1])

  cov_A = [None]*(n+1)    # covariance of activations[i]
  cov_B2 = [None]*(n+1)   # covariance of synthetic backprops[i]
  vars_svd_A = [None]*(n+1)
  vars_svd_B2 = [None]*(n+1)
  dW = [None]*(n+1)
  dW2 = [None]*(n+1)
  pre_dW = [None]*(n+1)   # preconditioned dW
  # todo: decouple initial value from covariance update
  # maybe need start with identity and do running average
  for i in range(1,n+1):
    if regularized_svd:
      cov_A[i] = init_var(A[i]@t(A[i])/args.batch_size+args.Lambda*u.Identity(f(i-1)), "cov_A%d"%(i,))
      cov_B2[i] = init_var(B2[i]@t(B2[i])/args.batch_size+args.Lambda*u.Identity(f(i)), "cov_B2%d"%(i,))
    else:
      cov_A[i] = init_var(A[i]@t(A[i])/args.batch_size, "cov_A%d"%(i,))
      cov_B2[i] = init_var(B2[i]@t(B2[i])/args.batch_size, "cov_B2%d"%(i,))
    vars_svd_A[i] = u.SvdWrapper(cov_A[i],"svd_A_%d"%(i,))
    vars_svd_B2[i] = u.SvdWrapper(cov_B2[i],"svd_B2_%d"%(i,))
    if use_tikhonov:
      whitened_A = u.regularized_inverse3(vars_svd_A[i],L=args.Lambda) @ A[i]
      whitened_B2 = u.regularized_inverse3(vars_svd_B2[i],L=args.Lambda) @ B[i]
    else:
      whitened_A = u.pseudo_inverse2(vars_svd_A[i]) @ A[i]
      whitened_B2 = u.pseudo_inverse2(vars_svd_B2[i]) @ B[i]
    
    dW[i] = (B[i] @ t(A[i]))/args.batch_size
    dW2[i] = B[i] @ t(A[i])
    pre_dW[i] = (whitened_B2 @ t(whitened_A))/args.batch_size

    
  sampled_labels_live = A[n+1] + tf.random_normal((f(n), f(-1)),
                                                  dtype=dtype, seed=0)
  if args.fixed_labels:
    sampled_labels_live = A[n+1]+tf.ones(shape=(f(n), f(-1)), dtype=dtype)
  sampled_labels = init_var(sampled_labels_live, "sampled_labels", is_global=False)
  err2 = A[n+1] - sampled_labels
  model.loss2 = u.L2(err2) / (2 * args.batch_size)
  model.global_vars = global_vars
  model.local_vars = local_vars
  model.trainable_vars = W[1:]

  # todo, we have 3 places where model step is tracked, reduce
  model.step = init_var(u.as_int32(0), "step", is_global=False)
  advance_step_op = model.step.assign_add(1)
  assert get_batch_size(X_full) % args.batch_size == 0
  batches_per_dataset = (get_batch_size(X_full) // args.batch_size)
  batch_idx = tf.mod(model.step, batches_per_dataset)
  start_idx = batch_idx * args.batch_size
  advance_batch_op = X.assign(X_full[:,start_idx:start_idx + args.batch_size])
  
  def advance_batch():
    print("Step for model(%s) is %s"%(model.name, u.eval(model.step)))
    sess = u.get_default_session()
    # TODO: get rid of _sampled_labels
    sessrun([sampled_labels.initializer, _sampled_labels.initializer])
    if args.advance_batch:
      with u.timeit("advance_batch"):
        sessrun(advance_batch_op)
    sessrun(advance_step_op)
    
  model.advance_batch = advance_batch

  # TODO: refactor this to take initial values out of Var struct
  #global_init_op = tf.group(*[v.initializer for v in global_vars])
  global_init_ops = [v.initializer for v in global_vars]
  global_init_op = tf.group(*[v.initializer for v in global_vars])
  global_init_query_ops = [tf.logical_not(tf.is_variable_initialized(v))
                           for v in global_vars]
  
  def initialize_global_vars(verbose=False, reinitialize=False):
    """If reinitialize is false, will not reinitialize variables already
    initialized."""
    
    sess = u.get_default_session()
    if not reinitialize:
      uninited = sessrun(global_init_query_ops)
      # use numpy boolean indexing to select list of initializers to run
      to_initialize = list(np.asarray(global_init_ops)[uninited])
    else:
      to_initialize = global_init_ops
      
    if verbose:
      print("Initializing following:")
      for v in to_initialize:
        print("   " + v.name)

    sessrun(to_initialize, feed_dict=init_dict)
  model.initialize_global_vars = initialize_global_vars

  # didn't quite work (can't initialize var in same run call as deps likely)
  # enforce that batch is initialized before everything
  # except fake labels opa
  # for v in local_vars:
  #   if v != X and v != sampled_labels and v != _sampled_labels:
  #     print("Adding dep %s on %s"%(v.initializer.name, X.initializer.name))
  #     u.add_dep(v.initializer, on_op=X.initializer)
      
  local_init_op = tf.group(*[v.initializer for v in local_vars],
                           name="%s_localinit"%(model.name))
  print("Local vars:")
  for v in local_vars:
    print(v.name)
    
  def initialize_local_vars():
    sess = u.get_default_session()
    sessrun(_sampled_labels.initializer, feed_dict=init_dict)
    sessrun(X.initializer, feed_dict=init_dict)
    sessrun(local_init_op, feed_dict=init_dict)
  model.initialize_local_vars = initialize_local_vars

  return model