Exemple #1
0
 def forward(self, x: torch.Tensor) -> dist.Normal:  # type: ignore
     """Output a Gaussian distribution parameterized by ``N(x, self.scale^2)``."""
     _validate_input(x)
     return dist.Normal(x, self.scale)
Exemple #2
0
 def prior(self):
     return distributions.Normal(self.prior_mu,
                                 torch.exp(self.prior_logstd))
    def __init__(
        self,
        n_in,
        n_latent,
        n_cat_list,
        n_hidden,
        n_layers,
        t,
        dropout_rate=0.05,
        use_batch_norm=True,
        do_h=True,
    ):
        """
        Encoder using h representation as described in IAF paper
        :param n_in:
        :param n_latent:
        :param n_cat_list:
        :param n_hidden:
        :param n_layers:
        :param t:
        :param dropout_rate:
        :param use_batch_norm:
        """
        super().__init__()
        self.do_h = do_h
        msg = "" if do_h else "Not "
        logger.info(msg="{}Using Hidden State".format(msg))
        self.n_latent = n_latent
        self.encoders = torch.nn.ModuleList()
        self.encoders.append(
            EncoderH(
                n_in=n_in,
                n_out=n_latent,
                n_cat_list=n_cat_list,
                n_layers=n_layers,
                n_hidden=n_hidden,
                do_h=do_h,
                dropout_rate=dropout_rate,
                use_batch_norm=use_batch_norm,
                do_sigmoid=False,
            )
        )

        n_in = 2 * n_latent if do_h else n_latent
        for _ in range(t - 1):
            self.encoders.append(
                EncoderH(
                    n_in=n_in,
                    n_out=n_latent,
                    n_cat_list=None,
                    n_layers=n_layers,
                    n_hidden=n_hidden,
                    do_h=False,
                    dropout_rate=dropout_rate,
                    use_batch_norm=use_batch_norm,
                    do_sigmoid=True,
                )
            )

        self.dist0 = db.Normal(
            loc=torch.zeros(n_latent, device="cuda"),
            scale=torch.ones(n_latent, device="cuda"),
        )
Exemple #4
0
 def base_dist(self, major_label):
     # input_size
     return D.Normal(self.base_dist_mean[major_label, :],
                     self.base_dist_var[major_label, :])
Exemple #5
0
    data[:, 0:2] = data[:, 0:2]/20

    return data[:, 0:2], data[:, 2].astype(np.int)


# %%
X, c = make_pinwheel_data(0.3, 0.05, 1, 512, 0.25)
X = torch.Tensor(X)
c = torch.Tensor(c)

plt.figure(figsize=(10, 10))
plt.scatter(X[:,0].numpy(), X[:,1].numpy(), c=c.numpy(), s=5)
plt.show()

# %%
base_dist = distrib.Normal(loc=torch.zeros(2), scale=torch.ones(2))

# %%
X0 = base_dist.sample((1000,)).numpy()

# %%
colors = np.zeros(len(X0))

idx_0 = np.logical_and(X0[:, 0] < 0, X0[:, 1] < 0)
colors[idx_0] = 0
idx_1 = np.logical_and(X0[:, 0] >= 0, X0[:, 1] < 0)
colors[idx_1] = 1
idx_2 = np.logical_and(X0[:, 0] >= 0, X0[:, 1] >= 0)
colors[idx_2] = 2
idx_3 = np.logical_and(X0[:, 0] < 0, X0[:, 1] >= 0)
colors[idx_3] = 3
Exemple #6
0
    def __init__(self, args):
        super().__init__()
        C, H, W = args.image_dims
        x_dim = C * H * W

        # --------------------
        # p model -- SSL paper generative semi supervised model M2
        # --------------------

        self.p_y = D.OneHotCategorical(probs=1 / args.y_dim * torch.ones(1,args.y_dim, device=args.device))
        self.p_z = D.Normal(torch.tensor(0., device=args.device), torch.tensor(1., device=args.device))

        # parametrized data likelihood p(x|y,z)
        self.decoder = nn.Sequential(#nn.Dropout(0.5),
                                     nn.Linear(args.z_dim + args.y_dim, args.hidden_dim),
                                     nn.BatchNorm1d(args.hidden_dim),
                                     nn.Softplus(),
                                     nn.Linear(args.hidden_dim, args.hidden_dim),
                                     nn.BatchNorm1d(args.hidden_dim),
                                     nn.Softplus(),
                                     #nn.Dropout(0.5),
                                     nn.Linear(args.hidden_dim, x_dim),
                                     nn.Softplus())

        #self.decoder_cnn = nn.Sequential(
        #                               #nn.Dropout(0.5),
        #                               nn.Conv2d(in_channels=args.image_dims[0], out_channels=10, kernel_size=3, stride=1, padding=1), ### <----------- EVT TILFØJ FLERE CHANNELS
        #                               nn.BatchNorm2d(10), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015
        #                               nn.Softplus(),
        #                                nn.Conv2d(in_channels=10, out_channels=20, kernel_size=3, stride=1, padding=1),
        #                               nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015
        #                               nn.Softplus(),
        #                               # nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
        #                               # nn.Dropout(0.4),
        #                               nn.Conv2d(in_channels=20, out_channels=args.image_dims[0], kernel_size=3, stride=1, padding=1),
        #                               #nn.BatchNorm2d(args.image_dims[0]), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015
        #                               #nn.Softplus()
        #                                )

        # Transposed Conv test
        # Before: 1 -> 10 -> 20 -> 1
        self.decoder_tcnn = nn.Sequential(
                                       #nn.Dropout(0.5),
                                       nn.ConvTranspose2d(in_channels=args.image_dims[0], out_channels=10, kernel_size=3, stride=1, padding=1), ### <----------- EVT TILFØJ FLERE CHANNELS
                                       nn.BatchNorm2d(10), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015
                                       nn.Softplus(),
                                        nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5, stride=1, padding=2),
                                       nn.Softplus(),
                                        nn.ConvTranspose2d(in_channels=20, out_channels=20, kernel_size=3, stride=1, padding=1),
                                       nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015
                                       nn.Softplus(),
                                       # nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                       # nn.Dropout(0.4),
                                       nn.ConvTranspose2d(in_channels=20, out_channels=args.image_dims[0], kernel_size=3, stride=1, padding=1),
                                       #nn.BatchNorm2d(args.image_dims[0]), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015
                                       #nn.Softplus()
                                        )

        # --------------------
        # q model -- SSL paper eq 4
        # --------------------

        # parametrized q(y|x) = Cat(y|pi_phi(x)) -- outputs parametrization of categorical distribution
        #before: 1 -> 10 -> 20
        self.encoder_y_cnn = nn.Sequential(
                                       #nn.Dropout(0.5),
                                       nn.Conv2d(in_channels=args.image_dims[0], out_channels=10, kernel_size=5, stride=1, padding=2),
                                       nn.BatchNorm2d(10), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015
                                       nn.Softplus(),
                                       nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                       # nn.Dropout(0.4),
                                       nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5, stride=1, padding=2),
                                       #nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015
                                       nn.Softplus(),
                                       nn.Conv2d(in_channels=20, out_channels=20, kernel_size=3, stride=1, padding=1),
                                       nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015
                                       nn.Softplus(),
                                       nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
                                        )

        self.encoder_y = nn.Sequential(#nn.Dropout(0.5),
                                       nn.Linear(20*7*7, args.hidden_dim), # x_dim i stedet for
                                       # nn.BatchNorm1d(args.hidden_dim),
                                       nn.Softplus(),
                                       #nn.Linear(args.hidden_dim, args.hidden_dim),
                                       #nn.Softplus(),
                                       nn.Linear(args.hidden_dim, args.hidden_dim),
                                       #nn.BatchNorm1d(args.hidden_dim),
                                       nn.Softplus(),
                                       #nn.Dropout(0.5),
                                       nn.Linear(args.hidden_dim, args.y_dim))

        # parametrized q(z|x,y) = Normal(z|mu_phi(x,y), diag(sigma2_phi(x))) -- output parametrizations for mean and diagonal variance of a Normal distribution
        #before: 1 -> 10 -> 20
        self.encoder_z_cnn = nn.Sequential(
                                       #nn.Dropout(0.5),
                                       nn.Conv2d(in_channels=args.image_dims[0], out_channels=10, kernel_size=5, stride=1, padding=2),
                                       nn.BatchNorm2d(10), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015
                                       nn.Softplus(),
                                       nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5, stride=1, padding=2),
                                       #nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015
                                       nn.Softplus(),
                                       nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                       # nn.Dropout(0.4),
                                       nn.Conv2d(in_channels=20, out_channels=20, kernel_size=3, stride=1, padding=1),
                                       nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015
                                       nn.Softplus(),
                                       #nn.Conv2d(in_channels=20, out_channels=20, kernel_size=3, stride=1, padding=1),
                                       #nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015
                                       #nn.Softplus(),

                                       #nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                       #nn.Conv2d(in_channels=20, out_channels=20, kernel_size=3, stride=1, padding=1),
                                       ##nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015
                                       #nn.Softplus(),
                                       nn.MaxPool2d(kernel_size=2, stride=2, padding=0))

        self.encoder_z = nn.Sequential(#nn.Dropout(0.5),
                                       nn.Linear(20*7*7 + args.y_dim, args.hidden_dim), # x_dim + args.y_dim
                                       # nn.BatchNorm1d(args.hidden_dim),
                                       nn.Softplus(),
                                       nn.Linear(args.hidden_dim, args.hidden_dim),
                                       #nn.BatchNorm1d(args.hidden_dim),
                                       nn.Softplus(),
                                       #nn.Dropout(0.5),
                                       nn.Linear(args.hidden_dim, 2*args.z_dim))

        # initialize weights to N(0, 0.001) and biases to 0 (cf SSL section 4.4)
        for p in self.parameters():
            p.data.normal_(0, 0.001)
            if p.ndimension() == 1: p.data.fill_(0.)
 def __init__(self, loc, scale):
     self.loc = loc
     self.scale = scale
     self.base_dist = pyd.Normal(loc, scale)
     transforms = [TanhTransform()]
     super().__init__(self.base_dist, transforms)
Exemple #8
0
def draw_message_distributions_tracker1_2d(mu, sigma, save_dir, azim):
    # sigma *= torch.where(sigma<0.01, torch.ones(sigma.shape).cuda(), 10*torch.ones(sigma.shape).cuda())
    mu = mu.view(-1, 2)
    sigma = sigma.view(-1, 2)

    s_mu = torch.Tensor([0.0, 0.0])
    s_sigma = torch.Tensor([1.0, 1.0])

    x = y = np.arange(100)
    t = np.meshgrid(x, y)

    d = D.Normal(s_mu, s_sigma)
    d1 = D.Normal(mu[0], sigma[0])
    d21 = D.Normal(mu[2], sigma[2])
    d22 = D.Normal(mu[3], sigma[3])
    d31 = D.Normal(mu[4], sigma[4])
    d32 = D.Normal(mu[5], sigma[5])

    print('Entropy')
    print(d1.entropy().detach().cpu().numpy())
    print(d21.entropy().detach().cpu().numpy())
    print(d22.entropy().detach().cpu().numpy())
    print(d31.entropy().detach().cpu().numpy())
    print(d32.entropy().detach().cpu().numpy())

    print('KL Divergence')

    for tt_i in range(3):
        d1 = D.Normal(mu[tt_i * 2 + 0], sigma[tt_i * 2 + 0])
        d2 = D.Normal(mu[tt_i * 2 + 1], sigma[tt_i * 2 + 1])
        print(tt_i,
              D.kl_divergence(d1, d2).sum().detach().cpu().numpy(),
              D.kl_divergence(d1, d).sum().detach().cpu().numpy(),
              D.kl_divergence(d2, d).sum().detach().cpu().numpy(),
              sigma[tt_i * 2 + 0].mean().detach().cpu().numpy(),
              sigma[tt_i * 2 + 1].mean().detach().cpu().numpy())

    # Numpy array of mu and sigma
    s_mu_ = s_mu.detach().cpu().numpy()
    mu_0 = mu[0].detach().cpu().numpy()
    mu_2 = mu[2].detach().cpu().numpy()
    mu_3 = mu[3].detach().cpu().numpy()
    mu_4 = mu[4].detach().cpu().numpy()
    mu_5 = mu[5].detach().cpu().numpy()

    s_sigma_ = s_sigma.detach().cpu().numpy()
    sigma_0 = sigma[0].detach().cpu().numpy()
    sigma_2 = sigma[2].detach().cpu().numpy()
    sigma_3 = sigma[3].detach().cpu().numpy()
    sigma_4 = sigma[4].detach().cpu().numpy()
    sigma_5 = sigma[5].detach().cpu().numpy()

    # Print
    print('mu and sigma')
    print(mu_0, sigma_0)
    print(mu_2, sigma_2)
    print(mu_3, sigma_3)
    print(mu_4, sigma_4)
    print(mu_5, sigma_5)

    # Create grid
    x = np.linspace(-5, 5, 5000)
    y = np.linspace(-5, 5, 5000)
    X, Y = np.meshgrid(x, y)
    pos = np.empty(X.shape + (2, ))
    pos[:, :, 0] = X
    pos[:, :, 1] = Y

    rv = multivariate_normal(s_mu_, [[s_sigma[0], 0], [0, s_sigma[1]]])

    # Agent 1
    # Create multivariate normal
    rv1 = multivariate_normal(mu_0, [[sigma_0[0], 0], [0, sigma_0[1]]])

    # Make a 3D plot
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.plot_surface(X,
                    Y,
                    rv.pdf(pos) + rv1.pdf(pos),
                    cmap='viridis',
                    linewidth=0)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('message')
    plt.tight_layout()
    plt.savefig(save_dir + 'agent1.png')
    ax.view_init(elev=0., azim=azim)
    plt.savefig(save_dir + ("agent1_0_%i.png" % int(azim)))
    ax.view_init(elev=90., azim=0.)
    plt.savefig(save_dir + "agent1_90_0.png")
    # plt.show()
    plt.close()

    # Agent 2
    # Create multivariate normal
    rv21 = multivariate_normal(mu_2, [[sigma_2[0], 0], [0, sigma_2[1]]])
    rv22 = multivariate_normal(mu_3, [[sigma_3[0], 0], [0, sigma_3[1]]])

    # Make a 3D plot
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.plot_surface(X,
                    Y,
                    rv.pdf(pos) + rv21.pdf(pos) + rv22.pdf(pos),
                    cmap='viridis',
                    linewidth=0)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('message')
    plt.tight_layout()
    plt.savefig(save_dir + 'agent2.png')
    ax.view_init(elev=0., azim=azim)
    plt.savefig(save_dir + ("agent2_0_%i.png" % int(azim)))
    ax.view_init(elev=90., azim=0.)
    plt.savefig(save_dir + "agent2_90_0.png")
    # plt.show()
    plt.close()

    # Agent 3
    rv31 = multivariate_normal(mu_4, [[sigma_4[0], 0], [0, sigma_4[1]]])
    rv32 = multivariate_normal(mu_5, [[sigma_5[0], 0], [0, sigma_5[1]]])

    # Make a 3D plot
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.plot_surface(X,
                    Y,
                    rv.pdf(pos) + rv31.pdf(pos) + rv32.pdf(pos),
                    cmap='viridis',
                    linewidth=0)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('message')
    plt.tight_layout()
    plt.savefig(save_dir + 'agent3.png')
    ax.view_init(elev=0., azim=azim)
    plt.savefig(save_dir + ("agent3_0_%i.png" % int(azim)))
    ax.view_init(elev=90., azim=0.)
    plt.savefig(save_dir + "agent3_90_0.png")
    # plt.show()
    plt.close()

    # Overall
    # Make a 3D plot
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.plot_surface(X,
                    Y,
                    rv.pdf(pos) + rv1.pdf(pos) + rv21.pdf(pos) +
                    rv22.pdf(pos) + rv31.pdf(pos) + rv32.pdf(pos),
                    cmap='viridis',
                    linewidth=0)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('message')
    plt.tight_layout()
    plt.savefig(save_dir + 'overall.png')
    ax.view_init(elev=0., azim=azim)
    plt.savefig(save_dir + ("overall_0_%i.png" % int(azim)))
    ax.view_init(elev=90., azim=0.)
    plt.savefig(save_dir + "overall_90_0.png")
    # plt.show()
    plt.close()
 def base_dist(self):
     return D.Normal(self.base_dist_mean, self.base_dist_var)
Exemple #10
0
    def __init__(
            self,
            x_dim,
            y_dim,
            z_dim,
            w_dim,
            predict_x_var=False,
            labels_mutually_exclusive=True,
            use_bernoulli_y=False,
            z_prior_var=1,
            likelihood_partition=None,
            likelihood_params={
                'lik_var': 0.1**2,
                'lik_var_lognormal': 0.1**2
            },
            pw_y_means=torch.tensor([[-1.], [1.]]),
            pw_y_vars=torch.tensor([[0.5**2], [0.5**2]]),
            show_val_loss=False,
            beta1=20,  # reconstruction_error,
            beta2=1,  # how strongly we want the identification network for q(w|x,y) to mimic the prior p(w|y)
            beta3=0.2,  # how strongly we want the identification network q(z|x) to mimic the prior p(z)
            beta4=10,  # beta4 and beta5 control how strongly we want to enforce the mutual information minimization between
            beta5=1,  # w and z. If too small, z will encode information about w. should be tweaked.
            optimizer1=None,
            optimizer2=None,
            device='cpu'):

        super().__init__()

        self.x_dim = x_dim
        self.y_dim = y_dim
        self.z_dim = z_dim
        self.w_dim = w_dim

        print(f"x_dim: {x_dim}")
        print(f"y_dim: {y_dim}")
        print(f"z_dim: {z_dim}")
        print(f"w_dim: {w_dim}")

        print(f"beta1: {beta1}")
        print(f"beta2: {beta2}")
        print(f"beta3: {beta3}")
        print(f"beta4: {beta4}")
        print(f"beta5: {beta5}")

        if likelihood_partition is None:
            # range for feature types
            self.likelihood_partition = {(0, x_dim - 1): 'real'}
        else:
            self.likelihood_partition = likelihood_partition

        self.likelihood_params = likelihood_params

        self.z_prior_var = z_prior_var

        self.device = device

        self.predict_x_var = predict_x_var
        self.labels_mutually_exclusive = labels_mutually_exclusive

        # prior on z

        p_z_loc = torch.zeros(z_dim).to(device)
        cov_diag = z_prior_var * torch.ones(z_dim).to(device)
        self.p_z = D.Normal(loc=p_z_loc, scale=cov_diag.sqrt())

        # prior on y
        if labels_mutually_exclusive:
            if not use_bernoulli_y:
                self.p_y = D.Categorical(probs=1 / y_dim *
                                         torch.ones(1, y_dim))
            else:
                if y_dim != 2:
                    raise ValueError("using bernoulli y with a y_dim != 2")
                self.p_y = D.Bernoulli(probs=0.5)

        else:
            self.p_y = D.Bernoulli(probs=0.5 * torch.ones(y_dim))

        self.decoder_x = nn.Sequential(nn.Linear(z_dim + w_dim, 64), nn.ReLU(),
                                       nn.Linear(64, 64), nn.ReLU(),
                                       nn.Linear(64, x_dim))

        self.encoder_w = nn.Sequential(
            nn.Linear(x_dim + ((y_dim - 1) if y_dim == 2 else y_dim), 64),
            nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 2 * w_dim))

        self.encoder_z = nn.Sequential(nn.Linear(x_dim, 64), nn.ReLU(),
                                       nn.Linear(64, 64), nn.ReLU(),
                                       nn.Linear(64, 2 * z_dim))

        self.optim_1_params = chain(self.decoder_x.parameters(),
                                    self.encoder_w.parameters(),
                                    self.encoder_z.parameters())

        self.decoder_y = nn.Sequential(nn.Linear(z_dim, 64), nn.ReLU(),
                                       nn.Linear(64, 64), nn.ReLU(),
                                       nn.Linear(64, y_dim))

        self.optim_2_params = chain(self.decoder_y.parameters())

        if optimizer1 is None:
            self.optimizer1 = torch.optim.Adam(self.optim_1_params, lr=1e-3)
        else:
            self.optimizer1 = optimizer1

        if optimizer2 is None:
            self.optimizer2 = torch.optim.Adam(self.optim_2_params, lr=1e-3)
        else:
            self.optimizer2 = optimizer2

        self.pw_y_means = pw_y_means
        self.pw_y_vars = pw_y_vars
        self.show_val_loss = show_val_loss

        self.beta1 = beta1
        self.beta2 = beta2
        self.beta3 = beta3
        self.beta4 = beta4
        self.beta5 = beta5

        self.pw_y0 = self.get_pw_y(torch.tensor(0))
        self.pw_y1 = self.get_pw_y(torch.tensor(1))
def kde_pig_dl(
    dm: pl.LightningDataModule,
    batch_size: int,
    N_hat_multiplier: float = 1,
) -> DataLoader:
    # %
    gd_n_steps, gd_lr, gd_threshold = 5, 4e-1, 0.005

    # Spherical = each component has single variance.
    bgm = BayesianGaussianMixture(
        n_components=batch_size,
        covariance_type="spherical",
        warm_start=True,
    )

    x_hat = torch.Tensor()
    for idx, batch in enumerate(iter(dm.train_dataloader())):
        x, _ = batch
        device = x.device
        x = x.detach().cpu().numpy()
        # Last batch might have less elements than origin n_components
        if x.shape[0] < bgm.n_components:
            bgm = BayesianGaussianMixture(
                n_components=x.shape[0],
                covariance_type="spherical",
            )
        # Estimate KDE
        bgm.fit(x)
        # [N_components, 1], [N_components, N_features], [N_components, 1]
        weights, means, variances = (
            torch.Tensor(bgm.weights_).to(device),
            torch.Tensor(bgm.means_).to(device),
            torch.Tensor(bgm.covariances_).to(device),
        )
        filter_weights_idx = weights >= 1e-5
        weights, means, variances = (
            weights[filter_weights_idx],
            means[filter_weights_idx],
            variances[filter_weights_idx][:, None],
        )
        n_selected_components = weights.shape[0]
        p_x = D.Independent(D.Normal(means, torch.sqrt(variances)), 1)
        mix = D.Categorical(weights)
        p_x = D.MixtureSameFamily(mix, p_x)
        # Sample according to multiplier
        x_start = p_x.sample(
            (
                n_selected_components
                * ((batch_size // n_selected_components) + 1)
                * N_hat_multiplier,
            )
        ).reshape(-1, x.shape[1])
        # Use GD
        _x_hat = density_gradient_descent(
            p_x,
            x_start,
            {"N_steps": gd_n_steps, "lr": gd_lr, "threshold": gd_threshold},
        )
        # Ensure same device
        if x_hat.device != device:
            x_hat = x_hat.to(device)
        x_hat = torch.cat((x_hat, _x_hat.detach()))

    dl = DataLoader(TensorDataset(x_hat), batch_size=batch_size, shuffle=True)
    return dl
    def forward(self, query, key, value):

        batch_size = query.shape[0]
        maxlen = query.shape[1]

        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        # Q = [batch size, query len, hid dim]
        # K = [batch size, key len, hid dim]
        # V = [batch size, value len, hid dim]

        Q = Q.view(batch_size, maxlen, self.n_heads,
                   self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, maxlen, self.n_heads,
                   self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, maxlen, self.n_heads,
                   self.head_dim).permute(0, 2, 1, 3)
        # Q = [batch size, n heads, query len, head dim]
        # K = [batch size, n heads, key len, head dim]
        # V = [batch size, n heads, value len, head dim]

        KLD = torch.tensor(0.0)
        if self.args.att_type == 'dot':
            energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        elif self.args.att_type == 'ikandirect':
            w1_proj = self.attsharedw.w1
            w2_proj = self.attsharedw.w2
            scores, norm = ika_ns(Q, K, self.args, self.scale, w1_proj,
                                  w2_proj, 2 * np.pi, self.training)
            energy = torch.log(scores + (1e-5)) + norm
        elif self.args.att_type == 'mikan':
            ''' copula augmented estimation '''
            mu, logvar, L = self.attsharedw.copulanet(Q, K)
            mu = mu.squeeze(-1)
            logvar = logvar.squeeze(-1)
            var = torch.exp(logvar)

            dim_batch_size, num_head, num_head = L.size()
            dim = int(dim_batch_size / batch_size)

            pos_eps = torch.randn([dim, num_head,
                                   self.args.M // 2]).cuda()  # [64,8,128(M/2)]
            X_pos = torch.einsum('ijk,ijl->ijl', L, pos_eps)  # [64,8,128(M/2)]
            X_pos = torch.clamp(X_pos, min=-2.0, max=2.0)
            U_pos = self.standard_normal_dist.cdf(
                X_pos)  # [64,num_head,128(M/2)]

            neg_eps = torch.randn([dim, num_head,
                                   self.args.M // 2]).cuda()  # [64,8,128(M/2)]
            X_neg = torch.einsum('ijk,ijl->ijl', L, neg_eps)  # [64,8,128(M/2)]
            X_neg = torch.clamp(X_neg, min=-2.0, max=2.0)
            U_neg = self.standard_normal_dist.cdf(
                X_neg)  # [64,num_head,128(M/2)]

            marginal_pos = Normal(
                mu.unsqueeze(-1),
                var.unsqueeze(-1))  # mu : [64,num_head] / var : [64,num_head]
            marginal_neg = Normal(
                -1 * mu.unsqueeze(-1),
                var.unsqueeze(-1))  # mu : [64,num_head] / var : [64,num_head]
            Y_pos = marginal_pos.icdf(U_pos)  # [32,4,64]
            Y_neg = marginal_neg.icdf(U_neg)
            U = torch.cat([U_pos, U_neg])
            ent_copula = -1 * torch.sum(torch.mul(U, torch.log(U + (1e-5))))
            ''' kernel and norm calculation '''
            z = torch.cat([Y_pos, Y_neg], -1)  # torch.Size([1, 64, 4, 256])
            w1_proj = self.attsharedw.wnet1(z)
            w2_proj = self.attsharedw.wnet2(z)
            scores, norm = ika_ns(Q, K, self.args, self.scale, w1_proj,
                                  w2_proj, 2 * np.pi, self.training)
            energy = torch.log(scores + (1e-5)) + norm
            # energy = [batch size, n heads, query len, key len]

            q_dist = tdist.Normal(mu, logvar.exp())
            KLD = torch.distributions.kl_divergence(q_dist, self.p_dist)
            KLD = self.args.kl_lambda * torch.sum(
                KLD) + self.args.copula_lambda * ent_copula

        attention = torch.softmax(energy, dim=-1)
        # attention = [batch size, n heads, query len, key len]

        x = torch.matmul(self.dropout(attention), V)
        # x = [batch size, n heads, query len, head dim]

        x = x.permute(0, 2, 1, 3).contiguous()
        # x = [batch size, query len, n heads, head dim]

        x = x.view(batch_size, -1, self.args.KEY_DIM)
        # x = [batch size, query len, hid dim]

        x = self.fc_o(x)
        # x = [batch size, query len, hid dim]

        return x, attention, KLD
 def likelihood(self, z):
     loc, scale = self.decoder(z)
     return dist.Normal(loc, scale)
Exemple #14
0
def _train(replay_buffer, net, teacher_net, criterion, coord_converter, logger,
           config, episode):

    import torch
    from phase2_utils import _log_visuals, get_weight, repeat

    import torch.distributions as tdist
    noiser = tdist.Normal(torch.tensor(0.0),
                          torch.tensor(config['speed_noise']))

    teacher_net.eval()

    for epoch in range(config['epoch_per_episode']):

        optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)

        net.train()
        replay_buffer.init_new_weights()
        loader = torch.utils.data.DataLoader(replay_buffer,
                                             batch_size=config['batch_size'],
                                             num_workers=4,
                                             shuffle=True,
                                             drop_last=True)

        for i, (idxes, rgb_image, command, speed, target,
                birdview) in enumerate(loader):
            if i % 100 == 0:
                print("ITER: %d" % i)
            rgb_image = rgb_image.to(config['device']).float()
            birdview = birdview.to(config['device']).float()
            command = one_hot(command).to(config['device']).float()
            speed = speed.to(config['device']).float()

            if config['speed_noise'] > 0:
                speed += noiser.sample(speed.size()).to(speed.device)
                speed = torch.clamp(speed, 0, 10)

            if len(rgb_image.size()) > 4:
                B, batch_aug, c, h, w = rgb_image.size()
                rgb_image = rgb_image.view(B * batch_aug, c, h, w)
                birdview = repeat(birdview, batch_aug)
                command = repeat(command, batch_aug)
                speed = repeat(speed, batch_aug)
            else:
                B = rgb_image.size(0)
                batch_aug = 1

            with torch.no_grad():
                _teac_location, _teac_locations = teacher_net(
                    birdview, speed, command)

            _pred_location, _pred_locations = net(rgb_image, speed, command)
            pred_location = coord_converter(_pred_location)
            pred_locations = coord_converter(_pred_locations)

            optimizer.zero_grad()
            loss = criterion(pred_locations, _teac_locations)

            # Compute resample weights
            pred_location_normed = pred_location / (0.5 * CROP_SIZE) - 1.
            weights = get_weight(pred_location_normed, _teac_location)
            weights = torch.mean(torch.stack(torch.chunk(weights, B)), dim=0)

            replay_buffer.update_weights(idxes, weights)

            loss_mean = loss.mean()

            loss_mean.backward()
            optimizer.step()

            should_log = False
            should_log |= i % config['log_iterations'] == 0

            if should_log:
                metrics = dict()
                metrics['loss'] = loss_mean.item()

                images = _log_visuals(
                    rgb_image, birdview, speed, command, loss, pred_location,
                    (_pred_location + 1) * coord_converter._img_size / 2,
                    _teac_location)

                logger.scalar(loss_mean=loss_mean.item())
                logger.image(birdview=images)

        replay_buffer.normalize_weights()

        rgb_image, birdview, command, speed, target = replay_buffer.get_highest_k(
            32)
        rgb_image = rgb_image.to(config['device']).float()
        birdview = birdview.to(config['device']).float()
        command = one_hot(command).to(config['device']).float()
        speed = speed.to(config['device']).float()

        with torch.no_grad():
            _teac_location, _teac_locations = teacher_net(
                birdview, speed, command)

        net.eval()
        _pred_location, _pred_locations = net(rgb_image, speed, command)
        pred_location = coord_converter(_pred_location)
        pred_locations = coord_converter(_pred_locations)
        pred_location_normed = pred_location / (0.5 * CROP_SIZE) - 1.
        weights = get_weight(pred_location_normed, _teac_location)

        # TODO: Plot highest
        images = _log_visuals(rgb_image, birdview, speed, command, weights,
                              pred_location, (_pred_location + 1) *
                              coord_converter._img_size / 2, _teac_location)

        logger.image(topk=images)

        logger.end_epoch()

    if episode in SAVE_EPISODES:
        torch.save(net.state_dict(),
                   str(Path(config['log_dir']) / ('model-%d.th' % episode)))
Exemple #15
0
 def forward(self, x):
     return td.Normal(x, self.log_scale.exp())
 def __call__(self, wav):
     noiser = distributions.Normal(0, self.var)
     if np.random.uniform() < 0.5:
         wav += noiser.sample(wav.size())
     return wav.clamp(-1, 1)
Exemple #17
0
                                           retain_graph=True,
                                           create_graph=True)[0]
                eps = torch.randn_like(dfdx)
                epsH = torch.autograd.grad(dfdx,
                                           x,
                                           grad_outputs=eps,
                                           create_graph=True,
                                           retain_graph=True)[0]

                trH = (epsH * eps).sum(1)
                norm_s = (dfdx * dfdx).sum(1)

                loss = (trH + .5 * norm_s).mean()

            elif args.mode == "nce":
                noise_dist = distributions.Normal(init_mu, init_std)
                x_fake = noise_dist.sample_n(x.size(0))

                pos_logits = modelICA(
                    x) + approx_normalizing_const - noise_dist.log_prob(x).sum(
                        1)
                neg_logits = modelICA(
                    x_fake) + approx_normalizing_const - noise_dist.log_prob(
                        x_fake).sum(1)

                pos_loss = nn.BCEWithLogitsLoss()(pos_logits,
                                                  torch.ones_like(pos_logits))
                neg_loss = nn.BCEWithLogitsLoss()(neg_logits,
                                                  torch.zeros_like(neg_logits))
                loss = pos_loss + neg_loss
def MIWAE(X_miss, h=128, d=1, K=1000, L=20, bs=64, n_epochs=201):
    mask = (1 - np.isnan(X_miss).numpy()).astype(bool)
    xhat_0 = np.where(np.isnan(X_miss), 0, X_miss)
    n = np.shape(X_miss)[0]
    p = np.shape(X_miss)[1]
    p_z = td.Independent(
        td.Normal(loc=torch.zeros(d).cuda(), scale=torch.ones(d).cuda()), 1)
    decoder = nn.Sequential(
        torch.nn.Linear(d, h),
        torch.nn.ReLU(),
        torch.nn.Linear(h, h),
        torch.nn.ReLU(),
        torch.nn.Linear(
            h, 3 * p
        ),  # the decoder will output both the mean, the scale, and the number of degrees of freedoms (hence the 3*p)
    )
    encoder = nn.Sequential(
        torch.nn.Linear(p, h),
        torch.nn.ReLU(),
        torch.nn.Linear(h, h),
        torch.nn.ReLU(),
        torch.nn.Linear(
            h, 2 * d
        ),  # the encoder will output both the mean and the diagonal covariance
    )
    encoder.cuda()  # we'll use the GPU
    decoder.cuda()
    optimizer = optim.Adam(list(encoder.parameters()) +
                           list(decoder.parameters()),
                           lr=1e-3)
    xhat = np.copy(xhat_0)  # This will be out imputed data matrix

    encoder.apply(weights_init)
    decoder.apply(weights_init)

    for ep in range(1, n_epochs):
        perm = np.random.permutation(
            n)  # We use the "random reshuffling" version of SGD
        batches_data = np.array_split(xhat_0[perm, ], n / bs)
        batches_mask = np.array_split(mask[perm, ], n / bs)
        for it in range(len(batches_data)):
            optimizer.zero_grad()
            encoder.zero_grad()
            decoder.zero_grad()
            b_data = torch.from_numpy(batches_data[it]).float().cuda()
            b_mask = torch.from_numpy(batches_mask[it]).float().cuda()
            loss = miwae_loss(iota_x=b_data,
                              mask=b_mask,
                              d=d,
                              K=K,
                              p_z=p_z,
                              encoder=encoder,
                              decoder=decoder)
            loss.backward()
            optimizer.step()
            ### Imputation
            xhat[~mask] = miwae_impute(
                iota_x=torch.from_numpy(xhat_0).float().cuda(),
                mask=torch.from_numpy(mask).float().cuda(),
                L=L,
                d=d,
                p_z=p_z,
                encoder=encoder,
                decoder=decoder).cpu().data.numpy()[~mask]
    x_miwae = np.where(np.isnan(X_miss), xhat, X_miss)
    return (x_miwae)
Exemple #19
0
                screens, obs, acts, rewards, dones, next_screens, next_obs = batch
                batch = []
                k = 0
                
                
                # Double-Q(s,a)
                q1, q2 = crt_net(screens, obs, acts)
                q_min = torch.min(q1,q2).squeeze()
                tracker.track("Q(s,a)_batch_mean", q_min.mean(), exp_count)

                # Double-Q(s',a') 
                next_mu, next_std = act_net(next_screens, next_obs) # no action cliping here
                next_q1, next_q2 = tgt_crt_net(next_screens, next_obs, next_mu) # take expected action, what else?

                # Q_ref
                next_acts_distr = distr.Normal(next_mu, next_std)
                v_next = torch.min(next_q1, next_q2).squeeze() - \
                    ENTROPY_ALPHA*next_acts_distr.log_prob(next_mu).sum(dim=1) # eqn (3) in 2nd SAC paper

                v_next[dones] = 0.0 # if terminated, local reward makes up the value already
                q_ref = (rewards + v_next*(GAMMA**UNROLL)).unsqueeze(dim=-1)
                tracker.track("Q_ref(s,a)_batch_mean", q_ref.mean(), exp_count)

                # Critic loss (on both Double-Q heads to update all parameters)
                q1_loss = F.mse_loss(q1, q_ref.detach()) # eqn (5) in 2nd SAC paper
                q2_loss = F.mse_loss(q2, q_ref.detach())
                q_loss = q1_loss + q2_loss
                q_loss.backward(retain_graph=True)
                crt_opt.step()
                tracker.track("Loss_crt", q_loss, exp_count)
Exemple #20
0
 def forward(self, x: torch.Tensor):
     loss = -tdist.Normal(self.mu, self.sigma).log_prob(x)
     return torch.sum(loss) / (loss.size(0) * loss.size(1))
Exemple #21
0
def sample_normal(means, **kwargs):
    return sample_(D.Normal(means, **kwargs))
Exemple #22
0
def test_non_empty_bit_vector(batch_shape=tuple(), K=3):
    assert K<= 10, "I test against explicit enumeration, K>10 might be too slow for that"
    
    # Uniform 
    F = NonEmptyBitVector(torch.zeros(batch_shape + (K,)))
    
    # Shapes
    assert F.batch_shape == batch_shape, "NonEmptyBitVector has the wrong batch_shape"
    assert F.dim == K, "NonEmptyBitVector has the wrong dim"
    assert F.event_shape == (K,), "NonEmptyBitVector has the wrong event_shape"
    assert F.scores.shape == batch_shape + (K,), "NonEmptyBitVector.score has the wrong shape"
    assert F.arc_weight.shape == batch_shape + (K+1,3,3), "NonEmptyBitVector.arc_weight has the wrong shape"
    assert F.state_value.shape == batch_shape + (K+2,3), "NonEmptyBitVector.state_value has the wrong shape"
    assert F.state_rvalue.shape == batch_shape + (K+2,3), "NonEmptyBitVector.state_rvalue has the wrong shape"
    # shape: [num_faces] + batch_shape + [K]
    support = F.enumerate_support()    
    # test shape of support
    assert support.shape == (2**K-1,) + batch_shape + (K,), "The support has the wrong shape"

    assert F.expand((2,3) + batch_shape).batch_shape == (2,3) + batch_shape, "Bad expand batch_shape"
    assert F.expand((2,3) + batch_shape).event_shape == (K,), "Bad expand event_shape"
    assert F.expand((2,3) + batch_shape).sample().shape == (2,3) + batch_shape + (K,), "Bad expand single sample"
    assert F.expand((2,3) + batch_shape).sample((13,)).shape == (13,2,3) + batch_shape + (K,), "Bad expand multiple samples"

    # Constraints
    assert (support.sum(-1) > 0).all(), "The support has an empty bit vector"
    for _ in range(100):  # testing one sample at a time
        assert F.sample().sum(-1).all(), "I found an empty vector"
    # testing a batch of samples
    assert F.sample((100,)).sum(-1).all(), "I found an empty vector"
    # testing a complex batch of samples
    assert F.sample((2, 100,)).sum(-1).all(), "I found an empty vector"
    
    # Distribution
    # check for uniform probabilities
    assert torch.isclose(F.log_prob(support).exp(), torch.tensor(1./F.support_size)).all(), "Non-uniform"
    # check for uniform marginal probabilities
    assert torch.isclose(F.sample((10000,)).float().mean(0), support.mean(0), atol=1e-1).all(), "Bad MC marginals"    
    assert torch.isclose(F.marginals(), support.mean(0)).all(), "Bad exact marginals"

    # Entropy

    # [num_faces, B]
    log_prob = F.log_prob(support)
    assert torch.isclose(F.entropy(), (-(log_prob.exp() * log_prob).sum(0)), atol=1e-2).all(), "Problem in the entropy DP"

    # Non-Uniform  

    # Entropy  
    P = NonEmptyBitVector(td.Normal(torch.zeros(batch_shape + (K,)), torch.ones(batch_shape + (K,))).sample())
    log_p = P.log_prob(support)
    assert torch.isclose(P.entropy(), (-(log_p.exp() * log_p).sum(0)), atol=1e-2).all(), "Problem in the entropy DP"
    # Cross-Entropy
    Q = NonEmptyBitVector(td.Normal(torch.zeros(batch_shape + (K,)), torch.ones(batch_shape + (K,))).sample())
    log_q = Q.log_prob(support)
    assert torch.isclose(P.cross_entropy(Q), -(log_p.exp() * log_q).sum(0), atol=1e-2).all(), "Problem in the cross-entropy DP"
    # KL
    assert torch.isclose(td.kl_divergence(P, Q), (log_p.exp() * (log_p - log_q)).sum(0), atol=1e-2).all(), "Problem in KL"

    # Constraints
    for _ in range(100):  # testing one sample at a time
        assert P.sample().sum(-1).all(), "I found an empty vector"
        assert Q.sample().sum(-1).all(), "I found an empty vector"
    # testing a batch of samples
    assert P.sample((100,)).sum(-1).all(), "I found an empty vector"
    assert Q.sample((100,)).sum(-1).all(), "I found an empty vector"
    # testing a complex batch of samples
    assert P.sample((2, 100,)).sum(-1).all(), "I found an empty vector"
    assert Q.sample((2, 100,)).sum(-1).all(), "I found an empty vector"
Exemple #23
0
def simulate_Y(H, Params):
    mu, sigma = torch.matmul(H,
                             Params.W_mu_y), torch.matmul(H, Params.W_sigma_y)
    Y = dist.Normal(mu, sigma).sample()
    return Y
            continue

        #generate labels for the real batch of data...the (k+1)th element is 1...rest are zero
        D_label_real = utils.get_labels(num_generators, -1, real_b_size,
                                        device)

        #forward pass for the real batch of data and then resize

        gen_input_noise = utils.generate_noise_for_generator(
            real_b_size // num_generators, n_z, device)
        gen_output = generator(
            gen_input_noise)  #, real_b_size//num_generators)

        gen_out_d_in = gen_output.detach()
        ##############################################################
        norm = dist.Normal(torch.tensor([NOISE_MEAN]),
                           torch.tensor([NOISE_DEV]))

        if add_noise == 1:
            x_noise = norm.sample(gen_out_d_in.size()).view(
                gen_out_d_in.size()).to(device)
            gen_out_d_in = gen_out_d_in + x_noise

        #################################################################
        D_Label_Fake = []
        for g in range(num_generators):
            D_Label_Fake.append(
                utils.get_labels(num_generators, g,
                                 real_b_size // num_generators, device))

        D_Label_Fake = torch.cat(D_Label_Fake)
        D_Labels = torch.cat([D_label_real, D_Label_Fake])
Exemple #25
0
    def _train_continuous(self, BATCH):
        q1 = self.critic(BATCH.obs, BATCH.action,
                         begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2 = self.critic2(BATCH.obs, BATCH.action,
                          begin_mask=BATCH.begin_mask)  # [T, B, 1]
        if self.is_continuous:
            target_mu, target_log_std = self.actor(
                BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(target_mu, target_log_std.exp()),
                                  1)
            target_pi = dist.sample()  # [T, B, A]
            target_pi, target_log_pi = squash_action(
                target_pi,
                dist.log_prob(target_pi).unsqueeze(-1))  # [T, B, A], [T, B, 1]
        else:
            target_logits = self.actor(
                BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            target_cate_dist = td.Categorical(logits=target_logits)
            target_pi = target_cate_dist.sample()  # [T, B]
            target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze(
                -1)  # [T, B, 1]
            target_pi = F.one_hot(target_pi, self.a_dim).float()  # [T, B, A]
        q1_target = self.critic.t(BATCH.obs_,
                                  target_pi,
                                  begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2_target = self.critic2.t(BATCH.obs_,
                                   target_pi,
                                   begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q_target = th.minimum(q1_target, q2_target)  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                             (q_target - self.alpha * target_log_pi),
                             BATCH.begin_mask).detach()  # [T, B, 1]
        td_error1 = q1 - dc_r  # [T, B, 1]
        td_error2 = q2 - dc_r  # [T, B, 1]
        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        critic_loss = 0.5 * q1_loss + 0.5 * q2_loss
        self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs,
                                     begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            pi = dist.rsample()  # [T, B, A]
            pi, log_pi = squash_action(
                pi,
                dist.log_prob(pi).unsqueeze(-1))  # [T, B, A], [T, B, 1]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            pi = _pi_diff + _pi  # [T, B, A]
            log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1).mean()  # 1
        q_s_pi = th.minimum(
            self.critic(BATCH.obs, pi, begin_mask=BATCH.begin_mask),
            self.critic2(BATCH.obs, pi,
                         begin_mask=BATCH.begin_mask))  # [T, B, 1]

        actor_loss = -(q_s_pi - self.alpha * log_pi).mean()  # 1

        self.actor_oplr.optimize(actor_loss)

        summaries = {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/q1_loss': q1_loss,
            'LOSS/q2_loss': q2_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/entropy': entropy,
            'Statistics/q_min': th.minimum(q1, q2).min(),
            'Statistics/q_mean': th.minimum(q1, q2).mean(),
            'Statistics/q_max': th.maximum(q1, q2).max()
        }
        if self.auto_adaption:
            alpha_loss = -(self.alpha *
                           (log_pi + self.target_entropy).detach()).mean()  # 1
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries
Exemple #26
0
 def __init__(self, loc, scale_diag, validate_args=None):
     base_distribution = dists.Normal(loc=loc, scale=scale_diag)
     super().__init__(base_distribution,
                      reinterpreted_batch_ndims=1,
                      validate_args=validate_args)
Exemple #27
0
 def _prior(self, batch_size, sequence_len):
   mu = torch.zeros(batch_size, sequence_len, self._z_dim)
   std = torch.ones_like(mu)
   return td.Independent(td.Normal(mu, std), 1)
Exemple #28
0
 def __init__(self, loc, scale, reinterpreted_batch_ndims=None, **kwargs):
     if reinterpreted_batch_ndims is None:  # Handle 1D input, e.g. Normal(0, 1)
         reinterpreted_batch_ndims = 1 if isinstance(loc,
                                                     torch.Tensor) else 0
     super().__init__(td.Normal(loc, scale, **kwargs),
                      reinterpreted_batch_ndims)
Exemple #29
0
 def forward(self, x=None):
     return dist.Normal(self.mu, self.sigma)
Exemple #30
0
 def _component_sample(self, idxs):
     mean, std = self.mean[idxs], self.log_std[idxs].exp()
     return distributions.Normal(mean, std).sample()