def sample_stroke(self, pis, mus, sigmas, rhos, qs, gamma): """ Input: pis[batch, seq_len, M] mus[batch, seq_len, M, 2] sigmas[batch, seq_len, M, 2] rhos[batch, seq_len, M] qs[batch, seq_len, 3] Output: strokes[batch, seq_len, 5]: """ batch_size, seq_len, M = pis.size() strokes = [] sigmas = sigmas * gamma # Sample for each sketch for i in range(batch_size): #print(pis[:,i,:].size(), pis[:,i,:].device) #print(pis.size(), mus.size(), sigmas.size(), rhos.size(), qs.size()) comp_m = OneHotCategorical(logits=pis[i, :, :]) comp_choice = (comp_m.sample()==1) mu, sigma, rho, q = mus[i,:,:,:][comp_choice], sigmas[i,:,:,:][comp_choice], rhos[i,:,:][comp_choice], qs[i,:,:] cov = torch.stack([torch.diag(sigma[j]*sigma[j]) + (1-torch.eye(2).to(mu.device)) * rho[j] * torch.prod(sigma[j]) for j in range(seq_len)]).to(device=mu.device) normal_m = MultivariateNormal(mu, cov) stroke_move = normal_m.sample().to(pis.device) # [seq_len, 2] pen_states = (q == q.max(dim=1, keepdim=True)[0]).to(dtype=torch.float)#[seq_len, 3] stroke = torch.cat([stroke_move, pen_states], dim=1).to(pis.device) strokes.append(stroke) return torch.stack(strokes)
def sample_single_stroke(self, pis, mus, sigmas, rhos, qs, gamma): """ Input: pis[M] mus[M, 2] sigmas[M, 2] rhos[M] qs[3] Output: strokes[5] """ comp_m = OneHotCategorical(logits=pis) comp_choice = (comp_m.sample() == 1) mu, sigma, rho, q = mus[comp_choice].view(-1), sigmas[ comp_choice].view(-1), rhos[comp_choice].view(-1), qs.view(-1) cov = (torch.diag((sigma * sigma)) + (1 - torch.eye(2).to(mu.device)) * rho * torch.prod(sigma)).to( device=mu.device) normal_m = MultivariateNormal(mu, cov) stroke_move = normal_m.sample().to(pis.device) # [seq_len, 2] pen_states = (q == q.max(dim=0, keepdim=True)[0]).to( dtype=torch.float) #[seq_len, 3] # print('mu',mu,'stroke_move', stroke_move, 'pen_states', pen_states) stroke = torch.cat( [stroke_move.view(-1), pen_states.view(-1)], dim=0).to(pis.device) return stroke
class Transform4(nn.Module): # rand translation def __init__(self): super(Transform4, self).__init__() self.conv_trans = nn.Conv2d(3, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False, groups=3) self.one_hot2 = OneHotCategorical( torch.Tensor([ 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 0.000, 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24. ])) def forward(self, x): kernel = self.one_hot2.sample().view(1, 5, 5) # 1x5x5 kernel = torch.stack([kernel] * 3).cuda() # 3x1x5x5 self.conv_trans.weight = nn.Parameter(kernel) y = self.conv_trans(x) self.conv_trans.requires_grad = False return y
def _prior(self, inference='sample'): device = self.z2hdec.weight.device batchSize = self.images.shape[0] if self.varType == 'cont': if inference == 'sample': sample = torch.randn(batchSize, self.latentSize, device=device) else: sample = torch.zeros(batchSize, self.latentSize, device=device) mean = torch.zeros_like(sample) logv = torch.ones_like(sample) z = sample, (mean, logv) elif self.varType == 'gumbel': K = self.num_embeddings V = self.num_vars prior_probs = torch.tensor([1 / K] * K, dtype=torch.float, device=device) logprior = torch.log(prior_probs) if inference == 'sample': prior = OneHotCategorical(prior_probs) z_vs = prior.sample(sample_shape=(batchSize * V, )) z = z_vs.reshape([batchSize, -1]) else: z_vs = prior_probs.expand(batchSize * V, -1) z = z_vs.reshape([batchSize, -1]) z = (z, logprior) elif self.varType == 'none': raise Exception('Z has no prior for varType==none') return z
def sample(self, probas, beta=1): batch_size = probas[0].size(0) phi = beta * sum([p.view(batch_size, self.q, self.N) for p in probas]) phi += self.linear.weights.view(1, self.q, self.N) self.phi = phi distribution = OneHotCategorical( probs=F.softmax(phi, 1).permute(0, 2, 1)) return distribution.sample().permute(0, 2, 1)
def forward(self, x): h1 = self.act(self.fc1(x)) h2 = self.act(self.fc2(h1)) h3 = self.act(self.fc3(h2)) h4 = self.act(self.fc4(h3)) out = self.fc5(h4) probs1 = F.softmax(out[:, :NG1], dim=1) probs2 = F.softmax(out[:, NG1:], dim=1) distr1 = OneHotCategorical(probs=probs1) distr2 = OneHotCategorical(probs=probs2) msg_oht1 = distr1.sample() msg_oht2 = distr2.sample() self.get_log_probs = torch.log((probs1 * msg_oht1).sum(1)) + torch.log( (probs2 * msg_oht2).sum(1)) self.get_entropy = distr2.entropy() msg1 = msg_oht1.argmax(1) msg2 = msg_oht2.argmax(1) msgs_value = torch.cat((msg1.unsqueeze(1), msg2.unsqueeze(1)), dim=1) return out, msgs_value
def get_action(self, state): all_hp_probs, all_anchor_probs = self.forward(state) all_anchor_act, all_hp_act = [], [] for layer_anchor_probs in all_anchor_probs: anchor_sampler = Bernoulli(layer_anchor_probs) layer_anchor_act = anchor_sampler.sample() all_anchor_act.append(layer_anchor_act) for hp_probs in all_hp_probs: sampler = OneHotCategorical(logits=hp_probs) all_hp_act.append(sampler.sample()) return all_hp_act, all_anchor_act
def forward(self, rating_matrix): cores = F.normalize(self.k_embedding.weight, dim=1) items = F.normalize(self.item_embedding.weight, dim=1) rating_matrix = F.normalize(rating_matrix) rating_matrix = F.dropout(rating_matrix, self.drop_out, training=self.training) cates_logits = torch.matmul(items, cores.transpose(0, 1)) / self.tau if self.nogb: cates = torch.softmax(cates_logits, dim=1) else: cates_dist = OneHotCategorical(logits=cates_logits) cates_sample = cates_dist.sample() cates_mode = torch.softmax(cates_logits, dim=1) cates = (self.training * cates_sample + (1 - self.training) * cates_mode) probs = None mulist = [] logvarlist = [] for k in range(self.kfac): cates_k = cates[:, k].reshape(1, -1) # encoder x_k = rating_matrix * cates_k h = self.encoder(x_k) mu = h[:, :self.embedding_size] mu = F.normalize(mu, dim=1) logvar = h[:, self.embedding_size:] mulist.append(mu) logvarlist.append(logvar) z = self.reparameterize(mu, logvar) # decoder z_k = F.normalize(z, dim=1) logits_k = torch.matmul(z_k, items.transpose(0, 1)) / self.tau probs_k = torch.exp(logits_k) probs_k = probs_k * cates_k probs = (probs_k if (probs is None) else (probs + probs_k)) logits = torch.log(probs) return logits, mulist, logvarlist
def gumbel_softmax(self, logits, temperature, hard=False): """Sample from the Gumbel-Softmax distribution and optionally discretize. Args: logits: [batch_size, n_class] unnormalized log-probs temperature: non-negative scalar hard: if True, take argmax, but differentiate w.r.t. soft sample y Returns: [batch_size, n_class] sample from the Gumbel-Softmax distribution. If hard=True, then the returned sample will be one-hot, otherwise it will be a probabilitiy distribution that sums to 1 across classes """ prob = self.gumbel_softmax_sample(logits, temperature) if hard: sampler = OneHotCategorical(prob) prob = sampler.sample() return prob
def cat_softmax(probs, mode, tau=1, hard=False, dim=-1): if mode == 'REINFORCE' or mode == 'SCST': cat_distr = OneHotCategorical(probs=probs) return cat_distr.sample(), cat_distr.entropy() elif mode == 'GUMBEL': cat_distr = RelaxedOneHotCategorical(tau, probs=probs) y_soft = cat_distr.rsample() if hard: # Straight through. index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(probs, device=DEVICE).scatter_(dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: # Reparametrization trick. ret = y_soft return ret, ret
def recon_sketches(self, pis, mus, sigmas, rhos, qs, gamma): """ Input: pis[batch, seq_len, M] mus[batch, seq_len, M, 2] sigmas[batch, seq_len, M, 2] rhos[batch, seq_len, M] qs[batch, seq_len, 3] Output: strokes[batch, seq_len, 5]: """ batch_size, seq_len, M = pis.size() sketches = [] sigmas = sigmas * gamma # Sample for each sketch for i in range(batch_size): strokes = [] #print(pis[:,i,:].size(), pis[:,i,:].device) #print(pis.size(), mus.size(), sigmas.size(), rhos.size(), qs.size()) for j in range(seq_len): comp_m = OneHotCategorical(logits=pis[i, j]) comp_choice = (comp_m.sample() == 1) mu, sigma, rho, q = mus[i, j][comp_choice].view(-1), sigmas[ i, j][comp_choice].view(-1), rhos[i, j][comp_choice].view( -1), qs[i, j].view(-1) cov = (torch.diag( (sigma * sigma)) + (1 - torch.eye(2).to(mu.device)) * rho * torch.prod(sigma)).to(device=mu.device) normal_m = MultivariateNormal(mu, cov) stroke_move = normal_m.sample().to(pis.device) # [seq_len, 2] pen_states = (q == q.max(dim=0, keepdim=True)[0]).to( dtype=torch.float) #[seq_len, 3] # print('mu',mu,'stroke_move', stroke_move, 'pen_states', pen_states) stroke = torch.cat([stroke_move.view(-1), pen_states.view(-1)], dim=0).to(pis.device) strokes.append(stroke) sketches.append(torch.stack(strokes)) return torch.stack(sketches)
def decoding_sampler(logits, mode, tau=1, hard=False, dim=-1): if mode == 'REINFORCE' or mode == 'SCST': cat_distr = OneHotCategorical(logits=logits) return cat_distr.sample() elif mode == 'GUMBEL': cat_distr = RelaxedOneHotCategorical(tau, logits=logits) y_soft = cat_distr.rsample() elif mode == 'SOFTMAX': y_soft = F.softmax(logits, dim=1) if hard: # Straight through. index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(logits, device=args.device).scatter_( dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: # Reparametrization trick. ret = y_soft return ret
def forward(self, tgt_x): batch_size = tgt_x.shape[0] tgt_hid = self.x_to_embd(tgt_x) lstm_input = torch.zeros((batch_size, NG1)).cuda() lstm_hid = tgt_hid.squeeze(1) lstm_cell = tgt_hid.squeeze(1) msgs = [] msgs_value = [] logits = [] log_probs = 0. for _ in range(2): lstm_hid, lstm_cell = self.lstm(lstm_input, (lstm_hid, lstm_cell)) logit = self.out_layer(lstm_hid) logits.append(logit) probs = nn.functional.softmax(logit, dim=1) if self.training: cat_distr = OneHotCategorical(probs=probs) msg_oht, entropy = cat_distr.sample(), cat_distr.entropy() self.get_entropy = entropy else: msg_oht = nn.functional.one_hot( torch.argmax(probs, dim=1), num_classes=self.out_size).float() log_probs += torch.log((probs * msg_oht).sum(1)) msgs.append(msg_oht) msgs_value.append(msg_oht.argmax(1)) lstm_input = msg_oht msgs = torch.stack(msgs) msgs_value = torch.stack(msgs_value).transpose(0, 1) logits = torch.stack(logits) logits = logits.transpose(0, 1).reshape(batch_size, -1) self.get_log_probs = log_probs return logits, msgs_value
class Transform3(nn.Module): # 8-direction translation def __init__(self): super(Transform3, self).__init__() kernel_left = [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]] kernel_left = torch.from_numpy(np.array(kernel_left)).float() self.conv_left = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv_left.weight = nn.Parameter(kernel_left) kernel_right = [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]] kernel_right = torch.from_numpy(np.array(kernel_right)).float() self.conv_right = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv_right.weight = nn.Parameter(kernel_right) kernel_up = [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0]]]] kernel_up = torch.from_numpy(np.array(kernel_up)).float() self.conv_up = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv_up.weight = nn.Parameter(kernel_up) kernel_down = [[[[0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]] kernel_down = torch.from_numpy(np.array(kernel_down)).float() self.conv_down = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv_down.weight = nn.Parameter(kernel_down) kernel5 = [[[[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]] kernel5 = torch.from_numpy(np.array(kernel5)).float() self.conv5 = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv5.weight = nn.Parameter(kernel5) kernel6 = [[[[0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]] kernel6 = torch.from_numpy(np.array(kernel6)).float() self.conv6 = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv6.weight = nn.Parameter(kernel6) kernel7 = [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]]]] kernel7 = torch.from_numpy(np.array(kernel7)).float() self.conv7 = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv7.weight = nn.Parameter(kernel7) kernel8 = [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]]]] kernel8 = torch.from_numpy(np.array(kernel8)).float() self.conv8 = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv8.weight = nn.Parameter(kernel8) self.one_hot1 = OneHotCategorical(torch.Tensor([1. / 8] * 8)) for param in self.parameters(): param.requires_grad = False def forward(self, x): switch = self.one_hot1.sample().cuda() y = self.conv_left(x) * switch[0] + self.conv_right(x) * switch[1] + \ self.conv_up(x) * switch[2] + self.conv_down(x) * switch[3] + \ self.conv5(x) * switch[4] + self.conv6(x) * switch[5] + \ self.conv7(x) * switch[6] + self.conv8(x) * switch[7] return y
def getNdiracs(data, N , sparse = False, flat = False, replace = True): if not sparse: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') graphcount =data.num_nodes #number of graphs in data/batch object totalnodecount = data.x.shape[1] #number of total nodes for each graph actualnodecount = 0 #cumulative number of nodes diracmatrix= torch.zeros((graphcount,totalnodecount,N),device=device) #matrix with dirac pulses for k in range(graphcount): graph_nodes = data.mask[k].sum() #number of nodes in the graph actualnodecount += graph_nodes #might not need this, we'll see probabilities= torch.ones((graph_nodes.item(),1),device=device)/graph_nodes #uniform probs node_distribution=OneHotCategorical(probs=probabilities.squeeze()) node_sample= node_distribution.sample(sample_shape=(N,)) node_sample= torch.cat((node_sample,torch.zeros((N,totalnodecount-node_sample.shape[1]),device=device)),-1) #concat zeros to fit dataset shape diracmatrix[k,:]= torch.transpose(node_sample,dim0=-1,dim1=-2) #add everything to the final matrix return diracmatrix else: original_batch_index = data.batch original_edge_index = data.edge_index device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') graphcount =data.num_graphs #number of graphs in data/batch object diracmatrix = torch.zeros(0,device=device) batch_prime = torch.zeros(0,device=device).long() locationmatrix = torch.zeros(0,device=device).long() global_offset = 0 for k in range(graphcount): graph_nodes = (data.batch == k).sum() #probabilities = torch.ones((graph_nodes.item(),1),device=device)/graph_nodes #uniform probs #node_distribution = OneHotCategorical(probs=probabilities.squeeze()) #node_sample = node_distribution.sample(sample_shape=(N,)) # if flat: # diracmatrix = torch.cat((diracmatrix, node_sample.view(-1)),0) # else: # diracmatrix = torch.cat((diracmatrix, node_sample.t(),0)) #for diracmatrix randInt = np.random.choice(range(graph_nodes), N, replace = replace) node_sample = torch.zeros(N*graph_nodes,device=device) offs = torch.arange(N, device=device)*graph_nodes dirac_locations = (offs + torch.from_numpy(randInt).to(device)) node_sample[dirac_locations] = 1 dirac_locations2 = torch.from_numpy(randInt).to(device) + global_offset global_offset += graph_nodes diracmatrix = torch.cat((diracmatrix, node_sample),0) locationmatrix = torch.cat((locationmatrix, dirac_locations2),0) #for batch prime dirac_indices = torch.arange(N, device=device).unsqueeze(-1).expand(-1, graph_nodes).contiguous().view(-1) dirac_indices = dirac_indices.long() dirac_indices += k*N batch_prime = torch.cat((batch_prime, dirac_indices)) # locationmatrix = diracmatrix.nonzero() edge_index_prime = torch.arange(N).unsqueeze(-1).expand(-1,data.edge_index.shape[1]).contiguous().view(-1)*data.batch.shape[0] offset = torch.arange(N).unsqueeze(-1).expand(-1,data.edge_index.size()[1]).contiguous().view(-1)*data.batch.shape[0] offset_2 = torch.cat(2*[offset.unsqueeze(0)],dim = 0) edge_index_prime = torch.cat(N*[data.edge_index], dim = 1) + offset_2 normalization_indices = data.batch.unsqueeze(-1).expand(-1,N).contiguous().view(-1).to(device) return Batch(batch = batch_prime, x = diracmatrix, edge_index = edge_index_prime, y = data.y, locations = locationmatrix, norm_index = normalization_indices, batch_old = original_batch_index, edge_index_old = original_edge_index)
def sample_action_with_prob(self, x): likelihood = self.forward(x) m = OneHotCategorical(likelihood) action = m.sample() return action, m.log_prob(action)
def sample(self, dist_info): prob = dist_info["prob"] sampler = OneHotCategorical(prob) return sampler.sample()
def forward(self, state, encoding=None, device='cpu'): """ - At the first time step, pass in the encoding vector from Encoder with shape (batch_size, hidden_size) using the optional argument encoding= . h_list and c_list will be reset to 0s - At the following time steps, DO NOT pass in any value to the optional argument encoding= """ # TODO: Test the dimensions of this multilayer LSTM policy net # If encoding is not None, reset lists of hidden states and cell states if encoding is not None: self.h_list = [ torch.zeros( (self.batch_size, self.hidden_size), device=device) * self.num_layers ] self.c_list = [ torch.zeros( (self.batch_size, self.hidden_size), device=device) * self.num_layers ] self.h_list[0] = encoding # Forward propagation h1_list = [] c1_list = [] # First layer h_1, c_1 = self.cell_list[0](state, (self.h_list[0], self.c_list[0])) h1_list.append(h_1) c1_list.append(c_1) # Following layers for i in range(1, self.num_layers): h_1, c_1 = self.cell_list[i](h_1, (self.h_list[0], self.c_list[0])) h1_list.append(h_1) c1_list.append(c_1) # Store hidden states list and cell state list self.h_list = h1_list self.c_list = c1_list decision_logit = self.FC_decision(h_1) values_mean = self.FC_values_mean(h_1) values_logstd = self.FC_values_logstd(h_1) # Take the exponentials of log standard deviation values_std = torch.exp(values_logstd) # Create a categorical (multinomial) distribution from which we can sample a decision on the action dimension m_decision = OneHotCategorical(logits=decision_logit) # Sample a decision and calculate its log probability. decision of shape (num_actions,) decision = m_decision.sample() decision_log_prob = m_decision.log_prob(decision) # Create a list of Normal distributions for sampling actions in each dimension # Note: the last action is assumed to be discrete, meaning "doing nothing", so it has a conditional probability # of 1. m_values = [] action_values = None actions_log_prob = None # All actions except the last one are assumed to have normal distribution for i in range(self.num_actions - 1): m_values.append(Normal(values_mean[:, i], values_std[:, i])) if action_values is None: action_values = m_values[-1].sample().unsqueeze( 1) # Unsqueeze to spare the batch dimension actions_log_prob = m_values[-1].log_prob( action_values[:, -1]).unsqueeze(1) else: action_values = torch.cat( [action_values, m_values[-1].sample().unsqueeze(1)], dim=1) actions_log_prob = torch.cat([ actions_log_prob, m_values[-1].log_prob( action_values[:, -1]).unsqueeze(1) ], dim=1) # TODO: Append the last action. The last action has value 0.0 and log probability 0.0. action_values = torch.cat( [action_values, torch.zeros((self.batch_size, 1), device=device)], dim=1) actions_log_prob = torch.cat([ actions_log_prob, torch.zeros((self.batch_size, 1), device=device) ], dim=1) # Filter the final action value in the intended action dimension final_action_values = (action_values * decision).sum(dim=1) final_action_log_prob = (actions_log_prob * decision).sum(dim=1) # Scale the action value by act_lim final_action_values = final_action_values * self.act_lim # Calculate the final log probability # Pr(action value in the ith dimension) = Pr(action value given the agent chooses the ith dimension) # * Pr(the agent chooses the ith dimension log_prob = decision_log_prob + final_action_log_prob return decision, final_action_values, log_prob
class Transform5(nn.Module): # combine def __init__(self): super(Transform5, self).__init__() kernel = [[[[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]]]] kernel = torch.from_numpy(np.array(kernel)).float() self.conv1 = nn.Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) self.conv1.weight = nn.Parameter(kernel) self.conv_trans1 = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv_trans2 = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv_trans3 = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv_trans4 = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv_smooth = nn.Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) self.conv_smooth.weight = nn.Parameter( torch.ones(9).cuda().view(1, 1, 3, 3) * 1 / 9.) self.drop = nn.Dropout(p=0.05) self.relu = nn.ReLU(inplace=True) self.one_hot1 = OneHotCategorical(torch.Tensor([0.6, 0.4])) self.one_hot2 = OneHotCategorical( torch.Tensor([ 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 0.000, 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24. ])) for param in self.parameters(): param.requires_grad = False def forward(self, x): # random translation self.conv_trans1.weight = nn.Parameter( self.one_hot2.sample().cuda().view(1, 1, 5, 5)) self.conv_trans2.weight = nn.Parameter( self.one_hot2.sample().cuda().view(1, 1, 5, 5)) y1 = self.conv_trans2(self.conv_trans1(x)) # equivalent to random crop # smooth # y2 = self.conv_smooth(x) # data dropout y3 = (self.drop(self.conv1(x)) + x) / 2. # gaussian noise # y4 = self.relu(torch.randn_like(x).cuda() * torch.mean(x) * 0.01) switch = self.one_hot1.sample().cuda() y = y1 * switch[0] + y3 * switch[1] for param in self.parameters(): param.requires_grad = False return y
def dcgan( root: xmen.Root, # # first argument is always an experiment instance. # can be unused (specify with _) in experiments # syntax practice. can be named whatever depending # on use case. Eg. logger, root, experiment ... b: int = 128, # the batch size per gpu hw0: Tuple[int, int] = (4, 4), # the height and width of the image nl: int = 4, # the number of levels in the discriminator. data_root: str = os.getenv("HOME") + '/data/mnist', # @p the root data directory cx: int = 1, cy: int = 10, # the dimensionality of the conditioning vector cf: int = 512, # the number of features after the first conv in the discriminator cz: int = 100, # the dimensionality of the noise vector ncpus: int = 8, # the number of threads to use for data loading ngpus: int = 1, # the number of gpus to run the model on epochs: int = 20, # no. of epochs to train for lr: float = 0.0002, # learning rate betas: Tuple[float, float] = (0.5, 0.999), # the beta parameters for the # monitoring parameters checkpoint: str = 'nn_.*@1e', # checkpoint at this modulo string log: str = 'loss_.*@20s', # log scalars sca: str = 'loss_.*@20s', # tensorboard scalars img: str = '_x_|x$@20s', # tensorboard images nimg: int = 64, # the maximum number of images to display to tensorboard ns: int = 5 # the number of samples to generate at inference) ): """Train a conditional GAN to predict MNIST digits. To viusalise the results run:: tensorboard --logdir ... """ from xmen.monitor import TorchMonitor, TensorboardLogger from xmen.examples.models import weights_init, set_requires_grad, GeneratorNet, DiscriminatorNet from torch.distributions import Normal from torch.distributions.one_hot_categorical import OneHotCategorical from torch.optim import Adam import logging hw = [d * 2**nl for d in hw0] device = 'cuda' if torch.cuda.is_available() else 'cpu' logger = logging.getLogger() logger.setLevel('INFO') # dataset datasets = get_datasets(cy, cz, b, ngpus, ncpus, ns, data_root, hw) # models nn_g = GeneratorNet(cy, cz, cx, cf, hw0, nl) nn_d = DiscriminatorNet(cx, cy, cf, hw0, nl) nn_g = nn_g.to(device).float().apply(weights_init) nn_d = nn_d.to(device).float().apply(weights_init) # distributions pz = Normal(torch.zeros([cz]), torch.ones([cz])) py = OneHotCategorical(probs=torch.ones([cy]) / cy) # optimisers op_d = Adam(nn_d.parameters(), lr=lr, betas=betas) op_g = Adam(nn_g.parameters(), lr=lr, betas=betas) # monitor m = TorchMonitor(root.directory, ckpt=checkpoint, log=log, sca=sca, img=img, time=('@20s', '@1e'), msg='root@100s', img_fn=lambda x: x[:min(nimg, x.shape[0])], hooks=[TensorboardLogger('image', '_xi_$@1e', nrow=10)]) for _ in m(range(epochs)): # (1) train for x, y in m(datasets['train']): # process input x, y = x.to(device), y.to(device).float() b = x.shape[0] # discriminator step set_requires_grad([nn_d], True) op_d.zero_grad() z = pz.sample([b]).reshape([b, cz, 1, 1]).to(device) _x_ = nn_g(y, z) loss_d = nn_d((x, y), True) + nn_d( (_x_.detach(), y.detach()), False) loss_d.backward() op_d.step() # generator step op_g.zero_grad() y = py.sample([b]).reshape([b, cy, 1, 1]).to(device) z = pz.sample([b]).reshape([b, cz, 1, 1]).to(device) _x_ = nn_g(y, z) set_requires_grad([nn_d], False) loss_g = nn_d((_x_, y), True) loss_g.backward() op_g.step() # (2) inference if 'inference' in datasets: with torch.no_grad(): for yi, zi in datasets['inference']: yi, zi = yi.to(device), zi.to(device) _xi_ = nn_g(yi, zi)