Example #1
0
    def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
        """Deals with the instability of the gumbel_softmax for older versions of torch.

        For more details about the issue:
        https://drive.google.com/file/d/1AA5wPfZ1kquaRtVruCd6BiYZGcDeNxyP/view?usp=sharing
        Args:
            logits:
                […, num_features] unnormalized log probabilities
            tau:
                non-negative scalar temperature
            hard:
                if True, the returned samples will be discretized as one-hot vectors,
                but will be differentiated as if it is the soft sample in autograd
            dim (int):
                a dimension along which softmax will be computed. Default: -1.
        Returns:
            Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
        """
        if version.parse(torch.__version__) < version.parse("1.2.0"):
            for i in range(10):
                transformed = functional.gumbel_softmax(logits,
                                                        tau=tau,
                                                        hard=hard,
                                                        eps=eps,
                                                        dim=dim)
                if not torch.isnan(transformed).any():
                    return transformed
            raise ValueError("gumbel_softmax returning NaN.")

        return functional.gumbel_softmax(logits,
                                         tau=tau,
                                         hard=hard,
                                         eps=eps,
                                         dim=dim)
 def forward(self, input, hidden0, previous_head, embedded_stack,temperature):
     if type(temperature) != tuple:
         temperature = (temperature, temperature, temperature)
     output, (ht, ct) = self.lstm(input, hidden0)
     d_select = F.gumbel_softmax(self.selectlinear(output.view(-1)),tau=temperature[2])
     
     read = (embedded_stack[:,0,:]*previous_head.view(-1,1).repeat(1,self.M_size[2])).sum(dim=0)
     ht = self.tanh(ht + self.tanh(self.plinear(read.view(1, -1)).view(1, 1, -1)))
     decision_input = output.view(-1)
     d_emb,d_cur = F.gumbel_softmax(self.emblinear(decision_input),tau=temperature[0]),F.gumbel_softmax(self.curlinear(decision_input),tau=temperature[1])
     y = self.olinear(output.view(-1))
     
     estack_symb = torch.zeros([1,self.M_size[1],self.M_size[2]]).to(device)
     estack_symb[0,0,:] = self.sigmoid(self.esymblinear(ht)).view(-1)
     stack_symb = self.sigmoid(self.esymblinear(ht)).view(1,1,-1).repeat(self.M_size[0],1,1)
     
     emb_push = torch.cat([estack_symb,embedded_stack[0:self.M_size[0]-1,:,:]],0)
     emb_pop = torch.cat([embedded_stack[1:self.M_size[0],:,:],torch.zeros([1,self.M_size[1],self.M_size[2]]).to(device)],0)
     embedded_stack_1 = emb_push*d_emb[0] + embedded_stack*d_emb[1] + emb_pop*d_emb[2]
     #embedded_stack_1 = embedded_stack
     stack_push = torch.cat([stack_symb,embedded_stack_1[:,0:self.M_size[1]-1,:]],1)
     stack_pop = torch.cat([embedded_stack_1[:,1:self.M_size[1],:],torch.zeros([self.M_size[0],1,self.M_size[2]]).to(device)],1)
     embedded_stack_2 = stack_push*d_cur[0] + embedded_stack_1*d_cur[1] + stack_pop*d_cur[2]
     new_embedded_stack = embedded_stack_1 * (1 - previous_head.view(-1,1,1).repeat([1,self.M_size[1],self.M_size[2]])) + embedded_stack_2 * previous_head.view(-1,1,1).repeat([1,self.M_size[1],self.M_size[2]])
     
     shift_right = torch.cat([torch.zeros([1]).to(device),previous_head[0:self.M_size[0]-1]],0) 
     shift_left = torch.cat([previous_head[1:self.M_size[0]],torch.zeros([1]).to(device)],0)
     next_head = shift_right*d_select[0] + previous_head*d_select[1] + shift_left*d_select[2]
     next_head = next_head/next_head.sum()
     
     #next_read = (new_embedded_stack[:,0,:]*next_head.view(-1,1).repeat(1,self.M_size[2])).sum(dim=0)
     #ct = self.tanh(ct + self.tanh(self.plinear(next_read.view(1, -1)).view(1, 1, -1)))
     
     debug_tuple=torch.cat([d_emb,d_cur,d_select],0)
     return y, (ht.view(1,1,-1),ct), next_head,new_embedded_stack,debug_tuple
Example #3
0
    def postprocess(self, inputs, method, temperature=1.):
        def listify(x):
            return x if type(x) == list or type(x) == tuple else [x]

        def delistify(x):
            return x if len(x) > 1 else x[0]

        if method == 'soft_gumbel':
            softmax = [
                F.gumbel_softmax(
                    e_logits.contiguous().view(-1, e_logits.size(-1)) /
                    temperature,
                    hard=False).view(e_logits.size())
                for e_logits in listify(inputs)
            ]
        elif method == 'hard_gumbel':
            softmax = [
                F.gumbel_softmax(
                    e_logits.contiguous().view(-1, e_logits.size(-1)) /
                    temperature,
                    hard=True).view(e_logits.size())
                for e_logits in listify(inputs)
            ]
        else:
            softmax = [
                F.softmax(e_logits / temperature, -1)
                for e_logits in listify(inputs)
            ]

        return [delistify(e) for e in (softmax)]
Example #4
0
  def forward(self, x):
    """to understand more about this forward pass please refer to the VQVAE_v3.forward
    method which has much better documentation."""
    enc_out = self.enc(x)  # [B, (H*W)//16, n_embd]
    if self.codebook is not None:
      if self.training:
        softmax = F.gumbel_softmax(enc_out, tau=1., hard=True, dim=-1)
      else:
        softmax = F.softmax(enc_out, dim=-1)
        softmax = F.one_hot(torch.argmax(softmax, dim=-1))
      quantized_inputs = einsum("bdhw,dn->bnhw", softmax, self.codebook.weight)
    else:
      if self.training:
        softmax = F.gumbel_softmax(enc_out, tau=1., hard=True, dim=-1)
      else:
        softmax = F.softmax(enc_out, dim=-1)
        softmax = F.one_hot(torch.argmax(softmax, dim=-1))
      quantized_inputs = softmax
    
    encoding_ids = torch.argmax(softmax, dim=-1).view(enc_out.size(0), -1)
    dec_out = self.dec(quantized_inputs)
    loss = F.mse_loss(dec_out, x)

    # encoding_ids, loss, recons
    return encoding_ids, loss, dec_out
Example #5
0
    def sub_scheduler(self, sub_scheduler_mlp, hidden_state, agent_mask, directed=True):
        """
        Function to perform a sub-scheduler

        Arguments: 
            sub_scheduler_mlp (nn.Sequential): the MLP layers in a sub-scheduler
            hidden_state (tensor): the encoded messages input to the sub-scheduler [n * hid_size]
            agent_mask (tensor): [n * 1]
            directed (bool): decide if generate directed graphs

        Return:
            adj (tensor): a adjacency matrix which is the communication graph [n * n]  
        """

        # hidden_state: [n * hid_size]
        n = self.args.nagents
        hid_size = hidden_state.size(-1)
        # hard_attn_input: [n * n * (2*hid_size)]
        hard_attn_input = torch.cat([hidden_state.repeat(1, n).view(n * n, -1), hidden_state.repeat(n, 1)], dim=1).view(n, -1, 2 * hid_size)
        # hard_attn_output: [n * n * 2]
        if directed:
            hard_attn_output = F.gumbel_softmax(sub_scheduler_mlp(hard_attn_input), hard=True)
        else:
            hard_attn_output = F.gumbel_softmax(0.5*sub_scheduler_mlp(hard_attn_input)+0.5*sub_scheduler_mlp(hard_attn_input.permute(1,0,2)), hard=True)
        # hard_attn_output: [n * n * 1]
        hard_attn_output = torch.narrow(hard_attn_output, 2, 1, 1)
        # agent_mask and agent_mask_transpose: [n * n]
        agent_mask = agent_mask.expand(n, n)
        agent_mask_transpose = agent_mask.transpose(0, 1)
        # adj: [n * n]
        adj = hard_attn_output.squeeze() * agent_mask * agent_mask_transpose
        
        return adj
Example #6
0
    def forward(self,
                user_query,
                item_query,
                user_context,
                item_context,
                user_key_mask,
                item_key_mask,
                mode="Train"):
        item_query = self.transform(item_query).unsqueeze(dim=1)
        user_output, user_weights = self.attention(item_query, user_context,
                                                   user_context, user_key_mask)

        user_query = self.transform(user_query).unsqueeze(dim=1)
        item_output, item_weights = self.attention(user_query, item_context,
                                                   item_context, item_key_mask)

        if mode == "Test":
            user_weights = torch.argmax(user_weights, dim=-1)
            item_weights = torch.argmax(item_weights, dim=-1)
            user_tensor, item_tensor = user_output, item_output
        else:
            user_weights = F.gumbel_softmax(user_weights, hard=True)
            user_tensor = torch.bmm(user_weights.float(), user_context)
            item_weights = F.gumbel_softmax(item_weights, hard=True)
            item_tensor = torch.bmm(item_weights, item_context)

        predicted = self.activation(user_tensor * item_tensor)
        return predicted, user_weights, item_weights
Example #7
0
    def forward(self, input, discrete=False):
        # NASBench only has one input to each cell
        s0 = self.stem(input)
        for i, cell in enumerate(self.cells):
            if i in [self._layers // 3, 2 * self._layers // 3]:
                # Perform down-sampling by factor 1/2
                # Equivalent to https://github.com/google-research/nasbench/blob/master/nasbench/lib/model_builder.py#L68
                s0 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)(s0)
            # If using discrete architecture from random_ws search with weight sharing then pass through architecture
            # weights directly.
            # For GDAS use gumbel softmax hard, therefore per mixed block only a single operation is evaluated
            preprocess_op_mixed_op = lambda x: x if discrete else F.gumbel_softmax(x, tau=self.tau, hard=True, dim=-1)
            # Don't use hard for the rest, because it very quickly gave exploding gradients
            preprocess_op = lambda x: x if discrete else F.gumbel_softmax(x, tau=self.tau, hard=False, dim=-1)

            # Normalize mixed_op weights for the choice blocks in the graph
            mixed_op_weights = preprocess_op_mixed_op(self._arch_parameters[0])
            # Normalize the output weights
            output_weights = preprocess_op(self._arch_parameters[1]) if self._output_weights else None
            # Normalize the input weights for the nodes in the cell
            input_weights = [preprocess_op(alpha) for alpha in self._arch_parameters[2:]]
            s0 = cell(s0, mixed_op_weights, output_weights, input_weights)

        # Include one more preprocessing step here
        s0 = self.postprocess(s0)  # [N, C_max * multiplier, w, h] -> [N, C_max, w, h]

        # Global Average Pooling by averaging over last two remaining spatial dimensions
        # Like in nasbench: https://github.com/google-research/nasbench/blob/master/nasbench/lib/model_builder.py#L92
        out = s0.view(*s0.shape[:2], -1).mean(-1)
        logits = self.classifier(out.view(out.size(0), -1))
        return logits
 def forward(self):
     dag = torch.zeros(self.n_nodes, self.n_nodes)   # the final dag
     sampled = np.zeros(self.n_nodes, dtype=bool)     # set of nodes that children were sampled for
     # sample roots
     log_p_roots = F.logsigmoid(self.root_probs)     # numerically stable
     p_log = torch.stack((log_p_roots, torch.log(1 - torch.exp(log_p_roots))))
     roots = gumbel_softmax(p_log, hard=True, dim=0)[0]
     self.log(f'sampled roots {roots}')
     to_sample = roots.nonzero().view(-1).tolist()  # list of nodes that will get children sampled
     # sample children
     log_p_edges = F.logsigmoid(self.edge_probs)
     ancestors = torch.eye(self.n_nodes, dtype=torch.uint8)
     count = 0
     while(len(to_sample) > 0):
         # pick random element to sample nodes for
         i= to_sample.pop(0)
         if sampled[i]:
             continue
         self.log(f'sampling children for {i}')
         # don't sample ancestors and roots as children
         candidates = (1-ancestors[i,:].float()) * (1-roots)
         # sample children for node i
         p_log = torch.stack((log_p_edges[i,:], torch.log(1 - torch.exp(log_p_edges[i,:]))))
         dag[i,:] = gumbel_softmax(p_log, hard= True, dim=0)[0] * candidates.float()
         for j in dag[i,:].nonzero().view(-1).tolist():
             self.log(f'sampled {j}')
             # add i to ancestors of j
             ancestors[j,i] = 1
             # add all ancestors of i to j
             ancestors[j,:][ancestors[i,:]] = 1
             to_sample.append(j)
         sampled[i] = True
     return dag
Example #9
0
    def forward(self, output_sizes, hold_seed=None, hold_initial_set=False):
        """
        Sample from prior
        :param output_sizes: Tensor([B,])
        :param hold_seed
        :param hold_initial_set
        :return: Tensor([B, N, D])
        """
        bsize = output_sizes.shape[0]
        if hold_initial_set:  # [B, N]
            x_mask = get_mask(output_sizes, self.max_outputs)
        else:
            x_mask = sample_mask(output_sizes, self.max_outputs)

        if hold_seed is not None:  # [B, N, Ds]
            torch.random.manual_seed(hold_seed)
            eps = torch.randn([1, self.max_outputs, self.dim_seed
                               ]).to(x_mask.device).repeat(bsize, 1, 1)
        else:
            eps = torch.randn([bsize, self.max_outputs,
                               self.dim_seed]).to(x_mask.device)

        if self.n_mixtures == 1:
            x = self.mu + torch.exp(self.logvar / 2.) * eps
        else:
            if self.train_gmm:
                if hold_seed is not None:
                    torch.random.manual_seed(hold_seed)
                    logits = self.logits.reshape([1, 1,
                                                  self.n_mixtures]).repeat(
                                                      1, self.max_outputs,
                                                      1)  # [1, N, M]
                    onehot = F.gumbel_softmax(
                        logits, tau=self.tau,
                        hard=True).repeat(bsize, 1,
                                          1).unsqueeze(-1)  # [B, N, M, 1]
                else:
                    logits = self.logits.reshape([1, 1,
                                                  self.n_mixtures]).repeat(
                                                      bsize, self.max_outputs,
                                                      1)  # [B, N, M]
                    onehot = F.gumbel_softmax(logits, tau=self.tau,
                                              hard=True).unsqueeze(
                                                  -1)  # [B, N, M, 1]
                mu = self.mu.reshape([1, 1, self.n_mixtures,
                                      self.dim_seed])  # [1, 1, M, D]
                sig = self.sig.reshape([1, 1, self.n_mixtures,
                                        self.dim_seed])  # [1, 1, M, D]
                mu = (mu * onehot).sum(2)  # [B, N, D]
                sig = (sig * onehot).sum(2)  # [B, N, D]
                x = mu + sig * eps
            else:
                mix = D.Categorical(self.logits)
                comp = D.Independent(D.Normal(self.mu, self.sig.abs()), 1)
                mixture = D.MixtureSameFamily(mix, comp)
                x = mixture.sample((output_sizes.size(0), self.max_outputs))

        x = self.output(x)  # [B, N, D]
        return x, x_mask
Example #10
0
def apply_activate(data,tanh_list,soft_list):
    temp = torch.sigmoid(data[:,tanh_list[0]:(tanh_list[-1]+1)])
    for i in range(len(soft_list)-1):
        tem_soft = F.gumbel_softmax(data[:,soft_list[i]:soft_list[i+1]], tau=0.2)
        temp = torch.cat([temp,tem_soft],1)
    tem_soft = F.gumbel_softmax(data[:,soft_list[-1]:], tau=0.2)
    temp = torch.cat([temp,tem_soft],1)
    return temp
Example #11
0
def parse_gumbel(alpha, beta, k):
    """
    parse continuous alpha to discrete gene.
    alpha is ParameterList:
    ParameterList [
        Parameter(n_edges1, n_ops),
        Parameter(n_edges2, n_ops),
        ...
    ]

    beta is ParameterList:
    ParameterList [
        Parameter(n_edges1),
        Parameter(n_edges2),
        ...
    ]

    gene is list:
    [
        [('node1_ops_1', node_idx), ..., ('node1_ops_k', node_idx)],
        [('node2_ops_1', node_idx), ..., ('node2_ops_k', node_idx)],
        ...
    ]
    each node has two edges (k=2) in CNN.
    """

    gene = []
    assert PRIMITIVES[-1] == 'none'  # assume last PRIMITIVE is 'none'

    # 1) Convert the mixed op to discrete edge (single op) by choosing top-1 weight edge
    # 2) Choose top-k edges per node by edge score (top-1 weight in edge)
    # output the connect idx[(node_idx, connect_idx, op_idx).... () ()]
    connect_idx = []
    for edges, w in zip(alpha, beta):
        # edges: Tensor(n_edges, n_ops)
        discrete_a = F.gumbel_softmax(edges[:, :-1].reshape(-1),
                                      tau=1,
                                      hard=True)
        for i in range(k - 1):
            discrete_a = discrete_a + F.gumbel_softmax(
                edges[:, :-1].reshape(-1), tau=1, hard=True)
        discrete_a = discrete_a.reshape(-1, len(PRIMITIVES) - 1)
        reserved_edge = (discrete_a > 0).nonzero()

        node_gene = []
        node_idx = []
        for i in range(reserved_edge.shape[0]):
            edge_idx = reserved_edge[i][0].item()
            prim_idx = reserved_edge[i][1].item()
            prim = PRIMITIVES[prim_idx]
            node_gene.append((prim, edge_idx))
            node_idx.append((edge_idx, prim_idx))

        gene.append(node_gene)
        connect_idx.append(node_idx)

    return gene, connect_idx
Example #12
0
    def select_action_old(self, obs, valid_actions):
        '''
        from logit to pysc2 actions
        :param logits: {'categorical': [], 'screen1': [], 'screen2': []}
        :return: FunctionCall form of action
        '''
        obs_torch = {'categorical': 0, 'screen1': 0, 'screen2': 0}
        for o in obs:
            x = obs[o].astype('float32')
            x = np.expand_dims(x, 0)
            obs_torch[o] = torch.from_numpy(x).to(arglist.DEVICE)

        logits = self.actor(obs_torch)
        logits['categorical'] = self._mask_unavailable_actions(
            logits['categorical'], valid_actions)
        tau = 1.0
        function_id = gumbel_softmax(logits['categorical'],
                                     tau=1e-10,
                                     hard=True)
        function_id = function_id.argmax().item()
        logits['categorical'].cpu().item()
        # select an action until it is valid.
        is_valid_action = self._test_valid_action(function_id, valid_actions)
        while not is_valid_action:
            tau *= 10
            function_id = gumbel_softmax(logits['categorical'],
                                         tau=tau,
                                         hard=True)
            function_id = function_id.argmax().item()
            is_valid_action = self._test_valid_action(function_id,
                                                      valid_actions)

        pos_screen1 = gumbel_softmax(logits['screen1'].view(1, -1),
                                     hard=True).argmax().item()
        pos_screen2 = gumbel_softmax(logits['screen2'].view(1, -1),
                                     hard=True).argmax().item()

        pos = [[
            int(pos_screen1 % arglist.FEAT2DSIZE),
            int(pos_screen1 // arglist.FEAT2DSIZE)
        ],
               [
                   int(pos_screen2 % arglist.FEAT2DSIZE),
                   int(pos_screen2 // arglist.FEAT2DSIZE)
               ]]  # (x, y)

        args = []
        cnt = 0
        for arg in actions.FUNCTIONS[function_id].args:
            if arg.name in ['screen', 'screen2', 'minimap']:
                args.append(pos[cnt])
                cnt += 1
            else:
                args.append([0])

        action = actions.FunctionCall(function_id, args)
        return action
Example #13
0
    def forward(self, x, y_, adj, non_label):
        if self.training:
            x = x.contiguous().view(-1, self.x_dim)
            y_ = y_.contiguous().view(-1, self.y_dim)
            # x2y
            y_encode, y_embedding = self.x_to_yu(x)

            q_dis_total, y_total, y_pred_total_total = [], [], []
            for i in range(1):
                y = y_.clone()
                y[non_label] = F.gumbel_softmax(y_encode[non_label],
                                                tau=1.0,
                                                hard=True)
                y_total.append(y)
                # encode
                r_nodes = self.xy_to_r(y_embedding, y)
                r_graph = self.r_aggregate(r_nodes)
                mu, sigma = self.r_to_musigma(r_graph)

                q_dis = Normal(mu, sigma)
                q_dis_total.append(q_dis)
                y_pred_total = []
                for _ in range(1):
                    z_sample = q_dis.rsample()
                    #Decode
                    y_pred = self.x_to_y(x, z_sample)
                    y_pred_total.append(y_pred)
                y_pred_total_total.append(y_pred_total)

            return y_pred_total_total, q_dis_total, y_total, y_encode
        else:
            x = x.contiguous().view(-1, self.x_dim)
            y_ = y_.contiguous().view(-1, self.y_dim)
            y_encode, y_embedding = self.x_to_yu(x)
            y_pred_total = []
            for i in range(40):
                y = y_.clone()
                y[non_label] = F.gumbel_softmax(y_encode[non_label],
                                                tau=1.0,
                                                hard=True)

                # encode
                r_nodes = self.xy_to_r(y_embedding, y)
                r_graph = self.r_aggregate(r_nodes)
                mu, sigma = self.r_to_musigma(r_graph)

                q_dis = Normal(mu, sigma)
                for _ in range(1):
                    z_sample = q_dis.rsample()
                    #Decode
                    y_pred = self.x_to_y(x, z_sample)
                    y_pred_total.append(y_pred)

            y_pred = sum(y_pred_total) / len(y_pred_total)

            return y_pred, y_encode
Example #14
0
    def encode(self, input, tau=1):
        enc_b = self.enc_b(input)
        enc_t = self.enc_t(enc_b)

        quant_t = self.quantize_conv_t(enc_t)

        latent = F.gumbel_softmax(quant_t, tau=tau, hard=True, dim=1)
        latent_distribution = F.gumbel_softmax(quant_t, tau=tau, hard=False, dim=1)

        return latent, latent_distribution
Example #15
0
 def forward(self, x, temp=1):
     x = F.relu(self.fc(x))
     #x = F.relu(self.fc_(x))
     action_score = self.a_head(x)
     #z = torch.nn.functional.one_hot(prob.max(1)[1], num_classes=self.categorical_dim).view(-1, self.categorical_dim)
     return F.gumbel_softmax(action_score, hard=True, dim=-1,
                             tau=temp), F.gumbel_softmax(action_score,
                                                         hard=False,
                                                         dim=-1,
                                                         tau=temp)
    def forward(self, x):
        if config().sim.env.state.type == "simple":
            x = x.reshape(x.size(0), x.size(2))
            return F.gumbel_softmax(self.simple_fc(x),
                                    tau=config().learning.gumbel_softmax.tau)

        out = self.conv(x)
        out = out.view(x.size(0), -1)
        return F.gumbel_softmax(self.fc(out),
                                tau=config().learning.gumbel_softmax.tau)
Example #17
0
 def sample_search(self):
     result = dict()
     for mutable in self.mutables:
         if isinstance(mutable, LayerChoice):
             # result[mutable.key] = F.gumbel_softmax(self.choices[mutable.key], hard=True, dim=-1).bool()[:-1]
             result[mutable.key] = F.gumbel_softmax(
                 self.choices[mutable.key], hard=True, dim=-1).bool()
         elif isinstance(mutable, InputChoice):
             result[mutable.key] = F.gumbel_softmax(
                 self.choices[mutable.key], hard=True, dim=-1).bool()
     return result
Example #18
0
    def learn(self):
        self.learn_step += 1

        sample_index = np.random.choice(self.memory_capacity, self.batch_size)
        batch_memory = self.memory[sample_index, :]

        # in the memory, the 1st---4th column is state_now , the 5th is action , the 6th is reward
        # the final 4 column is state_next
        batch_s = Variable(torch.FloatTensor(
            batch_memory[:, :self.state_num])).to(self.device)
        batch_a = Variable(
            torch.LongTensor(batch_memory[:, self.state_num:self.state_num +
                                          self.action_num])).to(self.device)
        batch_r = Variable(
            torch.FloatTensor(
                batch_memory[:,
                             self.state_num + self.action_num:self.state_num +
                             self.action_num + 1])).to(self.device)
        batch_next_s = Variable(
            torch.FloatTensor(batch_memory[:,
                                           -self.state_num:])).to(self.device)

        batch_next_a_logits = self.actor_target(batch_next_s)
        batch_target_next_a = F.gumbel_softmax(batch_next_a_logits, dim=-1)

        y_true = batch_r + self.gamma * self.critic_target(
            batch_next_s, batch_target_next_a).detach()

        y_pred = self.critic(batch_s, batch_a.float())

        critic_loss = self.loss(y_pred, y_true)
        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        batch_a_logits = self.actor(batch_s)
        batch_target_a = F.gumbel_softmax(batch_a_logits, dim=-1)

        actor_loss = -torch.mean(self.critic(batch_s, batch_target_a))
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        for target_param, param in zip(self.critic_target.parameters(),
                                       self.critic.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - self.tau) +
                                    param.data * self.tau)

        for target_param, param in zip(self.actor_target.parameters(),
                                       self.actor.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - self.tau) +
                                    param.data * self.tau)

        return critic_loss.item(), actor_loss.item()
Example #19
0
 def forward(self, input):
     batch, C, H, W = input.size()
     s0 = s1 = self.stem(input)
     for i, cell in enumerate(self.cells):
         if cell.reduction:
             weights = F.gumbel_softmax(self.alphas_reduce, self.tau, True)
         else:
             weights = F.gumbel_softmax(self.alphas_normal, self.tau, True)
         s0, s1 = s1, cell(s0, s1, weights)
     out = self.global_pooling(s1)
     logits = self.classifier(out.view(out.size(0), -1))
     return logits
def train(train_type):
    if train_type == 0:  # simple VGAE trained
        model.eval()
        for param in model.parameters():
            param.requires_grad = True
        model.psi.requires_grad = False

        model.train()
        optimizer.zero_grad()
        z = model.encode(x, train_pos_edge_index)
        l_kl_z = 1.0 * model.kl_loss() / data.num_nodes
        l_recon = model.recon_loss(z, train_pos_edge_index)
        l_kl_c = 0
        loss = l_recon + l_kl_z
        loss.backward()
        optimizer.step()
    elif train_type == 1:
        model.eval()
        for param in model.parameters():
            param.requires_grad = False
        model.psi.requires_grad = True

        model.train()
        optimizer.zero_grad()
        z = model.encode(x, train_pos_edge_index)
        pc_given_Z, qc_given_ZA = model.community_dists_probs(
            z, train_pos_edge_index)
        c = F.gumbel_softmax(qc_given_ZA.logits, tau=1, hard=True)
        l_kl_z = 1.0 * model.kl_loss() / data.num_nodes
        l_recon = model.recon_loss((z, c), train_pos_edge_index)
        l_kl_c = 1.0 * kl_divergence(qc_given_ZA, pc_given_Z).mean()
        loss = l_recon + l_kl_z
        loss.backward()
        optimizer.step()
    else:
        model.eval()
        for param in model.parameters():
            param.requires_grad = True

        model.train()
        optimizer.zero_grad()
        z = model.encode(x, train_pos_edge_index)
        pc_given_Z, qc_given_ZA = model.community_dists_probs(
            z, train_pos_edge_index)
        c = F.gumbel_softmax(qc_given_ZA.logits, tau=1, hard=True)
        l_kl_z = 1.0 * model.kl_loss() / data.num_nodes
        l_kl_c = 1.0 * kl_divergence(qc_given_ZA, pc_given_Z).mean()
        l_recon = model.recon_loss((z, c), train_pos_edge_index)
        loss = l_recon + l_kl_z + l_kl_c
        loss.backward()
        optimizer.step()

    return l_recon, l_kl_z, l_kl_c
Example #21
0
    def get_weight(self, tau):
        if self.mode == "softmax":
            weight_normal = f.softmax(self.alpha_normal / tau, dim=-1)
            weight_reduce = f.softmax(self.alpha_reduce / tau, dim=-1)
        elif self.mode == "sigmoid":
            weight_normal = f.sigmoid(self.alpha_normal / tau)
            weight_reduce = f.sigmoid(self.alpha_reduce / tau)
        elif self.mode == "sigmoid":
            weight_normal = f.gumbel_softmax(self.alpha_normal, tau, dim=-1)
            weight_reduce = f.gumbel_softmax(self.alpha_reduce, tau, dim=-1)
        else:
            raise NotImplementedError(f"{self.mode} not implemented.")

        return weight_normal, weight_reduce
Example #22
0
 def forward(self, x, decode=True):
   params = self._get_dists_params(x)
   zs = []
   if 'cont' in params.keys():
     zs = [self._reparam_gauss(*params['cont'])]
   if 'cat' in params.keys():
     for logits in params['cat']:
       if self.training: zs.append(F.gumbel_softmax(logits, tau=self.temp))
       else: zs.append(F.gumbel_softmax(logits, tau=self.temp, hard=True))
   z = torch.cat(zs, dim=1)
   
   if decode: recon = self.decoder(z)
   else: recon = None
   return recon, z, params
    def forward(self, input, hidden, encoder_outputs, mask):

        embedded = self.embedding(input).permute(1, 0, 2)
        embedded = self.dropout(embedded)
        output, hidden = self.gru(embedded, hidden)
        #print(encoder_outputs.shape) #max_length, batch_size, dim
        #print(hidden[-1].shape) #batch_size, dim, 1
        encoder_outputs_s = encoder_outputs.transpose(
            0, 1)  #batch_size, max_len, dim
        #print(encoder_outputs.shape)
        #print(hidden.shape)
        #print(hidden[-1].shape)
        score = torch.bmm(encoder_outputs_s,
                          hidden[-1].unsqueeze(2)).squeeze(2)
        #print(score.shape) #bts, max_len
        #print(score)
        #print(mask)
        score = score.masked_fill(mask == 0, -1e10)
        #print(score.shape)
        #print(score)
        attn_weights = F.softmax(score, dim=1)
        attn_weights = attn_weights.unsqueeze(2)
        #(bts, dim, max_len)(bts, max_len, 1)
        c_t = torch.bmm(encoder_outputs_s.transpose(1, 2), attn_weights)
        #print(embedded.shape) #bts, 1, dim
        #print(attn_applied.shape) #bts, dim, 1
        W_hc_t = self.Wh(c_t.squeeze(2))
        #print(W_hc_t.shape)
        #print(hidden.shape)
        #print(hidden[-1].shape)
        U_hh_t = self.Uh(hidden[-1])
        g = self.out(torch.tanh(U_hh_t + W_hc_t)).unsqueeze(0)
        #print(g.shape)
        g_ls = F.log_softmax(g, dim=2)
        #print(g_ls)
        #print(g_ls.shape)
        abc = F.gumbel_softmax(g, tau=0.9, hard=False, eps=1e-10, dim=2)
        '''print(abc)
        print(abc.shape)
        abc = F.gumbel_softmax(g_ls, tau=0.9, hard=False, eps=1e-10, dim=2)
        print(abc)
        print(abc.shape)'''

        cba = F.gumbel_softmax(g, tau=0.9, hard=True, eps=1e-10, dim=2)
        #print(cba)
        #print(cba.shape)
        #print(1/0)
        #g = g.squeeze(0)
        return cba, g_ls, hidden, attn_weights
        '''embedded = self.embedding(input)
Example #24
0
    def forward(self, x_f, y_f, y_c):
        """ Computes the correspondences in the feature space based on the selected parameters.

        Args:
            x_f (torch.tensor): infered features of points x [b,n,c] 
            y_f (torch.tensor): infered features of points y [b,m,c] 
            y_c (torch.tensor): coordinates of point y [b,m,3]

        Returns:
            x_corr (torch.tensor): coordinates of the feature based correspondences of points x [b,n,3]
         
        """

        dist = pairwise_distance(x_f, y_f).detach()

        if self.corr_type == 'soft':

            y_soft = torch.softmax(-dist / (self.get_temp()), dim=2)

            if self.st:
                # Straight through.
                index = y_soft.max(dim=2, keepdim=True)[1]
                y_hard = torch.zeros_like(y_soft).scatter_(dim=2,
                                                           index=index,
                                                           value=1.0)
                ret = y_hard - y_soft.detach() + y_soft

            else:
                ret = y_soft

        elif self.corr_type == 'soft_gumbel':

            if self.st:
                # Straight through.
                ret = F.gumbel_softmax(-dist, tau=self.get_temp(), hard=True)
            else:
                ret = F.gumbel_softmax(-dist, tau=self.get_temp(), hard=False)

        else:
            index = dist.min(dim=2, keepdim=True)[1]
            ret = torch.zeros_like(dist).scatter_(dim=2,
                                                  index=index,
                                                  value=1.0)

        # Compute corresponding coordinates
        x_corr = torch.matmul(ret, y_c)

        return x_corr
Example #25
0
 def forward(self, x):
     xs = tuple(layer(x) for layer in self.layers)
     logits = tuple(F.log_softmax(x, dim=1) for x in xs)
     categorical_outputs = tuple(
         F.gumbel_softmax(logit, tau=self.tau, hard=True, eps=1e-10)
         for logit in logits)
     return torch.cat(categorical_outputs, 1)
Example #26
0
 def forward(self, hidden_vec):
     hs = F.softplus(self.l1(hidden_vec))
     logit = torch.log(hs + 1e-08).view(-1, self.M, self.K).view(-1, self.K)
     probs = F.gumbel_softmax(logit, self.tau).view(-1, self.M * self.K)
     # probs ==> batchsize, M * K
     code_sum = torch.matmul(probs, self.codebook)
     return code_sum
Example #27
0
    def calculate_block_probability(self, arch_param, tau):
        """
        Encode arch param to probability for generator training
        """
        arch_param = arch_param.view(
            len(self.CONFIG.l_cfgs),
            self.CONFIG.split_blocks * self.CONFIG.kernels_nums)
        p_arch_param = torch.zeros_like(arch_param)

        for l_num, (l_cfg, l, p_l) in enumerate(
                zip(self.CONFIG.l_cfgs, arch_param, p_arch_param)):
            expansion, output_channel, kernels, stride, split_block, se = l_cfg

            for b in range(expansion):
                if b == 0 and l_num in self.CONFIG.static_layers:
                    end_index = (b + 1) * split_block - 1
                    split_arch_param = l[b *
                                         split_block:(b + 1) * split_block - 1]
                else:
                    end_index = (b + 1) * split_block
                    split_arch_param = l[b * split_block:(b + 1) * split_block]
                p_l[b * split_block:end_index] = \
                    F.gumbel_softmax(split_arch_param, tau=tau)

        return p_arch_param
    def forward(self, ques):
        # input
        # ques - shape: (batch_size, num_rounds, word_embedding_size)
        # output
        # ques_gs - shape: (batch_size, num_rounds, 2)
        # Lambda - shape: (batch_size, num_rounds, 2)

        batch_size = ques.size(0)
        num_rounds = ques.size(1)

        ques_embed = self.embed(
            ques)  # shape: (batch_size, num_rounds, lstm_hidden_size)
        ques_embed = F.normalize(
            ques_embed, p=2,
            dim=-1)  # shape: (batch_size, num_rounds, lstm_hidden_size)
        ques_logits = self.att(
            ques_embed)  # shape: (batch_size, num_rounds, 2)

        logits = ques_logits.view(-1, 2)
        if self.training:
            ques_gs = F.gumbel_softmax(logits,
                                       hard=True)  # shape: (batch_size, 2)
        else:
            _, max_value_indexes = logits.detach().max(1, keepdim=True)
            ques_gs = logits.detach().clone().zero_().scatter_(
                1, max_value_indexes, 1)
        ques_gs = ques_gs.view(-1, num_rounds, 2)

        Lambda = self.softmax(ques_logits)

        return ques_gs, Lambda  # discrete, continuous
 def input_to_sentence(self, x_hot):
     batch_size = x_hot.size()[0]
     h = self.i_h(x_hot)
     sr = self.s_r(h)
     #c = self.i_h(x_hot)
     c = torch.zeros_like(h)  #if lstm
     input_word = torch.zeros(batch_size,
                              self.vocab_size,
                              device=x_hot.device)
     output_words = torch.zeros(self.max_seq_len,
                                batch_size,
                                self.vocab_size,
                                device=x_hot.device)
     output_scores = torch.zeros(self.max_seq_len,
                                 batch_size,
                                 self.vocab_size,
                                 device=x_hot.device)
     for t in range(self.max_seq_len):
         h = self.sender_grucell(input_word, h)
         #h, c = self.sender_lstmcell(input_word, (h, c)) #if lstm
         output_score = self.h_w(h)
         output_scores[t] = F.log_softmax(output_score, dim=1)
         if self.eval_mode:
             output_word = torch.eye(output_score.size()[1])[torch.argmax(
                 output_score, dim=1)].to(device=x_hot.device)
         else:
             output_word = F.gumbel_softmax(output_score,
                                            hard=True,
                                            tau=self.tau)
         output_words[t] = output_word
         #input_word = output_word.detach()  #what if we don't detach?
         input_word = output_word
     return output_words, output_scores, sr
 def forward(self, thetas):
     device = thetas.device
     x = thetas.permute(0, 2, 1)  # N x 6000 x T -> N x T x 6000
     # x = torch.log(x)
     x = (F.gumbel_softmax(x, hard=True)
          if self.phase == "train" else self.softmax(x * 1e9))
     indices = torch.arange(6000, device=device).float()  # (n_bpm)
     softargmax = torch.matmul(x, indices)  # N x T
     thetas = softargmax * 2 * math.pi / 6000
     batch_size = thetas.size()[0]
     zero = torch.zeros(batch_size, 1, device=device)
     thetas_t_1 = torch.cat([zero, thetas], axis=1)[:, :-1]
     diff1 = self.relu(thetas - thetas_t_1)
     diff2 = 2 * math.pi - self.relu(thetas_t_1 - thetas)  # N x T
     diff = torch.stack([diff1, diff2], dim=2)
     delta_beattheta, _ = torch.min(diff, dim=2)  # N x T
     kernel_size = 21
     padding = int((kernel_size - 1) / 2)
     delta_beattheta = delta_beattheta.unsqueeze(1)
     delta_beattheta = F.pad(delta_beattheta, (padding, padding), "reflect")
     delta_beattheta = delta_beattheta.unfold(-1, kernel_size,
                                              1)  # N x T x kernel
     delta_beattheta, _ = torch.median(delta_beattheta, dim=-1)  # N x T
     delta_beattheta = delta_beattheta.squeeze(1)
     return delta_beattheta