Exemple #1
0
def generating(conf):

    net = VRNN(conf.x_dim, conf.h_dim, conf.z_dim)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net.to(device)
    net = torch.nn.DataParallel(net, device_ids=conf.device_ids)
    net.load_state_dict(torch.load(conf.checkpoint_path,
                                   map_location='cuda:0'))
    print('Restore model from ' + conf.checkpoint_path)

    with torch.no_grad():
        n = 15  # figure with 15x15 digits
        digit_size = 28
        figure = np.zeros((digit_size * n, digit_size * n))

        for i in range(n):
            for j in range(n):
                x_decoded = net.module.sampling(digit_size, device)
                x_decoded = x_decoded.cpu().numpy()
                digit = x_decoded.reshape(digit_size, digit_size)
                figure[i * digit_size:(i + 1) * digit_size,
                       j * digit_size:(j + 1) * digit_size] = digit

        plt.figure(figsize=(10, 10))
        plt.imshow(figure, cmap='Greys_r')
        plt.show()
def generate(opt, txt_path):

    device = torch.device(opt.device if torch.cuda.is_available() else "cpu")
    test_set = dataset.SBU_Dataset(opt, training=False)
    test_loader = DataLoader(test_set,
                             batch_size=1,
                             shuffle=False,
                             num_workers=opt.job,
                             pin_memory=True,
                             collate_fn=dataset.collate_fn,
                             drop_last=False)
    data_mean = opt.train_mean
    data_std = opt.train_std
    net = VRNN(opt)
    net.to(device)
    net.load_state_dict(
        torch.load('./models/bestmodel', map_location=device)["state_dict"])
    net.eval()

    with torch.no_grad():
        for n_iter, [batch_x_pack, batch_y_pack] in enumerate(test_loader, 0):
            batch_x_pack = batch_x_pack.float().to(device)
            batch_y_pack = batch_y_pack.float().to(device)
            output = net(batch_x_pack, batch_y_pack)
            _, _, _, _, decoder_all_1, decoder_all_2, _, _, _, _ = output
            # decoder_all_1: list,len = max_step, element = [1, 45]
            seq = []

            for t in range(len(decoder_all_1)):
                joints = np.concatenate(
                    (decoder_all_1[t].squeeze(dim=0).cpu().numpy(),
                     decoder_all_2[t].squeeze(dim=0).cpu().numpy()),
                    axis=0)
                joints = utils.unNormalizeData(joints, data_mean, data_std)
                seq.append(joints)
            np.savetxt(os.path.join(txt_path, '%03d.txt' % (n_iter + 1)),
                       np.array(seq),
                       fmt="%.4f",
                       delimiter=',')
Exemple #3
0
def train(conf):

    train_loader, test_loader = load_dataset(512)
    net = VRNN(conf.x_dim, conf.h_dim, conf.z_dim)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    torch.cuda.manual_seed_all(112858)
    net.to(device)
    net = torch.nn.DataParallel(net, device_ids=[0, 1])
    if conf.restore == True:
        net.load_state_dict(
            torch.load(conf.checkpoint_path, map_location='cuda:0'))
        print('Restore model from ' + conf.checkpoint_path)
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    for ep in range(1, conf.train_epoch + 1):
        prog = Progbar(target=117)
        print("At epoch:{}".format(str(ep)))
        for i, (data, target) in enumerate(train_loader):
            data = data.squeeze(1)
            data = (data / 255).to(device)
            package = net(data)
            loss = Loss(package, data)
            net.zero_grad()
            loss.backward()
            _ = torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
            optimizer.step()
            prog.update(i, exact=[("Training Loss", loss.item())])

        with torch.no_grad():
            x_decoded = net.module.sampling(conf.x_dim, device)
            x_decoded = x_decoded.cpu().numpy()
            digit = x_decoded.reshape(conf.x_dim, conf.x_dim)
            plt.imshow(digit, cmap='Greys_r')
            plt.pause(1e-6)

        if ep % conf.save_every == 0:
            torch.save(net.state_dict(),
                       '../checkpoint/Epoch_' + str(ep + 1) + '.pth')
Exemple #4
0
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from model import VRNN

#hyperparameters
x_dim = 28
h_dim = 100
z_dim = 16
n_layers = 1

device = torch.device('cuda') if torch.cuda.is_available() else torch.device(
    'cpu')

state_dict = torch.load('saves/vrnn_state_dict_41.pth')
model = VRNN(x_dim, h_dim, z_dim, n_layers)
model.load_state_dict(state_dict)
model.to(device)

sample = model.sample(28 * 6)
plt.imshow(sample.cpu().numpy(), cmap='gray')
plt.show()
Exemple #5
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Load vocabulary wrapper.
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Image preprocessing
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    #val_loader = get_loader('./data/val_resized2014/', './data/annotations/captions_val2014.json',
    #                         vocab, transform, 1, False, 1)

    start_epoch = 0

    encoder_state = args.encoder
    decoder_state = args.decoder

    # Build the models
    encoder = EncoderCNN(args.embed_size)
    if not args.train_encoder:
        encoder.eval()
    decoder = VRNN(args.embed_size, args.hidden_size, len(vocab),
                   args.latent_size, args.num_layers)

    if args.restart:
        encoder_state, decoder_state = 'new', 'new'

    if encoder_state == '': encoder_state = 'new'
    if decoder_state == '': decoder_state = 'new'

    print("Using encoder: {}".format(encoder_state))
    print("Using decoder: {}".format(decoder_state))

    try:
        start_epoch = int(float(decoder_state.split('-')[1]))
    except:
        pass

    if encoder_state != 'new':
        encoder.load_state_dict(torch.load(encoder_state))
    if decoder_state != 'new':
        decoder.load_state_dict(torch.load(decoder_state))

    # Build data loader
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)
    """ Make logfile and log output """
    with open(args.model_path + args.logfile, 'a+') as f:
        f.write("Using encoder: new\nUsing decoder: new\n\n")

    if torch.cuda.is_available():
        encoder.cuda()
        decoder.cuda()

    # Optimizer
    cross_entropy = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(
        encoder.linear.parameters()) + list(encoder.bn.parameters())
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    batch_loss = []
    batch_loss_det = []
    batch_kl = []
    batch_ml = []
    batch_acc = []

    # Train the Models
    total_step = len(data_loader)
    for epoch in range(start_epoch, args.num_epochs):
        for i, (images, captions, lengths, _, _) in enumerate(data_loader):

            # get lengths excluding <start> symbol
            lengths = [l - 1 for l in lengths]

            # Set mini-batch dataset
            images = to_var(images, volatile=True)
            captions = to_var(captions)

            # assuming following assertion
            assert min(lengths) > args.z_step + 2

            # get targets from captions (excluding <start> tokens)
            #targets = pack_padded_sequence(captions[:,1:], lengths, batch_first=True)[0]
            targets_var = captions[:, args.z_step + 1]
            targets_det = pack_padded_sequence(
                captions[:, args.z_step + 2:],
                [l - args.z_step - 1 for l in lengths],
                batch_first=True)[0]

            # Get prior and approximate distributions
            decoder.zero_grad()
            encoder.zero_grad()
            features = encoder(images)
            prior, q_z, q_x, det_x = decoder(features,
                                             captions,
                                             lengths,
                                             z_step=args.z_step)

            # Calculate KL Divergence
            kl = torch.mean(kl_divergence(*q_z + prior))

            # Get marginal likelihood from log likelihood of the correct symbol
            index = (torch.cuda.LongTensor(range(q_x.shape[0])), targets_var)
            ml = torch.mean(q_x[index])

            # Get Cross-Entropy loss for deterministic decoder
            ce = cross_entropy(det_x, targets_det)

            elbo = ml - kl
            loss_var = -elbo

            loss_det = ce

            loss = loss_var + loss_det

            batch_loss.append(loss.data[0])
            batch_loss_det.append(loss_det.data[0])
            batch_kl.append(kl.data[0])
            batch_ml.append(ml.data[0])

            loss.backward()
            optimizer.step()

            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f'
                    % (epoch, args.num_epochs, i, total_step, loss.data[0],
                       np.exp(loss.data[0])))

                with open(args.model_path + args.logfile, 'a') as f:
                    f.write(
                        'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f'
                        % (epoch, args.num_epochs, i, total_step, loss.data[0],
                           np.exp(loss.data[0])))

            # Save the models
            if (i + 1) % args.save_step == 0:
                torch.save(
                    decoder.state_dict(),
                    os.path.join(args.model_path,
                                 'decoder-%d-%d.pkl' % (epoch + 1, i + 1)))
                if args.train_encoder:
                    torch.save(
                        encoder.state_dict(),
                        os.path.join(args.model_path,
                                     'encoder-%d-%d.pkl' % (epoch + 1, i + 1)))
                with open(args.model_path + 'training_loss.pkl', 'w+') as f:
                    pickle.dump(batch_loss, f)
                with open(args.model_path + 'training_val.pkl', 'w+') as f:
                    pickle.dump(batch_acc, f)

    with open(args.model_path + args.logfile, 'a') as f:
        f.write("Training finished at {} .\n\n".format(str(datetime.now())))