Exemple #1
0
def valid_vqvae(train_cnt, do_plot=False):
    vqvae_model.eval()
    #states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = valid_data_loader.get_unique_minibatch()
    states, actions, rewards, values, pred_states, terminals, is_new_epoch, relative_indexes = valid_data_loader.get_framediff_minibatch(
    )
    # because we have 4 layers in vqvae, need to be divisible by 2, 4 times
    states = (2 * reshape_input(states) - 1).to(DEVICE)
    rec = (2 * reshape_input(pred_states[:, 0][:, None]) - 1).to(DEVICE)
    diff = (2 * reshape_input(pred_states[:, 1][:, None]) - 1).to(DEVICE)
    actions = actions.to(DEVICE)
    values = values.to(DEVICE)
    x_d, z_e_x, z_q_x, latents, pred_actions, pred_values = vqvae_model(states)
    # (args.nr_logistic_mix/2)*3 is needed for each reconstruction
    z_q_x.retain_grad()
    rec_est = x_d[:, :nmix]
    diff_est = x_d[:, nmix:]
    loss_rec = discretized_mix_logistic_loss(rec_est,
                                             rec,
                                             nr_mix=args.nr_logistic_mix,
                                             DEVICE=DEVICE)
    loss_diff = discretized_mix_logistic_loss(diff_est,
                                              diff,
                                              nr_mix=args.nr_logistic_mix,
                                              DEVICE=DEVICE)
    loss_act = F.nll_loss(pred_actions, actions)
    loss_act.backward(retain_graph=True)
    loss_values = args.ralpha * F.mse_loss(pred_values, values)
    loss_values.backward(retain_graph=True)
    loss_2 = F.mse_loss(z_q_x, z_e_x.detach())
    loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach())
    bs, yc, yh, yw = x_d.shape
    yhat = sample_from_discretized_mix_logistic(rec_est, args.nr_logistic_mix)
    if do_plot:
        print('writing img')
        n_imgs = 8
        n = min(states.shape[0], n_imgs)
        gold = (rec.to('cpu') + 1) / 2.0
        bs, _, h, w = gold.shape
        # sample from discretized should be between 0 and 255
        print("yhat sample", yhat[:, 0].min().item(), yhat[:, 0].max().item())
        yimg = ((yhat + 1.0) / 2.0).to('cpu')
        print("yhat img", yhat.min().item(), yhat.max().item())
        print("gold img", gold.min().item(), gold.max().item())
        comparison = torch.cat(
            [gold.view(bs, 1, h, w)[:n],
             yimg.view(bs, 1, h, w)[:n]])
        img_name = model_base_filepath + "_%010d_valid_reconstruction.png" % train_cnt
        save_image(comparison, img_name, nrow=n)
    bs = float(states.shape[0])
    loss_list = [
        loss_values.item() / bs,
        loss_act.item() / bs,
        loss_rec.item() / bs,
        loss_diff.item() / bs,
        loss_2.item() / bs,
        loss_3.item() / bs
    ]
    return loss_list
Exemple #2
0
def valid_vqvae(train_cnt, vqvae_model, info, valid_data_loader, do_plot=True):
    vqvae_model.eval()
    states, actions, rewards, values, pred_states, terminals, is_new_epoch, relative_indexes = valid_data_loader.get_framediff_minibatch(
    )
    states = (2 * reshape_input(torch.FloatTensor(states)) - 1).to(
        info['DEVICE'])
    rec = (2 * reshape_input(torch.FloatTensor(pred_states)[:, 0][:, None]) -
           1).to(info['DEVICE'])
    actions = torch.LongTensor(actions).to(info['DEVICE'])
    #rewards = torch.LongTensor(rewards).to(DEVICE)
    # dont normalize diff
    diff = (reshape_input(torch.FloatTensor(pred_states)[:, 1][:, None])).to(
        info['DEVICE'])
    x_d, z_e_x, z_q_x, latents, pred_actions = vqvae_model(states)
    z_q_x.retain_grad()
    rec_est = x_d[:, :info['nmix']]
    diff_est = x_d[:, info['nmix']:]
    loss_rec = info['ALPHA_REC'] * discretized_mix_logistic_loss(
        rec_est, rec, info['NR_LOGISTIC_MIX'], DEVICE=info['DEVICE'])
    loss_diff = discretized_mix_logistic_loss(diff_est,
                                              diff,
                                              nr_mix=info['NR_LOGISTIC_MIX'],
                                              DEVICE=info['DEVICE'])
    loss_act = info['ALPHA_ACT'] * F.nll_loss(
        pred_actions, actions, weight=info['actions_weight'])
    loss_act.backward(retain_graph=True)
    loss_2 = F.mse_loss(z_q_x, z_e_x.detach())
    loss_3 = info['BETA'] * F.mse_loss(z_e_x, z_q_x.detach())
    bs, yc, yh, yw = x_d.shape
    yhat = sample_from_discretized_mix_logistic(rec_est,
                                                info['NR_LOGISTIC_MIX'])
    if do_plot:
        n_imgs = 8
        n = min(states.shape[0], n_imgs)
        gold = (rec.to('cpu') + 1) / 2.0
        bs, _, h, w = gold.shape
        # sample from discretized should be between 0 and 255
        print("yhat sample", yhat[:, 0].min().item(), yhat[:, 0].max().item())
        yimg = ((yhat + 1.0) / 2.0).to('cpu')
        print("yhat img", yhat.min().item(), yhat.max().item())
        print("gold img", gold.min().item(), gold.max().item())
        comparison = torch.cat(
            [gold.view(bs, 1, h, w)[:n],
             yimg.view(bs, 1, h, w)[:n]])
        img_name = info[
            'vq_model_base_filepath'] + "_%010d_valid_reconstruction.png" % train_cnt
        save_image(comparison, img_name, nrow=n)
    bs = float(states.shape[0])
    loss_list = [
        loss_act.item() / bs,
        loss_rec.item() / bs,
        loss_diff.item() / bs,
        loss_2.item() / bs,
        loss_3.item() / bs
    ]
    return loss_list
Exemple #3
0
def train_vqvae(train_cnt):
    st = time.time()
    #for batch_idx, (data, label, data_index) in enumerate(train_loader):
    batches = 0
    while train_cnt < args.num_examples_to_train:
        vqvae_model.train()
        opt.zero_grad()
        #states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch()
        states, actions, rewards, values, pred_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_framediff_minibatch(
        )
        # because we have 4 layers in vqvae, need to be divisible by 2, 4 times
        states = (2 * reshape_input(states) - 1).to(DEVICE)
        rec = (2 * reshape_input(pred_states[:, 0][:, None]) - 1).to(DEVICE)
        # dont normalize diff
        diff = (reshape_input(pred_states[:, 1][:, None])).to(DEVICE)
        x_d, z_e_x, z_q_x, latents = vqvae_model(states)
        # (args.nr_logistic_mix/2)*3 is needed for each reconstruction
        z_q_x.retain_grad()
        rec_est = x_d[:, :nmix]
        diff_est = x_d[:, nmix:]
        loss_rec = discretized_mix_logistic_loss(rec_est,
                                                 rec,
                                                 nr_mix=args.nr_logistic_mix,
                                                 DEVICE=DEVICE)
        loss_diff = discretized_mix_logistic_loss(diff_est,
                                                  diff,
                                                  nr_mix=args.nr_logistic_mix,
                                                  DEVICE=DEVICE)
        loss_2 = F.mse_loss(z_q_x, z_e_x.detach())
        loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach())
        loss_rec.backward(retain_graph=True)
        loss_diff.backward(retain_graph=True)
        vqvae_model.embedding.zero_grad()
        z_e_x.backward(z_q_x.grad, retain_graph=True)
        loss_2.backward(retain_graph=True)
        loss_3.backward()
        parameters = list(vqvae_model.parameters())
        clip_grad_value_(parameters, 10)
        opt.step()
        bs = float(x_d.shape[0])
        loss_list = [
            loss_rec.item() / bs,
            loss_diff.item() / bs,
            loss_2.item() / bs,
            loss_3.item() / bs
        ]
        if batches > 5:
            handle_checkpointing(train_cnt, loss_list)
        train_cnt += len(states)
        batches += 1
        if not batches % 1000:
            print("finished %s epoch after %s seconds at cnt %s" %
                  (batches, time.time() - st, train_cnt))
    return train_cnt
def sample_batch(data, episode_number, episode_reward, name):
    with torch.no_grad():
        states, actions, rewards, next_states, terminals, reset, relative_indexes = data
        x = (2*reshape_input(states[:,-1:])-1).to(DEVICE)
        for i in range(states.shape[0]):
            x_d, z_e_x, z_q_x, latents = vqvae_model(x[i:i+1])
            loss_1 = discretized_mix_logistic_loss(x_d, x[i:i+1], nr_mix=largs.nr_logistic_mix, DEVICE=DEVICE)
            yhat = sample_from_discretized_mix_logistic(x_d, largs.nr_logistic_mix)
            yhat = (((yhat+1)/2.0)*255.0).cpu().numpy().astype(np.int)
            true = (states[i:i+1,-1:]*255.0).cpu().numpy().astype(np.int)
            f,ax = plt.subplots(1,2)
            iname = os.path.join(output_savepath, '%s_E%05d_R%03d_%05d.png'%(name, int(episode_number), int(episode_reward), i))
            print("writing", os.path.split(iname)[1])
            title = 'step %s/%s action %s reward %s' %(i, states.shape[0], actions[i].item(), rewards[i].item())
            ax[0].imshow(true[0,0])
            ax[0].set_title('true')
            ax[1].imshow(yhat[0,0])
            ax[1].set_title('est')
            plt.suptitle(title)
            plt.savefig(iname)
            print('saving', iname)
        search_path = iname[:-10:] + '*.png'
        gif_path = iname[:-10:] + '.gif'
        cmd = 'convert %s %s' %(search_path, gif_path)
        print('creating gif', gif_path)
        os.system(cmd)
def valid_vqvae(train_cnt, do_plot=False):
    vqvae_model.eval()
    opt.zero_grad()
    states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = valid_data_loader.get_unique_minibatch(
    )
    # because we have 4 layers in vqvae, need to be divisible by 2, 4 times
    states = reshape_input(states).to(DEVICE)
    # only predict future observation - normalize
    targets = (2 * states[:, -1:] - 1).to(DEVICE)
    #actions = actions.to(DEVICE)
    x_d, z_e_x, z_q_x, latents = vqvae_model(states, targets)
    loss_1 = discretized_mix_logistic_loss(x_d,
                                           targets,
                                           nr_mix=args.nr_logistic_mix,
                                           DEVICE=DEVICE)
    loss_2 = F.mse_loss(z_q_x, z_e_x.detach())
    loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach())
    #loss_1, loss_2, loss_3 = get_vqvae_loss(x_d, targets, z_e_x, z_q_x, nr_logistic_mix=args.nr_logistic_mix, beta=args.beta, device=DEVICE)
    bs, yc, yh, yw = x_d.shape
    yhat = sample_from_discretized_mix_logistic(x_d, args.nr_logistic_mix)
    if do_plot:
        print('writing img')
        n_imgs = 8
        n = min(states.shape[0], n_imgs)
        gold = states[:, -1:]
        bs, _, h, w = gold.shape
        comparison = torch.cat([
            gold.to('cpu').view(bs, 1, h, w)[:n],
            yhat.to('cpu').view(bs, 1, h, w)[:n]
        ])
        img_name = model_base_filepath + "_%010d_valid_reconstruction.png" % train_cnt
        save_image(comparison, img_name, nrow=n)
    bs = float(states.shape[0])
    return loss_1.item() / bs, loss_2.item() / bs, loss_3.item() / bs
def valid_vqvae(train_cnt, do_plot=False):
    vqvae_model.eval()
    states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = valid_data_loader.get_unique_minibatch(
    )
    # because we have 4 layers in vqvae, need to be divisible by 2, 4 times
    states = (2 * reshape_input(states[:, -1:]) - 1).to(DEVICE)
    x_d, z_e_x, z_q_x, latents = vqvae_model(states)
    z_q_x.retain_grad()
    loss_1 = discretized_mix_logistic_loss(x_d,
                                           states,
                                           nr_mix=args.nr_logistic_mix,
                                           DEVICE=DEVICE)
    loss_2 = F.mse_loss(z_q_x, z_e_x.detach())
    loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach())
    bs, yc, yh, yw = x_d.shape
    yhat = sample_from_discretized_mix_logistic(x_d, args.nr_logistic_mix)
    if do_plot:
        print('writing img')
        n_imgs = 8
        n = min(states.shape[0], n_imgs)
        gold = (states.to('cpu') + 1) / 2.0
        bs, _, h, w = gold.shape
        # sample from discretized should be between 0 and 255
        print("yhat sample", yhat.min(), yhat.max())
        yimg = ((yhat + 1.0) / 2.0).to('cpu')
        print("yhat img", yhat.min().item(), yhat.max().item())
        print("gold img", gold.min().item(), gold.max().item())
        comparison = torch.cat(
            [gold.view(bs, 1, h, w)[:n],
             yimg.view(bs, 1, h, w)[:n]])
        img_name = model_base_filepath + "_%010d_valid_reconstruction.png" % train_cnt
        save_image(comparison, img_name, nrow=n)
    bs = float(states.shape[0])
    return loss_1.item() / bs, loss_2.item() / bs, loss_3.item() / bs
def forward_pass(x, y):
    x = Variable(x, requires_grad=False).to(DEVICE)
    y = Variable(y, requires_grad=False).to(DEVICE)
    x_d, z_e_x, z_q_x, latents = vmodel(x)
    # with bigger model - latents is 64, 6, 6
    z_q_x.retain_grad()
    #loss_1 = F.binary_cross_entropy(x_d, x)
    # going into dml - x should be bt 0 and 1
    loss_1 = discretized_mix_logistic_loss(x_d, 2 * y - 1, DEVICE=DEVICE)
    loss_2 = F.mse_loss(z_q_x, z_e_x.detach())
    loss_3 = .25 * F.mse_loss(z_e_x, z_q_x.detach())
    return loss_1, loss_2, loss_3, x_d, z_e_x, z_q_x, latents
Exemple #8
0
def find_rec_losses(alpha, nr, nmix, x_d, true, DEVICE):
    rec_losses = []
    rec_ests = []
    # get reconstruction losses for each channel
    for i in range(true.shape[1]):
        st = i * nmix
        en = st + nmix
        pred_x_d = x_d[:, st:en]
        rec_ests.append(pred_x_d.detach())
        rloss = alpha * discretized_mix_logistic_loss(
            pred_x_d, true[:, i][:, None], nr_mix=nr, DEVICE=DEVICE)
        rec_losses.append(rloss)
    return rec_losses, rec_ests
def train_vqvae(train_cnt):
    st = time.time()
    #for batch_idx, (data, label, data_index) in enumerate(train_loader):
    batches = 0
    while train_cnt < args.num_examples_to_train:
        vqenc.train()
        pcnn_decoder.train()
        opt.zero_grad()
        states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch(
        )
        # because we have 4 layers in vqvae, need to be divisible by 2, 4 times
        states = reshape_input(states).to(DEVICE)
        # only predict future observation - normalize
        targets = (2 * states[:, -1:] - 1).to(DEVICE)
        #actions = actions.to(DEVICE)
        x_d, z_e_x, z_q_x, latents = vqvae_model(states, targets)
        #z_e_x, z_q_x, latents = vqenc(states)
        #float_condition = latents.view(latents.shape[0], latents.shape[1]*latents.shape[2]).float()
        #x_d = pcnn_decoder(targets, class_condition=actions, float_condition=float_condition)
        z_q_x.retain_grad()
        vqvae_model.spatial_condition.retain_grad()
        loss_1 = discretized_mix_logistic_loss(x_d,
                                               targets,
                                               nr_mix=args.nr_logistic_mix,
                                               DEVICE=DEVICE)
        loss_2 = F.mse_loss(z_q_x, z_e_x.detach())
        loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach())
        #loss_1, loss_2, loss_3 = get_vqvae_loss(x_d, targets, z_e_x, z_q_x, nr_logistic_mix=args.nr_logistic_mix, beta=args.beta, device=DEVICE)
        loss_1.backward(retain_graph=True)
        #vqvae_model.encoder.embedding.zero_grad()
        #z_e_x.backward(z_q_x.grad, retain_graph=True)
        z_e_x.backward(vqvae_model.spatial_condition.grad, retain_graph=True)
        loss_2.backward(retain_graph=True)
        loss_3.backward()
        parameters = list(vqvae_model.parameters())
        clip_grad_value_(parameters, 10)
        opt.step()

        bs = float(x_d.shape[0])
        handle_checkpointing(train_cnt,
                             loss_1.item() / bs,
                             loss_2.item() / bs,
                             loss_3.item() / bs)
        train_cnt += len(states)

        batches += 1
        if not batches % 1000:
            print("finished %s epoch after %s seconds at cnt %s" %
                  (batches, time.time() - st, train_cnt))
    return train_cnt
Exemple #10
0
def train_acn(train_cnt):
    #test_acn(0,True)
    vae_model.train()
    prior_model.train()
    train_loss = 0
    init_cnt = train_cnt
    st = time.time()
    train_buffer.reset_unique()
    #for batch_idx, (data, _, data_index) in enumerate(train_loader):
    while train_buffer.unique_available:
        #
        batch = train_buffer.get_unique_minibatch(args.batch_size)
        batch_idx = batch[-1]
        states, actions, rewards, next_states = make_state(
            batch[:-1], DEVICE, 255.)
        data = next_states[:, -1:]
        opt.zero_grad()
        z, u_q, s_q = vae_model(data)
        # add the predicted codes to the input
        prior_model.codes[batch_idx] = u_q.detach().cpu().numpy()
        prior_model.fit_knn(prior_model.codes)
        u_p, s_p = prior_model(u_q)
        kl = kl_loss_function(u_q, s_q, u_p, s_p)
        # decoder changed output of pcnn to number of channels needed for dml
        yhat_batch = vae_model.decoder(pcnn_decoder(x=data, float_condition=z))
        # input should be bt -1 and 1
        rec_loss = discretized_mix_logistic_loss(yhat_batch,
                                                 data,
                                                 nr_mix=nr_logistic_mix,
                                                 DEVICE=DEVICE)
        #yhat = sample_from_discretized_mix_logistic(yhat_batch, nr_logistic_mix)
        loss = kl + rec_loss
        loss.backward()
        train_loss += loss.item()
        opt.step()
        # add batch size because it hasn't been added to train cnt yet
        avg_train_loss = train_loss / float((train_cnt + data.shape[0]) -
                                            init_cnt)
        if train_cnt > 50000:
            handle_checkpointing(train_cnt, avg_train_loss)
        train_cnt += len(data)
    print("finished epoch after %s seconds at cnt %s" %
          (time.time() - st, train_cnt))
    return train_cnt
def train_vqvae(train_cnt):
    st = time.time()
    #for batch_idx, (data, label, data_index) in enumerate(train_loader):
    batches = 0
    while train_cnt < args.num_examples_to_train:
        vqvae_model.train()
        opt.zero_grad()
        states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch(
        )
        # because we have 4 layers in vqvae, need to be divisible by 2, 4 times
        states = (2 * reshape_input(states[:, -1:]) - 1).to(DEVICE)
        x_d, z_e_x, z_q_x, latents = vqvae_model(states)
        z_q_x.retain_grad()
        loss_1 = discretized_mix_logistic_loss(x_d,
                                               states,
                                               nr_mix=args.nr_logistic_mix,
                                               DEVICE=DEVICE)
        loss_2 = F.mse_loss(z_q_x, z_e_x.detach())
        loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach())
        loss_1.backward(retain_graph=True)
        vqvae_model.embedding.zero_grad()
        z_e_x.backward(z_q_x.grad, retain_graph=True)
        loss_2.backward(retain_graph=True)
        loss_3.backward()
        parameters = list(vqvae_model.parameters())
        clip_grad_value_(parameters, 10)
        opt.step()
        bs = float(x_d.shape[0])
        handle_checkpointing(train_cnt,
                             loss_1.item() / bs,
                             loss_2.item() / bs,
                             loss_3.item() / bs)
        train_cnt += len(states)

        batches += 1
        if not batches % 1000:
            print("finished %s epoch after %s seconds at cnt %s" %
                  (batches, time.time() - st, train_cnt))
    return train_cnt
Exemple #12
0
def train_vqvae(train_cnt, vqvae_model, opt, info, train_data_loader,
                valid_data_loader):
    st = time.time()
    #for batch_idx, (data, label, data_index) in enumerate(train_loader):
    batches = 0
    while train_cnt < info['VQ_NUM_EXAMPLES_TO_TRAIN']:
        vqvae_model.train()
        opt.zero_grad()
        states, actions, rewards, values, pred_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_framediff_minibatch(
        )
        # because we have 4 layers in vqvae, need to be divisible by 2, 4 times
        states = (2 * reshape_input(torch.FloatTensor(states)) - 1).to(
            info['DEVICE'])
        rec = (
            2 * reshape_input(torch.FloatTensor(pred_states)[:, 0][:, None]) -
            1).to(info['DEVICE'])
        actions = torch.LongTensor(actions).to(info['DEVICE'])
        rewards = torch.LongTensor(rewards).to(info['DEVICE'])
        # dont normalize diff
        diff = (reshape_input(
            torch.FloatTensor(pred_states)[:, 1][:, None])).to(info['DEVICE'])
        x_d, z_e_x, z_q_x, latents, pred_actions, pred_rewards = vqvae_model(
            states)
        z_q_x.retain_grad()
        rec_est = x_d[:, :info['nmix']]
        diff_est = x_d[:, info['nmix']:]
        loss_rec = info['ALPHA_REC'] * discretized_mix_logistic_loss(
            rec_est,
            rec,
            nr_mix=info['NR_LOGISTIC_MIX'],
            DEVICE=info['DEVICE'])
        loss_diff = discretized_mix_logistic_loss(
            diff_est,
            diff,
            nr_mix=info['NR_LOGISTIC_MIX'],
            DEVICE=info['DEVICE'])

        loss_act = info['ALPHA_ACT'] * F.nll_loss(
            pred_actions, actions, weight=info['actions_weight'])
        loss_rewards = info['ALPHA_REW'] * F.nll_loss(
            pred_rewards, rewards, weight=info['rewards_weight'])
        loss_2 = F.mse_loss(z_q_x, z_e_x.detach())

        loss_act.backward(retain_graph=True)
        loss_rec.backward(retain_graph=True)
        loss_diff.backward(retain_graph=True)

        loss_3 = info['BETA'] * F.mse_loss(z_e_x, z_q_x.detach())
        vqvae_model.embedding.zero_grad()
        z_e_x.backward(z_q_x.grad, retain_graph=True)
        loss_2.backward(retain_graph=True)
        loss_3.backward()

        parameters = list(vqvae_model.parameters())
        clip_grad_value_(parameters, 10)
        opt.step()
        bs = float(x_d.shape[0])
        avg_train_losses = [
            loss_rewards.item() / bs,
            loss_act.item() / bs,
            loss_rec.item() / bs,
            loss_diff.item() / bs,
            loss_2.item() / bs,
            loss_3.item() / bs
        ]
        if batches > info['VQ_MIN_BATCHES_BEFORE_SAVE']:
            if ((train_cnt - info['vq_last_save']) >= info['VQ_SAVE_EVERY']):
                info['vq_last_save'] = train_cnt
                info['vq_save_times'].append(time.time())
                avg_valid_losses = valid_vqvae(train_cnt, vqvae_model, info,
                                               valid_data_loader)
                handle_plot_ckpt(train_cnt, info, avg_train_losses,
                                 avg_valid_losses)
                filename = info[
                    'vq_model_base_filepath'] + "_%010dex.pt" % train_cnt
                print("SAVING MODEL:%s" % filename)
                print("Saving model at cnt:%s cnt since last saved:%s" %
                      (train_cnt, train_cnt - info['vq_last_save']))
                state = {
                    'vqvae_state_dict': vqvae_model.state_dict(),
                    'vq_optimizer': opt.state_dict(),
                    'vq_embedding': vqvae_model.embedding,
                    'vq_info': info,
                }
                save_checkpoint(state, filename=filename)

        train_cnt += len(states)
        batches += 1
        if not batches % 1000:
            print("finished %s epoch after %s seconds at cnt %s" %
                  (batches, time.time() - st, train_cnt))
    return train_cnt
Exemple #13
0
def test_acn(train_cnt, do_plot):
    vae_model.eval()
    prior_model.eval()
    test_loss = 0
    print('starting test', train_cnt)
    st = time.time()
    seen = 0
    with torch.no_grad():
        valid_buffer.reset_unique()
        for i in range(10):
            if valid_buffer.unique_available:
                batch = valid_buffer.get_unique_minibatch(args.batch_size)
                batch_idx = batch[-1]
                states, actions, rewards, next_states = make_state(
                    batch[:-1], DEVICE, 255.)
                data = next_states[:, -1:]
                # yhat_batch is bt 0-1
                z, u_q, s_q = vae_model(data)
                u_p, s_p = prior_model(u_q)
                kl = kl_loss_function(u_q, s_q, u_p, s_p)
                yhat_batch = vae_model.decoder(
                    pcnn_decoder(x=data, float_condition=z))
                rec_loss = discretized_mix_logistic_loss(
                    yhat_batch, data, nr_mix=nr_logistic_mix, DEVICE=DEVICE)
                loss = kl + rec_loss
                test_loss += loss.item()
                seen += data.shape[0]
                if i == 0:
                    if do_plot:
                        print('writing img')
                        n = min(data.size(0), 8)
                        bs = data.shape[0]
                        yhat = sample_from_discretized_mix_logistic(
                            yhat_batch, nr_logistic_mix, only_mean=True)
                        # sampled yhat_batch is bt 0-1
                        #yimg = yhat_batch
                        yimg = ((yhat + 1.0) / 2.0)
                        # yimg is bt 0.78 and 0.57 -
                        print('data', data.max(), data.min())
                        ## gold is bt 0 and .57
                        gold = (data + 1) / 2.0
                        #gold = data
                        print('bef', yhat_batch.max(), yhat_batch.min())
                        #print('sam', yhat.max(), yhat.min())
                        print('yimg', yimg.max(), yimg.min())
                        print('gold', gold.max(), gold.min())
                        bs, _, h, w = data.shape
                        # data should be between 0 and 1 to be plotted with
                        # save_image
                        assert (yimg.min() >= 0)
                        assert (yimg.max() <= 1)
                        comparison = torch.cat([
                            gold.view(bs, 1, h, w)[:n],
                            yimg.view(bs, 1, h, w)[:n]
                        ])
                        img_name = vae_base_filepath + "_%010d_valid_reconstruction.png" % train_cnt
                        save_image(comparison.cpu(), img_name, nrow=n)
                        print('finished writing img', img_name)

    test_loss /= seen
    print('====> Test set loss: {:.4f}'.format(test_loss))
    print('finished test', time.time() - st)
    return test_loss