Ejemplo n.º 1
0
def disp(net, ma, ol, b, batch=None):
    filters = net.conv1.weight.data.cpu().numpy()
    fig, axes = plt.subplots(4, 4)
    for a, ax in enumerate(axes.reshape(-1)):
        f = filters[a, :, :].squeeze()
        im = ax.imshow(f)
        if a == len(axes.reshape(-1)) - 1:
            fig.colorbar(im, ax=ax)
    plt.savefig('/home/jk/matt/nat_learn_synch/filters{}.png'.format(b))
    plt.close()
    #grads = net.conv1.weight.grad.data.cpu().numpy()
    #fig, axes = plt.subplots(4,4)
    #for a, ax in enumerate(axes.reshape(-1)):
    #	g = grads[a,:,:].squeeze()
    #       im = ax.imshow(g)
    #	if a == len(axes.reshape(-1)) - 1:
    #	    fig.colorbar(im, ax=ax)
    #    plt.savefig('/home/jk/matt/nat_learn_synch/grads{}.png'.format(b))
    #    plt.close()
    plt.plot(ma)
    plt.savefig('/home/jk/matt/nat_learn_synch/activities.png')
    plt.close()
    plt.plot(ol)
    plt.savefig('/home/jk/matt/nat_learn_synch/ortho_loss.png')
    plt.close()
    if batch is not None:
        for i in range(5):
            fig, ax = plt.subplots()
            cplx = batch[i, 0, :, :] + 1j * batch[i, 1, :, :]
            cplx_imshow(ax, cplx, cm=plt.cm.hsv)
            plt.savefig('/home/jk/matt/nat_learn_synch/synch{}.png'.format(i))
            plt.close()
Ejemplo n.º 2
0
def disp(Phi, save_dir, batch, w, inner_energies, outer_energies):
    for e, energies in enumerate([inner_energies, outer_energies]):
        name = 'inner_energies.png' if e == 0 else 'outer_energies.png'
        plt.plot(np.array(energies).T)
        plt.legend(('Total', 'Reconstruction', 'Sparsity', 'Desynchrony'))
        plt.savefig(os.path.join(save_dir, name))
        plt.close()

    fig, axes = plt.subplots(4, 4)
    for a, ax in enumerate(axes.reshape(-1)):
        filt = Phi.data[a, ...].reshape(batch.shape[2],
                                        batch.shape[3]).cpu().numpy()
        ax.imshow(filt)
    plt.savefig(os.path.join(save_dir, 'weights.png'))
    plt.close()

    fig, axes = plt.subplots(1, 2)
    Phi = Phi.unsqueeze(0).repeat(2, 1, 1)
    img = torch.cat((batch[0, 0, ...].unsqueeze(0),
                     torch.zeros_like(batch[0, 0, ...].unsqueeze(0))),
                    dim=0).data.cpu().numpy()
    reconstruction = torch.einsum('bcm,cmn->bcn', w, Phi)[0, ...].reshape(
        2, batch.shape[2], batch.shape[3]).data.cpu().numpy()
    arrays = [img, reconstruction]
    for (ax, array) in zip(axes, arrays):
        cplx_array = array[0, ...] + 1j * array[1, ...]
        cplx_imshow(ax, cplx_array)
    plt.savefig(os.path.join(save_dir, 'cplx_reconstruction.png'))
    plt.close()
Ejemplo n.º 3
0
def disp(net, ma, batch):
    filters = net.conv1.weight.data.cpu().numpy()
    fig, axes = plt.subplots(3,3)	
    for a, ax in enumerate(axes.reshape(-1)):
	f = filters[a,:,:].squeeze()
        im = ax.imshow(f)
	if a == len(axes.reshape(-1)) - 1:
	    fig.colorbar(im, ax=ax)
    plt.savefig('/home/jk/matt/learn_synch/filters.png')
    plt.close()
    plt.plot(ma)
    plt.savefig('/home/jk/matt/learn_synch/activities.png')
    plt.close()
    for i in range(5):
	fig, ax = plt.subplots()
	ipdb.set_trace()
	cplx = batch[i,0,:,:] + 1j*batch[i,1,:,:]
	cplx_imshow(ax, cplx, cm=plt.cm.hsv)
	plt.savefig('/home/jk/matt/learn_synch/synch_{}.png'.format(i))
	plt.close()
Ejemplo n.º 4
0
def evaluate(phase_batch,
             coupling,
             p_sigma=1.0,
             num_samples=16,
             num_steps=256,
             save_dir=None):
    '''Generate num_samples batches of successive samples using HMC and save them.'''
    phase = 2 * np.pi * torch.rand_like(phase_batch).double()
    phase_history = []
    for n in tqdm(range(num_samples)):
        phase = hmc(phase, coupling, p_sigma=p_sigma, num_steps=num_steps)
        phase_history.append(phase.data.numpy())
    fig, axes = plt.subplots(4, 4)

    for a, ax in enumerate(axes.reshape(-1)):
        cplx_img = np.cos(
            phase_history[a][0, ...]) + 1j * np.sin(phase_history[a][0, ...])
        cplx_imshow(ax, cplx_img)
    plt.savefig(os.path.join(save_dir, 'samples.png'))
    plt.close()
Ejemplo n.º 5
0
def disp(net, save_dir, phase_arrays, inner_energies, outer_energies, out):
    for e, energies in enumerate([inner_energies, outer_energies]):
        name = 'inner_energies.png' if e == 0 else 'outer_energies.png'
        plt.plot(np.array(energies).T)
        plt.legend(('Total', 'Local Coherence', 'Global Coherence'))
        plt.savefig(os.path.join(save_dir, name))
        plt.close()
    
    fig, axes = plt.subplots(4,4)
    for a, ax in enumerate(axes.reshape(-1)):
        filt = net.conv.weight.data[a,...].reshape(kernel_size, kernel_size).cpu().numpy()
        ax.imshow(filt) 
    plt.savefig(os.path.join(save_dir, 'weights.png'))
    plt.close()

    fig, axes = plt.subplots(1,2)
    for (ax,array) in zip(axes, phase_arrays):
        array = array.data.cpu().numpy()[0,...]
        cplx_array = array[0,...] + 1j*array[1,...]
        cplx_imshow(ax, cplx_array) 
    plt.savefig(os.path.join(save_dir, 'inner_opt.png'))
    plt.close()

    fig, axes = plt.subplots(4,4)
    for a, ax in enumerate(axes.reshape(-1)):
        feature_map = out[0,:,a,...].cpu().numpy()
        cplx_feature_map = feature_map[0,...] + 1j*feature_map[1,...] 
        cplx_imshow(ax, cplx_feature_map)
    plt.savefig(os.path.join(save_dir, 'feature_maps.png'))
    plt.close()
    
    if net.bias:
        biases = net.conv.bias.data.cpu().numpy()
        plt.hist(biases)
        plt.savefig(os.path.join(save_dir, 'biases.png'))
        plt.close()
Ejemplo n.º 6
0
def run(z0, k, model, energy=Energy1, constraint=None):
    z0_real = np.real(z0).reshape(1, 1, img_side, img_side)
    z0_imag = np.imag(z0).reshape(1, 1, img_side, img_side)
    z0_cplx = torch.tensor(np.concatenate((z0_real, z0_imag), axis=1)).cuda()

    energies = []
    for i in tqdm(range(max_iter)):

        z0_variable = Variable(z0_cplx, requires_grad=True)
        model.zero_grad()

        out = model.forward_cplx(z0_variable).squeeze(0)
        E = energy(out, target=k)
        energies.append(E.cpu().data.numpy())
        E.backward()
        ratio = np.abs(z0_variable.grad.data.cpu().numpy()).mean()
        lr_use = lr / ratio
        z0_variable.data.sub_(z0_variable.grad.data * lr_use)
        z0_cplx = z0_variable.data.cpu().numpy()  # b, c, h, w

        z0 = np.expand_dims(z0_cplx[:, 0, :, :] + 1j * z0_cplx[:, 1, :, :],
                            axis=0)
        z0 = clip_norm(z0, constraint=constraint)

        # Shape for input
        z0_real = np.real(z0)
        z0_imag = np.imag(z0)
        z0_cplx = torch.tensor(np.concatenate((z0_real, z0_imag),
                                              axis=1)).cuda()

        if i == 0 or (i + 1) % save_every == 0:
            fig, ax = plt.subplots()
            cplx_imshow(ax, z0, remap=(mu, sigma))
            plt.savefig(os.path.join(SAVE_PATH, 'dream%04d.png' % i))
            plt.close()
    return z0, np.array(energies)
Ejemplo n.º 7
0
        cons = tuple([{'type' : 'eq',
		   'fun'  : constraint,
	 	   'args' : (j,)} for j in range(v_side**2) + range(h_side**2)])
        res = minimize(Ham, z0, method='SLSQP', constraints=cons, options={'disp':True, 'maxiter':max_iter})
        ex = res['x']
        ex_cplx = ex[:v_side**2].reshape(v_side, v_side) + 1j*ex[v_side**2:2*v_side**2].reshape(v_side, v_side)
        if sub_exp == 1:
            bar_1_avg_phase = np.mean(np.angle(ex_cplx[:int(np.ceil(v_side / 2.0)),i]))
            bar_2_avg_phase = np.mean(np.angle(ex_cplx[int(np.ceil(v_side / 2.0)):,v_side - i - 1]))
            phase_diff.append(np.abs(bar_1_avg_phase - bar_2_avg_phase))
        elif sub_exp == 2:
	    p = pool(ex)
	    phase_diff.append(np.abs(np.angle(p[0]) - np.angle(p[1])))
	elif sub_exp == 3:
           fig, ax = plt.subplots()
	   cplx_imshow(ax, ex_cplx, cm=plt.cm.hsv)
	   plt.savefig('/home/matt/geman_style_videos/e{0}.png'.format(i))
	   continue
	    
    avg_phase_diff.append(np.mean(phase_diff)) 
    std_phase_diff.append(np.std(phase_diff))
if sub_exp != 3:
    t = np.array(range(v_side)) - int(np.floor(v_side / 2.0))
    x = np.array(avg_phase_diff)
    s = np.array(std_phase_diff)
    plt.plot(t,x, color='#CC4F1B')
    plt.fill_between(t, x - s, x + s, edgecolor='#CC4F1B', facecolor='#FF9848', alpha=.5)
    plt.savefig('/home/matt/geman_style_videos/phase_diff2.png')

    
Ejemplo n.º 8
0
def noisy_mnist(max_noise_patches,
                cplx=True,
                patch_side=8,
                num_images=60000,
                save_dir=None,
                cplx_bgrnd=True,
                display=True):

    clean_mnist = torch.load(
        os.path.join(os.path.expanduser('~'), 'data', 'synch_data',
                     'MNIST/processed/training.pt'))
    num_mnist = clean_mnist[0].shape[0]
    im_side = 40

    imgs = []
    labels = []

    counter = 0
    pbar = tqdm(total=num_images)
    while counter < num_images:
        all_patches = []
        all_phases = []
        canvas = torch.zeros(im_side, im_side)
        clean_img = clean_mnist[0][counter]
        label = clean_mnist[1][counter]
        digit_y = torch.randint(40 - 28 - 1, (1, ))
        digit_x = torch.randint(40 - 28 - 1, (1, ))
        canvas[digit_y:digit_y + 28, digit_x:digit_x + 28] = clean_img
        all_phases.append(2 * np.pi * torch.rand((1, )))

        # Add 'structured' noise
        patch_inds = torch.randperm(num_mnist)[:max_noise_patches]
        for ind in patch_inds:
            patch_canvas = torch.zeros(im_side, im_side)

            # Acquire patch
            rand_y = torch.randint(28 - patch_side - 1, (1, ))
            rand_x = torch.randint(28 - patch_side - 1, (1, ))
            patch = clean_mnist[0][ind][rand_y:rand_y + patch_side,
                                        rand_x:rand_x + patch_side]

            # Place patch
            patch_y = torch.randint(40 - patch_side - 1, (1, ))
            patch_x = torch.randint(40 - patch_side - 1, (1, ))
            patch_canvas[patch_y:patch_y + patch_side,
                         patch_x:patch_x + patch_side] = patch
            all_patches.append(patch_canvas)
            all_phases.append(2 * np.pi * torch.rand((1, )))

        if cplx:
            canvas = canvas.unsqueeze(0)
            cplx_canvas = torch.cat([
                canvas * torch.cos(all_phases[0]),
                canvas * torch.sin(all_phases[0])
            ],
                                    dim=0)
            for (pch, phs) in zip(all_patches[1:], all_phases[1:]):
                pch = pch.unsqueeze(0)
                cplx_pch = torch.cat(
                    [pch * torch.cos(phs), pch * torch.sin(phs)], dim=0)
                cplx_canvas[cplx_pch != 0] = cplx_pch[cplx_pch != 0]
            canvas = cplx_canvas
        else:
            for pch in all_patches[1:]:
                canvas += pch
                canvas = torch.clamp(canvas, 0, 255)
        if cplx_bgrnd:
            bgrnd_phase = 2 * np.pi * torch.rand((1, ))
            canvas[0, ...] = torch.where(
                canvas[0, ...] == 0.0,
                torch.cos(bgrnd_phase) * torch.ones_like(canvas[0, ...]),
                canvas[0, ...])
            canvas[1, ...] = torch.where(
                canvas[1, ...] == 0.0,
                torch.sin(bgrnd_phase) * torch.ones_like(canvas[1, ...]),
                canvas[1, ...])
        imgs.append(canvas)
        labels.append(label)
        if display and counter == 0:
            if cplx:
                np_img = canvas.numpy() / 255.
                cplx_img = np_img[0, ...] + 1j * np_img[1, ...]
                fig, ax = plt.subplots()
                cplx_imshow(ax, cplx_img)
                plt.savefig(
                    os.path.join(os.path.expanduser('~'),
                                 'tmp_{}.png'.format(max_noise_patches)))
                plt.close()
            else:
                np_img = canvas.numpy()
                plt.imshow(np_img)
                plt.savefig(os.path.join(os.path.expanduser('~'), 'tmp.png'))
                plt.close()
        counter += 1
        pbar.update(1)
    imgs = torch.stack(imgs)
    labels = torch.stack(labels)
    torch.save((imgs, labels),
               os.path.join(save_dir,
                            'processed.pt'.format(max_noise_patches)))
Ejemplo n.º 9
0
def multi_mnist(num_digits,
                cplx=True,
                num_images=60000,
                cplx_bgrnd=True,
                save_dir=None,
                display=True):

    if num_images > 60000:
        raise ('Number of images must be less than 60000')

    clean_mnist = torch.load(
        os.path.join(os.path.expanduser('~'),
                     'data/synch_data/MNIST/processed/training.pt'))
    num_mnist = clean_mnist[0].shape[0]
    im_side = 28 + (num_digits - 1) * 10

    imgs = []
    labels = []

    counter = 0
    pbar = tqdm(total=num_images)
    while counter < num_images:
        lb = []
        all_digits = []
        all_phases = []
        canvas = torch.zeros(im_side, im_side)
        clean_img = clean_mnist[0][counter]
        if num_digits > 1:
            digit_y1 = torch.randint(im_side - 28, (1, ))
            digit_x1 = torch.randint(im_side - 28, (1, ))
            canvas[digit_y1:digit_y1 + 28, digit_x1:digit_x1 + 28] = clean_img
        else:
            canvas = clean_img.float()
        lb.append(clean_mnist[1][counter])

        all_digits.append(canvas)
        all_phases.append(2 * np.pi * torch.rand((1, )))

        # Add digits
        digit_inds = torch.randperm(num_mnist)[:num_digits - 1]
        for ind in digit_inds:
            digit_canvas = torch.zeros(im_side, im_side)
            new_digit = clean_mnist[0][ind]
            lb.append(clean_mnist[1][ind])

            # Place patch
            if num_digits > 1:
                digit_yn = torch.randint(im_side - 28, (1, ))
                digit_xn = torch.randint(im_side - 28, (1, ))
                digit_canvas[digit_yn:digit_yn + 28,
                             digit_xn:digit_xn + 28] = new_digit
            else:
                digit_canvas = new_digit.float()

            all_digits.append(digit_canvas)
            all_phases.append(2 * np.pi * torch.rand((1, )))

        if cplx:
            canvas = canvas.unsqueeze(0)
            cplx_canvas = torch.cat([
                canvas * torch.cos(all_phases[0]),
                canvas * torch.sin(all_phases[0])
            ],
                                    dim=0)
            for (dig, phs) in zip(all_digits[1:], all_phases[1:]):
                dig = dig.unsqueeze(0)
                cplx_dig = torch.cat(
                    [dig * torch.cos(phs), dig * torch.sin(phs)], dim=0)
                cplx_canvas[cplx_dig != 0] = cplx_dig[cplx_dig != 0]
            canvas = cplx_canvas
        else:
            for dig in all_digits[1:]:
                canvas += dig
                canvas = torch.clamp(canvas, 0, 255)
        if cplx_bgrnd:
            bgrnd_phase = 2 * np.pi * torch.rand((1, ))
            canvas[0, ...] = torch.where(
                canvas[0, ...] == 0.0,
                128 * torch.cos(bgrnd_phase * torch.ones_like(canvas[0, ...])),
                canvas[0, ...])
            canvas[1, ...] = torch.where(
                canvas[1, ...] == 0.0,
                128 * torch.sin(bgrnd_phase * torch.ones_like(canvas[0, ...])),
                canvas[1, ...])
        imgs.append(canvas)
        labels.append(torch.tensor(lb))
        if display and counter == 0:
            if cplx:
                np_img = canvas.numpy() / 255.
                cplx_img = np_img[0, ...] + 1j * np_img[1, ...]
                fig, ax = plt.subplots()
                cplx_imshow(ax, cplx_img)
                plt.savefig(
                    os.path.join(os.path.expanduser('~'),
                                 'tmp{}.png'.format(num_digits)))
                #ipdb.set_trace()
            else:
                np_img = canvas.numpy()
                plt.imshow(np_img)
                plt.savefig(
                    os.path.join(os.path.expanduser('~'),
                                 'tmp{}.png'.format(num_digits)))
                plt.close()
                #ipdb.set_trace()

        counter += 1
        pbar.update(1)
    imgs = torch.stack(imgs)
    labels = torch.stack(labels)
    torch.save((imgs, labels),
               os.path.join(save_dir, 'processed.pt'.format(num_digits)))
Ejemplo n.º 10
0
for (batch, target) in dl:
    #v_prime = batch.data.numpy()[0,0,:,:]
    v_prime = (np.random.rand(28,28) - mu) / sigma
    k = target[0].data.numpy()
    print('Target: {}'.format(k))
    z0 = init(v_prime)
    #z0 = v_prime
    z0_real = np.real(z0)
    z0_imag = np.imag(z0)
    z0_flat = np.concatenate((z0_real.reshape(-1), z0_imag.reshape(-1)), axis=0)

    print('Making constraints...')	
    cons = tuple([{'type' : 'ineq',
		   'fun'  : constraint,
	 	   'args' : (j,free,)} for j in range(img_side**2)])
    print('Optimizing...')
    res = minimize(Energy, z0_flat, method='SLSQP', constraints=cons, options={'disp': True, 'maxiter':max_iter, 'verbose':2, 'ftol':0.0}, args=k)
    cp = res.x
    cp_real = cp[:784].reshape(28,28)
    cp_imag = cp[784:].reshape(28,28)
    cp_cplx = cp_real + 1j*cp_imag
   
    for i, img in enumerate([z0, cp_cplx]):
 	name = 'init' if i == 0 else 'seg'
        fig, ax = plt.subplots()
        cplx_imshow(ax, img, cm=plt.cm.hsv, remap=(mu,sigma))
        plt.savefig('/home/jk/matt/' + name + '.png')
        plt.close()
    ipdb.set_trace() 
    print('hey')