def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): """Deals with the instability of the gumbel_softmax for older versions of torch. For more details about the issue: https://drive.google.com/file/d/1AA5wPfZ1kquaRtVruCd6BiYZGcDeNxyP/view?usp=sharing Args: logits: […, num_features] unnormalized log probabilities tau: non-negative scalar temperature hard: if True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd dim (int): a dimension along which softmax will be computed. Default: -1. Returns: Sampled tensor of same shape as logits from the Gumbel-Softmax distribution. """ if version.parse(torch.__version__) < version.parse("1.2.0"): for i in range(10): transformed = functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim) if not torch.isnan(transformed).any(): return transformed raise ValueError("gumbel_softmax returning NaN.") return functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim)
def forward(self, input, hidden0, previous_head, embedded_stack,temperature): if type(temperature) != tuple: temperature = (temperature, temperature, temperature) output, (ht, ct) = self.lstm(input, hidden0) d_select = F.gumbel_softmax(self.selectlinear(output.view(-1)),tau=temperature[2]) read = (embedded_stack[:,0,:]*previous_head.view(-1,1).repeat(1,self.M_size[2])).sum(dim=0) ht = self.tanh(ht + self.tanh(self.plinear(read.view(1, -1)).view(1, 1, -1))) decision_input = output.view(-1) d_emb,d_cur = F.gumbel_softmax(self.emblinear(decision_input),tau=temperature[0]),F.gumbel_softmax(self.curlinear(decision_input),tau=temperature[1]) y = self.olinear(output.view(-1)) estack_symb = torch.zeros([1,self.M_size[1],self.M_size[2]]).to(device) estack_symb[0,0,:] = self.sigmoid(self.esymblinear(ht)).view(-1) stack_symb = self.sigmoid(self.esymblinear(ht)).view(1,1,-1).repeat(self.M_size[0],1,1) emb_push = torch.cat([estack_symb,embedded_stack[0:self.M_size[0]-1,:,:]],0) emb_pop = torch.cat([embedded_stack[1:self.M_size[0],:,:],torch.zeros([1,self.M_size[1],self.M_size[2]]).to(device)],0) embedded_stack_1 = emb_push*d_emb[0] + embedded_stack*d_emb[1] + emb_pop*d_emb[2] #embedded_stack_1 = embedded_stack stack_push = torch.cat([stack_symb,embedded_stack_1[:,0:self.M_size[1]-1,:]],1) stack_pop = torch.cat([embedded_stack_1[:,1:self.M_size[1],:],torch.zeros([self.M_size[0],1,self.M_size[2]]).to(device)],1) embedded_stack_2 = stack_push*d_cur[0] + embedded_stack_1*d_cur[1] + stack_pop*d_cur[2] new_embedded_stack = embedded_stack_1 * (1 - previous_head.view(-1,1,1).repeat([1,self.M_size[1],self.M_size[2]])) + embedded_stack_2 * previous_head.view(-1,1,1).repeat([1,self.M_size[1],self.M_size[2]]) shift_right = torch.cat([torch.zeros([1]).to(device),previous_head[0:self.M_size[0]-1]],0) shift_left = torch.cat([previous_head[1:self.M_size[0]],torch.zeros([1]).to(device)],0) next_head = shift_right*d_select[0] + previous_head*d_select[1] + shift_left*d_select[2] next_head = next_head/next_head.sum() #next_read = (new_embedded_stack[:,0,:]*next_head.view(-1,1).repeat(1,self.M_size[2])).sum(dim=0) #ct = self.tanh(ct + self.tanh(self.plinear(next_read.view(1, -1)).view(1, 1, -1))) debug_tuple=torch.cat([d_emb,d_cur,d_select],0) return y, (ht.view(1,1,-1),ct), next_head,new_embedded_stack,debug_tuple
def postprocess(self, inputs, method, temperature=1.): def listify(x): return x if type(x) == list or type(x) == tuple else [x] def delistify(x): return x if len(x) > 1 else x[0] if method == 'soft_gumbel': softmax = [ F.gumbel_softmax( e_logits.contiguous().view(-1, e_logits.size(-1)) / temperature, hard=False).view(e_logits.size()) for e_logits in listify(inputs) ] elif method == 'hard_gumbel': softmax = [ F.gumbel_softmax( e_logits.contiguous().view(-1, e_logits.size(-1)) / temperature, hard=True).view(e_logits.size()) for e_logits in listify(inputs) ] else: softmax = [ F.softmax(e_logits / temperature, -1) for e_logits in listify(inputs) ] return [delistify(e) for e in (softmax)]
def forward(self, x): """to understand more about this forward pass please refer to the VQVAE_v3.forward method which has much better documentation.""" enc_out = self.enc(x) # [B, (H*W)//16, n_embd] if self.codebook is not None: if self.training: softmax = F.gumbel_softmax(enc_out, tau=1., hard=True, dim=-1) else: softmax = F.softmax(enc_out, dim=-1) softmax = F.one_hot(torch.argmax(softmax, dim=-1)) quantized_inputs = einsum("bdhw,dn->bnhw", softmax, self.codebook.weight) else: if self.training: softmax = F.gumbel_softmax(enc_out, tau=1., hard=True, dim=-1) else: softmax = F.softmax(enc_out, dim=-1) softmax = F.one_hot(torch.argmax(softmax, dim=-1)) quantized_inputs = softmax encoding_ids = torch.argmax(softmax, dim=-1).view(enc_out.size(0), -1) dec_out = self.dec(quantized_inputs) loss = F.mse_loss(dec_out, x) # encoding_ids, loss, recons return encoding_ids, loss, dec_out
def sub_scheduler(self, sub_scheduler_mlp, hidden_state, agent_mask, directed=True): """ Function to perform a sub-scheduler Arguments: sub_scheduler_mlp (nn.Sequential): the MLP layers in a sub-scheduler hidden_state (tensor): the encoded messages input to the sub-scheduler [n * hid_size] agent_mask (tensor): [n * 1] directed (bool): decide if generate directed graphs Return: adj (tensor): a adjacency matrix which is the communication graph [n * n] """ # hidden_state: [n * hid_size] n = self.args.nagents hid_size = hidden_state.size(-1) # hard_attn_input: [n * n * (2*hid_size)] hard_attn_input = torch.cat([hidden_state.repeat(1, n).view(n * n, -1), hidden_state.repeat(n, 1)], dim=1).view(n, -1, 2 * hid_size) # hard_attn_output: [n * n * 2] if directed: hard_attn_output = F.gumbel_softmax(sub_scheduler_mlp(hard_attn_input), hard=True) else: hard_attn_output = F.gumbel_softmax(0.5*sub_scheduler_mlp(hard_attn_input)+0.5*sub_scheduler_mlp(hard_attn_input.permute(1,0,2)), hard=True) # hard_attn_output: [n * n * 1] hard_attn_output = torch.narrow(hard_attn_output, 2, 1, 1) # agent_mask and agent_mask_transpose: [n * n] agent_mask = agent_mask.expand(n, n) agent_mask_transpose = agent_mask.transpose(0, 1) # adj: [n * n] adj = hard_attn_output.squeeze() * agent_mask * agent_mask_transpose return adj
def forward(self, user_query, item_query, user_context, item_context, user_key_mask, item_key_mask, mode="Train"): item_query = self.transform(item_query).unsqueeze(dim=1) user_output, user_weights = self.attention(item_query, user_context, user_context, user_key_mask) user_query = self.transform(user_query).unsqueeze(dim=1) item_output, item_weights = self.attention(user_query, item_context, item_context, item_key_mask) if mode == "Test": user_weights = torch.argmax(user_weights, dim=-1) item_weights = torch.argmax(item_weights, dim=-1) user_tensor, item_tensor = user_output, item_output else: user_weights = F.gumbel_softmax(user_weights, hard=True) user_tensor = torch.bmm(user_weights.float(), user_context) item_weights = F.gumbel_softmax(item_weights, hard=True) item_tensor = torch.bmm(item_weights, item_context) predicted = self.activation(user_tensor * item_tensor) return predicted, user_weights, item_weights
def forward(self, input, discrete=False): # NASBench only has one input to each cell s0 = self.stem(input) for i, cell in enumerate(self.cells): if i in [self._layers // 3, 2 * self._layers // 3]: # Perform down-sampling by factor 1/2 # Equivalent to https://github.com/google-research/nasbench/blob/master/nasbench/lib/model_builder.py#L68 s0 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)(s0) # If using discrete architecture from random_ws search with weight sharing then pass through architecture # weights directly. # For GDAS use gumbel softmax hard, therefore per mixed block only a single operation is evaluated preprocess_op_mixed_op = lambda x: x if discrete else F.gumbel_softmax(x, tau=self.tau, hard=True, dim=-1) # Don't use hard for the rest, because it very quickly gave exploding gradients preprocess_op = lambda x: x if discrete else F.gumbel_softmax(x, tau=self.tau, hard=False, dim=-1) # Normalize mixed_op weights for the choice blocks in the graph mixed_op_weights = preprocess_op_mixed_op(self._arch_parameters[0]) # Normalize the output weights output_weights = preprocess_op(self._arch_parameters[1]) if self._output_weights else None # Normalize the input weights for the nodes in the cell input_weights = [preprocess_op(alpha) for alpha in self._arch_parameters[2:]] s0 = cell(s0, mixed_op_weights, output_weights, input_weights) # Include one more preprocessing step here s0 = self.postprocess(s0) # [N, C_max * multiplier, w, h] -> [N, C_max, w, h] # Global Average Pooling by averaging over last two remaining spatial dimensions # Like in nasbench: https://github.com/google-research/nasbench/blob/master/nasbench/lib/model_builder.py#L92 out = s0.view(*s0.shape[:2], -1).mean(-1) logits = self.classifier(out.view(out.size(0), -1)) return logits
def forward(self): dag = torch.zeros(self.n_nodes, self.n_nodes) # the final dag sampled = np.zeros(self.n_nodes, dtype=bool) # set of nodes that children were sampled for # sample roots log_p_roots = F.logsigmoid(self.root_probs) # numerically stable p_log = torch.stack((log_p_roots, torch.log(1 - torch.exp(log_p_roots)))) roots = gumbel_softmax(p_log, hard=True, dim=0)[0] self.log(f'sampled roots {roots}') to_sample = roots.nonzero().view(-1).tolist() # list of nodes that will get children sampled # sample children log_p_edges = F.logsigmoid(self.edge_probs) ancestors = torch.eye(self.n_nodes, dtype=torch.uint8) count = 0 while(len(to_sample) > 0): # pick random element to sample nodes for i= to_sample.pop(0) if sampled[i]: continue self.log(f'sampling children for {i}') # don't sample ancestors and roots as children candidates = (1-ancestors[i,:].float()) * (1-roots) # sample children for node i p_log = torch.stack((log_p_edges[i,:], torch.log(1 - torch.exp(log_p_edges[i,:])))) dag[i,:] = gumbel_softmax(p_log, hard= True, dim=0)[0] * candidates.float() for j in dag[i,:].nonzero().view(-1).tolist(): self.log(f'sampled {j}') # add i to ancestors of j ancestors[j,i] = 1 # add all ancestors of i to j ancestors[j,:][ancestors[i,:]] = 1 to_sample.append(j) sampled[i] = True return dag
def forward(self, output_sizes, hold_seed=None, hold_initial_set=False): """ Sample from prior :param output_sizes: Tensor([B,]) :param hold_seed :param hold_initial_set :return: Tensor([B, N, D]) """ bsize = output_sizes.shape[0] if hold_initial_set: # [B, N] x_mask = get_mask(output_sizes, self.max_outputs) else: x_mask = sample_mask(output_sizes, self.max_outputs) if hold_seed is not None: # [B, N, Ds] torch.random.manual_seed(hold_seed) eps = torch.randn([1, self.max_outputs, self.dim_seed ]).to(x_mask.device).repeat(bsize, 1, 1) else: eps = torch.randn([bsize, self.max_outputs, self.dim_seed]).to(x_mask.device) if self.n_mixtures == 1: x = self.mu + torch.exp(self.logvar / 2.) * eps else: if self.train_gmm: if hold_seed is not None: torch.random.manual_seed(hold_seed) logits = self.logits.reshape([1, 1, self.n_mixtures]).repeat( 1, self.max_outputs, 1) # [1, N, M] onehot = F.gumbel_softmax( logits, tau=self.tau, hard=True).repeat(bsize, 1, 1).unsqueeze(-1) # [B, N, M, 1] else: logits = self.logits.reshape([1, 1, self.n_mixtures]).repeat( bsize, self.max_outputs, 1) # [B, N, M] onehot = F.gumbel_softmax(logits, tau=self.tau, hard=True).unsqueeze( -1) # [B, N, M, 1] mu = self.mu.reshape([1, 1, self.n_mixtures, self.dim_seed]) # [1, 1, M, D] sig = self.sig.reshape([1, 1, self.n_mixtures, self.dim_seed]) # [1, 1, M, D] mu = (mu * onehot).sum(2) # [B, N, D] sig = (sig * onehot).sum(2) # [B, N, D] x = mu + sig * eps else: mix = D.Categorical(self.logits) comp = D.Independent(D.Normal(self.mu, self.sig.abs()), 1) mixture = D.MixtureSameFamily(mix, comp) x = mixture.sample((output_sizes.size(0), self.max_outputs)) x = self.output(x) # [B, N, D] return x, x_mask
def apply_activate(data,tanh_list,soft_list): temp = torch.sigmoid(data[:,tanh_list[0]:(tanh_list[-1]+1)]) for i in range(len(soft_list)-1): tem_soft = F.gumbel_softmax(data[:,soft_list[i]:soft_list[i+1]], tau=0.2) temp = torch.cat([temp,tem_soft],1) tem_soft = F.gumbel_softmax(data[:,soft_list[-1]:], tau=0.2) temp = torch.cat([temp,tem_soft],1) return temp
def parse_gumbel(alpha, beta, k): """ parse continuous alpha to discrete gene. alpha is ParameterList: ParameterList [ Parameter(n_edges1, n_ops), Parameter(n_edges2, n_ops), ... ] beta is ParameterList: ParameterList [ Parameter(n_edges1), Parameter(n_edges2), ... ] gene is list: [ [('node1_ops_1', node_idx), ..., ('node1_ops_k', node_idx)], [('node2_ops_1', node_idx), ..., ('node2_ops_k', node_idx)], ... ] each node has two edges (k=2) in CNN. """ gene = [] assert PRIMITIVES[-1] == 'none' # assume last PRIMITIVE is 'none' # 1) Convert the mixed op to discrete edge (single op) by choosing top-1 weight edge # 2) Choose top-k edges per node by edge score (top-1 weight in edge) # output the connect idx[(node_idx, connect_idx, op_idx).... () ()] connect_idx = [] for edges, w in zip(alpha, beta): # edges: Tensor(n_edges, n_ops) discrete_a = F.gumbel_softmax(edges[:, :-1].reshape(-1), tau=1, hard=True) for i in range(k - 1): discrete_a = discrete_a + F.gumbel_softmax( edges[:, :-1].reshape(-1), tau=1, hard=True) discrete_a = discrete_a.reshape(-1, len(PRIMITIVES) - 1) reserved_edge = (discrete_a > 0).nonzero() node_gene = [] node_idx = [] for i in range(reserved_edge.shape[0]): edge_idx = reserved_edge[i][0].item() prim_idx = reserved_edge[i][1].item() prim = PRIMITIVES[prim_idx] node_gene.append((prim, edge_idx)) node_idx.append((edge_idx, prim_idx)) gene.append(node_gene) connect_idx.append(node_idx) return gene, connect_idx
def select_action_old(self, obs, valid_actions): ''' from logit to pysc2 actions :param logits: {'categorical': [], 'screen1': [], 'screen2': []} :return: FunctionCall form of action ''' obs_torch = {'categorical': 0, 'screen1': 0, 'screen2': 0} for o in obs: x = obs[o].astype('float32') x = np.expand_dims(x, 0) obs_torch[o] = torch.from_numpy(x).to(arglist.DEVICE) logits = self.actor(obs_torch) logits['categorical'] = self._mask_unavailable_actions( logits['categorical'], valid_actions) tau = 1.0 function_id = gumbel_softmax(logits['categorical'], tau=1e-10, hard=True) function_id = function_id.argmax().item() logits['categorical'].cpu().item() # select an action until it is valid. is_valid_action = self._test_valid_action(function_id, valid_actions) while not is_valid_action: tau *= 10 function_id = gumbel_softmax(logits['categorical'], tau=tau, hard=True) function_id = function_id.argmax().item() is_valid_action = self._test_valid_action(function_id, valid_actions) pos_screen1 = gumbel_softmax(logits['screen1'].view(1, -1), hard=True).argmax().item() pos_screen2 = gumbel_softmax(logits['screen2'].view(1, -1), hard=True).argmax().item() pos = [[ int(pos_screen1 % arglist.FEAT2DSIZE), int(pos_screen1 // arglist.FEAT2DSIZE) ], [ int(pos_screen2 % arglist.FEAT2DSIZE), int(pos_screen2 // arglist.FEAT2DSIZE) ]] # (x, y) args = [] cnt = 0 for arg in actions.FUNCTIONS[function_id].args: if arg.name in ['screen', 'screen2', 'minimap']: args.append(pos[cnt]) cnt += 1 else: args.append([0]) action = actions.FunctionCall(function_id, args) return action
def forward(self, x, y_, adj, non_label): if self.training: x = x.contiguous().view(-1, self.x_dim) y_ = y_.contiguous().view(-1, self.y_dim) # x2y y_encode, y_embedding = self.x_to_yu(x) q_dis_total, y_total, y_pred_total_total = [], [], [] for i in range(1): y = y_.clone() y[non_label] = F.gumbel_softmax(y_encode[non_label], tau=1.0, hard=True) y_total.append(y) # encode r_nodes = self.xy_to_r(y_embedding, y) r_graph = self.r_aggregate(r_nodes) mu, sigma = self.r_to_musigma(r_graph) q_dis = Normal(mu, sigma) q_dis_total.append(q_dis) y_pred_total = [] for _ in range(1): z_sample = q_dis.rsample() #Decode y_pred = self.x_to_y(x, z_sample) y_pred_total.append(y_pred) y_pred_total_total.append(y_pred_total) return y_pred_total_total, q_dis_total, y_total, y_encode else: x = x.contiguous().view(-1, self.x_dim) y_ = y_.contiguous().view(-1, self.y_dim) y_encode, y_embedding = self.x_to_yu(x) y_pred_total = [] for i in range(40): y = y_.clone() y[non_label] = F.gumbel_softmax(y_encode[non_label], tau=1.0, hard=True) # encode r_nodes = self.xy_to_r(y_embedding, y) r_graph = self.r_aggregate(r_nodes) mu, sigma = self.r_to_musigma(r_graph) q_dis = Normal(mu, sigma) for _ in range(1): z_sample = q_dis.rsample() #Decode y_pred = self.x_to_y(x, z_sample) y_pred_total.append(y_pred) y_pred = sum(y_pred_total) / len(y_pred_total) return y_pred, y_encode
def encode(self, input, tau=1): enc_b = self.enc_b(input) enc_t = self.enc_t(enc_b) quant_t = self.quantize_conv_t(enc_t) latent = F.gumbel_softmax(quant_t, tau=tau, hard=True, dim=1) latent_distribution = F.gumbel_softmax(quant_t, tau=tau, hard=False, dim=1) return latent, latent_distribution
def forward(self, x, temp=1): x = F.relu(self.fc(x)) #x = F.relu(self.fc_(x)) action_score = self.a_head(x) #z = torch.nn.functional.one_hot(prob.max(1)[1], num_classes=self.categorical_dim).view(-1, self.categorical_dim) return F.gumbel_softmax(action_score, hard=True, dim=-1, tau=temp), F.gumbel_softmax(action_score, hard=False, dim=-1, tau=temp)
def forward(self, x): if config().sim.env.state.type == "simple": x = x.reshape(x.size(0), x.size(2)) return F.gumbel_softmax(self.simple_fc(x), tau=config().learning.gumbel_softmax.tau) out = self.conv(x) out = out.view(x.size(0), -1) return F.gumbel_softmax(self.fc(out), tau=config().learning.gumbel_softmax.tau)
def sample_search(self): result = dict() for mutable in self.mutables: if isinstance(mutable, LayerChoice): # result[mutable.key] = F.gumbel_softmax(self.choices[mutable.key], hard=True, dim=-1).bool()[:-1] result[mutable.key] = F.gumbel_softmax( self.choices[mutable.key], hard=True, dim=-1).bool() elif isinstance(mutable, InputChoice): result[mutable.key] = F.gumbel_softmax( self.choices[mutable.key], hard=True, dim=-1).bool() return result
def learn(self): self.learn_step += 1 sample_index = np.random.choice(self.memory_capacity, self.batch_size) batch_memory = self.memory[sample_index, :] # in the memory, the 1st---4th column is state_now , the 5th is action , the 6th is reward # the final 4 column is state_next batch_s = Variable(torch.FloatTensor( batch_memory[:, :self.state_num])).to(self.device) batch_a = Variable( torch.LongTensor(batch_memory[:, self.state_num:self.state_num + self.action_num])).to(self.device) batch_r = Variable( torch.FloatTensor( batch_memory[:, self.state_num + self.action_num:self.state_num + self.action_num + 1])).to(self.device) batch_next_s = Variable( torch.FloatTensor(batch_memory[:, -self.state_num:])).to(self.device) batch_next_a_logits = self.actor_target(batch_next_s) batch_target_next_a = F.gumbel_softmax(batch_next_a_logits, dim=-1) y_true = batch_r + self.gamma * self.critic_target( batch_next_s, batch_target_next_a).detach() y_pred = self.critic(batch_s, batch_a.float()) critic_loss = self.loss(y_pred, y_true) self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() batch_a_logits = self.actor(batch_s) batch_target_a = F.gumbel_softmax(batch_a_logits, dim=-1) actor_loss = -torch.mean(self.critic(batch_s, batch_target_a)) self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()): target_param.data.copy_(target_param.data * (1.0 - self.tau) + param.data * self.tau) for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()): target_param.data.copy_(target_param.data * (1.0 - self.tau) + param.data * self.tau) return critic_loss.item(), actor_loss.item()
def forward(self, input): batch, C, H, W = input.size() s0 = s1 = self.stem(input) for i, cell in enumerate(self.cells): if cell.reduction: weights = F.gumbel_softmax(self.alphas_reduce, self.tau, True) else: weights = F.gumbel_softmax(self.alphas_normal, self.tau, True) s0, s1 = s1, cell(s0, s1, weights) out = self.global_pooling(s1) logits = self.classifier(out.view(out.size(0), -1)) return logits
def train(train_type): if train_type == 0: # simple VGAE trained model.eval() for param in model.parameters(): param.requires_grad = True model.psi.requires_grad = False model.train() optimizer.zero_grad() z = model.encode(x, train_pos_edge_index) l_kl_z = 1.0 * model.kl_loss() / data.num_nodes l_recon = model.recon_loss(z, train_pos_edge_index) l_kl_c = 0 loss = l_recon + l_kl_z loss.backward() optimizer.step() elif train_type == 1: model.eval() for param in model.parameters(): param.requires_grad = False model.psi.requires_grad = True model.train() optimizer.zero_grad() z = model.encode(x, train_pos_edge_index) pc_given_Z, qc_given_ZA = model.community_dists_probs( z, train_pos_edge_index) c = F.gumbel_softmax(qc_given_ZA.logits, tau=1, hard=True) l_kl_z = 1.0 * model.kl_loss() / data.num_nodes l_recon = model.recon_loss((z, c), train_pos_edge_index) l_kl_c = 1.0 * kl_divergence(qc_given_ZA, pc_given_Z).mean() loss = l_recon + l_kl_z loss.backward() optimizer.step() else: model.eval() for param in model.parameters(): param.requires_grad = True model.train() optimizer.zero_grad() z = model.encode(x, train_pos_edge_index) pc_given_Z, qc_given_ZA = model.community_dists_probs( z, train_pos_edge_index) c = F.gumbel_softmax(qc_given_ZA.logits, tau=1, hard=True) l_kl_z = 1.0 * model.kl_loss() / data.num_nodes l_kl_c = 1.0 * kl_divergence(qc_given_ZA, pc_given_Z).mean() l_recon = model.recon_loss((z, c), train_pos_edge_index) loss = l_recon + l_kl_z + l_kl_c loss.backward() optimizer.step() return l_recon, l_kl_z, l_kl_c
def get_weight(self, tau): if self.mode == "softmax": weight_normal = f.softmax(self.alpha_normal / tau, dim=-1) weight_reduce = f.softmax(self.alpha_reduce / tau, dim=-1) elif self.mode == "sigmoid": weight_normal = f.sigmoid(self.alpha_normal / tau) weight_reduce = f.sigmoid(self.alpha_reduce / tau) elif self.mode == "sigmoid": weight_normal = f.gumbel_softmax(self.alpha_normal, tau, dim=-1) weight_reduce = f.gumbel_softmax(self.alpha_reduce, tau, dim=-1) else: raise NotImplementedError(f"{self.mode} not implemented.") return weight_normal, weight_reduce
def forward(self, x, decode=True): params = self._get_dists_params(x) zs = [] if 'cont' in params.keys(): zs = [self._reparam_gauss(*params['cont'])] if 'cat' in params.keys(): for logits in params['cat']: if self.training: zs.append(F.gumbel_softmax(logits, tau=self.temp)) else: zs.append(F.gumbel_softmax(logits, tau=self.temp, hard=True)) z = torch.cat(zs, dim=1) if decode: recon = self.decoder(z) else: recon = None return recon, z, params
def forward(self, input, hidden, encoder_outputs, mask): embedded = self.embedding(input).permute(1, 0, 2) embedded = self.dropout(embedded) output, hidden = self.gru(embedded, hidden) #print(encoder_outputs.shape) #max_length, batch_size, dim #print(hidden[-1].shape) #batch_size, dim, 1 encoder_outputs_s = encoder_outputs.transpose( 0, 1) #batch_size, max_len, dim #print(encoder_outputs.shape) #print(hidden.shape) #print(hidden[-1].shape) score = torch.bmm(encoder_outputs_s, hidden[-1].unsqueeze(2)).squeeze(2) #print(score.shape) #bts, max_len #print(score) #print(mask) score = score.masked_fill(mask == 0, -1e10) #print(score.shape) #print(score) attn_weights = F.softmax(score, dim=1) attn_weights = attn_weights.unsqueeze(2) #(bts, dim, max_len)(bts, max_len, 1) c_t = torch.bmm(encoder_outputs_s.transpose(1, 2), attn_weights) #print(embedded.shape) #bts, 1, dim #print(attn_applied.shape) #bts, dim, 1 W_hc_t = self.Wh(c_t.squeeze(2)) #print(W_hc_t.shape) #print(hidden.shape) #print(hidden[-1].shape) U_hh_t = self.Uh(hidden[-1]) g = self.out(torch.tanh(U_hh_t + W_hc_t)).unsqueeze(0) #print(g.shape) g_ls = F.log_softmax(g, dim=2) #print(g_ls) #print(g_ls.shape) abc = F.gumbel_softmax(g, tau=0.9, hard=False, eps=1e-10, dim=2) '''print(abc) print(abc.shape) abc = F.gumbel_softmax(g_ls, tau=0.9, hard=False, eps=1e-10, dim=2) print(abc) print(abc.shape)''' cba = F.gumbel_softmax(g, tau=0.9, hard=True, eps=1e-10, dim=2) #print(cba) #print(cba.shape) #print(1/0) #g = g.squeeze(0) return cba, g_ls, hidden, attn_weights '''embedded = self.embedding(input)
def forward(self, x_f, y_f, y_c): """ Computes the correspondences in the feature space based on the selected parameters. Args: x_f (torch.tensor): infered features of points x [b,n,c] y_f (torch.tensor): infered features of points y [b,m,c] y_c (torch.tensor): coordinates of point y [b,m,3] Returns: x_corr (torch.tensor): coordinates of the feature based correspondences of points x [b,n,3] """ dist = pairwise_distance(x_f, y_f).detach() if self.corr_type == 'soft': y_soft = torch.softmax(-dist / (self.get_temp()), dim=2) if self.st: # Straight through. index = y_soft.max(dim=2, keepdim=True)[1] y_hard = torch.zeros_like(y_soft).scatter_(dim=2, index=index, value=1.0) ret = y_hard - y_soft.detach() + y_soft else: ret = y_soft elif self.corr_type == 'soft_gumbel': if self.st: # Straight through. ret = F.gumbel_softmax(-dist, tau=self.get_temp(), hard=True) else: ret = F.gumbel_softmax(-dist, tau=self.get_temp(), hard=False) else: index = dist.min(dim=2, keepdim=True)[1] ret = torch.zeros_like(dist).scatter_(dim=2, index=index, value=1.0) # Compute corresponding coordinates x_corr = torch.matmul(ret, y_c) return x_corr
def forward(self, x): xs = tuple(layer(x) for layer in self.layers) logits = tuple(F.log_softmax(x, dim=1) for x in xs) categorical_outputs = tuple( F.gumbel_softmax(logit, tau=self.tau, hard=True, eps=1e-10) for logit in logits) return torch.cat(categorical_outputs, 1)
def forward(self, hidden_vec): hs = F.softplus(self.l1(hidden_vec)) logit = torch.log(hs + 1e-08).view(-1, self.M, self.K).view(-1, self.K) probs = F.gumbel_softmax(logit, self.tau).view(-1, self.M * self.K) # probs ==> batchsize, M * K code_sum = torch.matmul(probs, self.codebook) return code_sum
def calculate_block_probability(self, arch_param, tau): """ Encode arch param to probability for generator training """ arch_param = arch_param.view( len(self.CONFIG.l_cfgs), self.CONFIG.split_blocks * self.CONFIG.kernels_nums) p_arch_param = torch.zeros_like(arch_param) for l_num, (l_cfg, l, p_l) in enumerate( zip(self.CONFIG.l_cfgs, arch_param, p_arch_param)): expansion, output_channel, kernels, stride, split_block, se = l_cfg for b in range(expansion): if b == 0 and l_num in self.CONFIG.static_layers: end_index = (b + 1) * split_block - 1 split_arch_param = l[b * split_block:(b + 1) * split_block - 1] else: end_index = (b + 1) * split_block split_arch_param = l[b * split_block:(b + 1) * split_block] p_l[b * split_block:end_index] = \ F.gumbel_softmax(split_arch_param, tau=tau) return p_arch_param
def forward(self, ques): # input # ques - shape: (batch_size, num_rounds, word_embedding_size) # output # ques_gs - shape: (batch_size, num_rounds, 2) # Lambda - shape: (batch_size, num_rounds, 2) batch_size = ques.size(0) num_rounds = ques.size(1) ques_embed = self.embed( ques) # shape: (batch_size, num_rounds, lstm_hidden_size) ques_embed = F.normalize( ques_embed, p=2, dim=-1) # shape: (batch_size, num_rounds, lstm_hidden_size) ques_logits = self.att( ques_embed) # shape: (batch_size, num_rounds, 2) logits = ques_logits.view(-1, 2) if self.training: ques_gs = F.gumbel_softmax(logits, hard=True) # shape: (batch_size, 2) else: _, max_value_indexes = logits.detach().max(1, keepdim=True) ques_gs = logits.detach().clone().zero_().scatter_( 1, max_value_indexes, 1) ques_gs = ques_gs.view(-1, num_rounds, 2) Lambda = self.softmax(ques_logits) return ques_gs, Lambda # discrete, continuous
def input_to_sentence(self, x_hot): batch_size = x_hot.size()[0] h = self.i_h(x_hot) sr = self.s_r(h) #c = self.i_h(x_hot) c = torch.zeros_like(h) #if lstm input_word = torch.zeros(batch_size, self.vocab_size, device=x_hot.device) output_words = torch.zeros(self.max_seq_len, batch_size, self.vocab_size, device=x_hot.device) output_scores = torch.zeros(self.max_seq_len, batch_size, self.vocab_size, device=x_hot.device) for t in range(self.max_seq_len): h = self.sender_grucell(input_word, h) #h, c = self.sender_lstmcell(input_word, (h, c)) #if lstm output_score = self.h_w(h) output_scores[t] = F.log_softmax(output_score, dim=1) if self.eval_mode: output_word = torch.eye(output_score.size()[1])[torch.argmax( output_score, dim=1)].to(device=x_hot.device) else: output_word = F.gumbel_softmax(output_score, hard=True, tau=self.tau) output_words[t] = output_word #input_word = output_word.detach() #what if we don't detach? input_word = output_word return output_words, output_scores, sr
def forward(self, thetas): device = thetas.device x = thetas.permute(0, 2, 1) # N x 6000 x T -> N x T x 6000 # x = torch.log(x) x = (F.gumbel_softmax(x, hard=True) if self.phase == "train" else self.softmax(x * 1e9)) indices = torch.arange(6000, device=device).float() # (n_bpm) softargmax = torch.matmul(x, indices) # N x T thetas = softargmax * 2 * math.pi / 6000 batch_size = thetas.size()[0] zero = torch.zeros(batch_size, 1, device=device) thetas_t_1 = torch.cat([zero, thetas], axis=1)[:, :-1] diff1 = self.relu(thetas - thetas_t_1) diff2 = 2 * math.pi - self.relu(thetas_t_1 - thetas) # N x T diff = torch.stack([diff1, diff2], dim=2) delta_beattheta, _ = torch.min(diff, dim=2) # N x T kernel_size = 21 padding = int((kernel_size - 1) / 2) delta_beattheta = delta_beattheta.unsqueeze(1) delta_beattheta = F.pad(delta_beattheta, (padding, padding), "reflect") delta_beattheta = delta_beattheta.unfold(-1, kernel_size, 1) # N x T x kernel delta_beattheta, _ = torch.median(delta_beattheta, dim=-1) # N x T delta_beattheta = delta_beattheta.squeeze(1) return delta_beattheta