Exemplo n.º 1
0
    def sample_stroke(self, pis, mus, sigmas, rhos, qs, gamma):
        """
        Input:
            pis[batch, seq_len, M]
            mus[batch, seq_len, M, 2]
            sigmas[batch, seq_len, M, 2]
            rhos[batch, seq_len, M]
            qs[batch, seq_len, 3]
        Output:
            strokes[batch, seq_len, 5]:
        """
        batch_size, seq_len, M = pis.size()
        strokes = []
        sigmas = sigmas * gamma
        # Sample for each sketch
        for i in range(batch_size):
            #print(pis[:,i,:].size(), pis[:,i,:].device)
            #print(pis.size(), mus.size(), sigmas.size(), rhos.size(), qs.size())
            comp_m = OneHotCategorical(logits=pis[i, :, :])
            comp_choice = (comp_m.sample()==1)

            mu, sigma, rho, q = mus[i,:,:,:][comp_choice], sigmas[i,:,:,:][comp_choice], rhos[i,:,:][comp_choice], qs[i,:,:]

            cov = torch.stack([torch.diag(sigma[j]*sigma[j]) + (1-torch.eye(2).to(mu.device)) * rho[j] * torch.prod(sigma[j]) for j in range(seq_len)]).to(device=mu.device)


            normal_m = MultivariateNormal(mu, cov)
            stroke_move = normal_m.sample().to(pis.device) # [seq_len, 2]
            pen_states = (q == q.max(dim=1, keepdim=True)[0]).to(dtype=torch.float)#[seq_len, 3]
            stroke = torch.cat([stroke_move, pen_states], dim=1).to(pis.device)
            strokes.append(stroke)

        return torch.stack(strokes)
Exemplo n.º 2
0
    def sample_single_stroke(self, pis, mus, sigmas, rhos, qs, gamma):
        """
        Input:
            pis[M]
            mus[M, 2]
            sigmas[M, 2]
            rhos[M]
            qs[3]
        Output:
            strokes[5]
        """
        comp_m = OneHotCategorical(logits=pis)
        comp_choice = (comp_m.sample() == 1)

        mu, sigma, rho, q = mus[comp_choice].view(-1), sigmas[
            comp_choice].view(-1), rhos[comp_choice].view(-1), qs.view(-1)

        cov = (torch.diag((sigma * sigma)) +
               (1 - torch.eye(2).to(mu.device)) * rho * torch.prod(sigma)).to(
                   device=mu.device)

        normal_m = MultivariateNormal(mu, cov)
        stroke_move = normal_m.sample().to(pis.device)  # [seq_len, 2]
        pen_states = (q == q.max(dim=0, keepdim=True)[0]).to(
            dtype=torch.float)  #[seq_len, 3]
        # print('mu',mu,'stroke_move', stroke_move, 'pen_states', pen_states)
        stroke = torch.cat(
            [stroke_move.view(-1), pen_states.view(-1)], dim=0).to(pis.device)
        return stroke
Exemplo n.º 3
0
class Transform4(nn.Module):  # rand translation
    def __init__(self):
        super(Transform4, self).__init__()
        self.conv_trans = nn.Conv2d(3,
                                    3,
                                    kernel_size=(5, 5),
                                    stride=(1, 1),
                                    padding=(2, 2),
                                    bias=False,
                                    groups=3)
        self.one_hot2 = OneHotCategorical(
            torch.Tensor([
                1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24.,
                1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 0.000, 1 / 24.,
                1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24.,
                1 / 24., 1 / 24., 1 / 24., 1 / 24.
            ]))

    def forward(self, x):
        kernel = self.one_hot2.sample().view(1, 5, 5)  # 1x5x5
        kernel = torch.stack([kernel] * 3).cuda()  # 3x1x5x5
        self.conv_trans.weight = nn.Parameter(kernel)
        y = self.conv_trans(x)
        self.conv_trans.requires_grad = False
        return y
Exemplo n.º 4
0
 def _prior(self, inference='sample'):
     device = self.z2hdec.weight.device
     batchSize = self.images.shape[0]
     if self.varType == 'cont':
         if inference == 'sample':
             sample = torch.randn(batchSize, self.latentSize, device=device)
         else:
             sample = torch.zeros(batchSize, self.latentSize, device=device)
         mean = torch.zeros_like(sample)
         logv = torch.ones_like(sample)
         z = sample, (mean, logv)
     elif self.varType == 'gumbel':
         K = self.num_embeddings
         V = self.num_vars
         prior_probs = torch.tensor([1 / K] * K,
                                    dtype=torch.float,
                                    device=device)
         logprior = torch.log(prior_probs)
         if inference == 'sample':
             prior = OneHotCategorical(prior_probs)
             z_vs = prior.sample(sample_shape=(batchSize * V, ))
             z = z_vs.reshape([batchSize, -1])
         else:
             z_vs = prior_probs.expand(batchSize * V, -1)
             z = z_vs.reshape([batchSize, -1])
         z = (z, logprior)
     elif self.varType == 'none':
         raise Exception('Z has no prior for varType==none')
     return z
Exemplo n.º 5
0
 def sample(self, probas, beta=1):
     batch_size = probas[0].size(0)
     phi = beta * sum([p.view(batch_size, self.q, self.N) for p in probas])
     phi += self.linear.weights.view(1, self.q, self.N)
     self.phi = phi
     distribution = OneHotCategorical(
         probs=F.softmax(phi, 1).permute(0, 2, 1))
     return distribution.sample().permute(0, 2, 1)
Exemplo n.º 6
0
    def forward(self, x):
        h1 = self.act(self.fc1(x))
        h2 = self.act(self.fc2(h1))
        h3 = self.act(self.fc3(h2))
        h4 = self.act(self.fc4(h3))
        out = self.fc5(h4)
        probs1 = F.softmax(out[:, :NG1], dim=1)
        probs2 = F.softmax(out[:, NG1:], dim=1)
        distr1 = OneHotCategorical(probs=probs1)
        distr2 = OneHotCategorical(probs=probs2)
        msg_oht1 = distr1.sample()
        msg_oht2 = distr2.sample()

        self.get_log_probs = torch.log((probs1 * msg_oht1).sum(1)) + torch.log(
            (probs2 * msg_oht2).sum(1))
        self.get_entropy = distr2.entropy()
        msg1 = msg_oht1.argmax(1)
        msg2 = msg_oht2.argmax(1)
        msgs_value = torch.cat((msg1.unsqueeze(1), msg2.unsqueeze(1)), dim=1)
        return out, msgs_value
Exemplo n.º 7
0
 def get_action(self, state):
     all_hp_probs, all_anchor_probs = self.forward(state)
     all_anchor_act, all_hp_act = [], []
     for layer_anchor_probs in all_anchor_probs:
         anchor_sampler = Bernoulli(layer_anchor_probs)
         layer_anchor_act = anchor_sampler.sample()
         all_anchor_act.append(layer_anchor_act)
     for hp_probs in all_hp_probs:
         sampler = OneHotCategorical(logits=hp_probs)
         all_hp_act.append(sampler.sample())
     return all_hp_act, all_anchor_act
Exemplo n.º 8
0
    def forward(self, rating_matrix):

        cores = F.normalize(self.k_embedding.weight, dim=1)
        items = F.normalize(self.item_embedding.weight, dim=1)

        rating_matrix = F.normalize(rating_matrix)
        rating_matrix = F.dropout(rating_matrix,
                                  self.drop_out,
                                  training=self.training)

        cates_logits = torch.matmul(items, cores.transpose(0, 1)) / self.tau

        if self.nogb:
            cates = torch.softmax(cates_logits, dim=1)
        else:
            cates_dist = OneHotCategorical(logits=cates_logits)
            cates_sample = cates_dist.sample()
            cates_mode = torch.softmax(cates_logits, dim=1)
            cates = (self.training * cates_sample +
                     (1 - self.training) * cates_mode)

        probs = None
        mulist = []
        logvarlist = []
        for k in range(self.kfac):
            cates_k = cates[:, k].reshape(1, -1)
            # encoder
            x_k = rating_matrix * cates_k
            h = self.encoder(x_k)
            mu = h[:, :self.embedding_size]
            mu = F.normalize(mu, dim=1)
            logvar = h[:, self.embedding_size:]

            mulist.append(mu)
            logvarlist.append(logvar)

            z = self.reparameterize(mu, logvar)

            # decoder
            z_k = F.normalize(z, dim=1)
            logits_k = torch.matmul(z_k, items.transpose(0, 1)) / self.tau
            probs_k = torch.exp(logits_k)
            probs_k = probs_k * cates_k
            probs = (probs_k if (probs is None) else (probs + probs_k))

        logits = torch.log(probs)

        return logits, mulist, logvarlist
Exemplo n.º 9
0
 def gumbel_softmax(self, logits, temperature, hard=False):
     """Sample from the Gumbel-Softmax distribution and optionally discretize.
     Args:
         logits: [batch_size, n_class] unnormalized log-probs
         temperature: non-negative scalar
         hard: if True, take argmax, but differentiate w.r.t. soft sample y
     Returns:
         [batch_size, n_class] sample from the Gumbel-Softmax distribution.
         If hard=True, then the returned sample will be one-hot, otherwise it will
         be a probabilitiy distribution that sums to 1 across classes
     """
     prob = self.gumbel_softmax_sample(logits, temperature)
     if hard:
         sampler = OneHotCategorical(prob)
         prob = sampler.sample()
     return prob
Exemplo n.º 10
0
def cat_softmax(probs, mode, tau=1, hard=False, dim=-1):
    if mode == 'REINFORCE' or mode == 'SCST':
        cat_distr = OneHotCategorical(probs=probs)
        return cat_distr.sample(), cat_distr.entropy()
    elif mode == 'GUMBEL':
        cat_distr = RelaxedOneHotCategorical(tau, probs=probs)
        y_soft = cat_distr.rsample()

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(probs,
                                  device=DEVICE).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret, ret
Exemplo n.º 11
0
    def recon_sketches(self, pis, mus, sigmas, rhos, qs, gamma):
        """
        Input:
            pis[batch, seq_len, M]
            mus[batch, seq_len, M, 2]
            sigmas[batch, seq_len, M, 2]
            rhos[batch, seq_len, M]
            qs[batch, seq_len, 3]
        Output:
            strokes[batch, seq_len, 5]:
        """
        batch_size, seq_len, M = pis.size()
        sketches = []
        sigmas = sigmas * gamma
        # Sample for each sketch
        for i in range(batch_size):
            strokes = []
            #print(pis[:,i,:].size(), pis[:,i,:].device)
            #print(pis.size(), mus.size(), sigmas.size(), rhos.size(), qs.size())
            for j in range(seq_len):
                comp_m = OneHotCategorical(logits=pis[i, j])
                comp_choice = (comp_m.sample() == 1)

                mu, sigma, rho, q = mus[i, j][comp_choice].view(-1), sigmas[
                    i, j][comp_choice].view(-1), rhos[i, j][comp_choice].view(
                        -1), qs[i, j].view(-1)

                cov = (torch.diag(
                    (sigma * sigma)) + (1 - torch.eye(2).to(mu.device)) * rho *
                       torch.prod(sigma)).to(device=mu.device)

                normal_m = MultivariateNormal(mu, cov)
                stroke_move = normal_m.sample().to(pis.device)  # [seq_len, 2]
                pen_states = (q == q.max(dim=0, keepdim=True)[0]).to(
                    dtype=torch.float)  #[seq_len, 3]
                # print('mu',mu,'stroke_move', stroke_move, 'pen_states', pen_states)
                stroke = torch.cat([stroke_move.view(-1),
                                    pen_states.view(-1)],
                                   dim=0).to(pis.device)
                strokes.append(stroke)
            sketches.append(torch.stack(strokes))
        return torch.stack(sketches)
Exemplo n.º 12
0
def decoding_sampler(logits, mode, tau=1, hard=False, dim=-1):
    if mode == 'REINFORCE' or mode == 'SCST':
        cat_distr = OneHotCategorical(logits=logits)
        return cat_distr.sample()
    elif mode == 'GUMBEL':
        cat_distr = RelaxedOneHotCategorical(tau, logits=logits)
        y_soft = cat_distr.rsample()
    elif mode == 'SOFTMAX':
        y_soft = F.softmax(logits, dim=1)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, device=args.device).scatter_(
            dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft

    return ret
Exemplo n.º 13
0
    def forward(self, tgt_x):
        batch_size = tgt_x.shape[0]
        tgt_hid = self.x_to_embd(tgt_x)
        lstm_input = torch.zeros((batch_size, NG1)).cuda()
        lstm_hid = tgt_hid.squeeze(1)
        lstm_cell = tgt_hid.squeeze(1)
        msgs = []
        msgs_value = []
        logits = []
        log_probs = 0.

        for _ in range(2):
            lstm_hid, lstm_cell = self.lstm(lstm_input, (lstm_hid, lstm_cell))
            logit = self.out_layer(lstm_hid)
            logits.append(logit)
            probs = nn.functional.softmax(logit, dim=1)
            if self.training:
                cat_distr = OneHotCategorical(probs=probs)
                msg_oht, entropy = cat_distr.sample(), cat_distr.entropy()
                self.get_entropy = entropy
            else:
                msg_oht = nn.functional.one_hot(
                    torch.argmax(probs,
                                 dim=1), num_classes=self.out_size).float()
            log_probs += torch.log((probs * msg_oht).sum(1))
            msgs.append(msg_oht)
            msgs_value.append(msg_oht.argmax(1))
            lstm_input = msg_oht

        msgs = torch.stack(msgs)
        msgs_value = torch.stack(msgs_value).transpose(0, 1)
        logits = torch.stack(logits)
        logits = logits.transpose(0, 1).reshape(batch_size, -1)

        self.get_log_probs = log_probs
        return logits, msgs_value
Exemplo n.º 14
0
class Transform3(nn.Module):  # 8-direction translation
    def __init__(self):
        super(Transform3, self).__init__()
        kernel_left = [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1],
                         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]]
        kernel_left = torch.from_numpy(np.array(kernel_left)).float()
        self.conv_left = nn.Conv2d(1,
                                   1,
                                   kernel_size=(5, 5),
                                   stride=(1, 1),
                                   padding=(2, 2),
                                   bias=False)
        self.conv_left.weight = nn.Parameter(kernel_left)

        kernel_right = [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0],
                          [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]]
        kernel_right = torch.from_numpy(np.array(kernel_right)).float()
        self.conv_right = nn.Conv2d(1,
                                    1,
                                    kernel_size=(5, 5),
                                    stride=(1, 1),
                                    padding=(2, 2),
                                    bias=False)
        self.conv_right.weight = nn.Parameter(kernel_right)

        kernel_up = [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
                       [0, 0, 0, 0, 0], [0, 0, 1, 0, 0]]]]
        kernel_up = torch.from_numpy(np.array(kernel_up)).float()
        self.conv_up = nn.Conv2d(1,
                                 1,
                                 kernel_size=(5, 5),
                                 stride=(1, 1),
                                 padding=(2, 2),
                                 bias=False)
        self.conv_up.weight = nn.Parameter(kernel_up)

        kernel_down = [[[[0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
                         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]]
        kernel_down = torch.from_numpy(np.array(kernel_down)).float()
        self.conv_down = nn.Conv2d(1,
                                   1,
                                   kernel_size=(5, 5),
                                   stride=(1, 1),
                                   padding=(2, 2),
                                   bias=False)
        self.conv_down.weight = nn.Parameter(kernel_down)

        kernel5 = [[[[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]]
        kernel5 = torch.from_numpy(np.array(kernel5)).float()
        self.conv5 = nn.Conv2d(1,
                               1,
                               kernel_size=(5, 5),
                               stride=(1, 1),
                               padding=(2, 2),
                               bias=False)
        self.conv5.weight = nn.Parameter(kernel5)

        kernel6 = [[[[0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]]
        kernel6 = torch.from_numpy(np.array(kernel6)).float()
        self.conv6 = nn.Conv2d(1,
                               1,
                               kernel_size=(5, 5),
                               stride=(1, 1),
                               padding=(2, 2),
                               bias=False)
        self.conv6.weight = nn.Parameter(kernel6)

        kernel7 = [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]]]]
        kernel7 = torch.from_numpy(np.array(kernel7)).float()
        self.conv7 = nn.Conv2d(1,
                               1,
                               kernel_size=(5, 5),
                               stride=(1, 1),
                               padding=(2, 2),
                               bias=False)
        self.conv7.weight = nn.Parameter(kernel7)

        kernel8 = [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]]]]
        kernel8 = torch.from_numpy(np.array(kernel8)).float()
        self.conv8 = nn.Conv2d(1,
                               1,
                               kernel_size=(5, 5),
                               stride=(1, 1),
                               padding=(2, 2),
                               bias=False)
        self.conv8.weight = nn.Parameter(kernel8)
        self.one_hot1 = OneHotCategorical(torch.Tensor([1. / 8] * 8))

        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        switch = self.one_hot1.sample().cuda()
        y = self.conv_left(x) * switch[0] + self.conv_right(x) * switch[1] + \
            self.conv_up(x)   * switch[2] + self.conv_down(x)  * switch[3] + \
            self.conv5(x)     * switch[4] + self.conv6(x)      * switch[5] + \
            self.conv7(x)     * switch[6] + self.conv8(x)      * switch[7]
        return y
Exemplo n.º 15
0
def getNdiracs(data, N , sparse = False, flat = False, replace = True):
    
    if not sparse:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        graphcount =data.num_nodes #number of graphs in data/batch object
        totalnodecount = data.x.shape[1] #number of total nodes for each graph 
        actualnodecount = 0 #cumulative number of nodes
        diracmatrix= torch.zeros((graphcount,totalnodecount,N),device=device) #matrix with dirac pulses


        for k in range(graphcount):
            graph_nodes = data.mask[k].sum() #number of nodes in the graph
            actualnodecount += graph_nodes #might not need this, we'll see
            probabilities= torch.ones((graph_nodes.item(),1),device=device)/graph_nodes #uniform probs
            node_distribution=OneHotCategorical(probs=probabilities.squeeze())
            node_sample= node_distribution.sample(sample_shape=(N,))
            node_sample= torch.cat((node_sample,torch.zeros((N,totalnodecount-node_sample.shape[1]),device=device)),-1) #concat zeros to fit dataset shape
            diracmatrix[k,:]= torch.transpose(node_sample,dim0=-1,dim1=-2) #add everything to the final matrix
    
        return diracmatrix
    
    else:
        
            original_batch_index = data.batch
            original_edge_index = data.edge_index
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            graphcount =data.num_graphs #number of graphs in data/batch object
            diracmatrix = torch.zeros(0,device=device)
            batch_prime = torch.zeros(0,device=device).long()
            locationmatrix = torch.zeros(0,device=device).long()
            
            global_offset = 0
            for k in range(graphcount):
                graph_nodes = (data.batch == k).sum()
                #probabilities = torch.ones((graph_nodes.item(),1),device=device)/graph_nodes #uniform probs
                #node_distribution = OneHotCategorical(probs=probabilities.squeeze())
                #node_sample = node_distribution.sample(sample_shape=(N,))
                
                
#                 if flat:
#                     diracmatrix = torch.cat((diracmatrix, node_sample.view(-1)),0)
#                 else:
#                     diracmatrix = torch.cat((diracmatrix, node_sample.t(),0))
                
                #for diracmatrix
                randInt = np.random.choice(range(graph_nodes), N, replace = replace)
                node_sample = torch.zeros(N*graph_nodes,device=device)
                offs  = torch.arange(N, device=device)*graph_nodes
                dirac_locations = (offs + torch.from_numpy(randInt).to(device))
                node_sample[dirac_locations] = 1
                
                dirac_locations2 = torch.from_numpy(randInt).to(device) + global_offset
                global_offset += graph_nodes
                
                diracmatrix = torch.cat((diracmatrix, node_sample),0)
                locationmatrix = torch.cat((locationmatrix, dirac_locations2),0)
            
                
            
                #for batch prime
                dirac_indices = torch.arange(N, device=device).unsqueeze(-1).expand(-1, graph_nodes).contiguous().view(-1)
                dirac_indices = dirac_indices.long()
                dirac_indices += k*N
                batch_prime = torch.cat((batch_prime, dirac_indices))



            
#             locationmatrix = diracmatrix.nonzero()
            edge_index_prime = torch.arange(N).unsqueeze(-1).expand(-1,data.edge_index.shape[1]).contiguous().view(-1)*data.batch.shape[0]
            offset = torch.arange(N).unsqueeze(-1).expand(-1,data.edge_index.size()[1]).contiguous().view(-1)*data.batch.shape[0]
            offset_2 = torch.cat(2*[offset.unsqueeze(0)],dim = 0)
            edge_index_prime = torch.cat(N*[data.edge_index], dim = 1) + offset_2
            normalization_indices = data.batch.unsqueeze(-1).expand(-1,N).contiguous().view(-1).to(device)

            return Batch(batch = batch_prime, x = diracmatrix, edge_index = edge_index_prime,
                         y = data.y, locations = locationmatrix, norm_index = normalization_indices, batch_old = original_batch_index, edge_index_old = original_edge_index)
Exemplo n.º 16
0
 def sample_action_with_prob(self, x):
     likelihood = self.forward(x)
     m = OneHotCategorical(likelihood)
     action = m.sample()
     return action, m.log_prob(action)
Exemplo n.º 17
0
 def sample(self, dist_info):
     prob = dist_info["prob"]
     sampler = OneHotCategorical(prob)
     return sampler.sample()
Exemplo n.º 18
0
    def forward(self, state, encoding=None, device='cpu'):
        """
            - At the first time step, pass in the encoding vector from Encoder with shape (batch_size, hidden_size)
                using the optional argument encoding= . h_list and c_list will be reset to 0s
            - At the following time steps, DO NOT pass in any value to the optional argument encoding=
        """

        # TODO: Test the dimensions of this multilayer LSTM policy net

        # If encoding is not None, reset lists of hidden states and cell states
        if encoding is not None:
            self.h_list = [
                torch.zeros(
                    (self.batch_size, self.hidden_size), device=device) *
                self.num_layers
            ]
            self.c_list = [
                torch.zeros(
                    (self.batch_size, self.hidden_size), device=device) *
                self.num_layers
            ]
            self.h_list[0] = encoding

        # Forward propagation
        h1_list = []
        c1_list = []
        # First layer
        h_1, c_1 = self.cell_list[0](state, (self.h_list[0], self.c_list[0]))
        h1_list.append(h_1)
        c1_list.append(c_1)
        # Following layers
        for i in range(1, self.num_layers):
            h_1, c_1 = self.cell_list[i](h_1, (self.h_list[0], self.c_list[0]))
            h1_list.append(h_1)
            c1_list.append(c_1)
        # Store hidden states list and cell state list
        self.h_list = h1_list
        self.c_list = c1_list

        decision_logit = self.FC_decision(h_1)
        values_mean = self.FC_values_mean(h_1)
        values_logstd = self.FC_values_logstd(h_1)

        # Take the exponentials of log standard deviation
        values_std = torch.exp(values_logstd)

        # Create a categorical (multinomial) distribution from which we can sample a decision on the action dimension
        m_decision = OneHotCategorical(logits=decision_logit)

        # Sample a decision and calculate its log probability. decision of shape (num_actions,)
        decision = m_decision.sample()
        decision_log_prob = m_decision.log_prob(decision)

        # Create a list of Normal distributions for sampling actions in each dimension
        # Note: the last action is assumed to be discrete, meaning "doing nothing", so it has a conditional probability
        #       of 1.
        m_values = []
        action_values = None
        actions_log_prob = None
        # All actions except the last one are assumed to have normal distribution
        for i in range(self.num_actions - 1):
            m_values.append(Normal(values_mean[:, i], values_std[:, i]))
            if action_values is None:
                action_values = m_values[-1].sample().unsqueeze(
                    1)  # Unsqueeze to spare the batch dimension
                actions_log_prob = m_values[-1].log_prob(
                    action_values[:, -1]).unsqueeze(1)
            else:
                action_values = torch.cat(
                    [action_values, m_values[-1].sample().unsqueeze(1)], dim=1)
                actions_log_prob = torch.cat([
                    actions_log_prob, m_values[-1].log_prob(
                        action_values[:, -1]).unsqueeze(1)
                ],
                                             dim=1)

        # TODO: Append the last action. The last action has value 0.0 and log probability 0.0.
        action_values = torch.cat(
            [action_values,
             torch.zeros((self.batch_size, 1), device=device)],
            dim=1)
        actions_log_prob = torch.cat([
            actions_log_prob,
            torch.zeros((self.batch_size, 1), device=device)
        ],
                                     dim=1)

        # Filter the final action value in the intended action dimension
        final_action_values = (action_values * decision).sum(dim=1)
        final_action_log_prob = (actions_log_prob * decision).sum(dim=1)

        # Scale the action value by act_lim
        final_action_values = final_action_values * self.act_lim

        # Calculate the final log probability
        #   Pr(action value in the ith dimension) = Pr(action value given the agent chooses the ith dimension)
        #                                           * Pr(the agent chooses the ith dimension
        log_prob = decision_log_prob + final_action_log_prob

        return decision, final_action_values, log_prob
Exemplo n.º 19
0
class Transform5(nn.Module):  # combine
    def __init__(self):
        super(Transform5, self).__init__()
        kernel = [[[[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]]]]
        kernel = torch.from_numpy(np.array(kernel)).float()
        self.conv1 = nn.Conv2d(1,
                               1,
                               kernel_size=(3, 3),
                               stride=(1, 1),
                               padding=(1, 1),
                               bias=False)
        self.conv1.weight = nn.Parameter(kernel)

        self.conv_trans1 = nn.Conv2d(1,
                                     1,
                                     kernel_size=(5, 5),
                                     stride=(1, 1),
                                     padding=(2, 2),
                                     bias=False)
        self.conv_trans2 = nn.Conv2d(1,
                                     1,
                                     kernel_size=(5, 5),
                                     stride=(1, 1),
                                     padding=(2, 2),
                                     bias=False)
        self.conv_trans3 = nn.Conv2d(1,
                                     1,
                                     kernel_size=(5, 5),
                                     stride=(1, 1),
                                     padding=(2, 2),
                                     bias=False)
        self.conv_trans4 = nn.Conv2d(1,
                                     1,
                                     kernel_size=(5, 5),
                                     stride=(1, 1),
                                     padding=(2, 2),
                                     bias=False)

        self.conv_smooth = nn.Conv2d(1,
                                     1,
                                     kernel_size=(3, 3),
                                     stride=(1, 1),
                                     padding=(1, 1),
                                     bias=False)
        self.conv_smooth.weight = nn.Parameter(
            torch.ones(9).cuda().view(1, 1, 3, 3) * 1 / 9.)
        self.drop = nn.Dropout(p=0.05)
        self.relu = nn.ReLU(inplace=True)

        self.one_hot1 = OneHotCategorical(torch.Tensor([0.6, 0.4]))
        self.one_hot2 = OneHotCategorical(
            torch.Tensor([
                1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24.,
                1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 0.000, 1 / 24.,
                1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24.,
                1 / 24., 1 / 24., 1 / 24., 1 / 24.
            ]))

        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        # random translation
        self.conv_trans1.weight = nn.Parameter(
            self.one_hot2.sample().cuda().view(1, 1, 5, 5))
        self.conv_trans2.weight = nn.Parameter(
            self.one_hot2.sample().cuda().view(1, 1, 5, 5))
        y1 = self.conv_trans2(self.conv_trans1(x))  # equivalent to random crop

        # smooth
        # y2 = self.conv_smooth(x)

        # data dropout
        y3 = (self.drop(self.conv1(x)) + x) / 2.

        # gaussian noise
        # y4 = self.relu(torch.randn_like(x).cuda() * torch.mean(x) * 0.01)

        switch = self.one_hot1.sample().cuda()
        y = y1 * switch[0] + y3 * switch[1]

        for param in self.parameters():
            param.requires_grad = False
        return y
Exemplo n.º 20
0
def dcgan(
    root: xmen.Root,  #
    # first argument is always an experiment instance.
    # can be unused (specify with _) in experiments
    # syntax practice. can be named whatever depending
    # on use case. Eg. logger, root, experiment ...
    b: int = 128,  # the batch size per gpu
    hw0: Tuple[int, int] = (4, 4),  # the height and width of the image
    nl: int = 4,  # the number of levels in the discriminator.
    data_root: str = os.getenv("HOME") +
    '/data/mnist',  # @p the root data directory
    cx: int = 1,
    cy: int = 10,  # the dimensionality of the conditioning vector
    cf:
    int = 512,  # the number of features after the first conv in the discriminator
    cz: int = 100,  # the dimensionality of the noise vector
    ncpus: int = 8,  # the number of threads to use for data loading
    ngpus: int = 1,  # the number of gpus to run the model on
    epochs: int = 20,  # no. of epochs to train for
    lr: float = 0.0002,  # learning rate
    betas: Tuple[float, float] = (0.5, 0.999),  # the beta parameters for the
    # monitoring parameters
    checkpoint: str = 'nn_.*@1e',  # checkpoint at this modulo string
    log: str = 'loss_.*@20s',  # log scalars
    sca: str = 'loss_.*@20s',  # tensorboard scalars
    img: str = '_x_|x$@20s',  # tensorboard images
    nimg: int = 64,  # the maximum number of images to display to tensorboard
    ns: int = 5  # the number of samples to generate at inference)
):
    """Train a conditional GAN to predict MNIST digits.

    To viusalise the results run::

        tensorboard --logdir ...

    """
    from xmen.monitor import TorchMonitor, TensorboardLogger
    from xmen.examples.models import weights_init, set_requires_grad, GeneratorNet, DiscriminatorNet
    from torch.distributions import Normal
    from torch.distributions.one_hot_categorical import OneHotCategorical
    from torch.optim import Adam
    import logging

    hw = [d * 2**nl for d in hw0]
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    logger = logging.getLogger()
    logger.setLevel('INFO')
    # dataset
    datasets = get_datasets(cy, cz, b, ngpus, ncpus, ns, data_root, hw)
    # models
    nn_g = GeneratorNet(cy, cz, cx, cf, hw0, nl)
    nn_d = DiscriminatorNet(cx, cy, cf, hw0, nl)
    nn_g = nn_g.to(device).float().apply(weights_init)
    nn_d = nn_d.to(device).float().apply(weights_init)
    # distributions
    pz = Normal(torch.zeros([cz]), torch.ones([cz]))
    py = OneHotCategorical(probs=torch.ones([cy]) / cy)
    # optimisers
    op_d = Adam(nn_d.parameters(), lr=lr, betas=betas)
    op_g = Adam(nn_g.parameters(), lr=lr, betas=betas)
    # monitor
    m = TorchMonitor(root.directory,
                     ckpt=checkpoint,
                     log=log,
                     sca=sca,
                     img=img,
                     time=('@20s', '@1e'),
                     msg='root@100s',
                     img_fn=lambda x: x[:min(nimg, x.shape[0])],
                     hooks=[TensorboardLogger('image', '_xi_$@1e', nrow=10)])

    for _ in m(range(epochs)):
        # (1) train
        for x, y in m(datasets['train']):
            # process input
            x, y = x.to(device), y.to(device).float()
            b = x.shape[0]
            # discriminator step
            set_requires_grad([nn_d], True)
            op_d.zero_grad()
            z = pz.sample([b]).reshape([b, cz, 1, 1]).to(device)
            _x_ = nn_g(y, z)
            loss_d = nn_d((x, y), True) + nn_d(
                (_x_.detach(), y.detach()), False)
            loss_d.backward()
            op_d.step()
            # generator step
            op_g.zero_grad()
            y = py.sample([b]).reshape([b, cy, 1, 1]).to(device)
            z = pz.sample([b]).reshape([b, cz, 1, 1]).to(device)
            _x_ = nn_g(y, z)
            set_requires_grad([nn_d], False)
            loss_g = nn_d((_x_, y), True)
            loss_g.backward()
            op_g.step()
        # (2) inference
        if 'inference' in datasets:
            with torch.no_grad():
                for yi, zi in datasets['inference']:
                    yi, zi = yi.to(device), zi.to(device)
                    _xi_ = nn_g(yi, zi)