Example #1
0
def sample_autoregressive_batch_last_state(data, episode_number,
                                           episode_reward, name):
    with torch.no_grad():
        states, actions, rewards, next_states, terminals, reset, relative_indexes = data
        states = reshape_input(states).to(DEVICE)
        targets = (2 * reshape_input(next_states[:, -1:]) - 1).to(DEVICE)
        actions = actions.to(DEVICE)
        #
        if args.teacher_force:
            name += '_tf'
        bs = states.shape[0]
        #vqvae_model.scl
        print('generating %s images' % (bs))
        np_targets = deepcopy(targets.cpu().numpy())
        #output = np.zeros((targets.shape[2], targets.shape[3]))
        total_reward = 0
        for bi in range(bs):
            # sample one at a time due to memory constraints
            iname = os.path.join(
                output_savepath, '%s_E%05d_R%03d_%05d.png' %
                (name, int(episode_number), int(episode_reward), bi))
            if not os.path.exists(iname):
                total_reward += rewards[bi].item()
                y = targets[bi:bi + 1] * 0.0
                y[0, 0, 0, 0] = targets[bi, 0, 0, 0]
                title = 'step:%05d action:%d reward:%s %s/%s' % (
                    bi, actions[bi].item(), int(
                        rewards[bi]), total_reward, int(episode_reward))
                print("making", title)
                for i in range(y.shape[1]):
                    for j in range(y.shape[2]):
                        for k in range(y.shape[3]):
                            x_d, z_e_x, z_q_x, latents = vqvae_model(
                                states[bi:bi + 1], y=y)
                            yhat = sample_from_discretized_mix_logistic(
                                x_d, largs.nr_logistic_mix)
                            if not args.teacher_force:
                                y[0, 0, j,
                                  k] = 2 * (yhat[0, 0, j, k] / 255.0) - 1
                            else:
                                y[0, 0, j, k] = targets[bi, 0, j, k]

                np_canvas = yhat[0, 0].cpu().numpy()
                f, ax = plt.subplots(1, 2)
                ax[0].imshow(np_targets[bi, 0])
                ax[0].set_title('true')
                ax[1].imshow(np_canvas)
                ax[1].set_title('est')
                plt.suptitle(title)
                plt.savefig(iname)
                print('saving', iname)
        search_path = iname[:-10:] + '*.png'
        gif_path = iname[:-10:] + '.gif'
        cmd = 'convert %s %s' % (search_path, gif_path)
        print('creating gif', gif_path)
        os.system(cmd)
def sample_batch(data, episode_number, episode_reward, name):
    with torch.no_grad():
        states, actions, rewards, next_states, terminals, reset, relative_indexes = data
        x = (2*reshape_input(states[:,-1:])-1).to(DEVICE)
        for i in range(states.shape[0]):
            x_d, z_e_x, z_q_x, latents = vqvae_model(x[i:i+1])
            loss_1 = discretized_mix_logistic_loss(x_d, x[i:i+1], nr_mix=largs.nr_logistic_mix, DEVICE=DEVICE)
            yhat = sample_from_discretized_mix_logistic(x_d, largs.nr_logistic_mix)
            yhat = (((yhat+1)/2.0)*255.0).cpu().numpy().astype(np.int)
            true = (states[i:i+1,-1:]*255.0).cpu().numpy().astype(np.int)
            f,ax = plt.subplots(1,2)
            iname = os.path.join(output_savepath, '%s_E%05d_R%03d_%05d.png'%(name, int(episode_number), int(episode_reward), i))
            print("writing", os.path.split(iname)[1])
            title = 'step %s/%s action %s reward %s' %(i, states.shape[0], actions[i].item(), rewards[i].item())
            ax[0].imshow(true[0,0])
            ax[0].set_title('true')
            ax[1].imshow(yhat[0,0])
            ax[1].set_title('est')
            plt.suptitle(title)
            plt.savefig(iname)
            print('saving', iname)
        search_path = iname[:-10:] + '*.png'
        gif_path = iname[:-10:] + '.gif'
        cmd = 'convert %s %s' %(search_path, gif_path)
        print('creating gif', gif_path)
        os.system(cmd)
def generate_forward_datasets():
    with torch.no_grad():
        #for dname, data_loader in {'valid':valid_data_loader, 'train':train_data_loader}.items():
        for dname, data_loader in {'valid': valid_data_loader}.items():
            rmax = data_loader.relative_indexes.max()
            new = True
            st = 1
            if args.debug:
                st = 260
            en = 0
            keep_going = True
            while keep_going:
                en = min(st + args.batch_size, rmax - 1)
                print("all", st, en)
                fdata = data_loader.get_data(np.arange(st, en, dtype=np.int))
                # use reward as endpoint
                if args.debug:
                    fterminals = list(fdata[2])
                else:
                    fterminals = list(fdata[5])
                # end at end of episode
                if 1 in fterminals:
                    en = st + list(fterminals).index(1) + 1
                    data = data_loader.get_data(np.arange(st, en,
                                                          dtype=np.int))
                else:
                    data = fdata
                    print("NO END")
                print('generating from %s to %s of %s' % (st, en, rmax))
                states, actions, rewards, values, next_states, terminals, reset, relative_indexes = data
                assert np.sum(terminals[:-1]) == 0
                prev_relative_indexes = relative_indexes - 1
                prev_data = data_loader.get_data(prev_relative_indexes)
                pstates, pactions, prewards, pvalues, pnext_states, pterminals, preset, prelative_indexes = prev_data
                ps = (2 * reshape_input(torch.FloatTensor(pstates)) -
                      1).to(DEVICE)
                s = (2 * reshape_input(torch.FloatTensor(states)) -
                     1).to(DEVICE)
                ns = (2 * reshape_input(torch.FloatTensor(next_states)) -
                      1).to(DEVICE)
                for xx in range(s.shape[0]):
                    try:
                        assert ps[xx, -1].sum() == s[xx,
                                                     -2].sum() == ns[xx,
                                                                     -3].sum()
                    except:
                        print("assert broke", xx)
                        embed()
                px_d, zp_e_x, pz_q_x, platents, prev_pred_actions, prev_pred_rewards = vqvae_model(
                    ps)
                x_d, z_e_x, z_q_x, latents, pred_actions, pred_rewards = vqvae_model(
                    s)
                nx_d, nz_e_x, nz_q_x, nlatents, next_pred_actions, next_pred_rewards = vqvae_model(
                    ns)
                if new:
                    all_prev_latents = platents.cpu()
                    all_latents = latents.cpu()
                    all_next_latents = nlatents.cpu()

                    if args.debug:
                        all_prev_states = pstates
                        all_states = states
                        all_next_states = next_states
                        all_next_pred_actions = next_pred_actions.cpu().numpy()
                        all_next_pred_rewards = next_pred_rewards.cpu().numpy()
                        all_pred_actions = pred_actions.cpu().numpy()
                        all_pred_rewards = pred_rewards.cpu().numpy()

                    all_prev_actions = pactions
                    all_prev_rewards = prewards
                    all_prev_values = pvalues

                    all_rewards = rewards
                    all_values = values
                    all_actions = actions
                    all_rel_inds = relative_indexes

                    new = False
                else:
                    all_prev_latents = np.concatenate(
                        (all_prev_latents, platents.cpu().numpy()), axis=0)
                    all_latents = np.concatenate(
                        (all_latents, latents.cpu().numpy()), axis=0)
                    all_next_latents = np.concatenate(
                        (all_next_latents, nlatents.cpu().numpy()), axis=0)

                    if args.debug:
                        all_prev_states = np.concatenate(
                            (all_prev_states, pstates), axis=0)
                        all_states = np.concatenate((all_states, states),
                                                    axis=0)
                        all_next_states = np.concatenate(
                            (all_next_states, next_states), axis=0)
                        all_next_pred_rewards = np.concatenate(
                            (all_next_pred_rewards,
                             next_pred_rewards.cpu().numpy()),
                            axis=0)
                        all_pred_rewards = np.concatenate(
                            (all_pred_rewards, pred_rewards.cpu().numpy()),
                            axis=0)
                        all_next_pred_actions = np.concatenate(
                            (all_next_pred_actions,
                             next_pred_actions.cpu().numpy()),
                            axis=0)
                        all_pred_actions = np.concatenate(
                            (all_pred_actions, pred_actions.cpu().numpy()),
                            axis=0)

                    all_prev_rewards = np.concatenate(
                        (all_prev_rewards, prewards))
                    all_prev_values = np.concatenate(
                        (all_prev_values, pvalues))
                    all_prev_actions = np.concatenate(
                        (all_prev_actions, pactions))

                    all_rewards = np.concatenate((all_rewards, rewards))
                    all_values = np.concatenate((all_values, values))
                    all_actions = np.concatenate((all_actions, actions))
                    all_rel_inds = np.concatenate(
                        (all_rel_inds, relative_indexes))

                if 1 in fterminals:
                    # skip ahead one so that prev state is correct
                    st = en + 1
                else:
                    st = en
                if en > rmax - 2:
                    keep_going = False
            forward_dir = args.model_loadname.replace(
                '.pt', '_%s_forward_imgs' % dname)
            if not os.path.exists(forward_dir):
                os.makedirs(forward_dir)

            forward_filename = args.model_loadname.replace(
                '.pt', '_%s_forward.npz' % dname)
            if args.debug:
                forward_filename = forward_filename.replace(
                    '.npz', 'debug.npz')
                np.savez(forward_filename,
                         relative_indexes=all_rel_inds,
                         prev_latents=all_prev_latents,
                         latents=all_latents,
                         next_latents=all_next_latents,
                         rewards=all_rewards,
                         values=all_values,
                         actions=all_actions,
                         prev_rewards=all_prev_rewards,
                         prev_values=all_prev_values,
                         prev_actions=all_prev_actions,
                         prev_states=all_prev_states,
                         states=all_states,
                         next_states=all_next_states,
                         num_k=largs.num_k)
            else:
                print('saving', forward_filename)
                np.savez(forward_filename,
                         relative_indexes=all_rel_inds,
                         prev_latents=all_prev_latents,
                         latents=all_latents,
                         next_latents=all_next_latents,
                         rewards=all_rewards,
                         values=all_values,
                         actions=all_actions,
                         prev_rewards=all_prev_rewards,
                         prev_values=all_prev_values,
                         prev_actions=all_prev_actions,
                         num_k=largs.num_k)

            if args.debug:
                for i in range(all_prev_latents.shape[0]):
                    f, ax = plt.subplots(3, 3)
                    ax[0, 0].imshow(all_prev_states[i, -1])
                    # one ahead in action/reward because vq is predicting transition
                    ax[0, 0].set_title('%04d A%sPA%s' %
                                       (i, all_prev_actions[i],
                                        np.argmax(all_pred_actions[i])))

                    ax[0, 1].imshow(all_states[i, -1])
                    ax[0, 1].set_title('%04d A%sPA%s' %
                                       (i + 1, all_actions[i],
                                        np.argmax(all_next_pred_actions[i])))

                    ax[0, 2].imshow(all_next_states[i, -1])
                    ax[0, 2].set_title('%04d' % (i + 2))

                    ax[1, 0].imshow(all_prev_latents[i], vmin=0, vmax=512)
                    ax[1, 0].set_title('%04d R%sRA%s' %
                                       (i, all_prev_rewards[i],
                                        np.argmax(all_pred_rewards[i])))

                    ax[1, 1].imshow(all_latents[i], vmin=0, vmax=512)
                    ax[1, 1].set_title('%04d R%sRA%s' %
                                       (i + 1, all_rewards[i],
                                        np.argmax(all_next_pred_rewards[i])))

                    ax[1, 2].imshow(all_next_latents[i], vmin=0, vmax=512)
                    ax[1, 2].set_title('%04d' % (i + 2))

                    s_mask = (all_latents[i] - all_prev_latents[i]) == 0
                    diffnl = deepcopy(all_latents[i])
                    diffnl[s_mask] *= 0
                    ax[2, 1].imshow(diffnl, vmin=0, vmax=512)
                    ax[2, 1].set_title('diff %s-%s' % (i, i + 1))

                    s1_mask = (all_next_latents[i] - all_latents[i]) == 0
                    diffs1nl = deepcopy(all_next_latents[i])
                    diffs1nl[s1_mask] *= 0
                    ax[2, 2].imshow(diffs1nl, vmin=0, vmax=512)
                    ax[2, 2].set_title('diff %s-%s' % (i + 1, i + 2))

                    pname = os.path.join(forward_dir,
                                         '%s_frame%05d.png' % (dname, i + 1))
                    plt.savefig(pname)
                    plt.close()
                    if not i % 10:
                        print('plotting', i, pname)
                cmd = 'convert %s %s' % (os.path.join(
                    forward_dir, '%s*.png' %
                    dname), os.path.join(forward_dir, '%s.gif' % dname))
                print("!!!! creating gif")
                print(cmd)
                os.system(cmd)
    #sample_batch(valid_episode_batch, episode_index, episode_reward, 'valid')

    # Can work with any model, but it assumes that the model has a
    # feature method, and a classifier method,
    # as in the VGG models in torchvision.

    grad_cam = GradCam(model = vqvae_model,
                    target_layer_names =['10'], use_cuda=args.use_cuda)

    # If None, returns the map for the highest scoring category.
    # Otherwise, targets the requested index.
    target_index = None

    states, actions, rewards, values, pred_states, terminals, reset, relative_indexes = valid_episode_batch

    input_data = (2*reshape_input(states[:1])-1).to(DEVICE)

    mask = grad_cam(input_data, target_index)
    img = input_data[0,-1].cpu().numpy()
    img = (img+1)/2.0
    cimg = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
    show_cam_on_image(cimg, mask)
    #utils.save_image(torch.from_numpy(cam_gb), 'cam_gb.jpg')

    #gb_model = GuidedBackpropReLUModel(model = models.vgg19(pretrained=True), use_cuda=args.use_cuda)
    #gb = gb_model(input, index=target_index)
    #utils.save_image(torch.from_numpy(gb), 'gb.jpg')

    #cam_mask = np.zeros(gb.shape)
    #for i in range(0, gb.shape[0]):
    #    cam_mask[i, :, :] = mask
Example #5
0
def sample_episode(data, episode_number, episode_reward, name):
    # rollout for number of steps and decode with vqvae decoder
    states, actions, rewards, values, next_states, terminals, reset, relative_indexes = data
    params = (episode_number, episode_reward, name, actions, rewards)
    pred_actions = []
    bs = min(states.shape[0], args.rollout_length)
    snp = reshape_input(deepcopy(states))
    s = (2 * reshape_input(torch.FloatTensor(states)) - 1)
    nsnp = reshape_input(next_states)
    # make channels for actions which is the size of the latents
    gen_method = []
    actions = torch.LongTensor(actions).to(DEVICE)
    elen = actions.shape[0]
    #print("setting all actions to one")
    #actions[args.lead_in:] = 1
    #actions[40:] = 0
    #actions[args.lead_in:]=torch.LongTensor(np.random.randint(min(data_loader.action_space),
    #                                                          max(data_loader.action_space),
    #                                                          actions[args.lead_in:].shape[0])).to(DEVICE)
    channel_actions = torch.zeros(
        (elen, forward_info['num_actions'], forward_info['hsize'],
         forward_info['hsize']))
    for a in range(forward_info['num_actions']):
        channel_actions[actions == a, a] = 1.0
    all_tf_pred_latents = np.zeros((args.rollout_length, 10, 10))
    all_real_latents = torch.zeros(
        (args.rollout_length, 10, 10)).to(DEVICE).float()
    # first pred index is zeros - since we cant predict it
    all_pred_latents = torch.zeros(
        (args.rollout_length, 10, 10)).to(DEVICE).float()
    assert args.lead_in >= 2
    # at beginning
    # 0th real action is actually for 1
    # all_real_latents represents i
    # tf_pred_latents is constructing i, given latents from latents of i-1,i-2
    # from tf_pred_latents[i], you can find obs[i]
    for i in range(bs):
        x_d, z_e_x, z_q_x, real_latents, pred_actions, pred_signals = vqvae_model(
            s[i:i + 1])
        # for the ith index
        all_real_latents[i] = real_latents.float()
        gmethod = 'NOT'
        if i >= 2:
            # get a teacher force result - where past real latents are used to
            # predict this state
            assert (np.abs(all_real_latents[i - 2]).sum() > 0)
            assert (np.abs(all_real_latents[i - 1]).sum() > 0)
            tf_state_input = torch.cat(
                (channel_actions[i][None, :], all_real_latents[i - 2][None,
                                                                      None],
                 all_real_latents[i - 1][None, None]),
                dim=1)
            #tf_pred_next_latents, tf_pred_prev_actions, tf_pred_rewards = conv_forward_model(tf_state_input)
            tf_pred_next_latents = conv_forward_model(tf_state_input)
            # prediction for the i + 1 index
            tf_pred_next_latents = torch.argmax(tf_pred_next_latents, dim=1)
            # THIS is pred s+1 so the indexes are different
            all_tf_pred_latents[i] = deepcopy(
                tf_pred_next_latents[0].cpu().numpy())
            #print('tf', i, all_tf_pred_latents[i].sum())
            if i > args.lead_in:
                # use last prediction
                print('i', i)
                assert (np.abs(all_pred_latents[i - 2]).sum() > 0)
                assert (np.abs(all_pred_latents[i - 1]).sum() > 0)
                state_input = torch.cat(
                    (channel_actions[i][None, :],
                     all_pred_latents[i - 2][None, None].float(),
                     all_pred_latents[i - 1][None, None].float()),
                    dim=1)
                gmethod = 'SLF'
                # use teacher forced version if we are in "lead in"
            else:
                gmethod = 'FTF'
                state_input = torch.cat(
                    (channel_actions[i][None, :],
                     all_real_latents[i - 2][None, None].float(),
                     all_real_latents[i - 1][None, None].float()),
                    dim=1)

            out_pred_next_latents = conv_forward_model(state_input)
            # take argmax over channels axis
            pred_next_latents = torch.argmax(out_pred_next_latents, dim=1)
            # replace true with this
            all_pred_latents[i] = pred_next_latents[0].float()
            #print('pred', i, all_pred_latents[i].sum())
        else:
            print("feeding beginning of preds with real", i)
            all_pred_latents[i] = real_latents.float()
        gen_method.append(gmethod)

    all_pred_latents = all_pred_latents.cpu().numpy()
    all_real_latents = all_real_latents.cpu().numpy()
    plot_reconstructions(snp, nsnp, all_real_latents, all_pred_latents,
                         all_tf_pred_latents, params, gen_method)
    plot_latents(all_real_latents, all_pred_latents, all_tf_pred_latents,
                 params, gen_method)