Example #1
0
    def sample_relax(logits, surrogate):
        cat = Categorical(logits=logits)
        u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
        gumbels = -torch.log(-torch.log(u))
        z = logits + gumbels
        b = torch.argmax(z, dim=1) #.view(B,1)
        logprob = cat.log_prob(b).view(B,1)


        # czs = []
        # for j in range(1):
        #     z = sample_relax_z(logits)
        #     surr_input = torch.cat([z, x, logits.detach()], dim=1)
        #     cz = surrogate.net(surr_input)
        #     czs.append(cz)
        # czs = torch.stack(czs)
        # cz = torch.mean(czs, dim=0)#.view(1,1)
        surr_input = torch.cat([z, x, logits.detach()], dim=1)
        cz = surrogate.net(surr_input)


        cz_tildes = []
        for j in range(1):
            z_tilde = sample_relax_given_b(logits, b)
            surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
            cz_tilde = surrogate.net(surr_input)
            cz_tildes.append(cz_tilde)
        cz_tildes = torch.stack(cz_tildes)
        cz_tilde = torch.mean(cz_tildes, dim=0) #.view(B,1)

        return b, logprob, cz, cz_tilde
def sample_relax_given_class(logits, samp):

    cat = Categorical(logits=logits)

    u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels

    b = samp #torch.argmax(z, dim=1)
    logprob = cat.log_prob(b).view(B,1)


    u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
    z_tilde_b = -torch.log(-torch.log(u_b))
    
    z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)


    z = z_tilde

    u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
    z_tilde_b = -torch.log(-torch.log(u_b))
    
    u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
    z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

    return z, z_tilde, logprob
Example #3
0
def relax_grad2(x, logits, b, surrogate, mixtureweights):
    B = logits.shape[0]
    C = logits.shape[1]

    cat = Categorical(logits=logits)
    # u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
    u = myclamp(torch.rand(B,C).cuda())
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    # b = torch.argmax(z, dim=1) #.view(B,1)
    logq = cat.log_prob(b).view(B,1)

    surr_input = torch.cat([z, x, logits.detach()], dim=1)
    cz = surrogate.net(surr_input)

    z_tilde = sample_relax_given_b(logits, b)
    surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
    cz_tilde = surrogate.net(surr_input)

    logpx_given_z = logprob_undercomponent(x, component=b)
    logpz = torch.log(mixtureweights[b]).view(B,1)
    logpxz = logpx_given_z + logpz #[B,1]

    f = logpxz - logq 
    net_loss = - torch.mean( (f.detach() - cz_tilde.detach()) * logq - logq +  cz - cz_tilde )

    grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0] #[B,C]
    pb = torch.exp(logq)

    return grad, pb
def sample_relax(logits): #, k=1):
    

    # u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) #.cuda()
    u = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    b = torch.argmax(z, dim=1)

    cat = Categorical(logits=logits)
    logprob = cat.log_prob(b).view(B,1)

    v_k = torch.rand(B,1).clamp(1e-12, 1.-1e-12)
    z_tilde_b = -torch.log(-torch.log(v_k))
    #this way seems biased even tho it shoudlnt be
    # v_k = torch.gather(input=u, dim=1, index=b.view(B,1))
    # z_tilde_b = torch.gather(input=z, dim=1, index=b.view(B,1))

    v = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
    probs = torch.softmax(logits,dim=1).repeat(B,1)
    # print (probs.shape, torch.log(v_k).shape, torch.log(v).shape)
    # fasdfa

    # print (v.shape)
    # print (v.shape)
    z_tilde = -torch.log((- torch.log(v) / probs) - torch.log(v_k))

    # print (z_tilde)
    # print (z_tilde_b)
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
    # print (z_tilde)
    # fasdfs

    return z, b, logprob, z_tilde
Example #5
0
class OneHotCategorical(Distribution):
    r"""
    Creates a one-hot categorical distribution parameterized by `probs`.

    Samples are one-hot coded vectors of size probs.size(-1).

    See also: :func:`torch.distributions.Categorical`

    Example::

        >>> m = OneHotCategorical(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ]))
        >>> m.sample()  # equal probability of 0, 1, 2, 3
         0
         0
         1
         0
        [torch.FloatTensor of size 4]

    Args:
        probs (Tensor or Variable): event probabilities
    """
    params = {'probs': constraints.simplex}
    support = constraints.simplex
    has_enumerate_support = True

    def __init__(self, probs=None, logits=None):
        self._categorical = Categorical(probs, logits)
        batch_shape = self._categorical.probs.size()[:-1]
        event_shape = self._categorical.probs.size()[-1:]
        super(OneHotCategorical, self).__init__(batch_shape, event_shape)

    def sample(self, sample_shape=torch.Size()):
        sample_shape = torch.Size(sample_shape)
        probs = self._categorical.probs
        one_hot = probs.new(self._extended_shape(sample_shape)).zero_()
        indices = self._categorical.sample(sample_shape)
        if indices.dim() < one_hot.dim():
            indices = indices.unsqueeze(-1)
        return one_hot.scatter_(-1, indices, 1)

    def log_prob(self, value):
        indices = value.max(-1)[1]
        return self._categorical.log_prob(indices)

    def entropy(self):
        return self._categorical.entropy()

    def enumerate_support(self):
        probs = self._categorical.probs
        n = self.event_shape[0]
        if isinstance(probs, Variable):
            values = Variable(torch.eye(n, out=probs.data.new(n, n)))
        else:
            values = torch.eye(n, out=probs.new(n, n))
        values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
        return values.expand((n,) + self.batch_shape + (n,))
Example #6
0
def reinforce_baseline(surrogate, x, logits, mixtureweights, k=1, get_grad=False):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)
    outputs = {}

    cat = Categorical(probs=probs)

    grads =[]
    # net_loss = 0
    for jj in range(k):

        cluster_H = cat.sample()
        outputs['logq'] = logq = cat.log_prob(cluster_H).view(B,1)
        outputs['logpx_given_z'] = logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        outputs['logpz'] = logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]

        surr_pred = surrogate.net(x)

        outputs['f'] = f = logpxz - logq - 1. 
        # outputs['net_loss'] = net_loss = net_loss - torch.mean((f.detach() ) * logq)
        outputs['net_loss'] = net_loss = - torch.mean((f.detach() - surr_pred.detach()) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

        # surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))

        grad_logq =  torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
        surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq )**2)

        if get_grad:
            grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0]
            grads.append(grad)

    # net_loss = net_loss/ k

    if get_grad:
        grads = torch.stack(grads)
        # print (grads.shape)
        outputs['grad_avg'] = torch.mean(torch.mean(grads, dim=0),dim=0)
        outputs['grad_std'] = torch.std(grads, dim=0)[0]

    outputs['surr_loss'] = surr_loss
    # return net_loss, f, logpx_given_z, logpz, logq
    return outputs
Example #7
0
    def get_action(self, x, action=None, prev_pol=None, prev_n=None):
        logits = self.get_logits(x)
        mu = self.get_mu(x)
        scale = self.get_scale(x)
        z = self.forward(x)
        n = self.policy_repeat_sampler(mu, scale)

        if prev_pol is not None and prev_n is not None:
            logits, n = self.repeat_policy(prev_pol, prev_n, logits, n)

        probs = Categorical(logits=logits)

        if action is None:
            #   print("n", n)
            #   print("logits", logits)
            #   print("probs", probs)
            action = probs.sample()

        return action, probs.log_prob(
            action), probs.entropy(), n, logits, mu, scale, z
Example #8
0
    def sample_relax(probs):
        cat = Categorical(probs=probs)
        #Sample z
        u = torch.rand(B, C).cuda()
        u = u.clamp(1e-8, 1. - 1e-8)
        gumbels = -torch.log(-torch.log(u))
        z = torch.log(probs) + gumbels

        b = torch.argmax(z, dim=1)
        logprob = cat.log_prob(b).view(B, 1)

        #Sample z_tilde
        u_b = torch.rand(B, 1).cuda()
        u_b = u_b.clamp(1e-8, 1. - 1e-8)
        z_tilde_b = -torch.log(-torch.log(u_b))
        u = torch.rand(B, C).cuda()
        u = u.clamp(1e-8, 1. - 1e-8)
        z_tilde = -torch.log((-torch.log(u) / probs) - torch.log(u_b))
        z_tilde[:, b] = z_tilde_b
        return z, b, logprob, z_tilde, gumbels
Example #9
0
    def sample_relax(probs):
        cat = Categorical(probs=probs)
        #Sample z
        u = torch.rand(B,C).cuda()
        u = u.clamp(1e-8, 1.-1e-8)
        gumbels = -torch.log(-torch.log(u))
        z = torch.log(probs) + gumbels

        b = torch.argmax(z, dim=1)
        logprob = cat.log_prob(b).view(B,1)

        #Sample z_tilde
        u_b = torch.rand(B,1).cuda()
        u_b = u_b.clamp(1e-8, 1.-1e-8)
        z_tilde_b = -torch.log(-torch.log(u_b))
        u = torch.rand(B,C).cuda()
        u = u.clamp(1e-8, 1.-1e-8)
        z_tilde = -torch.log((- torch.log(u) / probs) - torch.log(u_b))
        z_tilde[:,b] = z_tilde_b
        return z, b, logprob, z_tilde, gumbels
Example #10
0
    def forward(self, encoder_inputs, hx, n_steps, greedy=False):
        _input = encoder_inputs.new_zeros(
            (encoder_inputs.size(0), encoder_inputs.size(2)))
        mask = encoder_inputs.new_zeros(
            (encoder_inputs.size(0), encoder_inputs.size(1)))
        log_ps = []
        actions = []
        entropys = []

        for i in range(n_steps):
            hx = self.cell(_input, hx)
            #                 print (hx.size(),encoder_inputs.size(),mask.size())
            p = self.attn(hx, encoder_inputs, mask)
            dist = Categorical(p)
            entropy = dist.entropy()

            if greedy:
                _, index = p.max(dim=-1)
            else:
                index = dist.sample()

            actions.append(index)
            log_p = dist.log_prob(index)
            log_ps.append(log_p)
            entropys.append(entropy)

            mask = mask.scatter(1,
                                index.unsqueeze(-1).expand(mask.size(0), -1),
                                1)
            _input = torch.gather(
                encoder_inputs, 1,
                index.unsqueeze(-1).unsqueeze(-1).expand(
                    encoder_inputs.size(0), -1,
                    encoder_inputs.size(2))).squeeze(1)

        log_ps = torch.stack(log_ps, 1)
        actions = torch.stack(actions, 1)
        entropys = torch.stack(entropys, 1)
        log_p = log_ps.sum(dim=1)
        entropy = entropys.mean(dim=1)
        return actions, log_p, entropy
def sample_relax_given_class_k(logits, samp, k):

    cat = Categorical(logits=logits)
    b = samp  #torch.argmax(z, dim=1)
    logprob = cat.log_prob(b).view(B, 1)

    zs = []
    z_tildes = []
    for i in range(k):

        u = torch.rand(B, C).clamp(1e-8, 1. - 1e-8)
        gumbels = -torch.log(-torch.log(u))
        z = logits + gumbels

        u_b = torch.gather(input=u, dim=1, index=b.view(B, 1))
        z_tilde_b = -torch.log(-torch.log(u_b))

        z_tilde = -torch.log((-torch.log(u) / torch.softmax(logits, dim=1)) -
                             torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B, 1), src=z_tilde_b)

        z = z_tilde

        u_b = torch.gather(input=u, dim=1, index=b.view(B, 1))
        z_tilde_b = -torch.log(-torch.log(u_b))

        u = torch.rand(B, C).clamp(1e-8, 1. - 1e-8)
        z_tilde = -torch.log((-torch.log(u) / torch.softmax(logits, dim=1)) -
                             torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B, 1), src=z_tilde_b)

        zs.append(z)
        z_tildes.append(z_tilde)

    zs = torch.stack(zs)
    z_tildes = torch.stack(z_tildes)

    z = torch.mean(zs, dim=0)
    z_tilde = torch.mean(z_tildes, dim=0)

    return z, z_tilde, logprob
Example #12
0
 def get_action(self, x):
     logits = pg.forward(x)
     # ALGO LOGIC: `env.action_space` specific logic
     if isinstance(env.action_space, Discrete):
         probs = Categorical(logits=logits)
         action = probs.sample()
         return action, -probs.log_prob(action), probs.entropy()
     elif isinstance(env.action_space, MultiDiscrete):
         logits_categories = torch.split(logits, env.action_space.nvec.tolist(), dim=1)
         action = []
         probs_categories = []
         entropy = torch.zeros((logits.shape[0]))
         neglogprob = torch.zeros((logits.shape[0]))
         for i in range(len(logits_categories)):
             probs_categories.append(Categorical(logits=logits_categories[i]))
             if len(action) != env.action_space.shape:
                 action.append(probs_categories[i].sample())
             neglogprob -= probs_categories[i].log_prob(action[i])
             entropy += probs_categories[i].entropy()
         action = torch.stack(action).transpose(0, 1)
         return action, neglogprob, entropy
Example #13
0
    def forward(self, masked, lengths, unmasked, mask):
        self.encoder.lstm.flatten_parameters()
        logits, attns = super().forward(masked, lengths, unmasked)
        bsz, seqlen, vocab_size = logits.size()

        # Sample from x converting it to probabilities
        samples = []
        log_probs = []
        for t in range(seqlen):
            logit = logits[:, t, :]
            distribution = Categorical(logits=logit)
            sampled = distribution.sample()
            fsampled = torch.where(mask[:, t].byte(), sampled, unmasked[:, t])
            log_prob = distribution.log_prob(fsampled)
            # flog_prob = torch.where(mask[:, t].byte(), log_prob, torch.zeros_like(log_prob))
            log_probs.append(log_prob)
            samples.append(fsampled)

        samples = torch.stack(samples, dim=1)
        log_probs = torch.stack(log_probs, dim=1)
        return (samples, log_probs, attns)
    def act(self, obs: np.ndarray, explore: bool):
        """Returns an action (should be called at every timestep)

        **YOU MUST IMPLEMENT THIS FUNCTION FOR Q3**

        Select an action from the model's stochastic policy by sampling a discrete action
        from the distribution specified by the model output

        :param obs (np.ndarray): observation vector from the environment
        :param explore (bool): flag indicating whether we should explore
        :return (sample from self.action_space): action the agent should perform
        """
        state = torch.from_numpy(obs).float().unsqueeze(0)
        probs = self.policy.forward(state)
        probs = torch.nn.functional.softmax(probs)

        m = Categorical(probs)
        action = m.sample()
        self.save_policy_probs.append(m.log_prob(action))

        return action.item()
Example #15
0
    def forward(self,
                large_maps,
                small_maps,
                rgb_ims=None,
                hidden_state=None,
                action=None,
                deterministic=False):
        seq_len, batch_size, C, H, W = large_maps.size()
        large_maps = large_maps.view(batch_size * seq_len, C, H, W)
        l_cnn_out = self.large_map_resnet_model(large_maps)
        l_cnn_out = l_cnn_out.view(seq_len, batch_size, -1)

        seq_len, batch_size, C, H, W = small_maps.size()
        small_maps = small_maps.view(batch_size * seq_len, C, H, W)
        s_cnn_out = self.small_map_resnet_model(small_maps)
        s_cnn_out = s_cnn_out.view(seq_len, batch_size, -1)

        if self.use_rgb:
            seq_len, batch_size, C, H, W = rgb_ims.size()
            rgb_ims = rgb_ims.view(batch_size * seq_len, C, H, W)
            rgb_cnn_out = self.rgb_resnet_model(rgb_ims)
            rgb_cnn_out = rgb_cnn_out.view(seq_len, batch_size, -1)
            cnn_out = torch.cat((rgb_cnn_out, l_cnn_out, s_cnn_out), dim=-1)
        else:
            cnn_out = torch.cat((l_cnn_out, s_cnn_out), dim=-1)

        rnn_in = F.elu(self.merge_fc(cnn_out))

        rnn_out, hidden_state = self.rnn(rnn_in, hidden_state)
        pi = self.actor_head(self.actor_fc(rnn_out))
        val = self.critic_head(self.critic_fc(rnn_out))
        cat_dist = Categorical(logits=pi)
        if action is None:
            if not deterministic:
                action = cat_dist.sample()
            else:
                action = torch.max(pi, dim=2)[1]
        log_prob = cat_dist.log_prob(action)
        return action, log_prob, cat_dist.entropy(), val, hidden_state.detach(
        ), pi
def update_model(model, gamma, optim, rollouts, device, iteration, writer):
    actor_loss, critic_loss = 0., 0.
    for i in range(len(rollouts)):
        s, a, r, ns = rollouts[i]
        actor, critic = model.forward(s)
        n_actor, n_critic = model.forward(ns)
        target = r + gamma * n_critic
        loss_c = F.mse_loss(critic, target)

        err = r + gamma * n_critic - critic
        actor_dist = Categorical(logits=actor)
        loss_a = -actor_dist.log_prob(a) * err
        loss = loss_c + loss_a

        optim.zero_grad()
        loss.backward()
        optim.step()

        actor_loss += loss_a.view([])
        critic_loss += loss_c.view([])

    return actor_loss, critic_loss
Example #17
0
def reinforce(x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = Categorical(probs=probs)

    net_loss = 0
    for jj in range(k):

        cluster_H = cat.sample()
        logq = cat.log_prob(cluster_H).view(B,1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f = logpxz - logq
        net_loss += - torch.mean((f.detach() - 1.) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

    net_loss = net_loss/ k

    return net_loss, f, logpx_given_z, logpz, logq
def sample_relax_given_class_k(logits, samp, k):

    cat = Categorical(logits=logits)
    b = samp #torch.argmax(z, dim=1)
    logprob = cat.log_prob(b).view(B,1)

    zs = []
    z_tildes = []
    for i in range(k):

        u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
        gumbels = -torch.log(-torch.log(u))
        z = logits + gumbels

        u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
        z_tilde_b = -torch.log(-torch.log(u_b))
        
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        z = z_tilde

        u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
        z_tilde_b = -torch.log(-torch.log(u_b))
        
        u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        zs.append(z)
        z_tildes.append(z_tilde)

    zs= torch.stack(zs)
    z_tildes= torch.stack(z_tildes)
    
    z = torch.mean(zs, dim=0)
    z_tilde = torch.mean(z_tildes, dim=0)

    return z, z_tilde, logprob
def get_action(state, policy_model, value_model, device):

    policy_model.eval()
    value_model.eval()

    if not state is torch.Tensor:
        state = torch.from_numpy(state).float().to(device)

    if state.shape[0] != 1:
        state = state.unsqueeze(0) # Create batch dimension

    logits = policy_model(state)

    m = Categorical(logits=logits)

    action = m.sample()

    log_probability = m.log_prob(action)

    value = value_model(state)

    return action.item(), log_probability.item(), value.item()
Example #20
0
def reinforce(x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = Categorical(probs=probs)

    net_loss = 0
    for jj in range(k):

        cluster_H = cat.sample()
        logq = cat.log_prob(cluster_H).view(B, 1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B, 1)
        logpxz = logpx_given_z + logpz  #[B,1]
        f = logpxz - logq
        net_loss += -torch.mean((f.detach() - 1.) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

    net_loss = net_loss / k

    return net_loss, f, logpx_given_z, logpz, logq
Example #21
0
    def forward(self, iteration):
        '''
    '''

        entropys = []
        log_probs = []
        sampled_arcs = []

        start_idx, end_idx = self._get_stage_index(iteration)
        cur_layer_idx = list(range(start_idx, end_idx))
        self.op_dist = []
        for layer_id in range(self.num_layers):
            logit = self.alpha[layer_id]
            # if self.temperature > 0:
            #   logit /= self.temperature
            # if self.tanh_constant is not None:
            #   logit = self.tanh_constant * torch.tanh(logit)

            op_dist = Categorical(logits=logit)
            self.op_dist.append(op_dist)

            if layer_id in cur_layer_idx:
                sampled_op = op_dist.sample()
                log_prob = op_dist.log_prob(sampled_op)
                log_probs.append(log_prob.view(-1, 1))
                entropy = op_dist.entropy()
                entropys.append(entropy.view(-1, 1))
            elif layer_id < start_idx:
                sampled_op = logit.argmax(-1)
            elif layer_id >= end_idx:
                sampled_op = op_dist.sample()
            sampled_arcs.append(sampled_op.view(-1, 1))

        self.sampled_arcs = torch.cat(sampled_arcs, dim=1)
        self.sample_entropy = torch.cat(entropys, dim=1)
        self.sample_log_prob = torch.cat(log_probs, dim=1)

        return self.sampled_arcs
Example #22
0
    def get_samples_and_logp(self, x, n_samples, return_inermediate=False):
        if not return_inermediate:
            logits = self.forward(x, False)
        else:
            logits, shared_out = self.forward(x, True)

        distribs = []
        samples = []
        logps = []
        for l in logits:
            d = Categorical(logits=l)
            sample = d.sample((n_samples, ))
            samples.append(sample)
            distribs.append(d)
            logps.append(d.log_prob(sample))

        samples = torch.stack(samples, dim=0).T
        logps = torch.stack(logps, dim=0).T

        if not return_inermediate:
            return distribs, samples, logps
        else:
            return distribs, samples, logps, shared_out
    def act(self, obs: np.ndarray, explore: bool):
        """Returns an action (should be called at every timestep)

        **YOU MUST IMPLEMENT THIS FUNCTION FOR Q3**

        Select an action from the model's stochastic policy by sampling a discrete action
        from the distribution specified by the model output

        :param obs (np.ndarray): observation vector from the environment
        :param explore (bool): flag indicating whether we should explore
        :return (sample from self.action_space): action the agent should perform
        """
        state = torch.from_numpy(obs).type(torch.FloatTensor)
        probs = self.actor(state)
        probs = self.soft_max(probs)

        m = Categorical(probs)
        action = m.sample()
        log_prob = m.log_prob(action)

        critic_state = self.critic(state)

        return action.item(), critic_state, log_prob
Example #24
0
    def monte_carlo_sampling(self, h_j, dec_state, enc_idx=None):
        current_h_j = h_j
        current_enc_idx = enc_idx
        s_t = dec_state  # entire final encoding state - hidden (plus cell) for all layers from num_layers

        y_prev = np.ones((h_j.size(0),), dtype=int) * self.vocab[START_DEC]
        y_prev = torch.from_numpy(y_prev).to(self.device)
        y_prev = self.embedding(y_prev)

        ys, neg_log_probs, weights = [], [], []

        if self.windower:
            current_h_j = h_j[:, 0:self.windower.ws, :]
            current_enc_idx = enc_idx[:, 0:self.windower.ws]
            enc_slider = EncoderSlider(h_j, enc_idx, self.windower)

        for t in range(self.dec_max_len):

            if self.windower:
                current_h_j, current_enc_idx = enc_slider.slide(current_h_j, current_enc_idx, t)

            dec_outputs, s_t, a_ij, _ = self.one_step_decode(current_h_j, s_t, y_prev, enc_idx,
                                                                       current_enc_idx)
            sample_probs = torch.exp(dec_outputs)
            cat_dist = Categorical(sample_probs)
            sample = cat_dist.sample()
            ys.append(sample)
            neg_log_prob = -cat_dist.log_prob(sample)
            sample_masked = mask_oov(sample, self.vocab)
            y_prev = self.embedding(sample_masked)

            weights.append(a_ij)
            neg_log_probs.append(neg_log_prob)

        return torch.stack(neg_log_probs).transpose(0, 1), \
               torch.stack(ys).transpose(0, 1), \
               torch.stack(weights).transpose(0, 1)  # (batch_size, max_dec_len, max_enc_len)
Example #25
0
def diagonal_FIM(agent, env, episode_len, model_name):
    print('Estimating diagonal FIM...')
    episodes = 1000
    log_probs = []
    avg_reward = 0.0
    for step in range(episodes):
        # Run an episode.
        (states, actions,
         discounted_rewards) = network.run_episode(env, agent, episode_len)
        avg_reward += np.mean(discounted_rewards)
        if step % 100 == 0:
            print('Average reward @ episode {}: {}'.format(
                step, avg_reward / 100))
            avg_reward = 0.0

        # Repeat each action, and backpropagate discounted
        # rewards. This can probably be batched for efficiency with a
        # memoryless agent...
        for (step, a) in enumerate(actions):
            logits = agent(states[step])
            dist = Categorical(logits=logits)
            log_probs.append(-dist.log_prob(actions[step]) *
                             discounted_rewards[step])

    loglikelihoods = torch.cat(log_probs).mean(0)
    loglikelihood_grads = autograd.grad(loglikelihoods, agent.parameters())
    # torch.dot(loglikelihood_grads * loglikelihood_grads.T)
    FIM = {
        n: g**2
        for n, g in zip([n for (
            n, _) in agent.named_parameters()], loglikelihood_grads)
    }
    for (n, _) in agent.named_parameters():
        FIM[n.replace(".", "__")] = FIM.pop(n)
    with open("data-{model}/FIM.dat".format(model=model_name), 'wb+') as f:
        pickle.dump(FIM, f)
        print("File dumped correctly.")
Example #26
0
def relax_grad(x, logits, b, surrogate, mixtureweights):
    B = logits.shape[0]
    C = logits.shape[1]

    cat = Categorical(logits=logits)
    # u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
    u = myclamp(torch.rand(B,C).cuda())
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    # b = torch.argmax(z, dim=1) #.view(B,1)
    logq = cat.log_prob(b).view(B,1)

    surr_input = torch.cat([z, x, logits.detach()], dim=1)
    cz = surrogate.net(surr_input)

    z_tilde = sample_relax_given_b(logits, b)
    surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
    cz_tilde = surrogate.net(surr_input)

    logpx_given_z = logprob_undercomponent(x, component=b)
    logpz = torch.log(mixtureweights[b]).view(B,1)
    logpxz = logpx_given_z + logpz #[B,1]

    f = logpxz - logq 

    grad_logq =  torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
    grad_surr_z =  torch.autograd.grad([torch.mean(cz)], [logits], create_graph=True, retain_graph=True)[0]
    grad_surr_z_tilde = torch.autograd.grad([torch.mean(cz_tilde)], [logits], create_graph=True, retain_graph=True)[0]
    # surr_loss = torch.mean(((f.detach() - cz_tilde) * grad_logq - grad_logq + grad_surr_z - grad_surr_z_tilde)**2, dim=1, keepdim=True)
    surr_loss = ((f.detach() - cz_tilde) * grad_logq - grad_logq + grad_surr_z - grad_surr_z_tilde)**2

    # print (surr_loss.shape)
    # print (logq.shape)
    # fasda

    # print (surr_loss,  torch.exp(logq))
    return surr_loss, torch.exp(logq)
Example #27
0
    def update(self, rewards: List[float], observations: List[np.ndarray],
               actions: List[int]) -> Dict[str, float]:
        """Update function for REINFORCE

        **YOU MUST IMPLEMENT THIS FUNCTION FOR Q3**

        :param rewards (List[float]): rewards of episode (from first to last)
        :param observations (List[np.ndarray]): observations of episode (from first to last)
        :param actions (List[int]): applied actions of episode (from first to last)
        :return (Dict[str, float]): dictionary mapping from loss names to loss values
        """
        G = self.compute_gt(rewards)
        p_loss = 0.0
        for i in range(len(rewards)):
            probs = self.policy.forward(Tensor(observations[i]))
            dist = torch.nn.functional.softmax(probs, dim=-1)
            m = Categorical(dist)
            p_loss -= m.log_prob(torch.FloatTensor([actions[i]])) * G[i]

        self.policy_optim.zero_grad()
        p_loss.backward()
        self.policy_optim.step()

        return {"p_loss": p_loss}
Example #28
0
    def forward(self, img1, img2, mex):

        mex = torch.nn.functional.one_hot(mex, num_classes=self.vocab_len)

        img1 = img1.view(img1.size(0), -1)
        img2 = img2.view(img2.size(0), -1)
        out1 = self.policy_single_img(img1)
        out2 = self.policy_single_img(img2)

        symbol = mex.view(mex.size(0), -1).float()
        symbol = self.policy_single_mex(symbol)

        out1 = torch.bmm(symbol.view(symbol.size(0), 1, symbol.size(1)),
                         out1.view(out1.size(0), symbol.size(1),
                                   1))  # un numero per ogni immagine
        out2 = torch.bmm(symbol.view(symbol.size(0), 1, symbol.size(1)),
                         out2.view(out2.size(0), symbol.size(1), 1))

        out1 = out1.view(out1.size(0), -1)
        out2 = out2.view(out2.size(0), -1)

        combined = torch.cat((out1, out2), dim=1)
        probs = self.softmax(combined)

        dist = Categorical(probs=probs)

        if self.training:
            actions = dist.sample()
        else:
            actions = dist.argmax(dim=1)

        logprobs = dist.log_prob(actions)

        entropy = dist.entropy()

        return probs, actions, logprobs, entropy
def sample_relax(logits):  #, k=1):

    # u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) #.cuda()
    u = torch.rand(B, C).clamp(1e-12, 1. - 1e-12)  #.cuda()
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    b = torch.argmax(z, dim=1)

    cat = Categorical(logits=logits)
    logprob = cat.log_prob(b).view(B, 1)

    v_k = torch.rand(B, 1).clamp(1e-12, 1. - 1e-12)
    z_tilde_b = -torch.log(-torch.log(v_k))

    # # v = torch.rand(B,C) #.clamp(1e-12, 1.-1e-12) #.cuda()
    # v_k = torch.gather(input=u, dim=1, index=b.view(B,1))
    # # z_tilde_b = -torch.log(-torch.log(v_k))
    # z_tilde_b = torch.gather(input=z, dim=1, index=b.view(B,1))
    # # print (z_tilde_b)

    v = torch.rand(B, C).clamp(1e-12, 1. - 1e-12)  #.cuda()
    probs = torch.softmax(logits, dim=1).repeat(B, 1)
    # print (probs.shape, torch.log(v_k).shape, torch.log(v).shape)
    # fasdfa

    # print (v.shape)
    # print (v.shape)
    z_tilde = -torch.log((-torch.log(v) / probs) - torch.log(v_k))

    # print (z_tilde)
    # print (z_tilde_b)
    z_tilde.scatter_(dim=1, index=b.view(B, 1), src=z_tilde_b)
    # print (z_tilde)
    # fasdfs

    return z, b, logprob, z_tilde
Example #30
0
def get_action(state):
    state = torch.from_numpy(state).float().cuda()

    logits = actor(state.unsqueeze(dim=0))

    l = list(range(buffer.action_space))
    legal_actions = env.legal_actions()
    mask = [ele for ele in l if ele not in legal_actions]

    buffer.masks.append(mask)

    logits[0][mask] = -float("Inf")

    m = Categorical(logits=logits)

    action = m.sample()

    #print(action.item() in legal_actions)

    log_probs = m.log_prob(action)

    value = critic(state.unsqueeze(dim=0))

    return action.item(), value, log_probs
Example #31
0
def sample_relax(x, logits, surrogate):
    B = logits.shape[0]
    C = logits.shape[1]

    cat = Categorical(logits=logits)
    # u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
    u = myclamp(torch.rand(B, C).cuda())
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    b = torch.argmax(z, dim=1)  #.view(B,1)
    logprob = cat.log_prob(b).view(B, 1)

    # czs = []
    # for j in range(1):
    #     z = sample_relax_z(logits)
    #     surr_input = torch.cat([z, x, logits.detach()], dim=1)
    #     cz = surrogate.net(surr_input)
    #     czs.append(cz)
    # czs = torch.stack(czs)
    # cz = torch.mean(czs, dim=0)#.view(1,1)
    surr_input = torch.cat([z, x, logits.detach()], dim=1)
    cz = surrogate.net(surr_input)

    # cz_tildes = []
    # for j in range(1):
    #     z_tilde = sample_relax_given_b(logits, b)
    #     surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
    #     cz_tilde = surrogate.net(surr_input)
    #     cz_tildes.append(cz_tilde)
    # cz_tildes = torch.stack(cz_tildes)
    # cz_tilde = torch.mean(cz_tildes, dim=0) #.view(B,1)
    z_tilde = sample_relax_given_b(logits, b)
    surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
    cz_tilde = surrogate.net(surr_input)

    return b, logprob, cz, cz_tilde, z, z_tilde, gumbels, u
Example #32
0
    def forward(self):

        inputs, h0 = self.input_vars, None
        log_probs, entropys, sampled_arch = [], [], []
        for iedge in range(self.num_edge):
            outputs, h0 = self.w_lstm(inputs, h0)

            logits = self.w_pred(outputs)
            logits = logits / self.temperature
            logits = self.tanh_constant * torch.tanh(logits)
            # distribution
            op_distribution = Categorical(logits=logits)
            op_index = op_distribution.sample()
            sampled_arch.append(op_index.item())

            op_log_prob = op_distribution.log_prob(op_index)
            log_probs.append(op_log_prob.view(-1))
            op_entropy = op_distribution.entropy()
            entropys.append(op_entropy.view(-1))

            # obtain the input embedding for the next step
            inputs = self.w_embd(op_index)
        return torch.sum(torch.cat(log_probs)), torch.sum(
            torch.cat(entropys)), self.convert_structure(sampled_arch)
    # dist = LogitRelaxedBernoulli(torch.Tensor([1.]), bern_param)
    # dist_bernoulli = Bernoulli(bern_param)
    C= 2
    n_components = C
    B=1
    probs = torch.ones(B,C)
    bern_param = bern_param.view(B,1)
    aa = 1 - bern_param
    probs = torch.cat([aa, bern_param], dim=1)

    cat = Categorical(probs= probs)

    grads = []
    for i in range(n):
        b = cat.sample()
        logprob = cat.log_prob(b.detach())
        # b_ = torch.argmax(z, dim=1)

        logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0]
        grad = f(b) * logprobgrad

        grads.append(grad[0][0].data.numpy())

    print ('Grad Estimator: Reinfoce categorical')
    print ('Grad mean', np.mean(grads))
    print ('Grad std', np.std(grads))
    print ()

    reinforce_cat_grad_means.append(np.mean(grads))
    reinforce_cat_grad_stds.append(np.std(grads))
Example #34
0
def HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
    cat_bernoulli = Categorical(probs=probs)

    net_loss = 0
    surr_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)

        logq_z = cat.log_prob(cluster_S.detach()).view(B,1)
        logq_b = cat_bernoulli.log_prob(cluster_H.detach()).view(B,1)


        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f_z = logpxz - logq_z - 1.
        f_b = logpxz - logq_b - 1.

        surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
        # surr_pred, alpha = surrogate.net(surr_input)
        surr_pred = surrogate.net(surr_input)
        alpha = torch.sigmoid(surrogate2.net(x))

        net_loss += - torch.mean(     alpha.detach()*(f_z.detach()  - surr_pred.detach()) * logq_z  
                                    + alpha.detach()*surr_pred 
                                    + (1-alpha.detach())*(f_b.detach()  ) * logq_b)

        # surr_loss += torch.mean(torch.abs(f_z.detach() - surr_pred))

        grad_logq_z = torch.mean( torch.autograd.grad([torch.mean(logq_z)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        grad_logq_b =  torch.mean( torch.autograd.grad([torch.mean(logq_b)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        # print (alpha.shape, f_z.shape, surr_pred.shape, grad_logq_z.shape, grad_surr.shape)
        # fsdfa
        # grad_surr = torch.autograd.grad([surr_pred[0]], [logits], create_graph=True, retain_graph=True)[0]
        # print (grad_surr)
        # fsdfasd
        surr_loss += torch.mean(
                                    (alpha*(f_z.detach() - surr_pred) * grad_logq_z 
                                    + alpha*grad_surr
                                    + (1-alpha)*(f_b.detach()) * grad_logq_b )**2
                                    )

        surr_dif = torch.mean(torch.abs(f_z.detach() - surr_pred))
        # gradd = torch.autograd.grad([surr_loss], [alpha], create_graph=True, retain_graph=True)[0]
        # print (gradd)
        # fdsf
        grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
        grad_score = torch.autograd.grad([torch.mean((f_z.detach() - surr_pred.detach()) * logq_z)], [logits], create_graph=True, retain_graph=True)[0]
        grad_path = torch.mean(torch.abs(grad_path))
        grad_score = torch.mean(torch.abs(grad_score))


    net_loss = net_loss/ k
    surr_loss = surr_loss/ k

    return net_loss, f_b, logpx_given_z, logpz, logq_b, surr_loss, surr_dif, grad_path, grad_score, torch.mean(alpha)
def sample_reinforce_given_class(logits, samp):
    dist = Categorical(logits=logits)
    logprob = dist.log_prob(samp)
    return logprob
Example #36
0
 def get_action(self, x, action=None):
     logits = self.actor(self.forward(x))
     probs = Categorical(logits=logits)
     if action is None:
         action = probs.sample()
     return action, probs.log_prob(action), probs.entropy()
Example #37
0
if __name__ == "__main__":

    for episode in range(NUM_EPISODES):
        s, done = env.reset(), False
        states, rewards, log_probs = [], [], []

        while not done:
            s = torch.from_numpy(s).float()
            p = Categorical(actor(s))
            a = p.sample()
            with torch.no_grad():
                succ, r, done, _ = env.step(a.numpy())

            states.append(s)
            rewards.append(r)
            log_probs.append(p.log_prob(a))

            s = succ

        discounted_rewards = [DISCOUNT**t * r for t, r in enumerate(rewards)]
        cumulative_returns = [
            G(discounted_rewards, t) for t, _ in enumerate(discounted_rewards)
        ]

        states = torch.stack(states)
        state_values = critic(states).reshape(-1)

        cumulative_returns = tensor(cumulative_returns)
        Adv = cumulative_returns - state_values

        log_probs = torch.stack(log_probs).reshape(-1)
Example #38
0
    def run(self, episodes, steps, train=False, render_once=1e10, saveonce=10):
        if train:
            assert self.recorder.log_message is not None, "log_message is necessary during training, Instantiate Runner with log message"

        reset_model = False
        if hasattr(self.model, "type") and self.model.type == "mem":
            print("Recurrent Model")
            reset_model = True
        self.env.display_neural_image = self.visual_activations
        for _ in range(episodes):

            self.env.reset()
            self.env.enable_draw = True if not train or _ % render_once == render_once - 1 else False

            if reset_model:
                self.model.reset()

            state = self.env.get_state().reshape(-1)
            bar = tqdm(range(steps),
                       bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
            trewards = 0

            for step in bar:

                state = T.from_numpy(state).float()
                actions = self.model(state)

                c = Categorical(actions)
                action = c.sample()
                log_prob = c.log_prob(action)

                u = np.zeros(self.nactions)
                u[action] = 1.0
                newstate, reward = self.env.act(u)
                state = newstate.reshape(-1)
                trewards += reward

                if train:
                    self.trainer.store_records(reward, log_prob)

                if self.visual_activations:
                    u = T.cat(self.activations, dim=0).reshape(-1)
                    self.env.neural_image_values = u.detach().numpy()
                    self.activations = []
                    if _ % 10 == 0 and step / steps == 0:
                        self.update_weights()
                        self.env.neural_weights = self.weights
                        self.env.weight_change = True
                    if type(self.model.hidden_vectors) != type(None):
                        self.env.hidden_state = self.model.hidden_vectors

                bar.set_description(f"Episode: {_:4} Rewards : {trewards}")
                if train:
                    self.env.step()
                else:
                    self.env.step(speed=0)

            if train:
                self.trainer.update()
                self.trainer.clear_memory()
                self.recorder.newdata(trewards)
                if _ % saveonce == saveonce - 1:
                    self.recorder.save()
                    self.recorder.plot()

                if _ % saveonce == saveonce - 1 and self.recorder.final_reward >= self.current_max_reward:
                    self.recorder.save_model(self.model)
                    self.current_max_reward = self.recorder.final_reward
        print("******* Run Complete *******")
Example #39
0
    steps_list = []
    for step in range(n_steps):

        optim.zero_grad()

        loss = 0
        net_loss = 0
        for i in range(batch_size):
            x = sample_true()
            logits = encoder.net(x)
            # print (logits.shape)
            # print (torch.softmax(logits, dim=0))
            # fsfd
            cat = Categorical(probs= torch.softmax(logits, dim=0))
            cluster = cat.sample()
            logprob_cluster = cat.log_prob(cluster.detach())
            # print (logprob_cluster)
            pxz = logprob_undercomponent(x, component=cluster, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=False)
            f = pxz - logprob_cluster
            # print (f)
            # logprob = logprob_givenmixtureeweights(x, needsoftmax_mixtureweight)
            net_loss += -f.detach() * logprob_cluster
            loss += -f
        loss = loss / batch_size
        net_loss = net_loss / batch_size

        # print (loss, net_loss)

        loss.backward(retain_graph=True)  
        optim.step()
def simplax():



    def show_surr_preds():

        batch_size = 1

        rows = 3
        cols = 1
        fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

        for i in range(rows):

            x = sample_true(1).cuda() #.view(1,1)
            logits = encoder.net(x)
            probs = torch.softmax(logits, dim=1)
            cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
            cluster_S = cat.rsample()
            cluster_H = H(cluster_S)
            logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
            check_nan(logprob_cluster)

            z = cluster_S

            n_evals = 40
            x1 = np.linspace(-9,205, n_evals)
            x = torch.from_numpy(x1).view(n_evals,1).float().cuda()
            z = z.repeat(n_evals,1)
            cluster_H = cluster_H.repeat(n_evals,1)
            xz = torch.cat([z,x], dim=1) 

            logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
            f = logpxz #- logprob_cluster

            surr_pred = surrogate.net(xz)
            surr_pred = surr_pred.data.cpu().numpy()
            f = f.data.cpu().numpy()

            col =0
            row = i
            # print (row)
            ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

            ax.plot(x1,surr_pred, label='Surr')
            ax.plot(x1,f, label='f')
            ax.set_title(str(cluster_H[0]))
            ax.legend()


        # save_dir = home+'/Documents/Grad_Estimators/GMM/'
        plt_path = exp_dir+'gmm_surr.png'
        plt.savefig(plt_path)
        print ('saved training plot', plt_path)
        plt.close()




    def plot_dist():


        mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)

        rows = 1
        cols = 1
        fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

        col =0
        row = 0
        ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)


        xs = np.linspace(-9,205, 300)
        sum_ = np.zeros(len(xs))

        # C = 20
        for c in range(n_components):
            m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
            ys = []
            for x in xs:
                # component_i = (torch.exp(m.log_prob(x) )* ((c+5.) / 290.)).numpy()
                component_i = (torch.exp(m.log_prob(x) )* mixture_weights[c]).detach().cpu().numpy()


                ys.append(component_i)

            ys = np.reshape(np.array(ys), [-1])
            sum_ += ys
            ax.plot(xs, ys, label='')

        ax.plot(xs, sum_, label='')

        # save_dir = home+'/Documents/Grad_Estimators/GMM/'
        plt_path = exp_dir+'gmm_plot_dist.png'
        plt.savefig(plt_path)
        print ('saved training plot', plt_path)
        plt.close()
        


    def get_loss():

        x = sample_true(batch_size).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        # cluster_onehot = torch.zeros(n_components)
        # cluster_onehot[cluster_H] = 1.
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
        check_nan(logprob_cluster)

        logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
        f = logpxz - logprob_cluster

        surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
        surr_pred = surrogate.net(surr_input)
        
        # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred)
        # loss = - torch.mean(f)
        surr_loss = torch.mean(torch.abs(logpxz.detach()-surr_pred))

        return surr_loss


    def plot_posteriors(needsoftmax_mixtureweight, name=''):

        x = sample_true(1).cuda() 
        trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)

        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1).view(n_components)

        trueposterior = trueposterior.data.cpu().numpy()
        qz = probs.data.cpu().numpy()

        error = L2_mixtureweights(trueposterior,qz)
        kl = KL_mixutreweights(p=trueposterior, q=qz)


        rows = 1
        cols = 1
        fig = plt.figure(figsize=(8+cols,8+rows), facecolor='white') #, dpi=150)

        col =0
        row = 0
        ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

        width = .3
        ax.bar(range(len(qz)), trueposterior, width=width, label='True')
        ax.bar(np.array(range(len(qz)))+width, qz, width=width, label='q')
        # ax.bar(np.array(range(len(q_b)))+width+width, q_b, width=width)
        ax.legend()
        ax.grid(True, alpha=.3)
        ax.set_title(str(error) + ' kl:' + str(kl))
        ax.set_ylim(0.,1.)

        # save_dir = home+'/Documents/Grad_Estimators/GMM/'
        plt_path = exp_dir+'posteriors'+name+'.png'
        plt.savefig(plt_path)
        print ('saved training plot', plt_path)
        plt.close()
        



    def inference_error(needsoftmax_mixtureweight):

        error_sum = 0
        kl_sum = 0
        n=10
        for i in range(n):

            # if x is None:
            x = sample_true(1).cuda() 
            trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)

            logits = encoder.net(x)
            probs = torch.softmax(logits, dim=1).view(n_components)

            error = L2_mixtureweights(trueposterior.data.cpu().numpy(),probs.data.cpu().numpy())
            kl = KL_mixutreweights(trueposterior.data.cpu().numpy(), probs.data.cpu().numpy())

            error_sum+=error
            kl_sum += kl
        
        return error_sum/n, kl_sum/n
        # fsdfa



    #SIMPLAX
    needsoftmax_mixtureweight = torch.randn(n_components, requires_grad=True, device="cuda")#.cuda()
    
    print ('current mixuture weights')
    print (torch.softmax(needsoftmax_mixtureweight, dim=0))
    print()

    encoder = NN3(input_size=1, output_size=n_components).cuda()
    surrogate = NN3(input_size=1+n_components, output_size=1).cuda()
    # optim = torch.optim.Adam([needsoftmax_mixtureweight], lr=.00004)
    # optim_net = torch.optim.Adam(encoder.parameters(), lr=.0004)
    # optim_surr = torch.optim.Adam(surrogate.parameters(), lr=.004)
    # optim = torch.optim.Adam([needsoftmax_mixtureweight], lr=.0001)
    # optim_net = torch.optim.Adam(encoder.parameters(), lr=.0001)
    optim_net = torch.optim.SGD(encoder.parameters(), lr=.0001)
    # optim_surr = torch.optim.Adam(surrogate.parameters(), lr=.005)
    temp = 1.
    batch_size = 100
    n_steps = 300000
    surrugate_steps = 0
    k = 1
    L2_losses = []
    inf_losses = []
    inf_losses_kl = []
    kl_losses_2 = []
    surr_losses = []
    steps_list =[]
    grad_reparam_list =[]
    grad_reinforce_list =[]
    f_list = []
    logpxz_list = []
    logprob_cluster_list = []
    logpx_list = []
    # logprob_cluster_list = []
    for step in range(n_steps):

        for ii in range(surrugate_steps):
            surr_loss = get_loss()
            optim_surr.zero_grad()
            surr_loss.backward()
            optim_surr.step()

        x = sample_true(batch_size).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        # print (probs)
        # fsdafsa
        # cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())

        cat = Categorical(probs=probs)
        # cluster = cat.sample()
        # logprob_cluster = cat.log_prob(cluster.detach())

        net_loss = 0
        loss = 0
        surr_loss = 0
        for jj in range(k):

            # cluster_S = cat.rsample()
            # cluster_H = H(cluster_S)
            cluster_H = cat.sample()

            # print (cluster_H.shape)
            # print (cluster_H[0])
            # fsad


            # cluster_onehot = torch.zeros(n_components)
            # cluster_onehot[cluster_H] = 1.
            # logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
            logprob_cluster = cat.log_prob(cluster_H.detach()).view(batch_size,1)
            # logprob_cluster = cat.log_prob(cluster_S).view(batch_size,1).detach()

            # cat = RelaxedOneHotCategorical(probs=probs.detach(), temperature=torch.tensor([temp]).cuda())
            # logprob_cluster = cat.log_prob(cluster_S).view(batch_size,1)

            check_nan(logprob_cluster)

            logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
            f = logpxz - logprob_cluster
            # print (f)

            # surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
            surr_input = torch.cat([probs, x], dim=1) #[B,21]
            surr_pred = surrogate.net(surr_input)
            
            # print (f.shape)
            # print (surr_pred.shape)
            # print (logprob_cluster.shape)
            # fsadfsa
            # net_loss += - torch.mean((logpxz.detach()-surr_pred.detach()) * logprob_cluster + surr_pred - logprob_cluster)
            # net_loss += - torch.mean((logpxz.detach() - surr_pred.detach() - 1.) * logprob_cluster + surr_pred)
            net_loss += - torch.mean((logpxz.detach() - 1.) * logprob_cluster)
            # net_loss += - torch.mean((logpxz.detach()) * logprob_cluster - logprob_cluster)
            loss += - torch.mean(logpxz)
            surr_loss += torch.mean(torch.abs(logpxz.detach()-surr_pred))

        net_loss = net_loss/ k
        loss = loss / k
        surr_loss = surr_loss/ k



        # if step %2==0:
        # optim.zero_grad()
        # loss.backward(retain_graph=True)  
        # optim.step()

        optim_net.zero_grad()
        net_loss.backward(retain_graph=True)
        optim_net.step()

        # optim_surr.zero_grad()
        # surr_loss.backward(retain_graph=True)
        # optim_surr.step()

        # print (torch.mean(f).cpu().data.numpy())
        # plot_posteriors(name=str(step))
        # fsdf

        # kl_batch = compute_kl_batch(x,probs)


        if step%500==0:
            print (step, 'f:', torch.mean(f).cpu().data.numpy(), 'surr_loss:', surr_loss.cpu().data.detach().numpy(), 
                            'theta dif:', L2_mixtureweights(true_mixture_weights,torch.softmax(
                                        needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()))
            # if step %5000==0:
            #     print (torch.softmax(needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()) 
            #     # test_samp, test_cluster = sample_true2() 
            #     # print (test_cluster.cpu().data.numpy(), test_samp.cpu().data.numpy(), torch.softmax(encoder.net(test_samp.cuda().view(1,1)), dim=1))           
            #     print ()

            if step > 0:
                L2_losses.append(L2_mixtureweights(true_mixture_weights,torch.softmax(
                                            needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()))
                steps_list.append(step)
                surr_losses.append(surr_loss.cpu().data.detach().numpy())

                inf_error, kl_error = inference_error(needsoftmax_mixtureweight)
                inf_losses.append(inf_error)
                inf_losses_kl.append(kl_error)

                kl_batch = compute_kl_batch(x,probs,needsoftmax_mixtureweight)
                kl_losses_2.append(kl_batch)

                logpx = copmute_logpx(x, needsoftmax_mixtureweight)
                logpx_list.append(logpx)

                f_list.append(torch.mean(f).cpu().data.detach().numpy())
                logpxz_list.append(torch.mean(logpxz).cpu().data.detach().numpy())
                logprob_cluster_list.append(torch.mean(logprob_cluster).cpu().data.detach().numpy())




                # i_feel_like_it = 1
                # if i_feel_like_it:

                if len(inf_losses) > 0:
                    print ('probs', probs[0])
                    print('logpxz', logpxz[0])
                    print('pred', surr_pred[0])
                    print ('dif', logpxz.detach()[0]-surr_pred.detach()[0])
                    print ('logq', logprob_cluster[0])
                    print ('dif*logq', (logpxz.detach()[0]-surr_pred.detach()[0])*logprob_cluster[0])
                    
                    


                    output= torch.mean((logpxz.detach()-surr_pred.detach()) * logprob_cluster, dim=0)[0] 
                    output2 = torch.mean(surr_pred, dim=0)[0]
                    output3 = torch.mean(logprob_cluster, dim=0)[0]
                    # input_ = torch.mean(probs, dim=0) #[0]
                    # print (probs.shape)
                    # print (output.shape)
                    # print (input_.shape)
                    grad_reinforce = torch.autograd.grad(outputs=output, inputs=(probs), retain_graph=True)[0]
                    grad_reparam = torch.autograd.grad(outputs=output2, inputs=(probs), retain_graph=True)[0]
                    grad3 = torch.autograd.grad(outputs=output3, inputs=(probs), retain_graph=True)[0]
                    # print (grad)
                    # print (grad_reinforce.shape)
                    # print (grad_reparam.shape)
                    grad_reinforce = torch.mean(torch.abs(grad_reinforce))
                    grad_reparam = torch.mean(torch.abs(grad_reparam))
                    grad3 = torch.mean(torch.abs(grad3))
                    # print (grad_reinforce)
                    # print (grad_reparam)
                    # dfsfda
                    grad_reparam_list.append(grad_reparam.cpu().data.detach().numpy())
                    grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy())
                    # grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy())

                    print ('reparam:', grad_reparam.cpu().data.detach().numpy())
                    print ('reinforce:', grad_reinforce.cpu().data.detach().numpy())
                    print ('logqz grad:', grad3.cpu().data.detach().numpy())

                    print ('current mixuture weights')
                    print (torch.softmax(needsoftmax_mixtureweight, dim=0))
                    print()

                    # print ()
                else:
                    grad_reparam_list.append(0.)
                    grad_reinforce_list.append(0.)                    



            if len(surr_losses) > 3  and step %1000==0:
                plot_curve(steps=steps_list,  thetaloss=L2_losses, 
                            infloss=inf_losses, surrloss=surr_losses,
                            grad_reinforce_list=grad_reinforce_list, 
                            grad_reparam_list=grad_reparam_list,
                            f_list=f_list, logpxz_list=logpxz_list,
                            logprob_cluster_list=logprob_cluster_list,
                            inf_losses_kl=inf_losses_kl,
                            kl_losses_2=kl_losses_2,
                            logpx_list=logpx_list)


                plot_posteriors(needsoftmax_mixtureweight)
                plot_dist()
                show_surr_preds()
                

            # print (f)
            # print (surr_pred)

            #Understand surr preds
            # if step %5000==0:

            # if step ==0:
                
                # fasdf
                





    data_dict = {}

    data_dict['steps'] = steps_list
    data_dict['losses'] = L2_losses

    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    with open( exp_dir+"data_simplax.p", "wb" ) as f:
        pickle.dump(data_dict, f)
    print ('saved data')
Example #41
0
    def forward(self, prev_tokens: Dict[str, torch.LongTensor],
                prev_tags: Dict[str, torch.LongTensor],
                fol_tokens: Dict[str, torch.LongTensor],
                fol_tags: Dict[str, torch.LongTensor],
                prev_labels: torch.Tensor = None,
                fol_labels: torch.Tensor = None,
                conflicts: List[Any] = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        prev_mask = get_text_field_mask(prev_tokens)
        # embedding sequence
        prev_embedding_seq = self.token_field_embedding(prev_tokens)
        # embedding tag
        prev_tag_embedding = self.char_field_embedding(prev_tags)

        fol_mask = get_text_field_mask(fol_tokens)
        # embedding sequence
        fol_embedding_seq = self.token_field_embedding(fol_tokens)
        # embedding tag
        fol_tag_embedding = self.char_field_embedding(fol_tags)

        batch_size, _ = prev_mask.size()

        # initialization in specific gpu devices
        gpu_device = prev_embedding_seq.device

        prev_phrase_tensor = torch.tensor([0.0], device=gpu_device)
        fol_phrase_tensor = torch.tensor([1.0], device=gpu_device)

        prev_phrase_embedding_seq = prev_phrase_tensor.repeat(
            prev_embedding_seq.size(0),
            prev_embedding_seq.size(1),
            1
        )

        fol_phrase_embedding_seq = fol_phrase_tensor.repeat(
            fol_embedding_seq.size(0),
            fol_embedding_seq.size(1),
            1
        )

        # concat embedding and phrase
        prev_embedding_seq = torch.cat([prev_embedding_seq, prev_phrase_embedding_seq, prev_tag_embedding],
                                       dim=2)
        fol_embedding_seq = torch.cat([fol_embedding_seq, fol_phrase_embedding_seq, fol_tag_embedding], dim=2)

        prev_embedding_seq = self.projection_layer(prev_embedding_seq)
        fol_embedding_seq = self.projection_layer(fol_embedding_seq)

        # embedding phrase label 0 means prev, 1 means follow-up
        if self.training:
            embedding = torch.cat([prev_embedding_seq, fol_embedding_seq], dim=1)
            embedding_var = self._variational_dropout(embedding)
            prev_mask_len = prev_mask.size(1)
            prev_embedding_seq_var = embedding_var[:, :prev_mask_len]
            fol_embedding_seq_var = embedding_var[:, prev_mask_len:]
        else:
            prev_embedding_seq_var = prev_embedding_seq
            fol_embedding_seq_var = fol_embedding_seq

        # encode sequence
        prev_encoder_out = self.tokens_encoder(prev_embedding_seq_var, prev_mask)
        fol_encoder_out = self.tokens_encoder(fol_embedding_seq_var, fol_mask)

        prev_forward_output = prev_encoder_out[:, :, :self.hidden_size]
        prev_backward_output = prev_encoder_out[:, :, self.hidden_size:]

        fol_forward_output = fol_encoder_out[:, :, :self.hidden_size]
        fol_backward_output = fol_encoder_out[:, :, self.hidden_size:]

        prev_attn_mask = prev_mask.view(batch_size, -1, 1) * fol_mask.view(batch_size, 1, -1)
        prev_forward_attn_matrix = self._self_attention(prev_forward_output, fol_forward_output) / self._scaled_value
        prev_backward_attn_matrix = self._self_attention(prev_backward_output, fol_backward_output) / self._scaled_value
        prev_mean_pooling_attn = util.masked_softmax(prev_forward_attn_matrix + prev_backward_attn_matrix,
                                                     prev_attn_mask)

        # take max pooling rather than average
        prev_attn_vec = torch.matmul(prev_mean_pooling_attn, fol_encoder_out)

        fol_attn_mask = fol_mask.view(batch_size, -1, 1) * prev_mask.view(batch_size, 1, -1)
        fol_forward_attn_matrix = self._self_attention(fol_forward_output, prev_forward_output) / self._scaled_value
        fol_backward_attn_matrix = self._self_attention(fol_backward_output, prev_backward_output) / self._scaled_value
        fol_mean_pooling_attn = util.masked_softmax(fol_forward_attn_matrix + fol_backward_attn_matrix, fol_attn_mask)

        # take max pooling rather than average
        fol_attn_vec = torch.matmul(fol_mean_pooling_attn, prev_encoder_out)

        # non_linear_output = self._non_linear(torch.cat([encoder_out, self_attention_vec], dim=2))
        # prev_linear = torch.cat([prev_encoder_out, prev_attn_vec], dim=2)
        # fol_linear = torch.cat([fol_encoder_out, fol_attn_vec], dim=2)
        prev_attn_multiply = prev_encoder_out * prev_attn_vec
        zero_tensor = torch.zeros((batch_size, 1, prev_attn_multiply.size(2)), device=gpu_device, dtype=torch.float)
        prev_attn_shift = torch.cat((zero_tensor,
                                     prev_attn_multiply[:, :-1, :]), dim=1)
        # shift attn vector to right, and then subtract them
        prev_linear = torch.cat([prev_encoder_out, prev_attn_multiply, prev_attn_shift], dim=2)

        fol_attn_multiply = fol_encoder_out * fol_attn_vec
        fol_attn_shift = torch.cat((zero_tensor,
                                    fol_attn_multiply[:, :-1, :]), dim=1)
        # shift attn vector to right, and then subtract them
        fol_linear = torch.cat([fol_encoder_out, fol_attn_multiply, fol_attn_shift], dim=2)

        prev_tag_logistics = self.policy_net(prev_linear)
        fol_tag_logistics = self.policy_net(fol_linear)

        # project to space
        prev_tag_prob = F.softmax(prev_tag_logistics, dim=2)
        prev_predict_labels = torch.argmax(prev_tag_prob, dim=2)

        fol_tag_prob = F.softmax(fol_tag_logistics, dim=2)
        fol_predict_labels = torch.argmax(fol_tag_prob, dim=2)

        predict_restate_str_list = []
        predict_restate_tag_list = []
        max_bleu_list = []

        # debug information
        _debug_batch_conflict_map = {}

        # using predict labels to cut utterance into span and fetch representations of span
        for batch_ind in range(batch_size):
            _debug_batch_conflict_map[batch_ind] = []

            # batch reference object
            batch_origin_obj = metadata[batch_ind]["origin_obj"]

            prev_start_end, fol_start_end = predict_span_start_end(
                prev_predict_labels[batch_ind, :sum(prev_mask[batch_ind])],
                fol_predict_labels[batch_ind, :sum(fol_mask[batch_ind])])

            # Phase 2: Predict actual fusion str via span start/end and similar gate
            predict_restate_str, predict_restate_tag \
                = self.predict_restate(batch_origin_obj,
                                       fol_start_end,
                                       prev_start_end,
                                       prev_forward_output,
                                       prev_backward_output,
                                       fol_forward_output,
                                       fol_backward_output,
                                       batch_ind,
                                       gpu_device,
                                       _debug_batch_conflict_map)

            # add it to batch
            predict_restate_str_list.append(predict_restate_str)
            predict_restate_tag_list.append(predict_restate_tag)

        batch_golden_restate_str = [" ".join(single_metadata["origin_obj"]["restate"].utterance)
                                    for single_metadata in metadata]

        batch_golden_restate_tag = [single_metadata["origin_obj"]["restate"].tags
                                    for single_metadata in metadata]
        output = {
            "probs": prev_tag_prob,
            "prev_labels": prev_predict_labels,
            "fol_labels": fol_predict_labels,
            "restate": predict_restate_str_list,
            "max_bleu": max_bleu_list
        }

        avg_bleu = self.metrics["bleu"](predict_restate_str_list, batch_golden_restate_str)
        avg_symbol = self.metrics["symbol"](predict_restate_tag_list, batch_golden_restate_tag)

        # overall measure
        self.metrics["overall"]([0.4 * avg_bleu + 0.6 * avg_symbol] * batch_size)

        conflict_confidences = []

        # condition on training to
        if self.training:
            if prev_labels is not None:

                labels = torch.cat([prev_labels, fol_labels], dim=1)
                # Initialization pre-training with longest common string
                logistics = torch.cat([prev_tag_logistics, fol_tag_logistics], dim=1)
                mask = torch.cat([prev_mask, fol_mask], dim=1)
                loss_snippet = sequence_cross_entropy_with_logits(logistics, labels, mask,
                                                                  label_smoothing=0.2)

                # for pre-training, we regard them as optimal ground truth
                conflict_confidences = [1.0] * batch_size
            else:
                if DEBUG:
                    rl_sample_count = 1
                else:
                    rl_sample_count = 20

                batch_loss_snippet = []
                batch_sample_conflicts = []

                # Training Phase 2: train conflict model via margin loss
                for batch_ind in range(batch_size):

                    dynamic_conflicts = []
                    dynamic_confidence = []

                    # batch reference object
                    batch_origin_obj = metadata[batch_ind]["origin_obj"]

                    prev_mask_len = prev_mask[batch_ind].sum().view(1).data.cpu().numpy()[0]
                    fol_mask_len = fol_mask[batch_ind].sum().view(1).data.cpu().numpy()[0]

                    sample_data = []

                    for _ in range(rl_sample_count):
                        prev_multi = Categorical(logits=prev_tag_logistics[batch_ind])
                        fol_multi = Categorical(logits=fol_tag_logistics[batch_ind])

                        prev_label_tensor = prev_multi.sample()
                        prev_label_tensor.data[0].fill_(1)
                        prev_sample_label = prev_label_tensor.data.cpu().numpy().astype(int)[:prev_mask_len]

                        fol_label_tensor = fol_multi.sample()
                        fol_label_tensor.data[0].fill_(1)
                        fol_sample_label = fol_label_tensor.data.cpu().numpy().astype(int)[:fol_mask_len]

                        log_prob = torch.cat(
                            [prev_multi.log_prob(prev_label_tensor), fol_multi.log_prob(fol_label_tensor)],
                            dim=-1)

                        conflict_prob_mat = self.calculate_conflict_prob_matrix(prev_sample_label,
                                                                                fol_sample_label,
                                                                                batch_ind,
                                                                                prev_forward_output,
                                                                                prev_backward_output,
                                                                                fol_forward_output,
                                                                                fol_backward_output,
                                                                                gpu_device)
                        self.policy_net.saved_log_probs.append(log_prob)
                        sample_data.append((prev_sample_label, fol_sample_label, batch_origin_obj, conflict_prob_mat))

                    if DEBUG:
                        ret_data = [sample_action(row) for row in sample_data]
                    else:
                        # Parallel to speed up the sampling process
                        with ThreadPool(4) as p:
                            chunk_size = rl_sample_count // 4
                            ret_data = p.map(sample_action, sample_data, chunksize=chunk_size)

                    for conflict_confidence, reinforce_reward, conflict_pair in ret_data:
                        self.policy_net.rewards.append(reinforce_reward)
                        dynamic_conflicts.append(conflict_pair)
                        dynamic_confidence.append(conflict_confidence)

                    rewards = torch.tensor(self.policy_net.rewards, device=gpu_device).float()
                    self.metrics["reward"](self.policy_net.rewards)
                    rewards -= rewards.mean().detach()
                    self.metrics["reward_var"]([rewards.std().data.cpu().numpy()])

                    loss_snippet = []
                    # reward high, optimize it; reward low, reversal optimization
                    for log_prob, reward in zip(self.policy_net.saved_log_probs,
                                                rewards):
                        loss_snippet.append((- log_prob * reward).unsqueeze(0))

                    loss_snippet = torch.cat(loss_snippet).mean(dim=1).sum().view(1)
                    batch_loss_snippet.append(loss_snippet)

                    # random select one
                    best_conflict_id = choice(range(rl_sample_count))
                    # best_conflict_id = np.argmax(self.policy_net.rewards)
                    batch_sample_conflicts.append(dynamic_conflicts[best_conflict_id])
                    conflict_confidences.append(dynamic_confidence[best_conflict_id])

                    self.policy_net.reset()

                loss_snippet = torch.cat(batch_loss_snippet).mean()

                # according to confidence
                conflicts = []
                for conflict_batch_id in range(batch_size):
                    conflicts.append(batch_sample_conflicts[conflict_batch_id])

            # Training Phase 1: train snippet model
            total_loss = loss_snippet

            border = torch.tensor([0.0], device=gpu_device)
            pos_target = torch.tensor([1.0], device=gpu_device)
            neg_target = torch.tensor([-1.0], device=gpu_device)

            # Training Phase 2: train conflict model via margin loss

            loss_conflict = torch.tensor([0.0], device=gpu_device)[0]
            # random decision on which to use

            for batch_ind in range(0, batch_size):
                batch_conflict_list = conflicts[batch_ind]
                # use prediction results to conflict

                temp_loss_conflict = torch.tensor([0.0], device=gpu_device)[0]

                if batch_conflict_list and len(batch_conflict_list) > 0:
                    for conflict in batch_conflict_list:
                        (prev_start, prev_end), (fol_start, fol_end), conflict_mode = conflict

                        fol_span_repr = get_span_repr(fol_forward_output[batch_ind],
                                                      fol_backward_output[batch_ind],
                                                      fol_start, fol_end)

                        prev_span_repr = get_span_repr(prev_forward_output[batch_ind],
                                                       prev_backward_output[batch_ind],
                                                       prev_start, prev_end)

                        inter_prob = self.cosine_similar(fol_span_repr, prev_span_repr).view(1)
                        # actual conflict
                        if conflict_mode == 1:
                            temp_loss_conflict += self.margin_loss(inter_prob,
                                                                   border,
                                                                   pos_target)
                        else:
                            temp_loss_conflict += self.margin_loss(inter_prob,
                                                                   border,
                                                                   neg_target)

                    temp_confidence = conflict_confidences[batch_ind]
                    loss_conflict += temp_confidence * temp_loss_conflict / len(batch_conflict_list)

            loss_conflict = loss_conflict / batch_size

            # for larger margin
            total_loss += loss_conflict

            output["loss"] = total_loss

        return output
Example #42
0
def compute_loss_pi(states, actions, advs):
    probs = pi(states)
    # Note that this is equivalent to what used to be called multinomial
    m = Categorical(probs)
    loss = -torch.sum(torch.mul(m.log_prob(actions), advs))
    return loss
Example #43
0
def evaluate_actions(pi, actions):
    model = Categorical(pi)
    return model.log_prob(
        actions.squeeze(-1)).unsqueeze(-1), model.entropy().mean()



#REINFORCE
print ('REINFORCE')

# def sample_reinforce_given_class(logits, samp):    
#     return logprob

grads = []
for i in range (N):

    dist = Categorical(logits=logits)
    samp = dist.sample()
    logprob = dist.log_prob(samp)
    reward = f(samp) 
    gradlogprob = torch.autograd.grad(outputs=logprob, inputs=(logits), retain_graph=True)[0]
    grads.append(reward*gradlogprob)
    
print ()
grads = torch.stack(grads).view(N,C)
# print (grads.shape)
grad_mean_reinforce = torch.mean(grads,dim=0)
grad_std_reinforce = torch.std(grads,dim=0)

print ('REINFORCE')
print ('mean:', grad_mean_reinforce)
print ('std:', grad_std_reinforce)
print ()
# print ('True')
Example #45
0
class OneHotCategorical(Distribution):
    r"""
    Creates a one-hot categorical distribution parameterized by :attr:`probs` or
    :attr:`logits`.

    Samples are one-hot coded vectors of size ``probs.size(-1)``.

    .. note:: :attr:`probs` will be normalized to be summing to 1.

    See also: :func:`torch.distributions.Categorical` for specifications of
    :attr:`probs` and :attr:`logits`.

    Example::

        >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
        >>> m.sample()  # equal probability of 0, 1, 2, 3
        tensor([ 0.,  0.,  0.,  1.])

    Args:
        probs (Tensor): event probabilities
        logits (Tensor): event log probabilities
    """
    arg_constraints = {'probs': constraints.simplex}
    support = constraints.simplex
    has_enumerate_support = True

    def __init__(self, probs=None, logits=None, validate_args=None):
        self._categorical = Categorical(probs, logits)
        batch_shape = self._categorical.batch_shape
        event_shape = self._categorical.param_shape[-1:]
        super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)

    def _new(self, *args, **kwargs):
        return self._categorical._new(*args, **kwargs)

    @property
    def probs(self):
        return self._categorical.probs

    @property
    def logits(self):
        return self._categorical.logits

    @property
    def mean(self):
        return self._categorical.probs

    @property
    def variance(self):
        return self._categorical.probs * (1 - self._categorical.probs)

    @property
    def param_shape(self):
        return self._categorical.param_shape

    def sample(self, sample_shape=torch.Size()):
        sample_shape = torch.Size(sample_shape)
        probs = self._categorical.probs
        one_hot = probs.new(self._extended_shape(sample_shape)).zero_()
        indices = self._categorical.sample(sample_shape)
        if indices.dim() < one_hot.dim():
            indices = indices.unsqueeze(-1)
        return one_hot.scatter_(-1, indices, 1)

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        indices = value.max(-1)[1]
        return self._categorical.log_prob(indices)

    def entropy(self):
        return self._categorical.entropy()

    def enumerate_support(self):
        n = self.event_shape[0]
        values = self._new((n, n))
        torch.eye(n, out=values)
        values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
        return values.expand((n,) + self.batch_shape + (n,))
def sample_reinforce_given_class(logits, samp):
    dist = Categorical(logits=logits)
    logprob = dist.log_prob(samp)
    return logprob