Пример #1
0
    def __init__(self, type, dataset, split, lr, diter, vis_screen, save_path, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, pre_trained_encod, batch_size, num_workers, epochs, visualize):
        with open('config.yaml', 'r') as f:
            config = yaml.load(f)

        if cuda:
            self.generator = torch.nn.DataParallel(gan_factory.generator_factory(type).cuda())
            self.discriminator = torch.nn.DataParallel(gan_factory.discriminator_factory(type).cuda())
            self.encoder = torch.nn.DataParallel(gan_factory.encoder_factory(type).cuda())
        else:
            self.generator = torch.nn.DataParallel(gan_factory.generator_factory(type))
            self.discriminator = torch.nn.DataParallel(gan_factory.discriminator_factory(type))
            self.encoder = torch.nn.DataParallel(gan_factory.encoder_factory(type))
        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        if pre_trained_encod:
            self.encoder.load_state_dict(torch.load(pre_trained_encod))
        else:
            self.encoder.apply(Utils.weights_init)

        if dataset == 'birds':
            self.dataset = Text2ImageDataset(config['birds_dataset_path'], split=split)
        elif dataset == 'flowers':
            self.dataset = Text2ImageDataset(config['flowers_dataset_path'], split=split)
        else:
            print('Dataset not supported, please select either birds or flowers.')
            exit()

        self.noise_dim = 100
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.DITER = diter

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef

        self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True,
                                num_workers=self.num_workers)

        self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999))
        self.optimE = torch.optim.Adam(self.encoder.parameters(), lr=self.lr, betas=(self.beta1, 0.999))

        self.checkpoints_path = 'checkpoints'
        self.save_path = save_path
        self.type = type
        self.visualize = visualize
        if self.visualize:
            self.logger = Logger(vis_screen)
Пример #2
0
    def __init__(self,
                 datasetFile,
                 textDir,
                 checking_folder,
                 lang,
                 client_txt,
                 pre_trained_gen,
                 pre_trained_disc,
                 ID,
                 batch_size=1):

        self.generator = torch.nn.DataParallel(
            gan_factory.generator_factory('gan').cuda())
        self.generator.load_state_dict(torch.load(pre_trained_gen))

        self.discriminator = torch.nn.DataParallel(
            gan_factory.discriminator_factory('gan').cuda())
        self.discriminator.load_state_dict(torch.load(pre_trained_disc))

        self.checking_folder = checking_folder
        self.lang = lang
        self.client_txt = client_txt
        self.filename = ID
        self.batch_size = batch_size

        cl = CorpusLoader(datasetFile=datasetFile, textDir=textDir)
        self.vectorizer = cl.TrainVocab()
Пример #3
0
    def __init__(self, type, dataset, vis_screen, pre_trained_gen, batch_size,
                 num_workers, epochs):
        with open('config.yaml', 'r') as f:
            config = yaml.load(f)

        self.generator = torch.nn.DataParallel(
            gan_factory.generator_factory(type).cuda())

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        if dataset == 'birds':
            self.dataset = Text2ImageDataset(config['birds_dataset_path'],
                                             split=1)
        elif dataset == 'flowers':
            self.dataset = Text2ImageDataset(
                config['flowers_val_dataset_path'], split=1)
        else:
            print(
                'Dataset not supported, please select either birds or flowers.'
            )
            exit()

        self.noise_dim = 100
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.num_epochs = epochs

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=False,
                                      num_workers=self.num_workers)

        self.logger = Logger(vis_screen)
        self.type = type
Пример #4
0
    def __init__(self, type, dataset, split, lr, diter, vis_screen, save_path,
                 l1_coef, l2_coef, pre_trained_gen, pre_trained_disc,
                 batch_size, num_workers, epochs, pre_trained_disc_B,
                 pre_trained_gen_B):
        with open('config.yaml', 'r') as f:
            config = yaml.load(f)

        # forward gan
        if is_cuda:
            self.generator = torch.nn.DataParallel(
                gan_factory.generator_factory('gan').cuda())
            self.discriminator = torch.nn.DataParallel(
                gan_factory.discriminator_factory('gan').cuda())
            self.generator2 = torch.nn.DataParallel(
                gan_factory.generator_factory('stage2_gan').cuda())
            self.discriminator2 = torch.nn.DataParallel(
                gan_factory.discriminator_factory('stage2_gan').cuda())
        else:
            self.generator = torch.nn.DataParallel(
                gan_factory.generator_factory('gan'))
            self.discriminator = torch.nn.DataParallel(
                gan_factory.discriminator_factory('gan'))
            self.generator2 = torch.nn.DataParallel(
                gan_factory.generator_factory('stage2_gan'))
            self.discriminator2 = torch.nn.DataParallel(
                gan_factory.discriminator_factory('stage2_gan'))

        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        if pre_trained_disc_B:
            self.discriminator2.load_state_dict(torch.load(pre_trained_disc_B))
        else:
            self.discriminator2.apply(Utils.weights_init)

        if pre_trained_gen_B:
            self.generator2.load_state_dict(torch.load(pre_trained_gen_B))
        else:
            self.generator2.apply(Utils.weights_init)

        if dataset == 'birds':
            with open('./data/birds_vocab.pkl', 'rb') as f:
                self.vocab = pickle.load(f)
            self.dataset = Text2ImageDataset(config['birds_dataset_path'],
                                             dataset_type='birds',
                                             vocab=self.vocab,
                                             split=split)
        elif dataset == 'flowers':
            with open('./data/flowers_vocab.pkl', 'rb') as f:
                self.vocab = pickle.load(f)
            self.dataset = Text2ImageDataset(config['flowers_dataset_path'],
                                             dataset_type='flowers',
                                             vocab=self.vocab,
                                             split=split)
        else:
            print(
                'Dataset not supported, please select either birds or flowers.'
            )
            exit()

        self.noise_dim = 100
        self.batch_size = batch_size
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.DITER = diter
        self.num_workers = num_workers

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=self.num_workers,
                                      collate_fn=collate_fn)

        self.optimD = torch.optim.Adam(self.discriminator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(self.generator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))

        self.optimD2 = torch.optim.Adam(self.discriminator2.parameters(),
                                        lr=self.lr,
                                        betas=(self.beta1, 0.999))
        self.optimG2 = torch.optim.Adam(self.generator2.parameters(),
                                        lr=self.lr,
                                        betas=(self.beta1, 0.999))

        self.checkpoints_path = './checkpoints/'
        self.save_path = save_path
        self.type = type

        # TODO: put these as runtime.py params
        self.embed_size = 256
        self.hidden_size = 512
        self.num_layers = 1

        self.gen_pretrain_num_epochs = 100
        self.disc_pretrain_num_epochs = 20

        self.figure_path = './figures/'
        if is_cuda:
            self.caption_generator = CaptionGenerator(self.embed_size,
                                                      self.hidden_size,
                                                      len(self.vocab),
                                                      self.num_layers).cuda()
            self.caption_discriminator = CaptionDiscriminator(
                self.embed_size, self.hidden_size, len(self.vocab),
                self.num_layers).cuda()
        else:
            self.caption_generator = CaptionGenerator(self.embed_size,
                                                      self.hidden_size,
                                                      len(self.vocab),
                                                      self.num_layers)
            self.caption_discriminator = CaptionDiscriminator(
                self.embed_size, self.hidden_size, len(self.vocab),
                self.num_layers)

        pretrained_caption_gen = './checkpoints/pretrained-generator-20.pkl'
        pretrained_caption_disc = './checkpoints/pretrained-discriminator-5.pkl'

        if os.path.exists(pretrained_caption_gen):
            print('loaded pretrained caption generator')
            self.caption_generator.load_state_dict(
                torch.load(pretrained_caption_gen))

        if os.path.exists(pretrained_caption_disc):
            print('loaded pretrained caption discriminator')
            self.caption_discriminator.load_state_dict(
                torch.load(pretrained_caption_disc))

        self.optim_captionG = torch.optim.Adam(
            list(self.caption_generator.parameters()))
        self.optim_captionD = torch.optim.Adam(
            list(self.caption_discriminator.parameters()))
Пример #5
0
    def __init__(self, type, dataset, split, lr, diter, vis_screen, save_path,
                 l1_coef, l2_coef, pre_trained_gen, pre_trained_disc,
                 batch_size, num_workers, epochs, pre_trained_disc_B,
                 pre_trained_gen_B):
        with open('config.yaml', 'r') as f:
            config = yaml.load(f)

        # forward gan
        if is_cuda:
            self.generator = torch.nn.DataParallel(
                gan_factory.generator_factory('gan').cuda())
            self.discriminator = torch.nn.DataParallel(
                gan_factory.discriminator_factory('gan').cuda())
            self.generator2 = torch.nn.DataParallel(
                gan_factory.generator_factory('stage2_gan').cuda())
            self.discriminator2 = torch.nn.DataParallel(
                gan_factory.discriminator_factory('stage2_gan').cuda())
        else:
            self.generator = torch.nn.DataParallel(
                gan_factory.generator_factory('gan'))
            self.discriminator = torch.nn.DataParallel(
                gan_factory.discriminator_factory('gan'))
            self.generator2 = torch.nn.DataParallel(
                gan_factory.generator_factory('stage2_gan'))
            self.discriminator2 = torch.nn.DataParallel(
                gan_factory.discriminator_factory('stage2_gan'))

        # inverse gan
        # TODO: pass these as parameters from runtime in the future
        # inverse_type = 'inverse_gan'
        # if is_cuda:
        #     self.inv_generator = torch.nn.DataParallel(gan_factory.generator_factory(inverse_type).cuda())
        #     self.inv_discriminator = torch.nn.DataParallel(gan_factory.discriminator_factory(inverse_type).cuda())
        # else:
        #     self.inv_generator = torch.nn.DataParallel(gan_factory.generator_factory(inverse_type))
        #     self.inv_discriminator = torch.nn.DataParallel(gan_factory.discriminator_factory(inverse_type))

        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        if pre_trained_disc_B:
            self.discriminator2.load_state_dict(torch.load(pre_trained_disc_B))
        else:
            self.discriminator2.apply(Utils.weights_init)

        if pre_trained_gen_B:
            self.generator2.load_state_dict(torch.load(pre_trained_gen_B))
        else:
            self.generator2.apply(Utils.weights_init)

        if dataset == 'birds':
            self.dataset = Text2ImageDataset(config['birds_dataset_path'],
                                             split=split)
        elif dataset == 'flowers':
            self.dataset = Text2ImageDataset(config['flowers_dataset_path'],
                                             split=split)
        else:
            print(
                'Dataset not supported, please select either birds or flowers.'
            )
            exit()

        self.noise_dim = 100
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.DITER = diter

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=self.num_workers)

        self.optimD = torch.optim.Adam(self.discriminator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(self.generator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))

        self.optimD2 = torch.optim.Adam(self.discriminator2.parameters(),
                                        lr=self.lr,
                                        betas=(self.beta1, 0.999))
        self.optimG2 = torch.optim.Adam(self.generator2.parameters(),
                                        lr=self.lr,
                                        betas=(self.beta1, 0.999))

        self.logger = Logger(vis_screen)
        self.checkpoints_path = './checkpoints/'
        self.save_path = save_path
        self.type = type
Пример #6
0
    def __init__(self, type, dataset, split, lr, lr_lower_boundary,
                 lr_update_type, lr_update_step, diter, vis_screen, save_path,
                 l1_coef, l2_coef, pre_trained_gen, pre_trained_disc,
                 batch_size, num_workers, epochs, h, scale_size, num_channels,
                 k, lambda_k, gamma, project, concat):
        with open('config.yaml', 'r') as f:
            config = yaml.load(f)

        self.generator = gan_factory.generator_factory(type, dataset,
                                                       batch_size, h,
                                                       scale_size,
                                                       num_channels).cuda()
        self.discriminator = gan_factory.discriminator_factory(
            type, batch_size, h, scale_size, num_channels).cuda()
        print(self.discriminator)
        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)
        self.dataset_name = dataset
        if self.dataset_name == 'birds':
            self.dataset = Text2ImageDataset(config['birds_dataset_path'],
                                             split=split)
        elif self.dataset_name == 'flowers':
            self.dataset = Text2ImageDataset(config['flowers_dataset_path'],
                                             split=split)
        elif self.dataset_name == 'youtubers':
            self.dataset = OneHot2YoutubersDataset(
                config['youtubers_dataset_path'],
                transform=Rescale(64),
                split=split)
        else:
            print(
                'Dataset not supported, please select either birds or flowers.'
            )
            exit()

        self.noise_dim = 100
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr
        self.lr_lower_boundary = lr_lower_boundary
        self.lr_update_type = lr_update_type
        self.lr_update_step = lr_update_step
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.DITER = diter
        self.apply_projection = project
        self.apply_concat = concat

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef
        self.h = h
        self.scale_size = scale_size
        self.num_channels = num_channels
        self.k = k
        self.lambda_k = lambda_k
        self.gamma = gamma

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=self.num_workers)

        self.optimD = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              self.discriminator.parameters()),
                                       lr=0.0004,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              self.generator.parameters()),
                                       lr=0.0001,
                                       betas=(self.beta1, 0.999))

        self.logger = Logger(vis_screen, save_path)
        self.checkpoints_path = 'checkpoints'
        self.save_path = save_path
        self.type = type
Пример #7
0
    def __init__(self,
                 type,
                 dataset,
                 split,
                 lr,
                 diter,
                 vis_screen,
                 save_path,
                 l1_coef,
                 l2_coef,
                 pre_trained_gen,
                 pre_trained_disc,
                 batch_size,
                 num_workers,
                 epochs,
                 pre_trained_disc_B,
                 pre_trained_gen_B,
                 pre_trained_caption_gen,
                 pre_trained_caption_disc,
                 caption_embed_size,
                 caption_hidden_size,
                 caption_num_layers,
                 caption_gen_pretrain_num_epochs,
                 caption_disc_pretrain_num_epochs,
                 caption_initial_noise=False):
        # with open('config.yaml', 'r') as f:
        #     config = yaml.load(f)
        config = utils.load_config()

        # forward gan
        if is_cuda:
            self.generator = torch.nn.DataParallel(
                gan_factory.generator_factory('gan').cuda())
            self.discriminator = torch.nn.DataParallel(
                gan_factory.discriminator_factory('gan').cuda())
            self.generator2 = torch.nn.DataParallel(
                gan_factory.generator_factory('stage2_gan').cuda())
            self.discriminator2 = torch.nn.DataParallel(
                gan_factory.discriminator_factory('stage2_gan').cuda())
        else:
            self.generator = torch.nn.DataParallel(
                gan_factory.generator_factory('gan'))
            self.discriminator = torch.nn.DataParallel(
                gan_factory.discriminator_factory('gan'))
            self.generator2 = torch.nn.DataParallel(
                gan_factory.generator_factory('stage2_gan'))
            self.discriminator2 = torch.nn.DataParallel(
                gan_factory.discriminator_factory('stage2_gan'))

        if pre_trained_disc:
            print('Loading pre_trained_disc A from: %s' %
                  os.path.abspath(pre_trained_disc))
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            print('Loading pre_trained_gen A from: %s' %
                  os.path.abspath(pre_trained_gen))
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        if pre_trained_disc_B:
            print('Loading pre_trained_disc B from: %s' %
                  os.path.abspath(pre_trained_disc_B))
            self.discriminator2.load_state_dict(torch.load(pre_trained_disc_B))
        else:
            self.discriminator2.apply(Utils.weights_init)

        if pre_trained_gen_B:
            print('Loading pre_trained_gen B from: %s' %
                  os.path.abspath(pre_trained_gen_B))
            self.generator2.load_state_dict(torch.load(pre_trained_gen_B))
        else:
            self.generator2.apply(Utils.weights_init)

        if dataset == 'birds':
            with open(config['birds_vocab_path'], 'rb') as f:
                self.vocab = pickle.load(f)
            self.dataset = Text2ImageDataset(config['birds_dataset_path'],
                                             dataset_type='birds',
                                             vocab=self.vocab,
                                             split=split)
        elif dataset == 'flowers':
            with open(config['flowers_vocab_path'], 'rb') as f:
                self.vocab = pickle.load(f)
            self.dataset = Text2ImageDataset(config['flowers_dataset_path'],
                                             dataset_type='flowers',
                                             vocab=self.vocab,
                                             split=split)
        else:
            print(
                'Dataset not supported, please select either birds or flowers.'
            )
            exit()

        self.noise_dim = 100
        self.batch_size = batch_size
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.DITER = diter
        self.num_workers = num_workers

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=self.num_workers,
                                      collate_fn=collate_fn)

        self.optimD = torch.optim.Adam(self.discriminator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(self.generator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))

        self.optimD2 = torch.optim.Adam(self.discriminator2.parameters(),
                                        lr=self.lr,
                                        betas=(self.beta1, 0.999))
        self.optimG2 = torch.optim.Adam(self.generator2.parameters(),
                                        lr=self.lr,
                                        betas=(self.beta1, 0.999))

        self.checkpoints_path = './checkpoints/'
        self.save_path = save_path
        self.type = type

        # settings for caption GAN
        self.embed_size = caption_embed_size
        self.hidden_size = caption_hidden_size
        self.num_layers = caption_num_layers
        self.caption_initial_noise = caption_initial_noise

        self.gen_pretrain_num_epochs = caption_gen_pretrain_num_epochs
        self.disc_pretrain_num_epochs = caption_disc_pretrain_num_epochs

        self.figure_path = './figures/'

        if is_cuda:
            self.caption_generator = Image2TextGenerator(
                self.embed_size,
                self.hidden_size,
                len(self.vocab),
                self.num_layers,
                initial_noise=self.caption_initial_noise).cuda()
            self.caption_discriminator = Image2TextDiscriminator(
                self.embed_size, self.hidden_size, len(self.vocab),
                self.num_layers).cuda()
        else:
            self.caption_generator = Image2TextGenerator(
                self.embed_size,
                self.hidden_size,
                len(self.vocab),
                self.num_layers,
                initial_noise=self.caption_initial_noise)
            self.caption_discriminator = Image2TextDiscriminator(
                self.embed_size, self.hidden_size, len(self.vocab),
                self.num_layers)

        if pre_trained_caption_gen and os.path.exists(pre_trained_caption_gen):
            print('loaded pretrained caption generator')
            self.caption_generator.load_state_dict(
                torch.load(pre_trained_caption_gen))

        if pre_trained_caption_disc and os.path.exists(
                pre_trained_caption_disc):
            print('loaded pretrained caption discriminator')
            self.caption_discriminator.load_state_dict(
                torch.load(pre_trained_caption_disc))

        self.optim_captionG = torch.optim.Adam(
            list(self.caption_generator.parameters()))
        self.optim_captionD = torch.optim.Adam(
            list(self.caption_discriminator.parameters()))
Пример #8
0
    def __init__(self, type, dataset, split, lr, diter, mode, vis_screen,
                 l1_coef, l2_coef, pre_trained_gen, pre_trained_disc,
                 batch_size, num_workers, epochs):
        with open('config.yaml', 'r') as f:
            config = yaml.load(f)

        self.generator = torch.nn.DataParallel(
            gan_factory.generator_factory(type).cuda())
        self.discriminator = torch.nn.DataParallel(
            gan_factory.discriminator_factory(type).cuda())

        print(self.generator)

        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        if dataset == 'oxford':
            if mode is False:  #if we are in training mode:
                self.dataset = Text2ImageDataset(
                    config['oxford_dataset_path_train'], split=split)
            else:  # testing mode:
                self.dataset = Text2ImageDataset(
                    config['oxford_dataset_path_test'], split=split)
        else:
            print(
                'Dataset not supported, please select either birds or flowers.'
            )
            exit()

        self.noise_dim = 100
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.DITER = diter

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=self.num_workers)

        self.optimD = torch.optim.Adam(self.discriminator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(self.generator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))

        self.logger = Logger(vis_screen)
        now = datetime.datetime.now(dateutil.tz.tzlocal())
        timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
        self.save_path = 'output/%s_%s' % \
                    (dataset, timestamp)
        self.checkpoints_path = os.path.join(self.save_path, 'checkpoints')
        self.image_dir = os.path.join(self.save_path, 'images')
        self.type = type
Пример #9
0
    def validate(self, cls, val_dataset, val_data_loader, epoch):
        criterion = nn.BCELoss()
        l2_loss = nn.MSELoss()
        l1_loss = nn.L1Loss()

        generator = torch.nn.DataParallel(
            gan_factory.generator_factory('gan').cuda())
        generator.load_state_dict(
            torch.load(self.val_pre_trained_gen.replace("XXX", str(epoch))))

        discriminator = torch.nn.DataParallel(
            gan_factory.discriminator_factory('gan').cuda())
        discriminator.load_state_dict(
            torch.load(self.val_pre_trained_disc.replace("XXX", str(epoch))))

        d_epoch_loss = g_epoch_loss = 0.0
        iteration = 0

        dt = datetime.datetime.now()
        print('Validating... Started on', dt.date(), 'at',
              dt.time().replace(microsecond=0))

        for sample in val_data_loader:
            iteration += 1

            right_images = sample['right_images']
            right_embed = sample['right_embed']
            wrong_images = sample['wrong_images']

            right_images = Variable(right_images.float()).cuda()
            right_embed = Variable(right_embed.float()).cuda()
            wrong_images = Variable(wrong_images.float()).cuda()

            real_labels = torch.ones(right_images.size(0))
            fake_labels = torch.zeros(right_images.size(0))

            smoothed_real_labels = torch.FloatTensor(
                Utils.smooth_label(real_labels.numpy(), -0.1))

            real_labels = Variable(real_labels).cuda()
            smoothed_real_labels = Variable(smoothed_real_labels).cuda()
            fake_labels = Variable(fake_labels).cuda()

            outputs, activation_real = discriminator(right_images, right_embed)
            real_loss = criterion(outputs, smoothed_real_labels)
            real_score = outputs

            wrong_loss = wrong_score = 0
            if cls:
                outputs, _ = discriminator(wrong_images, right_embed)
                wrong_loss = criterion(outputs, fake_labels)
                wrong_score = outputs

            noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
            noise = noise.view(noise.size(0), 100, 1, 1)
            fake_images = generator(right_embed, noise)
            outputs, _ = discriminator(fake_images, right_embed)
            fake_loss = criterion(outputs, fake_labels)
            fake_score = outputs

            d_loss = real_loss + fake_loss

            if cls:
                d_loss = d_loss + wrong_loss

            d_epoch_loss += outputs.shape[0] * d_loss.item()

            noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
            noise = noise.view(noise.size(0), 100, 1, 1)
            fake_images = generator(right_embed, noise)
            outputs, activation_fake = discriminator(fake_images, right_embed)
            _, activation_real = discriminator(right_images, right_embed)

            activation_fake = torch.mean(activation_fake, 0)
            activation_real = torch.mean(activation_real, 0)

            g_loss = criterion(outputs, real_labels) + self.l2_coef * l2_loss(
                activation_fake[0], activation_real[0].detach()
            ) + self.l1_coef * l1_loss(fake_images, right_images)
            # g_loss = BCE() + MSE(pomiedzy srednimi aktywacyjnymi dyskryminatora dla falszywych i prawdziwych obrazow) + MAE(pomiedzy obrazami)
            # g_loss = BCE() + bład średniokwadratowy wyznaczony pomiędzy średnimi aktywacjami dyskryminatora dla fałszywych i prawdziwych obrazów + bezwzględna róznica pomiędzy obrazami

            g_epoch_loss += outputs.shape[0] * g_loss.item()

            iters_cnt = 500
            if iteration % iters_cnt == 0:
                percentage = (iteration * 100.0) / len(val_data_loader)
                print('Samples (iterations):', round(percentage, 2), '%')

        dt = datetime.datetime.now()
        print('Validation completed on', dt.date(), 'at',
              dt.time().replace(microsecond=0))

        d_err = d_epoch_loss / len(val_dataset)
        g_err = g_epoch_loss / len(val_dataset)
        print('Discriminator loss:', d_err, '| generator loss:', g_err)

        f = open(self.loss_filename_v, 'a')
        f.write(str(epoch) + " ; " + str(d_err) + " ; " + str(g_err) + "\n")
        f.close()

        gc.collect()
        torch.cuda.empty_cache()
Пример #10
0
    def __init__(self, dataset, split, lr, save_path, l1_coef, l2_coef,
                 pre_trained_gen, pre_trained_disc, val_pre_trained_gen,
                 val_pre_trained_disc, batch_size, num_workers, epochs,
                 dataset_paths, arrangement, sampling):

        with open('config.yaml',
                  'r') as f:  # Wsteczna kompatybilnosc dla Text2ImageDataset
            config = yaml.safe_load(f)

        self.generator = torch.nn.DataParallel(
            gan_factory.generator_factory('gan').cuda())
        self.discriminator = torch.nn.DataParallel(
            gan_factory.discriminator_factory('gan').cuda())

        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        self.val_pre_trained_gen = val_pre_trained_gen
        self.val_pre_trained_disc = val_pre_trained_disc
        self.arrangement = arrangement
        self.sampling = sampling

        if dataset == 'birds':  # Wsteczna kompatybilnosc dla Text2ImageDataset
            self.dataset = Text2ImageDataset(
                config['birds_dataset_path'], split=split
            )  # '...\Text2Image\datasets\ee285f-public\caltech_ucsd_birds\birds.hdf5'
        elif dataset == 'flowers':  # Wsteczna kompatybilnosc dla Text2ImageDataset
            self.dataset = Text2ImageDataset(
                config['flowers_dataset_path'], split=split
            )  # '...\Text2Image\datasets\ee285f-public\oxford_flowers\flowers.hdf5'
        elif dataset == 'live':
            self.dataset_dict = easydict.EasyDict(dataset_paths)
            self.dataset = Text2ImageDataset2(
                datasetFile=self.dataset_dict.datasetFile,
                imagesDir=self.dataset_dict.imagesDir,
                textDir=self.dataset_dict.textDir,
                split=split,
                arrangement=arrangement,
                sampling=sampling)
        else:
            print('Dataset not supported.')

        print('Images =', len(self.dataset))
        self.noise_dim = 100
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef

        self.data_loader = DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers
        )  # shuffle=True - przetasowuje zbior danych w kazdej epoce

        self.optimD = torch.optim.Adam(self.discriminator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(self.generator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))

        self.checkpoints_path = 'checkpoints'
        self.save_path = save_path

        self.loss_filename_t = 'gd_loss_t.csv'
        self.loss_filename_v = 'gd_loss_v.csv'