Exemple #1
0
    def forward(self, obs, extract_sym_features: bool = False):
        imgs = obs.img
        lead_dim, T, B, img_shape = infer_leading_dims(imgs, 3)
        imgs = imgs.view(T * B, *img_shape)
        # NHWC -> NCHW
        imgs = imgs.transpose(3, 2).transpose(2, 1)

        # TODO: don't do uint8 -> float32 conversion twice (in detector and here)
        if self.sym_extractor is not None and extract_sym_features:
            sym_feats = self.sym_extractor(imgs[:, -1:])
        else:
            sym_feats = None

        # convert from [0, 255] uint8 to [0, 1] float32
        imgs = imgs.to(torch.float32)
        imgs.mul_(1.0 / 255)

        feats = self.feature_extractor(imgs)
        vector_obs = obs.vector.view(-1, obs.vector.shape[-1])
        feats = torch.cat((feats, vector_obs), -1)
        feats = self.relu(self.linear(feats))
        value = self.value_predictor(feats).squeeze(
            -1)  # squeezing seems expected?

        if self.categorical:
            action_dist = self.action_predictor(feats)
            action_dist, value = restore_leading_dims((action_dist, value),
                                                      lead_dim, T, B)
            if self.sym_extractor is not None and extract_sym_features:
                # have to "restore dims" for sym_feats manually...
                sym_feats = sym_feats.reshape((*action_dist.shape[:-1], -1))
                # sym_feats = torch.rand_like(action_dist)
            return action_dist, value, sym_feats
        else:
            mu = self.mu_predictor(feats)
            log_std = self.log_std_predictor(feats)
            mu, log_std, value = restore_leading_dims((mu, log_std, value),
                                                      lead_dim, T, B)
            if self.sym_extractor is not None and extract_sym_features:
                # have to "restore dims" for sym_feats manually...
                sym_feats = sym_feats.reshape((*mu.shape[:-1], -1))
                # sym_feats = torch.rand_like(action_dist)
            return mu, log_std, value, sym_feats
Exemple #2
0
    def forward(self, observation, prev_action, prev_reward):
        """
        Feedforward layers process as [T*B,H]. Return same leading dims as input, can be [T,B], [B], or [].

        forward是在agent类 DqnAgent 里面的 step() 函数里隐式调用的,由于 torch.nn.Module 类定义了 __call__(),因此 DqnAgent.step()
        里面通过 self.model(*model_inputs) 这种方式就相当于调用了forward()。
        """
        img = observation.type(torch.float)  # Expect torch.uint8 inputs
        img = img.mul_(1. / 255)  # From [0-255] to [0-1], in place.

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)

        conv_out = self.conv(img.view(T * B, *img_shape))  # Fold if T dimension.
        q = self.head(conv_out.view(T * B, -1))

        # Restore leading dimensions: [T,B], [B], or [], as input.
        q = restore_leading_dims(q, lead_dim, T, B)
        return q
    def forward(self, image, prev_action=None, prev_reward=None):
        #input normalization, cast to float then grayscale it
        img = image.type(torch.float)
        img = img.mul_(1. / 255)
        
        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)
        
        img = img.view(T * B, *img_shape)
        if self.augment_obs != None:
            b, c, h, w = img.shape
            mask_vbox = torch.zeros(size=img.shape, dtype=torch.bool, device=img.device)

            mh = math.ceil(h * 0.20)
            #2 squares side by side
            mw = mh * 2
            ##create velocity mask -> False where velocity box is, True rest of the screen
            vmask = torch.ones((b, c, mh, mw), dtype=torch.bool, device=img.device)
            mask_vbox[:,:,:mh,:mw] = vmask
            obs_without_vbox = torch.where(mask_vbox, torch.zeros_like(img), img)
            
            if self.augment_obs == 'cutout':
                augmented = random_cutout_color(obs_without_vbox)
            elif self.augment_obs == 'jitter':
                if self.transform is None:
                    self.transform = ColorJitterLayer(b)
                augmented = self.transform(obs_without_vbox)
            elif self.augment_obs == 'rand_conv':
                augmented = random_convolution(obs_without_vbox)

            fixed = torch.where(mask_vbox, img, augmented)
            img = fixed


        fc_out = self.conv(img)
        pi = F.softmax(self.pi(fc_out), dim=-1)
        v = self.value(fc_out).squeeze(-1)
        # Restore leading dimensions: [T,B], [B], or [], as input.
        #T -> transition
        #B -> batch_size?
        pi, v = restore_leading_dims((pi, v), lead_dim, T, B)

        return pi, v
    def forward(self, observation, prev_action, prev_reward, init_rnn_state):
        """Feedforward layers process as [T*B,H]. Return same leading dims as
        input, can be [T,B], [B], or []."""
        img = observation.type(torch.float)  # Expect torch.uint8 inputs
        img = img.mul_(1. / 255)  # From [0-255] to [0-1], in place.

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)

        conv_out = self.conv(img.view(T * B,
                                      *img_shape))  # Fold if T dimension.

        if self.use_recurrence:
            rnn_input = torch.cat(
                [
                    conv_out.view(T, B, -1),
                    prev_action.view(T, B, -1),  # Assumed onehot.
                    prev_reward.view(T, B, 1),
                ],
                dim=2)
            init_rnn_state = None if init_rnn_state is None else tuple(
                init_rnn_state)
        else:
            rnn_input = torch.cat(
                [
                    conv_out.view(T, B, -1),
                    prev_action.view(T, B, -1),  # Assumed onehot.
                    prev_reward.view(T, B, 1),
                ],
                dim=2)
            init_rnn_state = None

        rnn_out, (hn, ) = self.run_rnn(rnn_input, init_rnn_state)

        q = self.head(rnn_out.view(T * B, -1))

        # Restore leading dimensions: [T,B], [B], or [], as input.
        q = restore_leading_dims(q, lead_dim, T, B)
        # Model should always leave B-dimension in rnn state: [N,B,H].
        next_rnn_state = RnnState(h=hn)

        return q, next_rnn_state
Exemple #5
0
    def forward(self, observation, prev_action, prev_reward):
        """
        Compute mean, log_std, and value estimate from input state. Infers
        leading dimensions of input: can be [T,B], [B], or []; provides
        returns with same leading dims.  Intermediate feedforward layers
        process as [T*B,H], with T=1,B=1 when not given. Used both in sampler
        and in algorithm (both via the agent).
        """
        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim)

        obs_flat = observation.view(T * B, -1)
        mu = self.mu(obs_flat)
        v = self.v(obs_flat).squeeze(-1)
        log_std = self.log_std.repeat(T * B, 1)

        # Restore leading dimensions: [T,B], [B], or [], as input.
        mu, log_std, v = restore_leading_dims((mu, log_std, v), lead_dim, T, B)

        return mu, log_std, v
Exemple #6
0
    def forward(self, obs, extract_sym_features: bool = False):
        lead_dim, T, B, _ = infer_leading_dims(obs, 1)

        feats = self.feature_extractor(obs.view(T * B, -1))
        value = self.value_predictor(feats).squeeze(
            -1)  # squeezing seems expected?

        sym_feats = obs if extract_sym_features else None

        if self.categorical:
            action_dist = self.action_predictor(feats)
            action_dist, value = restore_leading_dims((action_dist, value),
                                                      lead_dim, T, B)
            return action_dist, value, sym_feats
        else:
            mu = self.mu_predictor(feats)
            log_std = self.log_std_predictor(feats)
            mu, log_std, value = restore_leading_dims((mu, log_std, value),
                                                      lead_dim, T, B)
            return mu, log_std, value, sym_feats
 def forward(self, image, prev_action, prev_reward):
     """
     Compute action probabilities and value estimate from input state.
     Infers leading dimensions of input: can be [T,B], [B], or []; provides
     returns with same leading dims.  Convolution layers process as [T*B,
     *image_shape], with T=1,B=1 when not given.  Expects uint8 images in
     [0,255] and converts them to float32 in [0,1] (to minimize image data
     storage and transfer).  Used in both sampler and in algorithm (both
     via the agent).
     """
     img = image.type(torch.float)  # Expect torch.uint8 inputs
     img = img.mul_(1. / 255)  # From [0-255] to [0-1], in place.
     # Infer (presence of) leading dimensions: [T,B], [B], or [].
     lead_dim, T, B, img_shape = infer_leading_dims(img, 3)
     fc_out = self.conv(img.view(T * B, *img_shape))
     pi = F.softmax(self.pi(fc_out), dim=-1)
     v = self.value(fc_out).squeeze(-1)
     # Restore leading dimensions: [T,B], [B], or [], as input.
     pi, v = restore_leading_dims((pi, v), lead_dim, T, B)
     return pi, v
Exemple #8
0
    def forward(self, obs, q_or_pi: str, extract_sym_features: bool = False):
        """
        :returns: (q1, q2, sym_features) or (mu, log_std, sym_features)
        """
        if q_or_pi == "q":
            obs, actions = obs
        imgs = obs.img
        lead_dim, T, B, img_shape = infer_leading_dims(imgs, 3)
        imgs = imgs.view(T * B, *img_shape)
        # NHWC -> NCHW
        imgs = imgs.transpose(3, 2).transpose(2, 1)

        # TODO: don't do uint8 -> float32 conversion twice (in detector and here)
        if self.sym_extractor is not None and extract_sym_features:
            sym_feats = self.sym_extractor(imgs[:, -1:])
        else:
            sym_feats = None

        # convert from [0, 255] uint8 to [0, 1] float32
        imgs = imgs.to(torch.float32)
        imgs = imgs.mul(1.0 / 255)

        vector_obs = obs.vector.view(-1, obs.vector.shape[-1])
        feats = self.feature_extractor(imgs, vector_obs)

        if q_or_pi == "q":
            feats = torch.cat((feats, actions.view(-1, actions.shape[-1])), -1)
            r1 = self.q1(feats)
            r2 = self.q2(feats)
        elif q_or_pi == "pi":
            r = self.pi(feats)
            r1 = r[:, :self.action_dim]  # mu
            r2 = r[:, self.action_dim:]  # log_std
        else:
            raise ValueError("q_or_pi must be 'q' or 'pi'.")

        r1, r2 = restore_leading_dims((r1, r2), lead_dim, T, B)
        if self.sym_extractor is not None and extract_sym_features:
            # have to "restore dims" for sym_feats manually...
            sym_feats = sym_feats.reshape((*r1.shape[:-1], -1))
        return r1, r2, sym_feats
    def forward(self, observation, prev_action, prev_reward, init_rnn_state):
        """
        Compute mean, log_std, and value estimate from input state. Infer
        leading dimensions of input: can be [T,B], [B], or []; provides
        returns with same leading dims.  Intermediate feedforward layers
        process as [T*B,H], and recurrent layers as [T,B,H], with T=1,B=1 when
        not given. Used both in sampler and in algorithm (both via the agent).
        Also returns the next RNN state.
        """
        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_n_dim)

        if self.normalize_observation:
            obs_var = self.obs_rms.var
            if self.norm_obs_var_clip is not None:
                obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip)
            observation = torch.clamp(
                (observation - self.obs_rms.mean) / obs_var.sqrt(),
                -self.norm_obs_clip, self.norm_obs_clip)

        mlp_out = self.mlp(observation.view(T * B, -1))
        lstm_input = torch.cat([
            mlp_out.view(T, B, -1),
            prev_action.view(T, B, -1),
            prev_reward.view(T, B, 1),
        ],
                               dim=2)
        init_rnn_state = None if init_rnn_state is None else tuple(
            init_rnn_state)
        lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state)
        outputs = self.head(lstm_out.view(T * B, -1))
        mu = outputs[:, :self._action_size]
        log_std = outputs[:, self._action_size:-1]
        v = outputs[:, -1]

        # Restore leading dimensions: [T,B], [B], or [], as input.
        mu, log_std, v = restore_leading_dims((mu, log_std, v), lead_dim, T, B)
        # Model should always leave B-dimension in rnn state: [N,B,H]
        next_rnn_state = RnnState(h=hn, c=cn)

        return mu, log_std, v, next_rnn_state
Exemple #10
0
    def next(self, actions, observation, prev_action, prev_reward):
        if isinstance(observation, tuple):
            observation = torch.cat(observation, dim=-1)

        lead_dim, T, B, _ = infer_leading_dims(observation,
                                               self._obs_ndim)
        input_obs = observation.view(T * B, -1)
        if self._counter == 0:
            output = self.mlp_loc(input_obs)
            mu, log_std = output.chunk(2, dim=-1)
        elif self._counter == 1:
            assert len(actions) == 1
            action_loc = actions[0].view(T * B, -1)
            model_input = torch.cat((input_obs, action_loc.repeat((1, self._n_tile))), dim=-1)
            output = self.mlp_delta(model_input)
            mu, log_std = output.chunk(2, dim=-1)
        else:
            raise Exception('Invalid self._counter', self._counter)
        mu, log_std = restore_leading_dims((mu, log_std), lead_dim, T, B)
        self._counter += 1
        return mu, log_std
Exemple #11
0
    def forward(self, image, prev_action, prev_reward):
        """
        Overrides AtariFfModel forward to also run separate
        intrinsic value head.
        """
        img = image.type(torch.float)  # Expect torch.uint8 inputs
        img = img.mul_(1. / 255)  # From [0-255] to [0-1], in place.

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)

        fc_out = self.conv(img.view(T * B, *img_shape))
        pi = F.softmax(self.pi(fc_out), dim=-1)
        ext_val = self.value(fc_out).squeeze(-1)
        int_val = self.int_value(fc_out).squeeze(-1)

        # Restore leading dimensions: [T,B], [B], or [], as input.
        pi, ext_val, int_val = restore_leading_dims((pi, ext_val, int_val),
                                                    lead_dim, T, B)

        return pi, ext_val, int_val
Exemple #12
0
    def forward(self,
                observation: torch.Tensor,
                prev_action: torch.Tensor = None,
                prev_state: RSSMState = None):
        lead_dim, T, B, img_shape = infer_leading_dims(observation, 3)
        observation = observation.reshape(T * B, *img_shape).type(
            self.dtype) / 255.0 - 0.5
        prev_action = prev_action.reshape(T * B, -1).to(self.dtype)
        if prev_state is None:
            prev_state = self.representation.initial_state(
                prev_action.size(0),
                device=prev_action.device,
                dtype=self.dtype)
        state = self.get_state_representation(observation, prev_action,
                                              prev_state)

        action, action_dist = self.policy(state)
        return_spec = ModelReturnSpec(action, state)
        return_spec = buffer_func(return_spec, restore_leading_dims, lead_dim,
                                  T, B)
        return return_spec
Exemple #13
0
    def write_videos(self, observation, action, image_pred, post, step=None, n=4, t=25):
        """
        observation shape T,N,C,H,W
        generates n rollouts with the model.
        For t time steps, observations are used to generate state representations.
        Then for time steps t+1:T, uses the state transition model.
        Outputs 3 different frames to video: ground truth, reconstruction, error
        """
        lead_dim, batch_t, batch_b, img_shape = infer_leading_dims(observation, 3)
        model = self.agent.model
        ground_truth = observation[:, :n] + 0.5
        reconstruction = image_pred.mean[:t, :n]

        prev_state = post[t - 1, :n]
        prior = model.rollout.rollout_transition(batch_t - t, action[t:, :n], prev_state)
        imagined = model.observation_decoder(get_feat(prior)).mean
        model = torch.cat((reconstruction, imagined), dim=0) + 0.5
        error = (model - ground_truth + 1) / 2
        # concatenate vertically on height dimension
        openl = torch.cat((ground_truth, model, error), dim=3)
        openl = openl.transpose(1, 0)  # N,T,C,H,W
        video_summary('videos/model_error', torch.clamp(openl, 0., 1.), step)
Exemple #14
0
    def forward(self, observation, prev_action, prev_reward):
        """
        Compute action Q-value estimates from input state.
        Infers leading dimensions of input: can be [T,B], [B], or []; provides
        returns with same leading dims.  Convolution layers process as [T*B,
        image_shape[0], image_shape[1],...,image_shape[-1]], with T=1,B=1 when not given.  Expects uint8 images in
        [0,255] and converts them to float32 in [0,1] (to minimize image data
        storage and transfer).  Used in both sampler and in algorithm (both
        via the agent).
        """
        img = observation.type(torch.float)  # Expect torch.uint8 inputs
        img = img.mul_(1. / 255)  # From [0-255] to [0-1], in place.

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)

        conv_out = self.conv(img.view(T * B, *img_shape))  # Fold if T dimension.
        q = self.head(conv_out.view(T * B, -1))

        # Restore leading dimensions: [T,B], [B], or [], as input.
        q = restore_leading_dims(q, lead_dim, T, B)
        return q
Exemple #15
0
    def forward(self, image, prev_action, prev_reward, init_rnn_state):
        """
        Compute action probabilities and value estimate from input state.
        Infers leading dimensions of input: can be [T,B], [B], or []; provides
        returns with same leading dims.  Convolution layers process as [T*B,
        *image_shape], with T=1,B=1 when not given.  Expects uint8 images in
        [0,255] and converts them to float32 in [0,1] (to minimize image data
        storage and transfer).  Recurrent layers processed as [T,B,H]. Used in
        both sampler and in algorithm (both via the agent).  Also returns the
        next RNN state.
        """
        if self.obs_stats is not None:  # don't normalize observation
            image = (image - self.obs_mean) / self.obs_std

        img = image.type(torch.float)  # Expect torch.uint8 inputs

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)

        fc_out = self.conv(img.view(T * B, *img_shape))
        lstm_input = torch.cat(
            [
                fc_out.view(T, B, -1),
                prev_action.view(T, B, -1),  # Assumed onehot.
            ],
            dim=2)
        init_rnn_state = None if init_rnn_state is None else tuple(
            init_rnn_state)
        lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state)

        pi = F.softmax(self.pi(lstm_out.view(T * B, -1)), dim=-1)
        v = self.value(lstm_out.view(T * B, -1)).squeeze(-1)

        # Restore leading dimensions: [T,B], [B], or [], as input.
        pi, v = restore_leading_dims((pi, v), lead_dim, T, B)
        # Model should always leave B-dimension in rnn state: [N,B,H].
        next_rnn_state = RnnState(h=hn, c=cn)

        return pi, v, next_rnn_state
Exemple #16
0
    def forward(self, observation, prev_action=None, prev_reward=None):
        self.device = observation.device
        lead_dim, T, B, img_shape = infer_leading_dims(observation, 3)
        obs = observation.view(T*B, *img_shape)
        obs = obs.permute(0, 3, 1, 2).float() / 255.
        if self.greyscale:
            obs = torch.mean(obs,dim=1, keepdims=True)
        noise_idx = None
        if self.noise_prob:
            obs, noise_idx = salt_and_pepper(obs,self.noise_prob)

        z, mu, logsd = self.encoder(obs)
        reconstruction = self.decoder(z)
        extractor_in = mu if self.deterministic else z

        if self.detach_vae:
            extractor_in = extractor_in.detach()
        extractor_out = self.shared_extractor(extractor_in)

        if self.detach_policy:
            policy_in = extractor_out.detach()
        else:
            policy_in = extractor_out
        if self.detach_value:
            value_in = extractor_out.detach()
        else:
            value_in = extractor_out
        
        act_dist = self.policy(policy_in)
        value = self.value(value_in).squeeze(-1)

        if self.rae:
            latent = mu
        else:
            latent = torch.cat((mu, logsd), dim=1)
        
        act_dist, value, latent, reconstruction = restore_leading_dims((act_dist, value, latent, reconstruction), lead_dim, T, B)
        return act_dist, value, latent, reconstruction, noise_idx
    def forward(self, observation, prev_action, prev_reward):
        """
        Compute mean, log_std, q-value, and termination estimates from input state. Infers
        leading dimensions of input: can be [T,B], [B], or []; provides
        returns with same leading dims.  Intermediate feedforward layers
        process as [T*B,H], with T=1,B=1 when not given. Used both in sampler
        and in algorithm (both via the agent).
        """
        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim)
        if self.normalize_observation:
            obs_var = self.obs_rms.var
            if self.norm_obs_var_clip is not None:
                obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip)
            observation = torch.clamp((observation - self.obs_rms.mean) /
                obs_var.sqrt(), -self.norm_obs_clip, self.norm_obs_clip)

        obs_flat = observation.view(T * B, -1)
        (mu, logstd), beta, q, pi_I, q_ent = self.model(obs_flat)
        log_std = logstd.repeat(T * B, 1, 1)
        # Restore leading dimensions: [T,B], [B], or [], as input.
        mu, log_std, q, beta, pi, q_ent = restore_leading_dims((mu, log_std, q, beta, pi_I, q_ent), lead_dim, T, B)
        return mu, log_std, beta, q, pi
Exemple #18
0
    def head_forward(self,
                     conv_out,
                     prev_action,
                     prev_reward,
                     logits=False):
        if len(conv_out.shape) > 2:
            lead_dim, T, B, img_shape = infer_leading_dims(conv_out, 3)
        # 1, 1, 32, [64, 7, 7]
        p = self.head(conv_out) # [32, 4, 51]

        if self.distributional:
            if logits:
                p = F.log_softmax(p, dim=-1)
            else:
                p = F.softmax(p, dim=-1)
        else:
            p = p.squeeze(-1)

        # Restore leading dimensions: [T,B], [B], or [], as input.
        if len(conv_out.shape) > 2:
            p = restore_leading_dims(p, lead_dim, T, B) # [32, 4, 51]
        # [32, 4, 51]
        return p
Exemple #19
0
    def forward(self, observation, prev_action, prev_reward):
        if observation.dtype == torch.uint8:
            img = observation.type(torch.float)
            img = img.mul_(1.0 / 255)
        else:
            img = observation

        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)
        conv = self.conv(img.view(T * B, *img_shape))

        if self.stop_conv_grad:
            conv = conv.detach()
        if self.normalize_conv_out:
            conv_var = self.conv_rms.var
            conv_var = torch.clamp(conv_var, min=self.var_clip)
            # stddev of uniform [a,b] = (b-a)/sqrt(12), 1/sqrt(12)~0.29
            # then allow [0, 10]?
            conv = torch.clamp(0.29 * conv / conv_var.sqrt(), 0, 10)

        q = self.q_mlp(conv.view(T * B, -1))

        q, conv = restore_leading_dims((q, conv), lead_dim, T, B)
        return q, conv
    def forward(self, observation, prev_action=None, prev_reward=None):
        """
        Compute action Q-value estimates from input state.
        Infers leading dimensions of input: can be [T,B], [B], or []; provides
        returns with same leading dims.  Convolution layers process as [T*B,
        image_shape[0], image_shape[1],...,image_shape[-1]], with T=1,B=1 when not given.  Expects uint8 images in
        [0,255] and converts them to float32
        Used in both sampler and in algorithm (both via the agent).
        """
        obs = observation.type(torch.float)  # Expect torch.uint8 inputs
        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        if len(observation.shape) < 3:
            reshape = (observation.shape[0], -1) if len(observation.shape) == 2 else (-1,)
            view = obs.view(*reshape)
            return self.head(view)

        lead_dim, T, B, _ = infer_leading_dims(obs, 3)

        q = self.head(obs.view(T * B, -1))

        # Restore leading dimensions: [T,B], [B], or [], as input.
        q = restore_leading_dims(q, lead_dim, T, B)
        return q
Exemple #21
0
    def forward(self, observation, prev_action, prev_reward):
        """Args:
        x: tensor shape [batch_size, input_size]
        train: a boolean scalar.
        loss_coef: a scalar - multiplier on load-balancing losses
        Returns:
        y: a tensor with shape [batch_size, output_size].
        extra_training_loss: a scalar.  This should be added into the overall
        training loss of the model.  The backpropagation of this loss
        encourages all experts to be approximately equally used across a batch.
        """
        train = self.training
        observation = observation.float()

        self.device = observation.device

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, obs_shape = infer_leading_dims(observation, 1)
        observation = observation.view(T * B, *obs_shape)
        action_mask = observation[:, -19:].type(torch.bool)
        observation = observation[:, :-19]

        z = self.encoder(observation)
        gates, load = self.noisy_top_k_gating(z, train)

        dispatcher = SparseDispatcher(self.num_experts, gates)
        expert_inputs = dispatcher.dispatch(z)
        gates = dispatcher.expert_to_gates()
        expert_outputs = [
            self.experts[i](expert_inputs[i]) for i in range(self.num_experts)
        ]
        y = dispatcher.combine(expert_outputs, device=self.device)
        value = self.value(observation).squeeze(-1)
        y[~action_mask] = -1e24
        y = nn.functional.softmax(y, dim=-1)
        y, value = restore_leading_dims((y, value), lead_dim, T, B)
        return y, value
Exemple #22
0
    def forward(self, observation, prev_action, prev_reward, init_rnn_state):
        """Feedforward layers process as [T*B,H]. Return same leading dims as
        input, can be [T,B], [B], or []."""
        obz = observation.type(torch.float)  # Expect torch.uint8 inputs

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_n_dim)

        if self.normalize_observation:
            obs_var = self.obs_rms.var
            if self.norm_obs_var_clip is not None:
                obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip)
            observation = torch.clamp(
                (observation - self.obs_rms.mean) / obs_var.sqrt(),
                -self.norm_obs_clip, self.norm_obs_clip)

        mlp_out = self.mlp(observation.view(T * B, -1))

        lstm_input = torch.cat([
            mlp_out.view(T, B, -1),
            prev_action.view(T, B, -1),
            prev_reward.view(T, B, 1),
        ],
                               dim=2)

        init_rnn_state = None if init_rnn_state is None else tuple(
            init_rnn_state)
        lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state)

        q = self.head(lstm_out.view(T * B, -1))

        # Restore leading dimensions: [T,B], [B], or [], as input.
        q = restore_leading_dims(q, lead_dim, T, B)
        # Model should always leave B-dimension in rnn state: [N,B,H].
        next_rnn_state = RnnState(h=hn, c=cn)

        return q, next_rnn_state
Exemple #23
0
    def forward(self, obs, prev_action, prev_reward):
        """Feedforward layers process as [T*B,H]. Return same leading dims as
        input, can be [T,B], [B], or []."""
        # print(obs.shape)
        # print(obs.target_im.shape)
        # print(obs.cur_im.shape)
        # x1 = self.conv(obs.target_im)
        # x2 = self.conv(obs.cur_im)

        # obs.cur_coord
        lead_dim, T, B, img_shape = infer_leading_dims(obs.target_im, 3)

        x1 = self.conv(obs.target_im.view(T * B, *img_shape))
        x2 = self.conv(obs.cur_im.view(T * B, *img_shape))
        x = torch.cat((x1, x2), dim=1)

        # x = self.conv(obs.view(T * B, *img_shape))
        x = self.fc(x.view(T * B, -1))

        pi = F.softmax(self.pi(x), dim=-1)
        v = self.value(x).squeeze(-1)
        pi, v = restore_leading_dims((pi, v), lead_dim, T, B)

        return pi, v
Exemple #24
0
 def forward(self, observation, prev_action, prev_reward, init_rnn_state):
     lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_dim)
     if init_rnn_state is not None:
         if self.rnn_is_lstm:
             init_rnn_pi, init_rnn_v = tuple(init_rnn_state)  # DualRnnState -> RnnState_pi, RnnState_v
             init_rnn_pi, init_rnn_v = tuple(init_rnn_pi), tuple(init_rnn_v)
         else:
             init_rnn_pi, init_rnn_v = tuple(init_rnn_state)  # DualRnnState -> h, h
     else:
         init_rnn_pi, init_rnn_v = None, None
     o_flat = observation.view(T*B, -1)
     b_pi, b_v = self.body_pi(o_flat), self.body_v(o_flat)
     rnn_input_pi = torch.cat([b_pi.view(T,B,-1),prev_action.view(T, B, -1),prev_reward.view(T, B, 1),], dim=2)
     rnn_input_v = torch.cat([b_v.view(T, B, -1), prev_action.view(T, B, -1), prev_reward.view(T, B, 1), ], dim=2)
     rnn_pi, next_rnn_state_pi = self.rnn_pi(rnn_input_pi, init_rnn_pi)
     rnn_v, next_rnn_state_v = self.rnn_pi(rnn_input_v, init_rnn_v)
     rnn_pi = rnn_pi.view(T*B, -1); rnn_v = rnn_v.view(T*B, -1)
     pi, v = self.pi(rnn_pi), self.v(rnn_v).squeeze(-1)
     pi, v = restore_leading_dims((pi, v), lead_dim, T, B)
     if self.rnn_is_lstm:
         next_rnn_state = DualRnnState(RnnState(*next_rnn_state_pi), RnnState(*next_rnn_state_v))
     else:
         next_rnn_state = DualRnnState(next_rnn_state_pi, next_rnn_state_v)
     return pi, v, next_rnn_state
Exemple #25
0
 def forward(self, observation, prev_action, prev_reward):
     obs_shape, T, B, has_T, has_B = infer_leading_dims(
         observation, self._obs_ndim)
     v = self.mlp(observation.view(T * B, -1)).squeeze(-1)
     v = restore_leading_dims(v, T, B, has_T, has_B)
     return v
Exemple #26
0
 def forward(self, observation, prev_action, prev_reward):
     lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim)
     mu = self._output_max * torch.tanh(
         self.mlp(observation.view(T * B, -1)))
     mu = restore_leading_dims(mu, lead_dim, T, B)
     return mu
    def evaluate(self, obs_tuple, act_tensor, update_stats=True):
        # put model into eval mode if necessary
        old_training = self.reward_model.training
        if old_training:
            self.reward_model.eval()

        with torch.no_grad():
            # flatten observations & actions
            obs_image = obs_tuple.observation
            old_dev = obs_image.device
            lead_dim, T, B, _ = infer_leading_dims(obs_image, self.obs_dims)
            # use tree_map so we are able to handle the namedtuple directly
            obs_flat = tree_map(
                lambda t: t.view((T * B, ) + t.shape[lead_dim:]), obs_tuple)
            act_flat = act_tensor.view((T * B, ) + act_tensor.shape[lead_dim:])

            # now evaluate one batch at a time
            reward_tensors = []
            for b_start in range(0, T * B, self.batch_size):
                obs_batch = obs_flat[b_start:b_start + self.batch_size]
                act_batch = act_flat[b_start:b_start + self.batch_size]
                dev_obs = tree_map(lambda t: t.to(self.dev), obs_batch)
                dev_acts = act_batch.to(self.dev)
                dev_reward = self.reward_model(dev_obs, dev_acts)
                reward_tensors.append(dev_reward.to(old_dev))

            # join together the batch results
            new_reward_flat = torch.cat(reward_tensors, 0)
            new_reward = restore_leading_dims(new_reward_flat, lead_dim, T, B)
            task_ids = obs_tuple.task_id
            assert new_reward.shape == task_ids.shape

        # put back into training mode if necessary
        if old_training:
            self.reward_model.train(old_training)

        # normalise if necessary
        if self.normalise:
            mus = []
            stds = []
            for task_id, averager in enumerate(self.rew_running_averages):
                if update_stats:
                    id_sub = task_ids.view((-1, )) == task_id
                    if not torch.any(id_sub):
                        continue
                    rew_sub = new_reward.view((-1, ))[id_sub]
                    averager.update(rew_sub)

                mus.append(averager.mean.item())
                stds.append(averager.std.item())

            mus = new_reward.new_tensor(mus)
            stds = new_reward.new_tensor(stds)
            denom = torch.max(stds.new_tensor(1e-3), stds / self.target_std)

            denom_sub = denom[task_ids]
            mu_sub = mus[task_ids]

            # only bother applying result if we've actually seen an update
            # before (otherwise reward will be insane)
            new_reward = (new_reward - mu_sub) / denom_sub

        return new_reward
    def loss(self, samples: SamplesFromReplay, sample_itr: int, opt_itr: int):
        """
        Compute the loss for a batch of data.  This includes computing the model and reward losses on the given data,
        as well as using the dynamics model to generate additional rollouts, which are used for the actor and value
        components of the loss.
        :param samples: samples from replay
        :param sample_itr: sample iteration
        :param opt_itr: optimization iteration
        :return: FloatTensor containing the loss
        """
        model = self.agent.model

        observation = samples.all_observation[:
                                              -1]  # [t, t+batch_length+1] -> [t, t+batch_length]
        action = samples.all_action[
            1:]  # [t-1, t+batch_length] -> [t, t+batch_length]
        reward = samples.all_reward[
            1:]  # [t-1, t+batch_length] -> [t, t+batch_length]
        reward = reward.unsqueeze(2)
        done = samples.done
        done = done.unsqueeze(2)

        # Extract tensors from the Samples object
        # They all have the batch_t dimension first, but we'll put the batch_b dimension first.
        # Also, we convert all tensors to floats so they can be fed into our models.

        lead_dim, batch_t, batch_b, img_shape = infer_leading_dims(
            observation, 3)
        # squeeze batch sizes to single batch dimension for imagination roll-out
        batch_size = batch_t * batch_b

        # normalize image
        observation = observation.type(self.type) / 255.0 - 0.5
        # embed the image
        embed = model.observation_encoder(observation)

        prev_state = model.representation.initial_state(batch_b,
                                                        device=action.device,
                                                        dtype=action.dtype)
        # Rollout model by taking the same series of actions as the real model
        prior, post = model.rollout.rollout_representation(
            batch_t, embed, action, prev_state)
        # Flatten our data (so first dimension is batch_t * batch_b = batch_size)
        # since we're going to do a new rollout starting from each state visited in each batch.

        # Compute losses for each component of the model

        # Model Loss
        feat = get_feat(post)
        image_pred = model.observation_decoder(feat)
        reward_pred = model.reward_model(feat)
        reward_loss = -torch.mean(reward_pred.log_prob(reward))
        image_loss = -torch.mean(image_pred.log_prob(observation))
        pcont_loss = torch.tensor(0.)  # placeholder if use_pcont = False
        if self.use_pcont:
            pcont_pred = model.pcont(feat)
            pcont_target = self.discount * (1 - done.float())
            pcont_loss = -torch.mean(pcont_pred.log_prob(pcont_target))
        prior_dist = get_dist(prior)
        post_dist = get_dist(post)
        div = torch.mean(
            torch.distributions.kl.kl_divergence(post_dist, prior_dist))
        div = torch.max(div, div.new_full(div.size(), self.free_nats))
        model_loss = self.kl_scale * div + reward_loss + image_loss
        if self.use_pcont:
            model_loss += self.pcont_scale * pcont_loss

        # ------------------------------------------  Gradient Barrier  ------------------------------------------------
        # Don't let gradients pass through to prevent overwriting gradients.
        # Actor Loss

        # remove gradients from previously calculated tensors
        with torch.no_grad():
            if self.use_pcont:
                # "Last step could be terminal." Done in TF2 code, but unclear why
                flat_post = buffer_method(post[:-1, :], 'reshape',
                                          (batch_t - 1) * (batch_b), -1)
            else:
                flat_post = buffer_method(post, 'reshape', batch_size, -1)
        # Rollout the policy for self.horizon steps. Variable names with imag_ indicate this data is imagined not real.
        # imag_feat shape is [horizon, batch_t * batch_b, feature_size]
        with FreezeParameters(self.model_modules):
            imag_dist, _ = model.rollout.rollout_policy(
                self.horizon, model.policy, flat_post)

        # Use state features (deterministic and stochastic) to predict the image and reward
        imag_feat = get_feat(
            imag_dist)  # [horizon, batch_t * batch_b, feature_size]
        # Assumes these are normal distributions. In the TF code it's be mode, but for a normal distribution mean = mode
        # If we want to use other distributions we'll have to fix this.
        # We calculate the target here so no grad necessary

        # freeze model parameters as only action model gradients needed
        with FreezeParameters(self.model_modules + self.value_modules):
            imag_reward = model.reward_model(imag_feat).mean
            value = model.value_model(imag_feat).mean
        # Compute the exponential discounted sum of rewards
        if self.use_pcont:
            with FreezeParameters([model.pcont]):
                discount_arr = model.pcont(imag_feat).mean
        else:
            discount_arr = self.discount * torch.ones_like(imag_reward)
        returns = self.compute_return(imag_reward[:-1],
                                      value[:-1],
                                      discount_arr[:-1],
                                      bootstrap=value[-1],
                                      lambda_=self.discount_lambda)
        # Make the top row 1 so the cumulative product starts with discount^0
        discount_arr = torch.cat(
            [torch.ones_like(discount_arr[:1]), discount_arr[1:]])
        discount = torch.cumprod(discount_arr[:-1], 0)
        actor_loss = -torch.mean(discount * returns)

        # ------------------------------------------  Gradient Barrier  ------------------------------------------------
        # Don't let gradients pass through to prevent overwriting gradients.
        # Value Loss

        # remove gradients from previously calculated tensors
        with torch.no_grad():
            value_feat = imag_feat[:-1].detach()
            value_discount = discount.detach()
            value_target = returns.detach()
        value_pred = model.value_model(value_feat)
        log_prob = value_pred.log_prob(value_target)
        value_loss = -torch.mean(value_discount * log_prob.unsqueeze(2))

        # ------------------------------------------  Gradient Barrier  ------------------------------------------------
        # loss info
        with torch.no_grad():
            prior_ent = torch.mean(prior_dist.entropy())
            post_ent = torch.mean(post_dist.entropy())
            loss_info = LossInfo(model_loss, actor_loss, value_loss, prior_ent,
                                 post_ent, div, reward_loss, image_loss,
                                 pcont_loss)

            if self.log_video:
                if opt_itr == self.train_steps - 1 and sample_itr % self.video_every == 0:
                    self.write_videos(observation,
                                      action,
                                      image_pred,
                                      post,
                                      step=sample_itr,
                                      n=self.video_summary_b,
                                      t=self.video_summary_t)

        return model_loss, actor_loss, value_loss, loss_info
Exemple #29
0
 def forward(self, observation, prev_action, prev_reward):
     lead_dim, T, B, img_shape = infer_leading_dims(observation, 3)
     action = torch.randint(low=0, high=self.num_actions, size=(T * B, ))
     action = restore_leading_dims((action), lead_dim, T, B)
     return action
    def compute_bonus(self, observations, prev_actions, actions):
        #------------------------------------------------------------#
        lead_dim, T, B, img_shape = infer_leading_dims(observations, 3)

        # hacky dimension add for when you have only one environment
        if prev_actions.dim() == 1: 
            prev_actions = prev_actions.view(1, 1, -1)
        if actions.dim() == 1:
            actions = actions.view(1, 1, -1)
        #------------------------------------------------------------#

        # generate belief states
        belief_states, gru_output_states = self.forward(observations, prev_actions)
        self.gru_states = None # only bc we're processing exactly 1 episode per batch

        # slice beliefs and actions
        belief_states_t = belief_states.clone()[:-self.horizon] # slice off last timesteps
        belief_states_tm1 = belief_states.clone()[:-self.horizon-1]
        
        action_seqs_t = torch.zeros((T-self.horizon, B, self.horizon*self.action_size), device=self.device) # placeholder
        action_seqs_tm1 = torch.zeros((T-self.horizon-1, B, (self.horizon+1)*self.action_size), device=self.device) # placeholder
        for i in range(len(actions)-self.horizon):
            if i != len(actions)-self.horizon-1:
                action_seq_tm1 = actions.clone()[i:i+self.horizon+1]
                action_seq_tm1 = torch.transpose(action_seq_tm1, 0, 1)
                action_seq_tm1 = torch.reshape(action_seq_tm1, (action_seq_tm1.shape[0], -1))
                action_seqs_tm1[i] = action_seq_tm1
            action_seq_t = actions.clone()[i:i+self.horizon]
            action_seq_t = torch.transpose(action_seq_t, 0, 1)
            action_seq_t = torch.reshape(action_seq_t, (action_seq_t.shape[0], -1))
            action_seqs_t[i] = action_seq_t
        
        # make forward model predictions
        if self.horizon == 1:
            predicted_states_tm1 = self.forward_model_2(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-1, B, 75)
            predicted_states_t = self.forward_model_1(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-1, B, 75)
        elif self.horizon == 2:
            predicted_states_tm1 = self.forward_model_3(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-2, B, 75)
            predicted_states_t = self.forward_model_2(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-2, B, 75)
        elif self.horizon == 3:
            predicted_states_tm1 = self.forward_model_4(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-3, B, 75)
            predicted_states_t = self.forward_model_3(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-3, B, 75)
        elif self.horizon == 4:
            predicted_states_tm1 = self.forward_model_5(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-4, B, 75)
            predicted_states_t = self.forward_model_4(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-4, B, 75)
        elif self.horizon == 5:
            predicted_states_tm1 = self.forward_model_6(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-5, B, 75)
            predicted_states_t = self.forward_model_5(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-5, B, 75)
        elif self.horizon == 6:
            predicted_states_tm1 = self.forward_model_7(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-6, B, 75)
            predicted_states_t = self.forward_model_6(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-6, B, 75)
        elif self.horizon == 7:
            predicted_states_tm1 = self.forward_model_8(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-7, B, 75)
            predicted_states_t = self.forward_model_7(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-7, B, 75)
        elif self.horizon == 8:
            predicted_states_tm1 = self.forward_model_9(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-8, B, 75)
            predicted_states_t = self.forward_model_8(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-8, B, 75)
        elif self.horizon == 9:
            predicted_states_tm1 = self.forward_model_10(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-9, B, 75)
            predicted_states_t = self.forward_model_9(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-9, B, 75)
        elif self.horizon == 10:
            predicted_states_tm1 = self.forward_model_11(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-10, B, 75)
            predicted_states_t = self.forward_model_10(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-10, B, 75)

        predicted_states_tm1 = nn.functional.sigmoid(predicted_states_tm1)
        true_obs_tm1 = observations.clone()[self.horizon:-1].view(-1, *predicted_states_tm1.shape[1:]).type(torch.float)
        predicted_states_t = nn.functional.sigmoid(predicted_states_t)
        true_obs_t = observations.clone()[self.horizon:].view(-1, *predicted_states_t.shape[1:]).type(torch.float)

        # generate losses
        losses_tm1 = nn.functional.binary_cross_entropy(predicted_states_tm1, true_obs_tm1, reduction='none')
        losses_tm1 = torch.sum(losses_tm1, dim=-1)/losses_tm1.shape[-1] # average of each feature for each environment at each timestep (T, B, ave_loss_over_feature)
        losses_t = nn.functional.binary_cross_entropy(predicted_states_t, true_obs_t, reduction='none')
        losses_t = torch.sum(losses_t, dim=-1)/losses_t.shape[-1]


        
        # subtract losses to get rewards (r[t+H-1] = losses[t-1] - losses[t])
        r_int = torch.zeros((T, B), device=self.device)
        r_int[self.horizon:len(losses_t)+self.horizon-1] = losses_tm1 - losses_t[1:] # time zero reward is set to 0 (L[-1] doesn't exist)
        # r_int[self.horizon:len(losses_t)+self.horizon-1] = losses_t[1:] - losses_tm1
        # r_int[1:len(losses_t)] = losses_tm1 - losses_t[1:]
        # r_int[1:len(losses_t)] = losses_t[1:] - losses_tm1

        # r_int = nn.functional.relu(r_int)

        return r_int*self.beta