def test_init_with_scalar_params(self): normal = Normal(loc=0, scale=1, features_shape=[2]) assert normal.sample()['x'].shape == torch.Size([1, 2]) assert normal.features_shape == torch.Size([2]) normal = Normal(loc=0, scale=1) assert normal.sample()['x'].shape == torch.Size([1]) assert normal.features_shape == torch.Size([])
def test_input_extra_var(self): normal = Normal(loc=0, scale=1) assert set(normal.sample({'y': torch.zeros(1)})) == set(('x', 'y')) assert normal.get_log_prob({ 'y': torch.zeros(1), 'x': torch.zeros(1) }).shape == torch.Size([1]) assert set(normal.sample({'x': torch.zeros(1)})) == set(('x'))
def test_set_option(self): dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal( var=['y'], loc=0, scale=1) dist.graph.set_option(dict(batch_n=4, sample_shape=(2, 3)), ['y']) sample = dist.sample() assert sample['y'].shape == torch.Size([2, 3, 4]) assert sample['x'].shape == torch.Size([2, 3, 4]) dist.graph.set_option({}, ['y']) assert dist.get_log_prob(sample, sum_features=True, feature_dims=None).shape == torch.Size([2]) assert dist.get_log_prob( sample, sum_features=False).shape == torch.Size([2, 3, 4]) dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * FactorizedBernoulli( var=['y'], probs=torch.tensor([0.3, 0.8])) dist.graph.set_option(dict(batch_n=3, sample_shape=(4, )), ['y']) sample = dist.sample() assert sample['y'].shape == torch.Size([4, 3, 2]) assert sample['x'].shape == torch.Size([4, 3, 2]) dist.graph.set_option(dict(), ['y']) assert dist.get_log_prob(sample, sum_features=True, feature_dims=[-1]).shape == torch.Size([4, 3])
def test_get_log_prob_feature_dims2(self): dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal( var=['y'], loc=0, scale=1) dist.graph.set_option(dict(batch_n=4, sample_shape=(2, 3)), ['y']) sample = dist.sample() assert sample['y'].shape == torch.Size([2, 3, 4]) list(dist.graph._factors_from_variable('y'))[0].option = {} assert dist.get_log_prob(sample, sum_features=True, feature_dims=None).shape == torch.Size([2]) assert dist.get_log_prob(sample, sum_features=True, feature_dims=[-2]).shape == torch.Size([2, 4]) assert dist.get_log_prob(sample, sum_features=True, feature_dims=[0, 1]).shape == torch.Size([4]) assert dist.get_log_prob(sample, sum_features=True, feature_dims=[]).shape == torch.Size( [2, 3, 4])
def test_sample_mean(self): dist = Normal(loc=0, scale=1) assert dist.sample(sample_mean=True)['x'] == torch.zeros(1)
def test_batch_n(self): normal = Normal(loc=0, scale=1) assert normal.sample(batch_n=3)['x'].shape == torch.Size([3])
def test_sample_mean(self): dist = Normal(var=['x'], loc=0, scale=1) * Normal( var=['y'], cond_var=['x'], loc='x', scale=1) assert dist.sample(sample_mean=True)['y'] == torch.zeros(1)
def main(smoke_test=False): epochs = 2 if smoke_test == True else 50 batch_size = 128 seed = 0 x_ch = 1 z_dim = 32 # 乱数シード初期化 torch.manual_seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) date_and_time = datetime.datetime.now().strftime('%Y-%m%d-%H%M') save_root = f'./results/pixyz/{date_and_time}' if not os.path.exists(save_root): os.makedirs(save_root) if torch.cuda.is_available(): device = 'cuda:0' else: device = 'cpu' root = '/mnt/hdd/sika/Datasets' train_loader, test_loader = make_MNIST_loader(root, batch_size=batch_size) # 生成モデルと推論モデルの生成 p = Generator(x_ch, z_dim).to(device) q = Inference(x_ch, z_dim).to(device) # 潜在変数の事前分布の規定 p_prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=['z'], features_shape=[z_dim], name='p_{prior}').to(device) # 損失関数の定義 loss = (KullbackLeibler(q, p_prior) - Expectation(q, LogProb(p))).mean() # Model APIの設定 model = Model(loss=loss, distributions=[p, q], optimizer=optim.Adam, optimizer_params={"lr": 1e-3}) x_fixed, _ = next(iter(test_loader)) x_fixed = x_fixed[:8].to(device) z_fixed = p_prior.sample(batch_n=64)['z'] train_loss_list, test_loss_list = [], [] for epoch in range(1, epochs + 1): train_loss_list.append(learn(model, epoch, train_loader, device, train=True)) test_loss_list.append(learn(model, epoch, test_loader, device, train=False)) print(f' [Epoch {epoch}] train loss {train_loss_list[-1]:.4f}') print(f' [Epoch {epoch}] test loss {test_loss_list[-1]:.4f}\n') # 損失値のグラフを作成し保存 plt.plot(list(range(1, epoch+1)), train_loss_list, label='train') plt.plot(list(range(1, epoch + 1)), test_loss_list, label='test') plt.xlabel('epochs') plt.ylabel('loss') plt.legend() plt.savefig(os.path.join(save_root, 'loss.png')) plt.close() # 再構成画像 x_reconst = reconstruct_image(p, q, x_fixed) save_image(torch.cat([x_fixed, x_reconst], dim=0), os.path.join(save_root, f'reconst_{epoch}.png'), nrow=8) # 補間画像 x_interpol = interpolate_image(p, q, x_fixed) save_image(x_interpol, os.path.join(save_root, f'interpol_{epoch}.png'), nrow=8) # 生成画像(潜在変数固定) x_generate = generate_image(p, z_fixed) save_image(x_generate, os.path.join(save_root, f'generate_{epoch}.png'), nrow=8) # 生成画像(ランダムサンプリング) x_sample = sample_image(p_prior, p) save_image(x_sample, os.path.join(save_root, f'sample_{epoch}.png'), nrow=8)
model = Model(loss=loss, distributions=[p, q], optimizer=optim.Adam, optimizer_params={"lr": 1e-3}) # print(model) x_org, y_org = next(iter(test_loader)) # 再構築用サンプルデータ x_fixed = x_org[:8].to(device) y_fixed = y_org[:8] y_fixed = torch.eye(CLASS_SIZE)[y_fixed].to(device) y_answers_1 = torch.argmax(y_fixed, dim=1) # 識別器用サンプルデータ z_sample = prior.sample(batch_n=8)['z'].to(device) x_sample = x_org[8:16].to(device) y_answers_2 = y_org[8:16] z_samples = prior.sample(batch_n=BATCH_SIZE)['z'].to(device) train_loss_list = [] test_loss_list = [] train_accuracy_list = [] test_accuracy_list = [] for epoch in range(1, EPOCHS + 1): train_loss = learn(epoch, model, device, train_loader, "Train") test_loss = learn(epoch, model, device, test_loader, "Test") train_loss_list.append(train_loss) test_loss_list.append(test_loss)
q = net.Inference().to(device) # prior p(z) prior = Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0), var=["z"], features_shape=[net.Z_DIM], name="p_{prior}").to(device) loss = (KullbackLeibler(q, prior) - Expectation(q, LogProb(p))).mean() model = Model(loss=loss, distributions=[p, q], optimizer=optim.Adam, optimizer_params={"lr": 1e-3}) # print(model) x_fixed, y_fixed = next(iter(test_loader)) x_fixed = x_fixed[:8].to(device) y_fixed = y_fixed[:8] y_fixed = torch.eye(CLASS_SIZE)[y_fixed].to(device) z_sample = prior.sample(batch_n=64)['z'].to(device) y_sample = torch.eye(CLASS_SIZE)[[PLOT_NUMBER] * 64].to(device) train_loss_list = [] test_loss_list = [] for epoch in range(1, EPOCHS + 1): train_loss = learn(epoch, model, device, train_loader, "Train") test_loss = learn(epoch, model, device, test_loader, "Test") train_loss_list.append(train_loss) test_loss_list.append(test_loss) print(f' [Epoch {epoch}] train loss {train_loss_list[-1]:.4f}') print(f' [Epoch {epoch}] test loss {test_loss_list[-1]:.4f}\n') # ELBOを描画する。 plot_figure(epoch, train_loss_list, test_loss_list)