Esempio n. 1
0
    def forward(self, imgs, extract_sym_features: bool = False):
        lead_dim, T, B, img_shape = infer_leading_dims(imgs, 3)
        imgs = imgs.view(T * B, *img_shape)
        # NHWC -> NCHW
        imgs = imgs.transpose(3, 2).transpose(2, 1)

        # TODO: don't do uint8 -> float32 conversion twice (in detector and here)
        if self.sym_extractor is not None and extract_sym_features:
            sym_feats = self.sym_extractor(imgs[:, -1:])
        else:
            sym_feats = None

        # convert from [0, 255] uint8 to [0, 1] float32
        imgs = imgs.to(torch.float32)
        imgs.mul_(1.0 / 255)
        m = self.normalize_mean.to(imgs.device)
        s = self.normalize_std.to(imgs.device)
        imgs.sub_(m).div_(s)

        feats = self.conv1(imgs)
        feats = self.bn1(feats)
        feats = self.relu(feats)
        feats = self.maxpool(feats)
        feats = self.layer1(feats)
        feats = self.relu(feats)
        feats = self.layer2(feats)
        feats = self.relu(feats)
        feats = self.layer3(feats)
        feats = self.relu(feats)
        feats = self.flatten(feats)
        feats = self.linear(feats)
        feats = self.relu(feats)

        value = self.value_predictor(feats).squeeze(
            -1)  # squeezing seems expected?

        if self.categorical:
            action_dist = self.action_predictor(feats)
            action_dist, value = restore_leading_dims((action_dist, value),
                                                      lead_dim, T, B)
            if self.sym_extractor is not None and extract_sym_features:
                # have to "restore dims" for sym_feats manually...
                sym_feats = sym_feats.reshape((*action_dist.shape[:-1], -1))
                # sym_feats = torch.rand_like(action_dist)
            return action_dist, value, sym_feats
        else:
            mu = self.mu_predictor(feats)
            log_std = self.log_std_predictor(feats)
            mu, log_std, value = restore_leading_dims((mu, log_std, value),
                                                      lead_dim, T, B)
            if self.sym_extractor is not None and extract_sym_features:
                # have to "restore dims" for sym_feats manually...
                sym_feats = sym_feats.reshape((*mu.shape[:-1], -1))
                # sym_feats = torch.rand_like(action_dist)
            return mu, log_std, value, sym_feats
Esempio n. 2
0
    def forward(self, observation):
        lead_dim, T, B, img_shape = infer_leading_dims(observation, 3)
        if observation.dtype == torch.uint8:
            img = observation.type(torch.float)
            img = img.mul_(1.0 / 255)
        else:
            img = observation
        conv, conv_layers = self.conv(img.view(T * B, *img_shape))  # lists all layers
        c = self.head(conv_layers[-1].view(T * B, -1))

        c, conv = restore_leading_dims((c, conv), lead_dim, T, B)
        conv_layers = restore_leading_dims(conv_layers, lead_dim, T, B)
        return c, conv, conv_layers  # include conv_outs for local-stdim losses
Esempio n. 3
0
 def forward(self,
             observation,
             prev_action,
             prev_reward,
             action,
             detach_encoder=False):  # dummy
     lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim)
     q_input = torch.cat(
         [observation.view(T * B, -1),
          action.view(T * B, -1)], dim=1)
     q1 = self.q1_mlp(q_input).squeeze(-1)
     q1 = restore_leading_dims(q1, lead_dim, T, B)
     q2 = self.q2_mlp(q_input).squeeze(-1)
     q2 = restore_leading_dims(q2, lead_dim, T, B)
     return q1, q2
Esempio n. 4
0
    def forward(self, observation, prev_action, prev_reward, init_rnn_state):
        img = observation.type(torch.float)  # Expect torch.uint8 inputs
        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)

        conv_out = self.conv(img.reshape(T * B,
                                         *img_shape))  # Fold if T dimension.
        features = conv_out.reshape(T * B, -1)

        init_rnn_state = None if init_rnn_state is None else tuple(
            init_rnn_state)
        lstm_out, (hidden_state,
                   cell_state) = self.lstm(features.reshape(T, B, -1),
                                           init_rnn_state)
        head_input = torch.cat((features, lstm_out.reshape(T * B, -1)), dim=-1)

        pi = self.pi_head(head_input)
        pi = torch.softmax(pi, dim=-1)
        # pi = torch.sigmoid(pi - 2)
        value = self.value_head(head_input).squeeze(-1)

        # Restore leading dimensions: [T,B], [B], or [], as input.
        pi, value = restore_leading_dims((pi, value), lead_dim, T, B)
        state = RnnState(h=hidden_state, c=cell_state)
        return pi, value, state
Esempio n. 5
0
 def forward(self, observation, prev_action, prev_reward):
     lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim)
     output = self.mlp(observation.view(T * B, -1))
     alpha, beta = output[:, :self._action_size], output[:,
                                                         self._action_size:]
     alpha, beta = restore_leading_dims((alpha, beta), lead_dim, T, B)
     return alpha, beta
Esempio n. 6
0
    def forward(self,
                observation,
                prev_action,
                prev_reward,
                init_rnn_state=None):
        """
        Compute mean, log_std, and value estimate from input state. Infers
        leading dimensions of input: can be [T,B], [B], or []; provides
        returns with same leading dims.  Intermediate feedforward layers
        process as [T*B,H], with T=1,B=1 when not given. Used both in sampler
        and in algorithm (both via the agent).
        """
        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, _ = infer_leading_dims(observation.state,
                                               self._obs_ndim)
        assert not torch.any(torch.isnan(observation.state)), 'obs elem is nan'

        obs_flat = observation.state.reshape(T * B, -1)
        # obs_flat = self.layer_norm(obs_flat)
        # features = self.shared_mlp(obs_flat)
        action = self.mu_head(obs_flat)

        v = self.v_head(obs_flat).squeeze(-1)

        # mu, std = (action[:, :self.action_size], action[:, self.action_size:])
        # std = self.softplus(std)

        pi_output = self.softplus(action * 8) + 1
        mu, std = pi_output[:, :self.action_size], pi_output[:,
                                                             self.action_size:]

        # Restore leading dimensions: [T,B], [B], or [], as input.
        mu, std, v = restore_leading_dims((mu, std, v), lead_dim, T, B)

        return mu, std, v, torch.zeros(1, B, 1), None  # return fake rnn state
Esempio n. 7
0
    def forward(self, image, prev_action, prev_reward, init_rnn_state):
        """Feedforward layers process as [T*B,H], recurrent ones as [T,B,H].
        Return same leading dims as input, can be [T,B], [B], or [].
        (Same forward used for sampling and training.)"""
        img = image.type(torch.float)  # Expect torch.uint8 inputs
        img = img.mul_(1. / 255)  # From [0-255] to [0-1], in place.

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)

        fc_out = self.conv(img.view(T * B, *img_shape))
        lstm_input = torch.cat([
            fc_out.view(T, B, -1),
            prev_action.view(T, B, -1),  # Assumed onehot.
            prev_reward.view(T, B, 1),
            ], dim=2)
        init_rnn_state = None if init_rnn_state is None else tuple(init_rnn_state)
        lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state)
        pi = F.softmax(self.pi(lstm_out.view(T * B, -1)), dim=-1)
        v = self.value(lstm_out.view(T * B, -1)).squeeze(-1)

        # Restore leading dimensions: [T,B], [B], or [], as input.
        pi, v = restore_leading_dims((pi, v), lead_dim, T, B)
        # Model should always leave B-dimension in rnn state: [N,B,H].
        next_rnn_state = RnnState(h=hn, c=cn)

        return pi, v, next_rnn_state
Esempio n. 8
0
    def forward(self, image, prev_action=None, prev_reward=None, init_rnn_state=None):
        img = image.type(torch.float)
        img = img.mul_(1. / 255)

        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)

        img = img.view(T * B, *img_shape)
        fc_out = self.conv(img.view(T * B, *img_shape))
        lstm_input = torch.cat([
            fc_out.view(T, B, -1),
            prev_action.view(T, B, -1),  # Assumed onehot.
            prev_reward.view(T, B, 1),
            ], dim=2)

        init_rnn_state = None if init_rnn_state is None else tuple(init_rnn_state)
        lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state)
        pi = F.softmax(self.pi(lstm_out.view(T * B, -1)), dim=-1)
        v = self.value(lstm_out.view(T * B, -1)).squeeze(-1)

        # Restore leading dimensions: [T,B], [B], or [], as input.
        pi, v = restore_leading_dims((pi, v), lead_dim, T, B)
        # Model should always leave B-dimension in rnn state: [N,B,H].
        next_rnn_state = RnnState(h=hn, c=cn)

        return pi, v, next_rnn_state
        
Esempio n. 9
0
    def forward(self, observation, prev_action, prev_reward, init_rnn_state):

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_n_dim)

        if self.normalize_observation:
            obs_var = self.obs_rms.var
            if self.norm_obs_var_clip is not None:
                obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip)
            observation = torch.clamp(
                (observation - self.obs_rms.mean) / obs_var.sqrt(),
                -self.norm_obs_clip, self.norm_obs_clip)

        mlp_out = self.mlp(observation.view(T * B, -1))
        lstm_input = torch.cat(
            [
                mlp_out.view(T, B, -1),
                prev_action.view(T, B, -1),  # Assumed onehot.
                prev_reward.view(T, B, 1),
            ],
            dim=2)
        init_rnn_state = None if init_rnn_state is None else tuple(
            init_rnn_state)
        lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state)
        pi = F.softmax(self.pi(lstm_out.view(T * B, -1)), dim=-1)
        v = self.value(lstm_out.view(T * B, -1)).squeeze(-1)
        # Restore leading dimensions: [T,B], [B], or [], as input.
        pi, v = restore_leading_dims((pi, v), lead_dim, T, B)
        # Model should always leave B-dimension in rnn state: [N,B,H].
        next_rnn_state = RnnState(h=hn, c=cn)

        return pi, v, next_rnn_state
Esempio n. 10
0
 def forward(self, observation, prev_action, prev_reward, init_rnn_state):
     lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_dim)
     if init_rnn_state is not None:
         if self.rnn_is_lstm:
             init_rnn_pi, init_rnn_v = tuple(init_rnn_state)  # DualRnnState -> RnnState_pi, RnnState_v
             init_rnn_pi, init_rnn_v = tuple(init_rnn_pi), tuple(init_rnn_v)
         else:
             init_rnn_pi, init_rnn_v = tuple(init_rnn_state)  # DualRnnState -> h, h
     else:
         init_rnn_pi, init_rnn_v = None, None
     o_flat = self.preprocessor(observation.view(T*B))
     b_pi, b_v = self.body_pi(o_flat), self.body_v(o_flat)
     p_a, p_r = prev_action.view(T, B, -1), prev_reward.view(T, B, 1)
     pi_inp_list = [b_pi.view(T,B,-1)] + ([p_a] if self.p_a in [1,3] else []) + ([p_r] if self.p_r in [1,3] else [])
     v_inp_list = [b_pi.view(T, B, -1)] + ([p_a] if self.p_a in [2,3] else []) + ([p_r] if self.p_r in [2, 3] else [])
     rnn_input_pi = torch.cat(pi_inp_list, dim=2)
     rnn_input_v = torch.cat(v_inp_list, dim=2)
     rnn_pi, next_rnn_state_pi = self.rnn_pi(rnn_input_pi, init_rnn_pi)
     rnn_v, next_rnn_state_v = self.rnn_v(rnn_input_v, init_rnn_v)
     rnn_pi = rnn_pi.view(T*B, -1); rnn_v = rnn_v.view(T*B, -1)
     pi, v = self.pi(rnn_pi), self.v(rnn_v).squeeze(-1)
     pi, v = restore_leading_dims((pi, v), lead_dim, T, B)
     if self.rnn_is_lstm:
         next_rnn_state = DualRnnState(RnnState(*next_rnn_state_pi), RnnState(*next_rnn_state_v))
     else:
         next_rnn_state = DualRnnState(next_rnn_state_pi, next_rnn_state_v)
     return pi, v, next_rnn_state
Esempio n. 11
0
    def forward(self, image, prev_action, prev_reward):
        """
        Compute action probabilities and value estimate from input state.
        Infers leading dimensions of input: can be [T,B], [B], or []; provides
        returns with same leading dims.  Convolution layers process as [T*B,
        *image_shape], with T=1,B=1 when not given.  Expects uint8 images in
        [0,255] and converts them to float32 in [0,1] (to minimize image data
        storage and transfer).  Used in both sampler and in algorithm (both
        via the agent).
        """
        # img = image.type(torch.float)  # Expect torch.uint8 inputs
        # img = img.mul_(1. / 255)  # From [0-255] to [0-1], in place.
        # Expects [0-1] float

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, img_shape = infer_leading_dims(image, 3)

        fc_out = self.conv(image.view(T * B, *img_shape))
        mu = self.mu(fc_out)
        v = self.v(fc_out).squeeze(-1)
        log_std = self.log_std.repeat(T * B, 1)

        # Restore leading dimensions: [T,B], [B], or [], as input.
        mu, log_std, v = restore_leading_dims((mu, log_std, v), lead_dim, T, B)
        return mu, log_std, v
Esempio n. 12
0
    def forward(self, observation):
        lead_dim, T, B, img_shape = infer_leading_dims(observation, 3)
        obs = observation.view(T * B, *img_shape)
        obs = (obs.permute(0, 3, 1, 2).float() / 255.) * 2 - 1

        z_fake = torch.normal(torch.zeros(obs.shape[0], self.zdim),
                              torch.ones(obs.shape[0], self.zdim)).to(device)
        z_real = self.e(obs).reshape(obs.shape[0], self.zdim)
        x_fake = self.g(z_fake).reshape(obs.shape[0], -1)
        x_real = obs.view(obs.shape[0], -1)

        # extractor_in=z_fake
        # if self.detach_encoder:
        #     extractor_in = extractor_in.detach()
        # extractor_out = self.shared_extractor(extractor_in)

        # if self.detach_policy:
        #     policy_in = extractor_out.detach()
        # else:
        #     policy_in = extractor_out
        # if self.detach_value:
        #     value_in = extractor_out.detach()
        # else:
        #     value_in = extractor_out

        # act_dist = self.policy(policy_in)
        # value = self.value(value_in).squeeze(-1)
        act_dist, value = 0, 0
        label_real = self.d(z_real, x_real)
        label_fake = self.d(z_fake, x_fake)
        latent, reconstruction = restore_leading_dims((z_fake, x_fake),
                                                      lead_dim, T, B)
        return act_dist, value, latent, reconstruction, label_real, label_fake
Esempio n. 13
0
 def forward(self, observation, prev_action, prev_reward):
     lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim)
     obs_flat = observation.view(T * B, -1)
     pi = F.softmax(self.pi(obs_flat), dim=-1)
     v = self.v(obs_flat).squeeze(-1)
     pi, v = restore_leading_dims((pi, v), lead_dim, T, B)
     return pi, v
Esempio n. 14
0
    def forward(self, observation, prev_action, prev_reward):
        """
        Compute mean, log_std, and value estimates from input state. Includes
        value estimates of both distinct intrinsic and extrinsic returns. See
        rlpyt MujocoFfModel for more information on this function.
        """
        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim)

        if self.normalize_observation:
            obs_var = self.obs_rms.var
            if self.norm_obs_var_clip is not None:
                obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip)
            observation = torch.clamp((observation - self.obs_rms.mean) /
                obs_var.sqrt(), -self.norm_obs_clip, self.norm_obs_clip)

        obs_flat = observation.view(T * B, -1)
        mu = self.mu(obs_flat)
        ev = self.v(obs_flat).squeeze(-1)
        iv = self.iv(obs_flat).squeeze(-1)  # Added intrinsic value MLP forward pass
        log_std = self.log_std.repeat(T * B, 1)

        # Restore leading dimensions: [T,B], [B], or [], as input.
        mu, log_std, ev, iv = restore_leading_dims((mu, log_std, ev, iv), lead_dim, T, B)

        return mu, log_std, ev, iv
Esempio n. 15
0
    def forward(self, image, prev_action=None, prev_reward=None):
        img = image.type(torch.float)
        img = img.mul_(1. / 255)

        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)

        img = img.view(T * B, *img_shape)

        relu = torch.nn.functional.relu

        y = relu(self.conv1(img))
        y = self.attention(y)
        y = relu(self.conv2(y))
        y = relu(self.conv3(y))

        fc_out = self.fc(y.view(T * B, -1))

        pi = torch.nn.functional.softmax(self.pi(fc_out), dim=-1)
        v = self.value(fc_out).squeeze(-1)
        # Restore leading dimensions: [T,B], [B], or [], as input.
        #T -> transition
        #B -> batch_size?
        pi, v = restore_leading_dims((pi, v), lead_dim, T, B)

        return pi, v
Esempio n. 16
0
    def forward(self, image, prev_action=None, prev_reward=None):
        img = image.type(torch.float)
        img = img.mul_(1. / 255)

        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)

        img = img.view(T * B, *img_shape)

        relu = torch.nn.functional.relu

        l1 = relu(self.conv1(img))
        l2 = relu(self.conv2(l1))
        y = relu(self.conv3(l2))
        #fc_out = self.fc(y.view(T * B, -1))
        fc_out = self.fc(y)

        l1 = self.projector(l1)
        l2 = self.projector2(l2)

        c1, g1 = self.attn1(l1, fc_out)
        c2, g2 = self.attn2(l2, fc_out)
        g = torch.cat((g1, g2), dim=1)  # batch_sizexC
        # classification layer

        pi = torch.nn.functional.softmax(self.pi(g), dim=-1)
        v = self.value(g).squeeze(-1)
        # Restore leading dimensions: [T,B], [B], or [], as input.
        #T -> transition
        #B -> batch_size?
        pi, v = restore_leading_dims((pi, v), lead_dim, T, B)

        return pi, v
Esempio n. 17
0
    def forward(self, observation, prev_action, prev_reward):
        """
        Compute action Q-value estimates from input state.
        Infers leading dimensions of input: can be [T,B], [B], or []; provides
        returns with same leading dims.  Convolution layers process as [T*B,
        image_shape[0], image_shape[1],...,image_shape[-1]], with T=1,B=1 when not given.  Expects uint8 images in
        [0,255] and converts them to float32 in [0,1] (to minimize image data
        storage and transfer).  Used in both sampler and in algorithm (both
        via the agent).
        """
        img = observation.type(torch.float)  # Expect torch.uint8 inputs

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)
        # features = self.mlp(img.reshape(T * B, -1))

        conv_out = self.conv(img.view(T * B,
                                      *img_shape))  # Fold if T dimension.
        # q = self.head(self.dropout(conv_out.view(T * B, -1)))
        q = self.head(conv_out.view(T * B, -1))
        # q = self.head(features)

        # Restore leading dimensions: [T,B], [B], or [], as input.
        q = restore_leading_dims(q, lead_dim, T, B)
        return q
Esempio n. 18
0
    def forward(self, observation, prev_action, prev_reward):
        """
        Compute mean, log_std, q-value, and termination estimates from input state. Infers
        leading dimensions of input: can be [T,B], [B], or []; provides
        returns with same leading dims.  Intermediate feedforward layers
        process as [T*B,H], with T=1,B=1 when not given. Used both in sampler
        and in algorithm (both via the agent).
        """
        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim)
        if self.normalize_observation:
            obs_var = self.obs_rms.var
            if self.norm_obs_var_clip is not None:
                obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip)
            observation = torch.clamp((observation - self.obs_rms.mean) /
                obs_var.sqrt(), -self.norm_obs_clip, self.norm_obs_clip)

        obs_flat = observation.view(T * B, -1)
        mu, logstd = self.mu(obs_flat)
        q = self.q(obs_flat)
        log_std = logstd.repeat(T * B, 1, 1)
        beta = self.beta(obs_flat)
        pi = self.pi_omega(obs_flat)
        I = self.pi_omega_I(obs_flat)
        # Restore leading dimensions: [T,B], [B], or [], as input.
        mu, log_std, q, beta, pi, I = restore_leading_dims((mu, log_std, q, beta, pi, I), lead_dim, T, B)
        pi = pi * I  # Torch multinomial will normalize
        return mu, log_std, beta, q, pi
Esempio n. 19
0
    def forward(self, observation, prev_action, prev_reward):
        """Feedforward layers process as [T*B,H]. Return same leading dims as
        input, can be [T,B], [B], or []."""

        observation = self.encoder(observation)

        if self.pooling is not None:
            if self.pooling == 'average':
                pooled = observation.mean(-2)

            elif self.pooling == 'max':
                pooled = observation.max(-2)

            pooled = pooled.unsqueeze(-2).expand_as(observation)
            observation = torch.cat([observation, pooled], dim=-1)

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim)
        obs_flat = observation.view(T * B * self._n_pop, -1)

        mu = self.mu(obs_flat)
        v = self.v(obs_flat).squeeze(-1)
        log_std = self.log_std.repeat(T * B, 1)

        mu = mu.view(T * B, self._n_pop, -1)
        v = v.view(T * B, self._n_pop)
        log_std = log_std.view(T * B, self._n_pop, -1)

        # Restore leading dimensions: [T,B], [B], or [], as input.
        mu, log_std, v = restore_leading_dims((mu, log_std, v), lead_dim, T, B)

        return mu, log_std, v
Esempio n. 20
0
    def forward(self, observation, prev_action, prev_reward, state):
        lead_dim, T, B, _ = infer_leading_dims(observation.state, 1)
        # print(f'step in episode {observation.step_in_episode}')
        aux_loss = None
        if T == 1:
            transformer_output, state = self.sample_forward(
                observation.state, state)
            value = torch.zeros(B)
        elif T == self.sequence_length:
            transformer_output, aux_loss = self.optim_forward(
                observation.state, state)
            value = self.value_head(transformer_output).reshape(T * B, -1)
        else:
            raise NotImplementedError

        pi_output = self.pi_head(transformer_output).view(T * B, -1)
        mu, std = pi_output[:, :self.action_size], pi_output[:,
                                                             self.action_size:]
        std = self.softplus(std - 1)
        mu = torch.tanh(mu)

        # pi_output = self.softplus(pi_output * 1) + 1
        # mu, std = pi_output[:, :self.action_size], pi_output[:, self.action_size:]

        mu, std, value = restore_leading_dims((mu, std, value), lead_dim, T, B)
        return mu, std, value, state, aux_loss
Esempio n. 21
0
    def forward(self, observation, prev_action, prev_reward, init_rnn_state):
        if observation.dtype == torch.uint8:
            img = observation.type(torch.float)
            img = img.mul_(1.0 / 255)
        else:
            img = observation

        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)
        conv = self.conv(img.view(T * B, *img_shape))

        if self.stop_conv_grad:
            conv = conv.detach()

        fc1 = F.relu(self.fc1(conv.view(T * B, -1)))
        lstm_input = torch.cat(
            [
                fc1.view(T, B, -1),
                prev_action.view(T, B, -1),  # Assumed onehot
                prev_reward.view(T, B, 1),
            ],
            dim=2,
        )
        init_rnn_state = None if init_rnn_state is None else tuple(
            init_rnn_state)
        lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state)
        if self._skip_lstm:
            lstm_out = lstm_out.view(T * B, -1) + fc1
        pi_v = self.pi_v_head(lstm_out.view(T * B, -1))
        pi = F.softmax(pi_v[:, :-1], dim=-1)
        v = pi_v[:, -1]
        pi, v, conv = restore_leading_dims((pi, v, conv), lead_dim, T, B)
        next_rnn_state = RnnState(h=hn, c=cn)
        return pi, v, next_rnn_state, conv
Esempio n. 22
0
    def forward(self, observation, prev_action=None, prev_reward=None):
        lead_dim, T, B, img_shape = infer_leading_dims(observation, 3)
        obs = observation.view(T * B, *img_shape)

        z, mu, logsd = self.encoder(obs)
        reconstruction = self.decoder(z)
        extractor_in = mu if self.deterministic else z

        if self.detach_vae:
            extractor_in = extractor_in.detach()
        extractor_out = self.shared_extractor(extractor_in)

        if self.detach_policy:
            policy_in = extractor_out.detach()
        else:
            policy_in = extractor_out
        if self.detach_value:
            value_in = extractor_out.detach()
        else:
            value_in = extractor_out

        act_dist = self.policy(policy_in)
        value = self.value(value_in).squeeze(-1)

        if self.rae:
            latent = mu
        else:
            latent = torch.cat((mu, logsd), dim=1)

        act_dist, value, latent, reconstruction = restore_leading_dims(
            (act_dist, value, latent, reconstruction), lead_dim, T, B)
        return act_dist, value, latent, reconstruction
Esempio n. 23
0
    def forward(self, image, prev_action, prev_reward):
        """
        Compute action probabilities and value estimate from input state.
        Infers leading dimensions of input: can be [T,B], [B], or []; provides
        returns with same leading dims.  Convolution layers process as [T*B,
        *image_shape], with T=1,B=1 when not given.  Expects uint8 images in
        [0,255] and converts them to float32 in [0,1] (to minimize image data
        storage and transfer).  Used in both sampler and in algorithm (both
        via the agent).
        """
        img = image.type(torch.float)  # Expect torch.uint8 inputs
        img = img.mul_(1. / 255)  # From [0-255] to [0-1], in place.

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)

        fc_out = self.conv(img.view(T * B, *img_shape))
        pi = self.pi(fc_out)
        q = self.q(fc_out)
        beta = self.beta(fc_out)
        pi_omega_I = self.pi_omega(fc_out)
        I = self.pi_omega_I(fc_out)
        pi_omega_I = pi_omega_I * I
        # Restore leading dimensions: [T,B], [B], or [], as input.
        pi, q, beta, pi_omega_I = restore_leading_dims(
            (pi, q, beta, pi_omega_I), lead_dim, T, B)
        return pi, q, beta, pi_omega_I
Esempio n. 24
0
    def forward(self, observation, prev_action, prev_reward, init_rnn_state):
        """ Compute action probabilities and value estimate

        NOTE: Rnn concatenates previous action and reward to input
        """
        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim)
        # Convert init_rnn_state appropriately
        if init_rnn_state is not None:
            if self.rnn_type == 'gru':
                init_rnn_state = init_rnn_state.h  # namedarraytuple -> h
            else:
                init_rnn_state = tuple(init_rnn_state)  # namedarraytuple -> tuple (h, c)
        oh = self.preprocessor(observation)  # Leave in TxB format for lstm
        rnn_input = torch.cat([
            oh.view(T,B,-1),
            prev_action.view(T, B, -1),  # Assumed onehot.
            prev_reward.view(T, B, 1),
            ], dim=2)
        rnn_out, h = self.rnn(rnn_input, init_rnn_state)
        rnn_out = rnn_out.view(T*B, -1)
        pi, beta, q, pi_omega, q_ent = self.model(rnn_out)
        # Restore leading dimensions: [T,B], [B], or [], as input.
        pi, beta, q, pi_omega, q_ent = restore_leading_dims((pi, beta, q, pi_omega, q_ent), lead_dim, T, B)
        # Model should always leave B-dimension in rnn state: [N,B,H].
        if self.rnn_type == 'gru':
            next_rnn_state = GruState(h=h)
        else:
            next_rnn_state = RnnState(*h)
        return pi, beta, q, pi_omega, next_rnn_state
    def forward(self, observation, prev_action, prev_reward, init_rnn_state):
        """Feedforward layers process as [T*B,H]. Return same leading dims as
        input, can be [T,B], [B], or []."""
        img = observation.type(torch.float)  # Expect torch.uint8 inputs
        img = img.mul_(1. / 255)  # From [0-255] to [0-1], in place.

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)

        conv_out = self.conv(img.view(T * B,
                                      *img_shape))  # Fold if T dimension.

        lstm_input = torch.cat(
            [
                conv_out.view(T, B, -1),
                prev_action.view(T, B, -1),  # Assumed onehot.
                prev_reward.view(T, B, 1),
            ],
            dim=2)
        init_rnn_state = None if init_rnn_state is None else tuple(
            init_rnn_state)
        lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state)

        q = self.head(lstm_out.view(T * B, -1))

        # Restore leading dimensions: [T,B], [B], or [], as input.
        q = restore_leading_dims(q, lead_dim, T, B)
        # Model should always leave B-dimension in rnn state: [N,B,H].
        next_rnn_state = RnnState(h=hn, c=cn)

        return q, next_rnn_state
Esempio n. 26
0
    def forward(self, observation, prev_action, prev_reward):
        """
        Compute mean, log_std, and value estimate from input state. Infers
        leading dimensions of input: can be [T,B], [B], or []; provides
        returns with same leading dims.  Intermediate feedforward layers
        process as [T*B,H], with T=1,B=1 when not given. Used both in sampler
        and in algorithm (both via the agent).
        """
        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim)

        if self.normalize_observation:
            obs_var = self.obs_rms.var
            if self.norm_obs_var_clip is not None:
                obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip)
            observation = torch.clamp(
                (observation - self.obs_rms.mean) / obs_var.sqrt(),
                -self.norm_obs_clip, self.norm_obs_clip)

        obs_flat = observation.view(T * B, -1)
        mu = self.mu(obs_flat)
        v = self.v(obs_flat).squeeze(-1)
        log_std = self.log_std.repeat(T * B, 1)

        # Restore leading dimensions: [T,B], [B], or [], as input.
        mu, log_std, v = restore_leading_dims((mu, log_std, v), lead_dim, T, B)

        return mu, log_std, v
Esempio n. 27
0
    def forward(self, observation, prev_action, prev_reward, init_rnn_state):
        """Feedforward layers process as [T*B,H]. Return same leading dims as
        input, can be [T,B], [B], or []."""

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        obs_shape, T, B, has_T, has_B = infer_leading_dims(
            observation, self._obs_n_dim)

        mlp_out = self.mlp(observation.view(T * B, -1))
        lstm_input = torch.cat([
            mlp_out.view(T, B, -1),
            prev_action.view(T, B, -1),
            prev_reward.view(T, B, 1),
        ],
                               dim=2)
        init_rnn_state = None if init_rnn_state is None else tuple(
            init_rnn_state)
        lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state)
        outputs = self.head(lstm_out)
        mu = outputs[:, :self._action_size]
        log_std = outputs[:, self._action_size:-1]
        v = outputs[:, -1].squeeze(-1)

        # Restore leading dimensions: [T,B], [B], or [], as input.
        mu, log_std, v = restore_leading_dims((mu, log_std, v), T, B, has_T,
                                              has_B)
        # Model should always leave B-dimension in rnn state: [N,B,H]
        next_rnn_state = RnnState(h=hn, c=cn)

        return mu, log_std, v, next_rnn_state
Esempio n. 28
0
    def forward(self, observation, prev_action, prev_reward):
        if observation.dtype == torch.uint8:
            img = observation.type(torch.float)
            img = img.mul_(1.0 / 255)
        else:
            img = observation

        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)
        conv = self.conv(img.view(T * B, *img_shape))

        if self.stop_conv_grad:
            conv = conv.detach()
        if self.normalize_conv_out:
            conv_var = self.conv_rms.var
            conv_var = torch.clamp(conv_var, min=self.var_clip)
            # stddev of uniform [a,b] = (b-a)/sqrt(12), 1/sqrt(12)~0.29
            # then allow [0, 10]?
            conv = torch.clamp(0.29 * conv / conv_var.sqrt(), 0, 10)

        pi_v = self.pi_v_mlp(conv.view(T * B, -1))
        pi = F.softmax(pi_v[:, :-1], dim=-1)
        v = pi_v[:, -1]

        pi, v, conv = restore_leading_dims((pi, v, conv), lead_dim, T, B)
        return pi, v, conv
Esempio n. 29
0
 def forward(self, observation, prev_action, prev_reward):
     obs_shape, T, B, has_T, has_B = infer_leading_dims(
         observation, self._obs_ndim)
     mu = self._output_max * torch.tanh(
         self.mlp(observation.view(T * B, -1)))
     mu = restore_leading_dims(mu, T, B, has_T, has_B)
     return mu
Esempio n. 30
0
 def forward(self, observation, prev_action, prev_reward):
     lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim)
     output = self.mlp(observation.view(T * B, -1))
     mu, log_std = output[:, :self._action_size], output[:,
                                                         self._action_size:]
     mu, log_std = restore_leading_dims((mu, log_std), lead_dim, T, B)
     return mu, log_std