Beispiel #1
0
save_dir = 'out/gmm/%d_%d_%d_%d_%.2f' % (INPUT_SIZE, OUTPUT_SIZE, LATENT_SIZE, N_INFERENCE, ALPHA)
os.makedirs(save_dir, exist_ok=True)
file = open(os.path.join(save_dir, 'progress.csv'), 'wt')
csv_writer = None

for it in range(int(2e4)):
    inds = np.random.choice(nb_train, BATCH_SIZE, replace=False)
    x, y, prob = xs_train[inds], ys_train[inds], probs_train[inds]
    loss = cvae.update(x[:, None], y[:, None])

    if it % int(1e3) == 0:
        inds = np.random.choice(nb_train, BATCH_SIZE, replace=False)
        x, y, prob = xs_train[inds], ys_train[inds], probs_train[inds]
        prob_pred = []
        for _ in range(N_INFERENCE):
            mean, log_var = cvae.sample(x[:, None])
            _log_prob = gaussian_log_density(mean, log_var, torch.tensor(y[:, None], dtype=torch.float32))
            _prob = torch.exp(_log_prob).detach().numpy()
            prob_pred.append(_prob)
        prob_pred = np.mean(np.asarray(prob_pred), axis=0)
        test_loss = np.mean((prob - prob_pred)**2)
        print('Iter-{}; Train Loss: {:.4f}, Test Loss:{:.4f} fps:{}'.format(it, loss.data, test_loss,
                                                                            (it+1) // (time.time() - time_st)))
        print(np.round(prob[:3], 3), np.round(prob_pred[:3], 3))

        logs = dict(
            step=it,
            train_loss=loss.data.numpy(),
            test_loss=test_loss,
        )
        if csv_writer is None:
Beispiel #2
0
    decoder_layer_sizes=(200, 200),
    dataset_name='mnist'
)

time_st = time.time()

for it in range(int(2e4)):
    y, x = mnist.train.next_batch(BATCH_SIZE)

    loss = cvae.update(x, y)
    if it % int(1e3) == 0:
        print('Iter-{}; Loss: {:.4f}, fps:{}'.format(it, loss.data, (it+1) // (time.time() - time_st)))

        x = np.zeros(shape=[BATCH_SIZE, INPUT_SIZE], dtype=np.float32)
        x[:, np.random.randint(0, 10)] = 1.
        samples = cvae.sample(x).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/mnist'):
            os.makedirs('out/mnist')
Beispiel #3
0
import torch
import cv2

from cvae import Data_Manager,CVAE

# --------------------------------------------------------------------------------
if __name__ == '__main__':
    device_id = '0'

    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda:{}'.format(device_id))

    # Data Manager
    data_manager = Data_Manager()

    # C-VAE
    model = CVAE(data_manager, device, epochs=10)
    model.fit()

    # Single Sample
    model.load_model('cvae_state.pt')
    sample = model.sample(idx=5)

    cv2.imwrite('sample_test.png', sample)