예제 #1
0
    def sample_single_stroke(self, pis, mus, sigmas, rhos, qs, gamma):
        """
        Input:
            pis[M]
            mus[M, 2]
            sigmas[M, 2]
            rhos[M]
            qs[3]
        Output:
            strokes[5]
        """
        comp_m = OneHotCategorical(logits=pis)
        comp_choice = (comp_m.sample() == 1)

        mu, sigma, rho, q = mus[comp_choice].view(-1), sigmas[
            comp_choice].view(-1), rhos[comp_choice].view(-1), qs.view(-1)

        cov = (torch.diag((sigma * sigma)) +
               (1 - torch.eye(2).to(mu.device)) * rho * torch.prod(sigma)).to(
                   device=mu.device)

        normal_m = MultivariateNormal(mu, cov)
        stroke_move = normal_m.sample().to(pis.device)  # [seq_len, 2]
        pen_states = (q == q.max(dim=0, keepdim=True)[0]).to(
            dtype=torch.float)  #[seq_len, 3]
        # print('mu',mu,'stroke_move', stroke_move, 'pen_states', pen_states)
        stroke = torch.cat(
            [stroke_move.view(-1), pen_states.view(-1)], dim=0).to(pis.device)
        return stroke
예제 #2
0
    def sample_stroke(self, pis, mus, sigmas, rhos, qs, gamma):
        """
        Input:
            pis[batch, seq_len, M]
            mus[batch, seq_len, M, 2]
            sigmas[batch, seq_len, M, 2]
            rhos[batch, seq_len, M]
            qs[batch, seq_len, 3]
        Output:
            strokes[batch, seq_len, 5]:
        """
        batch_size, seq_len, M = pis.size()
        strokes = []
        sigmas = sigmas * gamma
        # Sample for each sketch
        for i in range(batch_size):
            #print(pis[:,i,:].size(), pis[:,i,:].device)
            #print(pis.size(), mus.size(), sigmas.size(), rhos.size(), qs.size())
            comp_m = OneHotCategorical(logits=pis[i, :, :])
            comp_choice = (comp_m.sample()==1)

            mu, sigma, rho, q = mus[i,:,:,:][comp_choice], sigmas[i,:,:,:][comp_choice], rhos[i,:,:][comp_choice], qs[i,:,:]

            cov = torch.stack([torch.diag(sigma[j]*sigma[j]) + (1-torch.eye(2).to(mu.device)) * rho[j] * torch.prod(sigma[j]) for j in range(seq_len)]).to(device=mu.device)


            normal_m = MultivariateNormal(mu, cov)
            stroke_move = normal_m.sample().to(pis.device) # [seq_len, 2]
            pen_states = (q == q.max(dim=1, keepdim=True)[0]).to(dtype=torch.float)#[seq_len, 3]
            stroke = torch.cat([stroke_move, pen_states], dim=1).to(pis.device)
            strokes.append(stroke)

        return torch.stack(strokes)
예제 #3
0
 def _prior(self, inference='sample'):
     device = self.z2hdec.weight.device
     batchSize = self.images.shape[0]
     if self.varType == 'cont':
         if inference == 'sample':
             sample = torch.randn(batchSize, self.latentSize, device=device)
         else:
             sample = torch.zeros(batchSize, self.latentSize, device=device)
         mean = torch.zeros_like(sample)
         logv = torch.ones_like(sample)
         z = sample, (mean, logv)
     elif self.varType == 'gumbel':
         K = self.num_embeddings
         V = self.num_vars
         prior_probs = torch.tensor([1 / K] * K,
                                    dtype=torch.float,
                                    device=device)
         logprior = torch.log(prior_probs)
         if inference == 'sample':
             prior = OneHotCategorical(prior_probs)
             z_vs = prior.sample(sample_shape=(batchSize * V, ))
             z = z_vs.reshape([batchSize, -1])
         else:
             z_vs = prior_probs.expand(batchSize * V, -1)
             z = z_vs.reshape([batchSize, -1])
         z = (z, logprior)
     elif self.varType == 'none':
         raise Exception('Z has no prior for varType==none')
     return z
예제 #4
0
def select_action(args, logits, status='train', exploration=True, info={}):
    if args.continuous:
        act_mean = logits
        act_std = cuda_wrapper(torch.ones_like(act_mean), args.cuda)
        if status is 'train':
            return Normal(act_mean, act_std).sample()
        elif status is 'test':
            return act_mean
    else:
        if status is 'train':
            if exploration:
                if args.epsilon_softmax:
                    eps = info['softmax_eps']
                    p_a = (1 - eps) * torch.softmax(logits, dim=-1) + eps / logits.size(-1)
                    return OneHotCategorical(logits=None, probs=p_a).sample()
                elif args.gumbel_softmax:
                    return GumbelSoftmax(logits=logits).sample()
                else:
                    return OneHotCategorical(logits=logits).sample()
            else:
                if args.gumbel_softmax:
                    temperature = 1.0
                    return torch.softmax(logits/temperature, dim=-1)
                else:
                    return OneHotCategorical(logits=logits).sample()
        elif status is 'test':
            p_a = torch.softmax(logits, dim=-1)
            return  (p_a == torch.max(p_a, dim=-1, keepdim=True)[0]).float()
예제 #5
0
파일: model.py 프로젝트: kero-ly/AgnosticMC
    def __init__(self):
        super(Transform5, self).__init__()
        kernel = [[[[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]]]]
        kernel = torch.from_numpy(np.array(kernel)).float()
        self.conv1 = nn.Conv2d(1,
                               1,
                               kernel_size=(3, 3),
                               stride=(1, 1),
                               padding=(1, 1),
                               bias=False)
        self.conv1.weight = nn.Parameter(kernel)

        self.conv_trans1 = nn.Conv2d(1,
                                     1,
                                     kernel_size=(5, 5),
                                     stride=(1, 1),
                                     padding=(2, 2),
                                     bias=False)
        self.conv_trans2 = nn.Conv2d(1,
                                     1,
                                     kernel_size=(5, 5),
                                     stride=(1, 1),
                                     padding=(2, 2),
                                     bias=False)
        self.conv_trans3 = nn.Conv2d(1,
                                     1,
                                     kernel_size=(5, 5),
                                     stride=(1, 1),
                                     padding=(2, 2),
                                     bias=False)
        self.conv_trans4 = nn.Conv2d(1,
                                     1,
                                     kernel_size=(5, 5),
                                     stride=(1, 1),
                                     padding=(2, 2),
                                     bias=False)

        self.conv_smooth = nn.Conv2d(1,
                                     1,
                                     kernel_size=(3, 3),
                                     stride=(1, 1),
                                     padding=(1, 1),
                                     bias=False)
        self.conv_smooth.weight = nn.Parameter(
            torch.ones(9).cuda().view(1, 1, 3, 3) * 1 / 9.)
        self.drop = nn.Dropout(p=0.05)
        self.relu = nn.ReLU(inplace=True)

        self.one_hot1 = OneHotCategorical(torch.Tensor([0.6, 0.4]))
        self.one_hot2 = OneHotCategorical(
            torch.Tensor([
                1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24.,
                1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 0.000, 1 / 24.,
                1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24.,
                1 / 24., 1 / 24., 1 / 24., 1 / 24.
            ]))

        for param in self.parameters():
            param.requires_grad = False
예제 #6
0
 def __init__(self):
   super(Transform4, self).__init__()
   self.conv_trans = nn.Conv2d(3, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False, groups=3)
   self.one_hot2 = OneHotCategorical(torch.Tensor([1/24., 1/24., 1/24., 1/24., 1/24.,
                                                   1/24., 1/24., 1/24., 1/24., 1/24.,
                                                   1/24., 1/24., 0.000, 1/24., 1/24.,
                                                   1/24., 1/24., 1/24., 1/24., 1/24.,
                                                   1/24., 1/24., 1/24., 1/24., 1/24.]))
예제 #7
0
파일: layers.py 프로젝트: CyrilMa/ssqa
 def sample(self, probas, beta=1):
     batch_size = probas[0].size(0)
     phi = beta * sum([p.view(batch_size, self.q, self.N) for p in probas])
     phi += self.linear.weights.view(1, self.q, self.N)
     self.phi = phi
     distribution = OneHotCategorical(
         probs=F.softmax(phi, 1).permute(0, 2, 1))
     return distribution.sample().permute(0, 2, 1)
예제 #8
0
 def get_action(self, state):
     all_hp_probs, all_anchor_probs = self.forward(state)
     all_anchor_act, all_hp_act = [], []
     for layer_anchor_probs in all_anchor_probs:
         anchor_sampler = Bernoulli(layer_anchor_probs)
         layer_anchor_act = anchor_sampler.sample()
         all_anchor_act.append(layer_anchor_act)
     for hp_probs in all_hp_probs:
         sampler = OneHotCategorical(logits=hp_probs)
         all_hp_act.append(sampler.sample())
     return all_hp_act, all_anchor_act
예제 #9
0
def sample_z(args):
    # generate samples from the prior
    z_cat = OneHotCategorical(
        logits=torch.zeros(args.batch_size, args.cat_dim)).sample()
    z_noise = dist.Uniform(-1, 1).sample(
        torch.Size((args.batch_size, args.noise_dim)))
    z_cont = dist.Uniform(-1, 1).sample(
        torch.Size((args.batch_size, args.cont_dim)))

    # concatenate the incompressible noise, discrete latest, and continuous latents
    z = torch.cat([z_noise, z_cat, z_cont], dim=1)

    return z.to(args.device), z_cat.to(args.device), z_noise.to(
        args.device), z_cont.to(args.device)
예제 #10
0
    def forward(self, rating_matrix):

        cores = F.normalize(self.k_embedding.weight, dim=1)
        items = F.normalize(self.item_embedding.weight, dim=1)

        rating_matrix = F.normalize(rating_matrix)
        rating_matrix = F.dropout(rating_matrix,
                                  self.drop_out,
                                  training=self.training)

        cates_logits = torch.matmul(items, cores.transpose(0, 1)) / self.tau

        if self.nogb:
            cates = torch.softmax(cates_logits, dim=1)
        else:
            cates_dist = OneHotCategorical(logits=cates_logits)
            cates_sample = cates_dist.sample()
            cates_mode = torch.softmax(cates_logits, dim=1)
            cates = (self.training * cates_sample +
                     (1 - self.training) * cates_mode)

        probs = None
        mulist = []
        logvarlist = []
        for k in range(self.kfac):
            cates_k = cates[:, k].reshape(1, -1)
            # encoder
            x_k = rating_matrix * cates_k
            h = self.encoder(x_k)
            mu = h[:, :self.embedding_size]
            mu = F.normalize(mu, dim=1)
            logvar = h[:, self.embedding_size:]

            mulist.append(mu)
            logvarlist.append(logvar)

            z = self.reparameterize(mu, logvar)

            # decoder
            z_k = F.normalize(z, dim=1)
            logits_k = torch.matmul(z_k, items.transpose(0, 1)) / self.tau
            probs_k = torch.exp(logits_k)
            probs_k = probs_k * cates_k
            probs = (probs_k if (probs is None) else (probs + probs_k))

        logits = torch.log(probs)

        return logits, mulist, logvarlist
예제 #11
0
def log_ce_with_pg(pred, truth, r, b):  # (bs, t, c)
    all_hp_pred, all_prob_anchors = pred
    all_hp_act, all_act_anchors = truth
    loss = torch.FloatTensor([0.0])
    for hp_pred, hp_act in list(zip(all_hp_pred, all_hp_act)):
        target = hp_act.detach()
        sampler = OneHotCategorical(logits=hp_pred)
        l = torch.mean(torch.sum(-sampler.log_prob(target), dim=-1) * (b - r))
        loss += l
    for anchors_pred, anchors_act in list(
            zip(all_prob_anchors, all_act_anchors)):
        target = anchors_act.detach()
        sampler = Bernoulli(logits=anchors_pred)
        l = torch.mean(torch.sum(-sampler.log_prob(target), dim=-1) * (b - r))
        loss += l
    return loss
예제 #12
0
파일: inheritance.py 프로젝트: robw4/xmen
 def distributions(self):
     """Generate one hot and normal samples"""
     from torch.distributions import Normal
     from torch.distributions.one_hot_categorical import OneHotCategorical
     pz = Normal(torch.zeros([self.cz]), torch.ones([self.cz]))
     py = OneHotCategorical(probs=torch.ones([self.cy]) / self.cy)
     return py, pz
예제 #13
0
class Transform4(nn.Module):  # rand translation
    def __init__(self):
        super(Transform4, self).__init__()
        self.conv_trans = nn.Conv2d(3,
                                    3,
                                    kernel_size=(5, 5),
                                    stride=(1, 1),
                                    padding=(2, 2),
                                    bias=False,
                                    groups=3)
        self.one_hot2 = OneHotCategorical(
            torch.Tensor([
                1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24.,
                1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 0.000, 1 / 24.,
                1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24.,
                1 / 24., 1 / 24., 1 / 24., 1 / 24.
            ]))

    def forward(self, x):
        kernel = self.one_hot2.sample().view(1, 5, 5)  # 1x5x5
        kernel = torch.stack([kernel] * 3).cuda()  # 3x1x5x5
        self.conv_trans.weight = nn.Parameter(kernel)
        y = self.conv_trans(x)
        self.conv_trans.requires_grad = False
        return y
예제 #14
0
 def gumbel_softmax(self, logits, temperature, hard=False):
     """Sample from the Gumbel-Softmax distribution and optionally discretize.
     Args:
         logits: [batch_size, n_class] unnormalized log-probs
         temperature: non-negative scalar
         hard: if True, take argmax, but differentiate w.r.t. soft sample y
     Returns:
         [batch_size, n_class] sample from the Gumbel-Softmax distribution.
         If hard=True, then the returned sample will be one-hot, otherwise it will
         be a probabilitiy distribution that sums to 1 across classes
     """
     prob = self.gumbel_softmax_sample(logits, temperature)
     if hard:
         sampler = OneHotCategorical(prob)
         prob = sampler.sample()
     return prob
예제 #15
0
파일: rejection.py 프로젝트: CyrilMa/ssqa
class IndependantSampler(Sampler):
    def __init__(self, p):
        super(IndependantSampler, self).__init__()
        self.sampler = OneHotCategorical(p)
        self.p = p

    def sample(self, data):
        n = data.size(0)
        return self.sampler.sample_n(n).permute(0, 2, 1)
예제 #16
0
 def generate_spikes(self, neurons_group):
     spikes = OneHotCategorical(
         torch.softmax(torch.cat(
             (torch.zeros([len(neurons_group), 1]).to(self.device),
              self.potential[neurons_group - self.n_input_neurons]),
             dim=-1),
                       dim=-1)).sample()
     self.spiking_history[neurons_group, :, -1] = spikes[:,
                                                         1:].to(self.device)
예제 #17
0
 def forward(self, logits, training=True, temperature=None):
     # gumbel-softmax (training and evaluation)
     if temperature is not None:
         return F.gumbel_softmax(logits, hard=not training, tau=temperature)
     # softmax training
     elif training:
         return F.softmax(logits, dim=1)
     # softmax evaluation
     else:
         return OneHotCategorical(logits=logits).sample()
예제 #18
0
    def recon_sketches(self, pis, mus, sigmas, rhos, qs, gamma):
        """
        Input:
            pis[batch, seq_len, M]
            mus[batch, seq_len, M, 2]
            sigmas[batch, seq_len, M, 2]
            rhos[batch, seq_len, M]
            qs[batch, seq_len, 3]
        Output:
            strokes[batch, seq_len, 5]:
        """
        batch_size, seq_len, M = pis.size()
        sketches = []
        sigmas = sigmas * gamma
        # Sample for each sketch
        for i in range(batch_size):
            strokes = []
            #print(pis[:,i,:].size(), pis[:,i,:].device)
            #print(pis.size(), mus.size(), sigmas.size(), rhos.size(), qs.size())
            for j in range(seq_len):
                comp_m = OneHotCategorical(logits=pis[i, j])
                comp_choice = (comp_m.sample() == 1)

                mu, sigma, rho, q = mus[i, j][comp_choice].view(-1), sigmas[
                    i, j][comp_choice].view(-1), rhos[i, j][comp_choice].view(
                        -1), qs[i, j].view(-1)

                cov = (torch.diag(
                    (sigma * sigma)) + (1 - torch.eye(2).to(mu.device)) * rho *
                       torch.prod(sigma)).to(device=mu.device)

                normal_m = MultivariateNormal(mu, cov)
                stroke_move = normal_m.sample().to(pis.device)  # [seq_len, 2]
                pen_states = (q == q.max(dim=0, keepdim=True)[0]).to(
                    dtype=torch.float)  #[seq_len, 3]
                # print('mu',mu,'stroke_move', stroke_move, 'pen_states', pen_states)
                stroke = torch.cat([stroke_move.view(-1),
                                    pen_states.view(-1)],
                                   dim=0).to(pis.device)
                strokes.append(stroke)
            sketches.append(torch.stack(strokes))
        return torch.stack(sketches)
예제 #19
0
    def forward(self, inputs):
        batch_size = inputs.size(0)

        h, c = self.lstm(
            inputs.expand(self.len_msg, batch_size,
                          self.embedding_size).transpose(0, 1))
        logits = self.linear(h)

        dists_speaker = OneHotCategorical(logits=F.log_softmax(logits, dim=2))

        return dists_speaker
예제 #20
0
def decoding_sampler(logits, mode, tau=1, hard=False, dim=-1):
    if mode == 'REINFORCE' or mode == 'SCST':
        cat_distr = OneHotCategorical(logits=logits)
        return cat_distr.sample()
    elif mode == 'GUMBEL':
        cat_distr = RelaxedOneHotCategorical(tau, logits=logits)
        y_soft = cat_distr.rsample()
    elif mode == 'SOFTMAX':
        y_soft = F.softmax(logits, dim=1)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, device=args.device).scatter_(
            dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft

    return ret
예제 #21
0
def construct_samples(executor_reply):

    onehot_reply = {}
    sample_reply = {}

    for (k, v) in executor_reply.items():
        one_hot = OneHotCategorical(v).sample()
        onehot_reply[k] = one_hot
        sample_reply["sample_" + k] = Categorical(one_hot).sample()

    pre_log_prob_sum, pre_log_probs = compute_log_prob(sample_reply, executor_reply)
    return pre_log_prob_sum, pre_log_probs, onehot_reply, sample_reply
예제 #22
0
def build_diayn(n_skills=4, env_name="MountainCar-v0", alpha=.1):
    '''
    :param n_skills:
    :param env_name: "MountainCar-v0" or "Navigation2D"
    :return:
    '''
    env = gym.make(env_name)
    alpha = .1
    gamma = .9
    prior = OneHotCategorical(torch.ones((1, n_skills)))
    hidden_sizes = {s: [30, 30] for s in ("actor", "discriminator", "critic")}
    model = diayn.DIAYN(env, prior, hidden_sizes, alpha=alpha, gamma=gamma)
    return model
예제 #23
0
    def forward(self, tgt_x):
        batch_size = tgt_x.shape[0]
        tgt_hid = self.x_to_embd(tgt_x)
        lstm_input = torch.zeros((batch_size, NG1)).cuda()
        lstm_hid = tgt_hid.squeeze(1)
        lstm_cell = tgt_hid.squeeze(1)
        msgs = []
        msgs_value = []
        logits = []
        log_probs = 0.

        for _ in range(2):
            lstm_hid, lstm_cell = self.lstm(lstm_input, (lstm_hid, lstm_cell))
            logit = self.out_layer(lstm_hid)
            logits.append(logit)
            probs = nn.functional.softmax(logit, dim=1)
            if self.training:
                cat_distr = OneHotCategorical(probs=probs)
                msg_oht, entropy = cat_distr.sample(), cat_distr.entropy()
                self.get_entropy = entropy
            else:
                msg_oht = nn.functional.one_hot(
                    torch.argmax(probs,
                                 dim=1), num_classes=self.out_size).float()
            log_probs += torch.log((probs * msg_oht).sum(1))
            msgs.append(msg_oht)
            msgs_value.append(msg_oht.argmax(1))
            lstm_input = msg_oht

        msgs = torch.stack(msgs)
        msgs_value = torch.stack(msgs_value).transpose(0, 1)
        logits = torch.stack(logits)
        logits = logits.transpose(0, 1).reshape(batch_size, -1)

        self.get_log_probs = log_probs
        return logits, msgs_value
예제 #24
0
    def generate_input(self) -> torch.Tensor:
        noise_input = torch.randn(self.feature_spec['noise'])

        categorical_input = []
        categorical_labels = []
        for n_cat in self.feature_spec['categorical']:
            categorical_input.append(OneHotCategorical(torch.ones(n_cat)/n_cat).sample())
            categorical_labels.append(torch.argmax(categorical_input[-1]))
        categorical_input = torch.hstack(categorical_input)
        gaussian_input = torch.randn(self.feature_spec['gaussian'])
        uniform_input = Uniform(-1, 1).sample((self.feature_spec['uniform'], ))

        gen_input = torch.hstack([noise_input, categorical_input, gaussian_input, uniform_input])
        gen_input = gen_input.to(self.device)
        return gen_input, torch.tensor(categorical_labels)
예제 #25
0
def build_diayn(n_skills=4, env_name="MountainCar-v0", alpha=0.1, gamma=0.1, seed = 101):
    '''
    :param n_skills:
    :param env_name: "MountainCar-v0" or "Navigation2D"
    :alpha=0.1,
    :gamma=0.1,
    :seed = 101
    :return:
    '''
    if env_name == "Navigation2D" :
        env = Navigation2D(20)
    else :
        env = gym.make(env_name)
    prior = OneHotCategorical(torch.ones((1, n_skills)))
    hidden_sizes = {s: [30, 30] for s in ("actor", "discriminator", "critic")}
    model = DIAYN(env, prior, hidden_sizes,  alpha, gamma, seed = seed)
    return model
예제 #26
0
    def forward(self, x):
        h1 = self.act(self.fc1(x))
        h2 = self.act(self.fc2(h1))
        h3 = self.act(self.fc3(h2))
        h4 = self.act(self.fc4(h3))
        out = self.fc5(h4)
        probs1 = F.softmax(out[:, :NG1], dim=1)
        probs2 = F.softmax(out[:, NG1:], dim=1)
        distr1 = OneHotCategorical(probs=probs1)
        distr2 = OneHotCategorical(probs=probs2)
        msg_oht1 = distr1.sample()
        msg_oht2 = distr2.sample()

        self.get_log_probs = torch.log((probs1 * msg_oht1).sum(1)) + torch.log(
            (probs2 * msg_oht2).sum(1))
        self.get_entropy = distr2.entropy()
        msg1 = msg_oht1.argmax(1)
        msg2 = msg_oht2.argmax(1)
        msgs_value = torch.cat((msg1.unsqueeze(1), msg2.unsqueeze(1)), dim=1)
        return out, msgs_value
예제 #27
0
def cat_softmax(probs, mode, tau=1, hard=False, dim=-1):
    if mode == 'REINFORCE' or mode == 'SCST':
        cat_distr = OneHotCategorical(probs=probs)
        return cat_distr.sample(), cat_distr.entropy()
    elif mode == 'GUMBEL':
        cat_distr = RelaxedOneHotCategorical(tau, probs=probs)
        y_soft = cat_distr.rsample()

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(probs,
                                  device=DEVICE).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret, ret
예제 #28
0
def getNdiracs(data, N , sparse = False, flat = False, replace = True):
    
    if not sparse:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        graphcount =data.num_nodes #number of graphs in data/batch object
        totalnodecount = data.x.shape[1] #number of total nodes for each graph 
        actualnodecount = 0 #cumulative number of nodes
        diracmatrix= torch.zeros((graphcount,totalnodecount,N),device=device) #matrix with dirac pulses


        for k in range(graphcount):
            graph_nodes = data.mask[k].sum() #number of nodes in the graph
            actualnodecount += graph_nodes #might not need this, we'll see
            probabilities= torch.ones((graph_nodes.item(),1),device=device)/graph_nodes #uniform probs
            node_distribution=OneHotCategorical(probs=probabilities.squeeze())
            node_sample= node_distribution.sample(sample_shape=(N,))
            node_sample= torch.cat((node_sample,torch.zeros((N,totalnodecount-node_sample.shape[1]),device=device)),-1) #concat zeros to fit dataset shape
            diracmatrix[k,:]= torch.transpose(node_sample,dim0=-1,dim1=-2) #add everything to the final matrix
    
        return diracmatrix
    
    else:
        
            original_batch_index = data.batch
            original_edge_index = data.edge_index
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            graphcount =data.num_graphs #number of graphs in data/batch object
            diracmatrix = torch.zeros(0,device=device)
            batch_prime = torch.zeros(0,device=device).long()
            locationmatrix = torch.zeros(0,device=device).long()
            
            global_offset = 0
            for k in range(graphcount):
                graph_nodes = (data.batch == k).sum()
                #probabilities = torch.ones((graph_nodes.item(),1),device=device)/graph_nodes #uniform probs
                #node_distribution = OneHotCategorical(probs=probabilities.squeeze())
                #node_sample = node_distribution.sample(sample_shape=(N,))
                
                
#                 if flat:
#                     diracmatrix = torch.cat((diracmatrix, node_sample.view(-1)),0)
#                 else:
#                     diracmatrix = torch.cat((diracmatrix, node_sample.t(),0))
                
                #for diracmatrix
                randInt = np.random.choice(range(graph_nodes), N, replace = replace)
                node_sample = torch.zeros(N*graph_nodes,device=device)
                offs  = torch.arange(N, device=device)*graph_nodes
                dirac_locations = (offs + torch.from_numpy(randInt).to(device))
                node_sample[dirac_locations] = 1
                
                dirac_locations2 = torch.from_numpy(randInt).to(device) + global_offset
                global_offset += graph_nodes
                
                diracmatrix = torch.cat((diracmatrix, node_sample),0)
                locationmatrix = torch.cat((locationmatrix, dirac_locations2),0)
            
                
            
                #for batch prime
                dirac_indices = torch.arange(N, device=device).unsqueeze(-1).expand(-1, graph_nodes).contiguous().view(-1)
                dirac_indices = dirac_indices.long()
                dirac_indices += k*N
                batch_prime = torch.cat((batch_prime, dirac_indices))



            
#             locationmatrix = diracmatrix.nonzero()
            edge_index_prime = torch.arange(N).unsqueeze(-1).expand(-1,data.edge_index.shape[1]).contiguous().view(-1)*data.batch.shape[0]
            offset = torch.arange(N).unsqueeze(-1).expand(-1,data.edge_index.size()[1]).contiguous().view(-1)*data.batch.shape[0]
            offset_2 = torch.cat(2*[offset.unsqueeze(0)],dim = 0)
            edge_index_prime = torch.cat(N*[data.edge_index], dim = 1) + offset_2
            normalization_indices = data.batch.unsqueeze(-1).expand(-1,N).contiguous().view(-1).to(device)

            return Batch(batch = batch_prime, x = diracmatrix, edge_index = edge_index_prime,
                         y = data.y, locations = locationmatrix, norm_index = normalization_indices, batch_old = original_batch_index, edge_index_old = original_edge_index)
예제 #29
0
 def sample_action_with_prob(self, x):
     likelihood = self.forward(x)
     m = OneHotCategorical(likelihood)
     action = m.sample()
     return action, m.log_prob(action)
예제 #30
0
파일: utils.py 프로젝트: drewlinsley/genadv
def get_kl(dists, P, eps=1e-4, eta=1e-20, Fdiv='kl'):
    """Get KL divergences for different distributions."""
    kl = 0.
    for dst in dists:
        name = dst['name']
        family = dst['family']
        if family == 'gaussian' or family == 'mv_normal':
            if P.wn:
                lambda_r_mu = normalize_weights(P=P, name=name, prop='center')
                lambda_r_scale = normalize_weights(P=P, name=name, prop='scale')  # noqa
            else:
                attr_name = '{}_{}'.format(name, 'center')
                lambda_r_mu = getattr(P, attr_name)
                attr_name = '{}_{}'.format(name, 'scale')
                lambda_r_scale = getattr(P, attr_name)
            lambda_0_mu = dst['lambda_0']
            lambda_0_scale = dst['lambda_0_scale']
            lambda_r_dist = MultivariateNormal(loc=lambda_r_mu, covariance_matrix=lambda_r_scale)  # noqa
            lambda_0_dist = MultivariateNormal(loc=lambda_0_mu, covariance_matrix=lambda_0_scale)  # noqa
        elif family == 'low_mv_normal':
            if P.wn:
                lambda_r_mu = normalize_weights(P=P, name=name, prop='center')
                lambda_r_scale = normalize_weights(P=P, name=name, prop='scale')  # noqa
                lambda_r_factor = normalize_weights(P=P, name=name, prop='factor')  # noqa
            else:
                attr_name = '{}_{}'.format(name, 'center')
                lambda_r_mu = getattr(P, attr_name)
                attr_name = '{}_{}'.format(name, 'scale')
                lambda_r_scale = getattr(P, attr_name)
                attr_name = '{}_{}'.format(name, 'factor')
                lambda_r_factor = getattr(P, attr_name)
            lambda_0_mu = dst['lambda_0']
            lambda_0_scale = dst['lambda_0_scale']
            lambda_0_factor = dst['lambda_0_factor']
            lambda_r_dist = LowRankMultivariateNormal(loc=lambda_r_mu, cov_diag=lambda_r_scale, cov_factor=lambda_r_factor)  # noqa
            lambda_0_dist = LowRankMultivariateNormal(loc=lambda_0_mu, cov_diag=lambda_0_scale, cov_factor=lambda_0_factor)  # noqa
        elif family == 'normal' or family == 'abs_normal' or family == 'cnormal':
            if P.wn:
                lambda_r_mu = normalize_weights(P=P, name=name, prop='center')
                lambda_r_scale = normalize_weights(P=P, name=name, prop='scale')  # noqa
            else:
                attr_name = '{}_{}'.format(name, 'center')
                lambda_r_mu = getattr(P, attr_name)
                attr_name = '{}_{}'.format(name, 'scale')
                lambda_r_scale = getattr(P, attr_name)
            lambda_0_mu = dst['lambda_0']
            lambda_0_scale = dst['lambda_0_scale']
            lambda_r_dist = Normal(loc=lambda_r_mu, scale=lambda_r_scale)
            lambda_0_dist = Normal(loc=lambda_0_mu, scale=lambda_0_scale)
        elif family == 'half_normal':
            if P.wn:
                lambda_r_scale = normalize_weights(P=P, name=name, prop='scale')  # noqa
            attr_name = '{}_{}'.format(name, 'scale')
            lambda_r_scale = getattr(P, attr_name)
            lambda_0_scale = dst['lambda_0_scale']
            lambda_r_dist = HalfNormal(scale=lambda_r_scale)
            lambda_0_dist = HalfNormal(scale=lambda_0_scale)
        elif family == 'categorical':
            if P.wn:
                lambda_r_mu = normalize_weights(P=P, name=name, prop='center')
            attr_name = '{}_{}'.format(name, 'center')
            lambda_r_mu = getattr(P, attr_name)
            lambda_0 = dst['lambda_0']  # This is probs
            log_0 = (lambda_0 + eta).log()
            # noqa lambda_r_dist = RelaxedOneHotCategorical(temperature=1e-1, logits=lambda_r_mu)  # Log probs
            # noqa lambda_0_dist = RelaxedOneHotCategorical(temperature=1e-1, logits=log_0)
            lambda_r_dist = OneHotCategorical(logits=lambda_r_mu)  # Log probs
            lambda_0_dist = OneHotCategorical(logits=log_0)
        elif family == 'relaxed_bernoulli':
            if P.wn:
                lambda_r_mu = normalize_weights(P=P, name=name, prop='center')
            attr_name = '{}_{}'.format(name, 'center')
            lambda_r_mu = getattr(P, attr_name)
            lambda_0 = dst['lambda_0']  # This is probs
            log_0 = (lambda_0 + eta).log()
            # noqa lambda_r_dist = RelaxedBernoulli(temperature=1e-1, logits=lambda_r_mu)  # Log probs
            # noqa lambda_0_dist = RelaxedBernoulli(temperature=1e-1, logits=log_0)
            lambda_r_dist = Bernoulli(logits=lambda_r_mu)  # Log probs
            lambda_0_dist = Bernoulli(logits=log_0)
        else:
            raise NotImplementedError(
                'KL for {} is not implemented.'.format(family))
        if Fdiv == 'kl':
            it_kl = kl_divergence(p=lambda_0_dist, q=lambda_r_dist).sum()
        elif Fdiv == 'js':
            raise RuntimeError('Needs per-distribution implementation.')
            m = 0.5 * (lambda_0_dist.probs * lambda_r_dist.probs)
            p = kl_divergence(p=lambda_0_dist, q=m).sum()
            q = kl_divergence(p=lambda_r_dist, q=m).sum()
            it_kl = 0.5 * p + 0.5 * q
        else:
            raise NotImplementedError(div)
        if it_kl < -1e-4 or torch.isnan(it_kl):  # Give a numerical margin
            print(kl)
        kl = kl + it_kl
    return kl