Beispiel #1
0
 def choose_action(self, observation):
     mu, sigma  = self.actor.forward(observation)#.to(self.actor.device)
     sigma = T.exp(sigma)
     action_probs = Independent(Normal(mu, sigma),1)
     probs = action_probs.sample()
     self.log_probs = action_probs.log_prob(probs).to(self.actor.device)
     return probs
Beispiel #2
0
 def forward(  # type: ignore
     self,
     batch: Batch,
     state: Optional[Union[dict, Batch, np.ndarray]] = None,
     input: str = "obs",
     **kwargs: Any,
 ) -> Batch:
     obs = batch[input]
     logits, h = self.actor(obs, state=state, info=batch.info)
     assert isinstance(logits, tuple)
     dist = Independent(Normal(*logits), 1)
     if self._deterministic_eval and not self.training:
         x = logits[0]
     else:
         x = dist.rsample()
     y = torch.tanh(x)
     act = y * self._action_scale + self._action_bias
     y = self._action_scale * (1 - y.pow(2)) + self.__eps
     log_prob = dist.log_prob(x).unsqueeze(-1)
     log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
     if self._noise is not None and self.training and not self.updating:
         act += to_torch_as(self._noise(act.shape), act)
     act = act.clamp(self._range[0], self._range[1])
     return Batch(
         logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
	def forward(self, input, kl_coef):
		b, m, train_m, test_m = input
		
		mean, std = self.ode_rnn(input)	# (batch_size, LO_hidden_size) * 2
		
		d = Normal(torch.tensor([0.0], device = self.param['device']), torch.tensor([1.0], device = self.param['device']))
		r = d.sample(mean.shape).squeeze(-1)
		z0 = mean + r * std
		
		z_out = odeint(self.ode_func, z0, b[0, :, 0], rtol = self.param['rtol'], atol = self.param['atol']) # (num_time_points, batch_size, LO_hidden_size)	
		z_out = z_out.permute(1, 0, 2)
		output = self.output_output(z_out).squeeze(2)

		z0_distr = Normal(mean, std)
		kl_div = kl_divergence(z0_distr, Normal(torch.tensor([0.0], device = self.param['device']), torch.tensor([1.0], device = self.param['device'])))
		kl_div = kl_div.mean(axis = 1)
		
		masked_output = output[test_m.bool()].reshape(self.param['batch_size'], (self.param['total_points'] - self.param['obs_points']))
		target = b[:, :, 1][test_m.bool()].reshape(self.param['batch_size'], (self.param['total_points'] - self.param['obs_points']))
		
		gaussian = Independent(Normal(loc = masked_output, scale = self.param['obsrv_std']), 1)
		log_prob = gaussian.log_prob(target)
		likelihood = log_prob / output.shape[1]
		
		loss = -torch.logsumexp(likelihood - kl_coef * kl_div, 0)
		mse = self.mse_func(masked_output, target)
		
		return loss, mse, masked_output
		
Beispiel #4
0
class IndependentNormal(Distribution):
    arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
    support = constraints.positive
    has_rsample = True

    def __init__(self, loc, scale, validate_args=None):
        self.base_dist = Independent(Normal(loc=loc,
                                            scale=scale,
                                            validate_args=validate_args),
                                     len(loc.shape) - 1,
                                     validate_args=validate_args)
        super(IndependentNormal, self).__init__(self.base_dist.batch_shape,
                                                self.base_dist.event_shape,
                                                validate_args=validate_args)

    def log_prob(self, value):
        return self.base_dist.log_prob(value)

    @property
    def mean(self):
        return self.base_dist.mean

    @property
    def variance(self):
        return self.base_dist.variance

    def sample(self, sample_shape=torch.Size()):
        return self.base_dist.sample(sample_shape)

    def rsample(self, sample_shape=torch.Size()):
        return self.base_dist.rsample(sample_shape)

    def entropy(self):
        entropy = self.base_dist.entropy()
        return entropy
Beispiel #5
0
def get_true_posterior_samples_linear_gaussian_uniform_prior(
    observation: torch.Tensor, prior: Independent, num_samples: int = 1000, std=1,
):
    observation = utils.torchutils.atleast_2d(observation)
    assert observation.ndim == 2, "needs batch dimension in observation"
    mean = observation
    event_shape = mean.shape[1]
    posterior = MultivariateNormal(
        loc=mean, covariance_matrix=std * torch.eye(event_shape)
    )

    # generate samples from ND Gaussian truncated by prior support
    num_remaining = num_samples
    samples = []

    while num_remaining > 0:
        candidate_samples = posterior.sample(sample_shape=(num_remaining,))
        is_in_prior = torch.isfinite(prior.log_prob(candidate_samples))
        # accept if in prior
        if is_in_prior.sum():
            samples.append(
                candidate_samples[is_in_prior,]
            )
            num_remaining -= is_in_prior.sum().item()

    return torch.cat(samples)
Beispiel #6
0
 def test_independent_expand(self):
     for Dist, params in EXAMPLES:
         for param in params:
             base_dist = Dist(**param)
             for reinterpreted_batch_ndims in range(
                     len(base_dist.batch_shape) + 1):
                 for s in [
                         torch.Size(),
                         torch.Size((2, )),
                         torch.Size((2, 3))
                 ]:
                     indep_dist = Independent(base_dist,
                                              reinterpreted_batch_ndims)
                     expanded_shape = s + indep_dist.batch_shape
                     expanded = indep_dist.expand(expanded_shape)
                     expanded_sample = expanded.sample()
                     expected_shape = expanded_shape + indep_dist.event_shape
                     self.assertEqual(expanded_sample.shape, expected_shape)
                     self.assertEqual(
                         expanded.log_prob(expanded_sample),
                         indep_dist.log_prob(expanded_sample),
                     )
                     self.assertEqual(expanded.event_shape,
                                      indep_dist.event_shape)
                     self.assertEqual(expanded.batch_shape, expanded_shape)
Beispiel #7
0
    def trajectory(self, current_state):
        '''
        Maybe this implementation doesn't utilize GPUs very well, but I have no clue or not.

        Final output looks like:
        [(s_0, a_0, r_0), ..., (s_L, a_L, r_l)]
        '''
        output_history = []
        while True:
            mu, sigma = self.forward(current_state)
            distribution = Independent(Normal(mu, sigma), 1)
            picked_action = distribution.rsample()
            action = picked_action.detach()
            #print(action)
            new_state, reward = self.env.state_and_reward(
                current_state, action
            )  #Get the reward and the new state that the action in the environment resulted in. None if action caused death. TODO build in environment
            #Attempting this
            output_history.append(
                (current_state, action, reward, distribution.log_prob(action)))
            if new_state is None:  #essentially, you died or finished your trajectory
                break
            else:
                current_state = new_state
        return output_history
Beispiel #8
0
 def forward(  # type: ignore
     self,
     batch: Batch,
     state: Optional[Union[dict, Batch, np.ndarray]] = None,
     input: str = "obs",
     **kwargs: Any,
 ) -> Batch:
     obs = batch[input]
     logits, h = self.actor(obs, state=state, info=batch.info)
     assert isinstance(logits, tuple)
     dist = Independent(Normal(*logits), 1)
     if self._deterministic_eval and not self.training:
         act = logits[0]
     else:
         act = dist.rsample()
     log_prob = dist.log_prob(act).unsqueeze(-1)
     # apply correction for Tanh squashing when computing logprob from Gaussian
     # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
     # in appendix C to get some understanding of this equation.
     if self.action_scaling and self.action_space is not None:
         action_scale = to_torch_as(
             (self.action_space.high - self.action_space.low) / 2.0, act)
     else:
         action_scale = 1.0  # type: ignore
     squashed_action = torch.tanh(act)
     log_prob = log_prob - torch.log(action_scale *
                                     (1 - squashed_action.pow(2)) +
                                     self.__eps).sum(-1, keepdim=True)
     return Batch(logits=logits,
                  act=squashed_action,
                  state=h,
                  dist=dist,
                  log_prob=log_prob)
Beispiel #9
0
    def forward(self, state, eval=False, with_log_prob=False):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))

        mu = self.fc3(x)
        log_sigma = self.fc4(x)
        # clip value of log_sigma, as was done in Haarnoja's implementation of SAC:
        # https://github.com/haarnoja/sac.git
        log_sigma = torch.clamp(log_sigma, -20.0, 2.0)

        sigma = torch.exp(log_sigma)
        distribution = Independent(Normal(mu, sigma), 1)

        if not eval:
            # use rsample() instead of sample(), as sample() does not allow back-propagation through params
            u = distribution.rsample()
            if with_log_prob:
                log_prob = distribution.log_prob(u)
                log_prob -= 2.0 * torch.sum(
                    (np.log(2.0) + 0.5 * np.log(self.ctrl_range) - u -
                     F.softplus(-2.0 * u)),
                    dim=1)

            else:
                log_prob = None
        else:
            u = mu
            log_prob = None
        # apply tanh so that the resulting action lies in (-1, 1)^D
        a = self.ctrl_range * torch.tanh(u)

        return a, log_prob
Beispiel #10
0
    def forward(  # type: ignore
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        input: str = "obs",
        **kwargs: Any,
    ) -> Batch:
        obs = batch[input]
        logits, h = self.actor(obs, state=state, info=batch.info)
        assert isinstance(logits, tuple)
        dist = Independent(Normal(*logits), 1)
        if self._deterministic_eval and not self.training:
            x = logits[0]
        else:
            x = dist.rsample()
        y = torch.tanh(x)
        act = y * self._action_scale + self._action_bias
        # __eps is used to avoid log of zero/negative number.
        y = self._action_scale * (1 - y.pow(2)) + self.__eps
        # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
        # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
        # in appendix C to get some understanding of this equation.
        log_prob = dist.log_prob(x).unsqueeze(-1)
        log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)

        return Batch(logits=logits,
                     act=act,
                     state=h,
                     dist=dist,
                     log_prob=log_prob)
Beispiel #11
0
class TanhNormal(Distribution):
    """Copied from Kaixhi"""
    def __init__(self, loc, scale):
        super().__init__()
        self.normal = Independent(Normal(loc, scale), 1)

    def sample(self):
        return torch.tanh(self.normal.sample())

    # samples with re-parametrization trick (differentiable)
    def rsample(self):
        return torch.tanh(self.normal.rsample())

    # Calculates log probability of value using the change-of-variables technique
    # (uses log1p = log(1 + x) for extra numerical stability)
    def log_prob(self, value):
        inv_value = (torch.log1p(value) - torch.log1p(-value)) / 2  # artanh(y)
        # log p(f^-1(y)) + log |det(J(f^-1(y)))|
        return self.normal.log_prob(inv_value) - torch.log1p(-value.pow(2) +
                                                             1e-6).sum(dim=1)

    @property
    def mean(self):
        return torch.tanh(self.normal.mean)

    def get_std(self):
        return self.normal.stddev
Beispiel #12
0
class VAE(nn.Module):
    def __init__(self, encoder, decoder, device=None):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        if device is None:
            self.device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device

    def forward(self, x=None):

        bs = x.size(0)
        ls = self.encoder.latent_dims

        mu, sigma = self.encoder(x)
        self.pz = Independent(Normal(loc=torch.zeros(bs, ls).to(self.device),
                                     scale=torch.ones(bs, ls).to(self.device)),
                              reinterpreted_batch_ndims=1)
        self.qz_x = Independent(Normal(loc=mu, scale=torch.exp(sigma)),
                                reinterpreted_batch_ndims=1)

        self.z = self.qz_x.rsample()
        decoded = self.decoder(self.z)

        return decoded

    def compute_loss(self, x, y, scale_kl=False):

        px_z = Independent(ContinuousBernoulli(logits=y),
                           reinterpreted_batch_ndims=3)

        px = px_z.log_prob(x)
        kl = self.pz.log_prob(self.z) - self.qz_x.log_prob(self.z)

        if scale_kl:
            kl = kl * scale_kl

        loss = -(px + kl).mean()

        return loss, kl.mean().item(), px.mean().item()

    def rmse(self, input, target):
        return torch.sqrt(F.mse_loss(input, target))
    def _loss_vae(self, x):

        batch_size = x.size(0)
        encoder_output = self.encoder(x)
        pz = Independent(Normal(loc=torch.zeros(batch_size, self.latent_dim).to(self.device),
                                scale=torch.ones(batch_size, self.latent_dim).to(self.device)),
                         reinterpreted_batch_ndims=1)
        qz_x = Independent(Normal(loc=encoder_output[:, :self.latent_dim],
                                  scale=torch.exp(encoder_output[:, self.latent_dim:])),
                           reinterpreted_batch_ndims=1)

        z = qz_x.rsample()

        decoder_output = self.decoder(z)
        px_z = Independent(Bernoulli(logits=decoder_output),
                           reinterpreted_batch_ndims=1)
        loss = -(px_z.log_prob(x) + pz.log_prob(z) - qz_x.log_prob(z)).mean()
        return loss, decoder_output
Beispiel #14
0
 def forward(self, obs, act=None, deterministic=False):
     # Optionally pass in an action to get the log_prob of that action
     mu = self.mu_layer(obs)
     std = torch.exp(self.log_std_layer)
     pi = Independent(Normal(mu, std), 1)
     if act is None:
         act = pi.mean if deterministic else pi.rsample()
     log_prob = pi.log_prob(act)
     return pi, act, log_prob
Beispiel #15
0
def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices = None):
	n_data_points = mu_2d.size()[-1]

	if n_data_points > 0:
		gaussian = Independent(Normal(loc = mu_2d, scale = obsrv_std.repeat(n_data_points)), 1)
		log_prob = gaussian.log_prob(data_2d) 
		log_prob = log_prob / n_data_points 
	else:
		log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze()
	return log_prob
Beispiel #16
0
    def act(self, obs, deterministic=False):
        action_mean = self.forward(obs)
        normal = Normal(action_mean, torch.exp(self.log_scale))
        dist = Independent(normal, 1)
        if deterministic:
            action = action_mean
        else:
            action = dist.rsample()
        action_logprobs = dist.log_prob(torch.squeeze(action))

        return action, action_logprobs
Beispiel #17
0
def calc_loglikelihood(inputs: torch.Tensor, outputs: torch.Tensor,
                       sigma_prior: float):
    predicted = outputs.flatten(1)
    true_mu = inputs.flatten(1)
    base_normal = Normal(
        true_mu,
        torch.tensor(torch.ones_like(true_mu) * sigma_prior,
                     dtype=torch.float32,
                     device=true_mu.device))
    mvn = Independent(base_normal, 1)
    llh = mvn.log_prob(predicted)
    return llh
Beispiel #18
0
    def compute_loss(self, x, y, scale_kl=False):

        px_z = Independent(ContinuousBernoulli(logits=y),
                           reinterpreted_batch_ndims=3)

        px = px_z.log_prob(x)
        kl = self.pz.log_prob(self.z) - self.qz_x.log_prob(self.z)

        if scale_kl:
            kl = kl * scale_kl

        loss = -(px + kl).mean()

        return loss, kl.mean().item(), px.mean().item()
Beispiel #19
0
    def loss(self, x):
        """
        returns
        1. the avergave value of negative ELBO across the minibatch x
        2. and the output of the decoder
        """
        batch_size = x.size(0)
        encoder_output = self.encoder(x)
        pz = Independent(Normal(loc=torch.zeros(batch_size,
                                                self.z_dim).to(self.device),
                                scale=torch.ones(batch_size,
                                                 self.z_dim).to(self.device)),
                         reinterpreted_batch_ndims=1)
        qz_x = Independent(Normal(loc=encoder_output[:, :self.z_dim],
                                  scale=torch.exp(
                                      encoder_output[:, self.z_dim:])),
                           reinterpreted_batch_ndims=1)

        z = qz_x.rsample()
        decoder_output = self.decoder(z)
        px_z = Independent(Bernoulli(logits=decoder_output),
                           reinterpreted_batch_ndims=1)
        loss = -(px_z.log_prob(x) + pz.log_prob(z) - qz_x.log_prob(z)).mean()
        return loss, decoder_output
Beispiel #20
0
    def get_action(self, obs):
        obs = torch.tensor(obs, dtype=torch.float).to(self.device)
        with torch.no_grad():
            mu, sigma = self.pi(obs)
            act_distribution = Independent(Normal(mu, sigma), 1)
            action = act_distribution.sample()

            log_prob = act_distribution.log_prob(action)
            val = self.V(obs)

        action = action.cpu().numpy()
        log_prob = log_prob.cpu().numpy()
        val = val.cpu().numpy()

        return action, log_prob, val
Beispiel #21
0
 def mdn_loss_fn(self, mu, sigma, y, epsilon=1e-9):
     # Non-vectorised version
     #result = torch.zeros(y.shape[0], self.n_gaussians).to(self.device)
     # for idx in range(self.n_gaussians):
     #     gaussian = Independent(Normal(loc=mu[:, :, idx], scale=sigma), 1)
     #     result_per_gaussian = gaussian.log_prob(y)
     #     result[:, idx] = result_per_gaussian + self.pi.log()
     # return -torch.mean(torch.logsumexp(result, dim=1))
     gaussian = Independent(
         Normal(loc=mu,
                scale=sigma.reshape(-1, self.n_outputs,
                                    1).repeat(1, 1, mu.shape[2])), 0)
     result = gaussian.log_prob(
         y.reshape([-1, mu.shape[1], 1]).repeat(1, 1, self.n_gaussians))
     result = torch.sum(result, dim=1) + self.pi.log()
     return -torch.mean(torch.logsumexp(result, dim=1))
Beispiel #22
0
    def forward(self, obs, deterministic=False):
        mu = self.mu_layer(obs)
        log_std = self.log_std_layer(obs)
        std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX).exp()

        # Pre-squash distribution and sample
        pi_distribution = Independent(Normal(mu, std), 1)

        act = mu if deterministic else pi_distribution.rsample()

        log_prob = pi_distribution.log_prob(act)
        squashed_action = torch.tanh(act)

        log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) +
                                        self.__eps).sum(axis=-1)
        return squashed_action, log_prob
Beispiel #23
0
def loss_vae_normal(x, encoder, decoder):
    batch_size = x.size(0)
    encoder_output = encoder(x)
    d = encoder_output.shape[1] // 2
    pz_loc = F.sigmoid(torch.zeros(batch_size, d).to(device))
    pz_scale = torch.ones(batch_size, d).to(device)
    pz = Independent(Normal(loc=pz_loc, scale=pz_scale),
                     reinterpreted_batch_ndims=1)
    qz_x_loc = encoder_output[:, :d]
    qz_x_log_scale = encoder_output[:, d:]
    qz_x = Independent(Normal(loc=qz_x_loc, scale=qz_x_log_scale**2),
                       reinterpreted_batch_ndims=1)
    z = qz_x.rsample()
    decoder_output = decoder(z)
    optimal_sigma_observed = ((x - decoder_output)**2).mean(
        [0, 1, 2, 3], keepdim=True).sqrt()
    px_z = Independent(Normal(loc=decoder_output,
                              scale=optimal_sigma_observed),
                       reinterpreted_batch_ndims=3)
    elbo = (px_z.log_prob(x) - kl_divergence(qz_x, pz)).mean()
    return -elbo, decoder_output
Beispiel #24
0
 def test_independent_shape(self):
     for Dist, params in EXAMPLES:
         for param in params:
             base_dist = Dist(**param)
             x = base_dist.sample()
             base_log_prob_shape = base_dist.log_prob(x).shape
             for reinterpreted_batch_ndims in range(
                     len(base_dist.batch_shape) + 1):
                 indep_dist = Independent(base_dist,
                                          reinterpreted_batch_ndims)
                 indep_log_prob_shape = base_log_prob_shape[:len(
                     base_log_prob_shape) - reinterpreted_batch_ndims]
                 self.assertEqual(
                     indep_dist.log_prob(x).shape, indep_log_prob_shape)
                 self.assertEqual(indep_dist.sample().shape,
                                  base_dist.sample().shape)
                 self.assertEqual(indep_dist.has_rsample,
                                  base_dist.has_rsample)
                 if indep_dist.has_rsample:
                     self.assertEqual(indep_dist.sample().shape,
                                      base_dist.sample().shape)
                 try:
                     self.assertEqual(
                         indep_dist.enumerate_support().shape,
                         base_dist.enumerate_support().shape,
                     )
                     self.assertEqual(indep_dist.mean.shape,
                                      base_dist.mean.shape)
                 except NotImplementedError:
                     pass
                 try:
                     self.assertEqual(indep_dist.variance.shape,
                                      base_dist.variance.shape)
                 except NotImplementedError:
                     pass
                 try:
                     self.assertEqual(indep_dist.entropy().shape,
                                      indep_log_prob_shape)
                 except NotImplementedError:
                     pass
def reinforce_loss(policy,
                   episodes,
                   init_std=1.0,
                   min_std=1e-6,
                   output_size=2
                   ):
    output = policy(episodes.observations.view((-1, *episodes.observation_shape)))

    min_log_std = math.log(min_std)
    sigma = nn.Parameter(torch.Tensor(output_size))
    sigma.data.fill_(math.log(init_std))

    scale = torch.exp(torch.clamp(sigma, min=min_log_std))
    pi = Independent(Normal(loc=output, scale=scale), 1)

    log_probs = pi.log_prob(episodes.actions.view((-1, *episodes.action_shape)))
    log_probs = log_probs.view(len(episodes), episodes.batch_size)

    losses = -weighted_mean(log_probs * episodes.advantages,
                            lengths=episodes.lengths)

    return losses.mean()
Beispiel #26
0
class IndependentRescaledBeta(Distribution):
    arg_constraints = {
        'concentration1': constraints.positive,
        'concentration0': constraints.positive
    }
    support = constraints.interval(-1., 1.)
    has_rsample = True

    def __init__(self, concentration1, concentration0, validate_args=None):
        self.base_dist = Independent(RescaledBeta(concentration1,
                                                  concentration0,
                                                  validate_args),
                                     len(concentration1.shape) - 1,
                                     validate_args=validate_args)
        super(IndependentRescaledBeta,
              self).__init__(self.base_dist.batch_shape,
                             self.base_dist.event_shape,
                             validate_args=validate_args)

    def log_prob(self, value):
        return self.base_dist.log_prob(value)

    @property
    def mean(self):
        return self.base_dist.mean

    @property
    def variance(self):
        return self.base_dist.variance

    def sample(self, sample_shape=torch.Size()):
        return self.base_dist.sample(sample_shape)

    def rsample(self, sample_shape=torch.Size()):
        return self.base_dist.rsample(sample_shape)

    def entropy(self):
        entropy = self.base_dist.entropy()
        return entropy
def loss_function(x_hat, x, q_z, z, epoch):
    if args.loss=='mixture':
        BCE = torch.mean(-log_mix_dep_Logistic_256(x, x_hat, average=True, n_comps=10))

    if args.loss=='CE':
        x_hat = x_hat.view(-1, 3, 256, 64, 64)
        x_hat = x_hat.permute(0, 1, 3, 4, 2)
        x_hat = x_hat.contiguous()
        x_hat = x_hat.view(-1, 256)
        #x_hat = torch.round(256 * x_hat.view(-1, 256))
        target = Variable(x.data.view(-1) * 255).long()
        BCE = loss(x_hat, target)
    #x = x.view(-1, x_hat.size(1))
    #tensor = torch.ones(1)
    #p_x_dist = Beta(tensor.new_full((z.size(0), z_dim), 0.5).to(device), tensor.new_full((z.size(0), z_dim), 0.5).to(device))
    z_sqrt = int(np.sqrt(z_dim))
    if arch == 'resnet' or arch == 'convlin':
        p_x_dist = Independent(distri(torch.zeros(z.size(0), z_dim).to(device), torch.ones(z.size(0), z_dim).to(device)), 1)
    else:
        p_x_dist = Independent(distri(torch.zeros(z.size(0), 1, z_sqrt, z_sqrt).to(device), torch.ones(z.size(0), 1, z_sqrt, z_sqrt).to(device)), 1)
    one_third = round(args.epochs/3)

    if beta_final>=1:
        if epoch<=one_third:
            beta = (beta_final*epoch)/one_third
        else:
            beta = beta_final
    else:
        beta = 1

    #BCE = torch.sum(-p_x.log_prob(x.view(x.size(0), x_dim**2)))
    KLD = torch.mean(q_z.log_prob(z) - p_x_dist.log_prob(z))

    print(BCE, KLD, beta)

    return (BCE + beta*KLD), BCE, KLD
Beispiel #28
0
    def __call__(self, x, out_keys=['action'], info={}, **kwargs):
        # Output dictionary
        out_policy = {}

        # Forward pass of feature networks to obtain features
        if self.recurrent:
            out_network = self.network(x=x,
                                       hidden_states=self.rnn_states,
                                       mask=info.get('mask', None))
            features = out_network['output']
            # Update the tracking of current RNN hidden states
            self.rnn_states = out_network['hidden_states']
        else:
            features = self.network(x)

        # Forward pass through mean head to obtain mean values for Gaussian distribution
        mean = self.network.mean_head(features)
        # Obtain logvar based on the options
        if isinstance(self.network.logvar_head,
                      nn.Linear):  # linear layer, then do forward pass
            logvar = self.network.logvar_head(features)
        else:  # either Tensor or nn.Parameter
            logvar = self.network.logvar_head
            # Expand as same shape as mean
            logvar = logvar.expand_as(mean)

        # Forward pass of value head to obtain value function if required
        if 'state_value' in out_keys:
            out_policy['state_value'] = self.network.value_head(
                features).squeeze(-1)  # squeeze final single dim

        # Get std from logvar
        if self.std_style == 'exp':
            std = torch.exp(0.5 * logvar)
        elif self.std_style == 'softplus':
            std = F.softplus(logvar)

        # Lower bound threshould for std
        min_std = torch.full(std.size(),
                             self.min_std).type_as(std).to(self.device)
        std = torch.max(std, min_std)

        # Create independent Gaussian distributions i.e. Diagonal Gaussian
        action_dist = Independent(Normal(loc=mean, scale=std), 1)

        # Sample action from the distribution (no gradient)
        # Do not use `rsample()`, it leads to zero gradient of mean head !
        action = action_dist.sample()
        out_policy['action'] = action

        # Calculate log-probability of the sampled action
        if 'action_logprob' in out_keys:
            out_policy['action_logprob'] = action_dist.log_prob(action)

        # Calculate policy entropy conditioned on state
        if 'entropy' in out_keys:
            out_policy['entropy'] = action_dist.entropy()

        # Calculate policy perplexity i.e. exp(entropy)
        if 'perplexity' in out_keys:
            out_policy['perplexity'] = action_dist.perplexity()

        # sanity check for NaN
        if torch.any(torch.isnan(action)):
            while True:
                msg = 'NaN ! A workaround is to learn state-independent std or use tanh rather than relu'
                msg2 = f'check: \n\t mean: {mean}, logvar: {logvar}'
                print(msg + msg2)

        # Constraint action in valid range
        out_policy['action'] = self.constraint_action(action)

        return out_policy
Beispiel #29
0
def normal_log_density(means, stds, actions):
    dist = Independent(Normal(means, stds), 1)
    return dist.log_prob(actions)
Beispiel #30
0
        b, m, train_m, test_m = make_batch_mask(batch, param)

        input_tuple = (b, m, train_m, test_m)
        #tec = time.time()
        #print('Batch got in %.2f sec' % (tec - tic))
        optimizer.zero_grad()
        output = model.forward(input_tuple)
        masked_output = output[test_m.bool()].reshape(
            param['batch_size'], (param['total_points'] - param['obs_points']))
        target = b[:, :, 1][test_m.bool()].reshape(
            param['batch_size'], (param['total_points'] - param['obs_points']))

        log_likelihood = torch.tensor(0.0)
        for i in range(masked_output.shape[0]):
            gaussian = Independent(Normal(masked_output[i], param['sigma']), 1)
            ll = gaussian.log_prob(target[i]) / masked_output.shape[1]
            log_likelihood += ll
        log_likelihood /= masked_output.shape[0]
        loss = -log_likelihood
        mse_loss = mse(masked_output, target)

        #tac = time.time()
        #print('Forward finished in %.2f sec' % (tac - tec))
        loss.backward()
        #tuc = time.time()
        #print('Backward fininshed in %.2f sec' % (tuc - tac))
        optimizer.step()
        toc = time.time()

        for k in range(param['figure_per_batch']):
            plt.clf()