Ejemplo n.º 1
0
  def correct(self, grad):
    """Accepts IndexedGrad object, produces corrected version."""
    s = self

    vars_ = []
    grads_new = []

    assert list(grad) == self.model.trainable_vars

    dsize = get_batch_size(grad)
    
    for var in grad:
      vars_.append(var)
      A = s.extract_A(grad, var)    # extract activations
      B = s.extract_B(grad, var)*dsize    # extract backprops
      if s.needs_correction(var):
        # correct the gradient. Assume op is left matmul
        A_svd = s[var].A.svd
        B2_svd = s[var].B2.svd
        if inverse_method == 'pseudo_inverse':
          A_new = u.pseudo_inverse2(A_svd) @ A
          B_new = u.pseudo_inverse2(B2_svd) @ B
        elif inverse_method == 'inverse':
          A_new = A_svd.inv @ A
          B_new = B2_svd.inv @ B
        else:
          assert False
          
        dW_new = (B_new @ t(A_new)) / dsize
        grads_new.append(dW_new)
      else:  
        dW = B@t(A)/dsize   
        grads_new.append(dW)

    return IndexedGrad(grads=grads_new, vars_=vars_)
Ejemplo n.º 2
0
 def hessian_quadratic_inv(delta):
   #    update_covariances()
   W = u.unflatten(delta, fs[1:])
   W.insert(0, None)
   total = 0
   for l in range(1, n+1):
     invB2 = u.pseudo_inverse2(vars_svd_B2[l])
     invA = u.pseudo_inverse2(vars_svd_A[l])
     decrement = tf.trace(t(W[l])@invB2@W[l]@invA)
     total+=decrement
   return (total/2).eval()
Ejemplo n.º 3
0
  pre_dW = [None]*(n+1)  # preconditioned dW
  pre_dW_stable = [None]*(n+1)  # preconditioned stable dW

  cov_A = [None]*(n+1)    # covariance of activations[i]
  cov_B2 = [None]*(n+1)   # covariance of synthetic backprops[i]
  vars_svd_A = [None]*(n+1)
  vars_svd_B2 = [None]*(n+1)
  for i in range(1,n+1):
    cov_A[i] = init_var(A[i]@t(A[i])/dsize, "cov_A%d"%(i,))
    cov_B2[i] = init_var(B2[i]@t(B2[i])/dsize, "cov_B2%d"%(i,))
    vars_svd_A[i] = u.SvdWrapper(cov_A[i],"svd_A_%d"%(i,))
    vars_svd_B2[i] = u.SvdWrapper(cov_B2[i],"svd_B2_%d"%(i,))
    if use_tikhonov:
      whitened_A = u.regularized_inverse2(vars_svd_A[i],L=Lambda) @ A[i]
    else:
      whitened_A = u.pseudo_inverse2(vars_svd_A[i]) @ A[i]
    if use_tikhonov:
      whitened_B2 = u.regularized_inverse2(vars_svd_B2[i],L=Lambda) @ B[i]
    else:
      whitened_B2 = u.pseudo_inverse2(vars_svd_B2[i]) @ B[i]
    whitened_A_stable = u.pseudo_inverse_sqrt2(vars_svd_A[i]) @ A[i]
    whitened_B2_stable = u.pseudo_inverse_sqrt2(vars_svd_B2[i]) @ B[i]
    pre_dW[i] = (whitened_B2 @ t(whitened_A))/dsize
    pre_dW_stable[i] = (whitened_B2_stable @ t(whitened_A_stable))/dsize
    dW[i] = (B[i] @ t(A[i]))/dsize

  # Loss function
  reconstruction = u.L2(err) / (2 * dsize)
  sparsity = beta * tf.reduce_sum(kl(rho, rho_hat))
  L2 = (lambda_ / 2) * (u.L2(W[1]) + u.L2(W[1]))
Ejemplo n.º 4
0
    # TODO: add tiling for natural sampling
    cov_A = [None] * (n + 1)
    cov_B2 = [None] * (n + 1)
    vars_svd_A = [None] * (n + 1)
    vars_svd_B2 = [None] * (n + 1)
    # eps=1e-5 gives same behavior on mnist (converge in 17 steps)
    eps_to_use = 1e-7
    for i in range(1, n + 1):
        cov_A[i] = init_var(A[i] @ t(A[i]) / dsize, "cov_A%d" % (i, ))
        cov_B2[i] = init_var(B2[i] @ t(B2[i]) / dsize, "cov_B2%d" % (i, ))
        vars_svd_A[i] = SvdWrapper(cov_A[i], "svd_A_%d" % (i, ))
        vars_svd_B2[i] = SvdWrapper(cov_B2[i], "svd_B2_%d" % (i, ))
        whitened_A = u.regularized_inverse(cov_A[i]) @ A[i]
        whitened_A = u.pseudo_inverse_sqrt2(vars_svd_A[i],
                                            eps=eps_to_use) @ A[i]
        whitened_A = u.pseudo_inverse2(vars_svd_A[i], eps=eps_to_use) @ A[i]
        #whitened_A = u.pseudo_inverse_scipy(cov_A[i]) @ A[i]
        # raise epsilon because b's get weird
        #    whitened_B2 = u.pseudo_inverse_scipy(cov_B2[i]) @ B[i]
        whitened_B2 = u.pseudo_inverse_scipy(cov_B2[i]) @ B[i]
        whitened_B2 = u.regularized_inverse(cov_B2[i]) @ B[i]
        whitened_B2 = u.pseudo_inverse_sqrt2(vars_svd_B2[i],
                                             eps=eps_to_use) @ B[i]
        whitened_B2 = u.pseudo_inverse2(vars_svd_B2[i], eps=eps_to_use) @ B[i]
        pre_dW[i] = (whitened_B2 @ t(whitened_A)) / dsize
        dW[i] = (B[i] @ t(A[i])) / dsize

    # Loss function
    reconstruction = u.L2(err) / (2 * dsize)
    sparsity = beta * tf.reduce_sum(kl(rho, rho_hat))
    L2 = (lambda_ / 2) * (u.L2(W[1]) + u.L2(W[1]))
Ejemplo n.º 5
0
def model_creator(batch_size, name='defaultmodel', dtype=np.float32):
    """Create MNIST autoencoder model. Dataset is part of model."""

    global hack_global_init_dict

    model = Model(name)

    # TODO: actually use batch_size
    init_dict = {}  # todo: rename to feed_dict?
    global_vars = []
    local_vars = []

    # TODO: rename to make_var
    def init_var(val, name, is_global=False):
        """Helper to create variables with numpy or TF initial values."""
        if isinstance(val, tf.Tensor):
            var = u.get_variable(name=name, initializer=val, reuse=is_global)
        else:
            val = np.array(val)
            assert u.is_numeric(val), "Non-numeric type."

            var_struct = u.get_var(name=name, initializer=val, reuse=is_global)
            holder = var_struct.val_
            init_dict[holder] = val
            var = var_struct.var

        if is_global:
            global_vars.append(var)
        else:
            local_vars.append(var)

        return var

    # TODO: get rid of purely_relu
    def nonlin(x):
        if purely_relu:
            return tf.nn.relu(x)
        elif purely_linear:
            return tf.identity(x)
        else:
            return tf.sigmoid(x)

    # TODO: rename into "nonlin_d"
    def d_nonlin(y):
        if purely_relu:
            return u.relu_mask(y)
        elif purely_linear:
            return 1
        else:
            return y * (1 - y)

    train_images = load_MNIST.load_MNIST_images(
        'data/train-images-idx3-ubyte').astype(dtype)
    patches = train_images[:, :batch_size]
    fs = [batch_size, 28 * 28, 196, 28 * 28]

    def f(i):
        return fs[i + 1]  # W[i] has shape f[i] x f[i-1]

    n = len(fs) - 2

    X = init_var(patches, "X", is_global=False)
    W = [None] * n
    W.insert(0, X)
    A = [None] * (n + 2)
    A[1] = W[0]
    W0f_old = W_uniform(fs[2],
                        fs[3]).astype(dtype)  # to match previous generation
    W0s_old = u.unflatten(W0f_old, fs[1:])  # perftodo: this creates transposes
    for i in range(1, n + 1):
        #    temp = init_var(ng_init(f(i), f(i-1)), "W_%d"%(i,), is_global=True)
        #    init_val1 = W0s_old[i-1]
        init_val = ng_init(f(i), f(i - 1)).astype(dtype)
        W[i] = init_var(init_val, "W_%d" % (i, ), is_global=True)
        A[i + 1] = nonlin(kfac_lib.matmul(W[i], A[i]))

    err = A[n + 1] - A[1]

    # manually compute backprop to use for sanity checking
    B = [None] * (n + 1)
    B2 = [None] * (n + 1)
    B[n] = err * d_nonlin(A[n + 1])
    _sampled_labels_live = tf.random_normal((f(n), f(-1)), dtype=dtype, seed=0)
    if use_fixed_labels:
        _sampled_labels_live = tf.ones(shape=(f(n), f(-1)), dtype=dtype)

    _sampled_labels = init_var(_sampled_labels_live,
                               "to_be_deleted",
                               is_global=False)

    B2[n] = _sampled_labels * d_nonlin(A[n + 1])
    for i in range(n - 1, -1, -1):
        backprop = t(W[i + 1]) @ B[i + 1]
        B[i] = backprop * d_nonlin(A[i + 1])
        backprop2 = t(W[i + 1]) @ B2[i + 1]
        B2[i] = backprop2 * d_nonlin(A[i + 1])

    cov_A = [None] * (n + 1)  # covariance of activations[i]
    cov_B2 = [None] * (n + 1)  # covariance of synthetic backprops[i]
    vars_svd_A = [None] * (n + 1)
    vars_svd_B2 = [None] * (n + 1)
    dW = [None] * (n + 1)
    dW2 = [None] * (n + 1)
    pre_dW = [None] * (n + 1)  # preconditioned dW
    for i in range(1, n + 1):
        if regularized_svd:
            cov_A[i] = init_var(
                A[i] @ t(A[i]) / batch_size + LAMBDA * u.Identity(f(i - 1)),
                "cov_A%d" % (i, ))
            cov_B2[i] = init_var(
                B2[i] @ t(B2[i]) / batch_size + LAMBDA * u.Identity(f(i)),
                "cov_B2%d" % (i, ))
        else:
            cov_A[i] = init_var(A[i] @ t(A[i]) / batch_size, "cov_A%d" % (i, ))
            cov_B2[i] = init_var(B2[i] @ t(B2[i]) / batch_size,
                                 "cov_B2%d" % (i, ))
        vars_svd_A[i] = u.SvdWrapper(cov_A[i], "svd_A_%d" % (i, ))
        vars_svd_B2[i] = u.SvdWrapper(cov_B2[i], "svd_B2_%d" % (i, ))
        if use_tikhonov:
            whitened_A = u.regularized_inverse3(vars_svd_A[i], L=LAMBDA) @ A[i]
            whitened_B2 = u.regularized_inverse3(vars_svd_B2[i],
                                                 L=LAMBDA) @ B[i]
        else:
            whitened_A = u.pseudo_inverse2(vars_svd_A[i]) @ A[i]
            whitened_B2 = u.pseudo_inverse2(vars_svd_B2[i]) @ B[i]

        dW[i] = (B[i] @ t(A[i])) / batch_size
        dW2[i] = B[i] @ t(A[i])
        pre_dW[i] = (whitened_B2 @ t(whitened_A)) / batch_size

        #  model.extra['A'] = A
        #  model.extra['B'] = B
        #  model.extra['B2'] = B2
        #  model.extra['cov_A'] = cov_A
        #  model.extra['cov_B2'] = cov_B2
        #  model.extra['vars_svd_A'] = vars_svd_A
        #  model.extra['vars_svd_B2'] = vars_svd_B2
        #  model.extra['W'] = W
        #  model.extra['dW'] = dW
        #  model.extra['dW2'] = dW2
        #  model.extra['pre_dW'] = pre_dW

    model.loss = u.L2(err) / (2 * batch_size)
    sampled_labels_live = A[n + 1] + tf.random_normal(
        (f(n), f(-1)), dtype=dtype, seed=0)
    if use_fixed_labels:
        sampled_labels_live = A[n + 1] + tf.ones(shape=(f(n), f(-1)),
                                                 dtype=dtype)
    sampled_labels = init_var(sampled_labels_live,
                              "sampled_labels",
                              is_global=False)
    err2 = A[n + 1] - sampled_labels
    model.loss2 = u.L2(err2) / (2 * batch_size)
    model.global_vars = global_vars
    model.local_vars = local_vars
    model.trainable_vars = W[1:]

    def advance_batch():
        sess = tf.get_default_session()
        # TODO: get rid of _sampled_labels
        sess.run([sampled_labels.initializer, _sampled_labels.initializer])

    model.advance_batch = advance_batch

    global_init_op = tf.group(*[v.initializer for v in global_vars])

    def initialize_global_vars():
        sess = tf.get_default_session()
        sess.run(global_init_op, feed_dict=init_dict)

    model.initialize_global_vars = initialize_global_vars

    local_init_op = tf.group(*[v.initializer for v in local_vars])

    def initialize_local_vars():
        sess = tf.get_default_session()
        sess.run(X.initializer, feed_dict=init_dict)  # A's depend on X
        sess.run(_sampled_labels.initializer, feed_dict=init_dict)
        sess.run(local_init_op, feed_dict=init_dict)

    model.initialize_local_vars = initialize_local_vars

    hack_global_init_dict = init_dict

    return model
Ejemplo n.º 6
0
def kfac_optimizer(model_creator):
    stats_batch_size = 10000
    main_batch_size = 10000

    stats_model, loss, labels = model_creator(stats_batch_size)
    # replace labels_node with synthetic labels

    main_model, _, _ = model_creator(main_batch_size)

    opt = tf.GradientDescentOptimizer(0.2)
    grads_and_vars = opt.compute_gradients(loss)

    trainable_vars = tf.trainable_variables()

    # create SVD and preconditioning variables for matmul vars
    for var in trainable_vars:
        if var not in matmul_registry:
            continue
        dW = u.extract_grad(grads_and_vars, var)
        A[var] = get_activations(var)
        B[var] = get_backprops(var)
        B2[var] = get_backprops2(var)  # get backprops with synthetic labels
        dW[var] = B[var] @ t(A[var])  # todo: sort out dsize division
        cov_A[var] = init_var(A[var] @ t(A[var]) / dsize,
                              "cov_A_%s" % (var.name, ))
        cov_B2[var] = init_var(B2[var] @ t(B2[var]) / dsize,
                               "cov_B2_%s" % (var.name, ))

        vars_svd_A[var] = SvdWrapper(cov_A[var], "svd_A_%d" % (var.name, ))
        vars_svd_B2[var] = SvdWrapper(cov_B2[var], "svd_B2_%d" % (var.name, ))
        whitened_A = u.pseudo_inverse2(vars_svd_A[var]) @ A[var]
        whitened_B2 = u.pseudo_inverse2(vars_svd_B2[var]) @ B[var]
        whitened_A_stable = u.pseudo_inverse_sqrt2(vars_svd_A[var]) @ A[var]
        whitened_B2_stable = u.pseudo_inverse_sqrt2(vars_svd_B2[var]) @ B[var]

        pre_dW[var] = (whitened_B2 @ t(whitened_A)) / dsize
        pre_dW_stable[var] = (
            whitened_B2_stable @ t(whitened_A_stable)) / dsize
        dW[var] = (B[var] @ t(A[var])) / dsize

    # create update params ops

    # new_grads_and_vars = []
    # for grad, var in grads_and_vars:
    #   if var in kfac_registry:
    #     pre_A, pre_B = kfac_registry[var]
    #     new_grad_live = pre_B @ grad @ t(pre_A)
    #     new_grads_and_vars.append((new_grad, var))
    #     print("Preconditioning %s"%(var.name))
    #   else:
    #     new_grads_and_vars.append((grad, var))
    # train_op = opt.apply_gradients(new_grads_and_vars)

    # Each variable has an associated gradient, pre_gradient, variable save op
    def update_grad():
        ops = [grad_update_ops[var] for var in trainable_vars]
        sess.run(ops)

    def update_pre_grad():
        ops = [pre_grad_update_ops[var] for var in trainable_vars]
        sess.run(ops)

    def update_pre_grad2():
        ops = [pre_grad2_update_ops[var] for var in trainable_vars]
        sess.run(ops)

    def save_params():
        ops = [var_save_ops[var] for var in trainable_vars]
        sess.run(ops)

    for step in range(num_steps):
        update_covariances()
        if step % whitened_every_n_steps == 0:
            update_svds()

        update_grad()
        update_pre_grad()  # perf todo: update one of these
        update_pre_grad2()  # stable alternative

        lr0, loss0 = sess.run([lr, loss])
        save_params()

        # when grad norm<1, Fisher is unstable, switch to Sqrt(Fisher)
        # TODO: switch to per-matrix normalization
        stabilized_mode = grad_norm.eval() < 1

        if stabilized_mode:
            update_params2()
        else:
            update_params()

        loss1 = loss.eval()
        advance_batch()

        # line search stuff
        target_slope = (-pre_grad_dot_grad.eval() if stabilized_mode else
                        -pre_grad_stable_dot_grad.eval())
        target_delta = lr0 * target_slope
        actual_delta = loss1 - loss0
        actual_slope = actual_delta / lr0
        slope_ratio = actual_slope / target_slope  # between 0 and 1.01

        losses.append(loss0)
        step_lengths.append(lr0)
        ratios.append(slope_ratio)

        if step % report_frequency == 0:
            print(
                "Step %d loss %.2f, target decrease %.3f, actual decrease, %.3f ratio %.2f grad norm: %.2f pregrad norm: %.2f"
                % (step, loss0, target_delta, actual_delta, slope_ratio,
                   grad_norm.eval(), pre_grad_norm.eval()))

        u.record_time()
Ejemplo n.º 7
0
def model_creator(batch_size, name='defaultmodel', dtype=np.float32):
    """Create MNIST autoencoder model. Dataset is part of model."""

    model = Model(name)

    init_dict = {}
    global_vars = []
    local_vars = []

    # TODO: factor out to reuse between scripts
    # TODO: change feed_dict logic to reuse value provided to VarStruct
    # current situation makes reinitialization of global variable change
    # it's value, counterinituitive
    def init_var(val, name, is_global=False):
        """Helper to create variables with numpy or TF initial values."""
        if isinstance(val, tf.Tensor):
            var = u.get_variable(name=name, initializer=val, reuse=is_global)
        else:
            val = np.array(val)
            assert u.is_numeric(val), "Non-numeric type."

            var_struct = u.get_var(name=name, initializer=val, reuse=is_global)
            holder = var_struct.val_
            init_dict[holder] = val
            var = var_struct.var

        if is_global:
            global_vars.append(var)
        else:
            local_vars.append(var)

        return var

    # TODO: get rid of purely_relu
    def nonlin(x):
        if purely_relu:
            return tf.nn.relu(x)
        elif purely_linear:
            return tf.identity(x)
        else:
            return tf.sigmoid(x)

    # TODO: rename into "nonlin_d"
    def d_nonlin(y):
        if purely_relu:
            return u.relu_mask(y)
        elif purely_linear:
            return 1
        else:
            return y * (1 - y)

    train_images = load_MNIST.load_MNIST_images(
        'data/train-images-idx3-ubyte').astype(dtype)
    patches = train_images[:, :batch_size]
    test_patches = train_images[:, -batch_size:]
    assert dsize < 25000

    fs = [
        batch_size, 28 * 28, 1024, 1024, 1024, 196, 1024, 1024, 1024, 28 * 28
    ]

    def f(i):
        return fs[i + 1]  # W[i] has shape f[i] x f[i-1]

    n = len(fs) - 2

    X = init_var(patches, "X", is_global=False)
    W = [None] * n
    W.insert(0, X)
    A = [None] * (n + 2)
    A[1] = W[0]
    for i in range(1, n + 1):
        init_val = ng_init(f(i), f(i - 1)).astype(dtype)
        W[i] = init_var(init_val, "W_%d" % (i, ), is_global=True)
        A[i + 1] = nonlin(kfac_lib.matmul(W[i], A[i]))

    err = A[n + 1] - A[1]

    # create test error eval
    layer = init_var(test_patches, "X_test", is_global=False)
    for i in range(1, n + 1):
        layer = nonlin(W[i] @ layer)
    verr = (layer - test_patches)
    model.vloss = u.L2(verr) / (2 * batch_size)

    # manually compute backprop to use for sanity checking
    B = [None] * (n + 1)
    B2 = [None] * (n + 1)
    B[n] = err * d_nonlin(A[n + 1])
    _sampled_labels_live = tf.random_normal((f(n), f(-1)), dtype=dtype, seed=0)
    if use_fixed_labels:
        _sampled_labels_live = tf.ones(shape=(f(n), f(-1)), dtype=dtype)

    _sampled_labels = init_var(_sampled_labels_live,
                               "to_be_deleted",
                               is_global=False)

    B2[n] = _sampled_labels * d_nonlin(A[n + 1])
    for i in range(n - 1, -1, -1):
        backprop = t(W[i + 1]) @ B[i + 1]
        B[i] = backprop * d_nonlin(A[i + 1])
        backprop2 = t(W[i + 1]) @ B2[i + 1]
        B2[i] = backprop2 * d_nonlin(A[i + 1])

    cov_A = [None] * (n + 1)  # covariance of activations[i]
    cov_B2 = [None] * (n + 1)  # covariance of synthetic backprops[i]
    vars_svd_A = [None] * (n + 1)
    vars_svd_B2 = [None] * (n + 1)
    dW = [None] * (n + 1)
    dW2 = [None] * (n + 1)
    pre_dW = [None] * (n + 1)  # preconditioned dW
    for i in range(1, n + 1):
        if regularized_svd:
            cov_A[i] = init_var(
                A[i] @ t(A[i]) / batch_size + LAMBDA * u.Identity(f(i - 1)),
                "cov_A%d" % (i, ))
            cov_B2[i] = init_var(
                B2[i] @ t(B2[i]) / batch_size + LAMBDA * u.Identity(f(i)),
                "cov_B2%d" % (i, ))
        else:
            cov_A[i] = init_var(A[i] @ t(A[i]) / batch_size, "cov_A%d" % (i, ))
            cov_B2[i] = init_var(B2[i] @ t(B2[i]) / batch_size,
                                 "cov_B2%d" % (i, ))
        vars_svd_A[i] = u.SvdWrapper(cov_A[i], "svd_A_%d" % (i, ))
        vars_svd_B2[i] = u.SvdWrapper(cov_B2[i], "svd_B2_%d" % (i, ))
        if use_tikhonov:
            whitened_A = u.regularized_inverse3(vars_svd_A[i], L=LAMBDA) @ A[i]
            whitened_B2 = u.regularized_inverse3(vars_svd_B2[i],
                                                 L=LAMBDA) @ B[i]
        else:
            whitened_A = u.pseudo_inverse2(vars_svd_A[i]) @ A[i]
            whitened_B2 = u.pseudo_inverse2(vars_svd_B2[i]) @ B[i]

        dW[i] = (B[i] @ t(A[i])) / batch_size
        dW2[i] = B[i] @ t(A[i])
        pre_dW[i] = (whitened_B2 @ t(whitened_A)) / batch_size

        #  model.extra['A'] = A
        #  model.extra['B'] = B
        #  model.extra['B2'] = B2
        #  model.extra['cov_A'] = cov_A
        #  model.extra['cov_B2'] = cov_B2
        #  model.extra['vars_svd_A'] = vars_svd_A
        #  model.extra['vars_svd_B2'] = vars_svd_B2
        #  model.extra['W'] = W
        #  model.extra['dW'] = dW
        #  model.extra['dW2'] = dW2
        #  model.extra['pre_dW'] = pre_dW

    model.loss = u.L2(err) / (2 * batch_size)
    sampled_labels_live = A[n + 1] + tf.random_normal(
        (f(n), f(-1)), dtype=dtype, seed=0)
    if use_fixed_labels:
        sampled_labels_live = A[n + 1] + tf.ones(shape=(f(n), f(-1)),
                                                 dtype=dtype)
    sampled_labels = init_var(sampled_labels_live,
                              "sampled_labels",
                              is_global=False)
    err2 = A[n + 1] - sampled_labels
    model.loss2 = u.L2(err2) / (2 * batch_size)
    model.global_vars = global_vars
    model.local_vars = local_vars
    model.trainable_vars = W[1:]

    def advance_batch():
        sess = tf.get_default_session()
        # TODO: get rid of _sampled_labels
        sess.run([sampled_labels.initializer, _sampled_labels.initializer])

    model.advance_batch = advance_batch

    # TODO: refactor this to take initial values out of Var struct
    #global_init_op = tf.group(*[v.initializer for v in global_vars])
    global_init_ops = [v.initializer for v in global_vars]
    global_init_op = tf.group(*[v.initializer for v in global_vars])

    global_init_query_op = [
        tf.logical_not(tf.is_variable_initialized(v)) for v in global_vars
    ]

    def initialize_global_vars(verbose=False, reinitialize=False):
        """If reinitialize is false, will not reinitialize variables already
    initialized."""

        sess = tf.get_default_session()
        if not reinitialize:
            uninited = sess.run(global_init_query_op)
            # use numpy boolean indexing to select list of initializers to run
            to_initialize = list(np.asarray(global_init_ops)[uninited])
        else:
            to_initialize = global_init_ops

        if verbose:
            print("Initializing following:")
            for v in to_initialize:
                print("   " + v.name)

        sess.run(to_initialize, feed_dict=init_dict)

    model.initialize_global_vars = initialize_global_vars

    local_init_op = tf.group(*[v.initializer for v in local_vars])

    def initialize_local_vars():
        sess = tf.get_default_session()
        sess.run(X.initializer, feed_dict=init_dict)  # A's depend on X
        sess.run(_sampled_labels.initializer, feed_dict=init_dict)
        sess.run(local_init_op, feed_dict=init_dict)

    model.initialize_local_vars = initialize_local_vars

    return model
Ejemplo n.º 8
0
def main():
    np.random.seed(0)
    tf.set_random_seed(0)

    dtype = np.float32
    # 64-bit doesn't help much, search for 64-bit in
    # https://www.wolframcloud.com/objects/5f297f41-30f7-4b1b-972c-cac8d1f8d8e4
    u.default_dtype = dtype
    machine_epsilon = np.finfo(dtype).eps  # 1e-7 or 1e-16
    train_images = load_MNIST.load_MNIST_images('data/train-images-idx3-ubyte')
    dsize = 10000
    patches = train_images[:, :dsize]
    fs = [dsize, 28 * 28, 196, 28 * 28]

    # values from deeplearning.stanford.edu/wiki/index.php/UFLDL_Tutorial
    X0 = patches
    lambda_ = 3e-3
    rho = tf.constant(0.1, dtype=dtype)
    beta = 3
    W0f = W_uniform(fs[2], fs[3])

    def f(i):
        return fs[i + 1]  # W[i] has shape f[i] x f[i-1]

    dsize = f(-1)
    n = len(fs) - 2

    # helper to create variables with numpy or TF initial value
    init_dict = {}  # {var_placeholder: init_value}
    vard = {}  # {var: util.VarInfo}

    def init_var(val, name, trainable=False, noinit=False):
        if isinstance(val, tf.Tensor):
            collections = [] if noinit else None
            var = tf.Variable(val, name=name, collections=collections)
        else:
            val = np.array(val)
            assert u.is_numeric, "Unknown type"
            holder = tf.placeholder(dtype,
                                    shape=val.shape,
                                    name=name + "_holder")
            var = tf.Variable(holder, name=name, trainable=trainable)
            init_dict[holder] = val
        var_p = tf.placeholder(var.dtype, var.shape)
        var_setter = var.assign(var_p)
        vard[var] = u.VarInfo(var_setter, var_p)
        return var

    lr = init_var(0.2, "lr")
    if purely_linear:  # need lower LR without sigmoids
        lr = init_var(.02, "lr")

    Wf = init_var(W0f, "Wf", True)
    Wf_copy = init_var(W0f, "Wf_copy", True)
    W = u.unflatten(Wf, fs[1:])  # perftodo: this creates transposes
    X = init_var(X0, "X")
    W.insert(0, X)

    def sigmoid(x):
        if not purely_linear:
            return tf.sigmoid(x)
        else:
            return tf.identity(x)

    def d_sigmoid(y):
        if not purely_linear:
            return y * (1 - y)
        else:
            return 1

    def kl(x, y):
        return x * tf.log(x / y) + (1 - x) * tf.log((1 - x) / (1 - y))

    def d_kl(x, y):
        return (1 - x) / (1 - y) - x / y

    # A[i] = activations needed to compute gradient of W[i]
    # A[n+1] = network output
    A = [None] * (n + 2)

    # A[0] is just for shape checks, assert fail on run
    # tf.assert always fails because of static assert
    # fail_node = tf.assert_equal(1, 0, message="too huge")
    fail_node = tf.Print(0, [0], "fail, this must never run")
    with tf.control_dependencies([fail_node]):
        A[0] = u.Identity(dsize, dtype=dtype)
    A[1] = W[0]
    for i in range(1, n + 1):
        A[i + 1] = sigmoid(W[i] @ A[i])

    # reconstruction error and sparsity error
    err = (A[3] - A[1])
    rho_hat = tf.reduce_sum(A[2], axis=1, keep_dims=True) / dsize

    # B[i] = backprops needed to compute gradient of W[i]
    # B2[i] = backprops from sampled labels needed for natural gradient
    B = [None] * (n + 1)
    B2 = [None] * (n + 1)
    B[n] = err * d_sigmoid(A[n + 1])
    sampled_labels_live = tf.random_normal((f(n), f(-1)), dtype=dtype, seed=0)
    sampled_labels = init_var(sampled_labels_live,
                              "sampled_labels",
                              noinit=True)
    B2[n] = sampled_labels * d_sigmoid(A[n + 1])
    for i in range(n - 1, -1, -1):
        backprop = t(W[i + 1]) @ B[i + 1]
        backprop2 = t(W[i + 1]) @ B2[i + 1]
        if i == 1 and not drop_sparsity:
            backprop += beta * d_kl(rho, rho_hat)
            backprop2 += beta * d_kl(rho, rho_hat)
        B[i] = backprop * d_sigmoid(A[i + 1])
        B2[i] = backprop2 * d_sigmoid(A[i + 1])

    # dW[i] = gradient of W[i]
    dW = [None] * (n + 1)
    pre_dW = [None] * (n + 1)  # preconditioned dW
    pre_dW_stable = [None] * (n + 1)  # preconditioned stable dW

    cov_A = [None] * (n + 1)  # covariance of activations[i]
    cov_B2 = [None] * (n + 1)  # covariance of synthetic backprops[i]
    vars_svd_A = [None] * (n + 1)
    vars_svd_B2 = [None] * (n + 1)
    for i in range(1, n + 1):
        cov_A[i] = init_var(A[i] @ t(A[i]) / dsize, "cov_A%d" % (i, ))
        cov_B2[i] = init_var(B2[i] @ t(B2[i]) / dsize, "cov_B2%d" % (i, ))
        vars_svd_A[i] = u.SvdWrapper(cov_A[i], "svd_A_%d" % (i, ))
        vars_svd_B2[i] = u.SvdWrapper(cov_B2[i], "svd_B2_%d" % (i, ))
        if use_tikhonov:
            whitened_A = u.regularized_inverse2(vars_svd_A[i], L=Lambda) @ A[i]
        else:
            whitened_A = u.pseudo_inverse2(vars_svd_A[i]) @ A[i]
        if use_tikhonov:
            whitened_B2 = u.regularized_inverse2(vars_svd_B2[i],
                                                 L=Lambda) @ B[i]
        else:
            whitened_B2 = u.pseudo_inverse2(vars_svd_B2[i]) @ B[i]
        whitened_A_stable = u.pseudo_inverse_sqrt2(vars_svd_A[i]) @ A[i]
        whitened_B2_stable = u.pseudo_inverse_sqrt2(vars_svd_B2[i]) @ B[i]
        pre_dW[i] = (whitened_B2 @ t(whitened_A)) / dsize
        pre_dW_stable[i] = (whitened_B2_stable @ t(whitened_A_stable)) / dsize
        dW[i] = (B[i] @ t(A[i])) / dsize

    # Loss function
    reconstruction = u.L2(err) / (2 * dsize)
    sparsity = beta * tf.reduce_sum(kl(rho, rho_hat))
    L2 = (lambda_ / 2) * (u.L2(W[1]) + u.L2(W[1]))

    loss = reconstruction
    if not drop_l2:
        loss = loss + L2
    if not drop_sparsity:
        loss = loss + sparsity

    grad_live = u.flatten(dW[1:])
    pre_grad_live = u.flatten(pre_dW[1:])  # fisher preconditioned gradient
    pre_grad_stable_live = u.flatten(
        pre_dW_stable[1:])  # sqrt fisher preconditioned grad
    grad = init_var(grad_live, "grad")
    pre_grad = init_var(pre_grad_live, "pre_grad")
    pre_grad_stable = init_var(pre_grad_stable_live, "pre_grad_stable")

    update_params_op = Wf.assign(Wf - lr * pre_grad).op
    update_params_stable_op = Wf.assign(Wf - lr * pre_grad_stable).op
    save_params_op = Wf_copy.assign(Wf).op
    pre_grad_dot_grad = tf.reduce_sum(pre_grad * grad)
    pre_grad_stable_dot_grad = tf.reduce_sum(pre_grad * grad)
    grad_norm = tf.reduce_sum(grad * grad)
    pre_grad_norm = u.L2(pre_grad)
    pre_grad_stable_norm = u.L2(pre_grad_stable)

    def dump_svd_info(step):
        """Dump singular values and gradient values in those coordinates."""
        for i in range(1, n + 1):
            svd = vars_svd_A[i]
            s0, u0, v0 = sess.run([svd.s, svd.u, svd.v])
            util.dump(s0, "A_%d_%d" % (i, step))
            A0 = A[i].eval()
            At0 = v0.T @ A0
            util.dump(A0 @ A0.T, "Acov_%d_%d" % (i, step))
            util.dump(At0 @ At0.T, "Atcov_%d_%d" % (i, step))
            util.dump(s0, "As_%d_%d" % (i, step))

        for i in range(1, n + 1):
            svd = vars_svd_B2[i]
            s0, u0, v0 = sess.run([svd.s, svd.u, svd.v])
            util.dump(s0, "B2_%d_%d" % (i, step))
            B0 = B[i].eval()
            Bt0 = v0.T @ B0
            util.dump(B0 @ B0.T, "Bcov_%d_%d" % (i, step))
            util.dump(Bt0 @ Bt0.T, "Btcov_%d_%d" % (i, step))
            util.dump(s0, "Bs_%d_%d" % (i, step))

    def advance_batch():
        sess.run(sampled_labels.initializer)  # new labels for next call

    def update_covariances():
        ops_A = [cov_A[i].initializer for i in range(1, n + 1)]
        ops_B2 = [cov_B2[i].initializer for i in range(1, n + 1)]
        sess.run(ops_A + ops_B2)

    def update_svds():
        if whitening_mode > 1:
            vars_svd_A[2].update()
        if whitening_mode > 2:
            vars_svd_B2[2].update()
        if whitening_mode > 3:
            vars_svd_B2[1].update()

    def init_svds():
        """Initialize our SVD to identity matrices."""
        ops = []
        for i in range(1, n + 1):
            ops.extend(vars_svd_A[i].init_ops)
            ops.extend(vars_svd_B2[i].init_ops)
        sess = tf.get_default_session()
        sess.run(ops)

    init_op = tf.global_variables_initializer()
    #  tf.get_default_graph().finalize()

    from tensorflow.core.protobuf import rewriter_config_pb2

    rewrite_options = rewriter_config_pb2.RewriterConfig(
        disable_model_pruning=True,
        constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
        memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
    optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L0)
    graph_options = tf.GraphOptions(optimizer_options=optimizer_options,
                                    rewrite_options=rewrite_options)
    config = tf.ConfigProto(graph_options=graph_options)
    #sess = tf.Session(config=config)
    sess = tf.InteractiveSession(config=config)
    sess.run(Wf.initializer, feed_dict=init_dict)
    sess.run(X.initializer, feed_dict=init_dict)
    advance_batch()
    update_covariances()
    init_svds()
    sess.run(init_op, feed_dict=init_dict)  # initialize everything else

    print("Running training.")
    u.reset_time()

    step_lengths = []  # keep track of learning rates
    losses = []
    ratios = []  # actual loss decrease / expected decrease
    grad_norms = []
    pre_grad_norms = []  # preconditioned grad norm squared
    pre_grad_stable_norms = []  # sqrt preconditioned grad norms squared
    target_delta_list = []  # predicted decrease linear approximation
    target_delta2_list = []  # predicted decrease quadratic appromation
    actual_delta_list = []  # actual decrease

    # adaptive line search parameters
    alpha = 0.3  # acceptable fraction of predicted decrease
    beta = 0.8  # how much to shrink when violation
    growth_rate = 1.05  # how much to grow when too conservative

    def update_cov_A(i):
        sess.run(cov_A[i].initializer)

    def update_cov_B2(i):
        sess.run(cov_B2[i].initializer)

    # only update whitening matrix of input activations in the beginning
    if whitening_mode > 0:
        vars_svd_A[1].update()

    # compute t(delta).H.delta/2
    def hessian_quadratic(delta):
        #    update_covariances()
        W = u.unflatten(delta, fs[1:])
        W.insert(0, None)
        total = 0
        for l in range(1, n + 1):
            decrement = tf.trace(t(W[l]) @ cov_B2[l] @ W[l] @ cov_A[l])
            total += decrement
        return (total / 2).eval()

    # compute t(delta).H^-1.delta/2
    def hessian_quadratic_inv(delta):
        #    update_covariances()
        W = u.unflatten(delta, fs[1:])
        W.insert(0, None)
        total = 0
        for l in range(1, n + 1):
            invB2 = u.pseudo_inverse2(vars_svd_B2[l])
            invA = u.pseudo_inverse2(vars_svd_A[l])
            decrement = tf.trace(t(W[l]) @ invB2 @ W[l] @ invA)
            total += decrement
        return (total / 2).eval()

    # do line search, dump values as csv
    def line_search(initial_value, direction, step, num_steps):
        saved_val = tf.Variable(Wf)
        sess.run(saved_val.initializer)
        pl = tf.placeholder(dtype, shape=(), name="linesearch_p")
        assign_op = Wf.assign(initial_value - direction * step * pl)
        vals = []
        for i in range(num_steps):
            sess.run(assign_op, feed_dict={pl: i})
            vals.append(loss.eval())
        sess.run(Wf.assign(saved_val))  # restore original value
        return vals

    for step in range(num_steps):
        update_covariances()
        if step % whiten_every_n_steps == 0:
            update_svds()

        sess.run(grad.initializer)
        sess.run(pre_grad.initializer)

        lr0, loss0 = sess.run([lr, loss])
        save_params_op.run()

        # regular inverse becomes unstable when grad norm exceeds 1
        stabilized_mode = grad_norm.eval() < 1

        if stabilized_mode and not use_tikhonov:
            update_params_stable_op.run()
        else:
            update_params_op.run()

        loss1 = loss.eval()
        advance_batch()

        # line search stuff
        target_slope = (-pre_grad_dot_grad.eval() if stabilized_mode else
                        -pre_grad_stable_dot_grad.eval())
        target_delta = lr0 * target_slope
        target_delta_list.append(target_delta)

        # second order prediction of target delta
        # TODO: the sign is wrong, debug this
        # https://www.wolframcloud.com/objects/8f287f2f-ceb7-42f7-a599-1c03fda18f28
        if local_quadratics:
            x0 = Wf_copy.eval()
            x_opt = x0 - pre_grad.eval()
            # computes t(x)@H^-1 @(x)/2
            y_opt = loss0 - hessian_quadratic_inv(grad)
            # computes t(x)@H @(x)/2
            y_expected = hessian_quadratic(Wf - x_opt) + y_opt
            target_delta2 = y_expected - loss0
            target_delta2_list.append(target_delta2)

        actual_delta = loss1 - loss0
        actual_slope = actual_delta / lr0
        slope_ratio = actual_slope / target_slope  # between 0 and 1.01
        actual_delta_list.append(actual_delta)

        if do_line_search:
            vals1 = line_search(Wf_copy, pre_grad, lr / 100, 40)
            vals2 = line_search(Wf_copy, grad, lr / 100, 40)
            u.dump(vals1, "line1-%d" % (i, ))
            u.dump(vals2, "line2-%d" % (i, ))

        losses.append(loss0)
        step_lengths.append(lr0)
        ratios.append(slope_ratio)
        grad_norms.append(grad_norm.eval())
        pre_grad_norms.append(pre_grad_norm.eval())
        pre_grad_stable_norms.append(pre_grad_stable_norm.eval())

        if step % report_frequency == 0:
            print(
                "Step %d loss %.2f, target decrease %.3f, actual decrease, %.3f ratio %.2f grad norm: %.2f pregrad norm: %.2f"
                % (step, loss0, target_delta, actual_delta, slope_ratio,
                   grad_norm.eval(), pre_grad_norm.eval()))

        if adaptive_step_frequency and adaptive_step and step > adaptive_step_burn_in:
            # shrink if wrong prediction, don't shrink if prediction is tiny
            if slope_ratio < alpha and abs(
                    target_delta) > 1e-6 and adaptive_step:
                print("%.2f %.2f %.2f" % (loss0, loss1, slope_ratio))
                print(
                    "Slope optimality %.2f, shrinking learning rate to %.2f" %
                    (
                        slope_ratio,
                        lr0 * beta,
                    ))
                sess.run(vard[lr].setter, feed_dict={vard[lr].p: lr0 * beta})

            # grow learning rate, slope_ratio .99 worked best for gradient
            elif step > 0 and i % 50 == 0 and slope_ratio > 0.90 and adaptive_step:
                print("%.2f %.2f %.2f" % (loss0, loss1, slope_ratio))
                print("Growing learning rate to %.2f" % (lr0 * growth_rate))
                sess.run(vard[lr].setter,
                         feed_dict={vard[lr].p: lr0 * growth_rate})

        u.record_time()

    # check against expected loss
    if 'Apple' in sys.version:
        pass
        #    u.dump(losses, "kfac_small_final_mac.csv")
        targets = np.loadtxt("data/kfac_small_final_mac.csv", delimiter=",")
    else:
        pass
        #    u.dump(losses, "kfac_small_final_linux.csv")
        targets = np.loadtxt("data/kfac_small_final_linux.csv", delimiter=",")

    u.check_equal(targets, losses[:len(targets)], rtol=1e-1)
    u.summarize_time()
    print("Test passed")
Ejemplo n.º 9
0
def model_creator(batch_size, name="default", dtype=np.float32):
  """Create MNIST autoencoder model. Dataset is part of model."""

  model = Model(name)

  def get_batch_size(data):
    if isinstance(data, IndexedGrad):
      return int(data.live[0].shape[1])
    else:
      return int(data.shape[1])

  init_dict = {}
  global_vars = []
  local_vars = []
  
  # TODO: factor out to reuse between scripts
  # TODO: change feed_dict logic to reuse value provided to VarStruct
  # current situation makes reinitialization of global variable change
  # it's value, counterinituitive
  def init_var(val, name, is_global=False):
    """Helper to create variables with numpy or TF initial values."""
    if isinstance(val, tf.Tensor):
      var = u.get_variable(name=name, initializer=val, reuse=is_global)
    else:
      val = np.array(val)
      assert u.is_numeric(val), "Non-numeric type."
      
      var_struct = u.get_var(name=name, initializer=val, reuse=is_global)
      holder = var_struct.val_
      init_dict[holder] = val
      var = var_struct.var

    if is_global:
      global_vars.append(var)
    else:
      local_vars.append(var)
      
    return var

  # TODO: get rid of purely_relu
  def nonlin(x):
    if purely_relu:
      return tf.nn.relu(x)
    elif purely_linear:
      return tf.identity(x)
    else:
      return tf.sigmoid(x)

  # TODO: rename into "nonlin_d"
  def d_nonlin(y):
    if purely_relu:
      return u.relu_mask(y)
    elif purely_linear:
      return 1
    else: 
      return y*(1-y)

  patches = train_images[:,:args.batch_size];
  test_patches = test_images[:,:args.batch_size];

  if args.dataset == 'cifar':
    input_dim = 3*32*32
  elif args.dataset == 'mnist':
    input_dim = 28*28
  else:
    assert False
  fs = [args.batch_size, input_dim, 1024, 1024, 1024, 196, 1024, 1024, 1024,
        input_dim]
    
  def f(i): return fs[i+1]  # W[i] has shape f[i] x f[i-1]
  n = len(fs) - 2

  # Full dataset from which new batches are sampled
  X_full = init_var(train_images, "X_full", is_global=True)

  X = init_var(patches, "X", is_global=False)  # stores local batch per model
  W = [None]*n
  W.insert(0, X)
  A = [None]*(n+2)
  A[1] = W[0]
  for i in range(1, n+1):
    init_val = ng_init(f(i), f(i-1)).astype(dtype)
    W[i] = init_var(init_val, "W_%d"%(i,), is_global=True)
    A[i+1] = nonlin(kfac_lib.matmul(W[i], A[i]))
  err = A[n+1] - A[1]
  model.loss = u.L2(err) / (2 * get_batch_size(err))

  # create test error eval
  layer0 = init_var(test_patches, "X_test", is_global=True)
  layer = layer0
  for i in range(1, n+1):
    layer = nonlin(W[i] @ layer)
  verr = (layer - layer0)
  model.vloss = u.L2(verr) / (2 * get_batch_size(verr))

  # manually compute backprop to use for sanity checking
  B = [None]*(n+1)
  B2 = [None]*(n+1)
  B[n] = err*d_nonlin(A[n+1])
  _sampled_labels_live = tf.random_normal((f(n), f(-1)), dtype=dtype, seed=0)
  if args.fixed_labels:
    _sampled_labels_live = tf.ones(shape=(f(n), f(-1)), dtype=dtype)
    
  _sampled_labels = init_var(_sampled_labels_live, "to_be_deleted",
                             is_global=False)

  B2[n] = _sampled_labels*d_nonlin(A[n+1])
  for i in range(n-1, -1, -1):
    backprop = t(W[i+1]) @ B[i+1]
    B[i] = backprop*d_nonlin(A[i+1])
    backprop2 = t(W[i+1]) @ B2[i+1]
    B2[i] = backprop2*d_nonlin(A[i+1])

  cov_A = [None]*(n+1)    # covariance of activations[i]
  cov_B2 = [None]*(n+1)   # covariance of synthetic backprops[i]
  vars_svd_A = [None]*(n+1)
  vars_svd_B2 = [None]*(n+1)
  dW = [None]*(n+1)
  dW2 = [None]*(n+1)
  pre_dW = [None]*(n+1)   # preconditioned dW
  # todo: decouple initial value from covariance update
  # maybe need start with identity and do running average
  for i in range(1,n+1):
    if regularized_svd:
      cov_A[i] = init_var(A[i]@t(A[i])/args.batch_size+args.Lambda*u.Identity(f(i-1)), "cov_A%d"%(i,))
      cov_B2[i] = init_var(B2[i]@t(B2[i])/args.batch_size+args.Lambda*u.Identity(f(i)), "cov_B2%d"%(i,))
    else:
      cov_A[i] = init_var(A[i]@t(A[i])/args.batch_size, "cov_A%d"%(i,))
      cov_B2[i] = init_var(B2[i]@t(B2[i])/args.batch_size, "cov_B2%d"%(i,))
    vars_svd_A[i] = u.SvdWrapper(cov_A[i],"svd_A_%d"%(i,))
    vars_svd_B2[i] = u.SvdWrapper(cov_B2[i],"svd_B2_%d"%(i,))
    if use_tikhonov:
      whitened_A = u.regularized_inverse3(vars_svd_A[i],L=args.Lambda) @ A[i]
      whitened_B2 = u.regularized_inverse3(vars_svd_B2[i],L=args.Lambda) @ B[i]
    else:
      whitened_A = u.pseudo_inverse2(vars_svd_A[i]) @ A[i]
      whitened_B2 = u.pseudo_inverse2(vars_svd_B2[i]) @ B[i]
    
    dW[i] = (B[i] @ t(A[i]))/args.batch_size
    dW2[i] = B[i] @ t(A[i])
    pre_dW[i] = (whitened_B2 @ t(whitened_A))/args.batch_size

    
  sampled_labels_live = A[n+1] + tf.random_normal((f(n), f(-1)),
                                                  dtype=dtype, seed=0)
  if args.fixed_labels:
    sampled_labels_live = A[n+1]+tf.ones(shape=(f(n), f(-1)), dtype=dtype)
  sampled_labels = init_var(sampled_labels_live, "sampled_labels", is_global=False)
  err2 = A[n+1] - sampled_labels
  model.loss2 = u.L2(err2) / (2 * args.batch_size)
  model.global_vars = global_vars
  model.local_vars = local_vars
  model.trainable_vars = W[1:]

  # todo, we have 3 places where model step is tracked, reduce
  model.step = init_var(u.as_int32(0), "step", is_global=False)
  advance_step_op = model.step.assign_add(1)
  assert get_batch_size(X_full) % args.batch_size == 0
  batches_per_dataset = (get_batch_size(X_full) // args.batch_size)
  batch_idx = tf.mod(model.step, batches_per_dataset)
  start_idx = batch_idx * args.batch_size
  advance_batch_op = X.assign(X_full[:,start_idx:start_idx + args.batch_size])
  
  def advance_batch():
    print("Step for model(%s) is %s"%(model.name, u.eval(model.step)))
    sess = u.get_default_session()
    # TODO: get rid of _sampled_labels
    sessrun([sampled_labels.initializer, _sampled_labels.initializer])
    if args.advance_batch:
      with u.timeit("advance_batch"):
        sessrun(advance_batch_op)
    sessrun(advance_step_op)
    
  model.advance_batch = advance_batch

  # TODO: refactor this to take initial values out of Var struct
  #global_init_op = tf.group(*[v.initializer for v in global_vars])
  global_init_ops = [v.initializer for v in global_vars]
  global_init_op = tf.group(*[v.initializer for v in global_vars])
  global_init_query_ops = [tf.logical_not(tf.is_variable_initialized(v))
                           for v in global_vars]
  
  def initialize_global_vars(verbose=False, reinitialize=False):
    """If reinitialize is false, will not reinitialize variables already
    initialized."""
    
    sess = u.get_default_session()
    if not reinitialize:
      uninited = sessrun(global_init_query_ops)
      # use numpy boolean indexing to select list of initializers to run
      to_initialize = list(np.asarray(global_init_ops)[uninited])
    else:
      to_initialize = global_init_ops
      
    if verbose:
      print("Initializing following:")
      for v in to_initialize:
        print("   " + v.name)

    sessrun(to_initialize, feed_dict=init_dict)
  model.initialize_global_vars = initialize_global_vars

  # didn't quite work (can't initialize var in same run call as deps likely)
  # enforce that batch is initialized before everything
  # except fake labels opa
  # for v in local_vars:
  #   if v != X and v != sampled_labels and v != _sampled_labels:
  #     print("Adding dep %s on %s"%(v.initializer.name, X.initializer.name))
  #     u.add_dep(v.initializer, on_op=X.initializer)
      
  local_init_op = tf.group(*[v.initializer for v in local_vars],
                           name="%s_localinit"%(model.name))
  print("Local vars:")
  for v in local_vars:
    print(v.name)
    
  def initialize_local_vars():
    sess = u.get_default_session()
    sessrun(_sampled_labels.initializer, feed_dict=init_dict)
    sessrun(X.initializer, feed_dict=init_dict)
    sessrun(local_init_op, feed_dict=init_dict)
  model.initialize_local_vars = initialize_local_vars

  return model