def temporal_model_step(self, mu_q_t, logscale_q_t, h, attribute=None):
     """
     generate z_t autoregressively
     """
     # mix with temporal info
     h_mu, h_scale = self.h_process(h, h)
     mu = torch.cat([mu_q_t, h_mu], dim=-1)
     logscale = torch.cat([logscale_q_t, h_scale], dim=-1)
     mu_z_t, logscale_z_t = self.psi_dy(mu, logscale)
     scale_z_t = logscale_z_t.exp()
     # final posterior distribution with rnn information
     posterior_t = Independent(Normal(mu_z_t, scale_z_t), 1)
     posterior_sample_t = posterior_t.rsample()
     # prior distribution with rnn information
     if not attribute is None:
         mixed_h_mu = torch.cat([h_mu, attribute], dim=-1)
         mixed_h_scale = torch.cat([h_scale, attribute], dim=-1)
         mu_p_t, logscale_p_t = self.psi_p(mixed_h_mu, mixed_h_scale)
     else:
         mu_p_t, logscale_p_t = self.psi_p(h_mu, h_scale)
     # scale_p_t = logscale_p_t.exp()
     scale_p_t = torch.ones_like(logscale_p_t)
     prior_t = Independent(Normal(mu_p_t, scale_p_t), 1)
     kl_div_t = torch.mean(kl_divergence(posterior_t, prior_t))
     return posterior_sample_t, kl_div_t, mu_z_t, scale_z_t
Example #2
0
 def forward(  # type: ignore
     self,
     batch: Batch,
     state: Optional[Union[dict, Batch, np.ndarray]] = None,
     input: str = "obs",
     **kwargs: Any,
 ) -> Batch:
     obs = batch[input]
     logits, h = self.actor(obs, state=state, info=batch.info)
     assert isinstance(logits, tuple)
     dist = Independent(Normal(*logits), 1)
     if self._deterministic_eval and not self.training:
         x = logits[0]
     else:
         x = dist.rsample()
     y = torch.tanh(x)
     act = y * self._action_scale + self._action_bias
     y = self._action_scale * (1 - y.pow(2)) + self.__eps
     log_prob = dist.log_prob(x).unsqueeze(-1)
     log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
     if self._noise is not None and self.training and not self.updating:
         act += to_torch_as(self._noise(act.shape), act)
     act = act.clamp(self._range[0], self._range[1])
     return Batch(
         logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
 def latent(self, conditioning):
     """
     z_enc: [batch, frames, latent_dims]
     attributes: [batch, frames, attribute_dims]
     """
     z_enc = conditioning['z']
     attributes = conditioning['attributes']
     batch_size, n_frames, _encdims = z_enc.shape
     if len(attributes.shape) < 3:
         # expand along frame dimension
         attributes = attributes.unsqueeze(1).expand(-1, n_frames, -1)
     mu_q, logscale_q = self.psi_q(z_enc, z_enc)
     # mix with latent
     mu_a, logscale_a = self.attribute_latent(mu_q, logscale_q, attributes)
     # feed into temporal model
     h_0 = torch.rand(1, batch_size, self.latent_dims).to(
         z_enc.device) * 0.01
     temp_q = self.temporal_latent_model(mu_q, logscale_q, h_0)
     # posterior distribution
     mu_z, scale_z = self.mix_with_temp(mu_a, logscale_a, temp_q)
     posterior = Independent(Normal(mu_z, scale_z), 1)
     posterior_sample = posterior.rsample()
     # prior
     output = torch.cat([temp_q, attributes], dim=-1)
     mu, scale = self.psi_p(output, output)
     scale = scale.exp()
     prior = Independent(Normal(mu, scale), 1)
     # sum over latent dim mean over batch
     kl_div = torch.mean(kl_divergence(posterior, prior))
     return posterior_sample, kl_div, [mu_z, scale_z]
Example #4
0
 def forward(  # type: ignore
     self,
     batch: Batch,
     state: Optional[Union[dict, Batch, np.ndarray]] = None,
     input: str = "obs",
     **kwargs: Any,
 ) -> Batch:
     obs = batch[input]
     logits, h = self.actor(obs, state=state, info=batch.info)
     assert isinstance(logits, tuple)
     dist = Independent(Normal(*logits), 1)
     if self._deterministic_eval and not self.training:
         act = logits[0]
     else:
         act = dist.rsample()
     log_prob = dist.log_prob(act).unsqueeze(-1)
     # apply correction for Tanh squashing when computing logprob from Gaussian
     # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
     # in appendix C to get some understanding of this equation.
     if self.action_scaling and self.action_space is not None:
         action_scale = to_torch_as(
             (self.action_space.high - self.action_space.low) / 2.0, act)
     else:
         action_scale = 1.0  # type: ignore
     squashed_action = torch.tanh(act)
     log_prob = log_prob - torch.log(action_scale *
                                     (1 - squashed_action.pow(2)) +
                                     self.__eps).sum(-1, keepdim=True)
     return Batch(logits=logits,
                  act=squashed_action,
                  state=h,
                  dist=dist,
                  log_prob=log_prob)
Example #5
0
    def forward(  # type: ignore
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        input: str = "obs",
        **kwargs: Any,
    ) -> Batch:
        obs = batch[input]
        logits, h = self.actor(obs, state=state, info=batch.info)
        assert isinstance(logits, tuple)
        dist = Independent(Normal(*logits), 1)
        if self._deterministic_eval and not self.training:
            x = logits[0]
        else:
            x = dist.rsample()
        y = torch.tanh(x)
        act = y * self._action_scale + self._action_bias
        # __eps is used to avoid log of zero/negative number.
        y = self._action_scale * (1 - y.pow(2)) + self.__eps
        # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
        # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
        # in appendix C to get some understanding of this equation.
        log_prob = dist.log_prob(x).unsqueeze(-1)
        log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)

        return Batch(logits=logits,
                     act=act,
                     state=h,
                     dist=dist,
                     log_prob=log_prob)
Example #6
0
    def forward(self, state, eval=False, with_log_prob=False):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))

        mu = self.fc3(x)
        log_sigma = self.fc4(x)
        # clip value of log_sigma, as was done in Haarnoja's implementation of SAC:
        # https://github.com/haarnoja/sac.git
        log_sigma = torch.clamp(log_sigma, -20.0, 2.0)

        sigma = torch.exp(log_sigma)
        distribution = Independent(Normal(mu, sigma), 1)

        if not eval:
            # use rsample() instead of sample(), as sample() does not allow back-propagation through params
            u = distribution.rsample()
            if with_log_prob:
                log_prob = distribution.log_prob(u)
                log_prob -= 2.0 * torch.sum(
                    (np.log(2.0) + 0.5 * np.log(self.ctrl_range) - u -
                     F.softplus(-2.0 * u)),
                    dim=1)

            else:
                log_prob = None
        else:
            u = mu
            log_prob = None
        # apply tanh so that the resulting action lies in (-1, 1)^D
        a = self.ctrl_range * torch.tanh(u)

        return a, log_prob
Example #7
0
 def generate(self,
              synth,
              h_0,
              f0_hz,
              enc_frame_setting='fine',
              n_samples=16000):
     """
     synth:          synth to generate audio
     h_0:            initial state of RNN [batch, latent_dims]
     f0_hz:          f0 conditioning of synth [batch, f0_n_frames, 1]
     enc_frame_setting: fft/hop size
     n_samples:      output audio length in samples
     """
     h = h_0
     n_fft, hop_length = get_window_hop(enc_frame_setting)
     n_frames = math.ceil((n_samples - n_fft) / hop_length) + 1
     f0_hz = resample_frames(f0_hz,
                             n_frames)  # needs to have same dimension as z
     params_list = []
     z = torch.zeros(h_0.shape[0], n_frames, self.latent_dims).to(h.device)
     for t in range(n_frames):
         # prior distribution with rnn information
         mu_p_t, scale_p_t = self.get_prior(h)
         prior_t = Independent(Normal(mu_p_t, scale_p_t), 1)
         prior_sample_t = prior_t.rsample()
         h = self.temporal(prior_sample_t, h)
         z[:, t, :] = prior_sample_t
     cond = {}
     cond['z'] = z
     cond['f0_hz'] = f0_hz
     y_params = self.decode(cond)
     params = synth.fill_params(y_params, cond)
     resyn_audio, outputs = synth(params, n_samples)
     return params, resyn_audio
Example #8
0
    def trajectory(self, current_state):
        '''
        Maybe this implementation doesn't utilize GPUs very well, but I have no clue or not.

        Final output looks like:
        [(s_0, a_0, r_0), ..., (s_L, a_L, r_l)]
        '''
        output_history = []
        while True:
            mu, sigma = self.forward(current_state)
            distribution = Independent(Normal(mu, sigma), 1)
            picked_action = distribution.rsample()
            action = picked_action.detach()
            #print(action)
            new_state, reward = self.env.state_and_reward(
                current_state, action
            )  #Get the reward and the new state that the action in the environment resulted in. None if action caused death. TODO build in environment
            #Attempting this
            output_history.append(
                (current_state, action, reward, distribution.log_prob(action)))
            if new_state is None:  #essentially, you died or finished your trajectory
                break
            else:
                current_state = new_state
        return output_history
Example #9
0
class TanhNormal(Distribution):
    """Copied from Kaixhi"""
    def __init__(self, loc, scale):
        super().__init__()
        self.normal = Independent(Normal(loc, scale), 1)

    def sample(self):
        return torch.tanh(self.normal.sample())

    # samples with re-parametrization trick (differentiable)
    def rsample(self):
        return torch.tanh(self.normal.rsample())

    # Calculates log probability of value using the change-of-variables technique
    # (uses log1p = log(1 + x) for extra numerical stability)
    def log_prob(self, value):
        inv_value = (torch.log1p(value) - torch.log1p(-value)) / 2  # artanh(y)
        # log p(f^-1(y)) + log |det(J(f^-1(y)))|
        return self.normal.log_prob(inv_value) - torch.log1p(-value.pow(2) +
                                                             1e-6).sum(dim=1)

    @property
    def mean(self):
        return torch.tanh(self.normal.mean)

    def get_std(self):
        return self.normal.stddev
Example #10
0
    def forward(self, x):
        kl = torch.zeros(1).to(device)
        z = 0.
        # Unet encoder result
        x_enc = self.unet_encoder(x)

        # VAE regularisation
        if not self.unet:
            emb = self.vae_encoder(x)

            # Split encoder outputs into a mean and variance vector
            mu, log_var = torch.chunk(emb, 2, dim=1)

            # Make sure that the log variance is positive
            log_var = softplus(log_var)
            sigma = torch.exp(log_var / 2)

            # Instantiate a diagonal Gaussian with mean=mu, std=sigma
            # This is the approximate latent distribution q(z|x)
            posterior = Independent(Normal(loc=mu, scale=sigma), 1)
            z = posterior.rsample()

            # Instantiate a standard Gaussian with mean=mu_0, std=sigma_0
            # This is the prior distribution p(z)
            prior = Independent(Normal(loc=self.mu_0, scale=self.sigma_0), 1)

            # Estimate the KLD between q(z|x)|| p(z)
            kl = KLD(posterior, prior).sum()

        # Outputs for MSE
        xHat = self.decoder(x_enc, z)

        return kl, xHat
Example #11
0
class IndependentNormal(Distribution):
    arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
    support = constraints.positive
    has_rsample = True

    def __init__(self, loc, scale, validate_args=None):
        self.base_dist = Independent(Normal(loc=loc,
                                            scale=scale,
                                            validate_args=validate_args),
                                     len(loc.shape) - 1,
                                     validate_args=validate_args)
        super(IndependentNormal, self).__init__(self.base_dist.batch_shape,
                                                self.base_dist.event_shape,
                                                validate_args=validate_args)

    def log_prob(self, value):
        return self.base_dist.log_prob(value)

    @property
    def mean(self):
        return self.base_dist.mean

    @property
    def variance(self):
        return self.base_dist.variance

    def sample(self, sample_shape=torch.Size()):
        return self.base_dist.sample(sample_shape)

    def rsample(self, sample_shape=torch.Size()):
        return self.base_dist.rsample(sample_shape)

    def entropy(self):
        entropy = self.base_dist.entropy()
        return entropy
Example #12
0
    def generate(self,
                 synth,
                 h_0,
                 f0_hz,
                 attributes,
                 enc_frame_setting='fine',
                 n_samples=16000):
        """
        synth:          synth to generate audio
        h_0:            initial seed of RNN [batch, latent_dims]
        f0_hz:          f0 conditioning of synth [batch, f0_n_frames, 1]
        attributes:     attributes [batch, n_frames, attribute_size]
        enc_frame_setting: fft/hop size
        n_samples:      output audio length in samples
        """
        if len(h_0.shape) == 2:
            h = h_0[None, :, :]  # 1, batch, latent_dims
        else:
            h = h_0
        n_fft, hop_length = get_window_hop(enc_frame_setting)
        n_frames = math.ceil((n_samples - n_fft) / hop_length) + 1
        f0_hz = resample_frames(f0_hz,
                                n_frames)  # needs to have same dimension as z
        params_list = []
        for i in range(n_frames):
            cond = {}
            output = torch.cat([h.permute(1, 0, 2), attributes], dim=-1)
            mu, logscale = self.psi_p(output, output)
            scale = logscale.exp()
            prior = Independent(Normal(mu, scale), 1)
            prior_sample = prior.rsample()
            cond['z'] = prior_sample
            cond['f0_hz'] = f0_hz[:, i, :].unsqueeze(1)
            cond['f0_scaled'] = hz_to_midi(cond['f0_hz']) / 127.0
            # generate x
            y = self.decode(cond)
            params = synth.fill_params(y, cond)
            params_list.append(params)
            x_tilde, _outputs = synth(
                params, n_samples=n_fft)  # write exactly one frame
            cond['audio'] = x_tilde
            # encode
            cond = self.encoder(cond)
            z_enc = cond['z']
            # get psi_q
            mu, logscale = self.psi_q(z_enc, z_enc)
            psi = torch.cat([mu, logscale], dim=-1)
            # temporal model
            temp_q, h = self.temporal_q(psi, h)  # one off

        param_names = params_list[0].keys()
        final_params = {}
        for pn in param_names:
            #cat over frames
            final_params[pn] = torch.cat([par[pn] for par in params_list],
                                         dim=1)

        final_audio, _outputs = synth(final_params, n_samples=n_samples)
        return final_params, final_audio
Example #13
0
 def forward(self, obs, act=None, deterministic=False):
     # Optionally pass in an action to get the log_prob of that action
     mu = self.mu_layer(obs)
     std = torch.exp(self.log_std_layer)
     pi = Independent(Normal(mu, std), 1)
     if act is None:
         act = pi.mean if deterministic else pi.rsample()
     log_prob = pi.log_prob(act)
     return pi, act, log_prob
Example #14
0
    def generate(self,
                 synth,
                 h_0,
                 f0_hz,
                 enc_frame_setting='fine',
                 n_samples=16000):
        """
        synth:          synth to generate audio
        h_0:            initial state of RNN [batch, latent_dims]
        f0_hz:          f0 conditioning of synth [batch, f0_n_frames, 1]
        enc_frame_setting: fft/hop size
        n_samples:      output audio length in samples
        """
        h = h_0
        n_fft, hop_length = get_window_hop(enc_frame_setting)
        n_frames = math.ceil((n_samples - n_fft) / hop_length) + 1
        f0_hz = resample_frames(f0_hz,
                                n_frames)  # needs to have same dimension as z
        params_list = []
        z = torch.zeros(h_0.shape[0], n_frames, self.latent_dims).to(h.device)
        for t in range(n_frames):
            h_mu, h_scale = self.h_process(h, h)
            mu_t, logscale_t = self.psi_p(h_mu,
                                          h_scale)  # [batch, latent_size]
            scale_t = logscale_t.exp()
            prior_t = Independent(Normal(mu_t, scale_t), 1)
            prior_sample_t = prior_t.rsample()
            cond = {}
            z[:, t, :] = prior_sample_t
            cond['z'] = prior_sample_t.unsqueeze(1)
            cond['f0_hz'] = f0_hz[:, t, :].unsqueeze(1)
            cond['f0_scaled'] = hz_to_midi(cond['f0_hz']) / 127.0
            # generate x
            y = self.decode(cond)
            params = synth.fill_params(y, cond)
            params_list.append(params)
            x_tilde, _outputs = synth(
                params, n_samples=n_fft)  # write exactly one frame
            cond['audio'] = x_tilde
            # encode
            cond = self.encoder(cond)
            z_enc = cond['z'].squeeze(1)
            # get psi_q
            mu, logscale = self.psi_q(z_enc, z_enc)
            rnn_input = torch.cat([mu, logscale, prior_sample_t], dim=-1)
            # temporal model
            h = self.temporal_q(rnn_input, h)  # one off

        cond = {}
        cond['z'] = z
        cond['f0_hz'] = f0_hz
        y_params = self.decode(cond)
        params = synth.fill_params(y_params, cond)
        resyn_audio, outputs = synth(params, n_samples)
        return params, resyn_audio
Example #15
0
    def act(self, obs, deterministic=False):
        action_mean = self.forward(obs)
        normal = Normal(action_mean, torch.exp(self.log_scale))
        dist = Independent(normal, 1)
        if deterministic:
            action = action_mean
        else:
            action = dist.rsample()
        action_logprobs = dist.log_prob(torch.squeeze(action))

        return action, action_logprobs
Example #16
0
 def latent(self, conditioning):
     z_enc = conditioning['z']
     batch_size, num_frames, _ = z_enc.shape
     mu = self.linear_mu(z_enc)
     log_var = self.linear_logvar(z_enc)
     eps = torch.randn_like(mu).detach().to(mu.device)
     posterior = Independent(Normal(mu, log_var.exp().sqrt()), 1)
     posterior_sample = posterior.rsample()
     # Compute KL divergence
     prior = Independent(Normal(torch.zeros_like(mu), torch.ones_like(log_var)), 1)
     kl_div = torch.mean(kl_divergence(posterior, prior))
     return posterior_sample, kl_div, [mu, log_var.exp().sqrt()]
Example #17
0
 def temporal_model_step(self, z_enc_t, h, attribute=None):
     """
     generate z_t autoregressively
     """
     # mix with temporal info
     mu_z_t, scale_z_t = self.get_posterior(h, z_enc_t)
     scale_z_t = scale_z_t + 1e-4  # minimum
     # final posterior distribution with rnn information
     posterior_t = Independent(Normal(mu_z_t, scale_z_t), 1)
     posterior_sample_t = posterior_t.rsample()
     # prior distribution with rnn information
     mu_p_t, scale_p_t = self.get_prior(h)
     scale_p_t = scale_p_t + 1e-4  # minimum
     prior_t = Independent(Normal(mu_p_t, scale_p_t), 1)
     kl_div_t = torch.mean(kl_divergence(posterior_t, prior_t))
     return posterior_sample_t, kl_div_t, mu_z_t, scale_z_t
Example #18
0
    def forward(self, obs, deterministic=False):
        mu = self.mu_layer(obs)
        log_std = self.log_std_layer(obs)
        std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX).exp()

        # Pre-squash distribution and sample
        pi_distribution = Independent(Normal(mu, std), 1)

        act = mu if deterministic else pi_distribution.rsample()

        log_prob = pi_distribution.log_prob(act)
        squashed_action = torch.tanh(act)

        log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) +
                                        self.__eps).sum(axis=-1)
        return squashed_action, log_prob
Example #19
0
    def _mcvi_forward(self, x):
        if self.certain or not self.deterministic:
            x_mean = x if not isinstance(x, tuple) else x[0]
            x_var = x_mean * x_mean
        else:
            x_mean = x[0]
            x_var = x[1]

        W_var = self._get_var(self.W_logvar)
        bias_var = self._get_var(self.bias_logvar)

        z_mean = F.conv2d(x_mean, self.W, self.bias, self.stride, self.padding)
        z_var = F.conv2d(x_var, W_var, bias_var, self.stride, self.padding)

        dst = Independent(Normal(z_mean, z_var), 1)
        sample = dst.rsample()
        return sample, None
Example #20
0
class VAE(nn.Module):
    def __init__(self, encoder, decoder, device=None):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        if device is None:
            self.device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device

    def forward(self, x=None):

        bs = x.size(0)
        ls = self.encoder.latent_dims

        mu, sigma = self.encoder(x)
        self.pz = Independent(Normal(loc=torch.zeros(bs, ls).to(self.device),
                                     scale=torch.ones(bs, ls).to(self.device)),
                              reinterpreted_batch_ndims=1)
        self.qz_x = Independent(Normal(loc=mu, scale=torch.exp(sigma)),
                                reinterpreted_batch_ndims=1)

        self.z = self.qz_x.rsample()
        decoded = self.decoder(self.z)

        return decoded

    def compute_loss(self, x, y, scale_kl=False):

        px_z = Independent(ContinuousBernoulli(logits=y),
                           reinterpreted_batch_ndims=3)

        px = px_z.log_prob(x)
        kl = self.pz.log_prob(self.z) - self.qz_x.log_prob(self.z)

        if scale_kl:
            kl = kl * scale_kl

        loss = -(px + kl).mean()

        return loss, kl.mean().item(), px.mean().item()

    def rmse(self, input, target):
        return torch.sqrt(F.mse_loss(input, target))
    def _loss_vae(self, x):

        batch_size = x.size(0)
        encoder_output = self.encoder(x)
        pz = Independent(Normal(loc=torch.zeros(batch_size, self.latent_dim).to(self.device),
                                scale=torch.ones(batch_size, self.latent_dim).to(self.device)),
                         reinterpreted_batch_ndims=1)
        qz_x = Independent(Normal(loc=encoder_output[:, :self.latent_dim],
                                  scale=torch.exp(encoder_output[:, self.latent_dim:])),
                           reinterpreted_batch_ndims=1)

        z = qz_x.rsample()

        decoder_output = self.decoder(z)
        px_z = Independent(Bernoulli(logits=decoder_output),
                           reinterpreted_batch_ndims=1)
        loss = -(px_z.log_prob(x) + pz.log_prob(z) - qz_x.log_prob(z)).mean()
        return loss, decoder_output
Example #22
0
 def generate(self,
              synth,
              h_0,
              f0_hz,
              attributes,
              enc_frame_setting='fine',
              n_samples=16000):
     """
     synth:          synth to generate audio
     h_0:            initial state of RNN [batch, latent_dims]
     f0_hz:          f0 conditioning of synth [batch, f0_n_frames, 1]
     attributes:     attributes [batch, attribute_size] or [batch, n_frames, attribute_size]
     enc_frame_setting: fft/hop size
     n_samples:      output audio length in samples
     """
     n_fft, hop_length = get_window_hop(enc_frame_setting)
     n_frames = math.ceil((n_samples - n_fft) / hop_length) + 1
     f0_hz = resample_frames(f0_hz, n_frames).to(
         h_0.device)  # needs to have same dimension as z
     params_list = []
     z = torch.zeros(h_0.shape[0], n_frames,
                     self.latent_dims).to(h_0.device)
     if len(attributes.shape) == 2:
         attributes = attributes[:, None, :].expand(-1, n_frames, -1)
     # set up initial prior with attributes
     z_t = torch.zeros(h_0.shape[0], self.latent_dims).to(h_0.device)
     rnn_input = torch.cat([z_t, attributes[:, 0, :]], dim=-1)
     h = self.temporal(rnn_input, h_0)
     for t in range(n_frames):
         # prior distribution with rnn information
         mu_p_t, scale_p_t = self.get_prior(h)
         prior_t = Independent(Normal(mu_p_t, scale_p_t), 1)
         z_t = prior_t.rsample()
         rnn_input = torch.cat([z_t, attributes[:, t, :]], dim=-1)
         h = self.temporal(rnn_input, h)
         z[:, t, :] = z_t
     cond = {}
     cond['z'] = z
     cond['f0_hz'] = f0_hz
     y_params = self.decode(cond)
     params = synth.fill_params(y_params, cond)
     resyn_audio, outputs = synth(params, n_samples)
     return params, resyn_audio
Example #23
0
def loss_vae_normal(x, encoder, decoder):
    batch_size = x.size(0)
    encoder_output = encoder(x)
    d = encoder_output.shape[1] // 2
    pz_loc = F.sigmoid(torch.zeros(batch_size, d).to(device))
    pz_scale = torch.ones(batch_size, d).to(device)
    pz = Independent(Normal(loc=pz_loc, scale=pz_scale),
                     reinterpreted_batch_ndims=1)
    qz_x_loc = encoder_output[:, :d]
    qz_x_log_scale = encoder_output[:, d:]
    qz_x = Independent(Normal(loc=qz_x_loc, scale=qz_x_log_scale**2),
                       reinterpreted_batch_ndims=1)
    z = qz_x.rsample()
    decoder_output = decoder(z)
    optimal_sigma_observed = ((x - decoder_output)**2).mean(
        [0, 1, 2, 3], keepdim=True).sqrt()
    px_z = Independent(Normal(loc=decoder_output,
                              scale=optimal_sigma_observed),
                       reinterpreted_batch_ndims=3)
    elbo = (px_z.log_prob(x) - kl_divergence(qz_x, pz)).mean()
    return -elbo, decoder_output
Example #24
0
class IndependentRescaledBeta(Distribution):
    arg_constraints = {
        'concentration1': constraints.positive,
        'concentration0': constraints.positive
    }
    support = constraints.interval(-1., 1.)
    has_rsample = True

    def __init__(self, concentration1, concentration0, validate_args=None):
        self.base_dist = Independent(RescaledBeta(concentration1,
                                                  concentration0,
                                                  validate_args),
                                     len(concentration1.shape) - 1,
                                     validate_args=validate_args)
        super(IndependentRescaledBeta,
              self).__init__(self.base_dist.batch_shape,
                             self.base_dist.event_shape,
                             validate_args=validate_args)

    def log_prob(self, value):
        return self.base_dist.log_prob(value)

    @property
    def mean(self):
        return self.base_dist.mean

    @property
    def variance(self):
        return self.base_dist.variance

    def sample(self, sample_shape=torch.Size()):
        return self.base_dist.sample(sample_shape)

    def rsample(self, sample_shape=torch.Size()):
        return self.base_dist.rsample(sample_shape)

    def entropy(self):
        entropy = self.base_dist.entropy()
        return entropy
Example #25
0
 def latent(self, conditioning):
     """
     z_enc: [batch, frames, latent_dims]
     """
     z_enc = conditioning['z']
     batch_size, n_frames, _encdims = z_enc.shape
     mu_q, logscale_q = self.psi_q(z_enc, z_enc)
     # feed into temporal model
     h_0 = torch.randn(1, batch_size, self.latent_dims).to(
         z_enc.device) * 0.01
     temp_q = self.temporal_latent_model(mu_q, logscale_q, h_0)
     # final posterior distribution with rnn information
     mu_z, scale_z = self.mix_with_temp(mu_q, logscale_q, temp_q)
     posterior = Independent(Normal(mu_z, scale_z), 1)
     posterior_sample = posterior.rsample()
     # prior distribution with rnn information
     mu, scale = self.psi_p(temp_q, temp_q)
     scale = scale.exp()
     prior = Independent(Normal(mu, scale), 1)
     # prior = Independent(Normal(torch.zeros_like(mu), torch.ones_like(scale)), 1)
     # sum over latent dim mean over batch
     kl_div = torch.mean(kl_divergence(posterior, prior))
     return posterior_sample, kl_div, [mu_z, scale_z]
Example #26
0
    def _zero_mean_forward(self, x):
        if self.certain or not self.deterministic:
            x_mean = x if not isinstance(x, tuple) else x[0]
            x_var = x_mean * x_mean
        else:
            x_mean = x[0]
            x_var = x[1]

        W_var = torch.exp(self.log_alpha) * self.weight * self.weight

        z_mean = F.conv2d(x_mean, torch.zeros_like(self.weight), self.bias,
                          self.stride, self.padding)
        z_var = F.conv2d(x_var,
                         W_var,
                         bias=None,
                         stride=self.stride,
                         padding=self.padding)

        if self.deterministic:
            return z_mean, z_var
        else:
            dst = Independent(Normal(z_mean, z_var), 1)
            sample = dst.rsample()
            return sample, None
Example #27
0
    def loss(self, x):
        """
        returns
        1. the avergave value of negative ELBO across the minibatch x
        2. and the output of the decoder
        """
        batch_size = x.size(0)
        encoder_output = self.encoder(x)
        pz = Independent(Normal(loc=torch.zeros(batch_size,
                                                self.z_dim).to(self.device),
                                scale=torch.ones(batch_size,
                                                 self.z_dim).to(self.device)),
                         reinterpreted_batch_ndims=1)
        qz_x = Independent(Normal(loc=encoder_output[:, :self.z_dim],
                                  scale=torch.exp(
                                      encoder_output[:, self.z_dim:])),
                           reinterpreted_batch_ndims=1)

        z = qz_x.rsample()
        decoder_output = self.decoder(z)
        px_z = Independent(Bernoulli(logits=decoder_output),
                           reinterpreted_batch_ndims=1)
        loss = -(px_z.log_prob(x) + pz.log_prob(z) - qz_x.log_prob(z)).mean()
        return loss, decoder_output
Example #28
0
class TorchGaussianMixtureDistribution(TorchDistributionWrapper):
    @staticmethod
    def required_model_output_shape(action_space, model_config):
        return prod((2, model_config['custom_model_config']['num_gaussians']) +
                    action_space.shape)

    def __init__(self, inputs: List[torch.Tensor], model: TorchModelV2):
        super(TorchDistributionWrapper, self).__init__(inputs, model)
        assert len(inputs.shape) == 2
        self.batch_size = self.inputs.shape[0]
        self.num_gaussians = model.model_config['custom_model_config'][
            'num_gaussians']
        self.monte_samples = model.model_config['custom_model_config'][
            'monte_samples']
        inputs = torch.reshape(self.inputs,
                               (self.batch_size, 2, self.num_gaussians, -1))
        self.action_dim = inputs.shape[-1]
        assert not torch.isnan(inputs).any(), "Input nan aborting"
        self.means = inputs[:,
                            0, :, :]  # batch_size x num_gaussians x action_dim
        self.sigmas = torch.exp(
            inputs[:, 1, :, :])  # batch_size x num_gaussians x action_dim

        self.cat = Categorical(
            torch.ones(self.batch_size,
                       self.num_gaussians,
                       device=inputs.device,
                       requires_grad=False))
        self.normals = Independent(Normal(self.means, self.sigmas), 1)

    def logp(self, actions: torch.Tensor):
        actions = actions.view(
            self.batch_size, 1,
            -1)  # batch_size x 1 (broadcast to num gaussians) x action_dim
        mix_lps = self.cat.logits  # batch_size x num_gaussians x action_dim
        normal_lps = self.normals.log_prob(
            actions)  # batch_size x num_gaussians x action_dim
        assert not torch.isnan(mix_lps).any(), "output nan aborting"
        assert not torch.isnan(normal_lps).any(), "output nan aborting"
        return torch.logsumexp(mix_lps + normal_lps,
                               dim=1)  # reduce along num gaussians

    def deterministic_sample(self) -> torch.Tensor:
        self.last_sample = self.means[:,
                                      0, :]  # select the mode of the first gaussian
        return self.last_sample

    def __rsamples(self):
        """ Compute samples that can be differentiated through
        """
        # Using reparameterization trick i.e. rsample
        normal_samples = self.normals.rsample(
            (self.monte_samples,
             ))  # monte_samples x batch_size x num_gaussians x action_dim
        cat_samples = self.cat.sample(
            (self.monte_samples, ))  # monte_samples x batch_size
        # First we need to expand cat so that it has the same dimension as normal samples
        cat_samples = cat_samples.reshape(self.monte_samples, -1, 1,
                                          1).expand(-1, -1, -1,
                                                    self.action_dim)
        # We select the normal distribution based on the outputs of
        # the categorical distribution
        return torch.gather(normal_samples, 2, cat_samples).squeeze(
            dim=2)  # monte_samples x batch_size x action_dim

    def kl(self, q: ActionDistribution) -> torch.Tensor:
        """ KL(self || q) estimated with monte carlo sampling
        """
        rsamples = self.__rsamples().unbind(0)
        log_ratios = torch.stack(
            [self.logp(rsample) - q.logp(rsample) for rsample in rsamples])
        assert not torch.isnan(log_ratios).any(), "output nan aborting"
        return log_ratios.mean(0)

    def entropy(self) -> torch.Tensor:
        """ H(self) estimated with monte carlo sampling
        """
        rsamples = self.__rsamples().unbind(0)
        log_ps = torch.stack([-self.logp(rsample) for rsample in rsamples])
        assert not torch.isnan(log_ps).any(), "output nan aborting"
        return log_ps.mean(0)

    def sample(self):
        normal_samples = self.normals.sample(
        )  # batch_size x num_gaussians x action_dim
        cat_samples = self.cat.sample()  # batch_size
        # First we need to expand cat so that it has the same dimension as normal samples
        cat_samples = cat_samples.view(-1, 1,
                                       1).expand(-1, -1, self.action_dim)
        # We select the normal distribution based on the outputs of
        # the categorical distribution
        self.last_sample = torch.gather(normal_samples, 1,
                                        cat_samples).squeeze(
                                            dim=1)  # batch_size x action_dim
        assert len(
            self.last_sample.shape) == 2, f"shape, {self.last_sample.shape}"
        return self.last_sample