def call_plot(model_dict, data_dict, info):
    from acn_utils import tsne_plot
    from acn_utils import pca_plot
    # always be in eval mode
    model_dict = set_model_mode(model_dict, 'valid')
    with torch.no_grad():
        for phase in ['valid', 'train']:
            data_loader = data_dict[phase]
            data_loader.reset_unique()
            batch = data_loader.get_unique_minibatch(84)
            states, actions, rewards, next_states, _, _, batch_indexes, index_indexes = batch
            fp_out = forward_pass(model_dict, states, actions, rewards, next_states, index_indexes, phase, info)
            model_dict, states, actions, rewards, target, u_q, u_p, s_p, rec_dml, pcnn_dml, z_e_x, z_q_x, latents = fp_out
            bs = states.shape[0]
            u_q_flat = u_q.view(bs, info['code_length'])
            X = u_q_flat.cpu().numpy()
            color = index_indexes
            images = target[:,0].cpu().numpy()
            if args.tsne:
                param_name = '_tsne_%s_P%s.html'%(phase, info['perplexity'])
                html_path = info['model_loadpath'].replace('.pt', param_name)
                tsne_plot(X=X, images=images, color=color,
                      perplexity=info['perplexity'],
                      html_out_path=html_path, serve=False)
            if args.pca:
                param_name = '_pca_%s.html'%(phase)
                html_path = info['model_loadpath'].replace('.pt', param_name)
                pca_plot(X=X, images=images, color=color,
                          html_out_path=html_path, serve=False)
            break
Beispiel #2
0
def call_plot(model_dict, data_dict, info, sample, tsne, pca):
    from acn_utils import tsne_plot
    from acn_utils import pca_plot
    # always be in eval mode - so we dont swap neighbors
    model_dict = set_model_mode(model_dict, 'valid')
    with torch.no_grad():
        for phase in ['valid', 'train']:
            data_loader = data_dict[phase]
            data_loader.reset_unique()
            batch = data_loader.get_unique_minibatch(84)
            states, actions, rewards, next_states, _, _, batch_indexes, index_indexes = batch
            fp_out = forward_pass(model_dict, states, actions, rewards,
                                  next_states, index_indexes, phase, info)
            model_dict, states, actions, rewards, target, u_q, u_p, s_p, rec_dml, z_e_x, z_q_x, latents = fp_out
            rec_yhat = sample_from_discretized_mix_logistic(
                rec_dml,
                info['nr_logistic_mix'],
                only_mean=info['sample_mean'],
                sampling_temperature=info['sampling_temperature'])
            f, ax = plt.subplots(10, 3)
            ax[0, 0].set_title('prev')
            ax[0, 1].set_title('true')
            ax[0, 2].set_title('pred')
            for i in range(10):
                ax[i, 0].matshow(states[i, -1])
                ax[i, 1].matshow(target[i, -1])
                ax[i, 2].matshow(rec_yhat[i, -1])
                ax[i, 0].axis('off')
                ax[i, 1].axis('off')
                ax[i, 2].axis('off')
            plt.subplots_adjust(wspace=0, hspace=0)
            plt.tight_layout()

            plt_path = info['model_loadpath'].replace('.pt',
                                                      '_%s_plt.png' % phase)
            print('plotting', plt_path)
            plt.savefig(plt_path)
            bs = states.shape[0]
            u_q_flat = u_q.view(bs, info['code_length'])
            X = u_q_flat.cpu().numpy()
            color = index_indexes
            images = target[:, 0].cpu().numpy()
            if tsne:
                param_name = '_tsne_%s_P%s.html' % (phase, info['perplexity'])
                html_path = info['model_loadpath'].replace('.pt', param_name)
                tsne_plot(X=X,
                          images=images,
                          color=color,
                          perplexity=info['perplexity'],
                          html_out_path=html_path,
                          serve=False)
            if pca:
                param_name = '_pca_%s.html' % (phase)
                html_path = info['model_loadpath'].replace('.pt', param_name)
                pca_plot(X=X,
                         images=images,
                         color=color,
                         html_out_path=html_path,
                         serve=False)
Beispiel #3
0
def load_uvdeconv_representation_model(representation_model_path):
    # model trained with this file:
    # ../models/train_atari_uvdeconv_tacn_midtwgradloss.py
    # will output a acn flat float representation and a vq discrete
    # representation - which to use?
    rep_info = {'device': device, 'args': args}
    rep_model_dict, _, rep_info, train_cnt, epoch_cnt, rescale, rescale_inv = create_models(
        rep_info, representation_model_path, load_data=False)
    rep_model_dict = set_model_mode(rep_model_dict, 'valid')
    return rep_model_dict, rep_info, prepare_uv_state_latents, rescale, rescale_inv
def sample(model_dict, data_dict, info):
    from skvideo.io import vwrite
    model_dict = set_model_mode(model_dict, 'valid')
    output_savepath = args.model_loadpath.replace('.pt', '')
    bs = 10
    with torch.no_grad():
        for phase in ['train', 'valid']:
            with torch.no_grad():
                data_loader = data_dict[phase]
                data_loader.reset_unique()
                batch = data_loader.get_unique_minibatch(bs)
                states, actions, rewards, next_states, _, _, batch_indexes, index_indexes = batch
                st_can = '_zc'

                iname = output_savepath + st_can + '_st%s'%args.sampling_temperature + '_sample_%s.png'%phase
                fp_out = forward_pass(model_dict, states, actions, rewards, next_states, index_indexes, phase, info)
                prep_batch = make_atari_channel_action_reward_diff_state(states, actions, rewards, next_states,
                                                                             info['device'],
                                                                             info['num_actions'], info['num_rewards'])

                states, action_cond, reward_cond, target = prep_batch
                last = states[:,-1:]
                rlast = rescale_inv(last)
                model_dict, states, actions, rewards, target, u_q, u_p, s_p, rec_dml, pcnn_dml, z_e_x, z_q_x, latents = fp_out
                # teacher forced version
                pcnn_yhat = sample_from_discretized_mix_logistic(pcnn_dml, info['nr_logistic_mix'], only_mean=args.sample_mean, sampling_temperature=args.sampling_temperature)
                rec_yhat = sample_from_discretized_mix_logistic(rec_dml, info['nr_logistic_mix'], only_mean=args.sample_mean, sampling_temperature=args.sampling_temperature)
                # create blank canvas for autoregressive sampling
                nprlast = rlast.detach().cpu().numpy()
                np_target = target.detach().cpu().numpy()
                #np_rec_yhat = nprlast+rec_yhat.detach().cpu().numpy()
                #canvas = deconv_yhat_batch
                print('using zero output as sample canvas')
                canvas = torch.zeros_like(target)
                for i in range(canvas.shape[1]):
                    for j in range(canvas.shape[2]):
                        print('sampling row: %s'%j)
                        for k in range(canvas.shape[3]):
                            #output = model_dict['pcnn_decoder_model'](x=canvas, spatial_condition=rec_dml)
                            output = model_dict['pcnn_decoder_model'](x=canvas, spatial_condition=rec_dml)
                            output = sample_from_discretized_mix_logistic(output.detach(), info['nr_logistic_mix'], only_mean=args.sample_mean, sampling_temperature=args.sampling_temperature)
                            canvas[:,i,j,k] = rescale(rlast[:,i,j,k]-rescale_inv(output[:,i,j,k]))
                            if target[0,i,j,k] != -1:
                                print(j,k,output[0,i,j,k], canvas[0,i,j,k], target[0,i,j,k])

                f,ax = plt.subplots(bs, 3, sharex=True, sharey=True, figsize=(3,bs))
                #np_output = output.detach().cpu().numpy()
                #np_output = nprlast+rescale_inv(output.detach().cpu().numpy())
                np_pcnn_yhat = pcnn_yhat.detach().cpu().numpy()
                true_diff = target.detach().cpu().numpy()
                pred_diff = output.detach().cpu().numpy()
                #pred_diff = canvas.detach().cpu().numpy()
                np_pred = nprlast+rescale_inv(pred_diff)
                np_true = nprlast+rescale_inv(true_diff)
                np_tf = nprlast+rescale_inv(np_pcnn_yhat)
                ma_true = true_diff * (np.abs(true_diff) > 0)
                ma_pred = pred_diff * (np.abs(pred_diff) > 0)
                ma_tf = np_pcnn_yhat * (np.abs(np_pcnn_yhat) > 0)

                for idx in range(bs):
                    ax[idx,0].matshow(np_true[idx,0], cmap=plt.cm.gray)
                    ax[idx,0].imshow(ma_true[idx,0], alpha=0.5, cmap=plt.cm.Reds_r)
                    ax[idx,1].matshow(np_tf[idx,0], cmap=plt.cm.gray)
                    ax[idx,1].imshow(ma_tf[idx,0], alpha=0.5, cmap=plt.cm.Reds_r)
                    ax[idx,2].matshow(np_pred[idx,0], cmap=plt.cm.gray)
                    ax[idx,2].imshow(ma_pred[idx,0], alpha=0.5, cmap=plt.cm.Reds_r)
                    ax[idx,0].axis('off')
                    ax[idx,1].axis('off')
                    ax[idx,2].axis('off')
                ax[0,0].set_title('target')
                ax[0,1].set_title('tf')
                ax[0,2].set_title('sam')
                print('plotting %s'%iname)
                plt.savefig(iname)
                plt.close()
def run(train_cnt, model_dict, data_dict, phase, info):
    st = time.time()
    loss_dict = {'running': 0,
             'kl':0,
             'pcnn_%s'%info['rec_loss_type']:0,
             'vq':0,
             'commit':0,
             'loss':0,
              }
    print('starting', phase, 'cuda', torch.cuda.memory_allocated(device=None))
    data_loader = data_dict[phase]
    data_loader.reset_unique()
    num_batches = len(data_loader.unique_indexes)//info['batch_size']
    idx = 0
    set_model_mode(model_dict, phase)
    torch.set_grad_enabled(phase=='train')
    while data_loader.unique_available:
        for key in model_dict.keys():
            model_dict[key].zero_grad()
        batch = data_loader.get_unique_minibatch(info['batch_size'])
        states, actions, rewards, next_states, _, _, batch_indexes, index_indexes = batch
        fp_out = forward_pass(model_dict, states, actions, rewards, next_states, index_indexes, phase, info)
        model_dict, states, actions, rewards, target, u_q, u_p, s_p, rec_dml, pcnn_dml, z_e_x, z_q_x, latents = fp_out
        bs,c,h,w = states.shape
        if idx == 0:
            log_ones = torch.zeros(bs, info['code_length']).to(info['device'])
        if bs != log_ones.shape[0]:
            log_ones = torch.zeros(bs, info['code_length']).to(info['device'])
        kl = kl_loss_function(u_q.view(bs, info['code_length']), log_ones,
                              u_p.view(bs, info['code_length']), s_p.view(bs, info['code_length']),
                              reduction=info['reduction'])
        # no loss on deconv rec
        pcnn_loss = discretized_mix_logistic_loss(pcnn_dml, target, nr_mix=info['nr_logistic_mix'], reduction=info['reduction'])
        vq_loss = F.mse_loss(z_q_x, z_e_x.detach(), reduction=info['reduction'])
        commit_loss = F.mse_loss(z_e_x, z_q_x.detach(), reduction=info['reduction'])
        commit_loss *= info['vq_commitment_beta']
        loss = kl+pcnn_loss+commit_loss+vq_loss
        loss_dict['running']+=bs
        loss_dict['loss']+=loss.detach().cpu().item()
        loss_dict['kl']+= kl.detach().cpu().item()
        loss_dict['vq']+= vq_loss.detach().cpu().item()
        loss_dict['commit']+= commit_loss.detach().cpu().item()
        loss_dict['pcnn_%s'%info['rec_loss_type']]+=pcnn_loss.detach().cpu().item()
        if phase == 'train':
            model_dict = clip_parameters(model_dict)
            loss.backward()
            model_dict['opt'].step()
            train_cnt+=bs
        if idx == num_batches-3:
            # store example near end for plotting
            pcnn_yhat = sample_from_discretized_mix_logistic(pcnn_dml, info['nr_logistic_mix'], only_mean=info['sample_mean'], sampling_temperature=info['sampling_temperature'])
            rec_yhat = sample_from_discretized_mix_logistic(rec_dml, info['nr_logistic_mix'], only_mean=info['sample_mean'], sampling_temperature=info['sampling_temperature'])
            example = {'prev_frame':rescale_inv(states[:,-1:].detach().cpu()),
                       'target':rescale_inv(target.detach().cpu()),
                       'deconv_yhat':rescale_inv(rec_yhat.detach().cpu()),
                       'pcnn_yhat':rescale_inv(pcnn_yhat.detach().cpu()),
                       }
        if not idx % 10:
            print(train_cnt, idx, account_losses(loss_dict))
            print(phase, 'cuda', torch.cuda.memory_allocated(device=None))
        idx+=1

    loss_avg = account_losses(loss_dict)
    print("finished %s after %s secs at cnt %s"%(phase,
                                                time.time()-st,
                                                train_cnt,
                                                ))
    print(loss_avg)
    print('end', phase, 'cuda', torch.cuda.memory_allocated(device=None))
    del states; del target; del actions; del rewards
    torch.cuda.empty_cache()
    print('after delete end', phase, 'cuda', torch.cuda.memory_allocated(device=None))
    #return model_dict, data_dict, loss_avg, example
    return loss_avg, example
Beispiel #6
0
def sample(model_dict, data_dict, info):
    model_dict = set_model_mode(model_dict, 'valid')
    output_savepath = args.model_loadpath.replace('.pt', '')
    bs = 10
    with torch.no_grad():
        for phase in ['valid', 'train']:
            with torch.no_grad():
                data_loader = data_dict[phase]
                data_loader.reset_unique()
                batch = data_loader.get_unique_minibatch(bs)
                states, actions, rewards, next_states, _, _, batch_indexes, index_indexes = batch
                fp_out = forward_pass(model_dict, states, actions, rewards,
                                      next_states, index_indexes, phase, info)
                model_dict, states, actions, rewards, target, u_q, u_p, s_p, rec_dml, pcnn_dml, z_e_x, z_q_x, latents = fp_out
                # teacher forced version
                pcnn_yhat = sample_from_discretized_mix_logistic(
                    pcnn_dml,
                    info['nr_logistic_mix'],
                    only_mean=args.sample_mean,
                    sampling_temperature=args.sampling_temperature)
                rec_yhat = sample_from_discretized_mix_logistic(
                    rec_dml,
                    info['nr_logistic_mix'],
                    only_mean=args.sample_mean,
                    sampling_temperature=args.sampling_temperature)
                # create blank canvas for autoregressive sampling
                last = states[:, -1:]
                np_last = deepcopy(last.detach().cpu().numpy())
                np_target = deepcopy(target.detach().cpu().numpy())
                np_rec_yhat = rec_yhat.detach().cpu().numpy()
                np_pcnn_yhat = pcnn_yhat.detach().cpu().numpy()
                if args.teacher_force_prev:
                    st_can = '_lf'
                    canvas = last
                else:
                    st_can = '_zc'
                    canvas = torch.zeros_like(target)
                iname = output_savepath + st_can + '_st%s' % args.sampling_temperature + '_sample_%s.png' % phase
                print('using zero output as sample canvas')
                for i in range(canvas.shape[1]):
                    for j in range(canvas.shape[2]):
                        print('sampling row: %s' % j)
                        for k in range(canvas.shape[3]):
                            output = model_dict['pcnn_decoder_model'](
                                x=canvas, spatial_condition=rec_dml)
                            output_o = sample_from_discretized_mix_logistic(
                                output.detach(),
                                info['nr_logistic_mix'],
                                only_mean=args.sample_mean,
                                sampling_temperature=args.sampling_temperature)
                            canvas[:, i, j, k] = output_o[:, i, j, k]
                output_o = sample_from_discretized_mix_logistic(
                    output.detach(),
                    info['nr_logistic_mix'],
                    only_mean=args.sample_mean,
                    sampling_temperature=args.sampling_temperature)

                f, ax = plt.subplots(bs,
                                     5,
                                     sharex=True,
                                     sharey=True,
                                     figsize=(3, bs))
                np_output = output_o.detach().cpu().numpy()
                for idx in range(bs):
                    ax[idx, 0].matshow(np_last[idx, 0], cmap=plt.cm.gray)
                    ax[idx, 1].matshow(np_target[idx, 0], cmap=plt.cm.gray)
                    ax[idx, 2].matshow(np_rec_yhat[idx, 0], cmap=plt.cm.gray)
                    ax[idx, 3].matshow(np_pcnn_yhat[idx, 0], cmap=plt.cm.gray)
                    ax[idx, 4].matshow(np_output[idx, 0], cmap=plt.cm.gray)
                    ax[idx, 0].set_title('last')
                    ax[idx, 1].set_title('true')
                    ax[idx, 2].set_title('conv')
                    ax[idx, 3].set_title('tf')
                    ax[idx, 4].set_title('sam')
                    ax[idx, 0].axis('off')
                    ax[idx, 1].axis('off')
                    ax[idx, 2].axis('off')
                    ax[idx, 3].axis('off')
                    ax[idx, 4].axis('off')
                print('plotting %s' % iname)
                plt.savefig(iname)
                plt.close()