Exemplo n.º 1
0
def relax_grad2(x, logits, b, surrogate, mixtureweights):
    B = logits.shape[0]
    C = logits.shape[1]

    cat = Categorical(logits=logits)
    # u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
    u = myclamp(torch.rand(B,C).cuda())
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    # b = torch.argmax(z, dim=1) #.view(B,1)
    logq = cat.log_prob(b).view(B,1)

    surr_input = torch.cat([z, x, logits.detach()], dim=1)
    cz = surrogate.net(surr_input)

    z_tilde = sample_relax_given_b(logits, b)
    surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
    cz_tilde = surrogate.net(surr_input)

    logpx_given_z = logprob_undercomponent(x, component=b)
    logpz = torch.log(mixtureweights[b]).view(B,1)
    logpxz = logpx_given_z + logpz #[B,1]

    f = logpxz - logq 
    net_loss = - torch.mean( (f.detach() - cz_tilde.detach()) * logq - logq +  cz - cz_tilde )

    grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0] #[B,C]
    pb = torch.exp(logq)

    return grad, pb
Exemplo n.º 2
0
    def sample_relax(logits, surrogate):
        cat = Categorical(logits=logits)
        u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
        gumbels = -torch.log(-torch.log(u))
        z = logits + gumbels
        b = torch.argmax(z, dim=1) #.view(B,1)
        logprob = cat.log_prob(b).view(B,1)


        # czs = []
        # for j in range(1):
        #     z = sample_relax_z(logits)
        #     surr_input = torch.cat([z, x, logits.detach()], dim=1)
        #     cz = surrogate.net(surr_input)
        #     czs.append(cz)
        # czs = torch.stack(czs)
        # cz = torch.mean(czs, dim=0)#.view(1,1)
        surr_input = torch.cat([z, x, logits.detach()], dim=1)
        cz = surrogate.net(surr_input)


        cz_tildes = []
        for j in range(1):
            z_tilde = sample_relax_given_b(logits, b)
            surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
            cz_tilde = surrogate.net(surr_input)
            cz_tildes.append(cz_tilde)
        cz_tildes = torch.stack(cz_tildes)
        cz_tilde = torch.mean(cz_tildes, dim=0) #.view(B,1)

        return b, logprob, cz, cz_tilde
def sample_relax(logits): #, k=1):
    

    # u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) #.cuda()
    u = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    b = torch.argmax(z, dim=1)

    cat = Categorical(logits=logits)
    logprob = cat.log_prob(b).view(B,1)

    v_k = torch.rand(B,1).clamp(1e-12, 1.-1e-12)
    z_tilde_b = -torch.log(-torch.log(v_k))
    #this way seems biased even tho it shoudlnt be
    # v_k = torch.gather(input=u, dim=1, index=b.view(B,1))
    # z_tilde_b = torch.gather(input=z, dim=1, index=b.view(B,1))

    v = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
    probs = torch.softmax(logits,dim=1).repeat(B,1)
    # print (probs.shape, torch.log(v_k).shape, torch.log(v).shape)
    # fasdfa

    # print (v.shape)
    # print (v.shape)
    z_tilde = -torch.log((- torch.log(v) / probs) - torch.log(v_k))

    # print (z_tilde)
    # print (z_tilde_b)
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
    # print (z_tilde)
    # fasdfs

    return z, b, logprob, z_tilde
def sample_relax_given_class(logits, samp):

    cat = Categorical(logits=logits)

    u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels

    b = samp #torch.argmax(z, dim=1)
    logprob = cat.log_prob(b).view(B,1)


    u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
    z_tilde_b = -torch.log(-torch.log(u_b))
    
    z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)


    z = z_tilde

    u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
    z_tilde_b = -torch.log(-torch.log(u_b))
    
    u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
    z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

    return z, z_tilde, logprob
Exemplo n.º 5
0
def sample_true2():
    cat = Categorical(probs= torch.tensor(true_mixture_weights))
    cluster = cat.sample()
    # print (cluster)
    # fsd
    norm = Normal(torch.tensor([cluster*10.]).float(), torch.tensor([5.0]).float())
    samp = norm.sample()
    # print (samp)
    return samp,cluster
Exemplo n.º 6
0
def sample_gmm(batch_size, mixture_weights):
    cat = Categorical(probs=mixture_weights)
    cluster = cat.sample([batch_size]) # [B]
    mean = (cluster*10.).float().cuda()
    std = torch.ones([batch_size]).cuda() *5.
    norm = Normal(mean, std)
    samp = norm.sample()
    samp = samp.view(batch_size, 1)
    return samp
Exemplo n.º 7
0
class OneHotCategorical(Distribution):
    r"""
    Creates a one-hot categorical distribution parameterized by `probs`.

    Samples are one-hot coded vectors of size probs.size(-1).

    See also: :func:`torch.distributions.Categorical`

    Example::

        >>> m = OneHotCategorical(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ]))
        >>> m.sample()  # equal probability of 0, 1, 2, 3
         0
         0
         1
         0
        [torch.FloatTensor of size 4]

    Args:
        probs (Tensor or Variable): event probabilities
    """
    params = {'probs': constraints.simplex}
    support = constraints.simplex
    has_enumerate_support = True

    def __init__(self, probs=None, logits=None):
        self._categorical = Categorical(probs, logits)
        batch_shape = self._categorical.probs.size()[:-1]
        event_shape = self._categorical.probs.size()[-1:]
        super(OneHotCategorical, self).__init__(batch_shape, event_shape)

    def sample(self, sample_shape=torch.Size()):
        sample_shape = torch.Size(sample_shape)
        probs = self._categorical.probs
        one_hot = probs.new(self._extended_shape(sample_shape)).zero_()
        indices = self._categorical.sample(sample_shape)
        if indices.dim() < one_hot.dim():
            indices = indices.unsqueeze(-1)
        return one_hot.scatter_(-1, indices, 1)

    def log_prob(self, value):
        indices = value.max(-1)[1]
        return self._categorical.log_prob(indices)

    def entropy(self):
        return self._categorical.entropy()

    def enumerate_support(self):
        probs = self._categorical.probs
        n = self.event_shape[0]
        if isinstance(probs, Variable):
            values = Variable(torch.eye(n, out=probs.data.new(n, n)))
        else:
            values = torch.eye(n, out=probs.new(n, n))
        values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
        return values.expand((n,) + self.batch_shape + (n,))
Exemplo n.º 8
0
    def test_gmm_loss(self):
        """ Test case 1 """
        n_samples = 10000

        means = torch.Tensor([[0., 0.],
                              [1., 1.],
                              [-1., 1.]])
        stds = torch.Tensor([[.03, .05],
                             [.02, .1],
                             [.1, .03]])
        pi = torch.Tensor([.2, .3, .5])

        cat_dist = Categorical(pi)
        indices = cat_dist.sample((n_samples,)).long()
        rands = torch.randn(n_samples, 2)

        samples = means[indices] + rands * stds[indices]

        class _model(nn.Module):
            def __init__(self, gaussians):
                super().__init__()
                self.means = nn.Parameter(torch.Tensor(1, gaussians, 2).normal_())
                self.pre_stds = nn.Parameter(torch.Tensor(1, gaussians, 2).normal_())
                self.pi = nn.Parameter(torch.Tensor(1, gaussians).normal_())

            def forward(self, *inputs):
                return self.means, torch.exp(self.pre_stds), f.softmax(self.pi, dim=1)

        model = _model(3)
        optimizer = torch.optim.Adam(model.parameters())

        iterations = 100000
        log_step = iterations // 10
        pbar = tqdm(total=iterations)
        cum_loss = 0
        for i in range(iterations):
            batch = samples[torch.LongTensor(128).random_(0, n_samples)]
            m, s, p = model.forward()
            loss = gmm_loss(batch, m, s, p)
            cum_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_postfix_str("avg_loss={:10.6f}".format(
                cum_loss / (i + 1)))
            pbar.update(1)
            if i % log_step == log_step - 1:
                print(m)
                print(s)
                print(p)
Exemplo n.º 9
0
def sample_true(batch_size):
    # print (true_mixture_weights.shape)
    cat = Categorical(probs=torch.tensor(true_mixture_weights))
    cluster = cat.sample([batch_size]) # [B]
    mean = (cluster*10.).float()
    std = torch.ones([batch_size]) *5.
    # print (cluster.shape)
    # fsd
    # norm = Normal(torch.tensor([cluster*10.]).float(), torch.tensor([5.0]).float())
    norm = Normal(mean, std)
    samp = norm.sample()
    # print (samp.shape)
    # fadsf
    samp = samp.view(batch_size, 1)
    return samp
Exemplo n.º 10
0
def reinforce_baseline(surrogate, x, logits, mixtureweights, k=1, get_grad=False):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)
    outputs = {}

    cat = Categorical(probs=probs)

    grads =[]
    # net_loss = 0
    for jj in range(k):

        cluster_H = cat.sample()
        outputs['logq'] = logq = cat.log_prob(cluster_H).view(B,1)
        outputs['logpx_given_z'] = logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        outputs['logpz'] = logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]

        surr_pred = surrogate.net(x)

        outputs['f'] = f = logpxz - logq - 1. 
        # outputs['net_loss'] = net_loss = net_loss - torch.mean((f.detach() ) * logq)
        outputs['net_loss'] = net_loss = - torch.mean((f.detach() - surr_pred.detach()) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

        # surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))

        grad_logq =  torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
        surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq )**2)

        if get_grad:
            grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0]
            grads.append(grad)

    # net_loss = net_loss/ k

    if get_grad:
        grads = torch.stack(grads)
        # print (grads.shape)
        outputs['grad_avg'] = torch.mean(torch.mean(grads, dim=0),dim=0)
        outputs['grad_std'] = torch.std(grads, dim=0)[0]

    outputs['surr_loss'] = surr_loss
    # return net_loss, f, logpx_given_z, logpz, logq
    return outputs
Exemplo n.º 11
0
    def sample_relax(probs):
        cat = Categorical(probs=probs)
        #Sample z
        u = torch.rand(B,C).cuda()
        u = u.clamp(1e-8, 1.-1e-8)
        gumbels = -torch.log(-torch.log(u))
        z = torch.log(probs) + gumbels

        b = torch.argmax(z, dim=1)
        logprob = cat.log_prob(b).view(B,1)

        #Sample z_tilde
        u_b = torch.rand(B,1).cuda()
        u_b = u_b.clamp(1e-8, 1.-1e-8)
        z_tilde_b = -torch.log(-torch.log(u_b))
        u = torch.rand(B,C).cuda()
        u = u.clamp(1e-8, 1.-1e-8)
        z_tilde = -torch.log((- torch.log(u) / probs) - torch.log(u_b))
        z_tilde[:,b] = z_tilde_b
        return z, b, logprob, z_tilde, gumbels
def sample_relax_given_class_k(logits, samp, k):

    cat = Categorical(logits=logits)
    b = samp #torch.argmax(z, dim=1)
    logprob = cat.log_prob(b).view(B,1)

    zs = []
    z_tildes = []
    for i in range(k):

        u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
        gumbels = -torch.log(-torch.log(u))
        z = logits + gumbels

        u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
        z_tilde_b = -torch.log(-torch.log(u_b))
        
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        z = z_tilde

        u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
        z_tilde_b = -torch.log(-torch.log(u_b))
        
        u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        zs.append(z)
        z_tildes.append(z_tilde)

    zs= torch.stack(zs)
    z_tildes= torch.stack(z_tildes)
    
    z = torch.mean(zs, dim=0)
    z_tilde = torch.mean(z_tildes, dim=0)

    return z, z_tilde, logprob
Exemplo n.º 13
0
def reinforce(x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = Categorical(probs=probs)

    net_loss = 0
    for jj in range(k):

        cluster_H = cat.sample()
        logq = cat.log_prob(cluster_H).view(B,1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f = logpxz - logq
        net_loss += - torch.mean((f.detach() - 1.) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

    net_loss = net_loss/ k

    return net_loss, f, logpx_given_z, logpz, logq
Exemplo n.º 14
0
def relax_grad(x, logits, b, surrogate, mixtureweights):
    B = logits.shape[0]
    C = logits.shape[1]

    cat = Categorical(logits=logits)
    # u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
    u = myclamp(torch.rand(B,C).cuda())
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    # b = torch.argmax(z, dim=1) #.view(B,1)
    logq = cat.log_prob(b).view(B,1)

    surr_input = torch.cat([z, x, logits.detach()], dim=1)
    cz = surrogate.net(surr_input)

    z_tilde = sample_relax_given_b(logits, b)
    surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
    cz_tilde = surrogate.net(surr_input)

    logpx_given_z = logprob_undercomponent(x, component=b)
    logpz = torch.log(mixtureweights[b]).view(B,1)
    logpxz = logpx_given_z + logpz #[B,1]

    f = logpxz - logq 

    grad_logq =  torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
    grad_surr_z =  torch.autograd.grad([torch.mean(cz)], [logits], create_graph=True, retain_graph=True)[0]
    grad_surr_z_tilde = torch.autograd.grad([torch.mean(cz_tilde)], [logits], create_graph=True, retain_graph=True)[0]
    # surr_loss = torch.mean(((f.detach() - cz_tilde) * grad_logq - grad_logq + grad_surr_z - grad_surr_z_tilde)**2, dim=1, keepdim=True)
    surr_loss = ((f.detach() - cz_tilde) * grad_logq - grad_logq + grad_surr_z - grad_surr_z_tilde)**2

    # print (surr_loss.shape)
    # print (logq.shape)
    # fasda

    # print (surr_loss,  torch.exp(logq))
    return surr_loss, torch.exp(logq)
Exemplo n.º 15
0
 def _distribution(self, obs):
     """Takes the observation and outputs a distribution over actions."""
     logits = self.logits_net(obs)
     return Categorical(logits=logits)
Exemplo n.º 16
0
 def compute_run(self, obs):
     embedding = self.encode(obs)
     x_dist = self.actor(embedding)
     dist = Categorical(logits=F.log_softmax(x_dist, dim=1))
     value = self.critic(embedding).squeeze(1)
     return dist, value
Exemplo n.º 17
0
    def forward(self,
                audio_feature,
                decode_step,
                tf_rate=0.0,
                teacher=None,
                state_len=None):
        bs = audio_feature.shape[0]
        # Encode
        encode_feature, encode_len = self.encoder(audio_feature, state_len)

        ctc_output = None
        att_output = None
        att_maps = None

        # CTC based decoding
        if self.joint_ctc:
            ctc_output = self.ctc_layer(encode_feature)

        # Attention based decoding
        if self.joint_att:
            if teacher is not None:
                teacher = self.embed(teacher)

            # Init (init char = <SOS>, reset all rnn state and cell)
            self.decoder.init_rnn(encode_feature)
            self.attention.reset_enc_mem()
            last_char = self.embed(
                torch.zeros((bs), dtype=torch.long).to(
                    next(self.decoder.parameters()).device))
            output_char_seq = []
            output_att_seq = [[]] * self.attention.num_head

            # Decode
            for t in range(decode_step):
                # Attend (inputs current state of first layer, encoded features)
                attention_score, context = self.attention(
                    self.decoder.state_list[0], encode_feature, encode_len)
                # Spell (inputs context + embedded last character)
                decoder_input = torch.cat([last_char, context], dim=-1)
                dec_out = self.decoder(decoder_input)

                # To char
                cur_char = self.char_trans(dec_out)

                # Teacher forcing
                if (teacher is not None):
                    if random.random() <= tf_rate:
                        last_char = teacher[:, t + 1, :]
                    else:
                        sampled_char = Categorical(F.softmax(cur_char,
                                                             dim=-1)).sample()
                        last_char = self.embed(sampled_char)
                else:
                    last_char = self.embed(torch.argmax(cur_char, dim=-1))

                output_char_seq.append(cur_char)
                for head, a in enumerate(attention_score):
                    output_att_seq[head].append(a.cpu())

            att_output = torch.stack(output_char_seq, dim=1)
            att_maps = [torch.stack(att, dim=1) for att in output_att_seq]

        return ctc_output, encode_len, att_output, att_maps
Exemplo n.º 18
0
class OneHotCategorical(Distribution):
    r"""
    Creates a one-hot categorical distribution parameterized by :attr:`probs` or
    :attr:`logits`.

    Samples are one-hot coded vectors of size ``probs.size(-1)``.

    .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
              and it will be normalized to sum to 1 along the last dimension. attr:`probs`
              will return this normalized value.
              The `logits` argument will be interpreted as unnormalized log probabilities
              and can therefore be any real number. It will likewise be normalized so that
              the resulting probabilities sum to 1 along the last dimension. attr:`logits`
              will return this normalized value.

    See also: :func:`torch.distributions.Categorical` for specifications of
    :attr:`probs` and :attr:`logits`.

    Example::

        >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
        >>> m.sample()  # equal probability of 0, 1, 2, 3
        tensor([ 0.,  0.,  0.,  1.])

    Args:
        probs (Tensor): event probabilities
        logits (Tensor): event log probabilities (unnormalized)
    """
    arg_constraints = {'probs': constraints.simplex,
                       'logits': constraints.real_vector}
    support = constraints.one_hot
    has_enumerate_support = True

    def __init__(self, probs=None, logits=None, validate_args=None):
        self._categorical = Categorical(probs, logits)
        batch_shape = self._categorical.batch_shape
        event_shape = self._categorical.param_shape[-1:]
        super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(OneHotCategorical, _instance)
        batch_shape = torch.Size(batch_shape)
        new._categorical = self._categorical.expand(batch_shape)
        super(OneHotCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False)
        new._validate_args = self._validate_args
        return new

    def _new(self, *args, **kwargs):
        return self._categorical._new(*args, **kwargs)

    @property
    def _param(self):
        return self._categorical._param

    @property
    def probs(self):
        return self._categorical.probs

    @property
    def logits(self):
        return self._categorical.logits

    @property
    def mean(self):
        return self._categorical.probs

    @property
    def variance(self):
        return self._categorical.probs * (1 - self._categorical.probs)

    @property
    def param_shape(self):
        return self._categorical.param_shape

    def sample(self, sample_shape=torch.Size()):
        sample_shape = torch.Size(sample_shape)
        probs = self._categorical.probs
        num_events = self._categorical._num_events
        indices = self._categorical.sample(sample_shape)
        return torch.nn.functional.one_hot(indices, num_events).to(probs)

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        indices = value.max(-1)[1]
        return self._categorical.log_prob(indices)

    def entropy(self):
        return self._categorical.entropy()

    def enumerate_support(self, expand=True):
        n = self.event_shape[0]
        values = torch.eye(n, dtype=self._param.dtype, device=self._param.device)
        values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
        if expand:
            values = values.expand((n,) + self.batch_shape + (n,))
        return values
Exemplo n.º 19
0
    n_steps = 100000
    L2_losses = []
    steps_list = []
    for step in range(n_steps):

        optim.zero_grad()

        loss = 0
        net_loss = 0
        for i in range(batch_size):
            x = sample_true()
            logits = encoder.net(x)
            # print (logits.shape)
            # print (torch.softmax(logits, dim=0))
            # fsfd
            cat = Categorical(probs=torch.softmax(logits, dim=0))
            cluster = cat.sample()
            logprob_cluster = cat.log_prob(cluster.detach())
            # print (logprob_cluster)
            pxz = logprob_undercomponent(
                x,
                component=cluster,
                needsoftmax_mixtureweight=needsoftmax_mixtureweight,
                cuda=False)
            f = pxz - logprob_cluster
            # print (f)
            # logprob = logprob_givenmixtureeweights(x, needsoftmax_mixtureweight)
            net_loss += -f.detach() * logprob_cluster
            loss += -f
        loss = loss / batch_size
        net_loss = net_loss / batch_size





#REINFORCE
print ('REINFORCE')

# def sample_reinforce_given_class(logits, samp):    
#     return logprob

grads = []
for i in range (N):

    dist = Categorical(logits=logits)
    samp = dist.sample()
    logprob = dist.log_prob(samp)
    reward = f(samp) 
    gradlogprob = torch.autograd.grad(outputs=logprob, inputs=(logits), retain_graph=True)[0]
    grads.append(reward*gradlogprob)
    
print ()
grads = torch.stack(grads).view(N,C)
# print (grads.shape)
grad_mean_reinforce = torch.mean(grads,dim=0)
grad_std_reinforce = torch.std(grads,dim=0)

print ('REINFORCE')
print ('mean:', grad_mean_reinforce)
print ('std:', grad_std_reinforce)
Exemplo n.º 21
0
    n_steps = 100000
    L2_losses = []
    steps_list = []
    for step in range(n_steps):

        optim.zero_grad()

        loss = 0
        net_loss = 0
        for i in range(batch_size):
            x = sample_true()
            logits = encoder.net(x)
            # print (logits.shape)
            # print (torch.softmax(logits, dim=0))
            # fsfd
            cat = Categorical(probs= torch.softmax(logits, dim=0))
            cluster = cat.sample()
            logprob_cluster = cat.log_prob(cluster.detach())
            # print (logprob_cluster)
            pxz = logprob_undercomponent(x, component=cluster, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=False)
            f = pxz - logprob_cluster
            # print (f)
            # logprob = logprob_givenmixtureeweights(x, needsoftmax_mixtureweight)
            net_loss += -f.detach() * logprob_cluster
            loss += -f
        loss = loss / batch_size
        net_loss = net_loss / batch_size

        # print (loss, net_loss)

        loss.backward(retain_graph=True)  
Exemplo n.º 22
0
 def get_policy(self, obs):
     logits = self.policy(obs)
     if self.discrete_action:
         return Categorical(logits=logits)
     else:
         return Normal(logits, self.log_std.exp())
Exemplo n.º 23
0
def valor(args):
    if not hasattr(args, "get"):
        args.get = args.__dict__.get
    env_fn = args.get('env_fn', lambda: gym.make('HalfCheetah-v2'))
    actor_critic = args.get('actor_critic', ActorCritic)
    ac_kwargs = args.get('ac_kwargs', {})
    disc = args.get('disc', Discriminator)
    dc_kwargs = args.get('dc_kwargs', {})
    seed = args.get('seed', 0)
    episodes_per_epoch = args.get('episodes_per_epoch', 40)
    epochs = args.get('epochs', 50)
    gamma = args.get('gamma', 0.99)
    pi_lr = args.get('pi_lr', 3e-4)
    vf_lr = args.get('vf_lr', 1e-3)
    dc_lr = args.get('dc_lr', 2e-3)
    train_v_iters = args.get('train_v_iters', 80)
    train_dc_iters = args.get('train_dc_iters', 50)
    train_dc_interv = args.get('train_dc_interv', 2)
    lam = args.get('lam', 0.97)
    max_ep_len = args.get('max_ep_len', 1000)
    logger_kwargs = args.get('logger_kwargs', {})
    context_dim = args.get('context_dim', 4)
    max_context_dim = args.get('max_context_dim', 64)
    save_freq = args.get('save_freq', 10)
    k = args.get('k', 1)

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    # seed += 10000 * proc_id()
    torch.manual_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    ac_kwargs['action_space'] = env.action_space

    # Model
    actor_critic = actor_critic(input_dim=obs_dim[0] + max_context_dim,
                                **ac_kwargs)
    disc = disc(input_dim=obs_dim[0], context_dim=max_context_dim, **dc_kwargs)

    # Buffer
    local_episodes_per_epoch = episodes_per_epoch  # int(episodes_per_epoch / num_procs())
    buffer = Buffer(max_context_dim, obs_dim[0], act_dim[0],
                    local_episodes_per_epoch, max_ep_len, train_dc_interv)

    # Count variables
    var_counts = tuple(
        count_vars(module)
        for module in [actor_critic.policy, actor_critic.value_f, disc.policy])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d, \t d: %d\n' %
               var_counts)

    # Optimizers
    #Optimizer for RL Policy
    train_pi = torch.optim.Adam(actor_critic.policy.parameters(), lr=pi_lr)

    #Optimizer for value function (for actor-critic)
    train_v = torch.optim.Adam(actor_critic.value_f.parameters(), lr=vf_lr)

    #Optimizer for decoder
    train_dc = torch.optim.Adam(disc.policy.parameters(), lr=dc_lr)

    #pdb.set_trace()

    # Parameters Sync
    #sync_all_params(actor_critic.parameters())
    #sync_all_params(disc.parameters())
    '''
    Training function
    '''
    def update(e):
        obs, act, adv, pos, ret, logp_old = [
            torch.Tensor(x) for x in buffer.retrieve_all()
        ]

        # Policy
        #pdb.set_trace()
        _, logp, _ = actor_critic.policy(obs, act, batch=False)
        #pdb.set_trace()
        entropy = (-logp).mean()

        # Policy loss
        pi_loss = -(logp * (k * adv + pos)).mean()

        # Train policy (Go through policy update)
        train_pi.zero_grad()
        pi_loss.backward()
        # average_gradients(train_pi.param_groups)
        train_pi.step()

        # Value function
        v = actor_critic.value_f(obs)
        v_l_old = F.mse_loss(v, ret)
        for _ in range(train_v_iters):
            v = actor_critic.value_f(obs)
            v_loss = F.mse_loss(v, ret)

            # Value function train
            train_v.zero_grad()
            v_loss.backward()
            # average_gradients(train_v.param_groups)
            train_v.step()

        # Discriminator
        if (e + 1) % train_dc_interv == 0:
            print('Discriminator Update!')
            con, s_diff = [torch.Tensor(x) for x in buffer.retrieve_dc_buff()]
            _, logp_dc, _ = disc(s_diff, con)
            d_l_old = -logp_dc.mean()

            # Discriminator train
            for _ in range(train_dc_iters):
                _, logp_dc, _ = disc(s_diff, con)
                d_loss = -logp_dc.mean()
                train_dc.zero_grad()
                d_loss.backward()
                # average_gradients(train_dc.param_groups)
                train_dc.step()

            _, logp_dc, _ = disc(s_diff, con)
            dc_l_new = -logp_dc.mean()
        else:
            d_l_old = 0
            dc_l_new = 0

        # Log the changes
        _, logp, _, v = actor_critic(obs, act)
        pi_l_new = -(logp * (k * adv + pos)).mean()
        v_l_new = F.mse_loss(v, ret)
        kl = (logp_old - logp).mean()
        logger.store(LossPi=pi_loss,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=entropy,
                     DeltaLossPi=(pi_l_new - pi_loss),
                     DeltaLossV=(v_l_new - v_l_old),
                     LossDC=d_l_old,
                     DeltaLossDC=(dc_l_new - d_l_old))
        # logger.store(Adv=adv.reshape(-1).numpy().tolist(), Pos=pos.reshape(-1).numpy().tolist())

    start_time = time.time()
    #Resets observations, rewards, done boolean
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

    #Creates context distribution where each logit is equal to one (This is first place to make change)
    context_dim_prob_dict = {
        i: 1 / context_dim if i < context_dim else 0
        for i in range(max_context_dim)
    }
    last_phi_dict = {i: 0 for i in range(context_dim)}
    context_dist = Categorical(
        probs=torch.Tensor(list(context_dim_prob_dict.values())))
    total_t = 0

    for epoch in range(epochs):
        #Sets actor critic and decoder (discriminator) into eval mode
        actor_critic.eval()
        disc.eval()

        #Runs the policy local_episodes_per_epoch before updating the policy
        for index in range(local_episodes_per_epoch):
            # Sample from context distribution and one-hot encode it (Step 2)
            # Every time we run the policy we sample a new context

            c = context_dist.sample()
            c_onehot = F.one_hot(c, max_context_dim).squeeze().float()
            for _ in range(max_ep_len):
                concat_obs = torch.cat(
                    [torch.Tensor(o.reshape(1, -1)),
                     c_onehot.reshape(1, -1)], 1)
                '''
                Feeds in observation and context into actor_critic which spits out a distribution 
                Label is a sample from the observation
                pi is the action sampled
                logp is the log probability of some other action a
                logp_pi is the log probability of pi 
                v_t is the value function
                '''
                a, _, logp_t, v_t = actor_critic(concat_obs)

                #Stores context and all other info about the state in the buffer
                buffer.store(c,
                             concat_obs.squeeze().detach().numpy(),
                             a.detach().numpy(), r, v_t.item(),
                             logp_t.detach().numpy())
                logger.store(VVals=v_t)

                o, r, d, _ = env.step(a.detach().numpy()[0])
                ep_ret += r
                ep_len += 1
                total_t += 1

                terminal = d or (ep_len == max_ep_len)
                if terminal:
                    # Key stuff with discriminator
                    dc_diff = torch.Tensor(buffer.calc_diff()).unsqueeze(0)
                    #Context
                    con = torch.Tensor([float(c)]).unsqueeze(0)
                    #Feed in differences between each state in your trajectory and a specific context
                    #Here, this is just the log probability of the label it thinks it is
                    _, _, log_p = disc(dc_diff, con)
                    buffer.end_episode(log_p.detach().numpy())
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, [actor_critic, disc], None)

        # Sets actor_critic and discriminator into training mode
        actor_critic.train()
        disc.train()

        update(epoch)
        #Need to implement curriculum learning here to update context distribution
        ''' 
            #Psuedocode:
            Loop through each of d episodes taken in local_episodes_per_epoch and check log probability from discrimantor
            If >= 0.86, increase k in the following manner: k = min(int(1.5*k + 1), Kmax)
            Kmax = 64
        '''

        decoder_accs = []
        stag_num = 10
        stag_pct = 0.05

        if (epoch + 1) % train_dc_interv == 0 and epoch > 0:
            #pdb.set_trace()
            con, s_diff = [torch.Tensor(x) for x in buffer.retrieve_dc_buff()]
            print("Context: ", con)
            print("num_contexts", len(con))
            _, logp_dc, _ = disc(s_diff, con)
            log_p_context_sample = logp_dc.mean().detach().numpy()

            print("Log Probability context sample", log_p_context_sample)

            decoder_accuracy = np.exp(log_p_context_sample)
            print("Decoder Accuracy", decoder_accuracy)

            logger.store(LogProbabilityContext=log_p_context_sample,
                         DecoderAccuracy=decoder_accuracy)
            '''
            Create score (phi(i)) = -log_p_context_sample.mean() for each specific context 
            Assign phis to each specific context
            Get p(i) in the following manner: (phi(i) + epsilon)
            Get Probabilities by doing p(i)/sum of all p(i)'s 
            '''
            logp_np = logp_dc.detach().numpy()
            con_np = con.detach().numpy()
            phi_dict = {i: 0 for i in range(context_dim)}
            count_dict = {i: 0 for i in range(context_dim)}
            for i in range(len(logp_np)):
                current_con = con_np[i]
                phi_dict[current_con] += logp_np[i]
                count_dict[current_con] += 1
            print(phi_dict)

            phi_dict = {
                k: last_phi_dict[k] if count_dict[k] == 0 else
                (-1) * v / count_dict[k]
                for (k, v) in phi_dict.items()
            }
            sorted_dict = dict(
                sorted(phi_dict.items(),
                       key=lambda item: item[1],
                       reverse=True))
            sorted_dict_keys = list(sorted_dict.keys())
            rank_dict = {
                sorted_dict_keys[i]: 1 / (i + 1)
                for i in range(len(sorted_dict_keys))
            }
            rank_dict_sum = sum(list(rank_dict.values()))
            context_dim_prob_dict = {
                k: rank_dict[k] / rank_dict_sum if k < context_dim else 0
                for k in context_dim_prob_dict.keys()
            }
            print(context_dim_prob_dict)

            decoder_accs.append(decoder_accuracy)
            stagnated = (len(decoder_accs) > stag_num
                         and (decoder_accs[-stag_num - 1] - decoder_accuracy) /
                         stag_num < stag_pct)
            if stagnated:
                new_context_dim = max(int(0.75 * context_dim), 5)
            elif decoder_accuracy >= 0.86:
                new_context_dim = min(int(1.5 * context_dim + 1),
                                      max_context_dim)
            if stagnated or decoder_accuracy >= 0.86:
                print("new_context_dim: ", new_context_dim)
                new_context_prob_arr = np.array(
                    new_context_dim * [1 / new_context_dim] +
                    (max_context_dim - new_context_dim) * [0])
                context_dist = Categorical(
                    probs=ptu.from_numpy(new_context_prob_arr))
                context_dim = new_context_dim

            for i in range(context_dim):
                if i in phi_dict:
                    last_phi_dict[i] = phi_dict[i]
                elif i not in last_phi_dict:
                    last_phi_dict[i] = max(phi_dict.values())

            buffer.clear_dc_buff()
        else:
            logger.store(LogProbabilityContext=0, DecoderAccuracy=0)

        # Log
        logger.store(ContextDim=context_dim)
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', total_t)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('LossDC', average_only=True)
        logger.log_tabular('DeltaLossDC', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.log_tabular('LogProbabilityContext', average_only=True)
        logger.log_tabular('DecoderAccuracy', average_only=True)
        logger.log_tabular('ContextDim', average_only=True)
        logger.dump_tabular()
def sample_reinforce_given_class(logits, samp):
    dist = Categorical(logits=logits)
    logprob = dist.log_prob(samp)
    return logprob
# mylogprobgrad = torch.autograd.grad(outputs=mylogprob, inputs=(probs), retain_graph=True)[0]

print('rewards', rewards)
print('probs', probs)
print('logits', logits)

#REINFORCE
print('REINFORCE')

# def sample_reinforce_given_class(logits, samp):
#     return logprob

grads = []
for i in range(N):

    dist = Categorical(logits=logits)
    samp = dist.sample()
    logprob = dist.log_prob(samp)
    reward = f(samp)
    gradlogprob = torch.autograd.grad(outputs=logprob,
                                      inputs=(logits),
                                      retain_graph=True)[0]
    grads.append(reward * gradlogprob)

print()
grads = torch.stack(grads).view(N, C)
# print (grads.shape)
grad_mean_reinforce = torch.mean(grads, dim=0)
grad_std_reinforce = torch.std(grads, dim=0)

print('REINFORCE')

    #REINFORCE


    # dist = LogitRelaxedBernoulli(torch.Tensor([1.]), bern_param)
    # dist_bernoulli = Bernoulli(bern_param)
    C= 2
    n_components = C
    B=1
    probs = torch.ones(B,C)
    bern_param = bern_param.view(B,1)
    aa = 1 - bern_param
    probs = torch.cat([aa, bern_param], dim=1)

    cat = Categorical(probs= probs)

    grads = []
    for i in range(n):
        b = cat.sample()
        logprob = cat.log_prob(b.detach())
        # b_ = torch.argmax(z, dim=1)

        logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0]
        grad = f(b) * logprobgrad

        grads.append(grad[0][0].data.numpy())

    print ('Grad Estimator: Reinfoce categorical')
    print ('Grad mean', np.mean(grads))
    print ('Grad std', np.std(grads))
Exemplo n.º 27
0
    def decode_where(self, input_coord_logits, input_attri_logits,
                     input_patch_vectors, sample_mode):
        """
        Inputs:
            where_states containing
            - **coord_logits**  (bsize, 1, grid_dim)
            - **attri_logits**  (bsize, 1, scale_ratio_dim, grid_dim)
            - **patch_vectors** (bsize, 1, patch_feature_dim, grid_dim)
            sample_mode
              0: top 1, 1: multinomial

        Outputs
            - **sample_inds**   (bsize, 3)
            - **sample_vecs**   (bsize, patch_feature_dim)
        """

        ##############################################################
        # Sampling locations
        ##############################################################
        coord_logits = input_coord_logits.squeeze(1)
        if sample_mode == 0:
            _, sample_coord_inds = torch.max(coord_logits + 1.0,
                                             dim=-1,
                                             keepdim=True)
        else:
            sample_coord_inds = Categorical(coord_logits).sample().unsqueeze(
                -1)

        ##############################################################
        # Sampling attributes and patch vectors
        ##############################################################

        patch_vectors = input_patch_vectors.squeeze(1)
        bsize, tsize, grid_dim = patch_vectors.size()
        aux_pos_inds = sample_coord_inds.expand(bsize, tsize).unsqueeze(-1)
        sample_patch_vectors = torch.gather(patch_vectors, -1,
                                            aux_pos_inds).squeeze(-1)

        attri_logits = input_attri_logits.squeeze(1)
        bsize, tsize, grid_dim = attri_logits.size()
        aux_pos_inds = sample_coord_inds.expand(bsize, tsize).unsqueeze(-1)
        local_logits = torch.gather(attri_logits, -1, aux_pos_inds).squeeze(-1)

        scale_logits = local_logits[:, :self.cfg.num_scales]
        ratio_logits = local_logits[:, self.cfg.num_scales:]

        if sample_mode == 0:
            _, sample_scale_inds = torch.max(scale_logits + 1.0,
                                             dim=-1,
                                             keepdim=True)
            _, sample_ratio_inds = torch.max(ratio_logits + 1.0,
                                             dim=-1,
                                             keepdim=True)
        else:
            sample_scale_inds = Categorical(scale_logits).sample().unsqueeze(
                -1)
            sample_ratio_inds = Categorical(ratio_logits).sample().unsqueeze(
                -1)

        sample_inds = torch.cat(
            [sample_coord_inds, sample_scale_inds, sample_ratio_inds], -1)

        return sample_inds, sample_patch_vectors
Exemplo n.º 28
0
    def forward(self, hidden_state=None):
        """
        Performs a forward pass. If training, use Gumbel Softmax (hard) for sampling, else use
        discrete sampling.
        Hidden state here represents the encoded image/metadata - initializes the RNN from it.
        """

        # hidden_state = self.input_module(hidden_state)
        state, batch_size = self._init_state(hidden_state, type(self.rnn))

        # Init output
        if not (self.vqvae and not self.discrete_communication
                and not self.rl):
            output = [
                torch.zeros(
                    (batch_size, self.vocab_size),
                    dtype=torch.float32,
                    device=self.device,
                )
            ]
            output[0][:, self.sos_id] = 1.0
        else:
            # In vqvae case, there is no sos symbol, since all words come from the unordered embedding table.
            # It is not possible to index code words by sos or eos symbols, since the number of codewords
            # is not necessarily the vocab size!
            output = [
                torch.zeros(
                    (batch_size, self.vocab_size),
                    dtype=torch.float32,
                    device=self.device,
                )
            ]

        # Keep track of sequence lengths
        initial_length = self.output_len + 1  # add the sos token
        seq_lengths = (
            torch.ones([batch_size], dtype=torch.int64, device=self.device) *
            initial_length
        )  # [initial_length, initial_length, ..., initial_length]. This gets reduced whenever it ends somewhere.

        embeds = []  # keep track of the embedded sequence
        sentence_probability = torch.zeros((batch_size, self.vocab_size),
                                           device=self.device)
        losses_2_3 = torch.empty(self.output_len, device=self.device)
        entropy = torch.empty((batch_size, self.output_len),
                              device=self.device)
        message_logits = torch.empty((batch_size, self.output_len),
                                     device=self.device)

        if self.vqvae:
            distance_computer = EmbeddingtableDistances(self.e)

        for i in range(self.output_len):

            emb = torch.matmul(output[-1], self.embedding)

            embeds.append(emb)

            state = self.rnn.forward(emb, state)

            if type(self.rnn) is nn.LSTMCell:
                h, _ = state
            else:
                h = state

            indices = [None] * batch_size

            if not self.rl:
                if not self.vqvae:
                    # That's the original baseline setting
                    p = F.softmax(self.linear_out(h), dim=1)
                    token, sentence_probability = self.calculate_token_gumbel_softmax(
                        p, self.tau, sentence_probability, batch_size)
                else:
                    pre_quant = self.linear_out(h)

                    if not self.discrete_communication:
                        token = self.vq.apply(pre_quant, self.e, indices)
                    else:
                        distances = distance_computer(pre_quant)
                        softmin = F.softmax(-distances, dim=1)
                        if not self.gumbel_softmax:
                            token = self.hard_max.apply(
                                softmin, indices, self.discrete_latent_number
                            )  # This also updates the indices
                        else:
                            _, indices[:] = torch.max(softmin, dim=1)
                            token, _ = self.calculate_token_gumbel_softmax(
                                softmin, self.tau, 0, batch_size)

            else:
                if not self.vqvae:
                    all_logits = F.log_softmax(self.linear_out(h) / self.tau,
                                               dim=1)
                else:
                    pre_quant = self.linear_out(h)
                    distances = distance_computer(pre_quant)
                    all_logits = F.log_softmax(-distances / self.tau, dim=1)
                    _, indices[:] = torch.max(all_logits, dim=1)

                distr = Categorical(logits=all_logits)
                entropy[:, i] = distr.entropy()

                if self.training:
                    token_index = distr.sample()
                    token = to_one_hot(token_index, n_dims=self.vocab_size)
                else:
                    token_index = all_logits.argmax(dim=1)
                    token = to_one_hot(token_index, n_dims=self.vocab_size)
                message_logits[:, i] = distr.log_prob(token_index)

            if not (self.vqvae and not self.discrete_communication
                    and not self.rl):
                # Whenever we have a meaningful eos symbol, we prune the messages in the end
                self._calculate_seq_len(seq_lengths,
                                        token,
                                        initial_length,
                                        seq_pos=i + 1)

            if self.vqvae:
                loss_2 = torch.mean(
                    torch.norm(pre_quant.detach() - self.e[indices], dim=1)**2)
                loss_3 = torch.mean(
                    torch.norm(pre_quant - self.e[indices].detach(), dim=1)**2)
                loss_2_3 = (
                    loss_2 + self.beta * loss_3
                )  # This corresponds to the second and third loss term in VQ-VAE
                losses_2_3[i] = loss_2_3

            token = token.to(self.device)
            output.append(token)

        messages = torch.stack(output, dim=1)
        loss_2_3_out = torch.mean(losses_2_3)

        return (
            messages,
            seq_lengths,
            entropy,
            torch.stack(embeds, dim=1),
            sentence_probability,
            loss_2_3_out,
            message_logits,
        )
def evaluate_actions_sil(pi, actions):
    cate_dist = Categorical(pi)
    return cate_dist.log_prob(
        actions.squeeze(-1)).unsqueeze(-1), cate_dist.entropy().unsqueeze(-1)
def select_actions(pi, deterministic=False):
    cate_dist = Categorical(pi)
    if deterministic:
        return torch.argmax(pi, dim=1).item()
    else:
        return cate_dist.sample().unsqueeze(-1)
Exemplo n.º 31
0
    def forward(self, obs, memory, instr_embedding=None):
        if self.use_desc and instr_embedding is None:
            if self.enable_instr and self.arch == "fusion":
                instr_embedding, instr_embedding2 = self._get_instr_embedding(
                    obs.instr)
            else:
                instr_embedding = self._get_instr_embedding(obs.instr)

        if self.use_desc and self.lang_model == "attgru":
            # outputs: B x L x D
            # memory: B x M
            mask = (obs.instr != 0).float()
            instr_embedding = instr_embedding[:, :mask.shape[1]]
            keys = self.memory2key(memory)
            pre_softmax = (keys[:, None, :] *
                           instr_embedding).sum(2) + 1000 * mask
            attention = F.softmax(pre_softmax, dim=1)
            instr_embedding = (instr_embedding * attention[:, :, None]).sum(1)

        x = torch.transpose(torch.transpose(obs.image, 1, 3), 2, 3)

        if self.arch.startswith("expert_filmcnn"):
            x = self.image_conv(x)
            for controler in self.controllers:
                x = controler(x, instr_embedding)
            x = F.relu(self.film_pool(x))
        elif self.arch == "fusion":
            # old fusion model
            '''
            x = self.image_conv(x)
            w = self.w_conv(x)
            N,_,W,H = w.shape
            w = w.view([N, self.instr_sents, -1])
            w = F.softmax(w,dim=1)
            y = torch.matmul(instr_embedding, w).view([N, 128, W, H])
            '''
            # new fusion model: separate cnns for image extractor and attention module input
            x_feat = self.image_conv(x)
            w = self.w_conv(x)
            N, _, W, H = w.shape
            w = w.view([N, self.instr_sents + 1, -1])
            w = F.softmax(w, dim=1)
            y = torch.matmul(instr_embedding, w[:, :-1]).view([N, 128, W, H])

            x = torch.cat([x_feat, y], axis=1)
            x = self.combined_conv(x)
            x = x.view(x.shape[0], x.shape[1], 1, 1)
            if self.enable_instr:
                for controler in self.controllers:
                    x = controler(x, instr_embedding2)
                x = F.relu(x)
        else:
            x = self.image_conv(x)

        x = x.reshape(x.shape[0], -1)

        if self.use_memory:
            hidden = (memory[:, :self.semi_memory_size],
                      memory[:, self.semi_memory_size:])
            hidden = self.memory_rnn(x, hidden)
            embedding = hidden[0]
            memory = torch.cat(hidden, dim=1)
        else:
            embedding = x

        if self.use_desc and not "filmcnn" in self.arch and not "fusion" in self.arch:
            embedding = torch.cat((embedding, instr_embedding), dim=1)

        if hasattr(self, 'aux_info') and self.aux_info:
            extra_predictions = {
                info: self.extra_heads[info](embedding)
                for info in self.extra_heads
            }
        else:
            extra_predictions = dict()

        x = self.actor(embedding)
        dist = Categorical(logits=F.log_softmax(x, dim=1))

        x = self.critic(embedding)
        value = x.squeeze(1)

        return {
            'dist': dist,
            'value': value,
            'memory': memory,
            'extra_predictions': extra_predictions
        }
Exemplo n.º 32
0
    def update_parameters(self, exps):
        # Collect experiences

        for _ in range(self.epochs):
            # Initialize log values

            log_entropies = []
            log_values = []
            log_policy_losses = []
            log_value_losses = []
            log_grad_norms = []

            for inds in self._get_batches_starting_indexes():
                # Initialize batch values

                batch_entropy = 0
                batch_value = 0
                batch_policy_loss = 0
                batch_value_loss = 0
                batch_loss = 0

                # Initialize memory

                if self.acmodel.recurrent:
                    memory = exps.memory[inds]

                for i in range(self.recurrence):
                    # Create a sub-batch of experience
                    sb = exps[inds + i]

                    # Compute loss

                    if self.acmodel.recurrent:
                        if self.variable_view:
                            dist, gaze, value, memory = self.acmodel(
                                sb.obs, memory * sb.mask)
                            # Combine action and gaze distributions into single 1D distribution
                            # gaze_dist = gaze[0].probs.ger(gaze[1].probs)
                            # dist = Categorical(dist.probs.ger(gaze_dist.view(-1)))

                            full_dist = []
                            for j in range(inds.shape[0]):
                                gaze_dist = gaze[0].probs[j].ger(
                                    gaze[1].probs[j]).view([-1])
                                full_action_space_dist = Categorical(
                                    (dist.probs[j].ger(gaze_dist)).view(-1))
                                full_dist.append(full_action_space_dist.probs)
                            dist = Categorical(torch.stack(full_dist))
                            sb.action = sb.action.long(
                            ) * 22 + sb.gaze[:, 0] * 11 + sb.gaze[:, 1]

                        else:
                            dist, value, memory = self.acmodel(
                                sb.obs, memory * sb.mask)
                    else:
                        dist, value = self.acmodel(sb.obs)

                    entropy = dist.entropy().mean()
                    ratio = torch.exp(dist.log_prob(sb.action) - sb.log_prob)
                    surr1 = ratio * sb.advantage
                    surr2 = torch.clamp(ratio, 1.0 - self.clip_eps,
                                        1.0 + self.clip_eps) * sb.advantage
                    policy_loss = -torch.min(surr1, surr2).mean()

                    value_clipped = sb.value + torch.clamp(
                        value - sb.value, -self.clip_eps, self.clip_eps)
                    surr1 = (value - sb.returnn).pow(2)
                    surr2 = (value_clipped - sb.returnn).pow(2)
                    value_loss = torch.max(surr1, surr2).mean()

                    loss = policy_loss - self.entropy_coef * entropy + self.value_loss_coef * value_loss

                    # Update batch values

                    batch_entropy += entropy.item()
                    batch_value += value.mean().item()
                    batch_policy_loss += policy_loss.item()
                    batch_value_loss += value_loss.item()
                    batch_loss += loss

                    # Update memories for next epoch

                    if self.acmodel.recurrent and i < self.recurrence - 1:
                        exps.memory[inds + i + 1] = memory.detach()

                # Update batch values

                batch_entropy /= self.recurrence
                batch_value /= self.recurrence
                batch_policy_loss /= self.recurrence
                batch_value_loss /= self.recurrence
                batch_loss /= self.recurrence

                # Update actor-critic

                self.optimizer.zero_grad()
                batch_loss.backward()
                grad_norm = sum(
                    p.grad.data.norm(2).item()**2
                    for p in self.acmodel.parameters())**0.5
                torch.nn.utils.clip_grad_norm_(self.acmodel.parameters(),
                                               self.max_grad_norm)
                self.optimizer.step()

                # Update log values

                log_entropies.append(batch_entropy)
                log_values.append(batch_value)
                log_policy_losses.append(batch_policy_loss)
                log_value_losses.append(batch_value_loss)
                log_grad_norms.append(grad_norm)

        # Log some values

        logs = {
            "entropy": numpy.mean(log_entropies),
            "value": numpy.mean(log_values),
            "policy_loss": numpy.mean(log_policy_losses),
            "value_loss": numpy.mean(log_value_losses),
            "grad_norm": numpy.mean(log_grad_norms)
        }

        return logs
Exemplo n.º 33
0
    def forward(self, comm, obs, memory, instr_embedding=None):

        message_embedding = torch.matmul(comm, self.comm_embed.weight)

        _, hidden = self.comm_encoder_rnn(message_embedding)
        message_encoded = hidden[-1]

        if self.student_obs_type == "vision":
            # Calculating instruction embedding
            if self.use_instr and instr_embedding is None:
                if self.lang_model == 'gru':
                    _, hidden = self.instr_rnn(self.word_embedding(obs.instr))
                    instr_embedding = hidden[-1]

            # Calculating the image imedding
            x = torch.transpose(torch.transpose(obs.image, 1, 3), 2, 3)
            if self.arch.startswith("expert_filmcnn"):
                image_embedding = self.image_conv(x)

                # Calculating FiLM_embedding from image and instruction embedding
                for controler in self.controllers:
                    x = controler(image_embedding, instr_embedding)
                FiLM_embedding = F.relu(self.film_pool(x))
            else:
                FiLM_embedding = self.image_conv(x)

            FiLM_embedding = FiLM_embedding.reshape(FiLM_embedding.shape[0],
                                                    -1)

            # Going through the memory layer
            if self.use_memory:
                hidden = (memory[:, :self.semi_memory_size],
                          memory[:, self.semi_memory_size:])
                hidden = self.memory_rnn(FiLM_embedding, hidden)
                embedding = hidden[0]
                memory = torch.cat(hidden, dim=1)
            else:
                embedding = x

            if self.use_instr and not "filmcnn" in self.arch:
                embedding = torch.cat((embedding, instr_embedding), dim=1)

            if hasattr(self, 'aux_info') and self.aux_info:
                extra_predictions = {
                    info: self.extra_heads[info](embedding)
                    for info in self.extra_heads
                }
            else:
                extra_predictions = dict()

        elif self.student_obs_type == "blind":
            embedding = torch.zeros(message_embedding.shape[0],
                                    self.semi_memory_size)
            extra_predictions = {}
            if torch.cuda.is_available():
                embedding = embedding.cuda()
        else:
            raise ValueError(
                "Student observation type must be either vision or blind")

        #Policy part
        policy_input = torch.cat((embedding, message_encoded), dim=1)
        policy_input = self.dropout(policy_input)
        x = self.actor(policy_input)
        dist = Categorical(logits=F.log_softmax(x, dim=1))

        x = self.critic(policy_input)
        value = x.squeeze(1)

        return {
            'dist': dist,
            'value': value,
            'memory_student': memory,
            'extra_predictions': extra_predictions
        }
Exemplo n.º 34
0
 def __init__(self, probs=None, logits=None, validate_args=None):
     self._categorical = Categorical(probs, logits)
     batch_shape = self._categorical.batch_shape
     event_shape = self._categorical.param_shape[-1:]
     super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
Exemplo n.º 35
0
 def sample_action(self, state):
     actions_logits = self(state)
     distribution = Categorical(logits=actions_logits)
     action = distribution.sample()
     return action
Exemplo n.º 36
0
 def get_action_and_value(self, x, action=None):
     logits = self.actor(x)
     probs = Categorical(logits=logits)
     if action is None:
         action = probs.sample()
     return action, probs.log_prob(action), probs.entropy(), self.critic(x)
Exemplo n.º 37
0
 def __init__(self, temperature, probs=None, logits=None, validate_args=None):
     self._categorical = Categorical(probs, logits)
     self.temperature = temperature
     batch_shape = self._categorical.batch_shape
     event_shape = self._categorical.param_shape[-1:]
     super(ExpRelaxedCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
Exemplo n.º 38
0
def trainAttention(model=NaiveAttention,dimbed=50,dimout=2,maxiter=epochs,epsilon=0.01,reg=None,verbose=True):
    # Le paramètre "reg" permet de choisir de régulariser ou non avec un critère entropique.
    writer = SummaryWriter("runs")
    
    if model == NaiveAttention:
        print("\n//////////////////// Attention-based LinNet : naive baseline /////////////////\n")
        name = "attentionbase.pch"
        etiq = "/Base_SGD"
    elif model == SimpleAttention:
        print("\n///////////////// Attention-based LinNet : basic implementation //////////////\n")
        name = "attentionclassic.pch"
        etiq = "/Classic_SGD"
    elif model == FurtherAttention:
        print("\n///////////////// Attention-based LinNet : further improvements //////////////\n")
        name = "attentionfurther.pch"
        etiq = "/Further_SGD_regul"
    elif model == LSTMAttention:
        print("\n//////////////////// Attention-based LinNet : adding an LSTM /////////////////\n")
        name = "attentionlstm.pch"
        etiq = "/LSTM_SGD"
    elif model == BILSTMAttention:
        print("\n//////////////////// Attention-based LinNet : adding an BiLSTM /////////////////\n")
        name = "attentionbilstm.pch"
        etiq = "/BiLSTM_SGD"
    # Creating a checkpointed model
    savepath = Path(name)
    if savepath.is_file():
        print("Restarting from previous state.")
        with savepath.open("rb") as fp :
            state = torch.load(fp)
    else:
        lin = model(dimbed,dimout).to(device)
        optim = torch.optim.SGD(params=lin.parameters(),lr=epsilon)
        state = State(lin,optim)
    
    loss = nn.CrossEntropyLoss()
    
    # Training the model
    for epoch in tqdm(range(state.epoch,maxiter)):
        state.model = state.model.train()
        losstrain = 0
        accytrain = 0
        divtrain = 0
        for x, y in train_loader:
            state.optim.zero_grad()
            y = y.to(device)
            if model == NaiveAttention: preds = state.model(x)
            else: preds, attns = state.model(x)  
            if model != NaiveAttention:
                entropytrain = Categorical(probs = attns.squeeze(2).t()).entropy()
            penalty = reg * torch.sum(entropytrain) if reg else 0
            ltrain = loss(preds,y.long()) + penalty
            
            ltrain.backward()
            state.optim.step()
            state.iteration += 1
            acctr = sum((preds.argmax(1) == y)).item() / y.shape[0]
            losstrain += ltrain
            accytrain += acctr
            divtrain += 1
            
        #if model != NaiveAttention:
        #    entropytrain = Categorical(probs = attns.squeeze(2).t()).entropy()
            
        state.model = state.model.eval()
        losstest = 0
        accytest = 0
        divtest = 0
        for x, y in test_loader:
            with torch.no_grad():
                y = y.to(device)
                if model == NaiveAttention :
                    preds = state.model(x)
                else :
                    preds, attns = state.model(x)                
                ltest = loss(preds,y.long()) 
                accts = sum((preds.argmax(1) == y)).item() / y.shape[0]  
            losstest += ltest
            accytest += accts
            divtest += 1
            
        # Saving the loss
        writer.add_scalars('Attention/Loss'+etiq,{'train':losstrain/divtrain,'test':losstest/divtest},epoch)
        writer.add_scalars('Attention/Accuracy'+etiq,{'train':accytrain/divtrain,'test':accytest/divtest},epoch)
        
        if model != NaiveAttention :
            entropytest = Categorical(probs = attns.squeeze(2).t()).entropy()
            writer.add_histogram('Attention/EntropyTest'+etiq,entropytest,epoch)
            writer.add_histogram('Attention/EntropyTrain'+etiq,entropytrain,epoch)
        
        if verbose:
            print('\nLOSS: \t\ttrain',(losstrain/divtrain).item(),'\t\ttest',(losstest/divtest).item())
            print('ACCURACY: \ttrain',accytrain/divtrain,'\t\ttest',accytest/divtest)
        
        # Saving the current state after each epoch
        with savepath.open ("wb") as fp:
            state.epoch = epoch+1
            torch.save(state, fp)
            
    print("\n\n\033[1mDone.\033[0m\n")
    writer.flush()
    writer.close()

# ////////////////////////////////////////////////////////////////////////////////////////////////// </training loop> ////
Exemplo n.º 39
0
 def __init__(self, probs=None, logits=None, validate_args=None):
     self._categorical = Categorical(probs, logits)
     batch_shape = self._categorical.batch_shape
     event_shape = self._categorical.param_shape[-1:]
     super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
Exemplo n.º 40
0
def train(meta_decoder, decoder_optimizer, fclayers_for_hyper_params):
    global moving_average
    global moving_average_alpha
    decoder_hidden = meta_decoder.initHidden()
    decoder_optimizer.zero_grad()
    output = torch.zeros([1, 1, meta_decoder.output_size], device=device)
    softmax = nn.Softmax(dim=1)
    softmax_outputs_stored = list()
    loss = 0
    #
    for i in range(3):
        output, decoder_hidden = meta_decoder(output, decoder_hidden)
        #print(hyper_params[i])
        softmax_outputs_stored.append(
            softmax(fclayers_for_hyper_params[hyper_params[i][0]](output)))
        #
    output_interaction = softmax_outputs_stored[-1]
    type_of_interaction = Categorical(output_interaction).sample().tolist()[0]
    if type_of_interaction == 0:
        # PairwiseEuDist
        for i in range(3, 4):
            output, decoder_hidden = meta_decoder(output, decoder_hidden)
            softmax_outputs_stored.append(
                softmax(fclayers_for_hyper_params[hyper_params[i][0]](output)))
    elif type_of_interaction == 1:
        # PairwiseLog
        # no hyper-params for this interaction type
        pass
    else:
        # PointwiseMLPCE
        for i in range(4, 7):
            output, decoder_hidden = meta_decoder(output, decoder_hidden)
            softmax_outputs_stored.append(
                softmax(fclayers_for_hyper_params[hyper_params[i][0]](output)))
    #
    resulted_str = []
    for outputs in softmax_outputs_stored:
        print("softmax_outputs: ", outputs)
        idx = Categorical(outputs).sample()
        resulted_str.append(idx.tolist()[0])
    resulted_str[
        2] = type_of_interaction  # the type of interaction has already been sampled before
    resulted_idx = resulted_str
    resulted_str = "_".join(map(str, resulted_str))
    print("resulted_str: " + resulted_str)
    #
    reward = calc_reward_given_descriptor(resulted_str)
    if moving_average == -19013:
        moving_average = reward
        reward = 0.0
    else:
        tmp = reward
        reward = reward - moving_average
        moving_average = moving_average_alpha * tmp + (
            1.0 - moving_average_alpha) * moving_average
    #
    print("current reward: " + str(reward))
    print("current moving average: " + str(moving_average))
    expectedReward = 0
    for i in range(len(softmax_outputs_stored)):
        logprob = torch.log(softmax_outputs_stored[i][0][resulted_idx[i]])
        expectedReward += logprob * reward
    loss = -expectedReward
    print('loss:', loss)
    # finally, backpropagate the loss according to the policy
    loss.backward()
    decoder_optimizer.step()
Exemplo n.º 41
0
def HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
    cat_bernoulli = Categorical(probs=probs)

    net_loss = 0
    surr_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)

        logq_z = cat.log_prob(cluster_S.detach()).view(B,1)
        logq_b = cat_bernoulli.log_prob(cluster_H.detach()).view(B,1)


        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f_z = logpxz - logq_z - 1.
        f_b = logpxz - logq_b - 1.

        surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
        # surr_pred, alpha = surrogate.net(surr_input)
        surr_pred = surrogate.net(surr_input)
        alpha = torch.sigmoid(surrogate2.net(x))

        net_loss += - torch.mean(     alpha.detach()*(f_z.detach()  - surr_pred.detach()) * logq_z  
                                    + alpha.detach()*surr_pred 
                                    + (1-alpha.detach())*(f_b.detach()  ) * logq_b)

        # surr_loss += torch.mean(torch.abs(f_z.detach() - surr_pred))

        grad_logq_z = torch.mean( torch.autograd.grad([torch.mean(logq_z)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        grad_logq_b =  torch.mean( torch.autograd.grad([torch.mean(logq_b)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        # print (alpha.shape, f_z.shape, surr_pred.shape, grad_logq_z.shape, grad_surr.shape)
        # fsdfa
        # grad_surr = torch.autograd.grad([surr_pred[0]], [logits], create_graph=True, retain_graph=True)[0]
        # print (grad_surr)
        # fsdfasd
        surr_loss += torch.mean(
                                    (alpha*(f_z.detach() - surr_pred) * grad_logq_z 
                                    + alpha*grad_surr
                                    + (1-alpha)*(f_b.detach()) * grad_logq_b )**2
                                    )

        surr_dif = torch.mean(torch.abs(f_z.detach() - surr_pred))
        # gradd = torch.autograd.grad([surr_loss], [alpha], create_graph=True, retain_graph=True)[0]
        # print (gradd)
        # fdsf
        grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
        grad_score = torch.autograd.grad([torch.mean((f_z.detach() - surr_pred.detach()) * logq_z)], [logits], create_graph=True, retain_graph=True)[0]
        grad_path = torch.mean(torch.abs(grad_path))
        grad_score = torch.mean(torch.abs(grad_score))


    net_loss = net_loss/ k
    surr_loss = surr_loss/ k

    return net_loss, f_b, logpx_given_z, logpz, logq_b, surr_loss, surr_dif, grad_path, grad_score, torch.mean(alpha)
Exemplo n.º 42
0
 def _distribution(self, obs):
     logits = self.logits_net(obs)
     return Categorical(logits=logits)
Exemplo n.º 43
0
 def get_pi(self, x):
     return Categorical(logits=self.actor(self.network(x.permute((0, 3, 1, 2)) / 255.0)))
Exemplo n.º 44
0
 def get_action(self, x, action=None):
     logits = self.actor(self.forward(x))
     probs = Categorical(logits=logits)
     if action is None:
         action = probs.sample()
     return action, probs.log_prob(action), probs.entropy()
Exemplo n.º 45
0
 def get_policy(self, observations: np.array):
     """ Turn an observation into a policy action distribution """
     logits = self.policy_net(observations)
     return Categorical(logits=logits)
def simplax():



    def show_surr_preds():

        batch_size = 1

        rows = 3
        cols = 1
        fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

        for i in range(rows):

            x = sample_true(1).cuda() #.view(1,1)
            logits = encoder.net(x)
            probs = torch.softmax(logits, dim=1)
            cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
            cluster_S = cat.rsample()
            cluster_H = H(cluster_S)
            logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
            check_nan(logprob_cluster)

            z = cluster_S

            n_evals = 40
            x1 = np.linspace(-9,205, n_evals)
            x = torch.from_numpy(x1).view(n_evals,1).float().cuda()
            z = z.repeat(n_evals,1)
            cluster_H = cluster_H.repeat(n_evals,1)
            xz = torch.cat([z,x], dim=1) 

            logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
            f = logpxz #- logprob_cluster

            surr_pred = surrogate.net(xz)
            surr_pred = surr_pred.data.cpu().numpy()
            f = f.data.cpu().numpy()

            col =0
            row = i
            # print (row)
            ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

            ax.plot(x1,surr_pred, label='Surr')
            ax.plot(x1,f, label='f')
            ax.set_title(str(cluster_H[0]))
            ax.legend()


        # save_dir = home+'/Documents/Grad_Estimators/GMM/'
        plt_path = exp_dir+'gmm_surr.png'
        plt.savefig(plt_path)
        print ('saved training plot', plt_path)
        plt.close()




    def plot_dist():


        mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)

        rows = 1
        cols = 1
        fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

        col =0
        row = 0
        ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)


        xs = np.linspace(-9,205, 300)
        sum_ = np.zeros(len(xs))

        # C = 20
        for c in range(n_components):
            m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
            ys = []
            for x in xs:
                # component_i = (torch.exp(m.log_prob(x) )* ((c+5.) / 290.)).numpy()
                component_i = (torch.exp(m.log_prob(x) )* mixture_weights[c]).detach().cpu().numpy()


                ys.append(component_i)

            ys = np.reshape(np.array(ys), [-1])
            sum_ += ys
            ax.plot(xs, ys, label='')

        ax.plot(xs, sum_, label='')

        # save_dir = home+'/Documents/Grad_Estimators/GMM/'
        plt_path = exp_dir+'gmm_plot_dist.png'
        plt.savefig(plt_path)
        print ('saved training plot', plt_path)
        plt.close()
        


    def get_loss():

        x = sample_true(batch_size).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        # cluster_onehot = torch.zeros(n_components)
        # cluster_onehot[cluster_H] = 1.
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
        check_nan(logprob_cluster)

        logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
        f = logpxz - logprob_cluster

        surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
        surr_pred = surrogate.net(surr_input)
        
        # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred)
        # loss = - torch.mean(f)
        surr_loss = torch.mean(torch.abs(logpxz.detach()-surr_pred))

        return surr_loss


    def plot_posteriors(needsoftmax_mixtureweight, name=''):

        x = sample_true(1).cuda() 
        trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)

        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1).view(n_components)

        trueposterior = trueposterior.data.cpu().numpy()
        qz = probs.data.cpu().numpy()

        error = L2_mixtureweights(trueposterior,qz)
        kl = KL_mixutreweights(p=trueposterior, q=qz)


        rows = 1
        cols = 1
        fig = plt.figure(figsize=(8+cols,8+rows), facecolor='white') #, dpi=150)

        col =0
        row = 0
        ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

        width = .3
        ax.bar(range(len(qz)), trueposterior, width=width, label='True')
        ax.bar(np.array(range(len(qz)))+width, qz, width=width, label='q')
        # ax.bar(np.array(range(len(q_b)))+width+width, q_b, width=width)
        ax.legend()
        ax.grid(True, alpha=.3)
        ax.set_title(str(error) + ' kl:' + str(kl))
        ax.set_ylim(0.,1.)

        # save_dir = home+'/Documents/Grad_Estimators/GMM/'
        plt_path = exp_dir+'posteriors'+name+'.png'
        plt.savefig(plt_path)
        print ('saved training plot', plt_path)
        plt.close()
        



    def inference_error(needsoftmax_mixtureweight):

        error_sum = 0
        kl_sum = 0
        n=10
        for i in range(n):

            # if x is None:
            x = sample_true(1).cuda() 
            trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)

            logits = encoder.net(x)
            probs = torch.softmax(logits, dim=1).view(n_components)

            error = L2_mixtureweights(trueposterior.data.cpu().numpy(),probs.data.cpu().numpy())
            kl = KL_mixutreweights(trueposterior.data.cpu().numpy(), probs.data.cpu().numpy())

            error_sum+=error
            kl_sum += kl
        
        return error_sum/n, kl_sum/n
        # fsdfa



    #SIMPLAX
    needsoftmax_mixtureweight = torch.randn(n_components, requires_grad=True, device="cuda")#.cuda()
    
    print ('current mixuture weights')
    print (torch.softmax(needsoftmax_mixtureweight, dim=0))
    print()

    encoder = NN3(input_size=1, output_size=n_components).cuda()
    surrogate = NN3(input_size=1+n_components, output_size=1).cuda()
    # optim = torch.optim.Adam([needsoftmax_mixtureweight], lr=.00004)
    # optim_net = torch.optim.Adam(encoder.parameters(), lr=.0004)
    # optim_surr = torch.optim.Adam(surrogate.parameters(), lr=.004)
    # optim = torch.optim.Adam([needsoftmax_mixtureweight], lr=.0001)
    # optim_net = torch.optim.Adam(encoder.parameters(), lr=.0001)
    optim_net = torch.optim.SGD(encoder.parameters(), lr=.0001)
    # optim_surr = torch.optim.Adam(surrogate.parameters(), lr=.005)
    temp = 1.
    batch_size = 100
    n_steps = 300000
    surrugate_steps = 0
    k = 1
    L2_losses = []
    inf_losses = []
    inf_losses_kl = []
    kl_losses_2 = []
    surr_losses = []
    steps_list =[]
    grad_reparam_list =[]
    grad_reinforce_list =[]
    f_list = []
    logpxz_list = []
    logprob_cluster_list = []
    logpx_list = []
    # logprob_cluster_list = []
    for step in range(n_steps):

        for ii in range(surrugate_steps):
            surr_loss = get_loss()
            optim_surr.zero_grad()
            surr_loss.backward()
            optim_surr.step()

        x = sample_true(batch_size).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        # print (probs)
        # fsdafsa
        # cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())

        cat = Categorical(probs=probs)
        # cluster = cat.sample()
        # logprob_cluster = cat.log_prob(cluster.detach())

        net_loss = 0
        loss = 0
        surr_loss = 0
        for jj in range(k):

            # cluster_S = cat.rsample()
            # cluster_H = H(cluster_S)
            cluster_H = cat.sample()

            # print (cluster_H.shape)
            # print (cluster_H[0])
            # fsad


            # cluster_onehot = torch.zeros(n_components)
            # cluster_onehot[cluster_H] = 1.
            # logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
            logprob_cluster = cat.log_prob(cluster_H.detach()).view(batch_size,1)
            # logprob_cluster = cat.log_prob(cluster_S).view(batch_size,1).detach()

            # cat = RelaxedOneHotCategorical(probs=probs.detach(), temperature=torch.tensor([temp]).cuda())
            # logprob_cluster = cat.log_prob(cluster_S).view(batch_size,1)

            check_nan(logprob_cluster)

            logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
            f = logpxz - logprob_cluster
            # print (f)

            # surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
            surr_input = torch.cat([probs, x], dim=1) #[B,21]
            surr_pred = surrogate.net(surr_input)
            
            # print (f.shape)
            # print (surr_pred.shape)
            # print (logprob_cluster.shape)
            # fsadfsa
            # net_loss += - torch.mean((logpxz.detach()-surr_pred.detach()) * logprob_cluster + surr_pred - logprob_cluster)
            # net_loss += - torch.mean((logpxz.detach() - surr_pred.detach() - 1.) * logprob_cluster + surr_pred)
            net_loss += - torch.mean((logpxz.detach() - 1.) * logprob_cluster)
            # net_loss += - torch.mean((logpxz.detach()) * logprob_cluster - logprob_cluster)
            loss += - torch.mean(logpxz)
            surr_loss += torch.mean(torch.abs(logpxz.detach()-surr_pred))

        net_loss = net_loss/ k
        loss = loss / k
        surr_loss = surr_loss/ k



        # if step %2==0:
        # optim.zero_grad()
        # loss.backward(retain_graph=True)  
        # optim.step()

        optim_net.zero_grad()
        net_loss.backward(retain_graph=True)
        optim_net.step()

        # optim_surr.zero_grad()
        # surr_loss.backward(retain_graph=True)
        # optim_surr.step()

        # print (torch.mean(f).cpu().data.numpy())
        # plot_posteriors(name=str(step))
        # fsdf

        # kl_batch = compute_kl_batch(x,probs)


        if step%500==0:
            print (step, 'f:', torch.mean(f).cpu().data.numpy(), 'surr_loss:', surr_loss.cpu().data.detach().numpy(), 
                            'theta dif:', L2_mixtureweights(true_mixture_weights,torch.softmax(
                                        needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()))
            # if step %5000==0:
            #     print (torch.softmax(needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()) 
            #     # test_samp, test_cluster = sample_true2() 
            #     # print (test_cluster.cpu().data.numpy(), test_samp.cpu().data.numpy(), torch.softmax(encoder.net(test_samp.cuda().view(1,1)), dim=1))           
            #     print ()

            if step > 0:
                L2_losses.append(L2_mixtureweights(true_mixture_weights,torch.softmax(
                                            needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()))
                steps_list.append(step)
                surr_losses.append(surr_loss.cpu().data.detach().numpy())

                inf_error, kl_error = inference_error(needsoftmax_mixtureweight)
                inf_losses.append(inf_error)
                inf_losses_kl.append(kl_error)

                kl_batch = compute_kl_batch(x,probs,needsoftmax_mixtureweight)
                kl_losses_2.append(kl_batch)

                logpx = copmute_logpx(x, needsoftmax_mixtureweight)
                logpx_list.append(logpx)

                f_list.append(torch.mean(f).cpu().data.detach().numpy())
                logpxz_list.append(torch.mean(logpxz).cpu().data.detach().numpy())
                logprob_cluster_list.append(torch.mean(logprob_cluster).cpu().data.detach().numpy())




                # i_feel_like_it = 1
                # if i_feel_like_it:

                if len(inf_losses) > 0:
                    print ('probs', probs[0])
                    print('logpxz', logpxz[0])
                    print('pred', surr_pred[0])
                    print ('dif', logpxz.detach()[0]-surr_pred.detach()[0])
                    print ('logq', logprob_cluster[0])
                    print ('dif*logq', (logpxz.detach()[0]-surr_pred.detach()[0])*logprob_cluster[0])
                    
                    


                    output= torch.mean((logpxz.detach()-surr_pred.detach()) * logprob_cluster, dim=0)[0] 
                    output2 = torch.mean(surr_pred, dim=0)[0]
                    output3 = torch.mean(logprob_cluster, dim=0)[0]
                    # input_ = torch.mean(probs, dim=0) #[0]
                    # print (probs.shape)
                    # print (output.shape)
                    # print (input_.shape)
                    grad_reinforce = torch.autograd.grad(outputs=output, inputs=(probs), retain_graph=True)[0]
                    grad_reparam = torch.autograd.grad(outputs=output2, inputs=(probs), retain_graph=True)[0]
                    grad3 = torch.autograd.grad(outputs=output3, inputs=(probs), retain_graph=True)[0]
                    # print (grad)
                    # print (grad_reinforce.shape)
                    # print (grad_reparam.shape)
                    grad_reinforce = torch.mean(torch.abs(grad_reinforce))
                    grad_reparam = torch.mean(torch.abs(grad_reparam))
                    grad3 = torch.mean(torch.abs(grad3))
                    # print (grad_reinforce)
                    # print (grad_reparam)
                    # dfsfda
                    grad_reparam_list.append(grad_reparam.cpu().data.detach().numpy())
                    grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy())
                    # grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy())

                    print ('reparam:', grad_reparam.cpu().data.detach().numpy())
                    print ('reinforce:', grad_reinforce.cpu().data.detach().numpy())
                    print ('logqz grad:', grad3.cpu().data.detach().numpy())

                    print ('current mixuture weights')
                    print (torch.softmax(needsoftmax_mixtureweight, dim=0))
                    print()

                    # print ()
                else:
                    grad_reparam_list.append(0.)
                    grad_reinforce_list.append(0.)                    



            if len(surr_losses) > 3  and step %1000==0:
                plot_curve(steps=steps_list,  thetaloss=L2_losses, 
                            infloss=inf_losses, surrloss=surr_losses,
                            grad_reinforce_list=grad_reinforce_list, 
                            grad_reparam_list=grad_reparam_list,
                            f_list=f_list, logpxz_list=logpxz_list,
                            logprob_cluster_list=logprob_cluster_list,
                            inf_losses_kl=inf_losses_kl,
                            kl_losses_2=kl_losses_2,
                            logpx_list=logpx_list)


                plot_posteriors(needsoftmax_mixtureweight)
                plot_dist()
                show_surr_preds()
                

            # print (f)
            # print (surr_pred)

            #Understand surr preds
            # if step %5000==0:

            # if step ==0:
                
                # fasdf
                





    data_dict = {}

    data_dict['steps'] = steps_list
    data_dict['losses'] = L2_losses

    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    with open( exp_dir+"data_simplax.p", "wb" ) as f:
        pickle.dump(data_dict, f)
    print ('saved data')
Exemplo n.º 47
0
 def get_policy(obs):
     logits = logits_net(obs)
     return Categorical(logits=logits)
Exemplo n.º 48
0
 def __init__(self, probs=None, logits=None):
     self._categorical = Categorical(probs, logits)
     batch_shape = self._categorical.probs.size()[:-1]
     event_shape = self._categorical.probs.size()[-1:]
     super(OneHotCategorical, self).__init__(batch_shape, event_shape)
Exemplo n.º 49
0
def sample_candidates(model,
                      n_samples: int,
                      encoder_output,
                      masks: Dict[str, Tensor],
                      max_output_length: int,
                      labels: dict = None):
    """
    Sample n_samples sequences from the model

    In each decoding step, find the k most likely partial hypotheses.

    :param decoder:
    :param size: size of the beam
    :param encoder_output:
    :param masks:
    :param max_output_length:
    :return:
        - stacked_output: dim?,
        - scores: dim?
    """

    # init
    transformer = model.is_transformer
    any_mask = next(iter(masks.values()))
    batch_size = any_mask.size(0)
    att_vectors = None  # not used for Transformer
    device = encoder_output.device

    masks.pop("trg", None)  # mutating one of the inputs is not good

    # Recurrent models only: initialize RNN hidden state
    if not transformer and model.decoder.bridge_layer is not None:
        hidden = model.decoder.bridge_layer(encoder_output.hidden)
    else:
        hidden = None

    # tile encoder states and decoder initial states beam_size times
    if hidden is not None:
        # layers x batch*k x dec_hidden_size
        hidden = tile(hidden, n_samples, dim=1)

    # encoder_output: batch*k x src_len x enc_hidden_size
    encoder_output.tile(n_samples, dim=0)
    masks = {k: tile(v, n_samples, dim=0) for k, v in masks.items()}

    # Transformer only: create target mask
    masks["trg"] = any_mask.new_ones([1, 1, 1]) if transformer else None

    # the structure holding all batch_size * k partial hypotheses
    alive_seq = torch.full((batch_size * n_samples, 1),
                           model.bos_index,
                           dtype=torch.long,
                           device=device)
    # the structure indicating, for each hypothesis, whether it has
    # encountered eos yet (if it has, stop updating the hypothesis
    # likelihood)
    is_finished = torch.zeros(batch_size * n_samples,
                              dtype=torch.bool,
                              device=device)

    # for each (batch x n_samples) sequence, there is a log probability
    seq_probs = torch.zeros(batch_size * n_samples, device=device)

    for step in range(1, max_output_length + 1):
        dec_input = alive_seq if transformer \
            else alive_seq[:, -1].view(-1, 1)

        # decode a step
        probs, hidden, att_scores, att_vectors = model.decode(
            trg_input=dec_input,
            encoder_output=encoder_output,
            masks=masks,
            decoder_hidden=hidden,
            prev_att_vector=att_vectors,
            unroll_steps=1,
            labels=labels,
            generate="true")

        # batch*k x trg_vocab
        # probs = model.decoder.gen_func(logits[:, -1], dim=-1).squeeze(1)

        next_ids = Categorical(probs).sample().unsqueeze(1)  # batch*k x 1
        next_scores = probs.gather(1, next_ids).squeeze(1)  # batch*k

        seq_probs = torch.where(is_finished, seq_probs,
                                seq_probs + next_scores.log())

        # append latest prediction
        # batch_size*k x hyp_len
        alive_seq = torch.cat([alive_seq, next_ids], -1)

        # update which hypotheses are finished
        is_finished = is_finished | next_ids.eq(model.eos_index).squeeze(1)

        if is_finished.all():
            break

    # final_outputs: batch x n_samples x len
    final_outputs = alive_seq.view(batch_size, n_samples, -1)
    seq_probs = seq_probs.view(batch_size, n_samples)

    return final_outputs, seq_probs
Exemplo n.º 50
0
 def loss(actions_logits, action, discounted_reward):
     distribution = Categorical(logits=actions_logits)
     return (-distribution.log_prob(action) * discounted_reward).view(-1)  # - (distribution.entropy() * (10 ** -2))  # TODO questo iperparametro va messo in un altro modo
def sample_reinforce_given_class(logits, samp):
    dist = Categorical(logits=logits)
    logprob = dist.log_prob(samp)
    return logprob
Exemplo n.º 52
0
class ExpRelaxedCategorical(Distribution):
    r"""
    Creates a ExpRelaxedCategorical parameterized by `probs` and `temperature`.
    Returns the log of a point in the simplex. Based on the interface to OneHotCategorical.

    Implementation based on [1].

    See also: :func:`torch.distributions.OneHotCategorical`

    Args:
        temperature (Tensor): relaxation temperature
        probs (Tensor): event probabilities
        logits (Tensor): the log probability of each event.

    [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
    (Maddison et al, 2017)

    [2] Categorical Reparametrization with Gumbel-Softmax
    (Jang et al, 2017)
    """
    arg_constraints = {'probs': constraints.simplex}
    support = constraints.real
    has_rsample = True

    def __init__(self, temperature, probs=None, logits=None, validate_args=None):
        self._categorical = Categorical(probs, logits)
        self.temperature = temperature
        batch_shape = self._categorical.batch_shape
        event_shape = self._categorical.param_shape[-1:]
        super(ExpRelaxedCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)

    def _new(self, *args, **kwargs):
        return self._categorical._new(*args, **kwargs)

    @property
    def param_shape(self):
        return self._categorical.param_shape

    @property
    def logits(self):
        return self._categorical.logits

    @property
    def probs(self):
        return self._categorical.probs

    def rsample(self, sample_shape=torch.Size()):
        sample_shape = torch.Size(sample_shape)
        uniforms = clamp_probs(self.logits.new(self._extended_shape(sample_shape)).uniform_())
        gumbels = -((-(uniforms.log())).log())
        scores = (self.logits + gumbels) / self.temperature
        return scores - _log_sum_exp(scores)

    def log_prob(self, value):
        K = self._categorical._num_events
        if self._validate_args:
            self._validate_sample(value)
        logits, value = broadcast_all(self.logits, value)
        log_scale = (self.temperature.new(self.temperature.shape).fill_(K).lgamma() -
                     self.temperature.log().mul(-(K - 1)))
        score = logits - value.mul(self.temperature)
        score = (score - _log_sum_exp(score)).sum(-1)
        return score + log_scale
Exemplo n.º 53
0
class DIPTransform:
    def __init__(self,
                 image_shape,
                 num_iters,
                 stop_iters,
                 input_depth=32,
                 input_noise_std=0,
                 optimizer='adam',
                 lr=1e-2,
                 plot_every=0,
                 device='cpu'):
        """
        :param image_shape: image shape (CxHxW)
        :param num_iters: number of iterations to overfit
        :param stop_iters: dict int->float specifying categorical distribution over percentages in (0, 100] representing the probability that the transform runs for KEY/100 * NUM_ITERS iters.
        :param input_depth: depth of random input. Default value taken from paper.
        :param input_noise_std: stdev of noise added to base random input at each iter. They say in the paper that this helps sometimes (seemingly using 0.03), but default is 0 (no noise).
        :param optimizer: supported optimizers are 'adam' and 'LBFGS'
        :param lr: learning rate. Per paper and code, 1e-2 works best.
        :param plot_every: how often to save images. Doesn't save images if set to 0 (default).
        :device: 'cuda' to use GPU if available.
        For example, with NUM_ITERS = 100 and STOP_ITERS = {10: 0.5, 50: 0.3, 100: 0.2}, there is a 50% chance that the transform runs for 10 iterations, a 30% chance it runs for 50 iterations, and a 20% chance it runs for the full 100 iterations. 
        """
        self.iters, probs = zip(*stop_iters.items())
        self.probs = Categorical(torch.tensor(probs))
        self.num_iters = num_iters
        self.image_shape = image_shape
        self.input_depth = input_depth
        self.input_noise_std = input_noise_std
        self.plot_every = plot_every
        self.device = device
        self.optimizer = optimizer
        self.lr = lr
        self.loss = torch.nn.MSELoss()
        # const strings from provided DIP code
        self.opt_over = 'net'
        self.const_input = 'noise'

        # initialize network
        self.net = skip(input_depth,
                        3,
                        num_channels_down=[8, 16, 32],
                        num_channels_up=[8, 16, 32],
                        num_channels_skip=[0, 0, 4],
                        upsample_mode='bilinear',
                        need_sigmoid=True,
                        need_bias=True,
                        pad='zeros',
                        act_fun='LeakyReLU').to(self.device)

    def sample_iters(self):
        return self.iters[self.probs.sample()] * self.num_iters // 100

    def run(self, image, num_iters):
        assert image.shape[
            1:] == self.image_shape, 'Wrong shape. Expected {}, got {}.'.format(
                self.image_shape, image.shape[1:])
        # run net for num_iters iterations
        net_input = get_noise(self.input_depth, self.const_input,
                              self.image_shape[1:]).to(self.device)
        net_input_saved = net_input.detach().clone()
        noise = net_input.detach().clone()
        p = get_params(self.opt_over, self.net, net_input)
        #def local_closure(iter_num):
        #    self.closure(net_input_saved, image, noise, iter_num)
        lambda_closure = lambda iter_num: self.closure(net_input_saved, image,
                                                       noise, iter_num)
        optimize(self.optimizer,
                 p,
                 lambda_closure,
                 self.lr,
                 num_iters,
                 pass_iter=True)
        transformed = self.net(net_input)
        return transformed

    def closure(self, net_input_saved, image, noise, iter_num):
        net_input = net_input_saved
        if self.input_noise_std > 0:
            net_input = net_input_saved + (noise.normal_() *
                                           self.input_noise_std)
        out = self.net(net_input)
        total_loss = self.loss(out, image)
        total_loss.backward()
        #print(total_loss)
        if self.plot_every > 0 and iter_num % self.plot_every == 0 and total_loss < 0.01:
            out_np = torch_to_np(out)
            plot_image_grid([np.clip(out_np, 0, 1)],
                            factor=4,
                            nrow=1,
                            show=False,
                            save_path=f'results_dip/imgs/{iter_num}.png')
        # maybe log loss here?

    def __call__(self, sample):
        """
        Takes in, transforms, and returns PIL image given by SAMPLE. Transformation is a random number of iterations of DIP.
        Distribution of number of iterations is specified when the transform is initialized.
        """
        torch_img = np_to_torch(tmp := pil_to_np(sample)).to(self.device)
        plot_image_grid([np.clip(tmp, 0, 1)],
                        factor=4,
                        nrow=1,
                        show=False,
                        save_path='results_dip/imgs/true.png')  # TODO remove
        num_iters = self.sample_iters()
        transformed = self.run(torch_img, num_iters)
        return np_to_pil(torch_to_np(transformed))
Exemplo n.º 54
0
class OneHotCategorical(Distribution):
    r"""
    Creates a one-hot categorical distribution parameterized by :attr:`probs` or
    :attr:`logits`.

    Samples are one-hot coded vectors of size ``probs.size(-1)``.

    .. note:: :attr:`probs` will be normalized to be summing to 1.

    See also: :func:`torch.distributions.Categorical` for specifications of
    :attr:`probs` and :attr:`logits`.

    Example::

        >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
        >>> m.sample()  # equal probability of 0, 1, 2, 3
        tensor([ 0.,  0.,  0.,  1.])

    Args:
        probs (Tensor): event probabilities
        logits (Tensor): event log probabilities
    """
    arg_constraints = {'probs': constraints.simplex}
    support = constraints.simplex
    has_enumerate_support = True

    def __init__(self, probs=None, logits=None, validate_args=None):
        self._categorical = Categorical(probs, logits)
        batch_shape = self._categorical.batch_shape
        event_shape = self._categorical.param_shape[-1:]
        super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)

    def _new(self, *args, **kwargs):
        return self._categorical._new(*args, **kwargs)

    @property
    def probs(self):
        return self._categorical.probs

    @property
    def logits(self):
        return self._categorical.logits

    @property
    def mean(self):
        return self._categorical.probs

    @property
    def variance(self):
        return self._categorical.probs * (1 - self._categorical.probs)

    @property
    def param_shape(self):
        return self._categorical.param_shape

    def sample(self, sample_shape=torch.Size()):
        sample_shape = torch.Size(sample_shape)
        probs = self._categorical.probs
        one_hot = probs.new(self._extended_shape(sample_shape)).zero_()
        indices = self._categorical.sample(sample_shape)
        if indices.dim() < one_hot.dim():
            indices = indices.unsqueeze(-1)
        return one_hot.scatter_(-1, indices, 1)

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        indices = value.max(-1)[1]
        return self._categorical.log_prob(indices)

    def entropy(self):
        return self._categorical.entropy()

    def enumerate_support(self):
        n = self.event_shape[0]
        values = self._new((n, n))
        torch.eye(n, out=values)
        values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
        return values.expand((n,) + self.batch_shape + (n,))
Exemplo n.º 55
0
 def get_pi_value_and_aux_value(self, x):
     hidden = self.network(x.permute((0, 3, 1, 2)) / 255.0)
     return Categorical(logits=self.actor(hidden)), self.critic(hidden.detach()), self.aux_critic(hidden)