Пример #1
0
 def latent(self, conditioning):
     """
     z_enc: [batch, frames, latent_dims]
     attributes: [batch, frames, attribute_dims]
     """
     z_enc = conditioning['z']
     attributes = conditioning['attributes']
     batch_size, n_frames, _encdims = z_enc.shape
     if len(attributes.shape) < 3:
         # expand along frame dimension
         attributes = attributes.unsqueeze(1).expand(-1, n_frames, -1)
     mu_q, logscale_q = self.psi_q(z_enc, z_enc)
     # mix with latent
     mu_a, logscale_a = self.attribute_latent(mu_q, logscale_q, attributes)
     # feed into temporal model
     h_0 = torch.rand(1, batch_size, self.latent_dims).to(
         z_enc.device) * 0.01
     temp_q = self.temporal_latent_model(mu_q, logscale_q, h_0)
     # posterior distribution
     mu_z, scale_z = self.mix_with_temp(mu_a, logscale_a, temp_q)
     posterior = Independent(Normal(mu_z, scale_z), 1)
     posterior_sample = posterior.rsample()
     # prior
     output = torch.cat([temp_q, attributes], dim=-1)
     mu, scale = self.psi_p(output, output)
     scale = scale.exp()
     prior = Independent(Normal(mu, scale), 1)
     # sum over latent dim mean over batch
     kl_div = torch.mean(kl_divergence(posterior, prior))
     return posterior_sample, kl_div, [mu_z, scale_z]
Пример #2
0
    def forward(  # type: ignore
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        input: str = "obs",
        **kwargs: Any,
    ) -> Batch:
        obs = batch[input]
        logits, h = self.actor(obs, state=state, info=batch.info)
        assert isinstance(logits, tuple)
        dist = Independent(Normal(*logits), 1)
        if self._deterministic_eval and not self.training:
            x = logits[0]
        else:
            x = dist.rsample()
        y = torch.tanh(x)
        act = y * self._action_scale + self._action_bias
        # __eps is used to avoid log of zero/negative number.
        y = self._action_scale * (1 - y.pow(2)) + self.__eps
        # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
        # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
        # in appendix C to get some understanding of this equation.
        log_prob = dist.log_prob(x).unsqueeze(-1)
        log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)

        return Batch(logits=logits,
                     act=act,
                     state=h,
                     dist=dist,
                     log_prob=log_prob)
Пример #3
0
    def trajectory(self, current_state):
        '''
        Maybe this implementation doesn't utilize GPUs very well, but I have no clue or not.

        Final output looks like:
        [(s_0, a_0, r_0), ..., (s_L, a_L, r_l)]
        '''
        output_history = []
        while True:
            mu, sigma = self.forward(current_state)
            distribution = Independent(Normal(mu, sigma), 1)
            picked_action = distribution.rsample()
            action = picked_action.detach()
            #print(action)
            new_state, reward = self.env.state_and_reward(
                current_state, action
            )  #Get the reward and the new state that the action in the environment resulted in. None if action caused death. TODO build in environment
            #Attempting this
            output_history.append(
                (current_state, action, reward, distribution.log_prob(action)))
            if new_state is None:  #essentially, you died or finished your trajectory
                break
            else:
                current_state = new_state
        return output_history
Пример #4
0
    def sample_trajectories(self, init_std=1.0, min_std=1e-6, output_size=2):
        # 基于当前策略,采样 batch_size 个完整的轨迹
        observations = self.envs.reset()
        with torch.no_grad():
            while not self.envs.dones.all():
                observations_tensor = torch.from_numpy(observations)
                """
                ******************************************************************
                """
                output = self.policy(observations_tensor)

                min_log_std = math.log(min_std)
                sigma = nn.Parameter(torch.Tensor(output_size))
                sigma.data.fill_(math.log(init_std))

                scale = torch.exp(torch.clamp(sigma, min=min_log_std))
                # loc 是高斯分布均值
                # scale 是高斯分布方差
                p_normal = Independent(Normal(loc=output, scale=scale), 1)

                actions_tensor = p_normal.sample()
                actions = actions_tensor.cpu().numpy()

                # pi = policy(observations_tensor)
                # actions_tensor = pi.sample()
                # actions = actions_tensor.cpu().numpy()

                new_observations, rewards, _, infos = self.envs.step(actions)
                batch_ids = infos['batch_ids']
                yield (observations, actions, rewards, batch_ids)
                observations = new_observations
Пример #5
0
 def test_independent_expand(self):
     for Dist, params in EXAMPLES:
         for param in params:
             base_dist = Dist(**param)
             for reinterpreted_batch_ndims in range(
                     len(base_dist.batch_shape) + 1):
                 for s in [
                         torch.Size(),
                         torch.Size((2, )),
                         torch.Size((2, 3))
                 ]:
                     indep_dist = Independent(base_dist,
                                              reinterpreted_batch_ndims)
                     expanded_shape = s + indep_dist.batch_shape
                     expanded = indep_dist.expand(expanded_shape)
                     expanded_sample = expanded.sample()
                     expected_shape = expanded_shape + indep_dist.event_shape
                     self.assertEqual(expanded_sample.shape, expected_shape)
                     self.assertEqual(
                         expanded.log_prob(expanded_sample),
                         indep_dist.log_prob(expanded_sample),
                     )
                     self.assertEqual(expanded.event_shape,
                                      indep_dist.event_shape)
                     self.assertEqual(expanded.batch_shape, expanded_shape)
Пример #6
0
    def KLD(self, a, b, prior_alpha, prior_beta):
        eps = 5 * torch.finfo(torch.float).eps 
        a = a.clamp(eps)
        b = b.clamp(eps)
        if self.dist == "km":
            ab = (a * b) + eps
            kl = 1 / (1 + ab) * self.Beta(1 / a, b)
            kl += 1 / (2 + ab) * self.Beta(2 / a, b)
            kl += 1 / (3 + ab) * self.Beta(3 / a, b)
            kl += 1 / (4 + ab) * self.Beta(4 / a, b)
            kl += 1 / (5 + ab) * self.Beta(5 / a, b)
            kl += 1 / (6 + ab) * self.Beta(6 / a, b)
            kl += 1 / (7 + ab) * self.Beta(7 / a, b)
            kl += 1 / (8 + ab) * self.Beta(8 / a, b)
            kl += 1 / (9 + ab) * self.Beta(9 / a, b)
            kl += 1 / (10 + ab) * self.Beta(10 / a, b)
            kl *= (prior_beta - 1) * b

            kl += (a - prior_alpha) / a * (-np.euler_gamma - torch.digamma(
                b) - 1 / b)  # T.psi(self.posterior_b)

            # add normalization constants
            kl += torch.log(ab) + torch.log(self.Beta(prior_alpha, prior_beta))

            # final term
            kl += -(b - 1) / b
        elif self.dist == "gamma":
            kl = torch.distributions.kl.kl_divergence(Gamma(a, b), Gamma(prior_alpha, prior_beta))
        elif self.dist == "gl":
            prior_alpha_beta = prior_alpha/(prior_alpha + prior_beta)
            prior_beta_beta =  torch.sqrt(prior_alpha*prior_beta / ((prior_alpha + prior_beta)**2*(prior_alpha + prior_beta + 1)))
            kl = torch.distributions.kl.kl_divergence(Independent(Normal(a, b),1), Independent(Normal(prior_alpha_beta, prior_beta_beta),1)).unsqueeze(1)
        return kl
Пример #7
0
 def temporal_model_step(self, mu_q_t, logscale_q_t, h, attribute=None):
     """
     generate z_t autoregressively
     """
     # mix with temporal info
     h_mu, h_scale = self.h_process(h, h)
     mu = torch.cat([mu_q_t, h_mu], dim=-1)
     logscale = torch.cat([logscale_q_t, h_scale], dim=-1)
     mu_z_t, logscale_z_t = self.psi_dy(mu, logscale)
     scale_z_t = logscale_z_t.exp()
     # final posterior distribution with rnn information
     posterior_t = Independent(Normal(mu_z_t, scale_z_t), 1)
     posterior_sample_t = posterior_t.rsample()
     # prior distribution with rnn information
     if not attribute is None:
         mixed_h_mu = torch.cat([h_mu, attribute], dim=-1)
         mixed_h_scale = torch.cat([h_scale, attribute], dim=-1)
         mu_p_t, logscale_p_t = self.psi_p(mixed_h_mu, mixed_h_scale)
     else:
         mu_p_t, logscale_p_t = self.psi_p(h_mu, h_scale)
     # scale_p_t = logscale_p_t.exp()
     scale_p_t = torch.ones_like(logscale_p_t)
     prior_t = Independent(Normal(mu_p_t, scale_p_t), 1)
     kl_div_t = torch.mean(kl_divergence(posterior_t, prior_t))
     return posterior_sample_t, kl_div_t, mu_z_t, scale_z_t
Пример #8
0
    def forward(self, state, eval=False, with_log_prob=False):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))

        mu = self.fc3(x)
        log_sigma = self.fc4(x)
        # clip value of log_sigma, as was done in Haarnoja's implementation of SAC:
        # https://github.com/haarnoja/sac.git
        log_sigma = torch.clamp(log_sigma, -20.0, 2.0)

        sigma = torch.exp(log_sigma)
        distribution = Independent(Normal(mu, sigma), 1)

        if not eval:
            # use rsample() instead of sample(), as sample() does not allow back-propagation through params
            u = distribution.rsample()
            if with_log_prob:
                log_prob = distribution.log_prob(u)
                log_prob -= 2.0 * torch.sum(
                    (np.log(2.0) + 0.5 * np.log(self.ctrl_range) - u -
                     F.softplus(-2.0 * u)),
                    dim=1)

            else:
                log_prob = None
        else:
            u = mu
            log_prob = None
        # apply tanh so that the resulting action lies in (-1, 1)^D
        a = self.ctrl_range * torch.tanh(u)

        return a, log_prob
Пример #9
0
 def get_loss(self, output, target, ignore_label=None):
     loss = super().get_loss(output, target, ignore_label)
     eps = 1e-12
     # construct N(0,1) diagonal covariance of size y (output)
     # construct N(0,1) diagonal covariance of size y (output)
     normal = Independent(
         Normal(loc=torch.FloatTensor(
             output['logits'].size()).fill_(0).to(device),
                scale=torch.FloatTensor(output['logits'].size()).fill_(
                    self.sigma).to(device)), 1)
     # sum ( softmax (distorted softmax probs)) using predicted voxel variances (scale)
     # we then take the log of these
     sum_distorted_softmax = torch.sum(torch.stack([
         self.softmax(output['logits'] +
                      (output['sigma'] * normal.sample()))
         for _ in torch.arange(self.samples)
     ]),
                                       dim=0)
     # sum_distorted_softmax should have shape [batch, nclasses, x, y]
     one_hot = torch.zeros(output['logits'].shape).scatter_(
         1, target.unsqueeze(1), 1)
     # mask sum_distorted_softmax in order to obtain only the softmax probs for the gt class and take max
     # of the result, which will just select the prob of the gt class (reduce dim 1=nclasses)
     sum_distorted_softmax, _ = torch.max(sum_distorted_softmax * one_hot,
                                          1)
     # sum_distorted_softmax should now have shape [batch, x, y]
     # finally compute the categorical aleatoric loss
     aleatoric_loss = -0.0001 * torch.mean(torch.sum(
         torch.log(sum_distorted_softmax + eps) - np.log(self.samples),
         dim=(1, 2)),
                                           dim=0)
     output['sigma'] = output['sigma'].cpu().detach().numpy()
     output['logits'] = None
     self.current_aleatoric_loss = aleatoric_loss.detach().cpu().numpy()
     return loss + aleatoric_loss
Пример #10
0
 def forward(  # type: ignore
     self,
     batch: Batch,
     state: Optional[Union[dict, Batch, np.ndarray]] = None,
     input: str = "obs",
     **kwargs: Any,
 ) -> Batch:
     obs = batch[input]
     logits, h = self.actor(obs, state=state, info=batch.info)
     assert isinstance(logits, tuple)
     dist = Independent(Normal(*logits), 1)
     if self._deterministic_eval and not self.training:
         act = logits[0]
     else:
         act = dist.rsample()
     log_prob = dist.log_prob(act).unsqueeze(-1)
     # apply correction for Tanh squashing when computing logprob from Gaussian
     # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
     # in appendix C to get some understanding of this equation.
     if self.action_scaling and self.action_space is not None:
         action_scale = to_torch_as(
             (self.action_space.high - self.action_space.low) / 2.0, act)
     else:
         action_scale = 1.0  # type: ignore
     squashed_action = torch.tanh(act)
     log_prob = log_prob - torch.log(action_scale *
                                     (1 - squashed_action.pow(2)) +
                                     self.__eps).sum(-1, keepdim=True)
     return Batch(logits=logits,
                  act=squashed_action,
                  state=h,
                  dist=dist,
                  log_prob=log_prob)
Пример #11
0
 def choose_action(self, observation):
     mu, sigma  = self.actor.forward(observation)#.to(self.actor.device)
     sigma = T.exp(sigma)
     action_probs = Independent(Normal(mu, sigma),1)
     probs = action_probs.sample()
     self.log_probs = action_probs.log_prob(probs).to(self.actor.device)
     return probs
Пример #12
0
 def generate(self,
              synth,
              h_0,
              f0_hz,
              enc_frame_setting='fine',
              n_samples=16000):
     """
     synth:          synth to generate audio
     h_0:            initial state of RNN [batch, latent_dims]
     f0_hz:          f0 conditioning of synth [batch, f0_n_frames, 1]
     enc_frame_setting: fft/hop size
     n_samples:      output audio length in samples
     """
     h = h_0
     n_fft, hop_length = get_window_hop(enc_frame_setting)
     n_frames = math.ceil((n_samples - n_fft) / hop_length) + 1
     f0_hz = resample_frames(f0_hz,
                             n_frames)  # needs to have same dimension as z
     params_list = []
     z = torch.zeros(h_0.shape[0], n_frames, self.latent_dims).to(h.device)
     for t in range(n_frames):
         # prior distribution with rnn information
         mu_p_t, scale_p_t = self.get_prior(h)
         prior_t = Independent(Normal(mu_p_t, scale_p_t), 1)
         prior_sample_t = prior_t.rsample()
         h = self.temporal(prior_sample_t, h)
         z[:, t, :] = prior_sample_t
     cond = {}
     cond['z'] = z
     cond['f0_hz'] = f0_hz
     y_params = self.decode(cond)
     params = synth.fill_params(y_params, cond)
     resyn_audio, outputs = synth(params, n_samples)
     return params, resyn_audio
Пример #13
0
class IndependentNormal(Distribution):
    arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
    support = constraints.positive
    has_rsample = True

    def __init__(self, loc, scale, validate_args=None):
        self.base_dist = Independent(Normal(loc=loc,
                                            scale=scale,
                                            validate_args=validate_args),
                                     len(loc.shape) - 1,
                                     validate_args=validate_args)
        super(IndependentNormal, self).__init__(self.base_dist.batch_shape,
                                                self.base_dist.event_shape,
                                                validate_args=validate_args)

    def log_prob(self, value):
        return self.base_dist.log_prob(value)

    @property
    def mean(self):
        return self.base_dist.mean

    @property
    def variance(self):
        return self.base_dist.variance

    def sample(self, sample_shape=torch.Size()):
        return self.base_dist.sample(sample_shape)

    def rsample(self, sample_shape=torch.Size()):
        return self.base_dist.rsample(sample_shape)

    def entropy(self):
        entropy = self.base_dist.entropy()
        return entropy
Пример #14
0
	def forward(self, input, kl_coef):
		b, m, train_m, test_m = input
		
		mean, std = self.ode_rnn(input)	# (batch_size, LO_hidden_size) * 2
		
		d = Normal(torch.tensor([0.0], device = self.param['device']), torch.tensor([1.0], device = self.param['device']))
		r = d.sample(mean.shape).squeeze(-1)
		z0 = mean + r * std
		
		z_out = odeint(self.ode_func, z0, b[0, :, 0], rtol = self.param['rtol'], atol = self.param['atol']) # (num_time_points, batch_size, LO_hidden_size)	
		z_out = z_out.permute(1, 0, 2)
		output = self.output_output(z_out).squeeze(2)

		z0_distr = Normal(mean, std)
		kl_div = kl_divergence(z0_distr, Normal(torch.tensor([0.0], device = self.param['device']), torch.tensor([1.0], device = self.param['device'])))
		kl_div = kl_div.mean(axis = 1)
		
		masked_output = output[test_m.bool()].reshape(self.param['batch_size'], (self.param['total_points'] - self.param['obs_points']))
		target = b[:, :, 1][test_m.bool()].reshape(self.param['batch_size'], (self.param['total_points'] - self.param['obs_points']))
		
		gaussian = Independent(Normal(loc = masked_output, scale = self.param['obsrv_std']), 1)
		log_prob = gaussian.log_prob(target)
		likelihood = log_prob / output.shape[1]
		
		loss = -torch.logsumexp(likelihood - kl_coef * kl_div, 0)
		mse = self.mse_func(masked_output, target)
		
		return loss, mse, masked_output
		
Пример #15
0
 def forward(  # type: ignore
     self,
     batch: Batch,
     state: Optional[Union[dict, Batch, np.ndarray]] = None,
     input: str = "obs",
     **kwargs: Any,
 ) -> Batch:
     obs = batch[input]
     logits, h = self.actor(obs, state=state, info=batch.info)
     assert isinstance(logits, tuple)
     dist = Independent(Normal(*logits), 1)
     if self._deterministic_eval and not self.training:
         x = logits[0]
     else:
         x = dist.rsample()
     y = torch.tanh(x)
     act = y * self._action_scale + self._action_bias
     y = self._action_scale * (1 - y.pow(2)) + self.__eps
     log_prob = dist.log_prob(x).unsqueeze(-1)
     log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
     if self._noise is not None and self.training and not self.updating:
         act += to_torch_as(self._noise(act.shape), act)
     act = act.clamp(self._range[0], self._range[1])
     return Batch(
         logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
Пример #16
0
class TanhNormal(Distribution):
    """Copied from Kaixhi"""
    def __init__(self, loc, scale):
        super().__init__()
        self.normal = Independent(Normal(loc, scale), 1)

    def sample(self):
        return torch.tanh(self.normal.sample())

    # samples with re-parametrization trick (differentiable)
    def rsample(self):
        return torch.tanh(self.normal.rsample())

    # Calculates log probability of value using the change-of-variables technique
    # (uses log1p = log(1 + x) for extra numerical stability)
    def log_prob(self, value):
        inv_value = (torch.log1p(value) - torch.log1p(-value)) / 2  # artanh(y)
        # log p(f^-1(y)) + log |det(J(f^-1(y)))|
        return self.normal.log_prob(inv_value) - torch.log1p(-value.pow(2) +
                                                             1e-6).sum(dim=1)

    @property
    def mean(self):
        return torch.tanh(self.normal.mean)

    def get_std(self):
        return self.normal.stddev
Пример #17
0
    def forward(self, x):
        kl = torch.zeros(1).to(device)
        z = 0.
        # Unet encoder result
        x_enc = self.unet_encoder(x)

        # VAE regularisation
        if not self.unet:
            emb = self.vae_encoder(x)

            # Split encoder outputs into a mean and variance vector
            mu, log_var = torch.chunk(emb, 2, dim=1)

            # Make sure that the log variance is positive
            log_var = softplus(log_var)
            sigma = torch.exp(log_var / 2)

            # Instantiate a diagonal Gaussian with mean=mu, std=sigma
            # This is the approximate latent distribution q(z|x)
            posterior = Independent(Normal(loc=mu, scale=sigma), 1)
            z = posterior.rsample()

            # Instantiate a standard Gaussian with mean=mu_0, std=sigma_0
            # This is the prior distribution p(z)
            prior = Independent(Normal(loc=self.mu_0, scale=self.sigma_0), 1)

            # Estimate the KLD between q(z|x)|| p(z)
            kl = KLD(posterior, prior).sum()

        # Outputs for MSE
        xHat = self.decoder(x_enc, z)

        return kl, xHat
Пример #18
0
    def generate(self,
                 synth,
                 h_0,
                 f0_hz,
                 attributes,
                 enc_frame_setting='fine',
                 n_samples=16000):
        """
        synth:          synth to generate audio
        h_0:            initial seed of RNN [batch, latent_dims]
        f0_hz:          f0 conditioning of synth [batch, f0_n_frames, 1]
        attributes:     attributes [batch, n_frames, attribute_size]
        enc_frame_setting: fft/hop size
        n_samples:      output audio length in samples
        """
        if len(h_0.shape) == 2:
            h = h_0[None, :, :]  # 1, batch, latent_dims
        else:
            h = h_0
        n_fft, hop_length = get_window_hop(enc_frame_setting)
        n_frames = math.ceil((n_samples - n_fft) / hop_length) + 1
        f0_hz = resample_frames(f0_hz,
                                n_frames)  # needs to have same dimension as z
        params_list = []
        for i in range(n_frames):
            cond = {}
            output = torch.cat([h.permute(1, 0, 2), attributes], dim=-1)
            mu, logscale = self.psi_p(output, output)
            scale = logscale.exp()
            prior = Independent(Normal(mu, scale), 1)
            prior_sample = prior.rsample()
            cond['z'] = prior_sample
            cond['f0_hz'] = f0_hz[:, i, :].unsqueeze(1)
            cond['f0_scaled'] = hz_to_midi(cond['f0_hz']) / 127.0
            # generate x
            y = self.decode(cond)
            params = synth.fill_params(y, cond)
            params_list.append(params)
            x_tilde, _outputs = synth(
                params, n_samples=n_fft)  # write exactly one frame
            cond['audio'] = x_tilde
            # encode
            cond = self.encoder(cond)
            z_enc = cond['z']
            # get psi_q
            mu, logscale = self.psi_q(z_enc, z_enc)
            psi = torch.cat([mu, logscale], dim=-1)
            # temporal model
            temp_q, h = self.temporal_q(psi, h)  # one off

        param_names = params_list[0].keys()
        final_params = {}
        for pn in param_names:
            #cat over frames
            final_params[pn] = torch.cat([par[pn] for par in params_list],
                                         dim=1)

        final_audio, _outputs = synth(final_params, n_samples=n_samples)
        return final_params, final_audio
Пример #19
0
    def initialize(self, data, ndim):
        self._mean = torch.zeros((data.shape[0] + 1, ndim), requires_grad=True)
        self._logstd = torch.zeros_like(self._mean, requires_grad=True)

        # ===== Start optimization ===== #
        self._sampledist = Independent(Normal(torch.zeros_like(self._mean), torch.ones_like(self._logstd)), 2)

        return self
Пример #20
0
    def get_ELBO_per_obs(self, batch, beta=1.0):
        output = self(batch)
        px, pz, qz, z = [output[k] for k in ["px", "pz", "qz", "z"]]
        kl_term = kl_divergence(Independent(qz, 1), Independent(pz, 1))

        elbo = px.log_prob(batch).sum(-1) + beta * kl_term

        return elbo
Пример #21
0
 def __init__(self, loc, scale, validate_args=None):
     self.base_dist = Independent(Normal(loc=loc,
                                         scale=scale,
                                         validate_args=validate_args),
                                  len(loc.shape) - 1,
                                  validate_args=validate_args)
     super(IndependentNormal, self).__init__(self.base_dist.batch_shape,
                                             self.base_dist.event_shape,
                                             validate_args=validate_args)
Пример #22
0
 def forward(self, obs, act=None, deterministic=False):
     # Optionally pass in an action to get the log_prob of that action
     mu = self.mu_layer(obs)
     std = torch.exp(self.log_std_layer)
     pi = Independent(Normal(mu, std), 1)
     if act is None:
         act = pi.mean if deterministic else pi.rsample()
     log_prob = pi.log_prob(act)
     return pi, act, log_prob
Пример #23
0
def prop_state(x, f, g):
    bins = Independent(Binomial(x[:-1], f), 1)
    samp = bins.sample()

    s = x[0] - samp[..., 0]
    i = x[1] + samp[..., 0] - samp[..., 1]
    r = x[2] + samp[..., 1]

    return concater(s, i, r)
Пример #24
0
def _kl_independent_independent(p, q):
    shared_ndims = min(p.reinterpreted_batch_ndims, q.reinterpreted_batch_ndims)
    p_ndims = p.reinterpreted_batch_ndims - shared_ndims
    q_ndims = q.reinterpreted_batch_ndims - shared_ndims
    p = Independent(p.base_dist, p_ndims) if p_ndims else p.base_dist
    q = Independent(q.base_dist, q_ndims) if q_ndims else q.base_dist
    kl = kl_divergence(p, q)
    if shared_ndims:
        kl = sum_rightmost(kl, shared_ndims)
    return kl
Пример #25
0
def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices = None):
	n_data_points = mu_2d.size()[-1]

	if n_data_points > 0:
		gaussian = Independent(Normal(loc = mu_2d, scale = obsrv_std.repeat(n_data_points)), 1)
		log_prob = gaussian.log_prob(data_2d) 
		log_prob = log_prob / n_data_points 
	else:
		log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze()
	return log_prob
Пример #26
0
 def __init__(self, concentration1, concentration0, validate_args=None):
     self.base_dist = Independent(RescaledBeta(concentration1,
                                               concentration0,
                                               validate_args),
                                  len(concentration1.shape) - 1,
                                  validate_args=validate_args)
     super(IndependentRescaledBeta,
           self).__init__(self.base_dist.batch_shape,
                          self.base_dist.event_shape,
                          validate_args=validate_args)
Пример #27
0
    def generate(self,
                 synth,
                 h_0,
                 f0_hz,
                 enc_frame_setting='fine',
                 n_samples=16000):
        """
        synth:          synth to generate audio
        h_0:            initial state of RNN [batch, latent_dims]
        f0_hz:          f0 conditioning of synth [batch, f0_n_frames, 1]
        enc_frame_setting: fft/hop size
        n_samples:      output audio length in samples
        """
        h = h_0
        n_fft, hop_length = get_window_hop(enc_frame_setting)
        n_frames = math.ceil((n_samples - n_fft) / hop_length) + 1
        f0_hz = resample_frames(f0_hz,
                                n_frames)  # needs to have same dimension as z
        params_list = []
        z = torch.zeros(h_0.shape[0], n_frames, self.latent_dims).to(h.device)
        for t in range(n_frames):
            h_mu, h_scale = self.h_process(h, h)
            mu_t, logscale_t = self.psi_p(h_mu,
                                          h_scale)  # [batch, latent_size]
            scale_t = logscale_t.exp()
            prior_t = Independent(Normal(mu_t, scale_t), 1)
            prior_sample_t = prior_t.rsample()
            cond = {}
            z[:, t, :] = prior_sample_t
            cond['z'] = prior_sample_t.unsqueeze(1)
            cond['f0_hz'] = f0_hz[:, t, :].unsqueeze(1)
            cond['f0_scaled'] = hz_to_midi(cond['f0_hz']) / 127.0
            # generate x
            y = self.decode(cond)
            params = synth.fill_params(y, cond)
            params_list.append(params)
            x_tilde, _outputs = synth(
                params, n_samples=n_fft)  # write exactly one frame
            cond['audio'] = x_tilde
            # encode
            cond = self.encoder(cond)
            z_enc = cond['z'].squeeze(1)
            # get psi_q
            mu, logscale = self.psi_q(z_enc, z_enc)
            rnn_input = torch.cat([mu, logscale, prior_sample_t], dim=-1)
            # temporal model
            h = self.temporal_q(rnn_input, h)  # one off

        cond = {}
        cond['z'] = z
        cond['f0_hz'] = f0_hz
        y_params = self.decode(cond)
        params = synth.fill_params(y_params, cond)
        resyn_audio, outputs = synth(params, n_samples)
        return params, resyn_audio
Пример #28
0
    def get_loss(self, x, return_kl=False, beta=1.0):
        output = self(x)
        px, pz, qz, z = [output[k] for k in ["px", "pz", "qz", "z"]]
        kl_term = kl_divergence(Independent(qz, 1), Independent(pz, 1))

        loss = -px.log_prob(x) + beta * kl_term

        if not return_kl:
            return loss.mean()
        else:
            return loss.mean(), kl_term.mean()
Пример #29
0
def prop_state(x, beta, gamma, eta, dt):
    f = _f(x, beta, gamma, eta, dt)

    bins = Independent(Binomial(x[..., :-1], f), 1)
    samp = bins.sample()

    s = x[..., 0] - samp[..., 0]
    i = x[..., 1] + samp[..., 0] - samp[..., 1]
    r = x[..., 2] + samp[..., 1]

    return concater(s, i, r)
Пример #30
0
    def act(self, obs, deterministic=False):
        action_mean = self.forward(obs)
        normal = Normal(action_mean, torch.exp(self.log_scale))
        dist = Independent(normal, 1)
        if deterministic:
            action = action_mean
        else:
            action = dist.rsample()
        action_logprobs = dist.log_prob(torch.squeeze(action))

        return action, action_logprobs