Exemplo n.º 1
0
 def _imagine(self, start, policy, horizon, repeats=None):
   dynamics = self._world_model.dynamics
   if repeats:
     start = {k: tf.repeat(v, repeats, axis=1) for k, v in start.items()}
   flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
   start = {k: flatten(v) for k, v in start.items()}
   def step(prev, _):
     state, _, _ = prev
     feat = dynamics.get_feat(state)
     inp = tf.stop_gradient(feat) if self._stop_grad_actor else feat
     action = policy(inp).sample()
     succ = dynamics.img_step(state, action, sample=self._config.imag_sample)
     return succ, feat, action
   feat = 0 * dynamics.get_feat(start)
   action = policy(feat).mode()
   succ, feats, actions = tools.static_scan(
       step, tf.range(horizon), (start, feat, action))
   states = {k: tf.concat([
       start[k][None], v[:-1]], 0) for k, v in succ.items()}
   if repeats:
     def unfold(tensor):
       s = tensor.shape
       return tf.reshape(tensor, [s[0], s[1] // repeats, repeats] + s[2:])
     states, feats, actions = tf.nest.map_structure(
         unfold, (states, feats, actions))
   return feats, states, actions
Exemplo n.º 2
0
 def imagine(self, action, state=None):
     if state is None:
         state = self.initial(tf.shape(action)[0])
     assert isinstance(state, dict), state
     action = tf.transpose(action, [1, 0, 2])
     prior = tools.static_scan(self.img_step, action, state)
     prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
     return prior
Exemplo n.º 3
0
 def imagine(self, action, state=None):
     swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape))))
     if state is None:
         state = self.initial(tf.shape(action)[0])
     assert isinstance(state, dict), state
     action = swap(action)
     prior = tools.static_scan(self.img_step, action, state)
     prior = {k: swap(v) for k, v in prior.items()}
     return prior
Exemplo n.º 4
0
    def _imagine_ahead(self, start, imag_depth):
        def rollout_traj(prev_state, _):
            cur_feat = tf.stop_gradient(self._dynamics.get_feat(prev_state))
            action = self._actor(cur_feat).sample()
            return self._dynamics.img_step(prev_state, action)

        states = tools.static_scan(rollout_traj, tf.range(imag_depth), start)
        imag_feat = self._dynamics.get_feat(states)
        return imag_feat, states
Exemplo n.º 5
0
 def observe(self, embed, action, state=None):
     if state is None:
         state = self.initial(tf.shape(action)[0])
     embed = tf.transpose(embed, [1, 0, 2])
     action = tf.transpose(action, [1, 0, 2])
     post, prior = tools.static_scan(lambda prev, inputs: self.obs_step(prev[0], *inputs),
                                     (action, embed), (state, state))
     post = {k: tf.transpose(v, [1, 0, 2]) for k, v in post.items()}
     prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
     return post, prior
Exemplo n.º 6
0
 def observe(self, embed, action, state=None):
     swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape))))
     if state is None:
         state = self.initial(tf.shape(action)[0])
     embed, action = swap(embed), swap(action)
     post, prior = tools.static_scan(
         lambda prev, inputs: self.obs_step(prev[0], *inputs),
         (action, embed), (state, state))
     post = {k: swap(v) for k, v in post.items()}
     prior = {k: swap(v) for k, v in prior.items()}
     return post, prior
 def observe(self, embed, action, state=None):
   if state is None: # No state --> set state to 0
     state = self.initial(tf.shape(action)[0])
   embed = tf.transpose(embed, [1, 0, 2]) # BS,Length,Feat --> Length, BS, feat
   action = tf.transpose(action, [1, 0, 2]) # BS,Length,Feat --> Length, BS, feat
   post, prior = tools.static_scan(
       lambda prev, inputs: self.obs_step(prev[0], *inputs), # transforms the data to state, action, embed
       (action, embed), (state, state)) # Applies obs_step to each element of the sequence in a batch fashion
   post = {k: tf.transpose(v, [1, 0, 2]) for k, v in post.items()} # Undo previous transpose
   prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()} # Undo previous transpose
   return post, prior #post : (state,state), prior:
Exemplo n.º 8
0
 def _imagine_ahead(self, post):
     if self._c.pcont:  # Last step could be terminal.
         post = {k: v[:, :-1] for k, v in post.items()}
     flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
     start = {k: flatten(v) for k, v in post.items()}
     policy = lambda state: np.ones((100, 6), dtype=np.float16)
     states = tools.static_scan(
         lambda prev, _: self._dynamics.img_step(prev, policy(prev)),
         tf.range(self._c.horizon), start)
     imag_feat = self._dynamics.get_feat(states)
     return imag_feat
Exemplo n.º 9
0
 def _imagine_ahead(self, post):
     if self._c.pcont:  # Last step could be terminal.
         post = {k: v[:, :-1] for k, v in post.items()}
     flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
     start = {k: flatten(v) for k, v in post.items()}
     policy = lambda state: self._actor(
         tf.stop_gradient(self._dynamics.get_feat(state)), training=True).sample()
     states = tools.static_scan(
         lambda prev, _: self._dynamics.img_step(prev, policy(prev)),
         tf.range(self._c.horizon), start)
     imag_feat = self._dynamics.get_feat(states)
     return imag_feat
Exemplo n.º 10
0
 def imagine(self, action, desc, state=None):
     if state is None:
         state = self.initial(tf.shape(action)[0])
     assert isinstance(state, dict), state
     action = tf.transpose(action, [1, 0, 2])
     desc = tf.transpose(desc, [1, 0, 2])
     #NOTE:ADDDESC
     prior, _ = tools.static_scan(
         lambda prev, inputs: self.img_step(prev[0], *inputs),
         (action, desc), (state, None))
     #prior = tools.static_scan(self.img_step, action, state)
     prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
     return prior
Exemplo n.º 11
0
 def _imagine_ahead(self, post):
     if self._c.pcont:  # Last step could be terminal.
         post = {k: v[:, :-1] for k, v in post.items()}
     flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
     start = {k: flatten(v) for k, v in post.items()}
     dtype = prec.global_policy().compute_dtype
     start['prob'] = tf.zeros([start['mean'].shape[0]], dtype)
     policy = lambda state: self._actor(
         tf.stop_gradient(self._dynamics.get_feat(state))).sample()
     states = tools.static_scan(
         lambda prev, _: self._dynamics.img_step(prev, policy(prev), calprob=True),
         tf.range(self._c.horizon), start)
     imag_feat = self._dynamics.get_feat(states)
     return imag_feat, tf.stop_gradient(states['prob'])
Exemplo n.º 12
0
    def observe(self, embed, action, desc, state=None):
        #:param: desc: (batch,length,desc_len)
        #:param: embed: (batch,length,hidden)
        if state is None:
            state = self.initial(tf.shape(action)[0])
        embed = tf.transpose(embed, [1, 0, 2])
        action = tf.transpose(action, [1, 0, 2])
        desc = tf.transpose(desc, [1, 0, 2])
        #NOTE:ADDDESC
        post, prior = tools.static_scan(
            lambda prev, inputs: self.obs_step(prev[0], *inputs),
            (action, embed, desc), (state, state))

        post = {k: tf.transpose(v, [1, 0, 2]) for k, v in post.items()}
        prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
        return post, prior
Exemplo n.º 13
0
 def _imagine_ahead(self, post, desc):
     if self._c.pcont:  # Last step could be terminal.
         post = {k: v[:, :-1] for k, v in post.items()}
         desc = desc[:, :-1]
     flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
     start = {k: flatten(v) for k, v in post.items()}
     desc = flatten(desc)
     policy = lambda state: self._actor(
         tf.stop_gradient(self._dynamics.get_feat(state))).sample()
     '''
     states = tools.static_scan(
         lambda prev, _: self._dynamics.img_step(prev, policy(prev)),
         tf.range(self._c.horizon), start)
     '''
     #NOTE:ADDDESC
     #prev_state, prev_action, desc
     states, _ = tools.static_scan(
         lambda prev, _: self._dynamics.img_step(prev[0], policy(prev[0]),
                                                 prev[1]),
         tf.range(self._c.horizon), (start, desc))
     imag_feat = self._dynamics.get_feat(states)
     return imag_feat
Exemplo n.º 14
0
    def _imagine_ahead(self, post):
        if self._c.pcont:  # Last step could be terminal.
            # post: {'mean': mean, 'std': std, 'stoch': stoch, 'deter': prior['deter']}
            post = {k: v[:, :-1] for k, v in post.items()}
            # post['mean'].shape: (25, 49, 30) for self._c.pcont = True
            # post['mean'].shape: (25, 50, 30) for self._c.pcont = False
        flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
        start = {k: flatten(v)
                 for k, v in post.items()}  # flatten evey entry of dict
        # start: [('mean', TensorShape([1225, 30])), ('std', TensorShape([1225, 30])), ('stoch', TensorShape([1225, 30])), ('deter', TensorShape([1225, 200]))]

        policy = lambda state: self._actor(
            tf.stop_gradient(self._dynamics.get_feat(state))).sample()

        states = tools.static_scan(
            lambda prev, _: self._dynamics.img_step(prev, policy(prev)),
            tf.range(self._c.horizon), start
        )  # scaning to get prior for each prev state, step(policy&world model) for horizon(15) steps
        imag_feat = self._dynamics.get_feat(
            states)  # concate state and obs # (15, 1225, 230)

        return imag_feat