Пример #1
0
def generating(conf):

    net = VRNN(conf.x_dim, conf.h_dim, conf.z_dim)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net.to(device)
    net = torch.nn.DataParallel(net, device_ids=conf.device_ids)
    net.load_state_dict(torch.load(conf.checkpoint_path,
                                   map_location='cuda:0'))
    print('Restore model from ' + conf.checkpoint_path)

    with torch.no_grad():
        n = 15  # figure with 15x15 digits
        digit_size = 28
        figure = np.zeros((digit_size * n, digit_size * n))

        for i in range(n):
            for j in range(n):
                x_decoded = net.module.sampling(digit_size, device)
                x_decoded = x_decoded.cpu().numpy()
                digit = x_decoded.reshape(digit_size, digit_size)
                figure[i * digit_size:(i + 1) * digit_size,
                       j * digit_size:(j + 1) * digit_size] = digit

        plt.figure(figsize=(10, 10))
        plt.imshow(figure, cmap='Greys_r')
        plt.show()
Пример #2
0
def main(opt):

    device = torch.device(opt.device if torch.cuda.is_available() else "cpu")
    train_set = dataset.SBU_Dataset(opt, training=True)
    test_set = dataset.SBU_Dataset(opt, training=False)
    lens = train_set.__len__()
    iters_per_epoch = math.ceil(lens / opt.batch_size)
    max_epoch = math.ceil(opt.max_iter / iters_per_epoch)

    train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True,
                             num_workers=opt.job, pin_memory=True,
                             collate_fn=dataset.collate_fn, drop_last=False)
    test_loader = DataLoader(test_set, batch_size=opt.batch_size, shuffle=True,
                              num_workers=opt.job, pin_memory=True,
                              collate_fn=dataset.collate_fn, drop_last=False)

    writer = SummaryWriter()
    print("loading the model.......")
    net = VRNN(opt)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr, betas=(0.9, 0.999))
    best_loss = 10000
    bar = tqdm(range(max_epoch))
    for epoch in bar:
        bar.set_description('train epoch %06d' % epoch)
        train(train_loader, net, device, optimizer, writer, epoch)
        test_loss = test(test_loader, net, device, writer, epoch)
        if test_loss < best_loss:
            best_loss = test_loss
            save_model(net, optimizer,epoch)

    writer.close()
Пример #3
0
def generate(opt, txt_path):

    device = torch.device(opt.device if torch.cuda.is_available() else "cpu")
    test_set = dataset.SBU_Dataset(opt, training=False)
    test_loader = DataLoader(test_set,
                             batch_size=1,
                             shuffle=False,
                             num_workers=opt.job,
                             pin_memory=True,
                             collate_fn=dataset.collate_fn,
                             drop_last=False)
    data_mean = opt.train_mean
    data_std = opt.train_std
    net = VRNN(opt)
    net.to(device)
    net.load_state_dict(
        torch.load('./models/bestmodel', map_location=device)["state_dict"])
    net.eval()

    with torch.no_grad():
        for n_iter, [batch_x_pack, batch_y_pack] in enumerate(test_loader, 0):
            batch_x_pack = batch_x_pack.float().to(device)
            batch_y_pack = batch_y_pack.float().to(device)
            output = net(batch_x_pack, batch_y_pack)
            _, _, _, _, decoder_all_1, decoder_all_2, _, _, _, _ = output
            # decoder_all_1: list,len = max_step, element = [1, 45]
            seq = []

            for t in range(len(decoder_all_1)):
                joints = np.concatenate(
                    (decoder_all_1[t].squeeze(dim=0).cpu().numpy(),
                     decoder_all_2[t].squeeze(dim=0).cpu().numpy()),
                    axis=0)
                joints = utils.unNormalizeData(joints, data_mean, data_std)
                seq.append(joints)
            np.savetxt(os.path.join(txt_path, '%03d.txt' % (n_iter + 1)),
                       np.array(seq),
                       fmt="%.4f",
                       delimiter=',')
Пример #4
0
def train(conf):

    train_loader, test_loader = load_dataset(512)
    net = VRNN(conf.x_dim, conf.h_dim, conf.z_dim)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    torch.cuda.manual_seed_all(112858)
    net.to(device)
    net = torch.nn.DataParallel(net, device_ids=[0, 1])
    if conf.restore == True:
        net.load_state_dict(
            torch.load(conf.checkpoint_path, map_location='cuda:0'))
        print('Restore model from ' + conf.checkpoint_path)
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    for ep in range(1, conf.train_epoch + 1):
        prog = Progbar(target=117)
        print("At epoch:{}".format(str(ep)))
        for i, (data, target) in enumerate(train_loader):
            data = data.squeeze(1)
            data = (data / 255).to(device)
            package = net(data)
            loss = Loss(package, data)
            net.zero_grad()
            loss.backward()
            _ = torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
            optimizer.step()
            prog.update(i, exact=[("Training Loss", loss.item())])

        with torch.no_grad():
            x_decoded = net.module.sampling(conf.x_dim, device)
            x_decoded = x_decoded.cpu().numpy()
            digit = x_decoded.reshape(conf.x_dim, conf.x_dim)
            plt.imshow(digit, cmap='Greys_r')
            plt.pause(1e-6)

        if ep % conf.save_every == 0:
            torch.save(net.state_dict(),
                       '../checkpoint/Epoch_' + str(ep + 1) + '.pth')
Пример #5
0
    dir_name = mk_dir(config.data + 'experiment')

    print(config, "DEVICE", device)

    data = read_data('data/pianorolls/{}.pkl'.format(config.data))

    train_data, test_data = data2seq(data=data,
                                     split='train',
                                     seq_len=config.seq_len)

    if config.model == "VRNN":
        model = VRNN(config, device)
    else:
        print("NotImplementedERROR")

    model.to(device)

    epoch = 0

    while (epoch < config.epochs):

        train_loader = iter(train_data)

        RANGE_LOSS1 = 0
        RANGE_LOSS2 = 0
        RANGE_LOSS3 = 0

        for idx, train_mat in train_loader:

            if idx % 20 == 0:
                print("{}/{} BATCH".format(idx + 1, len(train_data)))
torch.manual_seed(seed)
plt.ion()

#init model + optimizer + datasets

train_loader = torch.utils.data.DataLoader(datasets.MNIST(
    'data', train=True, download=True, transform=transforms.ToTensor()),
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(datasets.MNIST(
    'data', train=False, transform=transforms.ToTensor()),
                                          batch_size=batch_size,
                                          shuffle=True)

model = VRNN(x_dim, h_dim, z_dim, n_layers)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(1, n_epochs + 1):

    #training + testing
    train(epoch)
    test(epoch)

    #saving model
    if epoch % save_every == 1:
        fn = 'saves/vrnn_state_dict_' + str(epoch) + '.pth'
        torch.save(model.state_dict(), fn)
        print('Saved model to ' + fn)