Exemple #1
0
    def test_init_with_scalar_params(self):
        normal = Normal(loc=0, scale=1, features_shape=[2])
        assert normal.sample()['x'].shape == torch.Size([1, 2])
        assert normal.features_shape == torch.Size([2])

        normal = Normal(loc=0, scale=1)
        assert normal.sample()['x'].shape == torch.Size([1])
        assert normal.features_shape == torch.Size([])
Exemple #2
0
 def test_input_extra_var(self):
     normal = Normal(loc=0, scale=1)
     assert set(normal.sample({'y': torch.zeros(1)})) == set(('x', 'y'))
     assert normal.get_log_prob({
         'y': torch.zeros(1),
         'x': torch.zeros(1)
     }).shape == torch.Size([1])
     assert set(normal.sample({'x': torch.zeros(1)})) == set(('x'))
Exemple #3
0
    def test_set_option(self):
        dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal(
            var=['y'], loc=0, scale=1)
        dist.graph.set_option(dict(batch_n=4, sample_shape=(2, 3)), ['y'])
        sample = dist.sample()
        assert sample['y'].shape == torch.Size([2, 3, 4])
        assert sample['x'].shape == torch.Size([2, 3, 4])
        dist.graph.set_option({}, ['y'])
        assert dist.get_log_prob(sample, sum_features=True,
                                 feature_dims=None).shape == torch.Size([2])
        assert dist.get_log_prob(
            sample, sum_features=False).shape == torch.Size([2, 3, 4])

        dist = Normal(var=['x'], cond_var=['y'], loc='y',
                      scale=1) * FactorizedBernoulli(
                          var=['y'], probs=torch.tensor([0.3, 0.8]))
        dist.graph.set_option(dict(batch_n=3, sample_shape=(4, )), ['y'])
        sample = dist.sample()
        assert sample['y'].shape == torch.Size([4, 3, 2])
        assert sample['x'].shape == torch.Size([4, 3, 2])
        dist.graph.set_option(dict(), ['y'])
        assert dist.get_log_prob(sample, sum_features=True,
                                 feature_dims=[-1]).shape == torch.Size([4, 3])
Exemple #4
0
 def test_get_log_prob_feature_dims2(self):
     dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal(
         var=['y'], loc=0, scale=1)
     dist.graph.set_option(dict(batch_n=4, sample_shape=(2, 3)), ['y'])
     sample = dist.sample()
     assert sample['y'].shape == torch.Size([2, 3, 4])
     list(dist.graph._factors_from_variable('y'))[0].option = {}
     assert dist.get_log_prob(sample, sum_features=True,
                              feature_dims=None).shape == torch.Size([2])
     assert dist.get_log_prob(sample, sum_features=True,
                              feature_dims=[-2]).shape == torch.Size([2, 4])
     assert dist.get_log_prob(sample,
                              sum_features=True,
                              feature_dims=[0, 1]).shape == torch.Size([4])
     assert dist.get_log_prob(sample, sum_features=True,
                              feature_dims=[]).shape == torch.Size(
                                  [2, 3, 4])
Exemple #5
0
 def test_sample_mean(self):
     dist = Normal(loc=0, scale=1)
     assert dist.sample(sample_mean=True)['x'] == torch.zeros(1)
Exemple #6
0
 def test_batch_n(self):
     normal = Normal(loc=0, scale=1)
     assert normal.sample(batch_n=3)['x'].shape == torch.Size([3])
Exemple #7
0
 def test_sample_mean(self):
     dist = Normal(var=['x'], loc=0, scale=1) * Normal(
         var=['y'], cond_var=['x'], loc='x', scale=1)
     assert dist.sample(sample_mean=True)['y'] == torch.zeros(1)
Exemple #8
0
def main(smoke_test=False):
    epochs = 2 if smoke_test == True else 50
    batch_size = 128
    seed = 0

    x_ch = 1
    z_dim = 32

    # 乱数シード初期化
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    date_and_time = datetime.datetime.now().strftime('%Y-%m%d-%H%M')
    save_root = f'./results/pixyz/{date_and_time}'
    if not os.path.exists(save_root):
        os.makedirs(save_root)

    if torch.cuda.is_available():
        device = 'cuda:0'
    else:
        device = 'cpu'

    root = '/mnt/hdd/sika/Datasets'
    train_loader, test_loader = make_MNIST_loader(root, batch_size=batch_size)

    # 生成モデルと推論モデルの生成
    p = Generator(x_ch, z_dim).to(device)
    q = Inference(x_ch, z_dim).to(device)

    # 潜在変数の事前分布の規定
    p_prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
                     var=['z'], features_shape=[z_dim], name='p_{prior}').to(device)

    # 損失関数の定義
    loss = (KullbackLeibler(q, p_prior) - Expectation(q, LogProb(p))).mean()

    # Model APIの設定
    model = Model(loss=loss, distributions=[p, q],
                  optimizer=optim.Adam, optimizer_params={"lr": 1e-3})

    x_fixed, _ = next(iter(test_loader))
    x_fixed = x_fixed[:8].to(device)
    z_fixed = p_prior.sample(batch_n=64)['z']

    train_loss_list, test_loss_list = [], []
    for epoch in range(1, epochs + 1):
        train_loss_list.append(learn(model, epoch, train_loader, device, train=True))
        test_loss_list.append(learn(model, epoch, test_loader, device, train=False))

        print(f'    [Epoch {epoch}] train loss {train_loss_list[-1]:.4f}')
        print(f'    [Epoch {epoch}] test  loss {test_loss_list[-1]:.4f}\n')

        # 損失値のグラフを作成し保存
        plt.plot(list(range(1, epoch+1)), train_loss_list, label='train')
        plt.plot(list(range(1, epoch + 1)), test_loss_list, label='test')
        plt.xlabel('epochs')
        plt.ylabel('loss')
        plt.legend()
        plt.savefig(os.path.join(save_root, 'loss.png'))
        plt.close()

        # 再構成画像
        x_reconst = reconstruct_image(p, q, x_fixed)
        save_image(torch.cat([x_fixed, x_reconst], dim=0), os.path.join(save_root, f'reconst_{epoch}.png'), nrow=8)

        # 補間画像
        x_interpol = interpolate_image(p, q, x_fixed)
        save_image(x_interpol, os.path.join(save_root, f'interpol_{epoch}.png'), nrow=8)

        # 生成画像(潜在変数固定)
        x_generate = generate_image(p, z_fixed)
        save_image(x_generate, os.path.join(save_root, f'generate_{epoch}.png'), nrow=8)

        # 生成画像(ランダムサンプリング)
        x_sample = sample_image(p_prior, p)
        save_image(x_sample, os.path.join(save_root, f'sample_{epoch}.png'), nrow=8)
    model = Model(loss=loss,
                  distributions=[p, q],
                  optimizer=optim.Adam,
                  optimizer_params={"lr": 1e-3})
    # print(model)

    x_org, y_org = next(iter(test_loader))

    # 再構築用サンプルデータ
    x_fixed = x_org[:8].to(device)
    y_fixed = y_org[:8]
    y_fixed = torch.eye(CLASS_SIZE)[y_fixed].to(device)
    y_answers_1 = torch.argmax(y_fixed, dim=1)

    # 識別器用サンプルデータ
    z_sample = prior.sample(batch_n=8)['z'].to(device)
    x_sample = x_org[8:16].to(device)
    y_answers_2 = y_org[8:16]

    z_samples = prior.sample(batch_n=BATCH_SIZE)['z'].to(device)

    train_loss_list = []
    test_loss_list = []
    train_accuracy_list = []
    test_accuracy_list = []
    for epoch in range(1, EPOCHS + 1):
        train_loss = learn(epoch, model, device, train_loader, "Train")
        test_loss = learn(epoch, model, device, test_loader, "Test")
        train_loss_list.append(train_loss)
        test_loss_list.append(test_loss)
Exemple #10
0
    q = net.Inference().to(device)

    # prior p(z)
    prior = Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0),
                   var=["z"], features_shape=[net.Z_DIM], name="p_{prior}").to(device)

    loss = (KullbackLeibler(q, prior) - Expectation(q, LogProb(p))).mean()
    model = Model(loss=loss, distributions=[p, q], optimizer=optim.Adam, optimizer_params={"lr": 1e-3})
    # print(model)

    x_fixed, y_fixed = next(iter(test_loader))
    x_fixed = x_fixed[:8].to(device)
    y_fixed = y_fixed[:8]
    y_fixed = torch.eye(CLASS_SIZE)[y_fixed].to(device)

    z_sample = prior.sample(batch_n=64)['z'].to(device)
    y_sample = torch.eye(CLASS_SIZE)[[PLOT_NUMBER] * 64].to(device)

    train_loss_list = []
    test_loss_list = []
    for epoch in range(1, EPOCHS + 1):
        train_loss = learn(epoch, model, device, train_loader, "Train")
        test_loss = learn(epoch, model, device, test_loader, "Test")
        train_loss_list.append(train_loss)
        test_loss_list.append(test_loss)

        print(f'    [Epoch {epoch}] train loss {train_loss_list[-1]:.4f}')
        print(f'    [Epoch {epoch}] test  loss {test_loss_list[-1]:.4f}\n')

        # ELBOを描画する。
        plot_figure(epoch, train_loss_list, test_loss_list)