Beispiel #1
0
 def forward(self, x, cold=False):
     batch_size = x.shape[0]
     # take fft (B, 1, width*height, 2)
     x_freq = torch.rfft(x, signal_ndim=2,
                         onesided=False).view(batch_size, 1,
                                              self.input_size, 2)
     if cold:
         temperature = 0.0
     else:
         temperature = self.temperature
     logits = self.logits.expand(
         (batch_size, self.output_size, self.input_size))
     dist = RelaxedBernoulli(temperature=temperature, logits=logits)
     if cold:
         samples = dist.sample()
     else:
         samples = dist.rsample()
     # reshape so broadcasting works properly
     samples = samples.view(batch_size, self.output_size, self.input_size)
     # mask the frequencies (B, output_size, width*height, 2)
     sensed_freq = x_freq * samples
     # reshape for ifft
     sensed_freq = sensed_freq.view(-1, int(np.sqrt(self.input_size)))
     sensed_freq = sensed_freq.view(-1, self.resolution, self.resolution, 2)
     # (B*output_size, resolution, resolution)
     sensed_images = torch.irfft(sensed_freq,
                                 2,
                                 normalized=False,
                                 onesided=False)
     sensed = torch.sum(sensed_images.view(batch_size, self.output_size,
                                           self.input_size),
                        axis=-1)
     noise_scale = self.noise * torch.sqrt(sensed.detach())
     sensed += torch.randn_like(sensed) * noise_scale
     return sensed
    def init_prior(self):
        # Set up priors
        K = self.max_lag
        m = self.num_X
        rho = self.prior_rho_A
        sigma = self.prior_sigma_W

        temperature = torch.tensor([self.temperature], device=self.device)
        prior_A = rho * torch.ones(size=(K, 2 * m, 2 * m), device=self.device)
        prior_A[:, m:, :] = torch.tensor([0], device=self.device)

        # Set the diagonal
        for i in range(m):
            prior_A[:, m + i, m + i] = torch.tensor([rho], device=self.device)

        self.prior_A = RelaxedBernoulli(temperature=temperature, probs=prior_A)

        prior_W_scale = sigma * torch.ones(size=(K, 2 * m, 2 * m),
                                           device=self.device)
        prior_W_scale[:, m:, :] = torch.tensor([0], device=self.device)
        # Set the diagonal
        for i in range(m):
            prior_W_scale[:, m + i, m + i] = torch.tensor([sigma],
                                                          device=self.device)

        self.prior_W = Normal(loc=torch.zeros_like(prior_W_scale,
                                                   device=self.device),
                              scale=prior_W_scale)
    def sample(self, x):

        logits = self.inference_net(x)
        # logits = torch.tanh(logits) * 4.
        logits = torch.clamp(logits, min=-8., max=4.)

        dist = RelaxedBernoulli(torch.Tensor([1.]).cuda(), logits=logits)

        z = dist.rsample() #[B,Z]

        z = torch.clamp(z, min=.00000001, max=.9999999)

        logqz = dist.log_prob(z.detach())
        logqz = torch.sum(logqz,1)


        # z= z.cuda()
        # print (z.shape)

        # logqz = torch.sum( dist.log_prob(z.detach()), dim=1) #[B]

        # print (logqz.shape)
        # fasd
        # z, logqz = self.dist.sample(logits)

        return z, logits, logqz#, #logqz
Beispiel #4
0
 def forward(self, x, rein_flag=False):
     h1 = self.Alice(x)
     if rein_flag:
         h_distribution = RelaxedBernoulli(self.temp, h1)
         h1_nograd = h_distribution.sample()
         h1_nograd = h1_nograd.float()
         out = self.Bob(h1_nograd)
     else:
         out = self.Bob(h1)
     return out, h1
Beispiel #5
0
 def gumbel_sample(self, s):
     """
     Reparametrisation of Bernoulli distribution
     s : [batch, output_dim]
     """
     # s=F.relu(self.linear1(s))
     # p_logits=self.linear2(s)
     p_logits = self.linear(s)  # [batch, 1]
     action_dis = RelaxedBernoulli(temperature=0.8, logits=p_logits)
     action = action_dis.rsample()
     hard_action = (action > 0.5).float()
     return action + (hard_action - action).detach()  #
Beispiel #6
0
 def get_weight(self, x):
     if self.training:
         pi = self.sample_pi()
         temp = torch.Tensor([0.1])
         if torch.cuda.is_available():
             temp = temp.cuda()
         p_z = RelaxedBernoulli(temp, probs=pi)
         z = p_z.rsample(torch.Size([x.size(0)]))
     else:
         Epi = self.get_Epi()
         mask = self.get_mask(Epi=Epi)
         z = Epi * mask
     return z
Beispiel #7
0
        def forward(self, z):
            map1 = self.l1(z)
            map1 = map1.view(map1.shape[0], 128, self.init_size)
            out = self.model(map1)

            img = RelaxedBernoulli(torch.tensor([opt.temp]).type(TENSOR),
                                   probs=out).rsample()

            return img
    def __init__(self, w, p, l, temperature=0.1, validate_args=None):
        relaxed_bernoulli = RelaxedBernoulli(temperature, p)
        affine_transform = AffineTransform(w, l - w)
        super(ToeplitzBernoulliDistribution,
              self).__init__(relaxed_bernoulli, affine_transform,
                             validate_args)

        self.relaxed_bernoulli = relaxed_bernoulli
        self.affine_transform = affine_transform
Beispiel #9
0
    def forward(self, x, mask=None, loss=None):
        """
    Args:
      x (torch.FloatTensor):
        class-wise representation.
        torch.Size([n_cls, feature_dim])
      mask: torch.Size([n_cls, 1])
      loss: torch.Size([n_cls, 1])
    Returns:
      x (torch.FloatTensor):
        class-wise mask layout.
        torch.Size([n_cls, 1])
    """
        # generate states from the features at the first loop.
        if self.state is None:
            self.state = self.state_linear(x.detach())
            # state = self.static_init_state(x.size(0))

        if self.input_more:
            # detach from the graph
            mask = mask.detach()
            loss = loss.detach()

            # averaged and relative loss
            loss = loss / np.log(loss.size(0))  # scaling by log
            loss_mean = loss.mean().repeat(loss.size(0), 1)  # averaged loss
            loss_rel = loss - loss_mean  # relative loss
            loss_mean = self.preprocess(loss_mean).detach()
            loss_rel = self.preprocess(loss_rel).detach()
            step = self.step.repeat(loss.size(0), 1)
            x = torch.cat([x, mask, loss_mean, loss_rel, step], dim=1)
            # [n_cls , feature_dim + 1 + 2 + 2]
        else:
            step = self.step.repeat(loss.size(0), 1)
            x = torch.cat([x, step], dim=1)

        self.state = self.gru(x, self.state)  # [n_cls , rnn_h_dim]
        x = self.out_linear(self.state)  # [n_cls , 1]

        if self.output_more:
            mask = x[:, 0].unsqueeze(1)
            lr = (x[:, 1].mean() + self.c).exp()
        else:
            mask = x

        if self.sample_mode:
            mask = RelaxedBernoulli(self.temp, mask).rsample()
        else:
            mask = self.sigmoid(mask / self.temp)  # 0.2

        if self.output_more:
            return mask, lr
        else:
            return mask, None  # [n_cls , 1]
    def __init__(self, w, p, temperature=0.1, validate_args=None):
        relaxed_bernoulli = RelaxedBernoulli(temperature, p)
        affine_transform = AffineTransform(0, w)
        one_minus_p = AffineTransform(1, -1)
        super(BernoulliDropoutDistribution,
              self).__init__(relaxed_bernoulli,
                             ComposeTransform([one_minus_p, affine_transform]),
                             validate_args)

        self.relaxed_bernoulli = relaxed_bernoulli
        self.affine_transform = affine_transform
Beispiel #11
0
 def forward(self, x, logits, cold=False):
     batch_size = x.shape[0]
     # x_local is (B, 1, resolution*resolution)
     x_local = x.view(batch_size, -1, self.input_size)
     if cold:
         temperature = 0.001
     else:
         temperature = self.temperature
     # samples are (B, output_size, 1, resolution*resolution)
     dist = RelaxedBernoulli(temperature=temperature, logits=logits)
     if cold:
         samples = dist.sample()
     else:
         samples = dist.rsample()
     # now "mask" the pixels of the input spatially by elementwise multiply
     # masked_x is (B, output_size, C, resolution*resolution)
     masked_x = samples * x_local
     # get sensor values, shape of (B, output_size), summing across both channel and pixels
     sensed = masked_x.sum(axis=-1)
     return sensed
    def init_posterior(self):
        # Set up posteriors
        K = self.max_lag
        m = self.num_X

        rho = self.prior_rho_A
        sigma = self.prior_sigma_W

        temperature = torch.tensor([self.temperature], device=self.device)
        estimate_A = rho * torch.rand(size=(K, 2 * m, 2 * m),
                                      device=self.device)
        estimate_A[:, m:, :] = torch.tensor([0], device=self.device)
        # Set the diagonal
        for i in range(m):
            estimate_A[:, m + i, m + i] = torch.tensor([rho],
                                                       device=self.device)

        estimate_A = estimate_A.requires_grad_(True)

        self.posterior_A = RelaxedBernoulli(temperature=temperature,
                                            probs=estimate_A)

        estimate_W_scale = sigma * torch.ones(size=(K, 2 * m, 2 * m),
                                              device=self.device)
        estimate_W_scale[:, m:, :] = torch.tensor([0], device=self.device)
        # Set the diagonal
        for i in range(m):
            estimate_W_scale[:, m + i,
                             m + i] = torch.tensor([sigma], device=self.device)

        estimate_W_scale = estimate_W_scale.requires_grad_(True)

        estimate_W_loc = torch.rand(size=(K, 2 * m, 2 * m), device=self.device)
        estimate_W_loc[:, m:, :] = torch.tensor([0], device=self.device)
        # # Set the diagonal
        # for i in range(m):
        #     estimate_W_loc[:,m+i,m+i] = torch.rand(size=[1],device=self.device)

        estimate_W_loc = estimate_W_loc.requires_grad_(True)
        self.posterior_W = Normal(loc=estimate_W_loc, scale=estimate_W_scale)
Beispiel #13
0
    def forward(self, a, b):
        """
    Args:
      a, b (tensor): beta distribution parameters.

    Returns:
      z (tensor): Relaxed Beta-Bernoulli binary masks.
      pi (tensor): Bernoulli distribution parameters
    """
        # a, b = self.get_params(x)
        pi = self.sample_pi(a, b)
        temp = C(torch.Tensor([0.001]))
        z = RelaxedBernoulli(temp, probs=pi).rsample()
        return z, pi
Beispiel #14
0
    def forward(self, x, z_in):
        if len(x.size()) == 4:
            h = F.avg_pool2d(x, [x.size(2), x.size(3)])
            h = h.view(h.size(0), -1)
        else:
            h = x
        h = self.bn(h.detach())

        if self.training:
            sigma = F.softplus(self.sigma_uc)
            eps = torch.randn(self.num_gates)
            if torch.cuda.is_available():
                eps = eps.cuda()
            h = h + sigma * eps
            temp = torch.Tensor([0.1])
            if torch.cuda.is_available():
                temp = temp.cuda()
            p_z = RelaxedBernoulli(temp, probs=h.clamp(1e-10, 1 - 1e-10))
            z = p_z.rsample()
        else:
            z = h.clamp(1e-10, 1 - 1e-10)

        if len(x.size()) == 4:
            z = z.view(-1, self.num_gates, 1, 1)
            z_in = z_in.view(-1, self.num_gates, 1, 1)
        else:
            z_in = z_in.view(-1, self.num_gates)
        z = z * z_in

        if not self.training:
            num_active = (z > self.thres).float().sum(1).mean(0).item()
            self.num_active = (self.num_active*self.counter + num_active) / \
                    (self.counter + 1)
            self.counter += 1
            z[z <= self.thres] = 0.

        return x * z
    def forward(self, x):

        if self.gumbel:
            prior_exemplars = RelaxedBernoulli(self.t, logits=self.U).rsample()
        else:
            prior_exemplars = torch.sigmoid(self.U)

        prior_mu, prior_var = self.enc(prior_exemplars).chunk(2, -1)
        prior = Normal(prior_mu, (prior_var * 0.5).clamp(-5, 4).exp())

        posterior_mu, posterior_var = self.enc(x).chunk(2, -1)
        posterior = Normal(posterior_mu, (posterior_var * 0.5).clamp(-5,
                                                                     4).exp())

        z = posterior.rsample()

        x_hat = self.dec(z)

        return {'prior': prior, 'posterior': posterior, 'z': z, 'x_hat': x_hat}
Beispiel #16
0
    def get_weight(self, num_samps, training, samp_type='rel_ber'):
        temp = torch.Tensor([0.67])
        if torch.cuda.is_available():
            temp = temp.cuda()

        if training:
            pi = self.sample_pi()
            p_z = RelaxedBernoulli(temp, probs=pi)
            z = p_z.rsample(torch.Size([num_samps]))
        else:
            if samp_type == 'rel_ber':
                pi = self.sample_pi()
                p_z = RelaxedBernoulli(temp, probs=pi)
                z = p_z.rsample(torch.Size([num_samps]))
            elif samp_type == 'ber':
                pi = self.sample_pi()
                p_z = torch.distributions.Bernoulli(probs=pi)
                z = p_z.sample(torch.Size([num_samps]))
        return z, pi
Beispiel #17
0
    def f(self, x, z, logits, hard=False):

        B = x.shape[0]

        # image likelihood given b
        # b = harden(z).detach()
        x_hat = self.generator.forward(z)
        alpha = torch.sigmoid(x_hat)
        beta = Beta(alpha*self.beta_scale, (1.-alpha)*self.beta_scale)
        x_noise = torch.clamp(x + torch.FloatTensor(x.shape).uniform_(0., 1./256.).cuda(), min=1e-5, max=1-1e-5)
        logpx = beta.log_prob(x_noise) #[120,3,112,112]  # add uniform noise here
        logpx = torch.sum(logpx.view(B, -1),1) # [PB]  * self.w_logpx

        # prior is constant I think 
        # for q(b|x), we just want to increase its entropy 
        if hard:
            dist = Bernoulli(logits=logits)
        else:
            dist = RelaxedBernoulli(torch.Tensor([1.]).cuda(), logits=logits)
            
        logqb = dist.log_prob(z.detach())
        logqb = torch.sum(logqb,1)

        return logpx, logqb, alpha
    grads.append(f(samp) * logprobgrad.numpy())

print ('Grad Estimator: REINFORCE H(z), temp=10')
print ('Avg samp', np.mean(samps))
print ('Grad mean', np.mean(grads))
print ('Grad std', np.std(grads))
print ()







# REINFORCE but sampling p(z)
dist_relaxedbern = RelaxedBernoulli(torch.Tensor([1.]), bern_param)
dist_bern = Bernoulli(bern_param)

samps = []
grads = []
for i in range(n):
    samp = dist_relaxedbern.sample()
    Hsamp = Hpy(samp)

    logprob = dist_bern.log_prob(Hsamp)
    logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0]

    samps.append(Hsamp.numpy())
    grads.append(f(Hsamp.numpy()) * logprobgrad.numpy())

print ('Grad Estimator: REINFORCE but sampling p(z)')
    grads.append((f(samp.numpy()) - 0.) * logprobgrad.numpy())
    logprobgrads.append(logprobgrad.numpy())

# print (grads[:10])

print('Grad Estimator: REINFORCE')
print('Avg samp', np.mean(samps))
print('Grad mean', np.mean(grads))
print('Grad std', np.std(grads))
print('Avg logprobgrad', np.mean(logprobgrads))
print('Std logprobgrad', np.std(logprobgrads))
print()

# n=1000
# print (n)
dist = RelaxedBernoulli(torch.Tensor([5.]), bern_param)

samps = []
grads = []
logprobgrads = []
for i in range(n):
    samp = dist.sample()

    samp = torch.clamp(samp, min=.0000001, max=.99999999999999)
    # print (smp)

    logprob = dist.log_prob(samp)

    if logprob != logprob:
        print(samp, logprob)
ax = plt.subplot2grid((rows, cols), (row, col),
                      frameon=False,
                      colspan=1,
                      rowspan=1)

# sum_ = 0
# for i in range(20):
#     sum_+= i+5
# print (sum_)
# fds 290

# print (m.log_prob(2.))
xs = np.linspace(.001, .999, 30)

# dist = RelaxedBernoulli(temperature=torch.Tensor([1.]), logits=torch.tensor([0.]))
dist = RelaxedBernoulli(temperature=torch.Tensor([.2]),
                        probs=torch.tensor([0.3]))

ys = []
samps = []
for x in xs:
    # samp = dist.sample()
    # samps.append(samp.data.numpy()[0])
    # print (samp)
    # print (torch.exp(dist.log_prob(samp)))
    # print ()
    # print (m.log_prob(x))

    prob = torch.exp(dist.log_prob(torch.tensor([x]))).numpy()[0]
    # print (x)
    # print (prob)
    # component_i = ().numpy()[0]
Beispiel #21
0
    def sample_lambda0_r(self,
                         d,
                         batch_size,
                         offset=0,
                         object_locations=None,
                         object_margin=None,
                         num_objects=None,
                         gau=None,
                         max_rejections=1000,
                         margin_offset=2):
        """Sample dataset parameters perturbed by r."""
        name = d['name']
        family = d['family']
        attr_name = '{}_{}'.format(name, 'center')
        if self.wn:
            lambda_r = self.normalize_weights(name=name, prop='center')
        elif family != 'half_normal':
            lambda_r = getattr(self, attr_name)
        parameters = []
        if family == 'gaussian':
            attr_name = '{}_{}'.format(name, 'scale')
            if self.wn:
                lambda_r_scale = self.normalize_weights(name=name,
                                                        prop='scale')
            else:
                lambda_r_scale = getattr(self, attr_name)
            # lambda_r = transform_to(constraints.greater_than(
            #     1.))(lambda_r)
            # lambda_r_scale = transform_to(constraints.greater_than(
            #     self.minimum_spatial_scale))(lambda_r_scale)
            # TODO: Add constraint function here
            # w=module.weight.data
            # w=w.clamp(0.5,0.7)
            # module.weight.data=w

            if gau is None:
                gau = MultivariateNormal(loc=lambda_r,
                                         covariance_matrix=lambda_r_scale)
            if d['return_sampler']:
                return gau
            if name == 'object_location':
                if not len(object_locations):
                    return gau.rsample(), gau
                else:
                    parameters = self.rejection_sampling(
                        object_margin=object_margin,
                        margin_offset=margin_offset,
                        object_locations=object_locations,
                        max_rejections=max_rejections,
                        num_objects=num_objects,
                        gau=gau)
            else:
                raise NotImplementedError(name)
        elif family == 'normal':
            attr_name = '{}_{}'.format(name, 'scale')
            if self.wn:
                lambda_r_scale = self.normalize_weights(name=name,
                                                        prop='scale')
            else:
                lambda_r_scale = getattr(self, attr_name)
            nor = Normal(loc=lambda_r, scale=lambda_r_scale)
            if d['return_sampler']:
                return nor
            elif name == 'object_location':
                # nor.arg_constraints['scale'] = constraints.greater_than(self.minimum_spatial_scale)  # noqa
                if not len(object_locations):
                    return nor.rsample(), nor
                else:
                    parameters = self.rejection_sampling(
                        object_margin=object_margin,
                        margin_offset=margin_offset,
                        object_locations=object_locations,
                        max_rejections=max_rejections,
                        num_objects=num_objects,
                        gau=nor)
            else:
                for idx in range(batch_size):
                    parameters.append(nor.rsample())
        elif family == 'cnormal':
            attr_name = '{}_{}'.format(name, 'scale')
            if self.wn:
                lambda_r_scale = self.normalize_weights(name=name,
                                                        prop='scale')
            else:
                lambda_r_scale = getattr(self, attr_name)

            # Explicitly clamp the scale!
            lambda_r_scale = torch.clamp(lambda_r_scale,
                                         self.minimum_spatial_scale, 999.)
            nor = CNormal(loc=lambda_r, scale=lambda_r_scale)
            if d['return_sampler']:
                return nor
            elif name == 'object_location':
                # nor.arg_constraints['scale'] = constraints.greater_than(self.minimum_spatial_scale)  # noqa
                if not len(object_locations):
                    return nor.rsample(), nor
                else:
                    parameters = self.rejection_sampling(
                        object_margin=object_margin,
                        margin_offset=margin_offset,
                        object_locations=object_locations,
                        max_rejections=max_rejections,
                        num_objects=num_objects,
                        gau=nor)
            else:
                for idx in range(batch_size):
                    parameters.append(nor.rsample())
        elif family == 'abs_normal':
            attr_name = '{}_{}'.format(name, 'scale')
            if self.wn:
                lambda_r_scale = self.normalize_weights(name=name,
                                                        prop='scale')
            else:
                lambda_r_scale = getattr(self, attr_name)
            # lambda_r = transform_to(Normal.arg_constraints['loc'])(lambda_r)
            # lambda_r_scale = transform_to(Normal.arg_constraints['scale'])(lambda_r_scale)  # noqa
            # lambda_r = transforms.AbsTransform()(lambda_r)
            # lambda_r_scale = transforms.AbsTransform()(lambda_r_scale)
            # These kill grads!! # lambda_r = torch.abs(lambda_r)
            # These kill grads!! lambda_r_scale = torch.abs(lambda_r_scale)
            nor = Normal(loc=lambda_r, scale=lambda_r_scale)
            if d['return_sampler']:
                return nor
            else:
                parameters = nor.rsample([batch_size])
        elif family == 'half_normal':
            attr_name = '{}_{}'.format(name, 'scale')
            if self.wn:
                lambda_r_scale = self.normalize_weights(name=name,
                                                        prop='scale')
            else:
                lambda_r_scale = getattr(self, attr_name)
            nor = HalfNormal(scale=lambda_r_scale)
            if d['return_sampler']:
                return nor
            else:
                parameters = nor.rsample([batch_size])
        elif family == 'categorical':
            if d['return_sampler']:
                gum = RelaxedOneHotCategorical(1e-1, logits=lambda_r)
                return gum
                # return lambda sample_size: self.argmax(self.gumbel_fun(lambda_r, name=name)) + offset   # noqa
            for _ in range(batch_size):
                parameters.append(
                    self.argmax(self.gumbel_fun(lambda_r, name=name)) +
                    offset)  # noqa Use default temperature -> max
        elif family == 'relaxed_bernoulli':
            bern = RelaxedBernoulli(temperature=1e-1, logits=lambda_r)
            if d['return_sampler']:
                return bern
            else:
                parameters = bern.rsample([batch_size])
        else:
            raise NotImplementedError(
                '{} not implemented in sampling.'.format(family))
        return parameters



# sum_ = 0
# for i in range(20):
#     sum_+= i+5
# print (sum_)
# fds 290


# print (m.log_prob(2.))
xs = np.linspace(.001,.999, 30)

# dist = RelaxedBernoulli(temperature=torch.Tensor([1.]), logits=torch.tensor([0.]))
dist = RelaxedBernoulli(temperature=torch.Tensor([.2]), probs=torch.tensor([0.3]))

ys = []
samps = []
for x in xs:
    # samp = dist.sample()
    # samps.append(samp.data.numpy()[0])
    # print (samp)
    # print (torch.exp(dist.log_prob(samp)))
    # print ()
    # print (m.log_prob(x))

    prob = torch.exp(dist.log_prob(torch.tensor([x]))).numpy()[0]
    # print (x)
    # print (prob)
    # component_i = ().numpy()[0]
class TimeLatent(object):
    _logger = logging.getLogger(__name__)

    def __init__(self, num_X, max_lag, num_samples, device, prior_rho_A,
                 prior_sigma_W, temperature, sigma_Z, sigma_X):

        self.num_X = num_X
        self.max_lag = max_lag
        self.num_samples = num_samples
        self.device = device

        self.prior_rho_A = prior_rho_A
        self.temperature = temperature
        self.prior_sigma_W = prior_sigma_W
        self.prior_sigma_Z = sigma_Z * torch.ones(size=[num_X], device=device)

        self.likelihood_sigma_X = sigma_X * torch.ones(size=[num_X],
                                                       device=device)

        self.posterior_sigma_Z = sigma_Z * torch.ones(size=[num_X],
                                                      device=device)

        self.init_prior()
        self.init_posterior()
        self._logger.debug('Finished building model')

    def init_prior(self):
        # Set up priors
        K = self.max_lag
        m = self.num_X
        rho = self.prior_rho_A
        sigma = self.prior_sigma_W

        temperature = torch.tensor([self.temperature], device=self.device)
        prior_A = rho * torch.ones(size=(K, 2 * m, 2 * m), device=self.device)
        prior_A[:, m:, :] = torch.tensor([0], device=self.device)

        # Set the diagonal
        for i in range(m):
            prior_A[:, m + i, m + i] = torch.tensor([rho], device=self.device)

        self.prior_A = RelaxedBernoulli(temperature=temperature, probs=prior_A)

        prior_W_scale = sigma * torch.ones(size=(K, 2 * m, 2 * m),
                                           device=self.device)
        prior_W_scale[:, m:, :] = torch.tensor([0], device=self.device)
        # Set the diagonal
        for i in range(m):
            prior_W_scale[:, m + i, m + i] = torch.tensor([sigma],
                                                          device=self.device)

        self.prior_W = Normal(loc=torch.zeros_like(prior_W_scale,
                                                   device=self.device),
                              scale=prior_W_scale)

        # the proir over Z depends on the sample of A and W, so we don't set up Z here.

    def init_posterior(self):
        # Set up posteriors
        K = self.max_lag
        m = self.num_X

        rho = self.prior_rho_A
        sigma = self.prior_sigma_W

        temperature = torch.tensor([self.temperature], device=self.device)
        estimate_A = rho * torch.rand(size=(K, 2 * m, 2 * m),
                                      device=self.device)
        estimate_A[:, m:, :] = torch.tensor([0], device=self.device)
        # Set the diagonal
        for i in range(m):
            estimate_A[:, m + i, m + i] = torch.tensor([rho],
                                                       device=self.device)

        estimate_A = estimate_A.requires_grad_(True)

        self.posterior_A = RelaxedBernoulli(temperature=temperature,
                                            probs=estimate_A)

        estimate_W_scale = sigma * torch.ones(size=(K, 2 * m, 2 * m),
                                              device=self.device)
        estimate_W_scale[:, m:, :] = torch.tensor([0], device=self.device)
        # Set the diagonal
        for i in range(m):
            estimate_W_scale[:, m + i,
                             m + i] = torch.tensor([sigma], device=self.device)

        estimate_W_scale = estimate_W_scale.requires_grad_(True)

        estimate_W_loc = torch.rand(size=(K, 2 * m, 2 * m), device=self.device)
        estimate_W_loc[:, m:, :] = torch.tensor([0], device=self.device)
        # # Set the diagonal
        # for i in range(m):
        #     estimate_W_loc[:,m+i,m+i] = torch.rand(size=[1],device=self.device)

        estimate_W_loc = estimate_W_loc.requires_grad_(True)
        self.posterior_W = Normal(loc=estimate_W_loc, scale=estimate_W_scale)

    def ln_p_AWZ(self, A, W, Z):

        K = self.max_lag
        m = self.num_X
        ln_p_A = self.prior_A.log_prob(A)[:, :m, :].sum() + sum([
            torch.diagonal((self.prior_A.log_prob(A)[i, m:, m:]), 0)
            for i in range(K)
        ]).sum()
        ln_p_W = self.prior_W.log_prob(W)[:, :m, :].sum() + sum([
            torch.diagonal((self.prior_W.log_prob(W)[i, m:, m:]), 0)
            for i in range(K)
        ]).sum()

        # store the distributions from pZ(1),pZ(2)....pZ(T)
        # p_Z = []

        sigma_Z = self.prior_sigma_Z
        p_Z1 = Normal(loc=torch.zeros_like(sigma_Z, device=self.device),
                      scale=sigma_Z)
        # p_Z.append(p_Z1)
        ln_p_Z1 = p_Z1.log_prob(Z[0])

        ln_p_ZK = torch.zeros(size=[m], device=self.device)

        for t in range(2, K + 1):
            A_22 = A[:t - 1, m:, m:]
            W_22 = W[:t - 1, m:, m:]

            mean_t = []
            for i in range(1, t):
                A_22_i = torch.diagonal(A_22[i - 1])
                W_22_i = torch.diagonal(W_22[i - 1])
                mean_t.append(Z[t - 1 - i] * A_22_i * W_22_i)

            p_Zt = Normal(loc=sum(mean_t), scale=sigma_Z)
            # p_Z.append(p_Zt)
            ln_p_ZK += p_Zt.log_prob(Z[t - 1])

        ln_p_ZT = torch.zeros(size=[m], device=self.device)
        T = self.num_samples
        A_22 = A[:, m:, m:]
        W_22 = W[:, m:, m:]

        mean_t = []
        for i in range(1, K + 1):
            A_22_i = torch.diagonal(A_22[i - 1])
            W_22_i = torch.diagonal(W_22[i - 1])
            mean_t.append(Z[K - i:T - i] * A_22_i * W_22_i)

        p_Zt = Normal(loc=sum(mean_t), scale=sigma_Z)
        ln_p_ZT = p_Zt.log_prob(Z[K:]).sum(0)

        return (ln_p_ZT.sum() + ln_p_ZK.sum() +
                ln_p_Z1.sum()) + ln_p_A + ln_p_W

    def ln_p_X_AWZ(self, X, A, W, Z):

        sigma = self.likelihood_sigma_X
        T = self.num_samples

        K = self.max_lag
        m = self.num_X

        Sum_X_mu = torch.tensor([0.0], device=self.device)

        # t=1
        Sum_X_mu += (X[0]**2).sum()

        # 2<=t<=K
        for t in range(2, K + 1):
            A_11 = A[:, :m, :m]
            W_11 = W[:, :m, :m]

            A_12 = A[:, :m, m:]
            W_12 = A[:, :m, m:]

            mu = torch.zeros(size=[m], device=self.device)
            for i in range(1, t):
                A_11_i = A_11[i - 1]
                W_11_i = W_11[i - 1]
                A_12_i = A_12[i - 1]
                W_12_i = W_12[i - 1]

                mu += torch.matmul(X[t - 1 - i],
                                   (A_11_i * W_11_i).t()) + torch.matmul(
                                       Z[t - 1 - i], (A_12_i * W_12_i).t())

            Sum_X_mu += ((X[t - 1] - mu)**2).sum()

        # K+1 <= t <= T
        A_11 = A[:, :m, :m]
        W_11 = W[:, :m, :m]

        A_12 = A[:, :m, m:]
        W_12 = A[:, :m, m:]
        mu = []
        for i in range(1, K + 1):
            A_11_i = A_11[i - 1]
            W_11_i = W_11[i - 1]
            A_12_i = A_12[i - 1]
            W_12_i = W_12[i - 1]
            mu.append(
                torch.matmul(X[K - i:T - i], (A_11_i * W_11_i).t()) +
                torch.matmul(Z[K - i:T - i], (A_12_i * W_12_i).t()))

        Sum_X_mu += ((X[K:T + 1] - sum(mu))**2).sum()

        return -T / 2 * torch.log(2 * torch.tensor(
            [np.math.pi], device=self.device)) - T / 2 * torch.log(
                sigma * sigma).sum() - 1 / (2 *
                                            (sigma * sigma).sum()) * Sum_X_mu

    def sample_Z(self, A, W):
        # we sample Z from q(Z)

        # store the distributions from qZ(1),qZ(2)....qZ(T)
        # q_Z = []
        # sample Z(1)
        m = self.num_X
        sigma_Z = self.posterior_sigma_Z
        q_Z1 = Normal(loc=torch.zeros_like(sigma_Z, device=self.device),
                      scale=sigma_Z)
        # q_Z.append(q_Z1)
        Z1 = q_Z1.rsample()

        ln_q_Z = torch.tensor([q_Z1.log_prob(Z1).sum()], device=self.device)

        # store the sample Z(1:T)
        T = self.num_samples
        Z = []
        Z.append(Z1)

        # sample Z(2:K)
        K = self.max_lag
        for t in range(2, K + 1):
            A_22 = A[:t - 1, m:, m:]
            W_22 = W[:t - 1, m:, m:]

            mean_t = []
            for i in range(1, t):
                A_22_i = torch.diagonal(A_22[i - 1])
                W_22_i = torch.diagonal(W_22[i - 1])
                mean_t.append(Z[t - 1 - i] * A_22_i * W_22_i)

            q_Zt = Normal(loc=sum(mean_t), scale=sigma_Z)
            # q_Z.append(q_Zt)
            Z_t = q_Zt.rsample()
            ln_q_Z += q_Zt.log_prob(Z_t).sum()

            # Normalize Z_t, otherwise it will too large and leads to large Z(t) even NAN.
            Z_t = F.normalize(Z_t, dim=0)
            Z.append(Z_t)

        # sample Z(K+1:T)
        for t in range(K + 1, T + 1):
            A_22 = A[:, m:, m:]
            W_22 = W[:, m:, m:]
            mean_t = []

            for i in range(1, K + 1):
                A_22_i = torch.diagonal(A_22[i - 1])
                W_22_i = torch.diagonal(W_22[i - 1])
                mean_t.append(Z[t - 1 - i] * A_22_i * W_22_i)

            q_Zt = Normal(loc=sum(mean_t), scale=sigma_Z)
            # q_Z.append(q_Zt)
            Z_t = q_Zt.rsample()
            ln_q_Z += q_Zt.log_prob(Z_t).sum()

            # Normalize Z_t, otherwise it will too large and leads to large Z(t) even NAN.
            Z_t = F.normalize(Z_t, dim=0)
            Z.append(Z_t)

        # self.q_Z = q_Z
        return torch.stack(Z, 0), ln_q_Z

    def ln_q_Z(self, Z):

        loss = torch.tensor([0.0], device=self.device)
        T = self.num_samples
        for i in range(T):
            Zt = Z[i]
            log_prob = self.q_Z[i].log_prob(Zt)
            loss += log_prob.sum()

        return loss

    def loss(self, X):
        """
        return: the negative ELBO
        """
        # sample
        A = self.posterior_A.rsample()
        W = self.posterior_W.rsample()
        # ln_q_Z = self.ln_q_Z(Z)
        Z, ln_q_Z = self.sample_Z(
            A, W
        )  # We calculate ln_q_Z when sample Z so that we don't need to save the q_Z, which reduces the running memory.

        # Because we assume the X won't cause Z and Zi and mutually independent, A_22 is always the diagonal matrix and A_21 is always a zero matrix.
        m = self.num_X
        K = self.max_lag
        ln_q_A = self.posterior_A.log_prob(A)[:, :m, :].sum() + sum([
            torch.diagonal((self.posterior_A.log_prob(A)[i, m:, m:]), 0)
            for i in range(K)
        ]).sum()
        ln_q_W = self.posterior_W.log_prob(W)[:, :m, :].sum() + sum([
            torch.diagonal((self.posterior_W.log_prob(W)[i, m:, m:]), 0)
            for i in range(K)
        ]).sum()

        ln_q_AWZ = ln_q_Z + ln_q_A + ln_q_W

        # Calculating L_kl
        L_kl = -(ln_q_AWZ - self.ln_p_AWZ(A, W, Z))

        # Calculating L_ell
        L_ell = self.ln_p_X_AWZ(X, A, W, Z)

        ELBO = L_kl + L_ell
        # ELBO =  L_ell
        # self._logger.info("ln_q_Z:{}, ln_q_A: {}, ln_q_W: {}, ln_p_AWZ(A,W,Z):{}, L_ell:{} ".format(ln_q_Z.item(), ln_q_A.item(), ln_q_W.item() ,self.ln_p_AWZ(A,W,Z).item(), L_ell.item()))
        loss = -ELBO
        return loss

    @property
    def logger(self):
        try:
            return self._logger
        except:
            raise NotImplementedError('self._logger does not exist!')
    def forward(self, h_t, h_s, auxi_hs, memory_lengths=None, coverage=None):
        """
        Args:
          auxi_hs = (FloatTensor): auxiliary output vectors ``(batch, src_len, dim)``
          h_t (FloatTensor): query vectors ``(batch, tgt_len, dim)``
          h_s (FloatTensor): source vectors ``(batch, src_len, dim)``
          memory_lengths (LongTensor): the source context lengths ``(batch,)``
          coverage (FloatTensor): None (not supported yet)

        Returns:
          (FloatTensor, FloatTensor):

          * Computed vector ``(tgt_len, batch, dim)``
          * Attention distribtutions for each query
            ``(tgt_len, batch, src_len)``
        """

        # one step input
        if h_t.dim() == 2:
            one_step = True
            h_t = h_t.unsqueeze(1)
        else:
            one_step = False

        batch, source_l, dim = h_s.size()
        batch_, target_l, dim_ = h_t.size()

        aeq(batch, batch_)  #  Assert all the arguments have the same value.
        aeq(dim, dim_)
        aeq(self.dim, dim)

        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            h_s += self.linear_cover(cover).view_as(h_s)
            h_s = torch.tanh(h_s)

        # print('h_t',h_t,file=filename)
        '''Main Modification'''
        # Expand the target hidden state to concatenate with the source hidden state.
        H_t = h_t.expand(-1, source_l, -1)
        concat_h = torch.cat([auxi_hs, H_t],
                             2).view(batch * source_l,
                                     dim * 2)  # (batch*source_l,dim*2)
        new_concat_h = self.linear_map(concat_h).view(batch, source_l, 1)
        # Calculate the probability of the ouput of auxiliary network.
        p = self.sigmoid(new_concat_h)  # (batch,source_l,1)

        # Get the distribution of gate which follows the Bernoulli distribution with probability p.
        # For trainning:
        G = RelaxedBernoulli(
            torch.tensor([1]).cuda(),
            p).sample()  # hyperparameter--temperature. (batch,source_l,1)
        # For testing:
        # G = Bernoulli(p)

        # e = MLP(h_s) to get the infomation of source hidden state.
        e = self.mlp_h(h_s)
        e = e.view(batch, source_l, 1)

        # Calculate the alignment score.
        align_score = (self.softmax(e, G)).transpose(
            1, 2)  # align_score--(batch,1,source)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align_score.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            # align_score.masked_fill_(~mask, -float('inf')) # the original one--it may cause nan.
            align_score.masked_fill_(~mask, -100000)  # my settings.

        align_vectors = align_score / torch.sum(
            align_score, 2, keepdim=True)  # alpha--(batch,1,source_l)
        print(align_vectors, file=filename)
        # Calculate the context vectors.
        c = torch.bmm(align_vectors, h_s)  # context_vec--(batch,1,dim)

        # concatenate context vectors with the currenct target hidden state.
        concat_c = torch.cat([c, h_t], 2).view(batch * target_l, dim * 2)

        # Get the final output hidden state.
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)
            # print('align_vectors', align_vectors.size())
            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        return attn_h, align_vectors
Beispiel #25
0
    zeros[np.arange(len(zeros)), samples] = 1.
    # print (zeros)
    # print (zeros.shape)
    samples = np.sum(zeros, axis=0) / n_samples
    # print (samples)
    ax = plt.subplot2grid((rows, cols), (cur_row, 0), frameon=False, colspan=2)
    ax.bar(['0', '1', '2', '3'], samples)
    ax.text(-.5, .5, s=r'Samples Hard Gumbel', fontsize=10, family='serif')
    cur_row += 1

    #Plot gumbel pdf
    ax = plt.subplot2grid((rows, cols), (cur_row, 0), frameon=False, colspan=2)

    # plt.plot(x,y)

    dist = RelaxedBernoulli(probs=torch.Tensor([0.5]),
                            temperature=torch.Tensor([0.5]))
    # samp = dist.sample()
    # print (samp)
    # logprob = dist.log_prob(samp)

    x = [[.01], [.1], [.3], [.5], [.7], [.9], [.99]]
    x_len = len(x)
    logprob = dist.log_prob(torch.Tensor(x))
    logprob = np.reshape(logprob.numpy(), [x_len])  #[5]
    # print (samp, torch.exp(logprob))
    x = np.reshape(np.array(x), [x_len])
    plt.plot(x, logprob)
    cur_row += 1

    # ax = plt.subplot2grid((rows,cols), (cur_row,0), frameon=False, colspan=2)
    # plt.plot(x,np.exp(logprob))
Beispiel #26
0
    def forward(self, mask_mode, feature_extraction_fn=None):
        assert isinstance(mask_mode, MaskMode)

        if feature_extraction_fn is not None:
            assert callable(feature_extraction_fn)
            output = feature_extraction_fn()
            pairwise_dist = output.pairwise_dist.detach()
            classwise_loss = output.classwise_loss.detach()
            classwise_acc = output.classwise_acc.detach()
            n_classes = output.n_classes

        # attention-based pairwse distance information
        atten = self.softmax(self.pairwse_attention(pairwise_dist))
        # [n_classes, hidden_dim]
        p = self.tanh((pairwise_dist * atten).sum(dim=1))

        # loss
        loss_class = classwise_loss / np.log(n_classes)
        loss_mean = loss_class.mean()
        loss_rel = (loss_class - loss_mean) / loss_class.std()  # [n_classes]

        # acc
        acc_class = classwise_acc
        acc_mean = classwise_acc.mean()
        acc_rel = (acc_class - acc_mean) / acc_class.std()  # [n_classes]

        # running mean of loss & acc
        loss_mean = loss_mean.repeat(len(self.m))
        acc_mean = acc_mean.repeat(len(self.m))

        if self.loss_mean is None:
            self.loss_mean = loss_mean  # [n_momentum]
            self.acc_mean = acc_mean  # [n_momentum]
        else:
            self.loss_mean = self.m * self.loss_mean + (1 - self.m) * loss_mean
            self.acc_mean = self.m * self.acc_mean + (1 - self.m) * acc_mean

        # time encoding
        self.t += 1
        time = C(torch.tensor(self.t_encoder(self.t)))

        # shared feature encoding
        #   log_loss_mean: take log to suppress too large losses
        #                  then add 1 to avoid negative infinity
        log_loss_mean = (self.loss_mean + 1).log()
        s = torch.cat([log_loss_mean, self.acc_mean, time], dim=0).detach()
        s = self.tanh(self.shared_encoder(s)).repeat(n_classes, 1)
        # s: [n_classes, hidden_dim]

        # relative feature encoding
        r = torch.stack([loss_rel, acc_rel], dim=1).detach()
        r = self.tanh(self.relative_encoder(r))
        # r: [n_classes, hidden_dim]

        # mask generation (binary classification)
        h = torch.cat([p, s + r], dim=1)
        mask_logits = self.mask_generator(h)

        # learning rate generation
        h_mean = self.tanh(h.mean(dim=0))
        lr_logits = self.lr_generator(h_mean)
        lr = F.softplus(lr_logits) * 0.1

        #####################################################################
        # on_off_style = ['softmax', 'sigmoid'][1]  # TODO: global argument

        # soft mask
        if mask_mode.dist is MaskDist.SOFT:
            # if on_off_style == 'sigmoid':
            mask = self.sigmoid(
                mask_logits[:, 0])  # TODO: logit[:, 1] is redundant
        elif mask_mode.dist is MaskDist.RL:
            # elif on_off_style == 'softmax':
            mask = self.softmax(mask_logits)  # This guy is for RL case
        # discrete mask
        #####################################################################
        elif mask_mode.dist is MaskDist.DISCRETE:
            mask = lr_logits[:,
                             0].max(dim=1)[1]  # TODO: logit[:, 1] is redundant
        # concrete mask
        elif mask_mode.dist is MaskDist.CONCRETE:
            # infer Bernoulli parameter
            mean = self.sigmoid(mask_logits[:, 0])
            sigma = F.softplus(mask_logits[:, 1]) * 0.1
            eps = torch.randn(mean.size()).to(mean.device)
            # continously relaxed Bernoulli
            probs = mean + (sigma * eps)
            temp = torch.tensor([0.1]).to(mean.device)
            mask = RelaxedBernoulli(temp, probs=probs)
            # mask = mask.rsample()
            if torch.isnan(mask.rsample()).sum() > 0:
                import pdb
                pdb.set_trace()

        return mask, lr
Beispiel #27
0
    print()
    print ('REINFORCE H(z)')
    print ('Value:', val)
    print()

    # net = NN()

    optim = torch.optim.Adam([bern_param], lr=.004)
    # optim_NN = torch.optim.Adam([net.parameters()], lr=.0004)


    steps = []
    losses4= []
    for step in range(total_steps):

        dist = RelaxedBernoulli(torch.Tensor([1.]), logits=bern_param)

        optim.zero_grad()

        zs = []
        for i in range(20):
            z = dist.rsample()
            zs.append(z)
        zs = torch.FloatTensor(zs).unsqueeze(1)

        logprob = dist.log_prob(zs.detach())
        # logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0]

        H_z = Hpy(zs)

        # print (H_z)
    samples = np.sum(zeros, axis=0) / n_samples
    # print (samples)
    ax = plt.subplot2grid((rows,cols), (cur_row,0), frameon=False, colspan=2)
    ax.bar(['0','1','2','3'],samples)
    ax.text(-.5, .5, s=r'Samples Hard Gumbel', fontsize=10, family='serif')
    cur_row+=1



    #Plot gumbel pdf
    ax = plt.subplot2grid((rows,cols), (cur_row,0), frameon=False, colspan=2)

    # plt.plot(x,y)

    
    dist = RelaxedBernoulli(probs=torch.Tensor([0.5]), temperature=torch.Tensor([0.5]))
    # samp = dist.sample()
    # print (samp)
    # logprob = dist.log_prob(samp)


    x = [[.01], [.1], [.3], [.5], [.7], [.9], [.99]]
    x_len = len(x)
    logprob = dist.log_prob(torch.Tensor(x))
    logprob = np.reshape(logprob.numpy(), [x_len]) #[5]
    # print (samp, torch.exp(logprob))
    x = np.reshape(np.array(x), [x_len])
    plt.plot(x,logprob)
    cur_row+=1