Exemplo n.º 1
0
 def test_input_extra_var(self):
     q = Normal(var=['z'], cond_var=['x'], loc='x', scale=1)
     p = Normal(var=['y'], cond_var=['z'], loc='z', scale=1)
     e = Expectation(q, p.log_prob())
     assert set(e.eval({'y': torch.zeros(1), 'x': torch.zeros(1),
                        'w': torch.zeros(1)}, return_dict=True)[1]) == set(('w', 'x', 'y', 'z'))
     assert set(e.eval({'y': torch.zeros(1), 'x': torch.zeros(1),
                        'z': torch.zeros(1)}, return_dict=True)[1]) == set(('x', 'y', 'z'))
Exemplo n.º 2
0
 def test_input_extra_var(self):
     normal = Normal(loc=0, scale=1)
     assert set(normal.sample({'y': torch.zeros(1)})) == set(('x', 'y'))
     assert normal.get_log_prob({
         'y': torch.zeros(1),
         'x': torch.zeros(1)
     }).shape == torch.Size([1])
     assert set(normal.sample({'x': torch.zeros(1)})) == set(('x'))
Exemplo n.º 3
0
    def test_get_entropy(self):
        dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1)
        truth = dist.get_entropy({'y': torch.ones(1)})

        dist = Normal(var=['x'], cond_var=['y'], loc='y',
                      scale=1).replace_var(y='z', x='y')
        result = dist.get_entropy({'z': torch.ones(1)})
        assert result == truth

        dist = Normal(var=['x'], cond_var=['y'], loc='y',
                      scale=1).replace_var(y='z')
        with pytest.raises(ValueError):
            dist.get_entropy({'y': torch.ones(1)})
Exemplo n.º 4
0
    def test_sample_variance(self):
        dist = Normal(var=['x'], cond_var=['y'], loc=2, scale='y')
        result = dist.sample_variance({'y': torch.ones(1)})
        assert result == torch.ones(1)

        dist = Normal(var=['x'], cond_var=['y'], loc=2,
                      scale='y').replace_var(y='z')
        result = dist.sample_variance({'z': torch.ones(1)})
        assert result == torch.ones(1)

        dist = Normal(var=['x'], cond_var=['y'], loc=2,
                      scale='y').replace_var(y='z')
        with pytest.raises(ValueError):
            dist.sample_variance({'y': torch.ones(1)})
Exemplo n.º 5
0
def test_save_dist(tmpdir, no_contiguous_tensor):
    # pull request:#110
    ones = torch.ones_like(no_contiguous_tensor)
    p = Normal(loc=no_contiguous_tensor, scale=ones)
    save_path = pjoin(tmpdir, "tmp.pt")
    torch.save(p.state_dict(), save_path)
    q = Normal(loc=ones, scale=3 * ones)
    assert not torch.all(no_contiguous_tensor == q.loc).item()

    # it needs copy of tensor
    q = Normal(loc=ones, scale=ones)
    q.load_state_dict(torch.load(save_path))
    assert torch.all(no_contiguous_tensor == q.loc).item()
Exemplo n.º 6
0
    def test_init_with_scalar_params(self):
        normal = Normal(loc=0, scale=1, features_shape=[2])
        assert normal.sample()['x'].shape == torch.Size([1, 2])
        assert normal.features_shape == torch.Size([2])

        normal = Normal(loc=0, scale=1)
        assert normal.sample()['x'].shape == torch.Size([1])
        assert normal.features_shape == torch.Size([])
Exemplo n.º 7
0
    def test_set_option(self):
        dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal(
            var=['y'], loc=0, scale=1)
        dist.graph.set_option(dict(batch_n=4, sample_shape=(2, 3)), ['y'])
        sample = dist.sample()
        assert sample['y'].shape == torch.Size([2, 3, 4])
        assert sample['x'].shape == torch.Size([2, 3, 4])
        dist.graph.set_option({}, ['y'])
        assert dist.get_log_prob(sample, sum_features=True,
                                 feature_dims=None).shape == torch.Size([2])
        assert dist.get_log_prob(
            sample, sum_features=False).shape == torch.Size([2, 3, 4])

        dist = Normal(var=['x'], cond_var=['y'], loc='y',
                      scale=1) * FactorizedBernoulli(
                          var=['y'], probs=torch.tensor([0.3, 0.8]))
        dist.graph.set_option(dict(batch_n=3, sample_shape=(4, )), ['y'])
        sample = dist.sample()
        assert sample['y'].shape == torch.Size([4, 3, 2])
        assert sample['x'].shape == torch.Size([4, 3, 2])
        dist.graph.set_option(dict(), ['y'])
        assert dist.get_log_prob(sample, sum_features=True,
                                 feature_dims=[-1]).shape == torch.Size([4, 3])
Exemplo n.º 8
0
 def test_get_log_prob_feature_dims2(self):
     dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal(
         var=['y'], loc=0, scale=1)
     dist.graph.set_option(dict(batch_n=4, sample_shape=(2, 3)), ['y'])
     sample = dist.sample()
     assert sample['y'].shape == torch.Size([2, 3, 4])
     list(dist.graph._factors_from_variable('y'))[0].option = {}
     assert dist.get_log_prob(sample, sum_features=True,
                              feature_dims=None).shape == torch.Size([2])
     assert dist.get_log_prob(sample, sum_features=True,
                              feature_dims=[-2]).shape == torch.Size([2, 4])
     assert dist.get_log_prob(sample,
                              sum_features=True,
                              feature_dims=[0, 1]).shape == torch.Size([4])
     assert dist.get_log_prob(sample, sum_features=True,
                              feature_dims=[]).shape == torch.Size(
                                  [2, 3, 4])
Exemplo n.º 9
0
 def test_sample_mean(self):
     dist = MixtureModel([Normal(loc=0, scale=1),
                          Normal(loc=1, scale=1)],
                         Categorical(probs=torch.tensor([1., 2.])))
     assert dist.sample(sample_mean=True)['x'] == torch.ones(1)
Exemplo n.º 10
0
 def test_batch_n(self):
     normal = Normal(loc=0, scale=1)
     assert normal.sample(batch_n=3)['x'].shape == torch.Size([3])
Exemplo n.º 11
0
def main(smoke_test=False):
    epochs = 2 if smoke_test == True else 50
    batch_size = 128
    seed = 0

    x_ch = 1
    z_dim = 32

    # 乱数シード初期化
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    date_and_time = datetime.datetime.now().strftime('%Y-%m%d-%H%M')
    save_root = f'./results/pixyz/{date_and_time}'
    if not os.path.exists(save_root):
        os.makedirs(save_root)

    if torch.cuda.is_available():
        device = 'cuda:0'
    else:
        device = 'cpu'

    root = '/mnt/hdd/sika/Datasets'
    train_loader, test_loader = make_MNIST_loader(root, batch_size=batch_size)

    # 生成モデルと推論モデルの生成
    p = Generator(x_ch, z_dim).to(device)
    q = Inference(x_ch, z_dim).to(device)

    # 潜在変数の事前分布の規定
    p_prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
                     var=['z'], features_shape=[z_dim], name='p_{prior}').to(device)

    # 損失関数の定義
    loss = (KullbackLeibler(q, p_prior) - Expectation(q, LogProb(p))).mean()

    # Model APIの設定
    model = Model(loss=loss, distributions=[p, q],
                  optimizer=optim.Adam, optimizer_params={"lr": 1e-3})

    x_fixed, _ = next(iter(test_loader))
    x_fixed = x_fixed[:8].to(device)
    z_fixed = p_prior.sample(batch_n=64)['z']

    train_loss_list, test_loss_list = [], []
    for epoch in range(1, epochs + 1):
        train_loss_list.append(learn(model, epoch, train_loader, device, train=True))
        test_loss_list.append(learn(model, epoch, test_loader, device, train=False))

        print(f'    [Epoch {epoch}] train loss {train_loss_list[-1]:.4f}')
        print(f'    [Epoch {epoch}] test  loss {test_loss_list[-1]:.4f}\n')

        # 損失値のグラフを作成し保存
        plt.plot(list(range(1, epoch+1)), train_loss_list, label='train')
        plt.plot(list(range(1, epoch + 1)), test_loss_list, label='test')
        plt.xlabel('epochs')
        plt.ylabel('loss')
        plt.legend()
        plt.savefig(os.path.join(save_root, 'loss.png'))
        plt.close()

        # 再構成画像
        x_reconst = reconstruct_image(p, q, x_fixed)
        save_image(torch.cat([x_fixed, x_reconst], dim=0), os.path.join(save_root, f'reconst_{epoch}.png'), nrow=8)

        # 補間画像
        x_interpol = interpolate_image(p, q, x_fixed)
        save_image(x_interpol, os.path.join(save_root, f'interpol_{epoch}.png'), nrow=8)

        # 生成画像(潜在変数固定)
        x_generate = generate_image(p, z_fixed)
        save_image(x_generate, os.path.join(save_root, f'generate_{epoch}.png'), nrow=8)

        # 生成画像(ランダムサンプリング)
        x_sample = sample_image(p_prior, p)
        save_image(x_sample, os.path.join(save_root, f'sample_{epoch}.png'), nrow=8)
Exemplo n.º 12
0
class TestDistributionBase:
    def test_init_with_scalar_params(self):
        normal = Normal(loc=0, scale=1, features_shape=[2])
        assert normal.sample()['x'].shape == torch.Size([1, 2])
        assert normal.features_shape == torch.Size([2])

        normal = Normal(loc=0, scale=1)
        assert normal.sample()['x'].shape == torch.Size([1])
        assert normal.features_shape == torch.Size([])

    def test_batch_n(self):
        normal = Normal(loc=0, scale=1)
        assert normal.sample(batch_n=3)['x'].shape == torch.Size([3])

    def test_input_extra_var(self):
        normal = Normal(loc=0, scale=1)
        assert set(normal.sample({'y': torch.zeros(1)})) == set(('x', 'y'))
        assert normal.get_log_prob({
            'y': torch.zeros(1),
            'x': torch.zeros(1)
        }).shape == torch.Size([1])
        assert set(normal.sample({'x': torch.zeros(1)})) == set(('x'))

    def test_sample_mean(self):
        dist = Normal(loc=0, scale=1)
        assert dist.sample(sample_mean=True)['x'] == torch.zeros(1)

    @pytest.mark.parametrize(
        "dist",
        [
            Normal(loc=0, scale=1),
            Normal(var=['x'], loc=0, scale=1) *
            Normal(var=['y'], loc=0, scale=1),
            # Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal(var=['y'], loc=0, scale=1),
        ],
    )
    def test_get_log_prob_feature_dims(self, dist):
        assert dist.get_log_prob(dist.sample(batch_n=4, sample_shape=(2, 3)),
                                 sum_features=True,
                                 feature_dims=None).shape == torch.Size([2])
        assert dist.get_log_prob(dist.sample(batch_n=4, sample_shape=(2, 3)),
                                 sum_features=True,
                                 feature_dims=[-2]).shape == torch.Size([2, 4])
        assert dist.get_log_prob(dist.sample(batch_n=4, sample_shape=(2, 3)),
                                 sum_features=True,
                                 feature_dims=[0, 1]).shape == torch.Size([4])
        assert dist.get_log_prob(dist.sample(batch_n=4, sample_shape=(2, 3)),
                                 sum_features=True,
                                 feature_dims=[]).shape == torch.Size(
                                     [2, 3, 4])

    def test_get_log_prob_feature_dims2(self):
        dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal(
            var=['y'], loc=0, scale=1)
        dist.graph.set_option(dict(batch_n=4, sample_shape=(2, 3)), ['y'])
        sample = dist.sample()
        assert sample['y'].shape == torch.Size([2, 3, 4])
        list(dist.graph._factors_from_variable('y'))[0].option = {}
        assert dist.get_log_prob(sample, sum_features=True,
                                 feature_dims=None).shape == torch.Size([2])
        assert dist.get_log_prob(sample, sum_features=True,
                                 feature_dims=[-2]).shape == torch.Size([2, 4])
        assert dist.get_log_prob(sample,
                                 sum_features=True,
                                 feature_dims=[0, 1]).shape == torch.Size([4])
        assert dist.get_log_prob(sample, sum_features=True,
                                 feature_dims=[]).shape == torch.Size(
                                     [2, 3, 4])

    @pytest.mark.parametrize("dist", [
        Normal(loc=0, scale=1),
        Normal(var=['x'], cond_var=['y'], loc='y', scale=1) *
        Normal(var=['y'], loc=0, scale=1),
    ])
    def test_unknown_option(self, dist):
        x_dict = dist.sample(unknown_opt=None)
        dist.get_log_prob(x_dict, unknown_opt=None)
Exemplo n.º 13
0
 def test_sample_mean(self):
     dist = Normal(var=['x'], loc=0, scale=1) * Normal(
         var=['y'], cond_var=['x'], loc='x', scale=1)
     assert dist.sample(sample_mean=True)['y'] == torch.zeros(1)
Exemplo n.º 14
0
 def test_rename_atomdist(self):
     normal = Normal(var=['x'], name='p')
     graph = normal.graph
     assert graph.name == 'p'
     normal.name = 'q'
     assert graph.name == 'q'
Exemplo n.º 15
0
 def test_sample_mean(self):
     p = Normal(loc=0, scale=1)
     f = p.log_prob()
     e = Expectation(p, f)
     e.eval({}, sample_mean=True)
Exemplo n.º 16
0
 def test_sample_mean(self):
     dist = Normal(loc=0, scale=1)
     assert dist.sample(sample_mean=True)['x'] == torch.zeros(1)
Exemplo n.º 17
0
    def test_get_params(self):
        dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1)
        result = dist.get_params({'y': torch.ones(1)})
        assert list(result.keys()) == ['loc', 'scale']

        dist = Normal(var=['x'], cond_var=['y'], loc='y',
                      scale=1).replace_var(y='z')
        result = dist.get_params({'z': torch.ones(1)})
        assert list(result.keys()) == ['loc', 'scale']

        dist = Normal(var=['x'], cond_var=['y'], loc='y',
                      scale=1).replace_var(y='z')
        with pytest.raises(ValueError):
            dist.get_params({'y': torch.ones(1)})

        dist = Normal(var=['x'], cond_var=['y'], loc='y',
                      scale=1).replace_var(x='z')
        result = dist.get_params({'y': torch.ones(1)})
        assert list(result.keys()) == ['loc', 'scale']

        dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal(
            var=['y'], loc=0, scale=1)
        with pytest.raises(NotImplementedError):
            dist.get_params()
Exemplo n.º 18
0
 def test_print(self):
     normal = Normal(var=['x'], name='p')
     print(normal.graph)
Exemplo n.º 19
0
 def test_input_var(self):
     q = Normal(var=['z'], cond_var=['x'], loc='x', scale=1)
     p = Normal(var=['y'], cond_var=['z'], loc='z', scale=1)
     e = Expectation(q, p.log_prob())
     assert set(e.input_var) == set(('x', 'y'))
     assert e.eval({'y': torch.zeros(1), 'x': torch.zeros(1)}).shape == torch.Size([1])
Exemplo n.º 20
0
 def prior():
     return Normal(var=["z"],
                   name="p_{prior}",
                   features_shape=[10],
                   loc=torch.tensor(0.),
                   scale=torch.tensor(1.))
Exemplo n.º 21
0
class Classifier(RelaxedCategorical):
    def __init__(self):
        super(Classifier, self).__init__(var=["y"], cond_var=["x"], name="p")
        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, y_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.softmax(self.fc2(h), dim=1)
        return {"probs": h}


# prior model p(z)
prior = Normal(loc=torch.tensor(0.),
               scale=torch.tensor(1.),
               var=["z"],
               features_shape=[z_dim],
               name="p_{prior}").to(device)

# distributions for supervised learning
p = Generator().to(device)
q = Inference().to(device)
f = Classifier().to(device)
p_joint = p * prior

# distributions for unsupervised learning
_q_u = q.replace_var(x="x_u", y="y_u")
p_u = p.replace_var(x="x_u", y="y_u")
f_u = f.replace_var(x="x_u", y="y_u")

q_u = _q_u * f_u
Exemplo n.º 22
0
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    # p(y|x,z)
    p = net.Generator_().to(device)

    # q(z|x,y)
    q = net.Inference().to(device)

    # prior p(z)
    prior = Normal(loc=torch.tensor(0.0),
                   scale=torch.tensor(1.0),
                   var=["z"],
                   features_shape=[net.Z_DIM],
                   name="p_{prior}").to(device)

    loss = (KullbackLeibler(q, prior) - Expectation(q, LogProb(p))).mean()
    model = Model(loss=loss,
                  distributions=[p, q],
                  optimizer=optim.Adam,
                  optimizer_params={"lr": 1e-3})
    # print(model)

    x_org, y_org = next(iter(test_loader))

    # 再構築用サンプルデータ
    x_fixed = x_org[:8].to(device)
    y_fixed = y_org[:8]