Exemplo n.º 1
0
def train(
    q,
    k,
    scale,
    proj,
    true_attn,
    L_dL,
    proj_fn,
    alpha,
    num_iters,
    key,
    sample=True,
    post_renorm=False,
):
    losses = onp.zeros((num_iters, ))
    grads = onp.zeros((num_iters, ))
    for i in range(num_iters):
        if sample:
            key, key_sample = jax.random.split(key)
        else:
            key_sample = key
        projection_matrix = proj_fn(key_sample)
        kl_val, (dq, dk) = L_dL(q, k, projection_matrix, true_attn)
        q -= alpha * dq
        k -= alpha * dk
        losses[i] = kl_val
        grads[i] += norm(dq)**2
        grads[i] += norm(dk)**2

        if post_renorm:
            q = renorm(q)
            k = renorm(k)

    return losses, grads, q, k, scale, projection_matrix
Exemplo n.º 2
0
def train_proj(
    q,
    k,
    scale,
    proj,
    true_attn,
    L_dL,
    proj_fn_unused,
    alpha,
    num_iters,
    key,
    sample,
    post_renorm=False,
):
    losses = onp.zeros((num_iters, ))
    grads = onp.zeros((num_iters, ))
    for i in range(num_iters):
        kl_val, (dq, dk, dscale, dproj) = L_dL(q, k, scale, proj, true_attn)
        """
        # dbg
        ra, _ = relu_rff_attn0(q, k, proj)
        print(f"kl {kl_val}, attnmin {ra.min()}")
        if ra.min() < 0:
            import pdb; pdb.set_trace()
        if jnp.isinf(kl_val) or jnp.isnan(kl_val):
            import pdb; pdb.set_trace()
        #/dbg
        """

        q -= alpha * dq
        k -= alpha * dk
        if dscale is not None:
            scale -= alpha * dscale
        if dproj is not None:
            proj -= alpha * dproj

        losses[i] = kl_val

        grads[i] += norm(dq)**2
        grads[i] += norm(dk)**2
        if dscale is not None:
            grads[i] += norm(dscale)**2
        if dproj is not None:
            grads[i] += norm(dproj)**2

        #import pdb; pdb.set_trace()
        if post_renorm:
            q = renorm(q)
            k = renorm(k)

        #import pdb; pdb.set_trace()
        #print(f"grad {grads[i]}")
    #import pdb; pdb.set_trace()
    return losses, grads, q, k, scale, proj
Exemplo n.º 3
0
def main():
    n = 1000
    pf = ParticleFilter()
    pf.initialize(size=n)

    # visualization
    from matplotlib import pyplot as plt
    ps0 = pf.particles

    cost0 = np.linalg.norm(ps0[:, :2],
                           axis=-1)  # test: cost by distance from zero
    kp, kx = 0.5, 1.0  # configure as p=50% match at 1.0m distance
    k = (-np.log(kp) / kx)
    p0 = np.exp(-k * cost0)
    p0 = U.renorm(p0, 0.2, 0.7)

    S = np.reshape
    nax = np.newaxis

    c0 = np.full((n, 3), [1.0, 0.0, 0.0])
    col0 = np.concatenate([c0, p0[:, nax]], axis=-1)

    pf.resample(p0, noise=(0.5, 0.5, 0.5))
    ps1 = pf.particles
    cost1 = np.linalg.norm(ps1[:, :2],
                           axis=-1)  # test: cost by distance from zero
    p1 = np.exp(-k * cost1)
    p1 = U.renorm(p1, 0.2, 0.7)

    c1 = np.full((n, 3), [0.0, 1.0, 0.0])
    col1 = np.concatenate([c1, p1[:, nax]], axis=-1)

    sc0 = 20.0 * p0
    sc1 = 20.0 * p1

    plt.scatter(ps0[:, 0], ps0[:, 1], label='original', c=col0, s=sc0)
    plt.scatter(ps1[:, 0], ps1[:, 1], label='resample', c=col1, s=sc1)

    plt.xlim(-10.0, 10.0)
    plt.ylim(-10.0, 10.0)

    plt.legend(loc=1)
    ax = plt.gca()
    ax.set_axisbelow(True)
    plt.grid()

    leg = ax.get_legend()
    hl_dict = {handle.get_label(): handle for handle in leg.legendHandles}
    hl_dict['original'].set_color([1.0, 0.0, 0.0])
    hl_dict['resample'].set_color([0.0, 1.0, 0.0])
    plt.title('Resample (single step snapshot)')
    plt.show()
Exemplo n.º 4
0
def plot_image(image, args, flag):
    if flag == 'orig':  # for plotting an original image
        if args.NUM_CHANNELS == 3:  # for rgb images
            frame_image(renorm(image[0].cpu().numpy().transpose((1, 2, 0))))
        elif args.NUM_CHANNELS == 1:  # for grayscale images
            frame_image(renorm(image[0].cpu().numpy().reshape(
                (args.IMG_SIZE, args.IMG_SIZE))),
                        cmap='gray')
        else:
            raise ValueError(
                'NUM_CHANNELS must be 1 for grayscale or 3 for rgb images.')
    elif flag == 'recons':  # for plotting reconstructions
        if args.NUM_CHANNELS == 3:  # for rgb images
            if args.ALG == 'csdip':
                frame_image(renorm(image[0][0].transpose((1, 2, 0))))
            elif args.ALG == 'bm3d':
                frame_image(utils.renorm(np.asarray(image.transpose(1, 2, 0))))
            elif args.ALG == 'dct' or args.ALG == 'wavelet':
                frame_image(
                    renorm(
                        image.reshape(-1, 128, 3, order='F').swapaxes(0, 1)))
            else:
                raise ValueError(
                    'Plotting rgb images is supported only by csdip, bm3d, dct, wavelet.'
                )
        elif args.NUM_CHANNELS == 1:  # for grayscale images
            frame_image(renorm(image.reshape(args.IMG_SIZE, args.IMG_SIZE)),
                        cmap='gray')
        else:
            raise ValueError(
                'NUM_CHANNELS must be 1 for grayscale or 3 for rgb images.')
    else:
        raise ValueError(
            'flag input must be orig or recons for plotting original image or reconstruction, respectively.'
        )
Exemplo n.º 5
0
def report_train(
    q,
    k,
    proj_fn,
    L_dL,
    num_features,
    key,
    train_fn,
    sample=True,
    title=None,
    post_renorm=False,
):
    vals = jnp.exp(q @ k.T)
    true_attn = vals / vals.sum(-1, keepdims=True)

    # sample embeddings close to 0 to start
    #key, key_q, key_k, key_pq, key_pk = jax.random.split(key, 5)
    key, key_q, key_k = jax.random.split(key, 3)
    q_init = jax.random.uniform(key_q,
                                shape=q.shape,
                                minval=-gamma,
                                maxval=gamma)
    k_init = jax.random.uniform(key_k,
                                shape=k.shape,
                                minval=-gamma,
                                maxval=gamma)

    if post_renorm:
        q_init = renorm(q_init)
        k_init = renorm(k_init)

    #scale = 1.
    #scale_q = jax.numpy.ones((num_features, 1))
    # seed 1
    key_pq = jax.random.PRNGKey(1111)
    # seed 2
    #key_pq = jax.random.PRNGKey(1234)
    proj_init_q = jax.device_put(proj_fn(key_pq))

    #scale_k = jax.numpy.ones((num_features, 1))
    #proj_init_k = proj_fn(key_pk)

    scale = 1.
    #scale = scale_q
    proj_init = proj_init_q

    key, key_train = jax.random.split(key)

    losses, grads, q_t, k_t, scale, proj = train_fn(
        #q_init.copy(), k_init.copy(),
        q_init,
        k_init,
        scale,
        proj_init if proj_init is not None else None,
        true_attn,
        L_dL,
        proj_fn,
        alpha,
        num_iters,
        key_train,
        sample,
        post_renorm=post_renorm,
    )
    #print(f"scale {scale}")
    #print(proj)
    #import pdb; pdb.set_trace()

    #"""
    fig = go.Figure(data=go.Scattergl(
        x=onp.arange(num_iters),
        y=losses,
        mode="markers",
    ), )
    fig.update_layout(title=title)
    st.plotly_chart(fig, use_container_width=True)

    fig = go.Figure(data=go.Scattergl(
        x=onp.arange(num_iters),
        y=grads,
        mode="markers",
    ), )
    fig.update_layout(title=f"||GRAD||^2 {title}")
    st.plotly_chart(fig, use_container_width=True)
    #"""
    """
    if sample:
        key, key_comp = jax.random.split(key)
    else:
        key_comp = key_train
    print_comp_true(num_features, q_t, k_t, true_attn, key_comp, sample)
    """

    return losses[-32:].mean()  #, mse
Exemplo n.º 6
0
 def loss(q, k, scale, proj, attn_dist):
     qp = renorm(q, axis=-1)
     kp = renorm(k, axis=-1)
     ra, _ = fat.rff_attn(qp, kp, proj)
     return fat.kl(attn_dist, ra).mean()
Exemplo n.º 7
0
 def loss(q, k, scale, proj, attn_dist):
     qp = renorm(q, axis=-1)
     kp = renorm(k, axis=-1)
     ra, _ = fat.rff_attn(qp, kp, jax.lax.stop_gradient(proj))
     return fat.kl(attn_dist, ra).mean()