Exemplo n.º 1
0
def get_mdrnn_cell(rnn_dir):
    rnn_file = os.path.join(rnn_dir, 'best.tar')
    assert os.path.exists(rnn_file)
    state = torch.load(rnn_file)
    mdrnn_cell = MDRNNCell(LSIZE, ASIZE, RSIZE, 5).to(device)
    mdrnn_cell.load_state_dict(
        {k.strip('_l0'): v
         for k, v in state['state_dict'].items()})
    return mdrnn_cell
class SimulatedDoom(gym.Env):  # pylint: disable=too-many-instance-attributes
    """
    Simulated Car Racing.

    Gym environment using learnt VAE and MDRNN to simulate the
    CarRacing-v0 environment.

    :args directory: directory from which the vae and mdrnn are
    loaded.
    """
    def __init__(self, directory):
        vae_file = join(directory, 'vae', 'best.tar')
        rnn_file = join(directory, 'mdrnn', 'best.tar')
        assert exists(vae_file), "No VAE model in the directory..."
        assert exists(rnn_file), "No MDRNN model in the directory..."

        # spaces
        self.action_space = spaces.Box(low=-1, high=1, shape=(1, ))
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=(RED_SIZE, RED_SIZE, 3),
                                            dtype=np.uint8)

        # load VAE
        vae = VAE(3, LSIZE)
        vae_state = torch.load(vae_file,
                               map_location=lambda storage, location: storage)
        print("Loading VAE at epoch {}, "
              "with test error {}...".format(vae_state['epoch'],
                                             vae_state['precision']))
        vae.load_state_dict(vae_state['state_dict'])
        self._decoder = vae.decoder

        # load MDRNN
        self._rnn = MDRNNCell(LSIZE, ASIZE, RSIZE, 5)
        rnn_state = torch.load(rnn_file,
                               map_location=lambda storage, location: storage)
        print("Loading MDRNN at epoch {}, "
              "with test error {}...".format(rnn_state['epoch'],
                                             rnn_state['precision']))
        rnn_state_dict = {
            k.strip('_l0'): v
            for k, v in rnn_state['state_dict'].items()
        }
        self._rnn.load_state_dict(rnn_state_dict)

        # init state
        self._lstate = torch.randn(1, LSIZE)
        self._hstate = 2 * [torch.zeros(1, RSIZE)]

        # obs
        self._obs = None
        self._visual_obs = None

        # rendering
        self.monitor = None
        self.figure = None

    def reset(self):
        """ Resetting """
        import matplotlib.pyplot as plt
        self._lstate = torch.randn(1, LSIZE)
        self._hstate = 2 * [torch.zeros(1, RSIZE)]

        # also reset monitor
        if not self.monitor:
            self.figure = plt.figure()
            self.monitor = plt.imshow(
                np.zeros((RED_SIZE, RED_SIZE, 3), dtype=np.uint8))

    def step(self, action):
        """ One step forward """
        with torch.no_grad():
            action = torch.Tensor(action).unsqueeze(0)
            mu, sigma, pi, r, d, n_h = self._rnn(action, self._lstate,
                                                 self._hstate)
            pi = pi.squeeze()
            mixt = Categorical(torch.exp(pi)).sample().item()

            self._lstate = mu[:,
                              mixt, :]  # + sigma[:, mixt, :] * torch.randn_like(mu[:, mixt, :])
            self._hstate = n_h

            self._obs = self._decoder(self._lstate)
            np_obs = self._obs.numpy()
            np_obs = np.clip(np_obs, 0, 1) * 255
            np_obs = np.transpose(np_obs, (0, 2, 3, 1))
            np_obs = np_obs.squeeze()
            np_obs = np_obs.astype(np.uint8)
            self._visual_obs = np_obs

            return np_obs, r.item(), d.item() > 0

    def render(self):  # pylint: disable=arguments-differ
        """ Rendering """
        import matplotlib.pyplot as plt
        if not self.monitor:
            self.figure = plt.figure()
            self.monitor = plt.imshow(
                np.zeros((RED_SIZE, RED_SIZE, 3), dtype=np.uint8))
        self.monitor.set_data(self._visual_obs)
        plt.pause(.01)
Exemplo n.º 3
0
class SimulatedCarracing(gym.Env): # pylint: disable=too-many-instance-attributes
    """
    Simulated Car Racing.

    Gym environment using learnt VAE and MDRNN to simulate the
    CarRacing-v0 environment.

    :args directory: directory from which the vae and mdrnn are
    loaded.
    """
    def __init__(self, directory):
        vae_file = join(directory, 'vae', 'best.tar')
        rnn_file = join(directory, 'mdrnn', 'best.tar')
        assert exists(vae_file), "No VAE model in the directory..."
        assert exists(rnn_file), "No MDRNN model in the directory..."

        # spaces
        self.action_space = spaces.Box(np.array([-1, 0, 0]), np.array([1, 1, 1]))
        self.observation_space = spaces.Box(low=0, high=255, shape=(RED_SIZE, RED_SIZE, 3),
                                            dtype=np.uint8)

        # load VAE
        vae = VAE(3, LSIZE)
        vae_state = torch.load(vae_file, map_location=lambda storage, location: storage)
        print("Loading VAE at epoch {}, "
              "with test error {}...".format(
                  vae_state['epoch'], vae_state['precision']))
        vae.load_state_dict(vae_state['state_dict'])
        self._decoder = vae.decoder

        # load MDRNN
        self._rnn = MDRNNCell(32, 3, RSIZE, 5)
        rnn_state = torch.load(rnn_file, map_location=lambda storage, location: storage)
        print("Loading MDRNN at epoch {}, "
              "with test error {}...".format(
                  rnn_state['epoch'], rnn_state['precision']))
        rnn_state_dict = {k.strip('_l0'): v for k, v in rnn_state['state_dict'].items()}
        self._rnn.load_state_dict(rnn_state_dict)

        # init state
        self._lstate = torch.randn(1, LSIZE)
        self._hstate = 2 * [torch.zeros(1, RSIZE)]

        # obs
        self._obs = None
        self._visual_obs = None

        # rendering
        self.monitor = None
        self.figure = None

    def reset(self):
        """ Resetting """
        import matplotlib.pyplot as plt
        self._lstate = torch.randn(1, LSIZE)
        self._hstate = 2 * [torch.zeros(1, RSIZE)]

        # also reset monitor
        if not self.monitor:
            self.figure = plt.figure()
            self.monitor = plt.imshow(
                np.zeros((RED_SIZE, RED_SIZE, 3),
                         dtype=np.uint8))

    def step(self, action):
        """ One step forward """
        with torch.no_grad():
            action = torch.Tensor(action).unsqueeze(0)
            mu, sigma, pi, r, d, n_h = self._rnn(action, self._lstate, self._hstate)
            pi = pi.squeeze()
            mixt = Categorical(torch.exp(pi)).sample().item()

            self._lstate = mu[:, mixt, :] # + sigma[:, mixt, :] * torch.randn_like(mu[:, mixt, :])
            self._hstate = n_h

            self._obs = self._decoder(self._lstate)
            np_obs = self._obs.numpy()
            np_obs = np.clip(np_obs, 0, 1) * 255
            np_obs = np.transpose(np_obs, (0, 2, 3, 1))
            np_obs = np_obs.squeeze()
            np_obs = np_obs.astype(np.uint8)
            self._visual_obs = np_obs

            return np_obs, r.item(), d.item() > 0

    def render(self): # pylint: disable=arguments-differ
        """ Rendering """
        import matplotlib.pyplot as plt
        if not self.monitor:
            self.figure = plt.figure()
            self.monitor = plt.imshow(
                np.zeros((RED_SIZE, RED_SIZE, 3),
                         dtype=np.uint8))
        self.monitor.set_data(self._visual_obs)
        plt.pause(.01)
Exemplo n.º 4
0
def create_memory(env, agent, memory, steps, config, episodes=100):
    print("Create buffer with size {} steps ".format(steps))
    score = 0
    average_score = 0
    average_steps = 0
    agent.eval()

    LSIZE = 200
    ASIZE = 1
    RSIZE = 256
    mdir = "15.02_l200_RGB"
    m = ""
    vae_file, rnn_file = [join(mdir, m, 'best.tar') for m in ['vae', 'mdrnn']]
    vae_state, rnn_state = [
        torch.load(fname, map_location={'cuda:0': str(config["device"])})
        for fname in (vae_file, rnn_file)
    ]
    for m, s in (('VAE', vae_state), ('MDRNN', rnn_state)):
        print("Loading {} at epoch {} "
              "with test loss {}".format(m, s['epoch'], s['precision']))

    vae = VAE(3, LSIZE).to(config["device"])
    vae.load_state_dict(vae_state['state_dict'])
    mdrnn = MDRNNCell(LSIZE, ASIZE, RSIZE, 5).to(config["device"])
    mdrnn.load_state_dict(
        {k.strip('_l0'): v
         for k, v in rnn_state['state_dict'].items()})
    for i in range(episodes):
        # env = gym.wrappers.Monitor(env,str(config["locexp"])+"/vid/{}/{}".format(steps, i), video_callable=lambda episode_id: True,force=True)
        env.seed(i)
        state, obs = env.reset("mediumClassic")
        print(obs)
        episode_reward = 0
        index = memory.idx
        hidden = [torch.zeros(1, RSIZE).to(config["device"]) for _ in range(2)]
        for t in range(125):
            state_tensor = state.clone().detach().type(
                torch.cuda.FloatTensor).div_(255)
            action = agent.act(state_tensor)
            action_rnn = torch.as_tensor(action, device=config["device"]).type(
                torch.int).unsqueeze(0).unsqueeze(0)
            # print(action_rnn)
            # print(action_rnn.shape)
            states = obs
            states = torch.as_tensor(states,
                                     device=config["device"]).unsqueeze(0)
            states = states.type(torch.float32).div_(255)
            _, latent_mu, _ = vae(states)
            # print(latent_mu.shape)
            _, _, _, _, _, next_hidden = mdrnn(action_rnn, latent_mu, hidden)
            # print(next_hidden[0].shape)
            # print(next_hidden[1].shape)
            # sys.exit()
            next_state, reward, done, next_obs = env.step(action)
            if t != 124:
                done_no_max = done
            else:
                done_no_max = False
            memory.add(obs, hidden[0], hidden[1], action, next_obs,
                       next_hidden[0], next_hidden[1], done, done_no_max)
            if memory.idx % 500 == 0:
                path = "pacman_expert_memory-{}".format(memory.idx)
                print("save memory to ", path)
                memory.save_memory(path)
                if memory.idx >= steps:
                    return
            hidden = next_hidden
            state = next_state
            obs = next_obs
            score += reward
            episode_reward += reward
            if done or t == 124:
                if episode_reward < 600:
                    memory.idx = index

                print("Episode_reward {} and memory idx {}".format(
                    episode_reward, memory.idx))
                break
Exemplo n.º 5
0
def data_pass(epoch, train, include_reward):  # pylint: disable=too-many-locals
    """ One pass through the data """
    if train:
        mdrnn.train()
        loader = train_loader
    else:
        mdrnn.eval()
        loader = test_loader

    loader.dataset.load_next_buffer()

    cum_loss = 0
    cum_gmm = 0
    cum_bce = 0
    cum_mse = 0

    pbar = tqdm(total=len(loader.dataset), desc="Epoch {}".format(epoch))
    for i, data in enumerate(loader):
        obs, action, reward, terminal, next_obs = [
            arr.to(device) for arr in data
        ]

        # transform obs
        latent_obs, latent_next_obs = to_latent(obs, next_obs)

        if train:
            losses = get_loss(latent_obs, action, reward, terminal,
                              latent_next_obs, include_reward)

            mdrnnCell = MDRNNCell(LSIZE, ASIZE, RSIZE, 5)
            rnn_state_dict = {
                k.strip('_l0'): v
                for k, v in mdrnn.state_dict().items()
            }
            mdrnnCell.load_state_dict(rnn_state_dict)
            interim_policy = train_C_given_M(mdrnnCell=mdrnnCell,
                                             latent_dim=LSIZE,
                                             hidden_dim=RSIZE,
                                             action_dim=ASIZE)
            interim_policy_ope = ope(mdrnnCell, interim_policy)
            loss = losses['loss'] - interim_policy_ope.squeeze(dim=0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            with torch.no_grad():
                losses = get_loss(latent_obs, action, reward, terminal,
                                  latent_next_obs, include_reward)

        # cum_loss += losses['loss'].item()
        cum_loss += loss.item()
        cum_gmm += losses['gmm'].item()
        cum_bce += losses['bce'].item()
        cum_mse += losses['mse'].item() if hasattr(losses['mse'], 'item') else \
            losses['mse']

        pbar.set_postfix_str("loss={loss:10.6f} bce={bce:10.6f} "
                             "gmm={gmm:10.6f} mse={mse:10.6f}".format(
                                 loss=cum_loss / (i + 1),
                                 bce=cum_bce / (i + 1),
                                 gmm=cum_gmm / LSIZE / (i + 1),
                                 mse=cum_mse / (i + 1)))
        pbar.update(BSIZE)
    pbar.close()
    return cum_loss * BSIZE / len(loader.dataset)