Exemple #1
0
    def __init__(self, type, dataset, split, lr, save_path, l1_coef, l2_coef,
                 pre_trained_gen, pre_trained_disc, batch_size, num_workers,
                 epochs):
        with open('config.yaml', 'r') as f:
            config = yaml.load(f)

        self.generator = torch.nn.DataParallel(
            gan_factory.generator_factory(type).cuda())
        self.discriminator = torch.nn.DataParallel(
            gan_factory.discriminator_factory(type).cuda())

        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        if dataset == 'birds':
            self.dataset = Text2ImageDataset(config['birds_dataset_path'],
                                             split=split)
        elif dataset == 'flowers':
            self.dataset = Text2ImageDataset(config['flowers_dataset_path'],
                                             split=split)
        else:
            print(
                'Dataset not supported, please select either birds or flowers.'
            )
            exit()

        #print "Image = ",len(self.dataset)
        self.noise_dim = 100
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=False,
                                      num_workers=self.num_workers)

        self.optimD = torch.optim.Adam(self.discriminator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(self.generator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))

        self.logger = Logger()
        self.checkpoints_path = 'checkpoints'
        self.save_path = save_path
        self.type = type
Exemple #2
0
def build_vocab(data_path, threshold):
    """Build a simple vocabulary wrapper."""

    # data_path = "data/cub.h5"
    for split in range(3):
        dataset = Text2ImageDataset(data_path, split=split)
        counter = Counter()
        ids = dataset.dataset_keys

        for i, id in enumerate(ids):
            example = dataset.dataset[dataset.split][id]
            caption = str(np.array(example['txt']).astype(str))
            tokens = nltk.tokenize.word_tokenize(caption.lower())
            counter.update(tokens)

            if i % 1000 == 0:
                print("[%d/%d] Tokenized the captions." % (i, len(ids)))

    # If the word frequency is less than 'threshold', then the word is discarded.
    words = [word for word, cnt in counter.items() if cnt >= threshold]

    # Creates a vocab wrapper and add some special tokens.
    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    # Adds the words to the vocabulary.
    for i, word in enumerate(words):
        vocab.add_word(word)
    return vocab
Exemple #3
0
def main():
    device = torch.device("cuda:0" if FLAGS.cuda else "cpu")

    print('Loading data...\n')
    dataloader = DataLoader(Text2ImageDataset(os.path.join(
        FLAGS.data_dir, '{}.hdf5'.format(FLAGS.dataset)),
                                              split=0),
                            batch_size=FLAGS.batch_size,
                            shuffle=True,
                            num_workers=8)

    print('Creating model...\n')
    model = Model(FLAGS.model, device, dataloader, FLAGS.channels,
                  FLAGS.l1_coef, FLAGS.l2_coef)

    if FLAGS.train:
        model.create_optim(FLAGS.lr)

        print('Training...\n')
        model.train(FLAGS.epochs, FLAGS.log_interval, FLAGS.out_dir, True)

        model.save_to('')
    else:
        model.load_from('')

        print('Evaluating...\n')
        model.eval(batch_size=64)
    def __init__(self, dataset, split, lr, l1_coef, l2_coef, 
                 batch_size, num_workers, epochs, optimization):
        self.generator = torch.nn.DataParallel(dcgan.generator().cuda())
        self.discriminator = torch.nn.DataParallel(dcgan.discriminator().
                                                                 cuda())
        self.generator.apply(Utils.weights_init)
        self.discriminator.apply(Utils.weights_init)
        self.filename = dataset 
        
        if dataset == 'birds':
            self.dataset = Text2ImageDataset('Datasets/birds.hdf5', 
                                             split=split)
        elif dataset == 'flowers':
            self.dataset = Text2ImageDataset('Datasets/flowers.hdf5', 
                                             split=split)
        else:
            print('Dataset not available, select either birds or flowers.')
            exit()

        self.noise_dim = 100
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.l1_coef = l1_coef
        self.l2_coef = l2_coef
        self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, 
                                      shuffle=True, 
                                      num_workers=self.num_workers)
        if optimization == 'adam':
            self.optimG = torch.optim.Adam(self.generator.parameters(), 
                            lr=self.lr, betas=(self.beta1, 0.999))
            self.optimD = torch.optim.Adam(self.discriminator.parameters(), 
                            lr=self.lr, betas=(self.beta1, 0.999))
        else:
            self.optimG = torch.optim.LBFGS(self.generator.parameters(), lr=1, 
                            max_iter=20, max_eval=None, tolerance_grad=1e-05, 
                            tolerance_change=1e-09, history_size=100, 
                            line_search_fn=None)
            self.optimD = torch.optim.LBFGS(self.discriminator.parameters(), 
                            lr=1, max_iter=20, max_eval=None, 
                            tolerance_grad=1e-05, tolerance_change=1e-09, 
                            history_size=100, line_search_fn=None)
        
        self.logger = Logger()
        self.checkpoints_path = 'checkpoints'
Exemple #5
0
    def predict(self, trained=False):
        # making prediction
        test_data = Text2ImageDataset('./data/flowers.hdf5', split=0)
        test_data_loader = DataLoader(test_data, batch_size=40, shuffle=False)
        for i in (0, 20, 50, 100, 200, 400, 790):
            ourmodel = 'gen_' + str(i + 200) + '.pth'
            our_model_path = os.path.join('./Log/checkpoints/800_gan_cls_new',
                                          ourmodel)
            originalmodel = 'gen_' + str(i) + '.pth'
            original_model_path = os.path.join('./Log/checkpoints/800_gan_cls',
                                               originalmodel)
            our_save_path = os.path.join('./results/', str(i), 'our')
            original_save_path = os.path.join('./results/', str(i), 'original')
            # loading trained model
            if trained:
                # loading trained model for prediction
                # construct model
                original_generator = gan_cls_original.generator().to(DEVICE)
                original_generator.load_state_dict(
                    torch.load(original_model_path))
                our_generator = gan_cls_new.generator().to(DEVICE)
                our_generator.load_state_dict(torch.load(our_model_path))
                original_generator.eval()
                our_generator.eval()
            # self.generator.eval()
            count = 0
            for sample in test_data_loader:  # only generate 100 batches
                count += 1
                if count > 1000:
                    break
                print(count)
                right_images = sample['right_images']
                right_embed = sample['right_embed']
                txt = sample['txt']

                right_images = Variable(right_images.float()).cuda()
                right_embed = Variable(right_embed.float()).cuda()

                # Train the generator
                noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
                noise = noise.view(noise.size(0), 100, 1, 1)
                original_fake_images = original_generator(right_embed, noise)
                our_fake_images = our_generator(right_embed, noise)
                # save
                for original_fake_image, our_fake_image, t in zip(
                        original_fake_images, our_fake_images, txt):
                    original_im = Image.fromarray(
                        original_fake_image.data.mul_(127.5).add_(
                            127.5).byte().permute(1, 2, 0).cpu().numpy())
                    our_im = Image.fromarray(
                        our_fake_image.data.mul_(127.5).add_(
                            127.5).byte().permute(1, 2, 0).cpu().numpy())
                    t = t.replace("/", "")
                    original_im.save('{0}/{1}.jpg'.format(
                        original_save_path,
                        t.replace("\n", "")[:200]))
                    our_im.save('{0}/{1}.jpg'.format(our_save_path,
                                                     t.replace("\n",
                                                               "")[:200]))
Exemple #6
0
    def __init__(self, type, dataset, vis_screen, pre_trained_gen, batch_size,
                 num_workers, epochs):
        with open('config.yaml', 'r') as f:
            config = yaml.load(f)

        self.generator = torch.nn.DataParallel(
            gan_factory.generator_factory(type).cuda())

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        if dataset == 'birds':
            self.dataset = Text2ImageDataset(config['birds_dataset_path'],
                                             split=1)
        elif dataset == 'flowers':
            self.dataset = Text2ImageDataset(
                config['flowers_val_dataset_path'], split=1)
        else:
            print(
                'Dataset not supported, please select either birds or flowers.'
            )
            exit()

        self.noise_dim = 100
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.num_epochs = epochs

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=False,
                                      num_workers=self.num_workers)

        self.logger = Logger(vis_screen)
        self.type = type
Exemple #7
0
def main():
    # global args
    args = parser.parse_args()

    # <editor-fold desc="Initialization">
    if args.comment == "NONE":
        args.comment = args.method

    validate = args.validate == "true"

    if args.method == "coupled_vae_gan":
        trainer = coupled_vae_gan_trainer.coupled_vae_gan_trainer
    elif args.method == "coupled_vae":
        trainer = coupled_vae_trainer.coupled_vae_trainer
    elif args.method == "wgan":
        trainer = wgan_trainer.wgan_trainer
    elif args.method == "seq_wgan":
        trainer = seq_wgan_trainer.wgan_trainer
    elif args.method == "skip_thoughts":
        trainer = skipthoughts_vae_gan_trainer.coupled_vae_gan_trainer
    else:
        assert False, "Invalid method"

    # now = datetime.datetime.now()
    # current_date = now.strftime("%m-%d-%H-%M")

    assert args.text_criterion in ("MSE", "Cosine", "Hinge",
                                   "NLLLoss"), 'Invalid Loss Function'
    assert args.cm_criterion in ("MSE", "Cosine",
                                 "Hinge"), 'Invalid Loss Function'

    assert args.common_emb_ratio <= 1.0 and args.common_emb_ratio >= 0

    #</editor-fold>

    # <editor-fold desc="Image Preprocessing">

    # Image preprocessing //ATTENTION
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((.5, .5, .5), (.5, .5, .5))
        # transforms.Normalize((0.485, 0.456, 0.406),
        #                      (0.229, 0.224, 0.225))
    ])

    #</editor-fold>

    # <editor-fold desc="Creating Embeddings">
    if args.dataset != "coco":
        args.vocab_path = "./data/cub_vocab.pkl"

    # Load vocabulary wrapper.
    print("Loading Vocabulary...")
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Load Embeddings
    emb_size = args.word_embedding_size
    emb_path = args.embedding_path
    if args.embedding_path[-1] == '/':
        emb_path += 'glove.6B.' + str(emb_size) + 'd.txt'

    print("Loading Embeddings...")

    use_glove = args.use_glove == "true"
    if use_glove:
        emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size)
        word_emb = nn.Embedding(emb.size(0), emb.size(1))
        word_emb.weight = nn.Parameter(emb)
    else:
        word_emb = nn.Embedding(len(vocab), emb_size)

    # Freeze weighs
    if args.fixed_embeddings == "true":
        word_emb.weight.requires_grad = True

    # </editor-fold>

    # <editor-fold desc="Data-Loaders">

    # Build data loader
    print("Building Data Loader For Test Set...")
    if args.dataset == 'coco':
        data_loader = get_loader(args.image_dir,
                                 args.caption_path,
                                 vocab,
                                 transform,
                                 args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers)

        print("Building Data Loader For Validation Set...")
        val_loader = get_loader(args.valid_dir,
                                args.valid_caption_path,
                                vocab,
                                transform,
                                args.batch_size,
                                shuffle=True,
                                num_workers=args.num_workers)

    else:
        data_path = "data/cub.h5"
        dataset = Text2ImageDataset(data_path,
                                    split=0,
                                    vocab=vocab,
                                    transform=transform)
        data_loader = DataLoader(dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 collate_fn=collate_fn)

        dataset_val = Text2ImageDataset(data_path,
                                        split=1,
                                        vocab=vocab,
                                        transform=transform)
        val_loader = DataLoader(dataset_val,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=args.num_workers,
                                collate_fn=collate_fn)

    # </editor-fold>            txt_rc_loss = self.networks["coupled_vae"].text_reconstruction_loss(captions, txt2txt_out, lengths)

    # <editor-fold desc="Network Initialization">

    print("Setting up the trainer...")
    model_trainer = trainer(args, word_emb, vocab)

    #  <\editor-fold desc="Network Initialization">

    for epoch in range(args.num_epochs):

        # <editor-fold desc = "Epoch Initialization"?

        # TRAINING TIME
        print('EPOCH ::: TRAINING ::: ' + str(epoch + 1))
        batch_time = AverageMeter()
        end = time.time()

        bar = Bar(args.method if args.comment == "NONE" else args.method +
                  "/" + args.comment,
                  max=len(data_loader))

        model_trainer.set_train_models()
        model_trainer.create_losses_meter(model_trainer.losses)

        for i, (images, captions, lengths) in enumerate(data_loader):
            if model_trainer.load_models(epoch):
                break

            # if i == 1:
            if i == len(data_loader) - 1:
                break

            images = to_var(images)
            # captions = to_var(captions[:,1:])
            captions = to_var(captions)
            # lengths = to_var(torch.LongTensor(lengths) - 1)            # print(captions.size())
            lengths = to_var(
                torch.LongTensor(lengths))  # print(captions.size())

            model_trainer.forward(epoch, images, captions, lengths,
                                  not i % args.image_save_interval)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if not model_trainer.iteration % args.log_step:
                # plot progress
                bar.suffix = bcolors.HEADER
                # bar.suffix += '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}\n'.format(
                bar.suffix += '({batch}/{size}) Iter: {bt:} | Time: {total:}-{eta:}\n'.format(
                    batch=i,
                    size=len(data_loader),
                    # bt=batch_time.val,
                    bt=model_trainer.iteration,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                )
                bar.suffix += bcolors.ENDC

                cnt = 0
                for l_name, l_value in sorted(model_trainer.losses.items(),
                                              key=lambda x: x[0]):
                    cnt += 1
                    bar.suffix += ' | {name}: {val:.3f}'.format(
                        name=l_name,
                        val=l_value.avg,
                    )
                    if not cnt % 5:
                        bar.suffix += "\n"

                bar.next()

        # </editor-fold desc = "Logging">

        bar.finish()

        if validate:
            print('EPOCH ::: VALIDATION ::: ' + str(epoch + 1))
            batch_time = AverageMeter()
            end = time.time()
            barName = args.method if args.comment == "NONE" else args.method + "/" + args.comment
            barName = "VAL:" + barName
            bar = Bar(barName, max=len(val_loader))

            model_trainer.set_eval_models()
            model_trainer.create_metrics_meter(model_trainer.metrics)

            for i, (images, captions, lengths) in enumerate(val_loader):
                # if not model_trainer.keep_loading and not model_trainer.iteration % args.model:
                #     model_trainer.save_models(epoch)

                if i == len(val_loader) - 1:
                    break

                images = to_var(images)
                captions = to_var(captions[:, 1:])
                # lengths = to_var(torch.LongTensor(lengths - 1))            # print(captions.size())

                model_trainer.evaluate(epoch, images, captions, lengths,
                                       i == 0)

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # plot progress
                bar.suffix = bcolors.HEADER
                # bar.suffix += '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}\n'.format(
                bar.suffix += '({batch}/{size}) Iter: {bt:} | Time: {total:}-{eta:}\n'.format(
                    batch=i,
                    size=len(val_loader),
                    # bt=batch_time.val,
                    bt=model_trainer.iteration,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                )
                bar.suffix += bcolors.ENDC

                cnt = 0
                for l_name, l_value in sorted(model_trainer.metrics.items(),
                                              key=lambda x: x[0]):
                    cnt += 1
                    bar.suffix += ' | {name}: {val:.3f}'.format(
                        name=l_name,
                        val=l_value.avg,
                    )
                    if not cnt % 5:
                        bar.suffix += "\n"

                bar.next()

            bar.finish()

        # model_trainer.validate(val_loader)
    model_trainer.save_models(-1)
Exemple #8
0
    def __init__(self, type, dataset, split, lr, diter, vis_screen, save_path,
                 l1_coef, l2_coef, pre_trained_gen, pre_trained_disc,
                 batch_size, num_workers, epochs, pre_trained_disc_B,
                 pre_trained_gen_B):
        with open('config.yaml', 'r') as f:
            config = yaml.load(f)

        # forward gan
        if is_cuda:
            self.generator = torch.nn.DataParallel(
                gan_factory.generator_factory('gan').cuda())
            self.discriminator = torch.nn.DataParallel(
                gan_factory.discriminator_factory('gan').cuda())
            self.generator2 = torch.nn.DataParallel(
                gan_factory.generator_factory('stage2_gan').cuda())
            self.discriminator2 = torch.nn.DataParallel(
                gan_factory.discriminator_factory('stage2_gan').cuda())
        else:
            self.generator = torch.nn.DataParallel(
                gan_factory.generator_factory('gan'))
            self.discriminator = torch.nn.DataParallel(
                gan_factory.discriminator_factory('gan'))
            self.generator2 = torch.nn.DataParallel(
                gan_factory.generator_factory('stage2_gan'))
            self.discriminator2 = torch.nn.DataParallel(
                gan_factory.discriminator_factory('stage2_gan'))

        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        if pre_trained_disc_B:
            self.discriminator2.load_state_dict(torch.load(pre_trained_disc_B))
        else:
            self.discriminator2.apply(Utils.weights_init)

        if pre_trained_gen_B:
            self.generator2.load_state_dict(torch.load(pre_trained_gen_B))
        else:
            self.generator2.apply(Utils.weights_init)

        if dataset == 'birds':
            with open('./data/birds_vocab.pkl', 'rb') as f:
                self.vocab = pickle.load(f)
            self.dataset = Text2ImageDataset(config['birds_dataset_path'],
                                             dataset_type='birds',
                                             vocab=self.vocab,
                                             split=split)
        elif dataset == 'flowers':
            with open('./data/flowers_vocab.pkl', 'rb') as f:
                self.vocab = pickle.load(f)
            self.dataset = Text2ImageDataset(config['flowers_dataset_path'],
                                             dataset_type='flowers',
                                             vocab=self.vocab,
                                             split=split)
        else:
            print(
                'Dataset not supported, please select either birds or flowers.'
            )
            exit()

        self.noise_dim = 100
        self.batch_size = batch_size
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.DITER = diter
        self.num_workers = num_workers

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=self.num_workers,
                                      collate_fn=collate_fn)

        self.optimD = torch.optim.Adam(self.discriminator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(self.generator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))

        self.optimD2 = torch.optim.Adam(self.discriminator2.parameters(),
                                        lr=self.lr,
                                        betas=(self.beta1, 0.999))
        self.optimG2 = torch.optim.Adam(self.generator2.parameters(),
                                        lr=self.lr,
                                        betas=(self.beta1, 0.999))

        self.checkpoints_path = './checkpoints/'
        self.save_path = save_path
        self.type = type

        # TODO: put these as runtime.py params
        self.embed_size = 256
        self.hidden_size = 512
        self.num_layers = 1

        self.gen_pretrain_num_epochs = 100
        self.disc_pretrain_num_epochs = 20

        self.figure_path = './figures/'
        if is_cuda:
            self.caption_generator = CaptionGenerator(self.embed_size,
                                                      self.hidden_size,
                                                      len(self.vocab),
                                                      self.num_layers).cuda()
            self.caption_discriminator = CaptionDiscriminator(
                self.embed_size, self.hidden_size, len(self.vocab),
                self.num_layers).cuda()
        else:
            self.caption_generator = CaptionGenerator(self.embed_size,
                                                      self.hidden_size,
                                                      len(self.vocab),
                                                      self.num_layers)
            self.caption_discriminator = CaptionDiscriminator(
                self.embed_size, self.hidden_size, len(self.vocab),
                self.num_layers)

        pretrained_caption_gen = './checkpoints/pretrained-generator-20.pkl'
        pretrained_caption_disc = './checkpoints/pretrained-discriminator-5.pkl'

        if os.path.exists(pretrained_caption_gen):
            print('loaded pretrained caption generator')
            self.caption_generator.load_state_dict(
                torch.load(pretrained_caption_gen))

        if os.path.exists(pretrained_caption_disc):
            print('loaded pretrained caption discriminator')
            self.caption_discriminator.load_state_dict(
                torch.load(pretrained_caption_disc))

        self.optim_captionG = torch.optim.Adam(
            list(self.caption_generator.parameters()))
        self.optim_captionD = torch.optim.Adam(
            list(self.caption_discriminator.parameters()))
Exemple #9
0
    def __init__(self, type, dataset, split, lr, diter, vis_screen, save_path,
                 l1_coef, l2_coef, pre_trained_gen, pre_trained_disc,
                 batch_size, num_workers, epochs, pre_trained_disc_B,
                 pre_trained_gen_B):
        with open('config.yaml', 'r') as f:
            config = yaml.load(f)

        # forward gan
        if is_cuda:
            self.generator = torch.nn.DataParallel(
                gan_factory.generator_factory('gan').cuda())
            self.discriminator = torch.nn.DataParallel(
                gan_factory.discriminator_factory('gan').cuda())
            self.generator2 = torch.nn.DataParallel(
                gan_factory.generator_factory('stage2_gan').cuda())
            self.discriminator2 = torch.nn.DataParallel(
                gan_factory.discriminator_factory('stage2_gan').cuda())
        else:
            self.generator = torch.nn.DataParallel(
                gan_factory.generator_factory('gan'))
            self.discriminator = torch.nn.DataParallel(
                gan_factory.discriminator_factory('gan'))
            self.generator2 = torch.nn.DataParallel(
                gan_factory.generator_factory('stage2_gan'))
            self.discriminator2 = torch.nn.DataParallel(
                gan_factory.discriminator_factory('stage2_gan'))

        # inverse gan
        # TODO: pass these as parameters from runtime in the future
        # inverse_type = 'inverse_gan'
        # if is_cuda:
        #     self.inv_generator = torch.nn.DataParallel(gan_factory.generator_factory(inverse_type).cuda())
        #     self.inv_discriminator = torch.nn.DataParallel(gan_factory.discriminator_factory(inverse_type).cuda())
        # else:
        #     self.inv_generator = torch.nn.DataParallel(gan_factory.generator_factory(inverse_type))
        #     self.inv_discriminator = torch.nn.DataParallel(gan_factory.discriminator_factory(inverse_type))

        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        if pre_trained_disc_B:
            self.discriminator2.load_state_dict(torch.load(pre_trained_disc_B))
        else:
            self.discriminator2.apply(Utils.weights_init)

        if pre_trained_gen_B:
            self.generator2.load_state_dict(torch.load(pre_trained_gen_B))
        else:
            self.generator2.apply(Utils.weights_init)

        if dataset == 'birds':
            self.dataset = Text2ImageDataset(config['birds_dataset_path'],
                                             dataset_type='birds',
                                             split=split)
        elif dataset == 'flowers':
            self.dataset = Text2ImageDataset(config['flowers_dataset_path'],
                                             dataset_type='flowers',
                                             split=split)
        else:
            print(
                'Dataset not supported, please select either birds or flowers.'
            )
            exit()

        self.noise_dim = 100
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.DITER = diter

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=self.num_workers)

        self.optimD = torch.optim.Adam(self.discriminator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(self.generator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))

        self.optimD2 = torch.optim.Adam(self.discriminator2.parameters(),
                                        lr=self.lr,
                                        betas=(self.beta1, 0.999))
        self.optimG2 = torch.optim.Adam(self.generator2.parameters(),
                                        lr=self.lr,
                                        betas=(self.beta1, 0.999))

        self.logger = Logger(vis_screen)
        self.checkpoints_path = './checkpoints/'
        self.save_path = save_path
        self.type = type
Exemple #10
0
    def __init__(self, type, dataset, split, lr, lr_lower_boundary,
                 lr_update_type, lr_update_step, diter, vis_screen, save_path,
                 l1_coef, l2_coef, pre_trained_gen, pre_trained_disc,
                 batch_size, num_workers, epochs, h, scale_size, num_channels,
                 k, lambda_k, gamma, project, concat):
        with open('config.yaml', 'r') as f:
            config = yaml.load(f)

        self.generator = gan_factory.generator_factory(type, dataset,
                                                       batch_size, h,
                                                       scale_size,
                                                       num_channels).cuda()
        self.discriminator = gan_factory.discriminator_factory(
            type, batch_size, h, scale_size, num_channels).cuda()
        print(self.discriminator)
        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)
        self.dataset_name = dataset
        if self.dataset_name == 'birds':
            self.dataset = Text2ImageDataset(config['birds_dataset_path'],
                                             split=split)
        elif self.dataset_name == 'flowers':
            self.dataset = Text2ImageDataset(config['flowers_dataset_path'],
                                             split=split)
        elif self.dataset_name == 'youtubers':
            self.dataset = OneHot2YoutubersDataset(
                config['youtubers_dataset_path'],
                transform=Rescale(64),
                split=split)
        else:
            print(
                'Dataset not supported, please select either birds or flowers.'
            )
            exit()

        self.noise_dim = 100
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr
        self.lr_lower_boundary = lr_lower_boundary
        self.lr_update_type = lr_update_type
        self.lr_update_step = lr_update_step
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.DITER = diter
        self.apply_projection = project
        self.apply_concat = concat

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef
        self.h = h
        self.scale_size = scale_size
        self.num_channels = num_channels
        self.k = k
        self.lambda_k = lambda_k
        self.gamma = gamma

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=self.num_workers)

        self.optimD = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              self.discriminator.parameters()),
                                       lr=0.0004,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              self.generator.parameters()),
                                       lr=0.0001,
                                       betas=(self.beta1, 0.999))

        self.logger = Logger(vis_screen, save_path)
        self.checkpoints_path = 'checkpoints'
        self.save_path = save_path
        self.type = type
Exemple #11
0
    def __init__(self, type, dataset, split, lr, diter, mode, vis_screen,
                 l1_coef, l2_coef, pre_trained_gen, pre_trained_disc,
                 batch_size, num_workers, epochs):
        with open('config.yaml', 'r') as f:
            config = yaml.load(f)

        self.generator = torch.nn.DataParallel(
            gan_factory.generator_factory(type).cuda())
        self.discriminator = torch.nn.DataParallel(
            gan_factory.discriminator_factory(type).cuda())

        print(self.generator)

        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        if dataset == 'oxford':
            if mode is False:  #if we are in training mode:
                self.dataset = Text2ImageDataset(
                    config['oxford_dataset_path_train'], split=split)
            else:  # testing mode:
                self.dataset = Text2ImageDataset(
                    config['oxford_dataset_path_test'], split=split)
        else:
            print(
                'Dataset not supported, please select either birds or flowers.'
            )
            exit()

        self.noise_dim = 100
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.DITER = diter

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=self.num_workers)

        self.optimD = torch.optim.Adam(self.discriminator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(self.generator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))

        self.logger = Logger(vis_screen)
        now = datetime.datetime.now(dateutil.tz.tzlocal())
        timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
        self.save_path = 'output/%s_%s' % \
                    (dataset, timestamp)
        self.checkpoints_path = os.path.join(self.save_path, 'checkpoints')
        self.image_dir = os.path.join(self.save_path, 'images')
        self.type = type
Exemple #12
0
    def __init__(self, dataset, split, lr, save_path, l1_coef, l2_coef,
                 pre_trained_gen, pre_trained_disc, val_pre_trained_gen,
                 val_pre_trained_disc, batch_size, num_workers, epochs,
                 dataset_paths, arrangement, sampling):

        with open('config.yaml',
                  'r') as f:  # Wsteczna kompatybilnosc dla Text2ImageDataset
            config = yaml.safe_load(f)

        self.generator = torch.nn.DataParallel(
            gan_factory.generator_factory('gan').cuda())
        self.discriminator = torch.nn.DataParallel(
            gan_factory.discriminator_factory('gan').cuda())

        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        self.val_pre_trained_gen = val_pre_trained_gen
        self.val_pre_trained_disc = val_pre_trained_disc
        self.arrangement = arrangement
        self.sampling = sampling

        if dataset == 'birds':  # Wsteczna kompatybilnosc dla Text2ImageDataset
            self.dataset = Text2ImageDataset(
                config['birds_dataset_path'], split=split
            )  # '...\Text2Image\datasets\ee285f-public\caltech_ucsd_birds\birds.hdf5'
        elif dataset == 'flowers':  # Wsteczna kompatybilnosc dla Text2ImageDataset
            self.dataset = Text2ImageDataset(
                config['flowers_dataset_path'], split=split
            )  # '...\Text2Image\datasets\ee285f-public\oxford_flowers\flowers.hdf5'
        elif dataset == 'live':
            self.dataset_dict = easydict.EasyDict(dataset_paths)
            self.dataset = Text2ImageDataset2(
                datasetFile=self.dataset_dict.datasetFile,
                imagesDir=self.dataset_dict.imagesDir,
                textDir=self.dataset_dict.textDir,
                split=split,
                arrangement=arrangement,
                sampling=sampling)
        else:
            print('Dataset not supported.')

        print('Images =', len(self.dataset))
        self.noise_dim = 100
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef

        self.data_loader = DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers
        )  # shuffle=True - przetasowuje zbior danych w kazdej epoce

        self.optimD = torch.optim.Adam(self.discriminator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(self.generator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))

        self.checkpoints_path = 'checkpoints'
        self.save_path = save_path

        self.loss_filename_t = 'gd_loss_t.csv'
        self.loss_filename_v = 'gd_loss_v.csv'
Exemple #13
0
    def __init__(self,
                 dataset='flowers',
                 split=0,
                 lr=2e-4,
                 diter=5,
                 save_path='./Log',
                 l1_coef=90,
                 l2_coef=100,
                 pre_trained_gen=False,
                 pre_trained_disc=False,
                 batch_size=64,
                 num_workers=16,
                 epochs=800):
        self.generator = gan_cls_new.generator().to(DEVICE)
        self.discriminator = gan_cls_new.discriminator().to(DEVICE)
        if pre_trained_disc:
            self.discriminator.load_state_dict(
                torch.load('./Log/checkpoints/disc_190.pth'))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(
                torch.load('./Log/checkpoints/gen_190.pth'))
        else:
            self.generator.apply(Utils.weights_init)

        # choose smaller flower data set
        if dataset == 'flowers':
            self.dataset = Text2ImageDataset('./data/flowers.hdf5',
                                             split=split)
        elif dataset == 'birds':
            self.dataset = Text2ImageDataset('./data/birds.hdf5', split=split)
        else:
            print(
                'Data not supported, please select either birds.hdf5 or flowers.hdf5'
            )
            exit()
        # print(self.dataset.__len__()) # 29390 training samples
        self.noise_dim = 100
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.DITER = diter

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=self.num_workers)

        self.optimD = torch.optim.Adam(self.discriminator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(self.generator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))

        self.checkpoints_path = 'checkpoints/2gen_800epochs'
        self.save_path = save_path