Example #1
0
def test_model():
    args = get_setup_args()
    word_vectors = util.torch_from_json(args.word_emb_file)
    with open(args.char2idx_file, "r") as f:
        char2idx = json_load(f)
    model = QANet(word_vectors, char2idx)
    cw_idxs = torch.randint(2, 1000, (64, 374))
    cc_idxs = torch.randint(2, 50, (64, 374, 200))
    qw_idxs = torch.randint(2, 1000, (64, 70))
    qc_idxs = torch.randint(2, 50, (64, 70, 200))
    cw_idxs[:, 0] = 1
    cw_idxs[3, -1] = 0
    qw_idxs[:, 0] = 1
    qw_idxs[3, -1] = 0
    out = model(cw_idxs, cc_idxs, qw_idxs, qc_idxs)
    print(out)
Example #2
0
def test_input_embedding():
    args = get_setup_args()
    d_char = 200
    word_dropout = 0.1
    char_dropout = 0.05
    with open(args.char2idx_file, "r") as f:
        char2idx = json_load(f)
    hidden_size = 500
    highway_dropout = 0.1
    word_vectors = util.torch_from_json(args.word_emb_file)
    input_embedding = InputEmbedding(word_vectors, d_char, char2idx,
                                     hidden_size, word_dropout, char_dropout,
                                     highway_dropout)

    word_inputs = torch.tensor([[1, 2, 0], [1, 2, 4]], dtype=torch.long)
    char_inputs = torch.tensor([[[1, 2, 2, 0], [1, 3, 2, 3], [0, 0, 0, 0]],
                                [[1, 5, 2, 0], [1, 3, 6, 3], [3, 4, 2, 1]]],
                               dtype=torch.long)
    emb = input_embedding(word_inputs, char_inputs)
    pickle_in = open('input_emb.pickle', 'wb')
    pickle.dump(emb, pickle_in)
    assert emb.size() == (2, 3, 500)
    return emb
Example #3
0
    save(args.word_emb_file, word_emb_mat, message="word embedding"
         )  # word embedding矩阵  (word_voc_size, embedding_size)
    save(args.char_emb_file, char_emb_mat, message="char embedding"
         )  # char embedding矩阵 (char_voc_size, embedding_size)
    save(args.train_eval_file, train_eval, message="train eval")  # 训练集处理结果
    save(args.dev_eval_file, dev_eval, message="dev eval")  # 验证集处理结果
    save(args.word2idx_file, word2idx_dict,
         message="word dictionary")  # word2index文件
    save(args.char2idx_file, char2idx_dict,
         message="char dictionary")  # char2index文件
    save(args.dev_meta_file, dev_meta, message="dev meta")  # dev集的meta信息


if __name__ == '__main__':
    # Get command-line args
    args_ = get_setup_args()

    # Download resources
    # download(args_)

    # Import spacy language model
    nlp = spacy.blank("en")

    # Preprocess dataset
    args_.train_file = url_to_data_path(args_.train_url)
    args_.dev_file = url_to_data_path(args_.dev_url)
    if args_.include_test_examples:
        args_.test_file = url_to_data_path(args_.test_url)
    glove_dir = url_to_data_path(args_.glove_url.replace('.zip', ''))
    glove_ext = f'.txt' if glove_dir.endswith(
        'd') else f'.{args_.glove_dim}d.txt'
Example #4
0
def main():

    set_random_seed()

    #torch.backends.cudnn.enabled = False

    # Arguments
    opt = args.get_setup_args()

    #cuda = True if torch.cuda.is_available() else False
    device, gpu_ids = util.get_available_devices()

    num_classes = opt.num_classes
    noise_dim = opt.latent_dim + opt.num_classes

    # WGAN hyperparams

    # number of training steps for discriminator per iter
    n_critic = 5
    # Gradient penalty lambda hyperparameter
    lambda_gp = 10

    def weights_init(m):
        if isinstance(m, cwgan.MyConvo2d): 
            if m.conv.weight is not None:
                if m.he_init:
                    nn.init.kaiming_uniform_(m.conv.weight)
                else:
                    nn.init.xavier_uniform_(m.conv.weight)
            if m.conv.bias is not None:
                nn.init.constant_(m.conv.bias, 0.0)
        if isinstance(m, nn.Linear):
            if m.weight is not None:
                nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0.0)
    
    train_images_path = os.path.join(opt.data_path, "train")
    val_images_path = os.path.join(opt.data_path, "val")
    output_model_path = os.path.join(opt.output_path, opt.version)
    output_train_images_path = os.path.join(opt.output_path, opt.version, "train")
    output_sample_images_path = os.path.join(opt.output_path, opt.version, "sample")

    os.makedirs(output_train_images_path, exist_ok=True)
    os.makedirs(output_sample_images_path, exist_ok=True)

    train_set = datasets.ImageFolder(root=train_images_path,
                                transform=transforms.Compose([
                                    transforms.Resize((opt.img_size, opt.img_size)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                            ]))

    dataloader = torch.utils.data.DataLoader(train_set,
                                            batch_size=opt.batch_size,
                                            shuffle=True,
                                            num_workers=opt.num_workers)

    gen = cwgan.Generator(noise_dim, 64).to(device)
    disc = cwgan.Discriminator(64, num_classes).to(device)

    gen.apply(weights_init)
    disc.apply(weights_init)

    optimG = optim.Adam(gen.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optimD = optim.Adam(disc.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

    #optimG = optim.RMSprop(gen.parameters(), lr=opt.lr)
    #optimD = optim.RMSprop(disc.parameters(), lr=opt.lr)

    #adversarial_loss = torch.nn.BCELoss()
    auxiliary_loss = torch.nn.CrossEntropyLoss()

    # Keep track of losses, accuracy, FID
    G_losses = []
    D_losses = []
    D_acc = []
    FIDs = []
    val_epochs = []

    def print_labels():
        for class_name in train_set.classes:
            print("{} -> {}".format(class_name, train_set.class_to_idx[class_name]))



    def eval_fid(gen_images_path, eval_images_path):        
        print("Calculating FID...")
        fid = fid_score.calculate_fid_given_paths((gen_images_path, eval_images_path), opt.batch_size, device)
        return fid

    
    def validate(keep_images=True):
        val_set = datasets.ImageFolder(root=val_images_path,
                                       transform=transforms.Compose([
                                                 transforms.Resize((opt.img_size, opt.img_size)),
                                                 transforms.ToTensor()
                            ]))
        
        val_loader = torch.utils.data.DataLoader(val_set,
                                            batch_size=opt.batch_size,
                                            shuffle=True,
                                            num_workers=opt.num_workers)
        
        output_images_path = os.path.join(opt.output_path, opt.version, "val")
        os.makedirs(output_images_path, exist_ok=True)

        output_source_images_path = val_images_path + "_" + str(opt.img_size)

        source_images_available = True

        if (not os.path.exists(output_source_images_path)):
            os.makedirs(output_source_images_path)
            source_images_available = False

        images_done = 0
        for _, data in enumerate(val_loader, 0):
            images, labels = data
            batch_size = images.size(0)
            noise = torch.randn((batch_size, opt.latent_dim)).to(device)
            labels = torch.randint(0, num_classes, (batch_size,)).to(device)
            labels_onehot = F.one_hot(labels, num_classes)

            noise = torch.cat((noise, labels_onehot.to(dtype=torch.float)), 1)
            gen_images = gen(noise)
            for i in range(images_done, images_done + batch_size):
                vutils.save_image(gen_images[i - images_done, :, :, :], "{}/{}.jpg".format(output_images_path, i), normalize=True)       
                if (not source_images_available):
                    vutils.save_image(images[i - images_done, :, :, :], "{}/{}.jpg".format(output_source_images_path, i), normalize=True)     
            images_done += batch_size
        
        fid = eval_fid(output_images_path, output_source_images_path)
        if (not keep_images):
            print("Deleting images generated for validation...")
            rmtree(output_images_path)
        return fid


    def sample_images(num_images, batches_done):
        # Sample noise
        z = torch.randn((num_classes * num_images, opt.latent_dim)).to(device)
        # Get labels ranging from 0 to n_classes for n rows
        labels = torch.zeros((num_classes * num_images,), dtype=torch.long).to(device)

        for i in range(num_classes):
            for j in range(num_images):
                labels[i*num_images + j] = i
        
        labels_onehot = F.one_hot(labels, num_classes)
        z = torch.cat((z, labels_onehot.to(dtype=torch.float)), 1)        
        sample_imgs = gen(z)
        vutils.save_image(sample_imgs.data, "{}/{}.png".format(output_sample_images_path, batches_done), nrow=num_images, padding=2, normalize=True)

    
    def save_loss_plot(path):
        plt.figure(figsize=(10,5))
        plt.title("Generator and Discriminator Loss During Training")
        plt.plot(G_losses,label="G")
        plt.plot(D_losses,label="D")
        plt.xlabel("iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(path)
        plt.close()

    def save_acc_plot(path):
        plt.figure(figsize=(10,5))
        plt.title("Discriminator Accuracy")
        plt.plot(D_acc)
        plt.xlabel("iterations")
        plt.ylabel("accuracy")
        plt.savefig(path)
        plt.close()
    
    def save_fid_plot(FIDs, epochs, path):
        #N = len(FIDs)
        plt.figure(figsize=(10,5))
        plt.title("FID on Validation Set")
        plt.plot(epochs, FIDs)
        plt.xlabel("epochs")
        plt.ylabel("FID")
        #plt.xticks([i * 49 for i in range(1, N+1)])    
        plt.savefig(path)
        plt.close()

    
    def calc_gradient_penalty(netD, real_data, fake_data):
        batch_size = real_data.size(0)
        alpha = torch.rand(batch_size, 1)
        alpha = alpha.expand(batch_size, int(real_data.nelement()/batch_size)).contiguous()
        alpha = alpha.view(batch_size, 3, opt.img_size, opt.img_size)
        alpha = alpha.to(device)

        #fake_data = fake_data.view(batch_size, 3, opt.img_size, opt.img_size)
        interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach())

        interpolates = interpolates.to(device)
        interpolates.requires_grad_(True)   

        disc_interpolates, _ = netD(interpolates)

        gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                grad_outputs=torch.ones(disc_interpolates.size()).to(device),
                                create_graph=True, retain_graph=True, only_inputs=True)[0]

        gradients = gradients.view(gradients.size(0), -1)                              
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp
        return gradient_penalty

    
    print("Label to class mapping:")
    print_labels()

    for epoch in range(1, opt.num_epochs + 1):
        for i, data in enumerate(dataloader, 0):

            images, class_labels = data
            images = images.to(device)
            class_labels = class_labels.to(device)

            batch_size = images.size(0)

            ############################
            # Train Discriminator
            ###########################
            
            ## Train with all-real batch

            optimD.zero_grad()

            real_pred, real_aux = disc(images)
            
            d_real_aux_loss = auxiliary_loss(real_aux, class_labels)

            # Train with fake batch
            noise = torch.randn((batch_size, opt.latent_dim)).to(device)
            gen_class_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
            gen_class_labels_onehot = F.one_hot(gen_class_labels, num_classes)

            noise = torch.cat((noise, gen_class_labels_onehot.to(dtype=torch.float)), 1)
            gen_images = gen(noise).detach()
            fake_pred, fake_aux = disc(gen_images)

            #d_fake_aux_loss = auxiliary_loss(fake_aux, gen_class_labels)

            gradient_penalty = calc_gradient_penalty(disc, images, gen_images)

            # Total discriminator loss
            d_aux_loss = d_real_aux_loss
            d_loss = fake_pred.mean() - real_pred.mean() + gradient_penalty + d_aux_loss
            
            # Calculate discriminator accuracy
            pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
            gt = np.concatenate([class_labels.data.cpu().numpy(), gen_class_labels.data.cpu().numpy()], axis=0)
            d_acc = np.mean(np.argmax(pred, axis=1) == gt)

            d_loss.backward()
            optimD.step()

            if i % n_critic == 0:
                ############################
                # Train Generator
                ###########################

                optimG.zero_grad()

                gen_images = gen(noise)

                gen_pred, aux_scores = disc(gen_images)
                g_aux_loss = auxiliary_loss(aux_scores, gen_class_labels)
                g_loss = g_aux_loss - gen_pred.mean()

                g_loss.backward()
                optimG.step()

            # Save losses and accuracy for plotting
            G_losses.append(g_loss.item())
            D_losses.append(d_loss.item())
            D_acc.append(d_acc)

            # Output training stats
            if i % opt.print_every == 0:
                print("[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f, acc:  %d%%] [G loss: %.4f]"
                % (epoch, opt.num_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
                )

            batches_done = epoch * len(dataloader) + i
            
            # Generate and save sample images
            if (batches_done % opt.sample_interval == 0) or ((epoch == opt.num_epochs-1) and (i == len(dataloader)-1)):
                # Put G in eval mode
                gen.eval()
                
                with torch.no_grad():                
                    sample_images(opt.num_sample_images, batches_done)
                vutils.save_image(gen_images.data[:36], "{}/{}.png".format(output_train_images_path, batches_done), nrow=6, padding=2, normalize=True)
                
                # Put G back in train mode
                gen.train()
            
        # Save model checkpoint
        if (epoch != opt.num_epochs and epoch % opt.checkpoint_epochs == 0):
            print("Checkpoint at epoch {}".format(epoch))

            print("Saving G & D loss plot...")
            save_loss_plot(os.path.join(opt.output_path, opt.version, "loss_plot_{}.png".format(epoch)))
            print("Saving D accuracy plot...")
            save_acc_plot(os.path.join(opt.output_path, opt.version, "accuracy_plot_{}.png".format(epoch)))
            
            print("Validating model...")
            gen.eval()
            with torch.no_grad():
                fid = validate(keep_images=False)
            print("Validation FID: {}".format(fid))
            FIDs.append(fid)
            val_epochs.append(epoch)
            print("Saving FID plot...")
            save_fid_plot(FIDs, val_epochs, os.path.join(opt.output_path, opt.version, "fid_plot_{}.png".format(epoch)))
            gen.train()

            print("Saving model checkpoint...")
            torch.save({
            'epoch': epoch,
            'g_state_dict': gen.state_dict(),
            'd_state_dict': disc.state_dict(),
            'g_optimizer_state_dict': optimG.state_dict(),
            'd_optimizer_state_dict': optimD.state_dict(),
            'g_loss': g_loss.item(),
            'd_loss': d_loss.item(),
            'd_accuracy': d_acc,
            'val_fid': fid
            }, os.path.join(output_model_path, "model_checkpoint_{}.tar".format(epoch)))

    print("Saving final G & D loss plot...")
    save_loss_plot(os.path.join(opt.output_path, opt.version, "loss_plot.png"))
    print("Done!")

    print("Saving final D accuracy plot...")
    save_acc_plot(os.path.join(opt.output_path, opt.version, "accuracy_plot.png"))
    print("Done!")

    print("Validating final model...")
    gen.eval()    
    with torch.no_grad():
        fid = validate()
    print("Final Validation FID: {}".format(fid))
    FIDs.append(fid)
    val_epochs.append(epoch)
    print("Saving final FID plot...")
    save_fid_plot(FIDs, val_epochs, os.path.join(opt.output_path, opt.version, "fid_plot"))
    print("Done!")

    print("Saving final model...")
    torch.save({
    'epoch': epoch,
    'g_state_dict': gen.state_dict(),
    'd_state_dict': disc.state_dict(),
    'g_optimizer_state_dict': optimG.state_dict(),
    'd_optimizer_state_dict': optimD.state_dict(),
    'g_loss': g_loss.item(),
    'd_loss': d_loss.item(),
    'd_accuracy': d_acc,
    'val_fid': fid
    }, os.path.join(output_model_path, "model.tar"))
    print("Done!")
Example #5
0
def pre_process():
# Process training set and use it to decide on the word/character vocabularies
    word_counter, char_counter = Counter(), Counter()

   
   #This takes args.train_file
   # all examples = [dicts]
   # all_eval  = {id -> dict}
   
    #POTENTIAL BUG: MAY BE BAD TO DO WORD COUNTER ON "ENTIRE" DATASET rather than just train like in orig setup.py
    all_examples, all_eval = setup.process_file("./adversarial_dataset.json", "all", word_counter, char_counter)
    all_indices = list(map(lambda e: e['id'], all_examples))
    
#    import pdb; pdb.set_trace()

#    print(all_examples[0]["context_tokens"], all_examples[0]["ques_tokens"])
#    print(all_examples[1]["context_tokens"], all_examples[1]["ques_tokens"])
#    print(type(all_examples))
#    print(type(all_eval))
    # indices are from  0 to 3559 (3560 questions total)

    
    # 2136 total questions and answers in train
    # 712 questions + answers in dev
    # 712 questions + answers in test
    train_examples, residual_examples = train_test_split(all_examples, test_size=0.4)
    dev_examples, test_examples = train_test_split(residual_examples, test_size=0.5)
    
    train_eval = {str(e['id']) : all_eval[str(e['id'])] for e in train_examples}
    dev_eval = {str(e['id']) : all_eval[str(e['id'])] for e in dev_examples}
    test_eval = {str(e['id']) : all_eval[str(e['id'])] for e in test_examples}


    # IMPORTANT: Ensure that we do not split corresponding question and answers into different datasets
    assert set([str(e['id']) for e in train_examples]) == set(train_eval.keys())
    assert set([str(e['id']) for e in dev_examples]) == set(dev_eval.keys())
    assert set([str(e['id']) for e in test_examples]) == set(test_eval.keys())

    # TODO: Call the rest of the setup.py to get the .npz files
    # TODO: Once we have the .npz, we can call test on the adversarial data
    # TODO: Re-train BiDAF on adversarial dataset
    # TODO: Data augmentation
    # TODO: Auxiliary Model to predict sentence relevancy
    
    
    # ========= FROM SETUP.PY =========== #
    # Need to create the .npz, .json files for dev, test, and train
    # this is desired structure for training/testing
    
    args = get_setup_args()
    
    
    # Setup glove path for adversarial dataset
    glove_dir = setup.url_to_data_path(args.glove_url.replace('.zip', ''))
    glove_ext = f'.txt' if glove_dir.endswith('d') else f'.{args.glove_dim}d.txt'
    args.glove_file = os.path.join(glove_dir, os.path.basename(glove_dir) + glove_ext)
    
    
    # Setup word, char embeddings for adversarial data
    word_emb_mat, word2idx_dict = setup.get_embedding(word_counter, 'word', emb_file=args.glove_file, vec_size=args.glove_dim, num_vectors=args.glove_num_vecs)
    char_emb_mat, char2idx_dict = setup.get_embedding(char_counter, 'char', emb_file=None, vec_size=args.char_dim)
      
      
    #args.train_record_file is the .npz file path that we want to save stuff to
    setup.build_features(args, train_examples, "train", "./adv_data/train.npz", word2idx_dict, char2idx_dict)
    dev_meta = setup.build_features(args, dev_examples, "dev", "./adv_data/dev.npz", word2idx_dict, char2idx_dict)
      
    # True by default
    if args.include_test_examples:
        # Step done above
#        test_examples, test_eval = process_file("./adversarial_dataset/test-v2.0.json", "adv test", word_counter, char_counter)
        setup.save("./adv_data/test_eval.json", test_eval, message="adv test eval")
        test_meta = setup.build_features(args, test_examples, "adv test", "./adv_data/test.npz", word2idx_dict, char2idx_dict, is_test=True)
        setup.save("./adv_data/test_meta.json", test_meta, message="adv test meta")

    setup.save("./adv_data/word_emb.json", word_emb_mat, message="word embedding")
    setup.save("./adv_data/char_emb.json", char_emb_mat, message="char embedding")
    setup.save("./adv_data/train_eval.json", train_eval, message="adv train eval")
    setup.save("./adv_data/dev_eval.json", dev_val, message="adv dev eval")
    setup.save("./adv_data/word2idx.json", word2idx_dict, message="word dictionary")
    setup.save("./adv_data/char2idx.json", char2idx_dict, message="char dictionary")
    setup.save("./adv_data/dev_meta.json", dev_meta, message="adv dev meta")
Example #6
0
def main():
    args, log = get_setup_args()

    train = flatten_json(args.trn_file, 'train')
    dev = flatten_json(args.dev_file, 'dev')
    test = flatten_json(args.tst_file,'test')
    log.info('json data flattened.')

    # tokenize & annotate
    with Pool(args.threads, initializer=init) as p:
        annotate_ = partial(annotate, wv_cased=args.wv_cased)
        train = list(tqdm(p.imap(annotate_, train, chunksize=args.batch_size), total=len(train), desc='train'))
        dev = list(tqdm(p.imap(annotate_, dev, chunksize=args.batch_size), total=len(dev), desc='dev'))
        test = list(tqdm(p.imap(annotate_, test, chunksize=args.batch_size), total=len(test), desc='test'))
    train = list(map(index_answer, train))
    initial_len = len(train)
    train = list(filter(lambda x: x[-1] is not None, train))
    log.info('drop {} inconsistent samples.'.format(initial_len - len(train)))
    log.info('tokens generated')

    # load vocabulary from word vector files
    wv_vocab = set()
    with open(args.wv_file) as f:
        for line in f:
            token = normalize_text(line.rstrip().split(' ')[0])
            wv_vocab.add(token)
    log.info('glove vocab loaded.')

    # build vocabulary
    full = train + dev + test
    vocab, counter = build_vocab([row[5] for row in full], [row[1] for row in full], wv_vocab, args.sort_all)
    total = sum(counter.values())
    matched = sum(counter[t] for t in vocab)
    log.info('vocab coverage {1}/{0} | OOV occurrence {2}/{3} ({4:.4f}%)'.format(
        len(counter), len(vocab), (total - matched), total, (total - matched) / total * 100))
    counter_tag = collections.Counter(w for row in full for w in row[3])
    vocab_tag = sorted(counter_tag, key=counter_tag.get, reverse=True)
    counter_ent = collections.Counter(w for row in full for w in row[4])
    vocab_ent = sorted(counter_ent, key=counter_ent.get, reverse=True)
    w2id = {w: i for i, w in enumerate(vocab)}
    tag2id = {w: i for i, w in enumerate(vocab_tag)}
    ent2id = {w: i for i, w in enumerate(vocab_ent)}
    log.info('Vocabulary size: {}'.format(len(vocab)))
    log.info('Found {} POS tags.'.format(len(vocab_tag)))
    log.info('Found {} entity tags: {}'.format(len(vocab_ent), vocab_ent))

    to_id_ = partial(to_id, w2id=w2id, tag2id=tag2id, ent2id=ent2id)
    train = list(map(to_id_, train))
    dev = list(map(to_id_, dev))
    test = list(map(to_id_,test))
    log.info('converted to ids.')

    vocab_size = len(vocab)
    embeddings = np.zeros((vocab_size, args.wv_dim))
    embed_counts = np.zeros(vocab_size)
    embed_counts[:2] = 1  # PADDING & UNK
    with open(args.wv_file) as f:
        for line in f:
            elems = line.rstrip().split(' ')
            token = normalize_text(elems[0])
            if token in w2id:
                word_id = w2id[token]
                embed_counts[word_id] += 1
                embeddings[word_id] += [float(v) for v in elems[1:]]
    embeddings /= embed_counts.reshape((-1, 1))
    log.info('got embedding matrix.')

    meta = {
        'vocab': vocab,
        'vocab_tag': vocab_tag,
        'vocab_ent': vocab_ent,
        'embedding': embeddings.tolist(),
        'wv_cased': args.wv_cased,
    }
    with open('data/meta.msgpack', 'wb') as f:
        msgpack.dump(meta, f)
    result = {
        'train': train,
        'dev': dev,
        'test': test
    }
    # train: id, context_id, context_features, tag_id, ent_id,
    #        question_id, context, context_token_span, answer_start, answer_end
    # dev:   id, context_id, context_features, tag_id, ent_id,
    #        question_id, context, context_token_span, answer
    # test:   id, context_id, context_features, tag_id, ent_id,
    #        question_id, context, context_token_span, answer
    with open('data/data.msgpack', 'wb') as f:
        msgpack.dump(result, f)
    if args.sample_size:
        sample = {
            'train': train[:args.sample_size],
            'dev': dev[:args.sample_size],
            'test': test[:args.sample_size]
        }
        with open('data/sample.msgpack', 'wb') as f:
            msgpack.dump(sample, f)
    log.info('saved to disk.')
Example #7
0
def main():
    #set_random_seed()

    # Change the following comments for CPU
    #device, gpu_ids = util.get_available_devices()
    device = torch.device('cpu')

    # Arguments
    opt = args.get_setup_args()

    num_classes = opt.num_classes
    noise_dim = opt.latent_dim + opt.num_classes

    train_images_path = os.path.join(opt.data_path, "train")
    output_train_images_path = train_images_path + "_" + str(opt.img_size)
    output_sample_images_path = os.path.join(opt.output_path, opt.version,
                                             "sample_eval")
    output_nn_pixel_images_path = os.path.join(opt.output_path, opt.version,
                                               "nn_eval_pixel")
    output_nn_inception_images_path = os.path.join(opt.output_path,
                                                   opt.version,
                                                   "nn_eval_inception")

    os.makedirs(output_sample_images_path, exist_ok=True)
    os.makedirs(output_nn_pixel_images_path, exist_ok=True)

    #os.makedirs(output_nn_inception_images_path, exist_ok=True)

    def get_nn_pixels(sample_images, train_images):
        nn = [None] * len(sample_images)
        pdist = torch.nn.PairwiseDistance(p=2)
        N, C, H, W = train_images.shape
        for i in range(len(sample_images)):
            sample_image = sample_images[i].unsqueeze(0)
            sample_image = torch.cat(N * [sample_image])
            distances = pdist(sample_image.view(-1, C * H * W),
                              train_images.view(-1, C * H * W))
            min_index = torch.argmin(distances)
            nn[i] = train_images[min_index]

        r = torch.stack(nn, dim=0).squeeze().to(device)
        return r

    def get_nn_inception(sample_activations, train_activations, train_images):
        nn = [None] * len(sample_activations)
        pdist = torch.nn.PairwiseDistance(p=2)
        N = train_activations.size(0)
        for i in range(len(sample_activations)):
            sample_act = sample_activations[i].unsqueeze(0)
            sample_act = torch.cat(N * [sample_act])
            distances = pdist(sample_act, train_activations)
            min_index = torch.argmin(distances)
            nn[i] = train_images[min_index]

        r = torch.stack(nn, dim=0).squeeze().to(device)
        return r

    def get_nearest_neighbour_pixels(sample_images, num_images, train_images,
                                     train_labels):
        all_nn = []
        for i in range(num_classes):
            train_imgs = train_images[train_labels[:] == i]
            nearest_n = get_nn_pixels(
                sample_images[i * num_images:(i + 1) * num_images], train_imgs)
            class_nn = torch.stack([
                sample_images[i * num_images:(i + 1) * num_images], nearest_n
            ],
                                   dim=0).squeeze().view(
                                       -1, 3, opt.img_size,
                                       opt.img_size).to(device)
            all_nn.append(class_nn)
        #r = torch.stack(nn, dim=0).squeeze().view(-1, 3, opt.img_size, opt.img_size).to(device)
        #print(r.shape)
        return all_nn

    def get_nearest_neighbour_inception(sample_images, num_images,
                                        train_images, train_labels):
        print("Getting sample activations...")
        sample_activations = fid_score.get_activations_given_path(
            output_sample_images_path, opt.batch_size, device)
        sample_activations = torch.from_numpy(sample_activations).type(
            torch.FloatTensor).to(device)

        print("Getting train activations...")
        train_activations = fid_score.get_activations_given_path(
            output_train_images_path, opt.batch_size, device)
        train_activations = torch.from_numpy(train_activations).type(
            torch.FloatTensor).to(device)

        all_nn = []
        for i in range(num_classes):
            train_imgs = train_images[train_labels[:] == i]
            train_act = train_activations[train_labels[:] == i]
            nearest_n = get_nn_inception(
                sample_activations[i * num_images:(i + 1) * num_images],
                train_act, train_images)
            class_nn = torch.stack([
                sample_images[i * num_images:(i + 1) * num_images], nearest_n
            ],
                                   dim=0).squeeze().view(
                                       -1, 3, opt.img_size,
                                       opt.img_size).to(device)
            all_nn.append(class_nn)
        #r = torch.stack(nn, dim=0).squeeze().view(-1, 3, opt.img_size, opt.img_size).to(device)
        #print(r.shape)
        return all_nn

    def get_onehot_labels(num_images):
        labels = torch.zeros(num_images, 1).to(device)
        for i in range(num_classes - 1):
            temp = torch.ones(num_images, 1).to(device) + i
            labels = torch.cat([labels, temp], 0)

        labels_onehot = torch.zeros(num_images * num_classes,
                                    num_classes).to(device)
        labels_onehot.scatter_(1, labels.to(torch.long), 1)

        return labels_onehot

    def sample_images(num_images, itr):
        '''
        labels = torch.zeros((num_classes * num_images,), dtype=torch.long).to(device)

        for i in range(num_classes):
            for j in range(num_images):
                labels[i*num_images + j] = i
        
        labels_onehot = F.one_hot(labels, num_classes)        
        '''

        train_set = datasets.ImageFolder(root=train_images_path,
                                         transform=transforms.Compose([
                                             transforms.Resize(
                                                 (opt.img_size, opt.img_size)),
                                             transforms.ToTensor(),
                                             transforms.Normalize(
                                                 (0.5, 0.5, 0.5),
                                                 (0.5, 0.5, 0.5))
                                         ]))
        '''
        source_images_available = True

        if (not os.path.exists(output_train_images_path)):
            os.makedirs(output_train_images_path)
            source_images_available = False
        
        
        
        
        if (not source_images_available):
            train_loader = torch.utils.data.DataLoader(train_set,
                                            batch_size=1,
                                            num_workers=opt.num_workers)
        else:
            train_loader = torch.utils.data.DataLoader(train_set,
                                            batch_size=opt.batch_size,
                                            num_workers=opt.num_workers)
        '''

        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=opt.batch_size,
                                                   num_workers=opt.num_workers)

        train_images = torch.FloatTensor().to(device)
        train_labels = torch.LongTensor().to(device)

        print("Loading train images...")

        for i, data in enumerate(train_loader, 0):
            img, label = data
            img = img.to(device)
            label = label.to(device)
            train_images = torch.cat([train_images, img], 0)
            train_labels = torch.cat([train_labels, label], 0)
            #if (not source_images_available):
            #    vutils.save_image(img, "{}/{}.jpg".format(output_train_images_path, i), normalize=True)

        print(
            "Estimating nearest neighbors in pixel space, this takes a few minutes..."
        )

        for it in range(itr):
            z = torch.randn(
                (num_classes * num_images, opt.latent_dim)).to(device)
            labels_onehot = get_onehot_labels(num_images)
            z = torch.cat((z, labels_onehot.to(dtype=torch.float)), 1)
            sample_imgs = gen(z)
            for i in range(len(sample_imgs)):
                vutils.save_image(sample_imgs[i],
                                  "{}/{}.png".format(output_sample_images_path,
                                                     i),
                                  normalize=True)
            nearest_neighbour_imgs_list = get_nearest_neighbour_pixels(
                sample_imgs, num_images, train_images, train_labels)
            for label, nn_imgs in enumerate(nearest_neighbour_imgs_list):
                vutils.save_image(nn_imgs.data,
                                  "{}/iter{}-{}.png".format(
                                      output_nn_pixel_images_path, it, label),
                                  nrow=num_images,
                                  padding=2,
                                  normalize=True)
        print("Saved nearest neighbors.")
        '''
        print("Estimating nearest neighbors in feature space, this takes a few minutes...")
        nearest_neighbour_imgs_list = get_nearest_neighbour_inception(sample_imgs, num_images, train_images, train_labels)
        for label, nn_imgs in enumerate(nearest_neighbour_imgs_list):
            vutils.save_image(nn_imgs.data, "{}/{}.png".format(output_nn_inception_images_path, label), nrow=num_images, padding=2, normalize=True)
        print("Saved nearest neighbors.")
        '''

    def eval_fid(gen_images_path, eval_images_path):
        print("Calculating FID...")
        fid = fid_score.calculate_fid_given_paths(
            (gen_images_path, eval_images_path), opt.batch_size, device)
        return fid

    def evaluate(source_images_path, keep_images=True):
        dataset = datasets.ImageFolder(root=source_images_path,
                                       transform=transforms.Compose([
                                           transforms.Resize(
                                               (opt.img_size, opt.img_size)),
                                           transforms.ToTensor()
                                       ]))

        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 num_workers=opt.num_workers)

        output_gen_images_path = os.path.join(opt.output_path, opt.version,
                                              opt.eval_mode)
        os.makedirs(output_gen_images_path, exist_ok=True)

        output_source_images_path = source_images_path + "_" + str(
            opt.img_size)

        source_images_available = True

        if (not os.path.exists(output_source_images_path)):
            os.makedirs(output_source_images_path)
            source_images_available = False

        images_done = 0
        for _, data in enumerate(dataloader, 0):
            images, labels = data
            batch_size = images.size(0)
            noise = torch.randn((batch_size, opt.latent_dim)).to(device)
            labels = torch.randint(0, num_classes, (batch_size, )).to(device)
            labels_onehot = F.one_hot(labels, num_classes)

            noise = torch.cat((noise, labels_onehot.to(dtype=torch.float)), 1)
            gen_images = gen(noise)
            for i in range(images_done, images_done + batch_size):
                vutils.save_image(gen_images[i - images_done, :, :, :],
                                  "{}/{}.jpg".format(output_gen_images_path,
                                                     i),
                                  normalize=True)
                if (not source_images_available):
                    vutils.save_image(images[i - images_done, :, :, :],
                                      "{}/{}.jpg".format(
                                          output_source_images_path, i),
                                      normalize=True)
            images_done += batch_size

        fid = eval_fid(output_gen_images_path, output_source_images_path)
        if (not keep_images):
            print("Deleting images generated for validation...")
            rmtree(output_gen_images_path)
        return fid

    test_images_path = os.path.join(opt.data_path, "test")
    val_images_path = os.path.join(opt.data_path, "val")
    model_path = os.path.join(opt.output_path, opt.version, opt.model_file)

    gen = acgan.Generator(noise_dim).to(device)

    if (opt.model_file.endswith(".pt")):
        gen.load_state_dict(torch.load(model_path, map_location=device))
    elif (opt.model_file.endswith(".tar")):
        checkpoint = torch.load(model_path, map_location=device)
        gen.load_state_dict(checkpoint['g_state_dict'])

    gen.eval()

    if opt.eval_mode == "val":
        source_images_path = val_images_path
    elif opt.eval_mode == "test":
        source_images_path = test_images_path

    if opt.eval_mode == "val" or opt.eval_mode == "test":
        print("Evaluating model...")
        fid = evaluate(source_images_path)
        print("FID: {}".format(fid))
    elif opt.eval_mode == "nn":
        sample_images(opt.num_sample_images, 50)
Example #8
0
def main():

    set_random_seed()

    # Arguments
    opt = args.get_setup_args()

    #cuda = True if torch.cuda.is_available() else False
    device, gpu_ids = util.get_available_devices()

    num_classes = opt.num_classes
    noise_dim = opt.latent_dim + opt.num_classes

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    train_images_path = os.path.join(opt.data_path, "train")
    val_images_path = os.path.join(opt.data_path, "val")
    output_model_path = os.path.join(opt.output_path, opt.version)
    output_train_images_path = os.path.join(opt.output_path, opt.version,
                                            "train")
    output_sample_images_path = os.path.join(opt.output_path, opt.version,
                                             "sample")
    output_nn_images_path = os.path.join(opt.output_path, opt.version, "nn")
    output_const_images_path = os.path.join(opt.output_path, opt.version,
                                            "constant_sample")

    os.makedirs(output_train_images_path, exist_ok=True)
    os.makedirs(output_sample_images_path, exist_ok=True)
    os.makedirs(output_nn_images_path, exist_ok=True)
    os.makedirs(output_const_images_path, exist_ok=True)

    train_set = datasets.ImageFolder(root=train_images_path,
                                     transform=transforms.Compose([
                                         transforms.Resize(
                                             (opt.img_size, opt.img_size)),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5),
                                                              (0.5, 0.5, 0.5))
                                     ]))

    dataloader = torch.utils.data.DataLoader(train_set,
                                             batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=opt.num_workers)

    dataloader_nn = torch.utils.data.DataLoader(train_set,
                                                batch_size=1,
                                                num_workers=opt.num_workers)

    gen = fcgan.Generator(noise_dim).to(device)
    disc = fcgan.Discriminator(num_classes).to(device)

    gen.apply(weights_init)
    disc.apply(weights_init)

    optimG = optim.Adam(gen.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optimD = optim.Adam(disc.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    #optimD = optim.SGD(disc.parameters(), lr=opt.lr_sgd)

    adversarial_loss = torch.nn.BCELoss()
    auxiliary_loss = torch.nn.CrossEntropyLoss()

    real_label_val = 1
    #real_label_smooth_val = 0.9
    real_label_low = 0.75
    real_label_high = 1.0
    fake_label_val = 0
    c_fake_label = opt.num_classes

    # Probability of adding label noise during discriminator training
    label_noise_prob = 0.05

    # Keep track of losses, accuracy, FID
    G_losses = []
    D_losses = []
    D_acc = []
    FIDs = []
    val_epochs = []

    # Define a fixed noise vector for consistent samples
    z_const = torch.randn(
        (num_classes * opt.num_sample_images, opt.latent_dim)).to(device)

    def print_labels():
        for class_name in train_set.classes:
            print("{} -> {}".format(class_name,
                                    train_set.class_to_idx[class_name]))

    def eval_fid(gen_images_path, eval_images_path):
        print("Calculating FID...")
        fid = fid_score.calculate_fid_given_paths(
            (gen_images_path, eval_images_path), opt.batch_size, device)
        return fid

    def validate(keep_images=True):
        # Put G in eval mode
        gen.eval()

        val_set = datasets.ImageFolder(root=val_images_path,
                                       transform=transforms.Compose([
                                           transforms.Resize(
                                               (opt.img_size, opt.img_size)),
                                           transforms.ToTensor()
                                       ]))

        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 num_workers=opt.num_workers)

        output_images_path = os.path.join(opt.output_path, opt.version, "val")
        os.makedirs(output_images_path, exist_ok=True)

        output_source_images_path = val_images_path + "_" + str(opt.img_size)
        source_images_available = True
        if (not os.path.exists(output_source_images_path)):
            os.makedirs(output_source_images_path)
            source_images_available = False

        images_done = 0
        for _, data in enumerate(val_loader, 0):
            images, labels = data
            batch_size = images.size(0)
            noise = torch.randn((batch_size, opt.latent_dim)).to(device)
            labels = torch.randint(0, num_classes, (batch_size, )).to(device)
            labels_onehot = F.one_hot(labels, num_classes)

            noise = torch.cat((noise, labels_onehot.to(dtype=torch.float)), 1)
            gen_images = gen(noise)
            for i in range(images_done, images_done + batch_size):
                vutils.save_image(gen_images[i - images_done, :, :, :],
                                  "{}/{}.jpg".format(output_images_path, i),
                                  normalize=True)
                if (not source_images_available):
                    vutils.save_image(images[i - images_done, :, :, :],
                                      "{}/{}.jpg".format(
                                          output_source_images_path, i),
                                      normalize=True)
            images_done += batch_size

        # Put G back in train mode
        gen.train()

        fid = eval_fid(output_images_path, output_source_images_path)
        if (not keep_images):
            print("Deleting images generated for validation...")
            rmtree(output_images_path)
        return fid

    def get_dist(img1, img2):
        return torch.dist(img1, img2, p=1)

    def get_nn(images, class_label):
        nn = [None] * len(images)
        dist = [np.inf] * len(images)
        for e, data in enumerate(dataloader_nn, 0):
            img, label = data
            if label != class_label:
                continue
            img = img.to(device)
            for i in range(len(images)):
                d = get_dist(images[i], img)
                if d < dist[i]:
                    dist[i] = d
                    nn[i] = img
        r = torch.stack(nn, dim=0).squeeze().to(device)
        #print(r.shape)
        return r

    def get_nearest_neighbour(sample_images, num_images):
        all_nn = []
        for i in range(num_classes):
            nearest_n = get_nn(
                sample_images[i * num_images:(i + 1) * num_images], i)
            class_nn = torch.stack([
                sample_images[i * num_images:(i + 1) * num_images], nearest_n
            ],
                                   dim=0).squeeze().view(
                                       -1, 3, opt.img_size,
                                       opt.img_size).to(device)
            all_nn.append(class_nn)
        #r = torch.stack(nn, dim=0).squeeze().view(-1, 3, opt.img_size, opt.img_size).to(device)
        #print(r.shape)
        return all_nn

    def get_onehot_labels(num_images):
        labels = torch.zeros(num_images, 1).to(device)
        for i in range(num_classes - 1):
            temp = torch.ones(num_images, 1).to(device) + i
            labels = torch.cat([labels, temp], 0)

        labels_onehot = torch.zeros(num_images * num_classes,
                                    num_classes).to(device)
        labels_onehot.scatter_(1, labels.to(torch.long), 1)

        return labels_onehot

    def sample_images(num_images, batches_done, isLast):
        # Sample noise - declared once at the top to maintain consistency of samples
        z = torch.randn((num_classes * num_images, opt.latent_dim)).to(device)
        '''
        labels = torch.zeros((num_classes * num_images,), dtype=torch.long).to(device)

        for i in range(num_classes):
            for j in range(num_images):
                labels[i*num_images + j] = i
        
        labels_onehot = F.one_hot(labels, num_classes)        
        '''

        labels_onehot = get_onehot_labels(num_images)
        z = torch.cat((z, labels_onehot.to(dtype=torch.float)), 1)
        sample_imgs = gen(z)
        z_const_cat = torch.cat((z_const, labels_onehot.to(dtype=torch.float)),
                                1)
        const_sample_imgs = gen(z_const_cat)
        vutils.save_image(sample_imgs.data,
                          "{}/{}.png".format(output_sample_images_path,
                                             batches_done),
                          nrow=num_images,
                          padding=2,
                          normalize=True)
        vutils.save_image(const_sample_imgs.data,
                          "{}/{}.png".format(output_const_images_path,
                                             batches_done),
                          nrow=num_images,
                          padding=2,
                          normalize=True)

        if isLast:
            print(
                "Estimating nearest neighbors for the last samples, this takes a few minutes..."
            )
            nearest_neighbour_imgs_list = get_nearest_neighbour(
                sample_imgs, num_images)
            for label, nn_imgs in enumerate(nearest_neighbour_imgs_list):
                vutils.save_image(nn_imgs.data,
                                  "{}/{}_{}.png".format(
                                      output_nn_images_path, batches_done,
                                      label),
                                  nrow=num_images,
                                  padding=2,
                                  normalize=True)
            nearest_neighbour_imgs_list = get_nearest_neighbour(
                const_sample_imgs, num_images)
            for label, nn_imgs in enumerate(nearest_neighbour_imgs_list):
                vutils.save_image(nn_imgs.data,
                                  "{}/const_{}_{}.png".format(
                                      output_nn_images_path, batches_done,
                                      label),
                                  nrow=num_images,
                                  padding=2,
                                  normalize=True)
            print("Saved nearest neighbors.")

    def save_loss_plot(path):
        plt.figure(figsize=(10, 5))
        plt.title("Generator and Discriminator Loss During Training")
        plt.plot(G_losses, label="G")
        plt.plot(D_losses, label="D")
        plt.xlabel("iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(path)
        plt.close()

    def save_acc_plot(path):
        plt.figure(figsize=(10, 5))
        plt.title("Discriminator Accuracy")
        plt.plot(D_acc)
        plt.xlabel("iterations")
        plt.ylabel("accuracy")
        plt.savefig(path)
        plt.close()

    def save_fid_plot(FIDs, epochs, path):
        #N = len(FIDs)
        plt.figure(figsize=(10, 5))
        plt.title("FID on Validation Set")
        plt.plot(epochs, FIDs)
        plt.xlabel("epochs")
        plt.ylabel("FID")
        #plt.xticks([i * 49 for i in range(1, N+1)])
        plt.savefig(path)
        plt.close()

    def expectation_loss(real_feature, fake_feature):
        norm = torch.norm(real_feature - fake_feature)
        total = torch.abs(norm).sum()
        return norm / total

    print("Label to class mapping:")
    print_labels()

    for epoch in range(1, opt.num_epochs + 1):
        for i, data in enumerate(dataloader, 0):

            images, class_labels = data
            images = images.to(device)
            class_labels = class_labels.to(device)

            batch_size = images.size(0)

            #real_label_smooth = torch.full((batch_size,), real_label_smooth_val, device=device)
            real_label_smooth = (
                real_label_low - real_label_high) * torch.rand(
                    (batch_size, ), device=device) + real_label_high
            real_label = torch.full((batch_size, ),
                                    real_label_val,
                                    device=device)
            fake_label = torch.full((batch_size, ),
                                    fake_label_val,
                                    device=device)

            ############################
            # Train Discriminator
            ###########################

            ## Train with all-real batch

            optimD.zero_grad()

            real_pred, real_aux = disc(images)

            mask = torch.rand(
                (batch_size, ), device=device) <= label_noise_prob
            mask = mask.type(torch.float)
            noisy_label = torch.mul(1 - mask, real_label_smooth) + torch.mul(
                mask, fake_label)

            d_real_loss = (adversarial_loss(real_pred, noisy_label) +
                           auxiliary_loss(real_aux, class_labels)) / 2

            # Train with fake batch
            noise = torch.randn((batch_size, opt.latent_dim)).to(device)
            gen_class_labels = torch.randint(0, num_classes,
                                             (batch_size, )).to(device)
            gen_class_labels_onehot = F.one_hot(gen_class_labels, num_classes)

            noise = torch.cat(
                (noise, gen_class_labels_onehot.to(dtype=torch.float)), 1)
            gen_images = gen(noise)
            fake_pred, fake_aux = disc(gen_images.detach())

            mask = torch.rand(
                (batch_size, ), device=device) <= label_noise_prob
            mask = mask.type(torch.float)
            noisy_label = torch.mul(1 - mask, fake_label) + torch.mul(
                mask, real_label_smooth)

            c_fake = c_fake_label * torch.ones_like(gen_class_labels).to(
                device)
            d_fake_loss = (adversarial_loss(fake_pred, noisy_label) +
                           auxiliary_loss(fake_aux, c_fake)) / 2

            # Total discriminator loss
            d_loss = (d_real_loss + d_fake_loss) / 2

            # Calculate discriminator accuracy
            pred = np.concatenate(
                [real_aux.data.cpu().numpy(),
                 fake_aux.data.cpu().numpy()],
                axis=0)
            gt = np.concatenate([
                class_labels.data.cpu().numpy(),
                gen_class_labels.data.cpu().numpy()
            ],
                                axis=0)
            d_acc = np.mean(np.argmax(pred, axis=1) == gt)

            d_loss.backward()
            optimD.step()

            ############################
            # Train Generator
            ###########################

            optimG.zero_grad()

            validity, aux_scores = disc(gen_images)
            g_loss = 0.5 * (adversarial_loss(validity, real_label) +
                            auxiliary_loss(aux_scores, gen_class_labels)
                            )  # + expectation_loss(gen_features, r_f1)

            g_loss.backward()
            optimG.step()

            # Save losses and accuracy for plotting
            G_losses.append(g_loss.item())
            D_losses.append(d_loss.item())
            D_acc.append(d_acc)

            # Output training stats
            if i % opt.print_every == 0:
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f, acc:  %d%%] [G loss: %.4f]"
                    % (epoch, opt.num_epochs, i, len(dataloader),
                       d_loss.item(), 100 * d_acc, g_loss.item()))

            batches_done = epoch * len(dataloader) + i

            # Generate and save sample images
            isLast = ((epoch == opt.num_epochs - 1)
                      and (i == len(dataloader) - 1))
            if (batches_done % opt.sample_interval == 0) or isLast:
                # Put G in eval mode
                gen.eval()

                with torch.no_grad():
                    sample_images(opt.num_sample_images, batches_done, isLast)
                vutils.save_image(gen_images.data[:36],
                                  "{}/{}.png".format(output_train_images_path,
                                                     batches_done),
                                  nrow=6,
                                  padding=2,
                                  normalize=True)

                # Put G back in train mode
                gen.train()

        # Save model checkpoint
        if (epoch != opt.num_epochs and epoch % opt.checkpoint_epochs == 0):
            print("Checkpoint at epoch {}".format(epoch))

            print("Saving G & D loss plot...")
            save_loss_plot(
                os.path.join(opt.output_path, opt.version,
                             "loss_plot_{}.png".format(epoch)))
            print("Saving D accuracy plot...")
            save_acc_plot(
                os.path.join(opt.output_path, opt.version,
                             "accuracy_plot_{}.png".format(epoch)))

            print("Validating model...")
            with torch.no_grad():
                fid = validate(keep_images=False)
            print("Validation FID: {}".format(fid))
            with open(os.path.join(opt.output_path, opt.version, "FIDs.txt"),
                      "a") as f:
                f.write("Epoch: {}, FID: {}\n".format(epoch, fid))
            FIDs.append(fid)
            val_epochs.append(epoch)
            print("Saving FID plot...")
            save_fid_plot(
                FIDs, val_epochs,
                os.path.join(opt.output_path, opt.version,
                             "fid_plot_{}.png".format(epoch)))

            print("Saving model checkpoint...")
            torch.save(
                {
                    'epoch': epoch,
                    'g_state_dict': gen.state_dict(),
                    'd_state_dict': disc.state_dict(),
                    'g_optimizer_state_dict': optimG.state_dict(),
                    'd_optimizer_state_dict': optimD.state_dict(),
                    'g_loss': g_loss.item(),
                    'd_loss': d_loss.item(),
                    'd_accuracy': d_acc,
                    'val_fid': fid
                },
                os.path.join(output_model_path,
                             "model_checkpoint_{}.tar".format(epoch)))

    print("Saving final G & D loss plot...")
    save_loss_plot(os.path.join(opt.output_path, opt.version, "loss_plot.png"))
    print("Done!")

    print("Saving final D accuracy plot...")
    save_acc_plot(
        os.path.join(opt.output_path, opt.version, "accuracy_plot.png"))
    print("Done!")

    print("Validating final model...")
    gen.eval()
    with torch.no_grad():
        fid = validate()
    print("Final Validation FID: {}".format(fid))
    with open(os.path.join(opt.output_path, opt.version, "FIDs.txt"),
              "a") as f:
        f.write("Epoch: {}, FID: {}\n".format(epoch, fid))
    FIDs.append(fid)
    val_epochs.append(epoch)
    print("Saving final FID plot...")
    save_fid_plot(FIDs, val_epochs,
                  os.path.join(opt.output_path, opt.version, "fid_plot"))
    print("Done!")

    print("Saving final model...")
    torch.save(
        {
            'epoch': epoch,
            'g_state_dict': gen.state_dict(),
            'd_state_dict': disc.state_dict(),
            'g_optimizer_state_dict': optimG.state_dict(),
            'd_optimizer_state_dict': optimD.state_dict(),
            'g_loss': g_loss.item(),
            'd_loss': d_loss.item(),
            'd_accuracy': d_acc,
            'val_fid': fid
        }, os.path.join(output_model_path, "model.tar"))
    print("Done!")
Example #9
0
def char2idx_dic():
	char2idx_file_path = get_setup_args().char2idx_file
	with open(char2idx_file_path) as char2idx_data:
		dic = json.load(char2idx_data)
	return dic
Example #10
0
import ujson as json
import numpy as np
import spacy
import setup
from args import get_setup_args
from collections import Counter
import time
import torch

args = get_setup_args()

# word_emb_mat, word2idx_dict = setup.get_embedding(word_counter, 'word', emb_file=args.glove_file, vec_size=args.glove_dim, num_vectors=args.glove_num_vecs)
# tokens = nlp(words)

# for token in tokens:
#     # Printing the following attributes of each token.
#     # text: the word string, has_vector: if it contains
#     # a vector representation in the model,
#     # vector_norm: the algebraic norm of the vector,
#     # is_oov: if the word is out of vocabulary.
#     print(token.text, token.has_vector, token.vector_norm, token.is_oov)

# token1, token2 = tokens[0], tokens[1]

# print("Similarity:", token1.similarity(token2))
# word2idx_path = './data/word2idx.json'
# word2idx_file = open(word2idx_path, 'r')
# word2idx = json.load(word2idx_file)
# idx2word = {v: k for k, v in word2idx.items()}
# EXAMPLE_INDICES = range(100) #This is the word 'performance'
# index = 753
Example #11
0
def main():
    set_random_seed()

    device, gpu_ids = util.get_available_devices()

    # Arguments
    opt = args.get_setup_args()

    # Number of channels in the training images
    nc = opt.channels
    # Size of z latent vector (i.e. size of generator input)
    nz = opt.latent_dim
    # Size of feature maps in generator
    ngf = 64

    def eval_fid(gen_images_path, eval_images_path):
        print("Calculating FID...")
        fid = fid_score.calculate_fid_given_paths(
            (gen_images_path, eval_images_path), opt.batch_size, device)
        return fid

    def evaluate(source_images_path, keep_images=True):
        dataset = datasets.ImageFolder(root=source_images_path,
                                       transform=transforms.Compose([
                                           transforms.Resize(
                                               (opt.img_size, opt.img_size)),
                                           transforms.ToTensor()
                                       ]))

        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 num_workers=opt.num_workers)

        output_gen_images_path = os.path.join(opt.output_path, opt.version,
                                              opt.eval_mode)
        os.makedirs(output_gen_images_path, exist_ok=True)

        output_source_images_path = source_images_path + "_" + str(
            opt.img_size)

        source_images_available = True

        if (not os.path.exists(output_source_images_path)):
            os.makedirs(output_source_images_path)
            source_images_available = False

        images_done = 0
        for _, data in enumerate(dataloader, 0):
            images, _ = data
            batch_size = images.size(0)
            noise = torch.randn((batch_size, nz, 1, 1)).to(device)

            gen_images = netG(noise)
            for i in range(images_done, images_done + batch_size):
                vutils.save_image(gen_images[i - images_done, :, :, :],
                                  "{}/{}.jpg".format(output_gen_images_path,
                                                     i),
                                  normalize=True)
                if (not source_images_available):
                    vutils.save_image(images[i - images_done, :, :, :],
                                      "{}/{}.jpg".format(
                                          output_source_images_path, i),
                                      normalize=True)
            images_done += batch_size

        fid = eval_fid(output_gen_images_path, output_source_images_path)
        if (not keep_images):
            print("Deleting images generated for validation...")
            rmtree(output_gen_images_path)
        return fid

    test_images_path = os.path.join(opt.data_path, "test")
    val_images_path = os.path.join(opt.data_path, "val")
    model_path = os.path.join(opt.output_path, opt.version, opt.model_file)

    netG = dcgan.Generator(nc, nz, ngf).to(device)

    if (opt.model_file.endswith(".pt")):
        netG.load_state_dict(torch.load(model_path))
    elif (opt.model_file.endswith(".tar")):
        checkpoint = torch.load(model_path)
        netG.load_state_dict(checkpoint['g_state_dict'])

    netG.eval()

    if opt.eval_mode == "val":
        source_images_path = val_images_path
    else:
        source_images_path = test_images_path

    if opt.eval_mode == "val" or opt.eval_mode == "test":
        print("Evaluating model...")
        fid = evaluate(source_images_path)
        print("FID: {}".format(fid))
Example #12
0
def main():
    cuda = True if torch.cuda.is_available() else False
    device, _ = util.get_available_devices()
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

    # Arguments
    opt = args.get_setup_args()

    img_shape = (opt.channels, opt.img_size, opt.img_size)

    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()

            self.label_emb = nn.Embedding(opt.num_classes, opt.num_classes)

            def block(in_feat, out_feat, normalize=True):
                layers = [nn.Linear(in_feat, out_feat)]
                if normalize:
                    layers.append(nn.BatchNorm1d(out_feat, 0.8))
                layers.append(nn.LeakyReLU(0.2, inplace=True))
                return layers

            self.model = nn.Sequential(
                *block(opt.latent_dim + opt.num_classes, 128, normalize=False),
                *block(128, 256), *block(256, 512), *block(512, 1024),
                nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh())

        def forward(self, noise, labels):
            # Concatenate label embedding and image to produce input
            gen_input = torch.cat((self.label_emb(labels), noise), -1)
            img = self.model(gen_input)
            img = img.view(img.size(0), *img_shape)
            return img

    def eval_fid(fake_images):
        output_images_path = os.path.join(opt.output_path, opt.version, "test")
        os.makedirs(output_images_path, exist_ok=True)
        print("Saving images generated for testing...")
        for i in range(fake_images.size(0)):
            save_image(fake_images[i, :, :, :],
                       "{}/{}.jpg".format(output_images_path, i))
        print("Calculating FID...")
        fid = fid_score.calculate_fid_given_paths(
            (output_images_path, test_images_path), opt.batch_size, device)
        return fid

    test_images_path = os.path.join(opt.data_path, "test")
    model_path = os.path.join(opt.output_path, opt.version)

    test_set = datasets.ImageFolder(root=test_images_path,
                                    transform=transforms.Compose([
                                        transforms.Resize(
                                            (opt.img_size, opt.img_size)),
                                        transforms.ToTensor()
                                    ]))

    netG = Generator()
    netG.cuda()
    netG.load_state_dict(torch.load(os.path.join(model_path, "model.pt")))
    netG.eval()

    noise = FloatTensor(np.random.normal(0, 1,
                                         (len(test_set), opt.latent_dim)))
    gen_labels = LongTensor(
        np.random.randint(0, opt.num_classes, len(test_set)))

    # Generate fake image batch with G
    gen_imgs = netG(noise, gen_labels)

    fid = eval_fid(gen_imgs)

    print("FID: {}".format(fid))
Example #13
0
def main():

    set_random_seed()

    # Arguments
    opt = args.get_setup_args()

    #cuda = True if torch.cuda.is_available() else False
    device, gpu_ids = util.get_available_devices()

    num_classes = opt.num_classes
    # Size of feature maps in generator
    ngf = 64
    # Size of feature maps in discriminator
    ndf = 64

    # label preprocess
    label_vals = [i for i in range(num_classes)]
    onehot = torch.zeros(num_classes, num_classes).to(device)
    onehot = onehot.scatter_(
        1,
        torch.LongTensor(label_vals).view(num_classes, 1).to(device),
        1).view(num_classes, num_classes, 1, 1)
    fill = torch.zeros([num_classes, num_classes, opt.img_size,
                        opt.img_size]).to(device)
    for i in range(num_classes):
        fill[i, i, :, :] = 1

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    train_images_path = os.path.join(opt.data_path, "train")
    val_images_path = os.path.join(opt.data_path, "val")
    output_model_path = os.path.join(opt.output_path, opt.version)
    output_train_images_path = os.path.join(opt.output_path, opt.version,
                                            "train")
    output_sample_images_path = os.path.join(opt.output_path, opt.version,
                                             "sample")

    os.makedirs(output_train_images_path, exist_ok=True)
    os.makedirs(output_sample_images_path, exist_ok=True)

    train_set = datasets.ImageFolder(root=train_images_path,
                                     transform=transforms.Compose([
                                         transforms.Resize(
                                             (opt.img_size, opt.img_size)),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5),
                                                              (0.5, 0.5, 0.5))
                                     ]))

    dataloader = torch.utils.data.DataLoader(train_set,
                                             batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=opt.num_workers)

    gen = cdcgan.Generator(opt.channels, opt.latent_dim, num_classes,
                           ngf).to(device)
    disc = cdcgan.Discriminator(opt.channels, opt.num_classes, ndf).to(device)

    gen.apply(weights_init)
    disc.apply(weights_init)

    optimG = optim.Adam(gen.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optimD = optim.Adam(disc.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

    adversarial_loss = torch.nn.BCELoss()

    real_label_val = 1
    real_label_low = 0.75
    real_label_high = 1.0
    fake_label_val = 0

    # Probability of adding label noise during discriminator training
    label_noise_prob = 0.05

    # Keep track of losses, accuracy, FID
    G_losses = []
    D_losses = []
    D_acc = []
    FIDs = []
    val_epochs = []

    def print_labels():
        for class_name in train_set.classes:
            print("{} -> {}".format(class_name,
                                    train_set.class_to_idx[class_name]))

    def eval_fid(gen_images_path, eval_images_path):
        print("Calculating FID...")
        fid = fid_score.calculate_fid_given_paths(
            (gen_images_path, eval_images_path), opt.batch_size, device)
        return fid

    def validate(keep_images=True):
        val_set = datasets.ImageFolder(root=val_images_path,
                                       transform=transforms.Compose([
                                           transforms.Resize(
                                               (opt.img_size, opt.img_size)),
                                           transforms.ToTensor()
                                       ]))

        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 num_workers=opt.num_workers)

        output_images_path = os.path.join(opt.output_path, opt.version, "val")
        os.makedirs(output_images_path, exist_ok=True)

        output_source_images_path = val_images_path + "_" + str(opt.img_size)

        source_images_available = True

        if (not os.path.exists(output_source_images_path)):
            os.makedirs(output_source_images_path)
            source_images_available = False

        images_done = 0
        for _, data in enumerate(val_loader, 0):
            images, labels = data
            batch_size = images.size(0)
            noise = torch.randn((batch_size, opt.latent_dim)).to(device)
            labels = torch.randint(0, num_classes, (batch_size, )).to(device)
            labels_onehot = F.one_hot(labels, num_classes)
            #labels_onehot = onehot[labels]

            noise = torch.cat((noise, labels_onehot.to(dtype=torch.float)),
                              1).view(-1, opt.latent_dim + num_classes, 1, 1)
            gen_images = gen(noise)
            for i in range(images_done, images_done + batch_size):
                vutils.save_image(gen_images[i - images_done, :, :, :],
                                  "{}/{}.jpg".format(output_images_path, i),
                                  normalize=True)
                if (not source_images_available):
                    vutils.save_image(images[i - images_done, :, :, :],
                                      "{}/{}.jpg".format(
                                          output_source_images_path, i),
                                      normalize=True)
            images_done += batch_size

        fid = eval_fid(output_images_path, output_source_images_path)
        if (not keep_images):
            print("Deleting images generated for validation...")
            rmtree(output_images_path)
        return fid

    def sample_images(num_images, batches_done):
        # Sample noise
        z = torch.randn((num_classes * num_images, opt.latent_dim)).to(device)
        # Get labels ranging from 0 to n_classes for n rows
        labels = torch.zeros((num_classes * num_images, ),
                             dtype=torch.long).to(device)

        for i in range(num_classes):
            for j in range(num_images):
                labels[i * num_images + j] = i

        labels_onehot = F.one_hot(labels, num_classes)
        #labels_onehot = onehot[labels]
        z = torch.cat((z, labels_onehot.to(dtype=torch.float)),
                      1).view(-1, opt.latent_dim + num_classes, 1, 1)
        sample_imgs = gen(z)
        vutils.save_image(sample_imgs.data,
                          "{}/{}.png".format(output_sample_images_path,
                                             batches_done),
                          nrow=num_images,
                          padding=2,
                          normalize=True)

    def save_loss_plot(path):
        plt.figure(figsize=(10, 5))
        plt.title("Generator and Discriminator Loss During Training")
        plt.plot(G_losses, label="G")
        plt.plot(D_losses, label="D")
        plt.xlabel("iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(path)
        plt.close()

    def save_acc_plot(path):
        plt.figure(figsize=(10, 5))
        plt.title("Discriminator Accuracy")
        plt.plot(D_acc)
        plt.xlabel("iterations")
        plt.ylabel("accuracy")
        plt.savefig(path)
        plt.close()

    def save_fid_plot(FIDs, epochs, path):
        #N = len(FIDs)
        plt.figure(figsize=(10, 5))
        plt.title("FID on Validation Set")
        plt.plot(epochs, FIDs)
        plt.xlabel("epochs")
        plt.ylabel("FID")
        #plt.xticks([i * 49 for i in range(1, N+1)])
        plt.savefig(path)
        plt.close()

    print("Label to class mapping:")
    print_labels()

    for epoch in range(1, opt.num_epochs + 1):
        for i, data in enumerate(dataloader, 0):

            images, class_labels = data
            images = images.to(device)
            class_labels = class_labels.to(device)
            class_labels_fill = fill[class_labels]

            batch_size = images.size(0)

            real_label_smooth = (
                real_label_low - real_label_high) * torch.rand(
                    (batch_size, ), device=device) + real_label_high
            real_label = torch.full((batch_size, ),
                                    real_label_val,
                                    device=device)
            fake_label = torch.full((batch_size, ),
                                    fake_label_val,
                                    device=device)

            ############################
            # Train Discriminator
            ###########################

            ## Train with all-real batch

            optimD.zero_grad()

            real_pred = disc(images, class_labels_fill).view(-1)

            mask = torch.rand(
                (batch_size, ), device=device) <= label_noise_prob
            mask = mask.type(torch.float)
            noisy_label = torch.mul(1 - mask, real_label_smooth) + torch.mul(
                mask, fake_label)

            d_real_loss = adversarial_loss(real_pred, noisy_label)

            # Train with fake batch
            noise = torch.randn((batch_size, opt.latent_dim)).to(device)
            gen_class_labels = torch.randint(0, num_classes,
                                             (batch_size, )).to(device)
            gen_class_labels_onehot = F.one_hot(gen_class_labels, num_classes)
            #gen_class_labels_onehot = onehot[gen_class_labels]
            gen_class_labels_fill = fill[gen_class_labels]

            noise = torch.cat(
                (noise, gen_class_labels_onehot.to(dtype=torch.float)),
                1).view(-1, opt.latent_dim + num_classes, 1, 1)
            gen_images = gen(noise)
            fake_pred = disc(gen_images.detach(),
                             gen_class_labels_fill).view(-1)

            mask = torch.rand(
                (batch_size, ), device=device) <= label_noise_prob
            mask = mask.type(torch.float)
            noisy_label = torch.mul(1 - mask, fake_label) + torch.mul(
                mask, real_label_smooth)

            d_fake_loss = adversarial_loss(fake_pred, noisy_label)

            # Total discriminator loss
            d_loss = d_real_loss + d_fake_loss

            d_loss.backward()
            optimD.step()

            ############################
            # Train Generator
            ###########################

            optimG.zero_grad()

            validity = disc(gen_images, gen_class_labels_fill).view(-1)
            g_loss = adversarial_loss(validity, real_label)

            g_loss.backward()
            optimG.step()

            # Save losses and accuracy for plotting
            G_losses.append(g_loss.item())
            D_losses.append(d_loss.item())

            # Output training stats
            if i % opt.print_every == 0:
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f]"
                    % (epoch, opt.num_epochs, i, len(dataloader),
                       d_loss.item(), g_loss.item()))

            batches_done = epoch * len(dataloader) + i

            # Generate and save sample images
            if (batches_done % opt.sample_interval
                    == 0) or ((epoch == opt.num_epochs - 1) and
                              (i == len(dataloader) - 1)):
                # Put G in eval mode
                gen.eval()

                with torch.no_grad():
                    sample_images(opt.num_sample_images, batches_done)
                vutils.save_image(gen_images.data[:36],
                                  "{}/{}.png".format(output_train_images_path,
                                                     batches_done),
                                  nrow=6,
                                  padding=2,
                                  normalize=True)

                # Put G back in train mode
                gen.train()

        # Save model checkpoint
        if (epoch != opt.num_epochs and epoch % opt.checkpoint_epochs == 0):
            print("Checkpoint at epoch {}".format(epoch))
            print("Saving generator model...")
            torch.save(
                gen.state_dict(),
                os.path.join(output_model_path,
                             "model_checkpoint_{}.pt".format(epoch)))
            print("Saving G & D loss plot...")
            save_loss_plot(
                os.path.join(opt.output_path, opt.version,
                             "loss_plot_{}.png".format(epoch)))
            print("Validating model...")
            gen.eval()
            with torch.no_grad():
                fid = validate(keep_images=False)
            print("Validation FID: {}".format(fid))
            FIDs.append(fid)
            val_epochs.append(epoch)
            print("Saving FID plot...")
            save_fid_plot(
                FIDs, val_epochs,
                os.path.join(opt.output_path, opt.version,
                             "fid_plot_{}.png".format(epoch)))
            gen.train()

    print("Saving final generator model...")
    torch.save(gen.state_dict(), os.path.join(output_model_path, "model.pt"))
    print("Done!")

    print("Saving final G & D loss plot...")
    save_loss_plot(os.path.join(opt.output_path, opt.version, "loss_plot.png"))
    print("Done!")

    print("Validating final model...")
    gen.eval()
    with torch.no_grad():
        fid = validate()
    print("Final Validation FID: {}".format(fid))
    FIDs.append(fid)
    val_epochs.append(epoch)
    print("Saving final FID plot...")
    save_fid_plot(FIDs, val_epochs,
                  os.path.join(opt.output_path, opt.version, "fid_plot"))
Example #14
0
    if args.download_lyrics:
        start = datetime.now(tz=TIMEZONE)
        stock_filename = 'Lyrics_' + args.artist_name.replace(' ', '') + '.json'
        stock_filename = sanitize_filename(stock_filename)
        print('{}| Beginning download'.format(start))
        genius = lyricsgenius.Genius(GENIUS_ACCESS_TOKEN, sleep_time=1)
        artist_tracks = force_search_artist(genius, args.artist_name, sleep_time=600,
                                            max_songs=None,
                                            sort='popularity', per_page=20,
                                            get_full_info=True,
                                            allow_name_change=True)
        print('{}| Finished download in {}'.format(datetime.now(tz=TIMEZONE), 
                                                   datetime.now(tz=TIMEZONE) - start))
        artist_tracks.save_lyrics()
    if args.load_path:
        if os.path.isdir(args.load_path):
            print('{}| Preparing mixed lyrics files'.format(datetime.now(tz=TIMEZONE)))
            text_and_target_from_dir(args.load_path, args.lookback)
            print('{}| Finished.'.format(datetime.now(tz=TIMEZONE)))
            sys.exit()
        else:
            genius_file = read_json(args.load_path)
    else:
        genius_file = read_json('./'+stock_filename)
    if not args.download_only:
        artist_lyrics = get_lyrics_from_json(genius_file, SONG_PART_REGEX)
        create_text_and_target(artist_lyrics, lookback=args.lookback)

if __name__ == "__main__":
    get_and_process_songs(get_setup_args())
Example #15
0
    # save eval files
    save(args.train_eval_file, train_eval, message="train eval")
    save(args.test_eval_file, test_eval, message="test eval")
    save(args.dev_eval_file, dev_eval, message="dev eval")


def init():
    """initialize spacy in each process"""
    global nlp
    nlp = spacy.load('en', parser=False)


if __name__ == '__main__':
    # Get command-line args
    args, log = get_setup_args()

    # Download resources
    # download(args)

    # Import spacy language model
    nlp = spacy.load("en", parser=False)

    # Preprocess dataset
    args.train_file = url_to_data_path(args.train_url)
    args.dev_file = url_to_data_path(args.dev_url)
    args.test_file = url_to_data_path(args.test_url)
    glove_dir = url_to_data_path(args.glove_url.replace('.zip', ''))
    glove_ext = f'.txt' if glove_dir.endswith(
        'd') else f'.{args.glove_dim}d.txt'
    args.glove_file = os.path.join(glove_dir,
Example #16
0
def main():

    set_random_seed()

    #cuda = True if torch.cuda.is_available() else False
    device, gpu_ids = util.get_available_devices()

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    # Arguments
    opt = args.get_setup_args()

    # Number of channels in the training images
    nc = opt.channels
    # Size of z latent vector (i.e. size of generator input)
    nz = opt.latent_dim
    # Size of feature maps in generator
    ngf = 64
    # Size of feature maps in discriminator
    ndf = 64

    num_classes = opt.num_classes

    train_images_path = os.path.join(opt.data_path, "train")
    val_images_path = os.path.join(opt.data_path, "val")
    output_model_path = os.path.join(opt.output_path, opt.version)
    output_train_images_path = os.path.join(opt.output_path, opt.version,
                                            "train")
    output_sample_images_path = os.path.join(opt.output_path, opt.version,
                                             "sample")

    os.makedirs(output_train_images_path, exist_ok=True)
    os.makedirs(output_sample_images_path, exist_ok=True)

    # Initialize BCELoss function
    criterion = nn.BCELoss()

    # Initialize generator and discriminator
    netG = dcgan.Generator(nc, nz, ngf).to(device)
    netG.apply(weights_init)
    netD = dcgan.Discriminator(nc, ndf).to(device)
    netD.apply(weights_init)

    # Create batch of latent vectors to visualize
    # the progress of the generator
    # sample_noise = torch.randn(64, nz, 1, 1, device=device)

    # Establish convention for real and fake labels during training
    real_label = 1
    fake_label = 0

    # Setup Adam optimizers for both G and D
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr,
                            betas=(opt.b1, opt.b2))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr,
                            betas=(opt.b1, opt.b2))

    train_set = datasets.ImageFolder(root=train_images_path,
                                     transform=transforms.Compose([
                                         transforms.Resize(
                                             (opt.img_size, opt.img_size)),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5),
                                                              (0.5, 0.5, 0.5))
                                     ]))

    dataloader = torch.utils.data.DataLoader(train_set,
                                             batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=opt.num_workers)

    # ----------
    #  Training
    # ----------

    G_losses = []
    D_losses = []
    FIDs = []
    val_epochs = []

    def eval_fid(gen_images_path, eval_images_path):
        print("Calculating FID...")
        fid = fid_score.calculate_fid_given_paths(
            (gen_images_path, eval_images_path), opt.batch_size, device)
        return fid

    def validate(keep_images=True):
        val_set = datasets.ImageFolder(root=val_images_path,
                                       transform=transforms.Compose([
                                           transforms.Resize(
                                               (opt.img_size, opt.img_size)),
                                           transforms.ToTensor()
                                       ]))

        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 num_workers=opt.num_workers)

        output_images_path = os.path.join(opt.output_path, opt.version, "val")
        os.makedirs(output_images_path, exist_ok=True)

        output_source_images_path = val_images_path + "_" + str(opt.img_size)

        source_images_available = True

        if (not os.path.exists(output_source_images_path)):
            os.makedirs(output_source_images_path)
            source_images_available = False

        images_done = 0
        for _, data in enumerate(val_loader, 0):
            images, _ = data
            batch_size = images.size(0)
            noise = torch.randn((batch_size, nz, 1, 1)).to(device)

            gen_images = netG(noise)
            for i in range(images_done, images_done + batch_size):
                vutils.save_image(gen_images[i - images_done, :, :, :],
                                  "{}/{}.jpg".format(output_images_path, i),
                                  normalize=True)
                if (not source_images_available):
                    vutils.save_image(images[i - images_done, :, :, :],
                                      "{}/{}.jpg".format(
                                          output_source_images_path, i),
                                      normalize=True)
            images_done += batch_size

        fid = eval_fid(output_images_path, output_source_images_path)
        if (not keep_images):
            print("Deleting images generated for validation...")
            rmtree(output_images_path)
        return fid

    def sample_images(num_images, batches_done):
        # Sample noise
        z = torch.randn((num_classes * num_images, nz, 1, 1)).to(device)
        sample_imgs = netG(z)
        vutils.save_image(sample_imgs.data,
                          "{}/{}.png".format(output_sample_images_path,
                                             batches_done),
                          nrow=num_images,
                          padding=2,
                          normalize=True)

    def save_loss_plot(path):
        plt.figure(figsize=(10, 5))
        plt.title("Generator and Discriminator Loss During Training")
        plt.plot(G_losses, label="G")
        plt.plot(D_losses, label="D")
        plt.xlabel("iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(path)

    def save_fid_plot(FIDs, epochs, path):
        #N = len(FIDs)
        plt.figure(figsize=(10, 5))
        plt.title("FID on Validation Set")
        plt.plot(epochs, FIDs)
        plt.xlabel("epochs")
        plt.ylabel("FID")
        #plt.xticks([i * 49 for i in range(1, N+1)])
        plt.savefig(path)
        plt.close()

    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(1, opt.num_epochs + 1):
        # For each batch in the dataloader
        for i, data in enumerate(dataloader, 0):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            netD.zero_grad()
            # Format batch
            real_imgs = data[0].to(device)
            batch_size = real_imgs.size(0)
            label = torch.full((batch_size, ), real_label, device=device)

            # Forward pass real batch through D
            output = netD(real_imgs).view(-1)

            # Calculate loss on all-real batch
            errD_real = criterion(output, label)

            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(batch_size, nz, 1, 1, device=device)

            # Generate fake image batch with G
            fake = netG(noise)
            label.fill_(fake_label)

            # Classify all fake batch with D
            output = netD(fake.detach()).view(-1)

            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)

            # Calculate the gradients for this batch
            errD_fake.backward()
            D_G_z1 = output.mean().item()

            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = netD(fake).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, label)
            # Calculate gradients for G
            errG.backward()
            D_G_z2 = output.mean().item()
            # Update G
            optimizerG.step()

            # Save Losses for plotting
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Output training stats
            if i % opt.print_every == 0:
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f] [D(x): %.4f] [D(G(z)): %.4f / %.4f]"
                    % (epoch, opt.num_epochs, i, len(dataloader), errD.item(),
                       errG.item(), D_x, D_G_z1, D_G_z2))

            batches_done = epoch * len(dataloader) + i

            if (batches_done % opt.sample_interval
                    == 0) or ((epoch == opt.num_epochs - 1) and
                              (i == len(dataloader) - 1)):
                # Put G in eval mode
                netG.eval()

                with torch.no_grad():
                    sample_images(opt.num_sample_images, batches_done)
                vutils.save_image(fake.data[:25],
                                  "{}/{}.png".format(output_train_images_path,
                                                     batches_done),
                                  nrow=5,
                                  padding=2,
                                  normalize=True)

                # Put G back in train mode
                netG.train()

        # Save model checkpoint
        if (epoch != opt.num_epochs and epoch % opt.checkpoint_epochs == 0):
            print("Checkpoint at epoch {}".format(epoch))
            print("Saving generator model...")
            torch.save(
                netG.state_dict(),
                os.path.join(output_model_path,
                             "model_checkpoint_{}.pt".format(epoch)))
            print("Saving G & D loss plot...")
            save_loss_plot(
                os.path.join(opt.output_path, opt.version,
                             "loss_plot_{}.png".format(epoch)))

            print("Validating model...")
            netG.eval()
            with torch.no_grad():
                fid = validate(keep_images=False)
            print("Validation FID: {}".format(fid))
            with open(os.path.join(opt.output_path, opt.version, "FIDs.txt"),
                      "a") as f:
                f.write("Epoch: {}, FID: {}\n".format(epoch, fid))
            FIDs.append(fid)
            val_epochs.append(epoch)
            print("Saving FID plot...")
            save_fid_plot(
                FIDs, val_epochs,
                os.path.join(opt.output_path, opt.version,
                             "fid_plot_{}.png".format(epoch)))
            netG.train()

    print("Saving final generator model...")
    torch.save(netG.state_dict(), os.path.join(output_model_path, "model.pt"))
    print("Done!")

    print("Saving final G & D loss plot...")
    save_loss_plot(os.path.join(opt.output_path, opt.version, "loss_plot.png"))
    print("Done!")

    print("Validating final model...")
    netG.eval()
    with torch.no_grad():
        fid = validate()
    print("Final Validation FID: {}".format(fid))
    with open(os.path.join(opt.output_path, opt.version, "FIDs.txt"),
              "a") as f:
        f.write("Epoch: {}, FID: {}\n".format(epoch, fid))
    FIDs.append(fid)
    val_epochs.append(epoch)
    print("Saving final FID plot...")
    save_fid_plot(FIDs, val_epochs,
                  os.path.join(opt.output_path, opt.version, "fid_plot"))