Ejemplo n.º 1
0
 def test_beta_shape_tensor_params(self):
     dist = Beta(torch.Tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),
                 torch.Tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]))
     self.assertEqual(dist._batch_shape, torch.Size((3, 2)))
     self.assertEqual(dist._event_shape, torch.Size(()))
     self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
     self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
     self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
     self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
Ejemplo n.º 2
0
 def test_beta_shape_scalar_params(self):
     dist = Beta(0.1, 0.1)
     self.assertEqual(dist._batch_shape, torch.Size())
     self.assertEqual(dist._event_shape, torch.Size())
     self.assertEqual(dist.sample().size(), torch.Size((1,)))
     self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2)))
     self.assertRaises(ValueError, dist.log_prob, self.scalar_sample)
     self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
     self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
Ejemplo n.º 3
0
def test_e_log_stick():
    """
    This test DOES NOT PASS, and maybe should not
    """
    model = InfiniteIBP(4., 10, 0.1, 0.5, 36)
    model.init_z(10)

    K = model.K

    # take a lot of samples to get something working
    dist = Beta(model.tau.detach()[:, 0], model.tau.detach()[:, 1])
    samples = dist.sample((100000, ))
    f = (1. - samples.cumprod(1)).log().mean(0)
    log_stick, q = model._E_log_stick(model.tau, model.K)

    jeffrey_q = np.zeros((K, K))
    jeffrey_log_stick = np.zeros((K, ))
    for k in range(K):
        a, b = compute_q_Elogstick(model.tau.detach().numpy().T, k)
        jeffrey_q[k, :k + 1] = a
        jeffrey_log_stick[k] = b

    print("old:     {}".format(jeffrey_log_stick))
    print("new:     {}".format(log_stick.detach().numpy()))
    print("samples: {}".format(f.detach().numpy()))

    import ipdb
    ipdb.set_trace()
    def sample(self, datas):
        alpha, beta = datas

        distribution    = Beta(alpha, beta)
        action          = distribution.sample().float().to(set_device(self.use_gpu))

        return action
Ejemplo n.º 5
0
class MixUp(Callback):
    run_valid = False

    def __init__(self, alpha=0.4, onehot=False):
        self.alpha = alpha
        self.distrib = Beta(alpha, alpha)
        self.onehot = onehot

    def before_batch(self):
        bs = self.xb[0].shape[0]
        device = self.xb[0].device
        lambd = self.distrib.sample(
            (self.y.size(0), )).squeeze().to(self.x.device)
        lambd = torch.stack([lambd, 1 - lambd], 1).max(1)[0]
        shuffle = torch.randperm(bs).to(device)
        xb1, yb1 = self.xb[0][shuffle], self.yb[0][shuffle]
        a = tensor(lambd).float().view(-1, 1, 1, 1).to(device)
        self.learn.xb = tuple([a * self.xb[0] + (1 - a) * xb1])
        a = a.view(-1)
        if self.onehot:
            while len(a.shape) < len(yb1.shape):
                a = a[..., None]
            self.learn.yb = tuple([a * self.learn.yb[0] + (1 - a) * yb1])
        else:
            self.learn.yb = tuple([{
                'yb': self.learn.yb[0],
                'yb1': yb1,
                'a': a
            }])
Ejemplo n.º 6
0
    def chooseActionTrain(self, state):
        """ Choose an action during training mode
        
            Parameters
            -------
            state:
                The current state of the car.

            Returns
            -------
            action : np.ndarray
                The actions to run on the track
            coefficient : float
                The logarithmic probability for an action

            Notes
            -------
                This function is only called when the --train flag IS provided.
        """
        state = torch.from_numpy(state).double().to(
            self.hardwareDevice).unsqueeze(0)
        with torch.no_grad():
            alpha, beta = self.nn(state)[0]
        dist = Beta(alpha, beta)
        action = dist.sample()
        coefficient = dist.log_prob(action).sum(dim=1)

        action = action.squeeze().cpu().numpy()
        coefficient = coefficient.item()

        return action, coefficient
Ejemplo n.º 7
0
def test_beta_likelihood(concentration1: float, concentration0: float) -> None:
    """
    Test to check that maximizing the likelihood recovers the parameters
    """

    # generate samples
    concentration1s = torch.zeros((NUM_SAMPLES, )) + concentration1
    concentration0s = torch.zeros((NUM_SAMPLES, )) + concentration0

    distr = Beta(concentration1s, concentration0s)
    samples = distr.sample()

    init_biases = [
        inv_softplus(concentration1 -
                     START_TOL_MULTIPLE * TOL * concentration1),
        inv_softplus(concentration0 -
                     START_TOL_MULTIPLE * TOL * concentration0),
    ]

    concentration1_hat, concentration0_hat = maximum_likelihood_estimate_sgd(
        BetaOutput(),
        samples,
        init_biases=init_biases,
        learning_rate=PositiveFloat(0.05),
        num_epochs=PositiveInt(10),
    )

    assert (
        np.abs(concentration1_hat - concentration1) < TOL * concentration1
    ), f"concentration1 did not match: concentration1 = {concentration1}, concentration1_hat = {concentration1_hat}"
    assert (
        np.abs(concentration0_hat - concentration0) < TOL * concentration0
    ), f"concentration0 did not match: concentration0 = {concentration0}, concentration0_hat = {concentration0_hat}"
Ejemplo n.º 8
0
class CutMix(Callback):
    """ Cutmix callback which replaces a random patch of image data with the corresponding patch from another image.
    This callback also converts labels to one hot before combining them according to the lambda parameters, sampled from
    a beta distribution as is done in the paper.

    Example: ::

        >>> from torchbearer import Trial
        >>> from torchbearer.callbacks import CutMix

        # Example Trial which does CutMix regularisation
        >>> cutmix = CutMix(1, classes=10)
        >>> trial = Trial(None, callbacks=[cutmix], metrics=['acc'])

    Args:
        alpha (float): The alpha value for the beta distribution.
        classes (int): The number of classes for conversion to one hot.

    State Requirements:
        - :attr:`torchbearer.state.X`: State should have the current data stored
        - :attr:`torchbearer.state.Y_TRUE`: State should have the current data stored
    """
    def __init__(self, alpha, classes=-1):
        super(CutMix, self).__init__()
        self.classes = classes
        self.dist = Beta(torch.tensor([float(alpha)]),
                         torch.tensor([float(alpha)]))

    def _to_one_hot(self, target):
        if target.dim() == 1:
            target = target.unsqueeze(1)
            one_hot = torch.zeros_like(target).repeat(1, self.classes)
            one_hot.scatter_(1, target, 1)
            return one_hot
        return target

    def on_sample(self, state):
        super(CutMix, self).on_sample(state)

        lam = self.dist.sample().to(state[torchbearer.DEVICE])
        length = (1 - lam).sqrt()
        cutter = BatchCutout(
            1, (length * state[torchbearer.X].size(-1)).round().item(),
            (length * state[torchbearer.X].size(-2)).round().item())
        mask = cutter(state[torchbearer.X])
        erase_locations = mask == 0

        permutation = torch.randperm(state[torchbearer.X].size(0))

        state[torchbearer.X][erase_locations] = state[
            torchbearer.X][permutation][erase_locations]

        target = self._to_one_hot(state[torchbearer.TARGET]).float()
        state[torchbearer.
              TARGET] = lam * target + (1 - lam) * target[permutation]

    def on_sample_validation(self, state):
        super(CutMix, self).on_sample_validation(state)
        state[torchbearer.TARGET] = self._to_one_hot(
            state[torchbearer.TARGET]).float()
Ejemplo n.º 9
0
def kl_bernoulli(pi, step, args):
    cap = min(args.h_cap, step * args.h_cap / args.total_steps)
    beta_dist = Beta(torch.ones_like(pi) * args.alpha_0, torch.ones_like(pi))
    pi_prior = Bernoulli(torch.cumprod(beta_dist.sample(), dim=-1))
    pi_posterior = Bernoulli(pi)
    klh_loss = kl_divergence(pi_posterior, pi_prior).sum(dim=1).mean()
    cap_klh_loss = args.gamma_h * (klh_loss - cap).abs()
    return cap_klh_loss
Ejemplo n.º 10
0
 def test_beta_log_prob(self):
     for _ in range(100):
         alpha = np.exp(np.random.normal())
         beta = np.exp(np.random.normal())
         dist = Beta(alpha, beta)
         x = dist.sample()
         actual_log_prob = dist.log_prob(x).sum()
         expected_log_prob = scipy.stats.beta.logpdf(x, alpha, beta)[0]
         self.assertAlmostEqual(actual_log_prob, expected_log_prob, places=3, allow_inf=True)
Ejemplo n.º 11
0
 def sample(self, device, epoch, num=64):
     sample = torch.randn(num, self.latent_dim).to(device)
     x_alpha, x_beta = self.decode(sample)
     beta = Beta(x_alpha, x_beta)
     p = beta.sample()
     binomial = Binomial(255, p)
     x_sample = binomial.sample()
     x_sample = x_sample.float() / 255.
     save_image(x_sample.view(num, 1, 28, 28),
                'results/epoch_{}_samples.png'.format(epoch))
Ejemplo n.º 12
0
    def select_action(self, state):
        state = torch.from_numpy(state).double().to(device).unsqueeze(0)
        with torch.no_grad():
            alpha, beta = self.net(state)[0]
        dist = Beta(alpha, beta)
        action = dist.sample()  # 3 values in [0,1]
        a_logp = dist.log_prob(action).sum(dim=1)  # For PPO
        action = action.squeeze().cpu().numpy()
        a_logp = a_logp.item()

        return action, a_logp
    def select_action(self, state):
        state = torch.from_numpy(state).double().to(device).unsqueeze(0)
        with torch.no_grad():
            (alpha, beta), _, rcrc_s = self.net(state)
        dist = Beta(alpha, beta)
        action = dist.sample()
        a_logp = dist.log_prob(action).sum(dim=1)

        action = action.squeeze().cpu().numpy()
        a_logp = a_logp.item()
        return action, a_logp, rcrc_s
Ejemplo n.º 14
0
 def select_action(self, state):
     # deal with datatype of state and transform it
     state = torch.from_numpy(state).double().unsqueeze(0)
     with torch.no_grad():
         alpha, beta = self.net(state)[0]
     dist = Beta(alpha, beta)
     action = dist.sample()  # sampled action in interval (0, 1)
     a_logp = dist.log_prob(action).sum(
         dim=1)  # add the log probability densities of the 3-stack
     action = action.squeeze().numpy()
     a_logp = a_logp.item()
     return action, a_logp
Ejemplo n.º 15
0
def mixup(x, y, num_classes, gamma=0.2, smooth_eps=0.1):
    if gamma == 0 and smooth_eps == 0:
        return x, y
    m = Beta(torch.tensor([gamma]), torch.tensor([gamma]))
    lambdas = m.sample([x.size(0), 1, 1]).to(x)
    my = onehot(y, num_classes).to(x)
    true_class, false_class = 1. - smooth_eps * num_classes / (num_classes - 1), smooth_eps / (num_classes - 1)
    my = my * true_class + torch.ones_like(my) * false_class
    perm = torch.randperm(x.size(0))
    x2 = x[perm]
    y2 = my[perm]
    return x * (1 - lambdas) + x2 * lambdas, my * (1 - lambdas) + y2 * lambdas
Ejemplo n.º 16
0
 def reconstruct(self, x, device, epoch):
     x = x.view(-1, 784).float().to(device)
     z_mu, z_logvar = self.encode(x)
     z = self.reparameterize(z_mu, z_logvar)  # sample zs
     x_alpha, x_beta = self.decode(z)
     beta = Beta(x_alpha, x_beta)
     p = beta.sample()
     binomial = Binomial(255, p)
     x_recon = binomial.sample()
     x_recon = x_recon.float() / 255.
     x_with_recon = torch.cat((x, x_recon))
     save_image(x_with_recon.view(64, 1, 28, 28),
                'results/epoch_{}_recon.png'.format(epoch))
Ejemplo n.º 17
0
 def forward(self, x):
     with torch.no_grad():
         features = self.main(x)
         actor_features = self.actor(features)
         alpha = self.alpha_head(actor_features)+1
         beta = self.beta_head(actor_features)+1
     dist = Beta(alpha, beta)
     if not self.deterministic_sample:
         action = dist.sample().squeeze().numpy()
     else:
         action = dist.mean.squeeze().numpy()
     action[0] = action[0]*2-1
     return action
Ejemplo n.º 18
0
def kl_categorical_dp(eta, step, args):
    cap = min(args.m_cap, step * args.m_cap / args.total_steps)
    # cap = min(cap, np.log(args.disc))
    beta_dist = Beta(torch.ones_like(eta), torch.ones_like(eta) * args.beta_0)
    beta_sample = beta_dist.sample()
    neg_prod = torch.cumprod(1.0-beta_sample, dim=-1)
    beta_sample[:, 1:] = beta_sample[:, :-1] * neg_prod[:, :-1]
    beta_sample = F.softmax(beta_sample, dim=-1)
    cat_prior = Categorical(probs=beta_sample)
    cat_posterior = Categorical(probs=eta)
    klm_loss = kl_divergence(cat_posterior, cat_prior).mean()
    cap_klm_loss = args.gamma_m * (klm_loss - cap).abs()
    return cap_klm_loss
    def select_action(self, state, hidden):
        
        with torch.no_grad():
            _, latent_mu, _ = self.vae(state)
            alpha, beta = self.net(latent_mu, hidden[0])[0]
        
        dist = Beta(alpha, beta)
        action = dist.sample()
        a_logp = dist.log_prob(action).sum(dim=1)

        a_logp = a_logp.item()
        _, _, _, _, _, next_hidden = self.mdrnn(action, latent_mu, hidden)
        
        return action.squeeze().cpu().numpy(), a_logp, latent_mu, next_hidden
Ejemplo n.º 20
0
def sample_from_beta_dist(y_hat):
    """
    y_hat (batch_size x seq_len x 2):
    
    """
    # take exponentional to ensure positive
    loc_y = y_hat.exp()
    alpha = loc_y[:, :, 0].unsqueeze(-1)
    beta = loc_y[:, :, 1].unsqueeze(-1)
    dist = Beta(alpha, beta)
    sample = dist.sample()
    # rescale sample from [0,1] to [-1, 1]
    sample = 2.0 * sample - 1.0
    return sample
Ejemplo n.º 21
0
class CarlaImgPolicy(nn.Module):
    def __init__(self, input_dim, action_dim, hidden_layer=[400, 300]):
        super(CarlaImgPolicy, self).__init__()
        self.main_actor = CarlaSimpleEncoder(latent_size=input_dim - 1)
        self.main_critic = CarlaSimpleEncoder(latent_size=input_dim - 1)
        actor_layer_size = [input_dim] + hidden_layer
        actor_feature_layers = nn.ModuleList([])
        for i in range(len(actor_layer_size) - 1):
            actor_feature_layers.append(
                nn.Linear(actor_layer_size[i], actor_layer_size[i + 1]))
            actor_feature_layers.append(nn.ReLU())
        self.actor = nn.Sequential(*actor_feature_layers)
        self.alpha_head = nn.Sequential(
            nn.Linear(hidden_layer[-1], action_dim), nn.Softplus())
        self.beta_head = nn.Sequential(nn.Linear(hidden_layer[-1], action_dim),
                                       nn.Softplus())

        critic_layer_size = [input_dim] + hidden_layer
        critic_layers = nn.ModuleList([])
        for i in range(len(critic_layer_size) - 1):
            critic_layers.append(
                nn.Linear(critic_layer_size[i], critic_layer_size[i + 1]))
            critic_layers.append(nn.ReLU())
        critic_layers.append(layer_init(nn.Linear(hidden_layer[-1], 1),
                                        gain=1))
        self.critic = nn.Sequential(*critic_layers)

    def forward(self, x, action=None):
        speed = x[:, -1:]
        x = x[:, :-1].view(-1, 3, 128,
                           128)  # image size in carla driving task is 128x128
        x1 = self.main_actor(x)
        x1 = torch.cat([x1, speed], dim=1)

        x2 = self.main_critic(x)
        x2 = torch.cat([x2, speed], dim=1)

        actor_features = self.actor(x1)
        alpha = self.alpha_head(actor_features) + 1
        beta = self.beta_head(actor_features) + 1
        self.dist = Beta(alpha, beta)
        if action is None:
            action = self.dist.sample()
        else:
            action = (action + 1) / 2
        action_log_prob = self.dist.log_prob(action).sum(-1)
        entropy = self.dist.entropy().sum(-1)
        value = self.critic(x2)
        return action * 2 - 1, action_log_prob, value.squeeze(-1), entropy
def sq_log_posterior_predictive_eval(x_new, kappa, tau_0, tau_1, S):
    T = kappa.shape[0] + 1
    q_beta = Beta(torch.ones(T - 1), kappa)
    q_lambda = Gamma(tau_0, tau_1)
    beta_mc = q_beta.sample([S])
    lambda_mc = q_lambda.sample([S])
    log_prob = 0
    for s in range(S):
        post_pred_weights = mix_weights(beta_mc[s])
        post_pred_clusters = lambda_mc[s]
        for t in range(post_pred_clusters.shape[0]):
            log_prob -= post_pred_weights[t] * torch.exp(
                Poisson(post_pred_clusters[t]).log_prob(x_new))**2
    log_prob /= S
    return log_prob
Ejemplo n.º 23
0
    def experience(self, steps):

        total_obs = np.zeros((steps, ) + self.last_ob.shape +
                             (self.stack_size, ))

        total_rewards = np.zeros((steps, 1))
        total_actions = np.zeros((steps, 3))
        total_values = np.zeros((steps + 1, 1))
        masks = np.zeros((steps, 1))

        for step in range(steps):
            total_obs[step] = np.roll(total_obs[step], shift=-1, axis=-1)
            total_obs[step, :, :, -1] = self.last_ob

            alpha, beta, values = self.network(
                torch.from_numpy(total_obs[step]).type(
                    torch.FloatTensor).unsqueeze(0))
            total_values[step] = values.view(-1).detach().numpy()

            m = Beta(alpha, beta)
            actions = m.sample()

            total_actions[step] = actions.numpy()
            actions = actions.numpy() * np.array([2., 1., 1.]) - np.array(
                [1., 0., 0.])
            actions = actions.reshape((-1))

            self.last_ob, rews, dones_, _ = self.env.step(actions)
            self.env.render()

            self.last_ob = rgb2gris(self.last_ob)
            dones = np.logical_not(dones_) * 1
            total_rewards[step] = rews
            masks[step] = dones
            if dones_:
                self.env.reset()

        temp_ob = np.roll(total_obs[step], shift=-1, axis=-1)
        temp_ob[..., -1] = self.last_ob
        _, _, values = self.network(
            torch.from_numpy(temp_ob).type(torch.FloatTensor).unsqueeze(0))
        total_values[steps] = values.view(-1).detach().numpy()

        advantage, real_values = gae(total_rewards, masks, total_values)
        advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-5)

        return (total_obs, total_values, total_rewards, total_actions, masks,
                advantage, real_values)
def posterior_predictive_sample(kappa, tau_0, tau_1, S, M):
    T = kappa.shape[0] + 1
    q_beta = Beta(torch.ones(T - 1), kappa)
    q_lambda = Gamma(tau_0, tau_1)
    beta_mc = q_beta.sample([S])
    lambda_mc = q_lambda.sample([S])

    hallucinated_samples = torch.zeros(S, M)
    for s in range(S):
        post_pred_weights = mix_weights(beta_mc[s])
        post_pred_clusters = lambda_mc[s]
        hallucinated_samples[s, :] = MixtureSameFamily(
            Categorical(post_pred_weights),
            Poisson(post_pred_clusters)).sample([M])

    return hallucinated_samples
Ejemplo n.º 25
0
 def select_action(self, state, deterministic=False):
     """
     Compute an action or vector of actions given a state or vector of states
     :param state: the input state(s)
     :param deterministic: whether the policy should be considered deterministic or not
     :return: the resulting action(s)
     """
     with torch.no_grad():
         alpha, beta = self.forward(state)
         if deterministic:
             return alpha.data.numpy() / (alpha.data.numpy() +
                                          beta.data.numpy()).astype(float)
             # return np.clip(alpha.data.numpy().astype(float),-2,2)
         else:
             n = Beta(alpha, beta)
             action = n.sample()
             return action.data.numpy().astype(float)
Ejemplo n.º 26
0
    def select_action(self, state):
        if args.action_vec > 0:
            state = (torch.from_numpy(
                state[0]).float().to(device).unsqueeze(0),
                     torch.from_numpy(
                         state[1]).float().to(device).unsqueeze(0))
        else:
            state = torch.from_numpy(state).float().to(device).unsqueeze(0)
        #TODO CHANGE FOR VECTOR ACTIONS
        with torch.no_grad():
            alpha, beta = self.net(state)[0]
        dist = Beta(alpha, beta)
        action = dist.sample()
        a_logp = dist.log_prob(action).sum(dim=1)

        action = action.squeeze().cpu().numpy()
        a_logp = a_logp.item()
        return action, a_logp
Ejemplo n.º 27
0
    def traverse_grid(self,
                      cont_dim=0,
                      nrow=8,
                      ncol=8,
                      traverse=True,
                      use_prior=True,
                      set_zero=True,
                      file_name=None):
        if traverse and self.args.disc != 0:
            nrow = self.args.disc
        if set_zero:
            cont_samples = torch.zeros(nrow * ncol, self.args.cont).cuda()
        else:
            cont_samples = torch.randn(nrow * ncol, self.args.cont).cuda()
        fixed_value = torch.linspace(-2, 2, ncol).cuda()
        for row in range(nrow):
            for i in range(ncol):
                cont_samples[i + row * ncol, cont_dim] = fixed_value[i]

        if use_prior:
            v_prior = Beta(
                torch.ones_like(cont_samples) * self.args.alpha_0,
                torch.ones_like(cont_samples))
            mask_prob = torch.cumprod(v_prior.sample(), dim=1)
            mask = Bernoulli(mask_prob).sample()
            cont_samples = cont_samples * mask

        if self.args.model_type != 'ibp':
            disc_samples = torch.zeros(nrow * ncol, self.args.disc).cuda()
            for i in range(nrow):
                for j in range(ncol):
                    disc_samples[j + i * ncol, i] = 1.0
            samples = torch.cat([cont_samples, disc_samples], dim=-1)
        else:
            samples = cont_samples
        with torch.no_grad():
            x = self.model.decoder(samples).view(-1, self.args.img_channel,
                                                 self.args.img_size,
                                                 self.args.img_size)
        if self.save_img:
            save_image(x.data, file_name, nrow=ncol, padding=0, pad_value=0.0)
        else:
            return make_grid(x.data, nrow=ncol, padding=0, pad_value=0.0)
Ejemplo n.º 28
0
class BetaSeparatedPolicy(nn.Module):
    def __init__(self, input_dim, action_dim, hidden_layer=[64, 64]):
        super(BetaSeparatedPolicy, self).__init__()
        actor_layer_size = [input_dim] + hidden_layer
        alpha_feature_layers = nn.ModuleList([])
        beta_feature_layers = nn.ModuleList([])
        for i in range(len(actor_layer_size) - 1):
            alpha_feature_layers.append(
                nn.Linear(actor_layer_size[i], actor_layer_size[i + 1]))
            alpha_feature_layers.append(nn.ReLU())
            beta_feature_layers.append(
                nn.Linear(actor_layer_size[i], actor_layer_size[i + 1]))
            beta_feature_layers.append(nn.ReLU())
        self.alpha_body = nn.Sequential(*alpha_feature_layers)
        self.beta_body = nn.Sequential(*beta_feature_layers)
        self.alpha_head = nn.Sequential(
            nn.Linear(hidden_layer[-1], action_dim), nn.Softplus())
        self.beta_head = nn.Sequential(nn.Linear(hidden_layer[-1], action_dim),
                                       nn.Softplus())

        critic_layer_size = [input_dim] + hidden_layer
        critic_layers = nn.ModuleList([])
        for i in range(len(critic_layer_size) - 1):
            critic_layers.append(
                nn.Linear(critic_layer_size[i], critic_layer_size[i + 1]))
            critic_layers.append(nn.ReLU())
        critic_layers.append(nn.Linear(hidden_layer[-1], 1))
        self.critic = nn.Sequential(*critic_layers)

    def forward(self, x, action=None):
        alpha = self.alpha_head(self.alpha_body(x)) + 1
        beta = self.beta_head(self.beta_body(x)) + 1
        self.dist = Beta(alpha, beta)
        if action is None:
            action = self.dist.sample()
        else:
            action = (action + 1) / 2
        action_log_prob = self.dist.log_prob(action).sum(-1)
        entropy = self.dist.entropy().sum(-1)
        value = self.critic(x)

        return action * 2 - 1, action_log_prob, value.squeeze(-1), entropy
Ejemplo n.º 29
0
class MyDist(ActionDistribution):
    @staticmethod
    def required_model_output_shape(action_space, model_config):
        return 6

    def __init__(self, inputs, model):
        super(MyDist, self).__init__(inputs, model)
        self.dist = Beta(inputs[:, :3], inputs[:, 3:])

    def sample(self):
        self.sampled_action = self.dist.sample()
        return self.sampled_action

    def deterministic_sample(self):
        return self.dist.mean

    def sampled_action_logp(self):
        return self.logp(self.sampled_action)

    def logp(self, actions):
        return self.dist.log_prob(actions).sum(-1)

    # refered from https://github.com/pytorch/pytorch/blob/master/torch/distributions/kl.py
    def kl(self, other):
        p, q = self.dist, other.dist
        sum_params_p = p.concentration1 + p.concentration0
        sum_params_q = q.concentration1 + q.concentration0
        t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (
            sum_params_p).lgamma()
        t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (
            sum_params_q).lgamma()
        t3 = (p.concentration1 - q.concentration1) * torch.digamma(
            p.concentration1)
        t4 = (p.concentration0 - q.concentration0) * torch.digamma(
            p.concentration0)
        t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p)
        return (t1 - t2 + t3 + t4 + t5).sum(-1)

    def entropy(self):
        return self.dist.entropy().sum(-1)
Ejemplo n.º 30
0
 def traverse_line(self,
                   cont_dim=0,
                   disc_dim=None,
                   size=10,
                   use_prior=True,
                   set_zero=True,
                   traverse=False,
                   file_name=None):
     if set_zero:
         cont_samples = torch.zeros(size, self.args.cont).cuda()
     else:
         cont_samples = torch.randn(size, self.args.cont).cuda()
     fixed_value = torch.linspace(-2, 2, size).cuda()
     cont_samples[:, cont_dim] = fixed_value
     if use_prior:
         v_prior = Beta(
             torch.ones_like(cont_samples) * self.args.alpha_0,
             torch.ones_like(cont_samples))
         mask_prob = torch.cumprod(v_prior.sample(), dim=1)
         mask = Bernoulli(mask_prob).sample()
         cont_samples = cont_samples * mask
     if self.args.model_type != 'ibp':
         disc_samples = torch.zeros(size, self.args.disc).cuda()
         if traverse:
             for i in range(size):
                 disc_samples[i, i % self.args.disc] = 1.0
         else:
             disc_samples[:, disc_dim] = 1.0
         samples = torch.cat([cont_samples, disc_samples], dim=-1)
     else:
         samples = cont_samples
     with torch.no_grad():
         x = self.model.decoder(samples).view(-1, self.args.img_channel,
                                              self.args.img_size,
                                              self.args.img_size)
     if self.save_img:
         save_image(x.data, file_name, nrow=size, padding=0, pad_value=0.0)
     else:
         return make_grid(x.data, nrow=size, padding=0, pad_value=0.0)