コード例 #1
0
def pred(config, mode='cifar10'):
    if mode == 'cifar10':
        obs = (3, 32, 32)
    sample_batch_size = 25
    model = PixelCNN(nr_resnet=config.nr_resnet, nr_filters=config.nr_filters,
                     input_channels=obs[0], nr_logistic_mix=config.nr_logistic_mix).cuda()

    if config.load_params:
        load_part_of_model(model, config.load_params)
        print('model parameters loaded')
    sample_op = lambda x: sample_from_discretized_mix_logistic(x, config.nr_logistic_mix)
    rescaling_inv = lambda x: .5 * x + .5

    def sample(model):
        model.train(False)
        data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2])
        data = data.cuda()
        for i in range(obs[1]):
            for j in range(obs[2]):
                with torch.no_grad():
                    data_v = data
                    out = model(data_v, sample=True)
                    out_sample = sample_op(out)
                    data[:, :, i, j] = out_sample.data[:, :, i, j]
        return data

    print('sampling...')
    sample_t = sample(model)
    sample_t = rescaling_inv(sample_t)
    save_image(sample_t, 'images/sample.png', nrow=5, padding=0)
def generate_results(model, data_loader, nr_logistic_mix, do_use_cuda):
    start_time = time.time()
    for batch_idx, (data, img_names) in enumerate(data_loader):
        if do_use_cuda:
            x = Variable(data, requires_grad=False).cuda()
        else:
            x = Variable(data, requires_grad=False)

        x_d, z_e_x, z_q_x, latents = model(x)
        # z_e_x is output of encoder
        # z_q_x is input into decoder
        # latents is code book
        x_tilde = sample_from_discretized_mix_logistic(x_d, nr_logistic_mix)
        nx_tilde = x_tilde.cpu().data.numpy()
        nx_tilde = (0.5 * nx_tilde + 0.5) * 255
        nx_tilde = nx_tilde.astype(np.uint8)
        embed()
        #vae_input = z_e_x.contiguous().view(z_e_x.shape[0],-1)
        #vqvae_rec_images = x_tilde.cpu().data.numpy()
        nz_q_x = z_q_x.contiguous().view(z_q_x.shape[0], -1).cpu().data.numpy()
        nlatents = latents.cpu().data.numpy()
        for ind, img_name in enumerate(img_names):
            #gen_img_name = img_name.replace('.png', 'vqvae_gen.png')
            #imwrite(gen_img_name, nx_tilde[ind][0])
            #latents_name = img_name.replace('.png', 'vqvae_latents.npy')
            z_q_x_name = img_name.replace('.png', 'vqvae_z_q_x.npy')
            np.save(z_q_x_name, nz_q_x[ind])
        if not batch_idx % 10:
            print 'Generate batch_idx: {} Time: {}'.format(
                batch_idx,
                time.time() - start_time)
コード例 #3
0
ファイル: vqvae_vae.py プロジェクト: johannah/trajectories
def generate_results(data_loader):
    start_time = time.time()
    for batch_idx, (data, img_names) in enumerate(data_loader):
        if use_cuda:
            x = Variable(data, requires_grad=False).cuda()
        else:
            x = Variable(data, requires_grad=False)
        dec = vae(x)
        decr = dec.contiguous().view(dec.shape[0], 32, 10, 10)
        udec = (decr * z_q_x_std) + z_q_x_mean
        # TODO - knearest neighbors
        # going to use vae.z_mean, and vae.z_std
        # look at slides from laurent and vincent - cifar summer school
        # to prune unused dimensions
        # now we have mu and sigma that we will run knn lookup within our
        # training set. then once we've mapped to the original frame that made
        # the mu/sigma - we know frame and mu and sigma
        x_d = qmodel.decoder(udec)
        x_tilde = sample_from_discretized_mix_logistic(x_d, nr_logistic_mix)
        nx_tilde = x_tilde.cpu().data.numpy()
        nx_tilde = (0.5 * nx_tilde + 0.5) * 255
        nx_tilde = nx_tilde.astype(np.uint8)

        for ind, img_name in enumerate(img_names):
            print('int', ind)
            gen_img_name = img_name.replace('.npy', 'vv_gen.png')
            imwrite(gen_img_name, nx_tilde[ind][0])
        #    z_q_x_name = img_name.replace('.png', 'vqvae_z_q_x.npy')
        #    np.save(z_q_x_name, nz_q_x[ind])
        if not batch_idx % 10:
            print 'Generate batch_idx: {} Time: {}'.format(
                batch_idx,
                time.time() - start_time)
コード例 #4
0
def generate_reconstruction(base_path, data_loader, nr_logistic_mix,
                            do_use_cuda):
    if not os.path.exists(base_path):
        os.makedirs(base_path)
    for i, (mu, sigma) in enumerate(zip(mus, sigmas)):
        if 100 < i < 200:
            if do_use_cuda:
                tmu = Variable(torch.FloatTensor(mu),
                               requires_grad=False).cuda()
                tsigma = Variable(torch.FloatTensor(sigma),
                                  requires_grad=False).cuda()
            else:
                tmu = Variable(torch.FloatTensor(mu), requires_grad=False)
                tsigma = Variable(torch.FloatTensor(sigma),
                                  requires_grad=False)

            bs = 20
            base = Variable(torch.from_numpy(np.zeros((bs, 800))).float(),
                            requires_grad=False)
            for s in range(bs):
                base_noise = Variable(torch.from_numpy(
                    rdn.normal(0, 1, size=tsigma.size())).float(),
                                      requires_grad=False)
                base[s] = tmu + tsigma * base_noise

            base[:, worst_inds] = 0.0
            z = base.contiguous().view(bs, 32, 5, 5)
            x_d = vae.decoder(z)
            x_tilde = sample_from_discretized_mix_logistic(
                x_d, nr_logistic_mix)
            nx_tilde = x_tilde.cpu().data.numpy()
            inx_tilde = ((0.5 * nx_tilde + 0.5) * 255).astype(np.uint8)
            mean_tilde = np.mean(inx_tilde, axis=0)[0].astype(np.uint8)
            max_tilde = np.max(inx_tilde, axis=0)[0].astype(np.uint8)
            mean_img_name = os.path.join(base_path, 'gmean_%05d.png' % (i))
            a_img_name = os.path.join(base_path, 'gadapt_%05d.png' % (i))
            max_img_name = os.path.join(base_path, 'gmax_%05d.png' % (i))
            imwrite(mean_img_name, mean_tilde)
            imwrite(max_img_name, max_tilde)
            nonzero = np.count_nonzero(inx_tilde, axis=0)[0]
            adapt_tilde = max_tilde
            # must have 3 instances to go into adapt
            adapt_tilde[nonzero < 3] = 0
            imwrite(a_img_name, adapt_tilde)
コード例 #5
0
ファイル: train_vqvae.py プロジェクト: johannah/trajectories
def test(epoch, test_loader, do_use_cuda, save_img_path=None):
    test_loss = []
    for batch_idx, (data, _) in enumerate(test_loader):
        start_time = time.time()
        if do_use_cuda:
            x = Variable(data, requires_grad=False).cuda()
        else:
            x = Variable(data, requires_grad=False)

        x_d, z_e_x, z_q_x, latents = vmodel(x)
        loss_1 = discretized_mix_logistic_loss(x_d,
                                               2 * x - 1,
                                               use_cuda=do_use_cuda)
        loss_2 = F.mse_loss(z_q_x, z_e_x.detach())
        loss_3 = .25 * F.mse_loss(z_e_x, z_q_x.detach())
        test_loss.append(to_scalar([loss_1, loss_2, loss_3]))
    test_loss_mean = np.asarray(test_loss).mean(0)
    if save_img_path is not None:
        x_tilde = sample_from_discretized_mix_logistic(x_d, nr_logistic_mix)
        idx = 0
        x_cat = torch.cat([x[idx], x_tilde[idx]], 0)
        images = x_cat.cpu().data
        pred = (((np.array(x_tilde.cpu().data)[0, 0] + 1.0) / 2.0) *
                float(max_pixel - min_pixel)) + min_pixel
        # input x is between 0 and 1
        real = (np.array(x.cpu().data)[0, 0] *
                float(max_pixel - min_pixel)) + min_pixel
        f, ax = plt.subplots(1, 3, figsize=(10, 3))
        ax[0].imshow(real, vmin=0, vmax=max_pixel)
        ax[0].set_title("original")
        ax[1].imshow(pred, vmin=0, vmax=max_pixel)
        ax[1].set_title("pred epoch %s test loss %s" %
                        (epoch, np.mean(test_loss_mean)))
        ax[2].imshow((pred - real)**2, cmap='gray')
        ax[2].set_title("error")
        f.tight_layout()
        plt.savefig(save_img_path)
        plt.close()
        print("saving example image")
        print("rsync -avhp [email protected]://%s" %
              os.path.abspath(save_img_path))

    return test_loss_mean
def test(x, model, nr_logistic_mix, do_use_cuda=False, save_img_path=None):
    x_d, z_e_x, z_q_x, latents = model(x)
    x_tilde = sample_from_discretized_mix_logistic(x_d, nr_logistic_mix)
    loss_1 = discretized_mix_logistic_loss(x_d,
                                           2 * x - 1,
                                           use_cuda=do_use_cuda)
    loss_2 = F.mse_loss(z_q_x, z_e_x.detach())
    loss_3 = .25 * F.mse_loss(z_e_x, z_q_x.detach())
    test_loss = to_scalar([loss_1, loss_2, loss_3])

    if save_img_path is not None:
        idx = np.random.randint(0, len(test_data))
        x_cat = torch.cat([x[idx], x_tilde[idx]], 0)
        images = x_cat.cpu().data
        oo = 0.5 * np.array(x_tilde.cpu().data)[0, 0] + 0.5
        ii = np.array(x.cpu().data)[0, 0]
        imwrite(save_img_path, oo)
        imwrite(save_img_path.replace('.png', 'orig.png'), ii)
    return test_loss
コード例 #7
0
def generate_results(data_loader, nr_logistic_mix, do_use_cuda):
    start_time = time.time()
    for batch_idx, (data, img_names) in enumerate(data_loader):
        if do_use_cuda:
            x = Variable(data, requires_grad=False).cuda()
        else:
            x = Variable(data, requires_grad=False)
        x_d = vae(x.contiguous().view(x.shape[0], -1))
        #x_tilde = x_d.contiguous().view(x_d.shape[0], 1, dsize, dsize)
        x_di = x_d.contiguous().view(x_d.shape[0], probs_size, dsize, dsize)
        x_tilde = sample_from_discretized_mix_logistic(x_di, nr_logistic_mix)
        nx_tilde = x_tilde.cpu().data.numpy()
        inx_tilde = ((nx_tilde) * 255).astype(np.uint8)
        for ind, img_name in enumerate(img_names):
            gen_img_name = img_name.replace('.png', 'pixelvae_gen.png')
            imwrite(gen_img_name, inx_tilde[ind][0])
        if not batch_idx % 10:
            print 'Generate batch_idx: {} Time: {}'.format(
                batch_idx,
                time.time() - start_time)
コード例 #8
0
def generate_results(base_path, data_loader, nr_logistic_mix, do_use_cuda):
    if not os.path.exists(base_path):
        os.makedirs(base_path)
    start_time = time.time()
    data_mu = np.empty((0, 800))
    data_sigma = np.empty((0, 800))
    limit = 4000
    for batch_idx, (data, img_names) in enumerate(data_loader):
        if batch_idx * 32 > limit:
            continue
        else:
            if do_use_cuda:
                x = Variable(data, requires_grad=False).cuda()
            else:
                x = Variable(data, requires_grad=False)

            x_d = vae(x)
            x_tilde = sample_from_discretized_mix_logistic(
                x_d, nr_logistic_mix)
            nx_tilde = x_tilde.cpu().data.numpy()
            inx_tilde = ((0.5 * nx_tilde + 0.5) * 255).astype(np.uint8)
            # vae.z_mean is batch_sizex800
            # vae.z_sigma is batch_sizex800
            zmean = vae.z_mean.cpu().data.numpy()
            zsigma = vae.z_sigma.cpu().data.numpy()
            data_mu = np.vstack((data_mu, zmean))
            data_sigma = np.vstack((data_sigma, zsigma))
            for ind, img_path in enumerate(img_names):
                img_name = os.path.split(img_path)[1]
                gen_img_name = img_name.replace('.png', 'conv_vae_gen.png')
                #gen_latent_name = img_name.replace('.png', 'conv_vae_latents.npz')
                imwrite(os.path.join(base_path, gen_img_name),
                        inx_tilde[ind][0])
                #np.savez(gen_latent_name, zmean=zmean[ind], zsigma=zsigma[ind])
            if not batch_idx % 10:
                print 'Generate batch_idx: {} Time: {}'.format(
                    batch_idx,
                    time.time() - start_time)
    np.savez(os.path.join(base_path, 'mu_conv_vae.npz'), data_mu)
    np.savez(os.path.join(base_path, 'sigma_conv_vae.npz'), data_sigma)
コード例 #9
0
def generate(frame_num, gen_latents, orig_img_path, save_img_path):
    z_q_x = vmodel.embedding(gen_latents.view(gen_latents.size(0), -1))
    z_q_x = z_q_x.view(gen_latents.shape[0], 6, 6, -1).permute(0, 3, 1, 2)
    x_d = vmodel.decoder(z_q_x)
    if save_img_path is not None:
        x_tilde = sample_from_discretized_mix_logistic(x_d, nr_logistic_mix)
        pred = (((np.array(x_tilde.cpu().data)[0, 0] + 1.0) / 2.0) *
                float(max_pixel - min_pixel)) + min_pixel
        # input x is between 0 and 1
        real = imread(orig_img_path)
        f, ax = plt.subplots(1, 3, figsize=(10, 3))
        ax[0].imshow(real, vmin=0, vmax=max_pixel)
        ax[0].set_title("original frame %s" % frame_num)
        ax[1].imshow(pred, vmin=0, vmax=max_pixel)
        ax[1].set_title("pred")
        ax[2].imshow((pred - real)**2, cmap='gray')
        ax[2].set_title("error")
        f.tight_layout()
        plt.savefig(save_img_path)
        plt.close()
        print("saving example image")
        print("rsync -avhp [email protected]://%s" %
              os.path.abspath(save_img_path))
コード例 #10
0
def train(config, mode='cifar10'):
    model_name = 'pcnn_lr:{:.5f}_nr-resnet{}_nr-filters{}'.format(config.lr, config.nr_resnet, config.nr_filters)
    try:
        os.makedirs('models')
        os.makedirs('images')
        # print('mkdir:', config.outfile)
    except OSError:
        pass

    seed = np.random.randint(0, 10000)
    print("Random Seed: ", seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.benchmark = True

    trainset, train_loader, testset, test_loader, classes = load_data(mode=mode, batch_size=config.batch_size)
    if mode == 'cifar10' or mode == 'faces':
        obs = (3, 32, 32)
        loss_op = lambda real, fake: discretized_mix_logistic_loss(real, fake, config.nr_logistic_mix)
        sample_op = lambda x: sample_from_discretized_mix_logistic(x, config.nr_logistic_mix)
    elif mode == 'mnist':
        obs = (1, 28, 28)
        loss_op = lambda real, fake: discretized_mix_logistic_loss_1d(real, fake, config.nr_logistic_mix)
        sample_op = lambda x: sample_from_discretized_mix_logistic_1d(x, config.nr_logistic_mix)
    sample_batch_size = 25
    rescaling_inv = lambda x: .5 * x + .5

    model = PixelCNN(nr_resnet=config.nr_resnet, nr_filters=config.nr_filters,
                     input_channels=obs[0], nr_logistic_mix=config.nr_logistic_mix).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=config.lr_decay)

    if config.load_params:
        load_part_of_model(model, config.load_params)
        print('model parameters loaded')

    def sample(model):
        model.train(False)
        data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2])
        data = data.cuda()
        with tqdm(total=obs[1] * obs[2]) as pbar:
            for i in range(obs[1]):
                for j in range(obs[2]):
                    with torch.no_grad():
                        data_v = data
                        out = model(data_v, sample=True)
                        out_sample = sample_op(out)
                        data[:, :, i, j] = out_sample.data[:, :, i, j]
                    pbar.update(1)
        return data

    print('starting training')
    for epoch in range(config.max_epochs):
        model.train()
        torch.cuda.synchronize()
        train_loss = 0.
        time_ = time.time()
        with tqdm(total=len(train_loader)) as pbar:
            for batch_idx, (data, label) in enumerate(train_loader):
                data = data.requires_grad_(True).cuda()

                output = model(data)
                loss = loss_op(data, output)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
                pbar.update(1)

        deno = batch_idx * config.batch_size * np.prod(obs)
        print('train loss : %s' % (train_loss / deno), end='\t')

        # decrease learning rate
        scheduler.step()

        model.eval()
        test_loss = 0.
        with tqdm(total=len(test_loader)) as pbar:
            for batch_idx, (data, _) in enumerate(test_loader):
                data = data.requires_grad_(False).cuda()

                output = model(data)
                loss = loss_op(data, output)
                test_loss += loss.item()
                del loss, output
                pbar.update(1)
        deno = batch_idx * config.batch_size * np.prod(obs)
        print('test loss : {:.4f}, time : {:.4f}'.format((test_loss / deno), (time.time() - time_)))

        torch.cuda.synchronize()

        if (epoch + 1) % config.save_interval == 0:
            torch.save(model.state_dict(), 'models/{}_{}.pth'.format(model_name, epoch))
            print('sampling...')
            sample_t = sample(model)
            sample_t = rescaling_inv(sample_t)
            save_image(sample_t, 'images/{}_{}.png'.format(model_name, epoch), nrow=5, padding=0)
コード例 #11
0
def call_plot(model_dict, data_dict, info, sample, tsne, pca):
    from utils import tsne_plot
    from utils import pca_plot
    from sklearn.cluster import KMeans
    # always be in eval mode - so we dont swap neighbors
    model_dict = set_model_mode(model_dict, 'valid')
    srandom_state = np.random.RandomState(1234)
    with torch.no_grad():
        for phase in ['train', 'valid']:
            batch_index = srandom_state.randint(0,
                                                len(data_dict[phase].dataset),
                                                info['batch_size'])
            print(batch_index)
            data = torch.stack([
                data_dict[phase].dataset.indexed_dataset[index][0]
                for index in batch_index
            ])
            label = torch.stack([
                data_dict[phase].dataset.indexed_dataset[index][1]
                for index in batch_index
            ])
            batch_index = torch.LongTensor(batch_index)
            data = torch.FloatTensor(data)
            fp_out = forward_pass(model_dict, data, label, batch_index,
                                  'valid', info)
            if info['vq_decoder']:
                model_dict, data, target, rec_dml, u_q, u_p, s_p, z_e_x, z_q_x, latents = fp_out
            else:
                model_dict, data, target, rec_dml, u_q, u_p, s_p = fp_out

            bs, c, h, w = data.shape
            rec_yhat = sample_from_discretized_mix_logistic(
                rec_dml,
                info['nr_logistic_mix'],
                only_mean=info['sample_mean'],
                sampling_temperature=info['sampling_temperature'])
            data = data.detach().cpu().numpy()
            rec = rec_yhat.detach().cpu().numpy()
            u_q_flat = u_q.view(bs, info['code_length'])
            # choose limited number to plot
            n = min([20, bs])
            n_neighbors = args.num_k
            if sample:
                all_neighbor_distances, all_neighbor_indexes = model_dict[
                    'prior_model'].kneighbors(u_q_flat,
                                              n_neighbors=n_neighbors)
                all_neighbor_indexes = all_neighbor_indexes.cpu().numpy()
                all_neighbor_distances = all_neighbor_distances.cpu().numpy()

                n_cols = 2 + n_neighbors
                tbatch_index = batch_index.cpu().numpy()
                np_label = label.cpu().numpy()
                for i in np.arange(0, n):
                    # plot each base image
                    plt_path = info['model_loadpath'].replace(
                        '.pt', '_batch_rec_neighbors_%s_%06d_plt.png' %
                        (phase, tbatch_index[i]))
                    # bi 5136
                    neighbor_indexes = all_neighbor_indexes[i]
                    code = u_q[i].view(
                        (1, model_dict['acn_model'].bottleneck_channels,
                         model_dict['acn_model'].eo,
                         model_dict['acn_model'].eo)).cpu().numpy()
                    f, ax = plt.subplots(4, n_cols)
                    ax[0,
                       0].set_title('L%sI%s' % (np_label[i], tbatch_index[i]))
                    ax[0, 0].set_ylabel('true')
                    ax[0, 0].matshow(data[i, 0])
                    ax[1, 0].set_ylabel('rec')
                    ax[1, 0].matshow(rec[i, 0])
                    ax[2, 0].matshow(code[0, 0])
                    ax[3, 0].matshow(code[0, 1])

                    neighbor_data = torch.stack([
                        data_dict['train'].dataset.indexed_dataset[index][0]
                        for index in neighbor_indexes
                    ])
                    neighbor_label = torch.stack([
                        data_dict['train'].dataset.indexed_dataset[index][1]
                        for index in neighbor_indexes
                    ])
                    # u_q_flat
                    neighbor_codes_flat = model_dict['prior_model'].codes[
                        neighbor_indexes]
                    neighbor_codes = neighbor_codes_flat.view(
                        n_neighbors,
                        model_dict['acn_model'].bottleneck_channels,
                        model_dict['acn_model'].eo, model_dict['acn_model'].eo)
                    if info['vq_decoder']:
                        neighbor_rec_dml, _, _, _ = model_dict[
                            'acn_model'].decode(
                                neighbor_codes.to(info['device']))
                    else:
                        neighbor_rec_dml = model_dict['acn_model'].decode(
                            neighbor_codes.to(info['device']))
                    neighbor_data = neighbor_data.cpu().numpy()
                    neighbor_rec_yhat = sample_from_discretized_mix_logistic(
                        neighbor_rec_dml,
                        info['nr_logistic_mix'],
                        only_mean=info['sample_mean'],
                        sampling_temperature=info['sampling_temperature']).cpu(
                        ).numpy()
                    for ni in range(n_neighbors):
                        nindex = all_neighbor_indexes[i, ni].item()
                        nlabel = neighbor_label[ni].cpu().numpy()
                        ncode = neighbor_codes[ni].cpu().numpy()
                        ax[0, ni + 2].set_title('L%sI%s' % (nlabel, nindex))
                        ax[0, ni + 2].matshow(neighbor_data[ni, 0])
                        ax[1, ni + 2].matshow(neighbor_rec_yhat[ni, 0])
                        ax[2, ni + 2].matshow(ncode[0])
                        ax[3, ni + 2].matshow(ncode[1])
                    ax[2, 0].set_ylabel('lc0')
                    ax[3, 0].set_ylabel('lc1')
                    [ax[xx, 0].set_xticks([]) for xx in range(4)]
                    [ax[xx, 0].set_yticks([]) for xx in range(4)]
                    for xx in range(4):
                        [ax[xx, col].axis('off') for col in range(1, n_cols)]
                    plt.subplots_adjust(wspace=0, hspace=0)
                    plt.tight_layout()
                    print('plotting', plt_path)
                    plt.savefig(plt_path)
                    plt.close()
            X = u_q_flat.cpu().numpy()
            #km = KMeans(n_clusters=10)
            #y = km.fit_predict(X)
            # color points based on clustering, label, or index
            color = label.cpu().numpy()  #y #batch_indexes
            if tsne:
                param_name = '_tsne_%s_P%s.html' % (phase, info['perplexity'])
                html_path = info['model_loadpath'].replace('.pt', param_name)
                if not os.path.exists(html_path):
                    tsne_plot(X=X,
                              images=data[:, 0],
                              color=color,
                              perplexity=info['perplexity'],
                              html_out_path=html_path,
                              serve=False)
            if pca:
                param_name = '_pca_%s.html' % (phase)
                html_path = info['model_loadpath'].replace('.pt', param_name)
                if not os.path.exists(html_path):
                    pca_plot(X=X,
                             images=data[:, 0],
                             color=color,
                             html_out_path=html_path,
                             serve=False)
コード例 #12
0
def run(train_cnt, model_dict, data_dict, phase, info):
    st = time.time()
    loss_dict = {
        'running': 0,
        'kl': 0,
        'rec_%s' % info['rec_loss_type']: 0,
        'loss': 0,
    }
    if info['vq_decoder']:
        loss_dict['vq'] = 0
        loss_dict['commit'] = 0

    dataset = data_dict[phase]
    num_batches = len(dataset) // info['batch_size']
    print(phase, 'num batches', num_batches)
    set_model_mode(model_dict, phase)
    torch.set_grad_enabled(phase == 'train')
    batch_cnt = 0

    data_loader = data_dict[phase]
    num_batches = len(data_loader)
    for idx, (data, label, batch_index) in enumerate(data_loader):
        for key in model_dict.keys():
            model_dict[key].zero_grad()
        fp_out = forward_pass(model_dict, data, label, batch_index, phase,
                              info)
        if info['vq_decoder']:
            model_dict, data, target, rec_dml, u_q, u_p, s_p, z_e_x, z_q_x, latents = fp_out
        else:
            model_dict, data, target, rec_dml, u_q, u_p, s_p = fp_out
        bs, c, h, w = data.shape
        if batch_cnt == 0:
            log_ones = torch.zeros(bs, info['code_length']).to(info['device'])
        if bs != log_ones.shape[0]:
            log_ones = torch.zeros(bs, info['code_length']).to(info['device'])
        kl = kl_loss_function(u_q.view(bs, info['code_length']),
                              log_ones,
                              u_p.view(bs, info['code_length']),
                              s_p.view(bs, info['code_length']),
                              reduction=info['reduction'])
        rec_loss = discretized_mix_logistic_loss(
            rec_dml,
            target,
            nr_mix=info['nr_logistic_mix'],
            reduction=info['reduction'])

        if info['vq_decoder']:
            vq_loss = F.mse_loss(z_q_x,
                                 z_e_x.detach(),
                                 reduction=info['reduction'])
            commit_loss = F.mse_loss(z_e_x,
                                     z_q_x.detach(),
                                     reduction=info['reduction'])
            commit_loss *= info['vq_commitment_beta']
            loss_dict['vq'] += vq_loss.detach().cpu().item()
            loss_dict['commit'] += commit_loss.detach().cpu().item()
            loss = kl + rec_loss + commit_loss + vq_loss
        else:
            loss = kl + rec_loss

        loss_dict['running'] += bs
        loss_dict['rec_%s' %
                  info['rec_loss_type']] += rec_loss.detach().cpu().item()
        loss_dict['loss'] += loss.detach().cpu().item()
        loss_dict['kl'] += kl.detach().cpu().item()
        loss_dict['running'] += bs
        loss_dict['loss'] += loss.detach().cpu().item()
        loss_dict['kl'] += kl.detach().cpu().item()
        loss_dict['rec_%s' %
                  info['rec_loss_type']] += rec_loss.detach().cpu().item()
        if phase == 'train':
            model_dict = clip_parameters(model_dict)
            loss.backward()
            model_dict['opt'].step()
            train_cnt += bs
        if batch_cnt == num_batches - 1:
            # store example near end for plotting
            rec_yhat = sample_from_discretized_mix_logistic(
                rec_dml,
                info['nr_logistic_mix'],
                only_mean=info['sample_mean'],
                sampling_temperature=info['sampling_temperature'])
            example = {
                'target': data.detach().cpu().numpy(),
                'rec': rec_yhat.detach().cpu().numpy(),
            }
        if not batch_cnt % 100:
            print(train_cnt, batch_cnt, account_losses(loss_dict))
            print(phase, 'cuda', torch.cuda.memory_allocated(device=None))
        batch_cnt += 1

    loss_avg = account_losses(loss_dict)
    torch.cuda.empty_cache()
    print("finished %s after %s secs at cnt %s" % (
        phase,
        time.time() - st,
        train_cnt,
    ))
    del data
    del target
    return loss_avg, example
コード例 #13
0
ファイル: model.py プロジェクト: zd-daniel/gated_pixelCNN
        return x_out


if __name__ == '__main__':
    # img = torch.zeros(8, 3, 32, 32).float().uniform_(-1, 1).cuda()
    # # img = torch.zeros(8, 3, 32, 32).float().cuda()
    # model = PixelCNN(nr_resnet=3, nr_filters=100, input_channels=img.size(1)).cuda()
    # out = model(img)
    #
    # loss = discretized_mix_logistic_loss(img, out)
    # print('loss : %s' % loss.item())

    img = torch.zeros(1, 3, 32, 32).float().cuda()
    model = PixelCNN(nr_resnet=5, nr_filters=160, input_channels=img.size(1), nr_logistic_mix=10).cuda()
    load_part_of_model(model, 'models/pcnn_lr_0.00020_nr-resnet5_nr-filters160_58.pth')
    sample_op = lambda x: sample_from_discretized_mix_logistic(x, 10)

    from tqdm import tqdm
    with tqdm(total=32*32) as pbar:
        for i in range(32):
            for j in range(32):
                with torch.no_grad():
                    data_v = img
                    out = model(data_v, sample=True)
                    out_sample = sample_op(out)
                    img[:, :, i, j] = out_sample.data[:, :, i, j]
                pbar.update(1)
    from torchvision.utils import save_image
    save_image(img.data.cpu(), '1.jpg', nrow=1)
コード例 #14
0
ファイル: plot_rnn.py プロジェクト: johannah/trajectories
def generate_imgs(dataloader, output_filepath, true_img_path, data_type,
                  transform):
    if not os.path.exists(output_filepath):
        os.makedirs(output_filepath)
    for batch_idx, (data_mu_diff_scaled, data_mu_diff, data_mu_orig,
                    data_sigma_diff_scaled, data_sigma_diff, data_sigma_orig,
                    name) in enumerate(dataloader):
        # data_mu_orig will be one longer than the diff versions
        batch_size = data_mu_diff_scaled.shape[0]
        # predict one less time step than availble (first is input)
        n_timesteps = data_mu_diff_scaled.shape[1]
        vae_input_size = 800
        #######################
        # get rnn details
        #######################

        rnn_data = data_mu_diff_scaled.permute(1, 0, 2)
        seq = Variable(torch.FloatTensor(rnn_data), requires_grad=False)
        h1_tm1 = Variable(torch.FloatTensor(np.zeros(
            (batch_size, hidden_size))),
                          requires_grad=False)
        c1_tm1 = Variable(torch.FloatTensor(np.zeros(
            (batch_size, hidden_size))),
                          requires_grad=False)
        h2_tm1 = Variable(torch.FloatTensor(np.zeros(
            (batch_size, hidden_size))),
                          requires_grad=False)
        c2_tm1 = Variable(torch.FloatTensor(np.zeros(
            (batch_size, hidden_size))),
                          requires_grad=False)
        if use_cuda:
            mus_vae = mus_vae.cuda()
            seq = seq.cuda()
            out_mu = out_mu.cuda()
            h1_tm1 = h1_tm1.cuda()
            c1_tm1 = c1_tm1.cuda()
            h2_tm1 = h2_tm1.cuda()
            c2_tm1 = c2_tm1.cuda()
        # get time offsets correct
        x = seq[:-1]
        # put initial step in
        rnn_outputs = [seq[0]]
        gt_outputs = [seq[0]]
        nrnn_outputs = [seq[0].cpu().data.numpy()]
        ngt_outputs = [seq[0].cpu().data.numpy()]
        for i in range(len(x)):
            # number of frames to start with
            #if i < 4:
            output, h1_tm1, c1_tm1, h2_tm1, c2_tm1 = rnn(
                x[i], h1_tm1, c1_tm1, h2_tm1, c2_tm1)
            #else:
            #    output, h1_tm1, c1_tm1, h2_tm1, c2_tm1 = rnn(output, h1_tm1, c1_tm1, h2_tm1, c2_tm1)
            nrnn_outputs += [output.cpu().data.numpy()]
            rnn_outputs += [output]

            # put ground truth in to check pipeline
            ngt_outputs += [seq[i + 1].cpu().data.numpy()]
            gt_outputs += [seq[i + 1]]
            print(output.sum().data[0], seq[i + 1].sum().data[0])

        # vae data shoud be batch,timestep(example),features
        # 0th frame is the same here
        gt_rnn_pred = torch.stack(gt_outputs, 0)
        rnn_pred = torch.stack(rnn_outputs, 0)

        # 0th frame is the same here
        rnn_mu_diff_scaled = rnn_pred.permute(1, 0, 2).data.numpy()
        gt_rnn_mu_diff_scaled = gt_rnn_pred.permute(1, 0, 2).data.numpy()

        nrnn_mu_diff_scaled = np.swapaxes(np.array(nrnn_outputs), 0, 1)
        ngt_rnn_mu_diff_scaled = np.swapaxes(np.array(ngt_outputs), 0, 1)

        # only use relevant mus
        orig_mu_placeholder = Variable(torch.FloatTensor(
            np.zeros((n_timesteps, vae_input_size))),
                                       requires_grad=False)
        diff_mu_placeholder = Variable(torch.FloatTensor(
            np.zeros((n_timesteps, vae_input_size))),
                                       requires_grad=False)
        diff_mu_unscaled_placeholder = Variable(torch.FloatTensor(
            np.zeros((n_timesteps, vae_input_size))),
                                                requires_grad=False)
        diff_mu_unscaled_rnn_placeholder = Variable(torch.FloatTensor(
            np.zeros((n_timesteps, vae_input_size))),
                                                    requires_grad=False)
        gt_diff_mu_unscaled_rnn_placeholder = Variable(torch.FloatTensor(
            np.zeros((n_timesteps, vae_input_size))),
                                                       requires_grad=False)

        if transform == "std":
            print("removing standard deviation transform")
            # convert to numpy so broadcasting works
            rnn_mu_diff_unscaled = torch.FloatTensor((rnn_mu_diff_scaled *
                                                      mu_diff_std) +
                                                     mu_diff_mean[None])
            gt_rnn_mu_diff_unscaled = torch.FloatTensor(
                (gt_rnn_mu_diff_scaled * mu_diff_std) + mu_diff_mean[None])
            data_mu_diff_unscaled = torch.FloatTensor(
                (data_mu_diff_scaled.numpy() * mu_diff_std) +
                mu_diff_mean[None])
        else:
            print("no transform")
            rnn_mu_diff_unscaled = rnn_mu_diff_scaled
            gt_rnn_mu_diff_unscaled = gt_rnn_mu_diff_scaled
            data_mu_diff_unscaled = data_mu_diff_scaled

        # go through each distinct episode (should be length of 167)
        for e in range(batch_size):
            basename = os.path.split(name[e])[1].replace('.npz', '')
            if not e:
                print("starting %s" % basename)
            basepath = os.path.join(output_filepath, basename)
            # reconstruct rnn vae
            # now the size going through the decoder is 169x32x5x5
            # original data is one longer since there was no diff applied
            ep_mu_orig = data_mu_orig[e, 1:]
            ep_mu_diff = data_mu_diff[e]
            ep_mu_diff_unscaled = data_mu_diff_unscaled[e]
            ep_mu_diff_unscaled_rnn = rnn_mu_diff_unscaled[e]
            gt_ep_mu_diff_unscaled_rnn = gt_rnn_mu_diff_unscaled[e]

            primer_frame = data_mu_orig[e, 0, :]
            # need to reconstruct from original
            # get the first frame from the original dataset to add diffs to
            # data_mu_orig will be one frame longer
            # unscale the scaled version
            ep_mu_diff[0] += primer_frame
            ep_mu_diff_unscaled[0] += primer_frame
            ep_mu_diff_unscaled_rnn[0] += primer_frame
            gt_ep_mu_diff_unscaled_rnn[0] += primer_frame
            print("before diff add")
            for diff_frame in range(1, n_timesteps):
                #print("adding diff to %s" %diff_frame)
                ep_mu_diff[diff_frame] += ep_mu_diff[diff_frame - 1]
                ep_mu_diff_unscaled[diff_frame] += ep_mu_diff_unscaled[
                    diff_frame - 1]
                ep_mu_diff_unscaled_rnn[diff_frame] += ep_mu_diff_unscaled_rnn[
                    diff_frame - 1]
                gt_ep_mu_diff_unscaled_rnn[
                    diff_frame] += gt_ep_mu_diff_unscaled_rnn[diff_frame - 1]

            rnn_mu_img = ep_mu_diff_unscaled_rnn.numpy()
            gt_rnn_mu_img = gt_ep_mu_diff_unscaled_rnn.numpy()
            ff, axf = plt.subplots(1, 2, figsize=(5, 10))
            axf[0].imshow(gt_rnn_mu_img, origin='lower')
            axf[0].set_title("gt_rnn_mu")
            axf[1].imshow(rnn_mu_img, origin='lower')
            axf[1].set_title("rnn_mu")
            ff.tight_layout()
            fimg_name = basepath + '_rnn_mu_plot.png'
            fimg_name = fimg_name.replace('_frame_%05d' % 0, '')
            print("plotted %s" % fimg_name)
            plt.savefig(fimg_name)
            plt.close()

            orig_mu_placeholder[:, best_inds] = Variable(
                torch.FloatTensor(ep_mu_orig))
            diff_mu_placeholder[:, best_inds] = Variable(
                torch.FloatTensor(ep_mu_diff))
            diff_mu_unscaled_placeholder[:, best_inds] = Variable(
                torch.FloatTensor(ep_mu_diff_unscaled))
            diff_mu_unscaled_rnn_placeholder[:, best_inds] = Variable(
                torch.FloatTensor(ep_mu_diff_unscaled_rnn))
            gt_diff_mu_unscaled_rnn_placeholder[:, best_inds] = Variable(
                torch.FloatTensor(gt_ep_mu_diff_unscaled_rnn))
            #for i in range(1,diff_mu_unscaled_rnn_placeholder.shape[0]):
            #    diff_mu_unscaled_rnn_placeholder[i] = gt_diff_mu_unscaled_rnn_placeholder[i]
            # add a placeholder here if you want to process it
            mu_types = OrderedDict([
                ('orig', orig_mu_placeholder),
                #  ('diff',diff_mu_placeholder),
                #  ('diff_unscaled',diff_mu_unscaled_placeholder),
                ('gtrnn', gt_diff_mu_unscaled_rnn_placeholder),
                ('rnn', diff_mu_unscaled_rnn_placeholder),
            ])
            mu_reconstructed = OrderedDict()
            # get reconstructed image for each type
            for xx, mu_output_name in enumerate(mu_types.keys()):
                mu_output = mu_types[mu_output_name]
                cuts = get_cuts(mu_output.shape[0], 1)
                print(mu_output_name,
                      mu_output.sum().data[0], mu_output[0].sum().data[0])
                x_tildes = []
                for (s, e) in cuts:
                    mu_batch = mu_output[s:e]
                    # only put part of the episdoe through
                    x_d = vae.decoder(mu_batch.contiguous().view(
                        mu_batch.shape[0], 32, 5, 5))
                    x_tilde = sample_from_discretized_mix_logistic(
                        x_d, nr_logistic_mix, deterministic=True)
                    x_tildes.append(x_tilde.cpu().data.numpy())
                nx_tilde = np.array(x_tildes)[:, 0, 0]
                inx_tilde = ((0.5 * nx_tilde + 0.5) * 255).astype(np.uint8)
                mu_reconstructed[mu_output_name] = inx_tilde

            for frame_num in range(n_timesteps):
                true_img_name = os.path.join(
                    true_img_path,
                    basename.replace('_conv_vae',
                                     '.png')).replace('frame_%05d' % 0,
                                                      'frame_%05d' % frame_num)
                true_img = imread(true_img_name)
                print("true img %s" % true_img_name)
                num_imgs = len(mu_reconstructed.keys()) + 1
                f, ax = plt.subplots(1, num_imgs, figsize=(3 * num_imgs, 3))

                ax[0].imshow(true_img, origin='lower')
                ax[0].set_title('true frame %04d' % frame_num)

                for ii, mu_output_name in enumerate(mu_reconstructed.keys()):
                    ax[ii + 1].imshow(
                        mu_reconstructed[mu_output_name][frame_num],
                        origin='lower')
                    ax[ii + 1].set_title(mu_output_name)

                f.tight_layout()
                img_name = basepath + '_rnn_plot.png'
                img_name = img_name.replace('frame_%05d' % 0,
                                            'frame_%05d' % frame_num)
                print("plotted %s" % img_name)
                plt.savefig(img_name)
                plt.close()