def __init__(self, type, dataset, split, lr, save_path, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, batch_size, num_workers, epochs): with open('config.yaml', 'r') as f: config = yaml.load(f) self.generator = torch.nn.DataParallel( gan_factory.generator_factory(type).cuda()) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory(type).cuda()) if pre_trained_disc: self.discriminator.load_state_dict(torch.load(pre_trained_disc)) else: self.discriminator.apply(Utils.weights_init) if pre_trained_gen: self.generator.load_state_dict(torch.load(pre_trained_gen)) else: self.generator.apply(Utils.weights_init) if dataset == 'birds': self.dataset = Text2ImageDataset(config['birds_dataset_path'], split=split) elif dataset == 'flowers': self.dataset = Text2ImageDataset(config['flowers_dataset_path'], split=split) else: print( 'Dataset not supported, please select either birds or flowers.' ) exit() #print "Image = ",len(self.dataset) self.noise_dim = 100 self.batch_size = batch_size self.num_workers = num_workers self.lr = lr self.beta1 = 0.5 self.num_epochs = epochs self.l1_coef = l1_coef self.l2_coef = l2_coef self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.logger = Logger() self.checkpoints_path = 'checkpoints' self.save_path = save_path self.type = type
def build_vocab(data_path, threshold): """Build a simple vocabulary wrapper.""" # data_path = "data/cub.h5" for split in range(3): dataset = Text2ImageDataset(data_path, split=split) counter = Counter() ids = dataset.dataset_keys for i, id in enumerate(ids): example = dataset.dataset[dataset.split][id] caption = str(np.array(example['txt']).astype(str)) tokens = nltk.tokenize.word_tokenize(caption.lower()) counter.update(tokens) if i % 1000 == 0: print("[%d/%d] Tokenized the captions." % (i, len(ids))) # If the word frequency is less than 'threshold', then the word is discarded. words = [word for word, cnt in counter.items() if cnt >= threshold] # Creates a vocab wrapper and add some special tokens. vocab = Vocabulary() vocab.add_word('<pad>') vocab.add_word('<start>') vocab.add_word('<end>') vocab.add_word('<unk>') # Adds the words to the vocabulary. for i, word in enumerate(words): vocab.add_word(word) return vocab
def main(): device = torch.device("cuda:0" if FLAGS.cuda else "cpu") print('Loading data...\n') dataloader = DataLoader(Text2ImageDataset(os.path.join( FLAGS.data_dir, '{}.hdf5'.format(FLAGS.dataset)), split=0), batch_size=FLAGS.batch_size, shuffle=True, num_workers=8) print('Creating model...\n') model = Model(FLAGS.model, device, dataloader, FLAGS.channels, FLAGS.l1_coef, FLAGS.l2_coef) if FLAGS.train: model.create_optim(FLAGS.lr) print('Training...\n') model.train(FLAGS.epochs, FLAGS.log_interval, FLAGS.out_dir, True) model.save_to('') else: model.load_from('') print('Evaluating...\n') model.eval(batch_size=64)
def __init__(self, dataset, split, lr, l1_coef, l2_coef, batch_size, num_workers, epochs, optimization): self.generator = torch.nn.DataParallel(dcgan.generator().cuda()) self.discriminator = torch.nn.DataParallel(dcgan.discriminator(). cuda()) self.generator.apply(Utils.weights_init) self.discriminator.apply(Utils.weights_init) self.filename = dataset if dataset == 'birds': self.dataset = Text2ImageDataset('Datasets/birds.hdf5', split=split) elif dataset == 'flowers': self.dataset = Text2ImageDataset('Datasets/flowers.hdf5', split=split) else: print('Dataset not available, select either birds or flowers.') exit() self.noise_dim = 100 self.batch_size = batch_size self.num_workers = num_workers self.lr = lr self.beta1 = 0.5 self.num_epochs = epochs self.l1_coef = l1_coef self.l2_coef = l2_coef self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) if optimization == 'adam': self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) else: self.optimG = torch.optim.LBFGS(self.generator.parameters(), lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-05, tolerance_change=1e-09, history_size=100, line_search_fn=None) self.optimD = torch.optim.LBFGS(self.discriminator.parameters(), lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-05, tolerance_change=1e-09, history_size=100, line_search_fn=None) self.logger = Logger() self.checkpoints_path = 'checkpoints'
def predict(self, trained=False): # making prediction test_data = Text2ImageDataset('./data/flowers.hdf5', split=0) test_data_loader = DataLoader(test_data, batch_size=40, shuffle=False) for i in (0, 20, 50, 100, 200, 400, 790): ourmodel = 'gen_' + str(i + 200) + '.pth' our_model_path = os.path.join('./Log/checkpoints/800_gan_cls_new', ourmodel) originalmodel = 'gen_' + str(i) + '.pth' original_model_path = os.path.join('./Log/checkpoints/800_gan_cls', originalmodel) our_save_path = os.path.join('./results/', str(i), 'our') original_save_path = os.path.join('./results/', str(i), 'original') # loading trained model if trained: # loading trained model for prediction # construct model original_generator = gan_cls_original.generator().to(DEVICE) original_generator.load_state_dict( torch.load(original_model_path)) our_generator = gan_cls_new.generator().to(DEVICE) our_generator.load_state_dict(torch.load(our_model_path)) original_generator.eval() our_generator.eval() # self.generator.eval() count = 0 for sample in test_data_loader: # only generate 100 batches count += 1 if count > 1000: break print(count) right_images = sample['right_images'] right_embed = sample['right_embed'] txt = sample['txt'] right_images = Variable(right_images.float()).cuda() right_embed = Variable(right_embed.float()).cuda() # Train the generator noise = Variable(torch.randn(right_images.size(0), 100)).cuda() noise = noise.view(noise.size(0), 100, 1, 1) original_fake_images = original_generator(right_embed, noise) our_fake_images = our_generator(right_embed, noise) # save for original_fake_image, our_fake_image, t in zip( original_fake_images, our_fake_images, txt): original_im = Image.fromarray( original_fake_image.data.mul_(127.5).add_( 127.5).byte().permute(1, 2, 0).cpu().numpy()) our_im = Image.fromarray( our_fake_image.data.mul_(127.5).add_( 127.5).byte().permute(1, 2, 0).cpu().numpy()) t = t.replace("/", "") original_im.save('{0}/{1}.jpg'.format( original_save_path, t.replace("\n", "")[:200])) our_im.save('{0}/{1}.jpg'.format(our_save_path, t.replace("\n", "")[:200]))
def __init__(self, type, dataset, vis_screen, pre_trained_gen, batch_size, num_workers, epochs): with open('config.yaml', 'r') as f: config = yaml.load(f) self.generator = torch.nn.DataParallel( gan_factory.generator_factory(type).cuda()) if pre_trained_gen: self.generator.load_state_dict(torch.load(pre_trained_gen)) else: self.generator.apply(Utils.weights_init) if dataset == 'birds': self.dataset = Text2ImageDataset(config['birds_dataset_path'], split=1) elif dataset == 'flowers': self.dataset = Text2ImageDataset( config['flowers_val_dataset_path'], split=1) else: print( 'Dataset not supported, please select either birds or flowers.' ) exit() self.noise_dim = 100 self.batch_size = batch_size self.num_workers = num_workers self.num_epochs = epochs self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) self.logger = Logger(vis_screen) self.type = type
def main(): # global args args = parser.parse_args() # <editor-fold desc="Initialization"> if args.comment == "NONE": args.comment = args.method validate = args.validate == "true" if args.method == "coupled_vae_gan": trainer = coupled_vae_gan_trainer.coupled_vae_gan_trainer elif args.method == "coupled_vae": trainer = coupled_vae_trainer.coupled_vae_trainer elif args.method == "wgan": trainer = wgan_trainer.wgan_trainer elif args.method == "seq_wgan": trainer = seq_wgan_trainer.wgan_trainer elif args.method == "skip_thoughts": trainer = skipthoughts_vae_gan_trainer.coupled_vae_gan_trainer else: assert False, "Invalid method" # now = datetime.datetime.now() # current_date = now.strftime("%m-%d-%H-%M") assert args.text_criterion in ("MSE", "Cosine", "Hinge", "NLLLoss"), 'Invalid Loss Function' assert args.cm_criterion in ("MSE", "Cosine", "Hinge"), 'Invalid Loss Function' assert args.common_emb_ratio <= 1.0 and args.common_emb_ratio >= 0 #</editor-fold> # <editor-fold desc="Image Preprocessing"> # Image preprocessing //ATTENTION # For normalization, see https://github.com/pytorch/vision#models transform = transforms.Compose([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((.5, .5, .5), (.5, .5, .5)) # transforms.Normalize((0.485, 0.456, 0.406), # (0.229, 0.224, 0.225)) ]) #</editor-fold> # <editor-fold desc="Creating Embeddings"> if args.dataset != "coco": args.vocab_path = "./data/cub_vocab.pkl" # Load vocabulary wrapper. print("Loading Vocabulary...") with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Load Embeddings emb_size = args.word_embedding_size emb_path = args.embedding_path if args.embedding_path[-1] == '/': emb_path += 'glove.6B.' + str(emb_size) + 'd.txt' print("Loading Embeddings...") use_glove = args.use_glove == "true" if use_glove: emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size) word_emb = nn.Embedding(emb.size(0), emb.size(1)) word_emb.weight = nn.Parameter(emb) else: word_emb = nn.Embedding(len(vocab), emb_size) # Freeze weighs if args.fixed_embeddings == "true": word_emb.weight.requires_grad = True # </editor-fold> # <editor-fold desc="Data-Loaders"> # Build data loader print("Building Data Loader For Test Set...") if args.dataset == 'coco': data_loader = get_loader(args.image_dir, args.caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) print("Building Data Loader For Validation Set...") val_loader = get_loader(args.valid_dir, args.valid_caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) else: data_path = "data/cub.h5" dataset = Text2ImageDataset(data_path, split=0, vocab=vocab, transform=transform) data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn) dataset_val = Text2ImageDataset(data_path, split=1, vocab=vocab, transform=transform) val_loader = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn) # </editor-fold> txt_rc_loss = self.networks["coupled_vae"].text_reconstruction_loss(captions, txt2txt_out, lengths) # <editor-fold desc="Network Initialization"> print("Setting up the trainer...") model_trainer = trainer(args, word_emb, vocab) # <\editor-fold desc="Network Initialization"> for epoch in range(args.num_epochs): # <editor-fold desc = "Epoch Initialization"? # TRAINING TIME print('EPOCH ::: TRAINING ::: ' + str(epoch + 1)) batch_time = AverageMeter() end = time.time() bar = Bar(args.method if args.comment == "NONE" else args.method + "/" + args.comment, max=len(data_loader)) model_trainer.set_train_models() model_trainer.create_losses_meter(model_trainer.losses) for i, (images, captions, lengths) in enumerate(data_loader): if model_trainer.load_models(epoch): break # if i == 1: if i == len(data_loader) - 1: break images = to_var(images) # captions = to_var(captions[:,1:]) captions = to_var(captions) # lengths = to_var(torch.LongTensor(lengths) - 1) # print(captions.size()) lengths = to_var( torch.LongTensor(lengths)) # print(captions.size()) model_trainer.forward(epoch, images, captions, lengths, not i % args.image_save_interval) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if not model_trainer.iteration % args.log_step: # plot progress bar.suffix = bcolors.HEADER # bar.suffix += '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}\n'.format( bar.suffix += '({batch}/{size}) Iter: {bt:} | Time: {total:}-{eta:}\n'.format( batch=i, size=len(data_loader), # bt=batch_time.val, bt=model_trainer.iteration, total=bar.elapsed_td, eta=bar.eta_td, ) bar.suffix += bcolors.ENDC cnt = 0 for l_name, l_value in sorted(model_trainer.losses.items(), key=lambda x: x[0]): cnt += 1 bar.suffix += ' | {name}: {val:.3f}'.format( name=l_name, val=l_value.avg, ) if not cnt % 5: bar.suffix += "\n" bar.next() # </editor-fold desc = "Logging"> bar.finish() if validate: print('EPOCH ::: VALIDATION ::: ' + str(epoch + 1)) batch_time = AverageMeter() end = time.time() barName = args.method if args.comment == "NONE" else args.method + "/" + args.comment barName = "VAL:" + barName bar = Bar(barName, max=len(val_loader)) model_trainer.set_eval_models() model_trainer.create_metrics_meter(model_trainer.metrics) for i, (images, captions, lengths) in enumerate(val_loader): # if not model_trainer.keep_loading and not model_trainer.iteration % args.model: # model_trainer.save_models(epoch) if i == len(val_loader) - 1: break images = to_var(images) captions = to_var(captions[:, 1:]) # lengths = to_var(torch.LongTensor(lengths - 1)) # print(captions.size()) model_trainer.evaluate(epoch, images, captions, lengths, i == 0) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = bcolors.HEADER # bar.suffix += '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}\n'.format( bar.suffix += '({batch}/{size}) Iter: {bt:} | Time: {total:}-{eta:}\n'.format( batch=i, size=len(val_loader), # bt=batch_time.val, bt=model_trainer.iteration, total=bar.elapsed_td, eta=bar.eta_td, ) bar.suffix += bcolors.ENDC cnt = 0 for l_name, l_value in sorted(model_trainer.metrics.items(), key=lambda x: x[0]): cnt += 1 bar.suffix += ' | {name}: {val:.3f}'.format( name=l_name, val=l_value.avg, ) if not cnt % 5: bar.suffix += "\n" bar.next() bar.finish() # model_trainer.validate(val_loader) model_trainer.save_models(-1)
def __init__(self, type, dataset, split, lr, diter, vis_screen, save_path, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, batch_size, num_workers, epochs, pre_trained_disc_B, pre_trained_gen_B): with open('config.yaml', 'r') as f: config = yaml.load(f) # forward gan if is_cuda: self.generator = torch.nn.DataParallel( gan_factory.generator_factory('gan').cuda()) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory('gan').cuda()) self.generator2 = torch.nn.DataParallel( gan_factory.generator_factory('stage2_gan').cuda()) self.discriminator2 = torch.nn.DataParallel( gan_factory.discriminator_factory('stage2_gan').cuda()) else: self.generator = torch.nn.DataParallel( gan_factory.generator_factory('gan')) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory('gan')) self.generator2 = torch.nn.DataParallel( gan_factory.generator_factory('stage2_gan')) self.discriminator2 = torch.nn.DataParallel( gan_factory.discriminator_factory('stage2_gan')) if pre_trained_disc: self.discriminator.load_state_dict(torch.load(pre_trained_disc)) else: self.discriminator.apply(Utils.weights_init) if pre_trained_gen: self.generator.load_state_dict(torch.load(pre_trained_gen)) else: self.generator.apply(Utils.weights_init) if pre_trained_disc_B: self.discriminator2.load_state_dict(torch.load(pre_trained_disc_B)) else: self.discriminator2.apply(Utils.weights_init) if pre_trained_gen_B: self.generator2.load_state_dict(torch.load(pre_trained_gen_B)) else: self.generator2.apply(Utils.weights_init) if dataset == 'birds': with open('./data/birds_vocab.pkl', 'rb') as f: self.vocab = pickle.load(f) self.dataset = Text2ImageDataset(config['birds_dataset_path'], dataset_type='birds', vocab=self.vocab, split=split) elif dataset == 'flowers': with open('./data/flowers_vocab.pkl', 'rb') as f: self.vocab = pickle.load(f) self.dataset = Text2ImageDataset(config['flowers_dataset_path'], dataset_type='flowers', vocab=self.vocab, split=split) else: print( 'Dataset not supported, please select either birds or flowers.' ) exit() self.noise_dim = 100 self.batch_size = batch_size self.lr = lr self.beta1 = 0.5 self.num_epochs = epochs self.DITER = diter self.num_workers = num_workers self.l1_coef = l1_coef self.l2_coef = l2_coef self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, collate_fn=collate_fn) self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimD2 = torch.optim.Adam(self.discriminator2.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG2 = torch.optim.Adam(self.generator2.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.checkpoints_path = './checkpoints/' self.save_path = save_path self.type = type # TODO: put these as runtime.py params self.embed_size = 256 self.hidden_size = 512 self.num_layers = 1 self.gen_pretrain_num_epochs = 100 self.disc_pretrain_num_epochs = 20 self.figure_path = './figures/' if is_cuda: self.caption_generator = CaptionGenerator(self.embed_size, self.hidden_size, len(self.vocab), self.num_layers).cuda() self.caption_discriminator = CaptionDiscriminator( self.embed_size, self.hidden_size, len(self.vocab), self.num_layers).cuda() else: self.caption_generator = CaptionGenerator(self.embed_size, self.hidden_size, len(self.vocab), self.num_layers) self.caption_discriminator = CaptionDiscriminator( self.embed_size, self.hidden_size, len(self.vocab), self.num_layers) pretrained_caption_gen = './checkpoints/pretrained-generator-20.pkl' pretrained_caption_disc = './checkpoints/pretrained-discriminator-5.pkl' if os.path.exists(pretrained_caption_gen): print('loaded pretrained caption generator') self.caption_generator.load_state_dict( torch.load(pretrained_caption_gen)) if os.path.exists(pretrained_caption_disc): print('loaded pretrained caption discriminator') self.caption_discriminator.load_state_dict( torch.load(pretrained_caption_disc)) self.optim_captionG = torch.optim.Adam( list(self.caption_generator.parameters())) self.optim_captionD = torch.optim.Adam( list(self.caption_discriminator.parameters()))
def __init__(self, type, dataset, split, lr, diter, vis_screen, save_path, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, batch_size, num_workers, epochs, pre_trained_disc_B, pre_trained_gen_B): with open('config.yaml', 'r') as f: config = yaml.load(f) # forward gan if is_cuda: self.generator = torch.nn.DataParallel( gan_factory.generator_factory('gan').cuda()) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory('gan').cuda()) self.generator2 = torch.nn.DataParallel( gan_factory.generator_factory('stage2_gan').cuda()) self.discriminator2 = torch.nn.DataParallel( gan_factory.discriminator_factory('stage2_gan').cuda()) else: self.generator = torch.nn.DataParallel( gan_factory.generator_factory('gan')) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory('gan')) self.generator2 = torch.nn.DataParallel( gan_factory.generator_factory('stage2_gan')) self.discriminator2 = torch.nn.DataParallel( gan_factory.discriminator_factory('stage2_gan')) # inverse gan # TODO: pass these as parameters from runtime in the future # inverse_type = 'inverse_gan' # if is_cuda: # self.inv_generator = torch.nn.DataParallel(gan_factory.generator_factory(inverse_type).cuda()) # self.inv_discriminator = torch.nn.DataParallel(gan_factory.discriminator_factory(inverse_type).cuda()) # else: # self.inv_generator = torch.nn.DataParallel(gan_factory.generator_factory(inverse_type)) # self.inv_discriminator = torch.nn.DataParallel(gan_factory.discriminator_factory(inverse_type)) if pre_trained_disc: self.discriminator.load_state_dict(torch.load(pre_trained_disc)) else: self.discriminator.apply(Utils.weights_init) if pre_trained_gen: self.generator.load_state_dict(torch.load(pre_trained_gen)) else: self.generator.apply(Utils.weights_init) if pre_trained_disc_B: self.discriminator2.load_state_dict(torch.load(pre_trained_disc_B)) else: self.discriminator2.apply(Utils.weights_init) if pre_trained_gen_B: self.generator2.load_state_dict(torch.load(pre_trained_gen_B)) else: self.generator2.apply(Utils.weights_init) if dataset == 'birds': self.dataset = Text2ImageDataset(config['birds_dataset_path'], dataset_type='birds', split=split) elif dataset == 'flowers': self.dataset = Text2ImageDataset(config['flowers_dataset_path'], dataset_type='flowers', split=split) else: print( 'Dataset not supported, please select either birds or flowers.' ) exit() self.noise_dim = 100 self.batch_size = batch_size self.num_workers = num_workers self.lr = lr self.beta1 = 0.5 self.num_epochs = epochs self.DITER = diter self.l1_coef = l1_coef self.l2_coef = l2_coef self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimD2 = torch.optim.Adam(self.discriminator2.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG2 = torch.optim.Adam(self.generator2.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.logger = Logger(vis_screen) self.checkpoints_path = './checkpoints/' self.save_path = save_path self.type = type
def __init__(self, type, dataset, split, lr, lr_lower_boundary, lr_update_type, lr_update_step, diter, vis_screen, save_path, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, batch_size, num_workers, epochs, h, scale_size, num_channels, k, lambda_k, gamma, project, concat): with open('config.yaml', 'r') as f: config = yaml.load(f) self.generator = gan_factory.generator_factory(type, dataset, batch_size, h, scale_size, num_channels).cuda() self.discriminator = gan_factory.discriminator_factory( type, batch_size, h, scale_size, num_channels).cuda() print(self.discriminator) if pre_trained_disc: self.discriminator.load_state_dict(torch.load(pre_trained_disc)) else: self.discriminator.apply(Utils.weights_init) if pre_trained_gen: self.generator.load_state_dict(torch.load(pre_trained_gen)) else: self.generator.apply(Utils.weights_init) self.dataset_name = dataset if self.dataset_name == 'birds': self.dataset = Text2ImageDataset(config['birds_dataset_path'], split=split) elif self.dataset_name == 'flowers': self.dataset = Text2ImageDataset(config['flowers_dataset_path'], split=split) elif self.dataset_name == 'youtubers': self.dataset = OneHot2YoutubersDataset( config['youtubers_dataset_path'], transform=Rescale(64), split=split) else: print( 'Dataset not supported, please select either birds or flowers.' ) exit() self.noise_dim = 100 self.batch_size = batch_size self.num_workers = num_workers self.lr = lr self.lr_lower_boundary = lr_lower_boundary self.lr_update_type = lr_update_type self.lr_update_step = lr_update_step self.beta1 = 0.5 self.num_epochs = epochs self.DITER = diter self.apply_projection = project self.apply_concat = concat self.l1_coef = l1_coef self.l2_coef = l2_coef self.h = h self.scale_size = scale_size self.num_channels = num_channels self.k = k self.lambda_k = lambda_k self.gamma = gamma self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) self.optimD = torch.optim.Adam(filter(lambda p: p.requires_grad, self.discriminator.parameters()), lr=0.0004, betas=(self.beta1, 0.999)) self.optimG = torch.optim.Adam(filter(lambda p: p.requires_grad, self.generator.parameters()), lr=0.0001, betas=(self.beta1, 0.999)) self.logger = Logger(vis_screen, save_path) self.checkpoints_path = 'checkpoints' self.save_path = save_path self.type = type
def __init__(self, type, dataset, split, lr, diter, mode, vis_screen, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, batch_size, num_workers, epochs): with open('config.yaml', 'r') as f: config = yaml.load(f) self.generator = torch.nn.DataParallel( gan_factory.generator_factory(type).cuda()) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory(type).cuda()) print(self.generator) if pre_trained_disc: self.discriminator.load_state_dict(torch.load(pre_trained_disc)) else: self.discriminator.apply(Utils.weights_init) if pre_trained_gen: self.generator.load_state_dict(torch.load(pre_trained_gen)) else: self.generator.apply(Utils.weights_init) if dataset == 'oxford': if mode is False: #if we are in training mode: self.dataset = Text2ImageDataset( config['oxford_dataset_path_train'], split=split) else: # testing mode: self.dataset = Text2ImageDataset( config['oxford_dataset_path_test'], split=split) else: print( 'Dataset not supported, please select either birds or flowers.' ) exit() self.noise_dim = 100 self.batch_size = batch_size self.num_workers = num_workers self.lr = lr self.beta1 = 0.5 self.num_epochs = epochs self.DITER = diter self.l1_coef = l1_coef self.l2_coef = l2_coef self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.logger = Logger(vis_screen) now = datetime.datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') self.save_path = 'output/%s_%s' % \ (dataset, timestamp) self.checkpoints_path = os.path.join(self.save_path, 'checkpoints') self.image_dir = os.path.join(self.save_path, 'images') self.type = type
def __init__(self, dataset, split, lr, save_path, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, val_pre_trained_gen, val_pre_trained_disc, batch_size, num_workers, epochs, dataset_paths, arrangement, sampling): with open('config.yaml', 'r') as f: # Wsteczna kompatybilnosc dla Text2ImageDataset config = yaml.safe_load(f) self.generator = torch.nn.DataParallel( gan_factory.generator_factory('gan').cuda()) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory('gan').cuda()) if pre_trained_disc: self.discriminator.load_state_dict(torch.load(pre_trained_disc)) else: self.discriminator.apply(Utils.weights_init) if pre_trained_gen: self.generator.load_state_dict(torch.load(pre_trained_gen)) else: self.generator.apply(Utils.weights_init) self.val_pre_trained_gen = val_pre_trained_gen self.val_pre_trained_disc = val_pre_trained_disc self.arrangement = arrangement self.sampling = sampling if dataset == 'birds': # Wsteczna kompatybilnosc dla Text2ImageDataset self.dataset = Text2ImageDataset( config['birds_dataset_path'], split=split ) # '...\Text2Image\datasets\ee285f-public\caltech_ucsd_birds\birds.hdf5' elif dataset == 'flowers': # Wsteczna kompatybilnosc dla Text2ImageDataset self.dataset = Text2ImageDataset( config['flowers_dataset_path'], split=split ) # '...\Text2Image\datasets\ee285f-public\oxford_flowers\flowers.hdf5' elif dataset == 'live': self.dataset_dict = easydict.EasyDict(dataset_paths) self.dataset = Text2ImageDataset2( datasetFile=self.dataset_dict.datasetFile, imagesDir=self.dataset_dict.imagesDir, textDir=self.dataset_dict.textDir, split=split, arrangement=arrangement, sampling=sampling) else: print('Dataset not supported.') print('Images =', len(self.dataset)) self.noise_dim = 100 self.batch_size = batch_size self.num_workers = num_workers self.lr = lr self.beta1 = 0.5 self.num_epochs = epochs self.l1_coef = l1_coef self.l2_coef = l2_coef self.data_loader = DataLoader( self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers ) # shuffle=True - przetasowuje zbior danych w kazdej epoce self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.checkpoints_path = 'checkpoints' self.save_path = save_path self.loss_filename_t = 'gd_loss_t.csv' self.loss_filename_v = 'gd_loss_v.csv'
def __init__(self, dataset='flowers', split=0, lr=2e-4, diter=5, save_path='./Log', l1_coef=90, l2_coef=100, pre_trained_gen=False, pre_trained_disc=False, batch_size=64, num_workers=16, epochs=800): self.generator = gan_cls_new.generator().to(DEVICE) self.discriminator = gan_cls_new.discriminator().to(DEVICE) if pre_trained_disc: self.discriminator.load_state_dict( torch.load('./Log/checkpoints/disc_190.pth')) else: self.discriminator.apply(Utils.weights_init) if pre_trained_gen: self.generator.load_state_dict( torch.load('./Log/checkpoints/gen_190.pth')) else: self.generator.apply(Utils.weights_init) # choose smaller flower data set if dataset == 'flowers': self.dataset = Text2ImageDataset('./data/flowers.hdf5', split=split) elif dataset == 'birds': self.dataset = Text2ImageDataset('./data/birds.hdf5', split=split) else: print( 'Data not supported, please select either birds.hdf5 or flowers.hdf5' ) exit() # print(self.dataset.__len__()) # 29390 training samples self.noise_dim = 100 self.batch_size = batch_size self.num_workers = num_workers self.lr = lr self.beta1 = 0.5 self.num_epochs = epochs self.DITER = diter self.l1_coef = l1_coef self.l2_coef = l2_coef self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.checkpoints_path = 'checkpoints/2gen_800epochs' self.save_path = save_path