Esempio n. 1
0
def main(args):
    print("Loading data")
    dataset = args.data.rstrip('/').split('/')[-1]
    torch.cuda.set_device(args.cuda)
    device = args.device
    if dataset == 'mnist':
        train_loader, test_loader = get_mnist(args.batch_size, 'data/mnist')
        num = 10
    elif dataset == 'fashion':
        train_loader, test_loader = get_fashion_mnist(args.batch_size,
                                                      'data/fashion')
        num = 10
    elif dataset == 'svhn':
        train_loader, test_loader, _ = get_svhn(args.batch_size, 'data/svhn')
        num = 10
    elif dataset == 'stl':
        train_loader, test_loader, _ = get_stl10(args.batch_size, 'data/stl10')
    elif dataset == 'cifar':
        train_loader, test_loader = get_cifar(args.batch_size, 'data/cifar')
        num = 10
    elif dataset == 'chair':
        train_loader, test_loader = get_chair(args.batch_size,
                                              '~/data/rendered_chairs')
        num = 1393
    elif dataset == 'yale':
        train_loader, test_loader = get_yale(args.batch_size, 'data/yale')
        num = 38
    model = VAE(28 * 28, args.code_dim, args.batch_size, num,
                dataset).to(device)
    phi = nn.Sequential(
        nn.Linear(args.code_dim, args.phi_dim),
        nn.LeakyReLU(0.2, True),
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    optimizer_phi = torch.optim.Adam(phi.parameters(), lr=args.lr)
    criterion = nn.MSELoss(reduction='sum')
    for epoch in range(args.epochs):
        re_loss = 0
        kl_div = 0
        size = len(train_loader.dataset)
        for data, target in train_loader:
            data, target = data.squeeze(1).to(device), target.to(device)
            c = F.one_hot(target.long(), num_classes=num).float()
            output, q_z, p_z, z = model(data, c)
            hsic = HSIC(phi(z), target.long(), num)
            if dataset == 'mnist' or dataset == 'fashion':
                reloss = recon_loss(output, data.view(-1, 28 * 28))
            else:
                reloss = criterion(output, data)
            kld = total_kld(q_z, p_z)
            loss = reloss + kld + args.c * hsic

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            optimizer_phi.zero_grad()
            neg = -HSIC(phi(z.detach()), target.long(), num)
            neg.backward()
            optimizer_phi.step()

            re_loss += reloss.item() / size
            kl_div += kld.item() / size
        print('-' * 50)
        print(
            " Epoch {} |re loss {:5.2f} | kl div {:5.2f} | hs {:5.2f}".format(
                epoch, re_loss, kl_div, hsic))
    for data, target in test_loader:
        data, target = data.squeeze(1).to(device), target.to(device)
        c = F.one_hot(target.long(), num_classes=num).float()
        output, _, _, z = model(data, c)
        break
    if dataset == 'mnist' or dataset == 'fashion':
        img_size = [data.size(0), 1, 28, 28]
    else:
        img_size = [data.size(0), 3, 32, 32]
    images = [data.view(img_size)[:30].cpu()]
    for i in range(10):
        c = F.one_hot(torch.ones(z.size(0)).long() * i,
                      num_classes=num).float().to(device)
        output = model.decoder(torch.cat((z, c), dim=-1))
        images.append(output.view(img_size)[:30].cpu())
    images = torch.cat(images, dim=0)
    save_image(images,
               'imgs/recon_c{}_{}.png'.format(int(args.c), dataset),
               nrow=30)
    torch.save(model.state_dict(),
               'vae_c{}_{}.pt'.format(int(args.c), dataset))
    # z = p_z.sample()
    # for i in range(10):
    #     c = F.one_hot(torch.ones(z.size(0)).long()*i, num_classes=10).float().to(device)
    #     output = model.decoder(torch.cat((z, c), dim=-1))
    #     n = min(z.size(0), 8)
    #     save_image(output.view(z.size(0), 1, 28, 28)[:n].cpu(), 'imgs/recon_{}.png'.format(i), nrow=n)
    if args.tsne:
        datas, targets = [], []
        for i, (data, target) in enumerate(test_loader):
            datas.append(data), targets.append(target)
            if i >= 5:
                break
        data, target = torch.cat(datas, dim=0), torch.cat(targets, dim=0)
        c = F.one_hot(target.long(), num_classes=num).float()
        _, _, _, z = model(data.to(args.device), c.to(args.device))
        z, target = z.detach().cpu().numpy(), target.cpu().numpy()
        tsne = TSNE(n_components=2, init='pca', random_state=0)
        z_2d = tsne.fit_transform(z)
        plt.figure(figsize=(6, 5))
        plot_embedding(z_2d, target)
        plt.savefig('tsnes/tsne_c{}_{}.png'.format(int(args.c), dataset))
Esempio n. 2
0
def main(args):
    conf = None
    with open(args.config, 'r') as config_file:
        config = yaml.load(config_file, Loader=yaml.FullLoader)
        conf = config['combine']
        model_params = config['model']
        preprocess_params = config['preprocessor']
    date_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime())
    path = os.path.join(conf['save_path'], date_time)
    path = conf['save_path']

    model = VAE(model_params['roll_dim'], model_params['hidden_dim'],
                model_params['infor_dim'], model_params['time_step'], 12)

    model.load_state_dict(torch.load(conf['model_path']))
    if torch.cuda.is_available():
        print('Using: ',
              torch.cuda.get_device_name(torch.cuda.current_device()))
        model.cuda()
    else:
        print('CPU mode')
    model.eval()
    pitch_path = conf['p_path'] + ".txt"
    rhythm_path = conf['r_path'] + ".txt"
    #chord_path = conf['chord_path'] + ".txt"
    name1 = pitch_path.split("/")[-3]
    name2 = rhythm_path.split("/")[-3]
    name = name1 + "+" + name2 + ".mid"
    name2 = name1 + "+" + name2 + ".txt"

    pitch = np.loadtxt(pitch_path)
    print(pitch)
    rhythm = np.loadtxt(rhythm_path)
    print(rhythm)

    print("Importing " + name1 + " pitch and " + name2 + " rhythm")

    #line_graph(pitch,rhythm)
    #bar_graph(pitch,rhythm)

    pitch = torch.from_numpy(pitch).float()
    rhythm = torch.from_numpy(rhythm).float()
    recon = model.decoder(pitch, rhythm)

    recon = torch.squeeze(recon, 0)
    recon = mf._sampling(recon)
    recon = np.array(recon.cpu().detach().numpy())
    length = torch.sum(rhythm).int()
    recon = recon[:length]
    #打印生成的音符分布
    note = recon[:, :-2]
    note = np.nonzero(note)[1]
    note = np.bincount(note, minlength=34).astype(float)
    recon = mf.modify_pianoroll_dimentions(recon,
                                           preprocess_params['low_crop'],
                                           preprocess_params['high_crop'],
                                           "add")

    #bar_graph(pitch,rhythm)
    mf.numpy_to_midi(recon, 120, path, name,
                     preprocess_params['smallest_note'])

    #pitch_rhythm(recon,path,name2) # write pitch information

    print("combine succeed")
Esempio n. 3
0
    with torch.no_grad():
        model.eval()
        frame_idx = 0

        for test_data in test_loader:
            test_data = test_data.to(device)
            # test_data -= mean[None, ...]
            rec, penalty = model(test_data)
            mu, log_var = model.encoder(test_data)
            mus.append(mu.clone().detach().cpu().numpy())
            log_vars.append(log_var.clone().detach().cpu().numpy())
            if reparam:
                latent = reparameterize(mu, log_var)
            else:
                latent = mu
            rec = model.decoder(latent).cpu().numpy()
            recs.append(rec)
            if (epoch + 1) % 10 == 0:
                fnames = []
                print("Plotting reconstructions...")
                test_data = test_data.cpu().numpy()
                for local_frame_idx in range(test_data.shape[0]):
                    fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2,
                                                   figsize=(16, 8))
                    ax0.imshow(test_data[local_frame_idx, 0, :, :],
                               cmap=plt.cm.RdBu)
                    ax0.set_title("Original data (voxel-wise standardized)")

                    ax1.imshow(rec[local_frame_idx, 0, :, :], cmap=plt.cm.RdBu)
                    ax1.set_title("Reconstructed")
                    fname = f"reconstructions_{epoch:05d}_{frame_idx:05d}.png"
Esempio n. 4
0
image_size = 784
batch_size = 128

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE()
model.eval()
model.load_state_dict(torch.load('trained_model.ckpt'))
model.to(device)
test_loader = torch.utils.data.DataLoader(datasets.MNIST(
    '../data', train=False, transform=transforms.ToTensor()),
                                          batch_size=batch_size,
                                          shuffle=True)

# Generate one batch of testing sample
test_sample, _ = next(iter(test_loader))
test_sample = test_sample.to(device).view(-1, 784)

with torch.no_grad():
    # Sampled Image from decoder
    z = torch.randn(batch_size, 20).to(device)
    sampled = model.decoder(z).view(-1, 1, 28, 28)
    save_image(sampled, 'sampled.png')

    # Reconstructed from test sample
    reconstructed, _, _ = model(test_sample)
    reconstructed = reconstructed.view(-1, 1, 28, 28)
    test_sample = test_sample.view(-1, 1, 28, 28)
    x_concat = torch.cat([reconstructed, test_sample], dim=2)
    print(x_concat.shape)
    save_image(x_concat, 'reconstructed.png')