Пример #1
0
    def __init__(self, h_dim, action_encoding, observation_type, nr_inputs,
                 cnn_channels, encoder_batch_norm):
        super().__init__()
        self.observation_type = observation_type
        self.action_encoding = action_encoding
        # From h and phi_z

        encoding_dimension = h_dim + h_dim + action_encoding

        # For observation
        self.dec, self.dec_mean, self.dec_std = encoder_decoder.get_decoder(
            observation_type,
            nr_inputs,
            cnn_channels,
            batch_norm=encoder_batch_norm)

        self.cnn_output_dimension = encoder_decoder.get_cnn_output_dimension(
            observation_type, cnn_channels)
        self.cnn_output_number = reduce(mul, self.cnn_output_dimension, 1)

        if encoder_batch_norm:
            self.linear_obs_decoder = nn.Sequential(
                nn.Linear(encoding_dimension, self.cnn_output_number),
                nn.BatchNorm1d(self.cnn_output_number), nn.ReLU())
        else:
            self.linear_obs_decoder = nn.Sequential(
                nn.Linear(encoding_dimension, self.cnn_output_number),
                nn.ReLU())
Пример #2
0
    def __init__(self, phi_x_dim, nr_actions, action_encoding,
                 observation_type, nr_inputs, cnn_channels,
                 encoder_batch_norm):
        super().__init__()
        self.action_encoding = action_encoding
        self.phi_x_dim = phi_x_dim
        assert (action_encoding > 0)

        self.phi_x = encoder_decoder.get_encoder(observation_type,
                                                 nr_inputs,
                                                 cnn_channels,
                                                 batch_norm=encoder_batch_norm)

        self.cnn_output_dimension = encoder_decoder.get_cnn_output_dimension(
            observation_type, cnn_channels)
        self.cnn_output_number = reduce(mul, self.cnn_output_dimension, 1)

        if encoder_batch_norm:
            self.action_encoder = nn.Sequential(
                nn.Linear(nr_actions, action_encoding),
                nn.BatchNorm1d(action_encoding), nn.ReLU())
        else:
            self.action_encoder = nn.Sequential(
                nn.Linear(nr_actions, action_encoding), nn.ReLU())
        self.nr_actions = nr_actions
Пример #3
0
    def __init__(
            self,
            action_space,
            nr_inputs,
            observation_type,
            action_encoding,
            # obs_encoding,
            cnn_channels,
            h_dim,
            init_function,
            encoder_batch_norm,
            policy_batch_norm,
            prior_loss_coef,
            obs_loss_coef,
            detach_encoder,
            batch_size,
            num_particles,
            particle_aggregation,
            z_dim,
            resample):
        super().__init__(action_space, encoding_dimension=h_dim)
        self.init_function = init_function
        self.num_particles = num_particles
        self.particle_aggregation = particle_aggregation
        self.batch_size = batch_size
        self.obs_loss_coef = float(obs_loss_coef)
        self.prior_loss_coef = float(prior_loss_coef)
        self.observation_type = observation_type
        self.encoder_batch_norm = encoder_batch_norm
        self.policy_batch_norm = policy_batch_norm
        self.detach_encoder = detach_encoder
        self.h_dim = h_dim
        self.z_dim = z_dim
        self.resample = resample

        # All encoder/decoders are defined in the encoder_decoder.py file
        self.cnn_output_dimension = encoder_decoder.get_cnn_output_dimension(
            observation_type, cnn_channels)
        self.cnn_output_number = reduce(mul, self.cnn_output_dimension, 1)

        # Naming conventions
        phi_x_dim = self.cnn_output_number

        if action_space.__class__.__name__ == "Discrete":
            action_shape = action_space.n
        else:
            action_shape = action_space.shape[0]

        ## Create all relevant networks

        # Encodes actions and observations into a latent state
        self.encoding_network = VRNN_encoding(
            phi_x_dim=phi_x_dim,
            nr_actions=action_shape,
            action_encoding=action_encoding,
            observation_type=observation_type,
            nr_inputs=nr_inputs,
            cnn_channels=cnn_channels,
            encoder_batch_norm=encoder_batch_norm)

        # Computes p(z_t|h_{t-1}, a_{t-1})
        self.transition_network = VRNN_transition(
            h_dim=h_dim, z_dim=z_dim, action_encoding=action_encoding)

        # Computes h_t=f(h_{t-1}, z_t, a_{t-1}, o_t)
        self.deterministic_transition_network = VRNN_deterministic_transition(
            z_dim=z_dim,
            phi_x_dim=phi_x_dim,
            h_dim=h_dim,
            action_encoding=action_encoding)

        # Computes p(o_t|h_t, z_t, a_{t-1})
        self.emission_network = VRNN_emission(
            h_dim=h_dim,
            action_encoding=action_encoding,
            observation_type=observation_type,
            nr_inputs=nr_inputs,
            cnn_channels=cnn_channels,
            encoder_batch_norm=encoder_batch_norm)

        # Computes q(z_t|h_{t-1}, a_{t-1}, o_t)
        self.proposal_network = VRNN_proposal(
            z_dim=z_dim,
            h_dim=h_dim,
            phi_x_dim=phi_x_dim,
            action_encoding=action_encoding,
            encoder_batch_norm=encoder_batch_norm)

        # dim is for z, h, w, where z & h both have h_dim and w is scalar
        dim = 2 * h_dim + 1
        if particle_aggregation == 'rnn' and self.num_particles > 1:
            self.particle_gru = nn.GRU(dim, h_dim, batch_first=True)

        elif self.num_particles == 1:
            self.particle_gru = nn.Linear(dim, h_dim)

        self.reset_parameters()
Пример #4
0
    def __init__(self, action_space, nr_inputs, observation_type,
                 action_encoding, cnn_channels, h_dim, init_function,
                 encoder_batch_norm, policy_batch_norm, obs_loss_coef,
                 detach_encoder, batch_size, resample):

        super().__init__(action_space, encoding_dimension=h_dim)
        self.h_dim = h_dim
        self.init_function = init_function
        self.batch_size = batch_size
        self.obs_loss_coef = float(obs_loss_coef)
        self.encoder_batch_norm = encoder_batch_norm
        self.policy_batch_norm = policy_batch_norm
        self.observation_type = observation_type
        self.detach_encoder = detach_encoder
        self.resample = resample

        # All encoders and decoders are define centrally in one file
        self.encoder = encoder_decoder.get_encoder(
            observation_type,
            nr_inputs,
            cnn_channels,
            batch_norm=encoder_batch_norm)

        self.cnn_output_dimension = encoder_decoder.get_cnn_output_dimension(
            observation_type, cnn_channels)
        self.cnn_output_number = reduce(mul, self.cnn_output_dimension, 1)

        # Decoder takes latent_state + action_encoding
        # linear_obs_decoder is a fc network projecting the latent state onto the correct
        # dimensionality for a CNN decoder
        encoding_dimension = h_dim + action_encoding
        if encoder_batch_norm:
            self.linear_obs_decoder = nn.Sequential(
                nn.Linear(encoding_dimension, self.cnn_output_number),
                nn.BatchNorm1d(self.cnn_output_number), nn.ReLU())
        else:
            self.linear_obs_decoder = nn.Sequential(
                nn.Linear(encoding_dimension, self.cnn_output_number),
                nn.ReLU())

        self.decoder = encoder_decoder.get_decoder(
            observation_type,
            nr_inputs,
            cnn_channels,
            batch_norm=encoder_batch_norm)

        # Actions are encoded using one FC layer.
        if action_encoding > 0:
            if action_space.__class__.__name__ == "Discrete":
                action_shape = action_space.n
            else:
                action_shape = action_space.shape[0]
            if encoder_batch_norm:
                self.action_encoder = nn.Sequential(
                    nn.Linear(action_shape, action_encoding),
                    nn.BatchNorm1d(action_encoding), nn.ReLU())
            else:
                self.action_encoder = nn.Sequential(
                    nn.Linear(action_shape, action_encoding), nn.ReLU())

        self.gru = nn.GRUCell(self.cnn_output_number + action_encoding, h_dim)
        # if self.encoder_batch_norm:
        #     self.gru_bn = nn.BatchNorm1d(h_dim)

        if observation_type == 'fc':
            self.obs_criterion = nn.MSELoss()
        else:
            self.obs_criterion = nn.BCEWithLogitsLoss()

        self.train()
        self.reset_parameters()