Exemplo n.º 1
0
    def kl(self, dist_a, prior=None):
        """ KL divergence of dist_a against a prior, if none then Cat(1/k)

        :param dist_a: the distribution parameters
        :param prior: prior parameters (or None)
        :returns: batch_size kl-div tensor
        :rtype: torch.Tensor

        """
        if prior is None:  # use standard uniform prior
            # return torch.sum(GumbelSoftmax._kld_categorical_uniform(
            #     dist_a['discrete']['log_q_z'], dim=self.dim
            # ), -1)
            prior = D.OneHotCategorical(logits=dist_a['discrete']['log_q_z'])
            return torch.sum(GumbelSoftmax._kl_tf_version(
                D.OneHotCategorical(logits=dist_a['discrete']['log_q_z']),
                prior
            ), -1)

        # we have two distributions provided (eg: VRNN)
        return torch.sum(GumbelSoftmax._kl_tf_version(
            D.OneHotCategorical(logits=dist_a['discrete']['log_q_z']),
            # D.OneHotCategorical(prior['discrete']['log_q_z'])
            D.OneHotCategorical(logits=prior['discrete']['log_q_z'])
        ), -1)
Exemplo n.º 2
0
    def _get_sender_lstm_output(self, inputs):
        samples = []
        batch_size = inputs.shape[0]
        sample_loss = torch.zeros(batch_size, device=self.config['device'])
        total_kl = torch.zeros(batch_size, device=self.config['device'])
        hx = torch.zeros(batch_size,
                         self.config['num_lstm_sender'],
                         device=self.config['device'])
        cx = torch.zeros(batch_size,
                         self.config['num_lstm_sender'],
                         device=self.config['device'])

        for num in range(self.config['num_binary_messages']):
            hx, cx = self.sender_cell(inputs, (hx, cx))
            output = self.sender_project(hx)
            pre_logits = self.sender_out(output)

            sample = utils.gumbel_softmax(
                pre_logits,
                self.temperature[num],
                self.config['device'],
            )

            logits_dist = dists.OneHotCategorical(logits=pre_logits)
            prior_logits = self.prior[num].unsqueeze(0)
            prior_logits = prior_logits.expand(batch_size, self.output_size)
            prior_dist = dists.OneHotCategorical(logits=prior_logits)
            kl = dists.kl_divergence(logits_dist, prior_dist)
            total_kl += kl

            samples.append(sample)
        return samples, sample_loss, total_kl
Exemplo n.º 3
0
    def decode_x(self, w, z):
        params = self.decoder_x(torch.cat((w, z), dim=-1))

        px_wz = []
        samples = []

        for indices in self.likelihood_partition:
            data_type = self.likelihood_partition[indices]

            params_subset = params[:, indices[0]:(indices[1] + 1)]

            if data_type == 'real':
                cov_diag = self.likelihood_params['lik_var'] * torch.ones_like(
                    params_subset).to(self.device)

                dist = D.Normal(loc=params_subset, scale=cov_diag.sqrt())

            elif data_type == 'categorical':
                dist = D.OneHotCategorical(logits=params_subset)
            elif data_type == 'binary':
                dist = D.Bernoulli(logits=params_subset)
            elif data_type == 'positive':
                lognormal_var = self.likelihood_params[
                    'lik_var_lognormal'] * torch.ones_like(params_subset).to(
                        self.device)

                dist = D.LogNormal(loc=params_subset,
                                   scale=lognormal_var.sqrt())
            elif data_type == 'count':
                positive_params_subset = F.softplus(params_subset)
                dist = D.Poisson(rate=positive_params_subset)
            elif data_type == 'binomial':
                num_trials = self.likelihood_params['binomial_num_trials']
                dist = D.Binomial(total_count=num_trials, logits=params_subset)
            elif data_type == 'ordinal':
                h = params_subset[:, 0:1]
                thetas = torch.cumsum(F.softplus(params_subset[:, 1:]), axis=1)

                prob_lessthans = torch.sigmoid(thetas - h)
                probs = torch.cat((prob_lessthans, torch.ones(len(prob_lessthans), 1)), axis=1) - \
                        torch.cat((torch.zeros(len(prob_lessthans), 1), prob_lessthans), axis=1)

                dist = D.OneHotCategorical(probs=probs)
            else:
                raise NotImplementedError

            samples.append(dist.sample())
            px_wz.append(dist)

        sample_x = torch.cat(samples, axis=1)

        return params, sample_x, px_wz
Exemplo n.º 4
0
    def step(self, x, model):

        x_cur = x
        a_s = []
        m_terms = []
        prop_terms = []

        for i in range(self.n_steps):
            forward_delta = self.diff_fn(x_cur, model)
            # make sure we dont choose to stay where we are!
            forward_logits = forward_delta - 1e9 * x_cur
            #print(forward_logits)
            cd_forward = dists.OneHotCategorical(
                logits=forward_logits.view(x_cur.size(0), -1))
            changes = cd_forward.sample()

            # compute probability of sampling this change
            lp_forward = cd_forward.log_prob(changes)
            # reshape to (bs, dim, nout)
            changes_r = changes.view(x_cur.size())
            # get binary indicator (bs, dim) indicating which dim was changed
            changed_ind = changes_r.sum(-1)
            # mask out cuanged dim and add in the change
            x_delta = x_cur.clone() * (1. -
                                       changed_ind[:, :, None]) + changes_r

            reverse_delta = self.diff_fn(x_delta, model)
            reverse_logits = reverse_delta - 1e9 * x_delta
            cd_reverse = dists.OneHotCategorical(
                logits=reverse_logits.view(x_delta.size(0), -1))
            reverse_changes = x_cur * changed_ind[:, :, None]

            lp_reverse = cd_reverse.log_prob(
                reverse_changes.view(x_delta.size(0), -1))

            m_term = (model(x_delta).squeeze() - model(x_cur).squeeze())
            la = m_term + lp_reverse - lp_forward
            a = (la.exp() > torch.rand_like(la)).float()
            x_cur = x_delta * a[:, None,
                                None] + x_cur * (1. - a[:, None, None])
            a_s.append(a.mean().item())
            m_terms.append(m_term.mean().item())
            prop_terms.append((lp_reverse - lp_forward).mean().item())
        self._ar = np.mean(a_s)
        self._mt = np.mean(m_terms)
        self._pt = np.mean(prop_terms)

        self._hops = (x != x_cur).float().sum(-1).sum(-1).mean().item()
        return x_cur
Exemplo n.º 5
0
    def mutual_info(self, params, eps=1e-9):
        """ Returns Ent + xent where xent is taken against hard targets.

        :param params: distribution parameters
        :param eps: tolerance
        :returns: batch_size tensor of mutual info
        :rtype: torch.Tensor

        """
        return self.config['discrete_mut_info']*self.mutual_info_monte_carlo(params)

        targets = torch.argmax(params['q_z_given_xhat']['discrete']['z_hard'].type(
            long_type(self.config['cuda'])), dim=-1)
        # soft_targets = F.softmax(
        #     params['discrete']['logits'], -1
        # ).type(long_type(self.config['cuda']))
        # targets = torch.argmax(params['discrete']['log_q_z'], -1) # 3rd change, havent tried
        # crossent_loss = -F.cross_entropy(input=params['q_z_given_xhat']['discrete']['logits'],
        crossent_loss = -F.cross_entropy(input=params['discrete']['logits'],
                                         target=targets, reduce=False)
        # ent_loss = -torch.sum(D.OneHotCategorical(logits=params['discrete']['z_hard']).entropy(), -1)
        # ent_loss = torch.sum(D.OneHotCategorical(logits=params['discrete']['z_hard']).entropy(), -1)
        # print('xent = ', crossent_loss.shape, " |ent = ", ent_loss.shape)
        ent_loss = -torch.sum(D.OneHotCategorical(logits=params['discrete']['logits']).entropy(), -1)
        return -self.config['discrete_mut_info'] * (ent_loss + crossent_loss)
Exemplo n.º 6
0
 def forward(self, x, return_latents=False):
     x = self.model(x)
     critic_score = self.critic(x)
     x = self.dist_conv(x).view(-1, x.size(1))
     dist_dis = distributions.OneHotCategorical(logits=self.dis_categorical(x))
     dist_cont = distributions.Normal(loc=self.cont_mean(x), scale=torch.exp(0.5 * self.cont_logvar(x)))
     return critic_score, dist_dis, dist_cont if return_latents is True else critic_score
Exemplo n.º 7
0
    def step(self, x, model):
        if len(self._inds) == 0:  # ran out of inds
            self._inds = self._init_inds()

        inds = self._inds[:self.block_size]
        self._inds = self._inds[self.block_size:]
        # bit flips in the hamming ball
        H = torch.tensor(hamming_ball(len(inds), min(self.hamming_dist, len(inds)))).float().to(x.device)
        H_inds = list(range(H.size(0)))
        chosen_H_inds = np.random.choice(H_inds, x.size(0))
        changes = H[chosen_H_inds]
        u = x.clone()
        u[:, inds] = changes * (1. - u[:, inds]) + (1. - changes) * u[:, inds]  # apply sampled changes U ~ p(U | X)

        logits = []
        xs = []
        for c in H:
            xc = u.clone()
            c = torch.tensor(c).float().to(xc.device)[None]
            xc[:, inds] = c * (1. - xc[:, inds]) + (1. - c) * xc[:, inds]  # apply all changes
            l = model(xc).squeeze()
            xs.append(xc[:, :, None])
            logits.append(l[:, None])

        logits = torch.cat(logits, 1)
        xs = torch.cat(xs, 2)
        dist = dists.OneHotCategorical(logits=logits)
        choices = dist.sample()

        x_new = (xs * choices[:, None, :]).sum(-1)
        return x_new
Exemplo n.º 8
0
    def rsample(self, sample_shape=torch.Size([])):
        a_sampler = D.OneHotCategorical(probs=torch.ones(len(self.a_domain)))

        probs = {}
        x_a_dists = {}
        for a in self.a_domain:
            probs[a] = {}
            for x in self.x_support:
                probs[a][x] = math.exp(self.log_prob(x, a))

            normalise = sum(probs[a].values())
            for x in self.x_support:
                probs[a][x] = probs[a][x] / normalise

            x_a_dists[a] = EmpiricalDistribution(None, domain=self.x_support, probs=probs[a])
            self.test = EmpiricalDistribution(None, domain=self.x_support, probs=probs[a])

        a_vals = a_sampler.sample_n(sample_shape.numel())

        a_counts = torch.sum(a_vals, axis=0)

        x_samples = []
        a_samples = []
        for a_c, a_vals in zip(a_counts, self.a_domain):
            a_c = int(a_c)
            x_samples.append(x_a_dists[a_vals].sample_n(a_c))
            a_samples.append(torch.Tensor([a_vals] * a_c))

        x_samples = torch.cat(x_samples).view(*sample_shape, -1)
        a_samples = torch.cat(a_samples).view(*sample_shape, -1)

        return x_samples, a_samples
Exemplo n.º 9
0
    def step(self, x, model):
        sample = x.clone()
        lp_keep = model(sample).squeeze()
        if self.rand:
            changes = dists.OneHotCategorical(
                logits=torch.zeros((self.dim, ))).sample(
                    (x.size(0), )).to(x.device)
        else:
            changes = torch.zeros((x.size(0), self.dim)).to(x.device)
            changes[:, self._i] = 1.

        sample_change = (1. - changes) * sample + changes * (1. - sample)

        lp_change = model(sample_change).squeeze()

        lp_update = lp_change - lp_keep
        update_dist = dists.Bernoulli(logits=lp_update)
        updates = update_dist.sample()
        sample = sample_change * updates[:, None] + sample * (1. -
                                                              updates[:, None])
        self.changes[self._i] = updates.mean()
        self._i = (self._i + 1) % self.dim
        self._hops = (x != sample).float().sum(-1).mean().item()
        self._ar = self._hops
        return sample
Exemplo n.º 10
0
    def step(self, x, model):
        if self.rand:
            i = np.random.randint(0, self.dim)
        else:
            i = self._i

        logits = []
        ndim = x.size(-1)

        for k in range(ndim):
            sample = x.clone()
            sample_i = torch.zeros((ndim, ))
            sample_i[k] = 1.
            sample[:, i, :] = sample_i
            lp_k = model(sample).squeeze()
            logits.append(lp_k[:, None])
        logits = torch.cat(logits, 1)
        dist = dists.OneHotCategorical(logits=logits)
        updates = dist.sample()
        sample = x.clone()
        sample[:, i, :] = updates
        self._i = (self._i + 1) % self.dim
        self._hops = ((x != sample).float().sum(-1) / 2.).sum(-1).mean().item()
        self._ar = self._hops
        return sample
Exemplo n.º 11
0
    def step(self, x, model):
        H = self.H.to(x.device)
        x_cur = x
        forward_delta = self.diff_fn(x_cur, model)
        forward_logits = forward_delta @ H.t()

        cd_forward = dists.Categorical(logits=forward_logits.detach())
        changes = cd_forward.sample()

        lp_forward = cd_forward.log_prob(changes)

        x_changes = H[changes]
        x_delta = (1. - x_cur) * x_changes + x_cur * (1. - x_changes)

        reverse_delta = self.diff_fn(x_delta.detach(), model)
        reverse_logits = reverse_delta @ H.t()
        cd_reverse = dists.OneHotCategorical(logits=reverse_logits.detach())

        lp_reverse = cd_reverse.log_prob(changes)

        m_term = (model(x_delta).squeeze() - model(x_cur).squeeze())
        la = m_term + lp_reverse - lp_forward
        a = (la.exp() > torch.rand_like(la)).float()
        x_cur = x_delta * a[:, None] + x_cur * (1. - a[:, None])
        return x_cur
    def generate(self,
                 decode_fn,
                 prior: torch.Tensor,
                 length=2048,
                 tf_board_writer: SummaryWriter = None):
        decode_array = prior
        for i in Bar('generating').iter(range(min(self.max_seq, length))):
            if decode_array.shape[1] >= self.max_seq:
                break
            _, _, look_ahead_mask = \
                utils.get_masked_with_pad_tensor(decode_array.shape[1], decode_array, decode_array)

            # result, _ = self.forward(decode_array, lookup_mask=look_ahead_mask)
            result, _ = decode_fn(decode_array, look_ahead_mask)
            result = self.fc(result)
            result = result.softmax(-1)

            if tf_board_writer:
                tf_board_writer.add_image("logits", result, global_step=i)

            u = random.uniform(0, 1)
            if u > 1:
                result = result[:, -1].argmax(-1).to(torch.int32)
                decode_array = torch.cat(
                    [decode_array, result.unsqueeze(-1)], -1)
            else:
                pdf = dist.OneHotCategorical(probs=result[:, -1])
                result = pdf.sample(1)
                result = torch.transpose(result, 1, 0).to(torch.int32)
                decode_array = torch.cat((decode_array, result), dim=-1)
            del look_ahead_mask
        decode_array = decode_array[0]
        return decode_array
Exemplo n.º 13
0
    def log_likelihood(self, z, params):
        """ Log-likelihood of z induced under params.

        :param z: inferred latent z
        :param params: the params of the distribution
        :returns: log-likelihood
        :rtype: torch.Tensor

        """
        return D.OneHotCategorical(logits=params['discrete']['logits']).log_prob(z)
Exemplo n.º 14
0
    def dist_from_h(self, h, mode):
        logits_separated = torch.reshape(h, (-1, self.N, self.K))
        logits_separated_mean_zero = logits_separated - torch.mean(logits_separated, dim=-1, keepdim=True)
        if self.z_logit_clip is not None and mode == ModeKeys.TRAIN:
            c = self.z_logit_clip
            logits = torch.clamp(logits_separated_mean_zero, min=-c, max=c)
        else:
            logits = logits_separated_mean_zero

        return td.OneHotCategorical(logits=logits)
Exemplo n.º 15
0
    def prior_distribution(self, batch_size, **kwargs):
        """ get a torch distrbiution prior

        :param batch_size: size of the prior
        :returns: uniform categorical
        :rtype: torch.distribution

        """
        params = self.prior_params(batch_size, **kwargs)
        return D.OneHotCategorical(logits=params['discrete']['logits'])
    def _sample_batch_from_proposal(self,
                                    batch_size,
                                    return_log_density_of_samples=False):
        # need to do n_samples passes through autoregressive net
        samples = torch.zeros(batch_size, self.autoregressive_net.input_dim)
        log_density_of_samples = torch.zeros(batch_size,
                                             self.autoregressive_net.input_dim)
        for dim in range(self.autoregressive_net.input_dim):
            # compute autoregressive outputs
            autoregressive_outputs = self.autoregressive_net(samples).reshape(
                -1, self.dim, self.autoregressive_net.output_dim_multiplier)

            # grab proposal params for dth dimensions
            proposal_params = autoregressive_outputs[..., dim,
                                                     self.context_dim:]

            # make mixture coefficients, locs, and scales for proposal
            logits = proposal_params[
                ..., :self.n_proposal_mixture_components]  # [B, D, M]
            if logits.shape[0] == 1:
                logits = logits.reshape(self.dim,
                                        self.n_proposal_mixture_components)
            locs = proposal_params[..., self.n_proposal_mixture_components:(
                2 * self.n_proposal_mixture_components)]  # [B, D, M]
            scales = self.mixture_component_min_scale + self.scale_activation(
                proposal_params[..., (
                    2 * self.n_proposal_mixture_components):])  # [B, D, M]

            # create proposal
            if self.Component is not None:
                mixture_distribution = distributions.OneHotCategorical(
                    logits=logits, validate_args=True)
                components_distribution = self.Component(loc=locs,
                                                         scale=scales)
                self.proposal = distributions_.MixtureSameFamily(
                    mixture_distribution=mixture_distribution,
                    components_distribution=components_distribution)
                proposal_samples = self.proposal.sample((1, ))  # [S, B, D]

            else:
                self.proposal = distributions.Uniform(low=-4, high=4)
                proposal_samples = self.proposal.sample((1, batch_size, 1))
            proposal_samples = proposal_samples.permute(1, 2, 0)  # [B, D, S]
            proposal_log_density = self.proposal.log_prob(proposal_samples)
            log_density_of_samples[:, dim] += proposal_log_density.reshape(
                -1).detach()
            samples[:, dim] += proposal_samples.reshape(-1).detach()

        if return_log_density_of_samples:
            return samples, torch.sum(log_density_of_samples, dim=-1)
        else:
            return samples
Exemplo n.º 17
0
    def step(self, x, model):

        x_cur = x
        a_s = []
        m_terms = []
        prop_terms = []

        for i in range(self.n_steps):
            forward_delta = self.diff_fn(x_cur, model)
            cd_forward = dists.OneHotCategorical(logits=forward_delta)
            changes_all = cd_forward.sample((self.n_samples, ))

            lp_forward = cd_forward.log_prob(changes_all).sum(0)

            changes = (changes_all.sum(0) > 0.).float()

            x_delta = (1. - x_cur) * changes + x_cur * (1. - changes)
            self._phops = (x_delta != x).float().sum(-1).mean().item()

            reverse_delta = self.diff_fn(x_delta, model)
            cd_reverse = dists.OneHotCategorical(logits=reverse_delta)

            lp_reverse = cd_reverse.log_prob(changes_all).sum(0)

            m_term = (model(x_delta).squeeze() - model(x_cur).squeeze())
            la = m_term + lp_reverse - lp_forward
            a = (la.exp() > torch.rand_like(la)).float()
            x_cur = x_delta * a[:, None] + x_cur * (1. - a[:, None])
            a_s.append(a.mean().item())
            m_terms.append(m_term.mean().item())
            prop_terms.append((lp_reverse - lp_forward).mean().item())
        self._ar = np.mean(a_s)
        self._mt = np.mean(m_terms)
        self._pt = np.mean(prop_terms)

        self._hops = (x != x_cur).float().sum(-1).mean().item()
        return x_cur
Exemplo n.º 18
0
 def forward(self, x):
     raw_init_std = np.log(np.exp(self.init_std) - 1)
     x = self.model(x)
     if self.dist == "tanh_normal":
         mean, std = torch.chunk(x, 2, dim=-1)
         mean = self.mean_scale * torch.tanh(mean / self.mean_scale)
         std = self.softplus(std + raw_init_std) + self.min_std
         dist = td.Normal(mean, std)
         transforms = [TanhBijector()]
         dist = td.transformed_distribution.TransformedDistribution(dist, transforms)
         dist = td.Independent(dist, 1)
     elif self.dist == "onehot":
         dist = td.OneHotCategorical(logits=x)
         raise NotImplementedError("Atari not implemented yet!")
     return dist
Exemplo n.º 19
0
    def mutual_info_analytic(self, params, eps=1e-9):
        """ I(z_d; x) ~ H(z_prior, z_d) + H(z_prior), i.e. analytic version.

        :param params: parameters of distribution
        :param eps: tolerance
        :returns: batch_size mutual information (prop-to) tensor.
        :rtype: torch.Tensor

        """
        targets = torch.argmax(
            F.softmax(params['discrete']['logits'], -1), dim=-1
        ).type(long_type(self.config['cuda']))
        crossent_loss = F.cross_entropy(input=params['q_z_given_xhat']['discrete']['logits'],
                                        target=targets, reduce=False)
        ent_loss = -torch.sum(D.OneHotCategorical(logits=params['discrete']['z_hard']).entropy(), -1)
        return ent_loss + crossent_loss
Exemplo n.º 20
0
    def prior_params(self, batch_size, **kwargs):
        """ Helper to get prior parameters

        :param batch_size: the size of the batch
        :returns: a dictionary of parameters
        :rtype: dict

        """
        uniform_probs = same_type(self.config['half'], self.config['cuda'])(
            batch_size, self.output_size).zero_()
        uniform_probs += 1.0 / self.output_size
        return {
            'discrete': {
                'logits': D.OneHotCategorical(probs=uniform_probs).logits
            }
        }
    def forward(self, h, z_logit_clip=None):
        '''
        h: hidden state used to compute distribution parameter, (batch, self.K)
        '''
        self.device = h.device
        h = self.h_to_logit(h)
        logits_separated = torch.reshape(h, (-1, self.N, self.K))
        logits_separated_mean_zero = logits_separated - torch.mean(
            logits_separated, dim=-1, keepdim=True)
        if z_logit_clip is not None and self.training:
            logits = torch.clamp(logits_separated_mean_zero,
                                 min=-z_logit_clip,
                                 max=z_logit_clip)
        else:
            logits = logits_separated_mean_zero

        self.dist = td.OneHotCategorical(logits=logits)
Exemplo n.º 22
0
 def mutual_info_analytic(self, params, eps=1e-9):
     # I(z_d; x) ~ H(z_prior, z_d) + H(z_prior)
     targets = torch.argmax(params['discrete']['z_hard'].type(
         long_type(self.config['cuda'])),
                            dim=-1)
     # soft_targets = F.softmax(
     #     params['discrete']['logits'], -1
     # ).type(long_type(self.config['cuda']))
     # targets = torch.argmax(params['discrete']['log_q_z'], -1) # 3rd change, havent tried
     crossent_loss = -F.cross_entropy(
         input=params['q_z_given_xhat']['discrete']['logits'],
         target=targets,
         reduce=False)
     ent_loss = -torch.sum(
         D.OneHotCategorical(logits=params['discrete']['z_hard']).entropy(),
         -1)
     return ent_loss + crossent_loss
Exemplo n.º 23
0
    def check_test_acc(self):
        messages = []
        log_probs = np.zeros(5000)

        for num in range(self.config['num_binary_messages']):
            prior_dst = dists.OneHotCategorical(logits=self.prior[num])
            samples = prior_dst.sample((5000, ))
            log_prob = prior_dst.log_prob(samples).data.cpu().numpy()
            messages.append(samples)
            log_probs += log_prob

        messages = torch.stack(messages).permute(1, 0, 2)
        maxz = torch.argmax(messages, dim=-1, keepdim=True)
        h_z = torch.zeros(messages.shape,
                          device=self.config['device']).scatter_(-1, maxz, 1)
        _, final_preds = self.test_forward(h_z)
        final_preds = final_preds.data.cpu().numpy()
        no_rep = utils.check_correct_preds(final_preds)
        return no_rep / self.config['batch_size']
Exemplo n.º 24
0
    def __init__(self, args):
        super().__init__()
        C, H, W = args.image_dims
        x_dim = C * H * W

        # --------------------
        # p model -- SSL paper generative semi supervised model M2
        # --------------------

        self.p_y = D.OneHotCategorical(probs=1 / args.y_dim * torch.ones(1,args.y_dim, device=args.device))
        self.p_z = D.Normal(torch.tensor(0., device=args.device), torch.tensor(1., device=args.device))

        # parametrized data likelihood p(x|y,z)
        self.decoder = nn.Sequential(nn.Linear(args.z_dim + args.y_dim, args.hidden_dim),
                                     nn.Softplus(),
                                     nn.Linear(args.hidden_dim, args.hidden_dim),
                                     nn.Softplus(),
                                     nn.Linear(args.hidden_dim, x_dim))

        # --------------------
        # q model -- SSL paper eq 4
        # --------------------

        # parametrized q(y|x) = Cat(y|pi_phi(x)) -- outputs parametrization of categorical distribution
        self.encoder_y = nn.Sequential(nn.Linear(x_dim, args.hidden_dim),
                                       nn.Softplus(),
                                       nn.Linear(args.hidden_dim, args.hidden_dim),
                                       nn.Softplus(),
                                       nn.Linear(args.hidden_dim, args.y_dim))

        # parametrized q(z|x,y) = Normal(z|mu_phi(x,y), diag(sigma2_phi(x))) -- output parametrizations for mean and diagonal variance of a Normal distribution
        self.encoder_z = nn.Sequential(nn.Linear(x_dim + args.y_dim, args.hidden_dim),
                                       nn.Softplus(),
                                       nn.Linear(args.hidden_dim, args.hidden_dim),
                                       nn.Softplus(),
                                       nn.Linear(args.hidden_dim, 2 * args.z_dim))


        # initialize weights to N(0, 0.001) and biases to 0 (cf SSL section 4.4)
        for p in self.parameters():
            p.data.normal_(0, 0.001)
            if p.ndimension() == 1: p.data.fill_(0.)
Exemplo n.º 25
0
    def test_prior(self, data):
        batch_size = data.shape[0]

        input_embs = self.sender_embedding(data)
        inputs = input_embs.view(
            batch_size,
            self.config['num_digits'] * self.config['embedding_size_sender'])
        hx = torch.zeros(batch_size,
                         self.config['num_lstm_sender'],
                         device=self.config['device'])
        cx = torch.zeros(batch_size,
                         self.config['num_lstm_sender'],
                         device=self.config['device'])

        samples = []
        log_probs = 0
        post_probs = 0
        for num in range(self.config['num_binary_messages']):
            hx, cx = self.sender_cell(inputs, (hx, cx))
            output = self.sender_project(hx)
            pre_logits = self.sender_out(output)
            posterior_prob = torch.log_softmax(pre_logits, -1)
            sample = utils.gumbel_softmax(pre_logits, self.temperature[num],
                                          self.config['device'])
            samples.append(sample)

            maxz = torch.argmax(sample, dim=-1, keepdim=True)
            h_z = torch.zeros(sample.shape,
                              device=self.config['device']).scatter_(
                                  -1, maxz, 1)
            prior_dst = dists.OneHotCategorical(logits=self.prior[num])
            log_prob = prior_dst.log_prob(h_z).detach().cpu().numpy()
            log_probs += log_prob
            post_probs += posterior_prob[torch.arange(batch_size),
                                         maxz.squeeze()]

        samples = torch.stack(samples).permute(1, 0, 2)
        prior_prob = log_probs / self.config['num_binary_messages']
        post_prob = post_probs.detach().cpu().numpy(
        ) / self.config['num_binary_messages']
        return post_prob, prior_prob, samples
Exemplo n.º 26
0
    def generate(self,
                 prior: torch.Tensor,
                 length=2048,
                 tf_board_writer: SummaryWriter = None):
        decode_array = prior
        result_array = prior
        print(config)
        print(length)
        for i in Bar('generating').iter(range(length)):
            if decode_array.size(1) >= config.threshold_len:
                decode_array = decode_array[:, 1:]
            _, _, look_ahead_mask = \
                utils.get_masked_with_pad_tensor(decode_array.size(1), decode_array, decode_array, pad_token=config.pad_token)

            # result, _ = self.forward(decode_array, lookup_mask=look_ahead_mask)
            # result, _ = decode_fn(decode_array, look_ahead_mask)
            result, _ = self.Decoder(decode_array, None)
            result = self.fc(result)
            result = result.softmax(-1)

            if tf_board_writer:
                tf_board_writer.add_image("logits", result, global_step=i)

            u = 0
            if u > 1:
                result = result[:, -1].argmax(-1).to(decode_array.dtype)
                decode_array = torch.cat((decode_array, result.unsqueeze(-1)),
                                         -1)
            else:
                pdf = dist.OneHotCategorical(probs=result[:, -1])
                print("pdf: " + str(pdf))
                print("pdf shape: " + str(pdf.shape))
                result = pdf.sample().argmax(-1).unsqueeze(-1)
                print("result shape: " + str(result.shape))
                # result = torch.transpose(result, 1, 0).to(torch.int32)
                decode_array = torch.cat((decode_array, result), dim=-1)
                result_array = torch.cat((result_array, result), dim=-1)
            del look_ahead_mask
        result_array = result_array[0]
        return result_array
Exemplo n.º 27
0
 def __init__(self,
              dim,
              n_out=3,
              init_sigma=.15,
              init_bias=0.,
              learn_G=False,
              learn_sigma=False,
              learn_bias=False):
     super().__init__()
     g = ig.Graph.Lattice(dim=[dim, dim],
                          circular=True)  # Boundary conditions
     A = np.asarray(g.get_adjacency().data)  # g.get_sparse_adjacency()
     self.G = nn.Parameter(torch.tensor(A).float(), requires_grad=learn_G)
     self.sigma = nn.Parameter(torch.tensor(init_sigma).float(),
                               requires_grad=learn_sigma)
     self.bias = nn.Parameter(torch.ones(
         (dim**2, n_out)).float() * init_bias,
                              requires_grad=learn_bias)
     self.init_dist = dists.OneHotCategorical(logits=self.bias)
     self.dim = dim
     self.n_out = n_out
     self.data_dim = dim**2
Exemplo n.º 28
0
    def generate(self, prior, length=2048):
        decoded = prior
        outputs = prior

        for i in range(length):
            _, _, mask = get_masked_with_pad_tensor(decoded.size(1), decoded,
                                                    decoded, self.pad_token)

            result, _ = self.Decoder(decoded, mask)
            result = self.fc(result)
            result = result.softmax(dim=-1)

            pdf = dist.OneHotCategorical(probs=result[:, -1])
            result = pdf.sample().argmax(-1).unsqueeze(-1)

            decoded = torch.cat((decoded, result), dim=-1)
            outputs = torch.cat((outputs, result), dim=-1)

            del mask

        outputs = outputs[0]

        return outputs
Exemplo n.º 29
0
    def step(self, x, model):
        if len(self._inds) == 0:  # ran out of inds
            self._inds = self._init_inds()

        inds = self._inds[:self.block_size]
        self._inds = self._inds[self.block_size:]
        logits = []
        xs = []
        for c in itertools.product(*([[0., 1.]] * len(inds))):
            xc = x.clone()
            c = torch.tensor(c).float().to(xc.device)
            xc[:, inds] = c
            l = model(xc).squeeze()
            xs.append(xc[:, :, None])
            logits.append(l[:, None])

        logits = torch.cat(logits, 1)
        xs = torch.cat(xs, 2)
        dist = dists.OneHotCategorical(logits=logits)
        choices = dist.sample()

        x_new = (xs * choices[:, None, :]).sum(-1)
        return x_new
Exemplo n.º 30
0
import storch
import torch
import torch.distributions as td
method1 = storch.method.Reparameterization
method2 = storch.method.ScoreFunction

method1 = method1(plate_name="1",n_samples=25)
method2 = method2(plate_name="1",n_samples=25)
p1 = td.Independent(td.Normal(loc=torch.zeros([1000, 2]), scale=torch.ones([1000, 2])), 0)
p2 = td.Independent(td.OneHotCategorical(probs=torch.zeros([1000, 3]).uniform_()), 0)

samp1 = method1(p1)
samp2 = method2 (p2)
# torch.Size([25, 1000, 2])
# torch.Size([25, 1000, 3])

print(storch.cat([samp1,samp2], 2).shape)
# torch.Size([25, 1000, 5])

method1 = storch.method.Reparameterization
method2 = storch.method.UnorderedSetEstimator

method1 = method1(plate_name="1",n_samples=25)
method2 = method2(plate_name="2",k=25)
p1 = td.Independent(td.Normal(loc=torch.zeros([1000, 2]), scale=torch.ones([1000, 2])), 0)
p2 = td.Independent(td.OneHotCategorical(probs=torch.zeros([1000, 3]).uniform_()), 0)

samp1 = method1(p1)
samp2 = method2 (p2)
# torch.Size([25, 1000, 2])
# torch.Size([25, 1000, 3])