def __init__(self, type, dataset, split, lr, diter, vis_screen, save_path, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, pre_trained_encod, batch_size, num_workers, epochs, visualize): with open('config.yaml', 'r') as f: config = yaml.load(f) if cuda: self.generator = torch.nn.DataParallel(gan_factory.generator_factory(type).cuda()) self.discriminator = torch.nn.DataParallel(gan_factory.discriminator_factory(type).cuda()) self.encoder = torch.nn.DataParallel(gan_factory.encoder_factory(type).cuda()) else: self.generator = torch.nn.DataParallel(gan_factory.generator_factory(type)) self.discriminator = torch.nn.DataParallel(gan_factory.discriminator_factory(type)) self.encoder = torch.nn.DataParallel(gan_factory.encoder_factory(type)) if pre_trained_disc: self.discriminator.load_state_dict(torch.load(pre_trained_disc)) else: self.discriminator.apply(Utils.weights_init) if pre_trained_gen: self.generator.load_state_dict(torch.load(pre_trained_gen)) else: self.generator.apply(Utils.weights_init) if pre_trained_encod: self.encoder.load_state_dict(torch.load(pre_trained_encod)) else: self.encoder.apply(Utils.weights_init) if dataset == 'birds': self.dataset = Text2ImageDataset(config['birds_dataset_path'], split=split) elif dataset == 'flowers': self.dataset = Text2ImageDataset(config['flowers_dataset_path'], split=split) else: print('Dataset not supported, please select either birds or flowers.') exit() self.noise_dim = 100 self.batch_size = batch_size self.num_workers = num_workers self.lr = lr self.beta1 = 0.5 self.num_epochs = epochs self.DITER = diter self.l1_coef = l1_coef self.l2_coef = l2_coef self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimE = torch.optim.Adam(self.encoder.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.checkpoints_path = 'checkpoints' self.save_path = save_path self.type = type self.visualize = visualize if self.visualize: self.logger = Logger(vis_screen)
def __init__(self, datasetFile, textDir, checking_folder, lang, client_txt, pre_trained_gen, pre_trained_disc, ID, batch_size=1): self.generator = torch.nn.DataParallel( gan_factory.generator_factory('gan').cuda()) self.generator.load_state_dict(torch.load(pre_trained_gen)) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory('gan').cuda()) self.discriminator.load_state_dict(torch.load(pre_trained_disc)) self.checking_folder = checking_folder self.lang = lang self.client_txt = client_txt self.filename = ID self.batch_size = batch_size cl = CorpusLoader(datasetFile=datasetFile, textDir=textDir) self.vectorizer = cl.TrainVocab()
def __init__(self, type, dataset, split, lr, diter, vis_screen, save_path, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, batch_size, num_workers, epochs, pre_trained_disc_B, pre_trained_gen_B): with open('config.yaml', 'r') as f: config = yaml.load(f) # forward gan if is_cuda: self.generator = torch.nn.DataParallel( gan_factory.generator_factory('gan').cuda()) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory('gan').cuda()) self.generator2 = torch.nn.DataParallel( gan_factory.generator_factory('stage2_gan').cuda()) self.discriminator2 = torch.nn.DataParallel( gan_factory.discriminator_factory('stage2_gan').cuda()) else: self.generator = torch.nn.DataParallel( gan_factory.generator_factory('gan')) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory('gan')) self.generator2 = torch.nn.DataParallel( gan_factory.generator_factory('stage2_gan')) self.discriminator2 = torch.nn.DataParallel( gan_factory.discriminator_factory('stage2_gan')) if pre_trained_disc: self.discriminator.load_state_dict(torch.load(pre_trained_disc)) else: self.discriminator.apply(Utils.weights_init) if pre_trained_gen: self.generator.load_state_dict(torch.load(pre_trained_gen)) else: self.generator.apply(Utils.weights_init) if pre_trained_disc_B: self.discriminator2.load_state_dict(torch.load(pre_trained_disc_B)) else: self.discriminator2.apply(Utils.weights_init) if pre_trained_gen_B: self.generator2.load_state_dict(torch.load(pre_trained_gen_B)) else: self.generator2.apply(Utils.weights_init) if dataset == 'birds': with open('./data/birds_vocab.pkl', 'rb') as f: self.vocab = pickle.load(f) self.dataset = Text2ImageDataset(config['birds_dataset_path'], dataset_type='birds', vocab=self.vocab, split=split) elif dataset == 'flowers': with open('./data/flowers_vocab.pkl', 'rb') as f: self.vocab = pickle.load(f) self.dataset = Text2ImageDataset(config['flowers_dataset_path'], dataset_type='flowers', vocab=self.vocab, split=split) else: print( 'Dataset not supported, please select either birds or flowers.' ) exit() self.noise_dim = 100 self.batch_size = batch_size self.lr = lr self.beta1 = 0.5 self.num_epochs = epochs self.DITER = diter self.num_workers = num_workers self.l1_coef = l1_coef self.l2_coef = l2_coef self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, collate_fn=collate_fn) self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimD2 = torch.optim.Adam(self.discriminator2.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG2 = torch.optim.Adam(self.generator2.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.checkpoints_path = './checkpoints/' self.save_path = save_path self.type = type # TODO: put these as runtime.py params self.embed_size = 256 self.hidden_size = 512 self.num_layers = 1 self.gen_pretrain_num_epochs = 100 self.disc_pretrain_num_epochs = 20 self.figure_path = './figures/' if is_cuda: self.caption_generator = CaptionGenerator(self.embed_size, self.hidden_size, len(self.vocab), self.num_layers).cuda() self.caption_discriminator = CaptionDiscriminator( self.embed_size, self.hidden_size, len(self.vocab), self.num_layers).cuda() else: self.caption_generator = CaptionGenerator(self.embed_size, self.hidden_size, len(self.vocab), self.num_layers) self.caption_discriminator = CaptionDiscriminator( self.embed_size, self.hidden_size, len(self.vocab), self.num_layers) pretrained_caption_gen = './checkpoints/pretrained-generator-20.pkl' pretrained_caption_disc = './checkpoints/pretrained-discriminator-5.pkl' if os.path.exists(pretrained_caption_gen): print('loaded pretrained caption generator') self.caption_generator.load_state_dict( torch.load(pretrained_caption_gen)) if os.path.exists(pretrained_caption_disc): print('loaded pretrained caption discriminator') self.caption_discriminator.load_state_dict( torch.load(pretrained_caption_disc)) self.optim_captionG = torch.optim.Adam( list(self.caption_generator.parameters())) self.optim_captionD = torch.optim.Adam( list(self.caption_discriminator.parameters()))
def __init__(self, type, dataset, split, lr, diter, vis_screen, save_path, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, batch_size, num_workers, epochs, pre_trained_disc_B, pre_trained_gen_B): with open('config.yaml', 'r') as f: config = yaml.load(f) # forward gan if is_cuda: self.generator = torch.nn.DataParallel( gan_factory.generator_factory('gan').cuda()) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory('gan').cuda()) self.generator2 = torch.nn.DataParallel( gan_factory.generator_factory('stage2_gan').cuda()) self.discriminator2 = torch.nn.DataParallel( gan_factory.discriminator_factory('stage2_gan').cuda()) else: self.generator = torch.nn.DataParallel( gan_factory.generator_factory('gan')) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory('gan')) self.generator2 = torch.nn.DataParallel( gan_factory.generator_factory('stage2_gan')) self.discriminator2 = torch.nn.DataParallel( gan_factory.discriminator_factory('stage2_gan')) # inverse gan # TODO: pass these as parameters from runtime in the future # inverse_type = 'inverse_gan' # if is_cuda: # self.inv_generator = torch.nn.DataParallel(gan_factory.generator_factory(inverse_type).cuda()) # self.inv_discriminator = torch.nn.DataParallel(gan_factory.discriminator_factory(inverse_type).cuda()) # else: # self.inv_generator = torch.nn.DataParallel(gan_factory.generator_factory(inverse_type)) # self.inv_discriminator = torch.nn.DataParallel(gan_factory.discriminator_factory(inverse_type)) if pre_trained_disc: self.discriminator.load_state_dict(torch.load(pre_trained_disc)) else: self.discriminator.apply(Utils.weights_init) if pre_trained_gen: self.generator.load_state_dict(torch.load(pre_trained_gen)) else: self.generator.apply(Utils.weights_init) if pre_trained_disc_B: self.discriminator2.load_state_dict(torch.load(pre_trained_disc_B)) else: self.discriminator2.apply(Utils.weights_init) if pre_trained_gen_B: self.generator2.load_state_dict(torch.load(pre_trained_gen_B)) else: self.generator2.apply(Utils.weights_init) if dataset == 'birds': self.dataset = Text2ImageDataset(config['birds_dataset_path'], split=split) elif dataset == 'flowers': self.dataset = Text2ImageDataset(config['flowers_dataset_path'], split=split) else: print( 'Dataset not supported, please select either birds or flowers.' ) exit() self.noise_dim = 100 self.batch_size = batch_size self.num_workers = num_workers self.lr = lr self.beta1 = 0.5 self.num_epochs = epochs self.DITER = diter self.l1_coef = l1_coef self.l2_coef = l2_coef self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimD2 = torch.optim.Adam(self.discriminator2.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG2 = torch.optim.Adam(self.generator2.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.logger = Logger(vis_screen) self.checkpoints_path = './checkpoints/' self.save_path = save_path self.type = type
def __init__(self, type, dataset, split, lr, lr_lower_boundary, lr_update_type, lr_update_step, diter, vis_screen, save_path, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, batch_size, num_workers, epochs, h, scale_size, num_channels, k, lambda_k, gamma, project, concat): with open('config.yaml', 'r') as f: config = yaml.load(f) self.generator = gan_factory.generator_factory(type, dataset, batch_size, h, scale_size, num_channels).cuda() self.discriminator = gan_factory.discriminator_factory( type, batch_size, h, scale_size, num_channels).cuda() print(self.discriminator) if pre_trained_disc: self.discriminator.load_state_dict(torch.load(pre_trained_disc)) else: self.discriminator.apply(Utils.weights_init) if pre_trained_gen: self.generator.load_state_dict(torch.load(pre_trained_gen)) else: self.generator.apply(Utils.weights_init) self.dataset_name = dataset if self.dataset_name == 'birds': self.dataset = Text2ImageDataset(config['birds_dataset_path'], split=split) elif self.dataset_name == 'flowers': self.dataset = Text2ImageDataset(config['flowers_dataset_path'], split=split) elif self.dataset_name == 'youtubers': self.dataset = OneHot2YoutubersDataset( config['youtubers_dataset_path'], transform=Rescale(64), split=split) else: print( 'Dataset not supported, please select either birds or flowers.' ) exit() self.noise_dim = 100 self.batch_size = batch_size self.num_workers = num_workers self.lr = lr self.lr_lower_boundary = lr_lower_boundary self.lr_update_type = lr_update_type self.lr_update_step = lr_update_step self.beta1 = 0.5 self.num_epochs = epochs self.DITER = diter self.apply_projection = project self.apply_concat = concat self.l1_coef = l1_coef self.l2_coef = l2_coef self.h = h self.scale_size = scale_size self.num_channels = num_channels self.k = k self.lambda_k = lambda_k self.gamma = gamma self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) self.optimD = torch.optim.Adam(filter(lambda p: p.requires_grad, self.discriminator.parameters()), lr=0.0004, betas=(self.beta1, 0.999)) self.optimG = torch.optim.Adam(filter(lambda p: p.requires_grad, self.generator.parameters()), lr=0.0001, betas=(self.beta1, 0.999)) self.logger = Logger(vis_screen, save_path) self.checkpoints_path = 'checkpoints' self.save_path = save_path self.type = type
def __init__(self, type, dataset, split, lr, diter, vis_screen, save_path, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, batch_size, num_workers, epochs, pre_trained_disc_B, pre_trained_gen_B, pre_trained_caption_gen, pre_trained_caption_disc, caption_embed_size, caption_hidden_size, caption_num_layers, caption_gen_pretrain_num_epochs, caption_disc_pretrain_num_epochs, caption_initial_noise=False): # with open('config.yaml', 'r') as f: # config = yaml.load(f) config = utils.load_config() # forward gan if is_cuda: self.generator = torch.nn.DataParallel( gan_factory.generator_factory('gan').cuda()) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory('gan').cuda()) self.generator2 = torch.nn.DataParallel( gan_factory.generator_factory('stage2_gan').cuda()) self.discriminator2 = torch.nn.DataParallel( gan_factory.discriminator_factory('stage2_gan').cuda()) else: self.generator = torch.nn.DataParallel( gan_factory.generator_factory('gan')) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory('gan')) self.generator2 = torch.nn.DataParallel( gan_factory.generator_factory('stage2_gan')) self.discriminator2 = torch.nn.DataParallel( gan_factory.discriminator_factory('stage2_gan')) if pre_trained_disc: print('Loading pre_trained_disc A from: %s' % os.path.abspath(pre_trained_disc)) self.discriminator.load_state_dict(torch.load(pre_trained_disc)) else: self.discriminator.apply(Utils.weights_init) if pre_trained_gen: print('Loading pre_trained_gen A from: %s' % os.path.abspath(pre_trained_gen)) self.generator.load_state_dict(torch.load(pre_trained_gen)) else: self.generator.apply(Utils.weights_init) if pre_trained_disc_B: print('Loading pre_trained_disc B from: %s' % os.path.abspath(pre_trained_disc_B)) self.discriminator2.load_state_dict(torch.load(pre_trained_disc_B)) else: self.discriminator2.apply(Utils.weights_init) if pre_trained_gen_B: print('Loading pre_trained_gen B from: %s' % os.path.abspath(pre_trained_gen_B)) self.generator2.load_state_dict(torch.load(pre_trained_gen_B)) else: self.generator2.apply(Utils.weights_init) if dataset == 'birds': with open(config['birds_vocab_path'], 'rb') as f: self.vocab = pickle.load(f) self.dataset = Text2ImageDataset(config['birds_dataset_path'], dataset_type='birds', vocab=self.vocab, split=split) elif dataset == 'flowers': with open(config['flowers_vocab_path'], 'rb') as f: self.vocab = pickle.load(f) self.dataset = Text2ImageDataset(config['flowers_dataset_path'], dataset_type='flowers', vocab=self.vocab, split=split) else: print( 'Dataset not supported, please select either birds or flowers.' ) exit() self.noise_dim = 100 self.batch_size = batch_size self.lr = lr self.beta1 = 0.5 self.num_epochs = epochs self.DITER = diter self.num_workers = num_workers self.l1_coef = l1_coef self.l2_coef = l2_coef self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, collate_fn=collate_fn) self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimD2 = torch.optim.Adam(self.discriminator2.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG2 = torch.optim.Adam(self.generator2.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.checkpoints_path = './checkpoints/' self.save_path = save_path self.type = type # settings for caption GAN self.embed_size = caption_embed_size self.hidden_size = caption_hidden_size self.num_layers = caption_num_layers self.caption_initial_noise = caption_initial_noise self.gen_pretrain_num_epochs = caption_gen_pretrain_num_epochs self.disc_pretrain_num_epochs = caption_disc_pretrain_num_epochs self.figure_path = './figures/' if is_cuda: self.caption_generator = Image2TextGenerator( self.embed_size, self.hidden_size, len(self.vocab), self.num_layers, initial_noise=self.caption_initial_noise).cuda() self.caption_discriminator = Image2TextDiscriminator( self.embed_size, self.hidden_size, len(self.vocab), self.num_layers).cuda() else: self.caption_generator = Image2TextGenerator( self.embed_size, self.hidden_size, len(self.vocab), self.num_layers, initial_noise=self.caption_initial_noise) self.caption_discriminator = Image2TextDiscriminator( self.embed_size, self.hidden_size, len(self.vocab), self.num_layers) if pre_trained_caption_gen and os.path.exists(pre_trained_caption_gen): print('loaded pretrained caption generator') self.caption_generator.load_state_dict( torch.load(pre_trained_caption_gen)) if pre_trained_caption_disc and os.path.exists( pre_trained_caption_disc): print('loaded pretrained caption discriminator') self.caption_discriminator.load_state_dict( torch.load(pre_trained_caption_disc)) self.optim_captionG = torch.optim.Adam( list(self.caption_generator.parameters())) self.optim_captionD = torch.optim.Adam( list(self.caption_discriminator.parameters()))
def __init__(self, type, dataset, split, lr, diter, mode, vis_screen, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, batch_size, num_workers, epochs): with open('config.yaml', 'r') as f: config = yaml.load(f) self.generator = torch.nn.DataParallel( gan_factory.generator_factory(type).cuda()) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory(type).cuda()) print(self.generator) if pre_trained_disc: self.discriminator.load_state_dict(torch.load(pre_trained_disc)) else: self.discriminator.apply(Utils.weights_init) if pre_trained_gen: self.generator.load_state_dict(torch.load(pre_trained_gen)) else: self.generator.apply(Utils.weights_init) if dataset == 'oxford': if mode is False: #if we are in training mode: self.dataset = Text2ImageDataset( config['oxford_dataset_path_train'], split=split) else: # testing mode: self.dataset = Text2ImageDataset( config['oxford_dataset_path_test'], split=split) else: print( 'Dataset not supported, please select either birds or flowers.' ) exit() self.noise_dim = 100 self.batch_size = batch_size self.num_workers = num_workers self.lr = lr self.beta1 = 0.5 self.num_epochs = epochs self.DITER = diter self.l1_coef = l1_coef self.l2_coef = l2_coef self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.logger = Logger(vis_screen) now = datetime.datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') self.save_path = 'output/%s_%s' % \ (dataset, timestamp) self.checkpoints_path = os.path.join(self.save_path, 'checkpoints') self.image_dir = os.path.join(self.save_path, 'images') self.type = type
def validate(self, cls, val_dataset, val_data_loader, epoch): criterion = nn.BCELoss() l2_loss = nn.MSELoss() l1_loss = nn.L1Loss() generator = torch.nn.DataParallel( gan_factory.generator_factory('gan').cuda()) generator.load_state_dict( torch.load(self.val_pre_trained_gen.replace("XXX", str(epoch)))) discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory('gan').cuda()) discriminator.load_state_dict( torch.load(self.val_pre_trained_disc.replace("XXX", str(epoch)))) d_epoch_loss = g_epoch_loss = 0.0 iteration = 0 dt = datetime.datetime.now() print('Validating... Started on', dt.date(), 'at', dt.time().replace(microsecond=0)) for sample in val_data_loader: iteration += 1 right_images = sample['right_images'] right_embed = sample['right_embed'] wrong_images = sample['wrong_images'] right_images = Variable(right_images.float()).cuda() right_embed = Variable(right_embed.float()).cuda() wrong_images = Variable(wrong_images.float()).cuda() real_labels = torch.ones(right_images.size(0)) fake_labels = torch.zeros(right_images.size(0)) smoothed_real_labels = torch.FloatTensor( Utils.smooth_label(real_labels.numpy(), -0.1)) real_labels = Variable(real_labels).cuda() smoothed_real_labels = Variable(smoothed_real_labels).cuda() fake_labels = Variable(fake_labels).cuda() outputs, activation_real = discriminator(right_images, right_embed) real_loss = criterion(outputs, smoothed_real_labels) real_score = outputs wrong_loss = wrong_score = 0 if cls: outputs, _ = discriminator(wrong_images, right_embed) wrong_loss = criterion(outputs, fake_labels) wrong_score = outputs noise = Variable(torch.randn(right_images.size(0), 100)).cuda() noise = noise.view(noise.size(0), 100, 1, 1) fake_images = generator(right_embed, noise) outputs, _ = discriminator(fake_images, right_embed) fake_loss = criterion(outputs, fake_labels) fake_score = outputs d_loss = real_loss + fake_loss if cls: d_loss = d_loss + wrong_loss d_epoch_loss += outputs.shape[0] * d_loss.item() noise = Variable(torch.randn(right_images.size(0), 100)).cuda() noise = noise.view(noise.size(0), 100, 1, 1) fake_images = generator(right_embed, noise) outputs, activation_fake = discriminator(fake_images, right_embed) _, activation_real = discriminator(right_images, right_embed) activation_fake = torch.mean(activation_fake, 0) activation_real = torch.mean(activation_real, 0) g_loss = criterion(outputs, real_labels) + self.l2_coef * l2_loss( activation_fake[0], activation_real[0].detach() ) + self.l1_coef * l1_loss(fake_images, right_images) # g_loss = BCE() + MSE(pomiedzy srednimi aktywacyjnymi dyskryminatora dla falszywych i prawdziwych obrazow) + MAE(pomiedzy obrazami) # g_loss = BCE() + bład średniokwadratowy wyznaczony pomiędzy średnimi aktywacjami dyskryminatora dla fałszywych i prawdziwych obrazów + bezwzględna róznica pomiędzy obrazami g_epoch_loss += outputs.shape[0] * g_loss.item() iters_cnt = 500 if iteration % iters_cnt == 0: percentage = (iteration * 100.0) / len(val_data_loader) print('Samples (iterations):', round(percentage, 2), '%') dt = datetime.datetime.now() print('Validation completed on', dt.date(), 'at', dt.time().replace(microsecond=0)) d_err = d_epoch_loss / len(val_dataset) g_err = g_epoch_loss / len(val_dataset) print('Discriminator loss:', d_err, '| generator loss:', g_err) f = open(self.loss_filename_v, 'a') f.write(str(epoch) + " ; " + str(d_err) + " ; " + str(g_err) + "\n") f.close() gc.collect() torch.cuda.empty_cache()
def __init__(self, dataset, split, lr, save_path, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, val_pre_trained_gen, val_pre_trained_disc, batch_size, num_workers, epochs, dataset_paths, arrangement, sampling): with open('config.yaml', 'r') as f: # Wsteczna kompatybilnosc dla Text2ImageDataset config = yaml.safe_load(f) self.generator = torch.nn.DataParallel( gan_factory.generator_factory('gan').cuda()) self.discriminator = torch.nn.DataParallel( gan_factory.discriminator_factory('gan').cuda()) if pre_trained_disc: self.discriminator.load_state_dict(torch.load(pre_trained_disc)) else: self.discriminator.apply(Utils.weights_init) if pre_trained_gen: self.generator.load_state_dict(torch.load(pre_trained_gen)) else: self.generator.apply(Utils.weights_init) self.val_pre_trained_gen = val_pre_trained_gen self.val_pre_trained_disc = val_pre_trained_disc self.arrangement = arrangement self.sampling = sampling if dataset == 'birds': # Wsteczna kompatybilnosc dla Text2ImageDataset self.dataset = Text2ImageDataset( config['birds_dataset_path'], split=split ) # '...\Text2Image\datasets\ee285f-public\caltech_ucsd_birds\birds.hdf5' elif dataset == 'flowers': # Wsteczna kompatybilnosc dla Text2ImageDataset self.dataset = Text2ImageDataset( config['flowers_dataset_path'], split=split ) # '...\Text2Image\datasets\ee285f-public\oxford_flowers\flowers.hdf5' elif dataset == 'live': self.dataset_dict = easydict.EasyDict(dataset_paths) self.dataset = Text2ImageDataset2( datasetFile=self.dataset_dict.datasetFile, imagesDir=self.dataset_dict.imagesDir, textDir=self.dataset_dict.textDir, split=split, arrangement=arrangement, sampling=sampling) else: print('Dataset not supported.') print('Images =', len(self.dataset)) self.noise_dim = 100 self.batch_size = batch_size self.num_workers = num_workers self.lr = lr self.beta1 = 0.5 self.num_epochs = epochs self.l1_coef = l1_coef self.l2_coef = l2_coef self.data_loader = DataLoader( self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers ) # shuffle=True - przetasowuje zbior danych w kazdej epoce self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999)) self.checkpoints_path = 'checkpoints' self.save_path = save_path self.loss_filename_t = 'gd_loss_t.csv' self.loss_filename_v = 'gd_loss_v.csv'