示例#1
0
    def __init__(self, image_size, latent_dim = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, steps = 1, lr = 1e-4):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.G = Generator(image_size, latent_dim, network_capacity, transparent = transparent)
        self.D = Discriminator(image_size, network_capacity, transparent = transparent)

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.GE = Generator(image_size, latent_dim, network_capacity, transparent = transparent)

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(self.S.parameters())
        self.G_opt = DiffGrad(generator_params, lr = self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(), lr = self.lr, betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()

        self.cuda()
        
        if fp16:
            (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize([self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level='O2')
    def __init__(self,
                 image_size,
                 latent_dim=512,
                 style_depth=8,
                 network_capacity=16,
                 steps=1,
                 lr=1e-4):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.99)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.G = Generator(image_size, latent_dim, network_capacity)
        self.D = Discriminator(image_size, network_capacity)

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.GE = Generator(image_size, latent_dim, network_capacity)

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(
            self.S.parameters())
        self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(),
                              lr=self.lr,
                              betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()
    def __init__(self, image_size, latent_dim = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, fq_layers = [], fq_dict_size = 256, attn_layers = []):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.G = Generator(image_size, latent_dim, network_capacity, transparent = transparent)
        self.D = Discriminator(image_size, network_capacity, fq_layers = fq_layers, fq_dict_size = fq_dict_size, attn_layers = attn_layers, transparent = transparent)

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.GE = Generator(image_size, latent_dim, network_capacity, transparent = transparent)

        # experimental contrastive loss discriminator regularization
        self.D_cl = ContrastiveLearner(self.D, image_size, hidden_layer='flatten') if cl_reg else None

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(self.S.parameters())
        self.G_opt = DiffGrad(generator_params, lr = self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(), lr = self.lr, betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()

        self.cuda()
        
        if fp16:
            (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize([self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level='O2')
示例#4
0
    def __init__(self, image_size, label_dim, latent_dim=LATENT_DIM, style_depth=STYLE_DEPTH,
                 network_capacity=NETWORK_CAPACITY, steps=1, lr=LEARNING_RATE, channels=CHANNELS,
                 condition_on_mapper=CONDITION_ON_MAPPER, use_biases=USE_BIASES, label_epsilon=LABEL_EPSILON):
        super().__init__()
        self.condition_on_mapper = condition_on_mapper
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.99)

        self.S = StyleVectorizer(latent_dim, label_dim, style_depth, condition_on_mapper=self.condition_on_mapper, use_biases=use_biases)
        self.G = Generator(image_size, latent_dim, label_dim, network_capacity, channels=channels,
                           condition_on_mapper=self.condition_on_mapper, use_biases=use_biases)
        self.D = Discriminator(image_size, label_dim, network_capacity=network_capacity, channels=channels,
                               label_epsilon=label_epsilon)

        self.SE = StyleVectorizer(latent_dim, label_dim, style_depth, condition_on_mapper=self.condition_on_mapper,
                                  use_biases=use_biases)
        self.GE = Generator(image_size, latent_dim, label_dim, network_capacity, channels=channels,
                            condition_on_mapper=self.condition_on_mapper, use_biases=use_biases)

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(self.S.parameters())
        self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(), lr=self.lr, betas=(0.5, 0.9))

        self.use_biases = use_biases
        self._init_weights()
        self.reset_parameter_averaging()
示例#5
0
    def __init__(self,
                 image_size,
                 latent_dim=512,
                 noise_dim=100,
                 style_depth=8,
                 network_capacity=16,
                 transparent=False,
                 steps=1,
                 lr=2e-4):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.N = NoiseVectorizer(noise_dim)
        self.G = Generator(image_size,
                           latent_dim,
                           network_capacity,
                           transparent=transparent)
        self.D = Discriminator(image_size,
                               network_capacity,
                               transparent=transparent)
        ###########################################
        self.E = ExtractNetSimilarE(64)

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.NE = NoiseVectorizer(noise_dim)
        self.GE = Generator(image_size,
                            latent_dim,
                            network_capacity,
                            transparent=transparent)

        set_requires_grad(self.SE, False)
        set_requires_grad(self.NE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(
            self.S.parameters()) + list(self.N.parameters())
        self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(),
                              lr=self.lr,
                              betas=(0.5, 0.9))
        ###############################################
        E_params = list(self.E.parameters())
        self.E_opt = DiffGrad(E_params, lr=self.lr, betas=(0.5, 0.9))
        N_params = list(self.N.parameters())
        self.N_opt = DiffGrad(N_params, lr=self.lr, betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()
示例#6
0
    def __init__(self,
                 image_size,
                 latent_dim=512,
                 style_depth=8,
                 network_capacity=16,
                 transparent=False,
                 fp16=False,
                 cl_reg=False,
                 steps=1,
                 lr=1e-4,
                 fq_layers=[],
                 fq_dict_size=256,
                 freeze_g=False,
                 freeze_d=False):
        super().__init__()

        self.lr = lr
        self.steps = steps
        self.ema_decay = 0.995

        self.G = Generator(image_size,
                           latent_dim,
                           network_capacity,
                           transparent=transparent)
        self.D = Discriminator(image_size,
                               network_capacity,
                               fq_layers=fq_layers,
                               fq_dict_size=fq_dict_size,
                               transparent=transparent)

        self.GE = Generator(image_size,
                            latent_dim,
                            network_capacity,
                            transparent=transparent)
        set_requires_grad(self.GE, False)

        if freeze_g:
            self.G.freeze_()
        if freeze_d:
            self.D.freeze_()

        self.G.opt = DiffGrad(self.G.parameters(),
                              lr=self.lr,
                              betas=(0.5, 0.9))
        self.D.opt = DiffGrad(self.D.parameters(),
                              lr=self.lr,
                              betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()
示例#7
0
    def forward(self):
        if exists(self.start_image):
            tqdm.write('Preparing with initial image...')
            optim = DiffGrad(self.model.parameters(), lr=self.start_image_lr)
            pbar = trange(self.start_image_train_iters, desc='iteration')
            for _ in pbar:
                loss = self.model.model(self.start_image)
                loss.backward()
                pbar.set_description(f'loss: {loss.item():.2f}')

                optim.step()
                optim.zero_grad()

                if terminate:
                    print('interrupted by keyboard, gracefully exiting')
                    return exit()

            del self.start_image
            del optim

        tqdm.write(f'Imagining "{self.text}" from the depths of my weights...')

        if self.open_folder:
            open_folder('./')
            self.open_folder = False

        for epoch in trange(self.epochs, desc='epochs'):
            pbar = trange(self.iterations, desc='iteration')
            for i in pbar:
                loss = self.train_step(epoch, i)
                pbar.set_description(f'loss: {loss.item():.2f}')

                if terminate:
                    print('interrupted by keyboard, gracefully exiting')
                    return
示例#8
0
    def forward(self):
        if exists(self.start_image):
            tqdm.write('Preparing with initial image...')
            optim = DiffGrad(self.model.model.parameters(),
                             lr=self.start_image_lr)
            pbar = trange(self.start_image_train_iters, desc='iteration')
            try:
                for _ in pbar:
                    loss = self.model.model(self.start_image)
                    loss.backward()
                    pbar.set_description(f'loss: {loss.item():.2f}')

                    optim.step()
                    optim.zero_grad()
            except KeyboardInterrupt:
                print('interrupted by keyboard, gracefully exiting')
                return exit()

            del self.start_image
            del optim

        tqdm.write(
            f'Imagining "{self.textpath}" from the depths of my weights...')

        with torch.no_grad():
            self.model(
                self.clip_encoding, dry_run=True
            )  # do one warmup step due to potential issue with CLIP and CUDA

        if self.open_folder:
            if self.output_folder:
                open_folder(self.output_folder)
            else:
                open_folder('./')
            self.open_folder = False

        try:
            for epoch in trange(self.epochs, desc='epochs'):
                pbar = trange(self.iterations, desc='iteration')
                for i in pbar:
                    _, loss = self.train_step(epoch, i)
                    pbar.set_description(f'loss: {loss.item():.2f}')

                # Update clip_encoding per epoch if we are creating a story
                if self.create_story:
                    self.clip_encoding = self.update_story_encoding(epoch, i)
        except KeyboardInterrupt:
            print('interrupted by keyboard, gracefully exiting')
            return

        self.save_image(epoch, i)  # one final save at end

        if (self.save_gif or self.save_video) and self.save_progress:
            self.generate_gif()
示例#9
0
    def forward(self):
        if exists(self.start_image):
            tqdm.write('Preparing with initial image...')
            optim = DiffGrad(self.model.parameters(), lr = self.start_image_lr)
            pbar = trange(self.start_image_train_iters, desc='iteration')
            for _ in pbar:
                loss = self.model.model(self.start_image)
                loss.backward()
                pbar.set_description(f'loss: {loss.item():.2f}')

                optim.step()
                optim.zero_grad()

                if terminate:
                    print('interrupted by keyboard, gracefully exiting')
                    return sys.exit()

            del self.start_image
            del optim

        tqdm.write(f'Imagining "{self.textpath}" from the depths of my weights...')

        self.model(self.clip_encoding, dry_run = True) # do one warmup step due to potential issue with CLIP and CUDA

        if self.open_folder:
            open_folder('./')
            self.open_folder = False

        for epoch in trange(self.epochs, desc='epochs'):
            pbar = trange(self.iterations, desc='iteration')
            for i in pbar:
                loss = self.train_step(epoch, i)
                pbar.set_description(f'loss: {loss.item():.2f}')

                if terminate:
                    print('interrupted by keyboard, gracefully exiting')
                    return

        self.save_image(self.epochs, self.iterations) # one final save at end
示例#10
0
    def __init__(
        self,
        *,
        text=None,
        img=None,
        clip_encoding=None,
        lr=1e-5,
        batch_size=4,
        gradient_accumulate_every=4,
        save_every=100,
        image_width=512,
        num_layers=16,
        epochs=20,
        iterations=1050,
        save_progress=True,
        seed=None,
        open_folder=True,
        save_date_time=False,
        start_image_path=None,
        start_image_train_iters=10,
        start_image_lr=3e-4,
        theta_initial=None,
        theta_hidden=None,
        model_name="ViT-B/32",
        lower_bound_cutout=0.1,  # should be smaller than 0.8
        upper_bound_cutout=1.0,
        saturate_bound=False,
        averaging_weight=0.3,
        create_story=False,
        story_start_words=5,
        story_words_per_epoch=5,
        story_separator=None,
        gauss_sampling=False,
        gauss_mean=0.6,
        gauss_std=0.2,
        do_cutout=True,
        center_bias=False,
        center_focus=2,
        optimizer="AdamP",
        jit=True,
        hidden_size=256,
        save_gif=False,
        save_video=False,
    ):

        super().__init__()

        if exists(seed):
            tqdm.write(f'setting seed: {seed}')
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            random.seed(seed)
            torch.backends.cudnn.deterministic = True

        # fields for story creation:
        self.create_story = create_story
        self.words = None
        self.separator = str(
            story_separator) if story_separator is not None else None
        if self.separator is not None and text is not None:
            #exit if text is just the separator
            if str(text).replace(' ', '').replace(self.separator, '') == '':
                print(
                    'Exiting because the text only consists of the separator! Needs words or phrases that are separated by the separator.'
                )
                exit()
            #adds a space to each separator and removes double spaces that might be generated
            text = text.replace(self.separator,
                                self.separator + ' ').replace('  ',
                                                              ' ').strip()
        self.all_words = text.split(" ") if text is not None else None
        self.num_start_words = story_start_words
        self.words_per_epoch = story_words_per_epoch
        if create_story:
            assert text is not None, "We need text input to create a story..."
            # overwrite epochs to match story length
            num_words = len(self.all_words)
            self.epochs = 1 + (num_words -
                               self.num_start_words) / self.words_per_epoch
            # add one epoch if not divisible
            self.epochs = int(self.epochs) if int(
                self.epochs) == self.epochs else int(self.epochs) + 1
            if self.separator is not None:
                if self.separator not in text:
                    print("Separator '" + self.separator +
                          "' will be ignored since not in text!")
                    self.separator = None
                else:
                    self.epochs = len(
                        list(filter(None, text.split(self.separator))))
            print(
                "Running for", self.epochs, "epochs" +
                (" (split with '" + self.separator +
                 "' as the separator)" if self.separator is not None else ""))
        else:
            self.epochs = epochs

        # jit models only compatible with version 1.7.1
        if "1.7.1" not in torch.__version__:
            if jit == True:
                print(
                    "Setting jit to False because torch version is not 1.7.1.")
            jit = False

        # Load CLIP
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        clip_perceptor, norm = load(model_name, jit=jit, device=self.device)
        self.perceptor = clip_perceptor.eval()
        for param in self.perceptor.parameters():
            param.requires_grad = False
        if jit == False:
            input_res = clip_perceptor.visual.input_resolution
        else:
            input_res = clip_perceptor.input_resolution.item()
        self.clip_transform = create_clip_img_transform(input_res)

        self.iterations = iterations
        self.image_width = image_width
        total_batches = self.epochs * self.iterations * batch_size * gradient_accumulate_every
        model = DeepDaze(
            self.perceptor,
            norm,
            input_res,
            total_batches,
            batch_size=batch_size,
            image_width=image_width,
            num_layers=num_layers,
            theta_initial=theta_initial,
            theta_hidden=theta_hidden,
            lower_bound_cutout=lower_bound_cutout,
            upper_bound_cutout=upper_bound_cutout,
            saturate_bound=saturate_bound,
            gauss_sampling=gauss_sampling,
            gauss_mean=gauss_mean,
            gauss_std=gauss_std,
            do_cutout=do_cutout,
            center_bias=center_bias,
            center_focus=center_focus,
            hidden_size=hidden_size,
            averaging_weight=averaging_weight,
        ).to(self.device)
        self.model = model
        self.scaler = GradScaler()
        siren_params = model.model.parameters()
        if optimizer == "AdamP":
            self.optimizer = AdamP(siren_params, lr)
        elif optimizer == "Adam":
            self.optimizer = torch.optim.Adam(siren_params, lr)
        elif optimizer == "DiffGrad":
            self.optimizer = DiffGrad(siren_params, lr)
        self.gradient_accumulate_every = gradient_accumulate_every
        self.save_every = save_every
        self.save_date_time = save_date_time
        self.open_folder = open_folder
        self.save_progress = save_progress
        self.text = text
        self.image = img
        self.textpath = create_text_path(self.perceptor.context_length,
                                         text=text,
                                         img=img,
                                         encoding=clip_encoding,
                                         separator=story_separator)
        self.filename = self.image_output_path()

        # create coding to optimize for
        self.clip_encoding = self.create_clip_encoding(text=text,
                                                       img=img,
                                                       encoding=clip_encoding)

        self.start_image = None
        self.start_image_train_iters = start_image_train_iters
        self.start_image_lr = start_image_lr
        if exists(start_image_path):
            file = Path(start_image_path)
            assert file.exists(
            ), f'file does not exist at given starting image path {self.start_image_path}'
            image = Image.open(str(file))
            start_img_transform = T.Compose([
                T.Resize(image_width),
                T.CenterCrop((image_width, image_width)),
                T.ToTensor()
            ])
            image_tensor = start_img_transform(image).unsqueeze(0).to(
                self.device)
            self.start_image = image_tensor

        self.save_gif = save_gif
        self.save_video = save_video
示例#11
0
class Imagine(nn.Module):
    def __init__(
        self,
        *,
        text=None,
        img=None,
        clip_encoding=None,
        lr=1e-5,
        batch_size=4,
        gradient_accumulate_every=4,
        save_every=100,
        image_width=512,
        num_layers=16,
        epochs=20,
        iterations=1050,
        save_progress=True,
        seed=None,
        open_folder=True,
        save_date_time=False,
        start_image_path=None,
        start_image_train_iters=10,
        start_image_lr=3e-4,
        theta_initial=None,
        theta_hidden=None,
        model_name="ViT-B/32",
        lower_bound_cutout=0.1,  # should be smaller than 0.8
        upper_bound_cutout=1.0,
        saturate_bound=False,
        averaging_weight=0.3,
        create_story=False,
        story_start_words=5,
        story_words_per_epoch=5,
        story_separator=None,
        gauss_sampling=False,
        gauss_mean=0.6,
        gauss_std=0.2,
        do_cutout=True,
        center_bias=False,
        center_focus=2,
        optimizer="AdamP",
        jit=True,
        hidden_size=256,
        save_gif=False,
        save_video=False,
    ):

        super().__init__()

        if exists(seed):
            tqdm.write(f'setting seed: {seed}')
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            random.seed(seed)
            torch.backends.cudnn.deterministic = True

        # fields for story creation:
        self.create_story = create_story
        self.words = None
        self.separator = str(
            story_separator) if story_separator is not None else None
        if self.separator is not None and text is not None:
            #exit if text is just the separator
            if str(text).replace(' ', '').replace(self.separator, '') == '':
                print(
                    'Exiting because the text only consists of the separator! Needs words or phrases that are separated by the separator.'
                )
                exit()
            #adds a space to each separator and removes double spaces that might be generated
            text = text.replace(self.separator,
                                self.separator + ' ').replace('  ',
                                                              ' ').strip()
        self.all_words = text.split(" ") if text is not None else None
        self.num_start_words = story_start_words
        self.words_per_epoch = story_words_per_epoch
        if create_story:
            assert text is not None, "We need text input to create a story..."
            # overwrite epochs to match story length
            num_words = len(self.all_words)
            self.epochs = 1 + (num_words -
                               self.num_start_words) / self.words_per_epoch
            # add one epoch if not divisible
            self.epochs = int(self.epochs) if int(
                self.epochs) == self.epochs else int(self.epochs) + 1
            if self.separator is not None:
                if self.separator not in text:
                    print("Separator '" + self.separator +
                          "' will be ignored since not in text!")
                    self.separator = None
                else:
                    self.epochs = len(
                        list(filter(None, text.split(self.separator))))
            print(
                "Running for", self.epochs, "epochs" +
                (" (split with '" + self.separator +
                 "' as the separator)" if self.separator is not None else ""))
        else:
            self.epochs = epochs

        # jit models only compatible with version 1.7.1
        if "1.7.1" not in torch.__version__:
            if jit == True:
                print(
                    "Setting jit to False because torch version is not 1.7.1.")
            jit = False

        # Load CLIP
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        clip_perceptor, norm = load(model_name, jit=jit, device=self.device)
        self.perceptor = clip_perceptor.eval()
        for param in self.perceptor.parameters():
            param.requires_grad = False
        if jit == False:
            input_res = clip_perceptor.visual.input_resolution
        else:
            input_res = clip_perceptor.input_resolution.item()
        self.clip_transform = create_clip_img_transform(input_res)

        self.iterations = iterations
        self.image_width = image_width
        total_batches = self.epochs * self.iterations * batch_size * gradient_accumulate_every
        model = DeepDaze(
            self.perceptor,
            norm,
            input_res,
            total_batches,
            batch_size=batch_size,
            image_width=image_width,
            num_layers=num_layers,
            theta_initial=theta_initial,
            theta_hidden=theta_hidden,
            lower_bound_cutout=lower_bound_cutout,
            upper_bound_cutout=upper_bound_cutout,
            saturate_bound=saturate_bound,
            gauss_sampling=gauss_sampling,
            gauss_mean=gauss_mean,
            gauss_std=gauss_std,
            do_cutout=do_cutout,
            center_bias=center_bias,
            center_focus=center_focus,
            hidden_size=hidden_size,
            averaging_weight=averaging_weight,
        ).to(self.device)
        self.model = model
        self.scaler = GradScaler()
        siren_params = model.model.parameters()
        if optimizer == "AdamP":
            self.optimizer = AdamP(siren_params, lr)
        elif optimizer == "Adam":
            self.optimizer = torch.optim.Adam(siren_params, lr)
        elif optimizer == "DiffGrad":
            self.optimizer = DiffGrad(siren_params, lr)
        self.gradient_accumulate_every = gradient_accumulate_every
        self.save_every = save_every
        self.save_date_time = save_date_time
        self.open_folder = open_folder
        self.save_progress = save_progress
        self.text = text
        self.image = img
        self.textpath = create_text_path(self.perceptor.context_length,
                                         text=text,
                                         img=img,
                                         encoding=clip_encoding,
                                         separator=story_separator)
        self.filename = self.image_output_path()

        # create coding to optimize for
        self.clip_encoding = self.create_clip_encoding(text=text,
                                                       img=img,
                                                       encoding=clip_encoding)

        self.start_image = None
        self.start_image_train_iters = start_image_train_iters
        self.start_image_lr = start_image_lr
        if exists(start_image_path):
            file = Path(start_image_path)
            assert file.exists(
            ), f'file does not exist at given starting image path {self.start_image_path}'
            image = Image.open(str(file))
            start_img_transform = T.Compose([
                T.Resize(image_width),
                T.CenterCrop((image_width, image_width)),
                T.ToTensor()
            ])
            image_tensor = start_img_transform(image).unsqueeze(0).to(
                self.device)
            self.start_image = image_tensor

        self.save_gif = save_gif
        self.save_video = save_video

    def create_clip_encoding(self, text=None, img=None, encoding=None):
        self.text = text
        self.img = img
        if encoding is not None:
            encoding = encoding.to(self.device)
        elif self.create_story:
            encoding = self.update_story_encoding(epoch=0, iteration=1)
        elif text is not None and img is not None:
            encoding = (self.create_text_encoding(text) +
                        self.create_img_encoding(img)) / 2
        elif text is not None:
            encoding = self.create_text_encoding(text)
        elif img is not None:
            encoding = self.create_img_encoding(img)
        return encoding

    def create_text_encoding(self, text):
        tokenized_text = tokenize(text).to(self.device)
        with torch.no_grad():
            text_encoding = self.perceptor.encode_text(tokenized_text).detach()
        return text_encoding

    def create_img_encoding(self, img):
        if isinstance(img, str):
            img = Image.open(img)
        normed_img = self.clip_transform(img).unsqueeze(0).to(self.device)
        with torch.no_grad():
            img_encoding = self.perceptor.encode_image(normed_img).detach()
        return img_encoding

    def set_clip_encoding(self, text=None, img=None, encoding=None):
        encoding = self.create_clip_encoding(text=text,
                                             img=img,
                                             encoding=encoding)
        self.clip_encoding = encoding.to(self.device)

    def index_of_first_separator(self) -> int:
        for c, word in enumerate(self.all_words):
            if self.separator in str(word):
                return c + 1

    def update_story_encoding(self, epoch, iteration):
        if self.separator is not None:
            self.words = " ".join(
                self.all_words[:self.index_of_first_separator()])
            #removes separator from epoch-text
            self.words = self.words.replace(self.separator, '')
            self.all_words = self.all_words[self.index_of_first_separator():]
        else:
            if self.words is None:
                self.words = " ".join(self.all_words[:self.num_start_words])
                self.all_words = self.all_words[self.num_start_words:]
            else:
                # add words_per_epoch new words
                count = 0
                while count < self.words_per_epoch and len(self.all_words) > 0:
                    new_word = self.all_words[0]
                    self.words = " ".join(self.words.split(" ") + [new_word])
                    self.all_words = self.all_words[1:]
                    count += 1
                # remove words until it fits in context length
                while len(self.words) > self.perceptor.context_length:
                    # remove first word
                    self.words = " ".join(self.words.split(" ")[1:])
        # get new encoding
        print("Now thinking of: ", '"', self.words, '"')
        sequence_number = self.get_img_sequence_number(epoch, iteration)
        # save new words to disc
        with open("story_transitions.txt", "a") as f:
            f.write(f"{epoch}, {sequence_number}, {self.words}\n")

        encoding = self.create_text_encoding(self.words)
        return encoding

    def image_output_path(self, sequence_number=None):
        """
        Returns underscore separated Path.
        A current timestamp is prepended if `self.save_date_time` is set.
        Sequence number left padded with 6 zeroes is appended if `save_every` is set.
        :rtype: Path
        """
        output_path = self.textpath
        if sequence_number:
            sequence_number_left_padded = str(sequence_number).zfill(6)
            output_path = f"{output_path}.{sequence_number_left_padded}"
        if self.save_date_time:
            current_time = datetime.now().strftime("%y%m%d-%H%M%S_%f")
            output_path = f"{current_time}_{output_path}"
        return Path(f"{output_path}.jpg")

    def train_step(self, epoch, iteration):
        total_loss = 0

        for _ in range(self.gradient_accumulate_every):
            with autocast(enabled=True):
                out, loss = self.model(self.clip_encoding)
            loss = loss / self.gradient_accumulate_every
            total_loss += loss
            self.scaler.scale(loss).backward()
        out = out.cpu().float().clamp(0., 1.)
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()

        if (iteration % self.save_every == 0) and self.save_progress:
            self.save_image(epoch, iteration, img=out)

        return out, total_loss

    def get_img_sequence_number(self, epoch, iteration):
        current_total_iterations = epoch * self.iterations + iteration
        sequence_number = current_total_iterations // self.save_every
        return sequence_number

    @torch.no_grad()
    def save_image(self, epoch, iteration, img=None):
        sequence_number = self.get_img_sequence_number(epoch, iteration)

        if img is None:
            img = self.model(self.clip_encoding,
                             return_loss=False).cpu().float().clamp(0., 1.)
        self.filename = self.image_output_path(sequence_number=sequence_number)

        pil_img = T.ToPILImage()(img.squeeze())
        pil_img.save(self.filename, quality=95, subsampling=0)
        pil_img.save(f"{self.textpath}.jpg", quality=95, subsampling=0)

        tqdm.write(f'image updated at "./{str(self.filename)}"')

    def generate_gif(self):
        images = []
        for file_name in sorted(os.listdir('./')):
            if file_name.startswith(
                    self.textpath) and file_name != f'{self.textpath}.jpg':
                images.append(imread(os.path.join('./', file_name)))

        if self.save_video:
            mimsave(f'{self.textpath}.mp4', images)
            print(
                f'Generated image generation animation at ./{self.textpath}.mp4'
            )
        if self.save_gif:
            mimsave(f'{self.textpath}.gif', images)
            print(
                f'Generated image generation animation at ./{self.textpath}.gif'
            )

    def forward(self):
        if exists(self.start_image):
            tqdm.write('Preparing with initial image...')
            optim = DiffGrad(self.model.model.parameters(),
                             lr=self.start_image_lr)
            pbar = trange(self.start_image_train_iters, desc='iteration')
            try:
                for _ in pbar:
                    loss = self.model.model(self.start_image)
                    loss.backward()
                    pbar.set_description(f'loss: {loss.item():.2f}')

                    optim.step()
                    optim.zero_grad()
            except KeyboardInterrupt:
                print('interrupted by keyboard, gracefully exiting')
                return exit()

            del self.start_image
            del optim

        tqdm.write(
            f'Imagining "{self.textpath}" from the depths of my weights...')

        with torch.no_grad():
            self.model(
                self.clip_encoding, dry_run=True
            )  # do one warmup step due to potential issue with CLIP and CUDA

        if self.open_folder:
            open_folder('./')
            self.open_folder = False

        try:
            for epoch in trange(self.epochs, desc='epochs'):
                pbar = trange(self.iterations, desc='iteration')
                for i in pbar:
                    _, loss = self.train_step(epoch, i)
                    pbar.set_description(f'loss: {loss.item():.2f}')

                # Update clip_encoding per epoch if we are creating a story
                if self.create_story:
                    self.clip_encoding = self.update_story_encoding(epoch, i)
        except KeyboardInterrupt:
            print('interrupted by keyboard, gracefully exiting')
            return

        self.save_image(epoch, i)  # one final save at end

        if (self.save_gif or self.save_video) and self.save_progress:
            self.generate_gif()
示例#12
0
    def __init__(
            self,
            *,
            text=None,  # 文本
            img=None,  # 想象的艺术图片
            lr=1e-5,  # 学习率
            batch_size=4,  #
            gradient_accumulate_every=4,  # 梯度累积,增大可以在比较小的epoch上快速降低loss
            save_every=100,  # 每迭代100次就保存一次
            image_width=200,  # 最大400,相应的layer最大14
            num_layers=8,
            epochs=3,
            iterations=1050,
            save_progress=True,
            open_folder=True,
            theta_initial=None,  # 描述siren初始层的色彩空间
            theta_hidden=None,  # 描述siren隐藏层的色彩空间
            model_name="ViT-B/32",  # 模型名称 VIT-B 小模型
            lower_bound_cutout=0.1,  # should be smaller than 0.8
            upper_bound_cutout=1.0,
            averaging_weight=0.3,
            do_cutout=True,
            center_bias=False,
            center_focus=2,
            optimizer="AdamP",
            jit=True,
            hidden_size=256,
            save_gif=True,
            save_video=True,
    ):

        super().__init__()

        self.epochs = epochs

        # jit models only compatible with version 1.7.1
        if "1.7.1" not in torch.__version__:
            if jit:
                print("Setting jit to False because torch version is not 1.7.1.")
            jit = False

        # 加载CLIP模型
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 如果gpu可用则选择,否则cpu
        # 在线下载神经网络ViT-B/32模型,供clip使用
        # 返回clip模型,nn.Module
        # Torchvision转换,将PIL图像转换为张量,返回的模型可以将其用作输入
        clip_perceptor, norm = load(model_name, jit=jit, device=self.device)
        # 不启用 Batch Normalization 和 Dropout。
        # 生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。
        self.perceptor = clip_perceptor.eval()
        for param in self.perceptor.parameters():
            param.requires_grad = False

        if not jit:
            input_res = clip_perceptor.visual.input_resolution  # 输入分辨率
        else:
            input_res = clip_perceptor.input_resolution.item()
        # 创建clip图像transform模型
        self.clip_transform = create_clip_img_transform(input_res)
        # 迭代次数
        self.iterations = iterations
        # 生成图像的宽度
        self.image_width = image_width
        # 总共的批大小 = imagine输入的epochs(20) * 迭代次数 * batch的大小 * 梯度累积
        total_batches = self.epochs * self.iterations * batch_size * gradient_accumulate_every
        # 加载DeepDaze模型,并将该模型部署到gpu或cpu上
        model = ShallowDaze(
            self.perceptor,  # clip模型
            norm,  # clip模型范数
            input_res,  # 输入分辨率
            total_batches,
            batch_size=batch_size,  # batch_size=4
            image_width=image_width,
            num_layers=num_layers,
            theta_initial=theta_initial,  # None
            theta_hidden=theta_hidden,  # None
            lower_bound_cutout=lower_bound_cutout,  # 0.1
            upper_bound_cutout=upper_bound_cutout,  # 1.0
            do_cutout=do_cutout,
            center_bias=center_bias,
            center_focus=center_focus,
            hidden_size=hidden_size,
            averaging_weight=averaging_weight,
        ).to(self.device)
        self.model = model  # deep-daze模型
        self.scaler = GradScaler()  # 通过放大loss的值来防止梯度的下溢
        siren_params = model.model.parameters()
        # 三种梯度下降的方法,默认AdamP
        if optimizer == "AdamP":
            self.optimizer = AdamP(siren_params, lr)
        elif optimizer == "Adam":
            self.optimizer = torch.optim.Adam(siren_params, lr)
        elif optimizer == "DiffGrad":
            self.optimizer = DiffGrad(siren_params, lr)

        # 梯度累积
        self.gradient_accumulate_every = gradient_accumulate_every
        self.save_every = save_every
        self.open_folder = open_folder
        self.save_progress = save_progress
        self.text = text
        self.image = img
        # 默认clip_encoding=None
        self.textpath = create_text_path(self.perceptor.context_length, text=text, img=img)
        self.filename = self.image_output_path()

        # 创建代码以进行优化
        self.clip_encoding = self.create_clip_encoding(text=text, img=img)  # 默认clip_encoding=None

        self.save_gif = save_gif
        self.save_video = save_video
示例#13
0
class Imagine(nn.Module):
    def __init__(
            self,
            *,
            text=None,  # 文本
            img=None,  # 想象的艺术图片
            lr=1e-5,  # 学习率
            batch_size=4,  #
            gradient_accumulate_every=4,  # 梯度累积,增大可以在比较小的epoch上快速降低loss
            save_every=100,  # 每迭代100次就保存一次
            image_width=200,  # 最大400,相应的layer最大14
            num_layers=8,
            epochs=3,
            iterations=1050,
            save_progress=True,
            open_folder=True,
            theta_initial=None,  # 描述siren初始层的色彩空间
            theta_hidden=None,  # 描述siren隐藏层的色彩空间
            model_name="ViT-B/32",  # 模型名称 VIT-B 小模型
            lower_bound_cutout=0.1,  # should be smaller than 0.8
            upper_bound_cutout=1.0,
            averaging_weight=0.3,
            do_cutout=True,
            center_bias=False,
            center_focus=2,
            optimizer="AdamP",
            jit=True,
            hidden_size=256,
            save_gif=True,
            save_video=True,
    ):

        super().__init__()

        self.epochs = epochs

        # jit models only compatible with version 1.7.1
        if "1.7.1" not in torch.__version__:
            if jit:
                print("Setting jit to False because torch version is not 1.7.1.")
            jit = False

        # 加载CLIP模型
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 如果gpu可用则选择,否则cpu
        # 在线下载神经网络ViT-B/32模型,供clip使用
        # 返回clip模型,nn.Module
        # Torchvision转换,将PIL图像转换为张量,返回的模型可以将其用作输入
        clip_perceptor, norm = load(model_name, jit=jit, device=self.device)
        # 不启用 Batch Normalization 和 Dropout。
        # 生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。
        self.perceptor = clip_perceptor.eval()
        for param in self.perceptor.parameters():
            param.requires_grad = False

        if not jit:
            input_res = clip_perceptor.visual.input_resolution  # 输入分辨率
        else:
            input_res = clip_perceptor.input_resolution.item()
        # 创建clip图像transform模型
        self.clip_transform = create_clip_img_transform(input_res)
        # 迭代次数
        self.iterations = iterations
        # 生成图像的宽度
        self.image_width = image_width
        # 总共的批大小 = imagine输入的epochs(20) * 迭代次数 * batch的大小 * 梯度累积
        total_batches = self.epochs * self.iterations * batch_size * gradient_accumulate_every
        # 加载DeepDaze模型,并将该模型部署到gpu或cpu上
        model = ShallowDaze(
            self.perceptor,  # clip模型
            norm,  # clip模型范数
            input_res,  # 输入分辨率
            total_batches,
            batch_size=batch_size,  # batch_size=4
            image_width=image_width,
            num_layers=num_layers,
            theta_initial=theta_initial,  # None
            theta_hidden=theta_hidden,  # None
            lower_bound_cutout=lower_bound_cutout,  # 0.1
            upper_bound_cutout=upper_bound_cutout,  # 1.0
            do_cutout=do_cutout,
            center_bias=center_bias,
            center_focus=center_focus,
            hidden_size=hidden_size,
            averaging_weight=averaging_weight,
        ).to(self.device)
        self.model = model  # deep-daze模型
        self.scaler = GradScaler()  # 通过放大loss的值来防止梯度的下溢
        siren_params = model.model.parameters()
        # 三种梯度下降的方法,默认AdamP
        if optimizer == "AdamP":
            self.optimizer = AdamP(siren_params, lr)
        elif optimizer == "Adam":
            self.optimizer = torch.optim.Adam(siren_params, lr)
        elif optimizer == "DiffGrad":
            self.optimizer = DiffGrad(siren_params, lr)

        # 梯度累积
        self.gradient_accumulate_every = gradient_accumulate_every
        self.save_every = save_every
        self.open_folder = open_folder
        self.save_progress = save_progress
        self.text = text
        self.image = img
        # 默认clip_encoding=None
        self.textpath = create_text_path(self.perceptor.context_length, text=text, img=img)
        self.filename = self.image_output_path()

        # 创建代码以进行优化
        self.clip_encoding = self.create_clip_encoding(text=text, img=img)  # 默认clip_encoding=None

        self.save_gif = save_gif
        self.save_video = save_video

    def create_clip_encoding(self, text=None, img=None):
        self.text = text
        self.img = img
        if text is not None and img is not None:
            encoding = (self.create_text_encoding(text) + self.create_img_encoding(img)) / 2
        elif text is not None:
            encoding = self.create_text_encoding(text)
        elif img is not None:
            encoding = self.create_img_encoding(img)
        return encoding

    def create_text_encoding(self, text):
        """
        利用clip模型创建text的token
        """
        tokenized_text = tokenize(text).to(self.device)
        with torch.no_grad():
            text_encoding = self.perceptor.encode_text(tokenized_text).detach()
        return text_encoding

    def create_img_encoding(self, img):
        """
        通过clip模型编码图像
        """
        normed_img = self.clip_transform(img).unsqueeze(0).to(self.device)
        with torch.no_grad():
            img_encoding = self.perceptor.encode_image(normed_img).detach()
        return img_encoding

    def set_clip_encoding(self, text=None, img=None, encoding=None):
        """
        通过clip模型将编码好的图像和文本组合起来
        """
        encoding = self.create_clip_encoding(text=text, img=img, encoding=encoding)
        self.clip_encoding = encoding.to(self.device)

    def image_output_path(self, sequence_number=None):
        """
        返回下划线分隔的Path
          如果设置了“ self.save_date_time”,则以当前时间戳为准
          如果设置了“ save_every”,则在序列号的左边填充6个零
        :rtype: Path
        """
        output_path = self.textpath
        if sequence_number:
            sequence_number_left_padded = str(sequence_number).zfill(6)
            output_path = f"{output_path}.{sequence_number_left_padded}"
        return Path(f"{output_path}.jpg")

    def train_step(self, epoch, iteration):
        """
        epoch = 3, iteration = 1050
        @return: 权值和loss
        """
        total_loss = 0

        for _ in range(self.gradient_accumulate_every):  # gradient_accumulate_every=4
            # 充当上下文管理器或修饰器,使您的脚本区域可以混合精度运行
            with autocast(enabled=True):
                out, loss = self.model(self.clip_encoding)  # 通过deep-daze模型训练图像与文本的联合编码
            # 计算损失
            loss = loss / self.gradient_accumulate_every
            total_loss += loss

            self.scaler.scale(loss).backward()  # 反向梯度

        out = out.cuda().float().clamp(0., 1.)
        self.scaler.step(self.optimizer)
        self.scaler.update()  # 按照优化器更新权值
        self.optimizer.zero_grad()  # 每次训练将梯度累积,否则会影响下一次梯度计算

        if (iteration % self.save_every == 0) and self.save_progress:
            self.save_image(epoch, iteration, img=out)

        return out, total_loss

    def get_img_sequence_number(self, epoch, iteration):
        current_total_iterations = epoch * self.iterations + iteration
        sequence_number = current_total_iterations // self.save_every
        return sequence_number

    @torch.no_grad()
    def save_image(self, epoch, iteration, img=None):
        sequence_number = self.get_img_sequence_number(epoch, iteration)

        if img is None:
            img = self.model(self.clip_encoding, return_loss=False).cpu().float().clamp(0., 1.)
        self.filename = self.image_output_path(sequence_number=sequence_number)

        pil_img = T.ToPILImage()(img.squeeze())
        pil_img.save(self.filename, quality=95, subsampling=0)
        pil_img.save(f"{self.textpath}.jpg", quality=95, subsampling=0)

        tqdm.write(f'image updated at "./{str(self.filename)}"')

    def generate_gif(self):
        images = []
        for file_name in sorted(os.listdir('./')):
            if file_name.startswith(self.textpath) and file_name != f'{self.textpath}.jpg':
                images.append(imread(os.path.join('./', file_name)))

        if self.save_video:
            mimsave(f'{self.textpath}.mp4', images)
            print(f'Generated image generation animation at ./{self.textpath}.mp4')
        if self.save_gif:
            mimsave(f'{self.textpath}.gif', images)
            print(f'Generated image generation animation at ./{self.textpath}.gif')

    def forward(self):

        tqdm.write(f'Imagining "{self.textpath}" from the depths of my weights...')

        with torch.no_grad():
            self.model(self.clip_encoding, dry_run=True)  # do one warmup step due to potential issue with CLIP and CUDA
        # 打开文件夹
        if self.open_folder:
            open_folder('./')
            self.open_folder = False

        try:
            for epoch in trange(self.epochs, desc='epochs'):  # self.epochs = 3
                pbar = trange(self.iterations, desc='iteration')  # self.iterations = 1050
                for i in pbar:
                    _, loss = self.train_step(epoch, i)  # 训练
                    pbar.set_description(f'loss: {loss.item():.2f}')  # 进度条设置

        except KeyboardInterrupt:
            print('interrupted by keyboard, gracefully exiting')
            return

        self.save_image(epoch, i)  # one final save at end

        if (self.save_gif or self.save_video) and self.save_progress:
            self.generate_gif()
def get_optimizer(optimizer_name: str,
                  parameters,
                  learning_rate: float,
                  weight_decay=1e-5,
                  eps=1e-5,
                  **kwargs) -> Optimizer:
    from torch.optim import SGD, Adam, RMSprop, AdamW
    from torch_optimizer import RAdam, Lamb, DiffGrad, NovoGrad, Ranger

    if optimizer_name.lower() == "sgd":
        return SGD(parameters,
                   learning_rate,
                   momentum=0.9,
                   nesterov=True,
                   weight_decay=weight_decay,
                   **kwargs)

    if optimizer_name.lower() == "adam":
        return Adam(parameters,
                    learning_rate,
                    weight_decay=weight_decay,
                    eps=eps,
                    **kwargs)  # As Jeremy suggests

    if optimizer_name.lower() == "rms":
        return RMSprop(parameters,
                       learning_rate,
                       weight_decay=weight_decay,
                       **kwargs)

    if optimizer_name.lower() == "adamw":
        return AdamW(parameters,
                     learning_rate,
                     weight_decay=weight_decay,
                     eps=eps,
                     **kwargs)

    if optimizer_name.lower() == "radam":
        return RAdam(parameters,
                     learning_rate,
                     weight_decay=weight_decay,
                     eps=eps,
                     **kwargs)  # As Jeremy suggests

    # Optimizers from torch-optimizer
    if optimizer_name.lower() == "ranger":
        return Ranger(parameters,
                      learning_rate,
                      eps=eps,
                      weight_decay=weight_decay,
                      **kwargs)

    if optimizer_name.lower() == "lamb":
        return Lamb(parameters,
                    learning_rate,
                    eps=eps,
                    weight_decay=weight_decay,
                    **kwargs)

    if optimizer_name.lower() == "diffgrad":
        return DiffGrad(parameters,
                        learning_rate,
                        eps=eps,
                        weight_decay=weight_decay,
                        **kwargs)

    if optimizer_name.lower() == "novograd":
        return NovoGrad(parameters,
                        learning_rate,
                        eps=eps,
                        weight_decay=weight_decay,
                        **kwargs)

    # Optimizers from Apex (Fused version is faster on GPU with tensor cores)
    if optimizer_name.lower() == "fused_lamb":
        from apex.optimizers import FusedLAMB

        return FusedLAMB(parameters,
                         learning_rate,
                         eps=eps,
                         weight_decay=weight_decay,
                         **kwargs)

    if optimizer_name.lower() == "fused_sgd":
        from apex.optimizers import FusedSGD

        return FusedSGD(parameters,
                        learning_rate,
                        momentum=0.9,
                        nesterov=True,
                        weight_decay=weight_decay,
                        **kwargs)

    if optimizer_name.lower() == "fused_adam":
        from apex.optimizers import FusedAdam

        return FusedAdam(parameters,
                         learning_rate,
                         eps=eps,
                         weight_decay=weight_decay,
                         adam_w_mode=True,
                         **kwargs)

    raise ValueError("Unsupported optimizer name " + optimizer_name)
示例#15
0
    def __init__(
        self,
        image_size,
        latent_dim=512,
        style_depth=8,
        network_capacity=16,
        transparent=False,
        fp16=False,
        cl_reg=False,
        augment_fn=None,
        steps=1,
        lr=1e-4,
        fq_layers=[],
        fq_dict_size=256,
        attn_layers=[],
    ):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.G = Generator(image_size,
                           latent_dim,
                           network_capacity,
                           transparent=transparent,
                           attn_layers=attn_layers)
        self.D = Discriminator(
            image_size,
            network_capacity,
            fq_layers=fq_layers,
            fq_dict_size=fq_dict_size,
            attn_layers=attn_layers,
            transparent=transparent,
        )

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.GE = Generator(image_size,
                            latent_dim,
                            network_capacity,
                            transparent=transparent,
                            attn_layers=attn_layers)

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(
            self.S.parameters())
        self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(),
                              lr=self.lr,
                              betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()

        self.cuda()

        if fp16:
            (self.S, self.G, self.D, self.SE,
             self.GE), (self.G_opt, self.D_opt) = amp.initialize(
                 [self.S, self.G, self.D, self.SE, self.GE],
                 [self.G_opt, self.D_opt],
                 opt_level="O2")

        # experimental contrastive loss discriminator regularization
        if augment_fn is not None:
            self.augment_fn = augment_fn
        else:
            self.augment_fn = nn.Sequential(
                nn.ReflectionPad2d(int((sqrt(2) - 1) * image_size / 4)),
                RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.7),
                augs.RandomGrayscale(p=0.2),
                augs.RandomHorizontalFlip(),
                RandomApply(augs.RandomAffine(degrees=0,
                                              translate=(0.25, 0.25),
                                              shear=(15, 15)),
                            p=0.3),
                RandomApply(nn.Sequential(
                    augs.RandomRotation(180),
                    augs.CenterCrop(size=(image_size, image_size))),
                            p=0.2),
                augs.RandomResizedCrop(size=(image_size, image_size)),
                RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
                RandomApply(augs.RandomErasing(), p=0.1),
            )

        self.D_cl = (ContrastiveLearner(self.D,
                                        image_size,
                                        augment_fn=self.augment_fn,
                                        fp16=fp16,
                                        hidden_layer="flatten")
                     if cl_reg else None)
    def __init__(self,
                 image_size,
                 latent_dim=512,
                 noise_dim=100,
                 style_depth=8,
                 network_capacity=16,
                 transparent=False,
                 steps=1,
                 lr=2e-4):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.N = NoiseVectorizer(noise_dim)
        self.G = Generator(image_size,
                           latent_dim,
                           network_capacity,
                           transparent=transparent)
        self.D = Discriminator(image_size,
                               network_capacity,
                               transparent=transparent)
        ###########################################
        self.E = ExtractNetSimilarE(64)

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.NE = NoiseVectorizer(noise_dim)
        self.GE = Generator(image_size,
                            latent_dim,
                            network_capacity,
                            transparent=transparent)

        set_requires_grad(self.SE, False)
        set_requires_grad(self.NE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters())
        self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(),
                              lr=self.lr,
                              betas=(0.5, 0.9))
        ###############################################
        # E_params = list(self.E.parameters())+list(self.G.downsample.parameters())
        E_params = list(self.E.parameters()) + list(
            self.N.parameters()) + list(self.G.parameters())

        # E_params = list(self.E.to_logit.parameters())
        # base_param_ids = set(map(id, self.E.to_logit.parameters()))
        # new_params = [p for p in self.E.parameters() if id(p) not in base_param_ids]
        # E_param_groups = [{'params': self.E.parameters(), 'lr': self.lr},
        #                 #   {'params': self.G.parameters(), 'lr': self.lr},
        #                 #   {'params': new_params, 'lr': self.lr},  # other E layers
        #                   {'params': self.N.parameters(), 'lr': self.lr}
        #                   ]

        self.E_opt = DiffGrad(E_params, lr=self.lr, betas=(0.5, 0.9))

        self.E_opt_scheduler = torch.optim.lr_scheduler.StepLR(self.E_opt,
                                                               step_size=200,
                                                               gamma=0.1)

        N_params = list(self.N.parameters())
        self.N_opt = DiffGrad(N_params, lr=self.lr, betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()
示例#17
0
def get_optimizer(
    model: nn.Module,
    optimizer_name: str,
    learning_rate: float,
    weight_decay: float = 1e-5,
    no_weight_decay_on_bias: bool = False,
    eps: float = 1e-5,
    **kwargs,
) -> Optimizer:
    """
    Construct an Optimizer for given model
    Args:
        model: Model to optimize. Only parameters that require_grad will be used
        optimizer_name: Name of the optimizer. Case-insensitive
        learning_rate: Target learning rate (regardless of the scheduler)
        weight_decay: Target weight decay
        no_weight_decay_on_bias: Whether to disable weight decay on bias parameters
        eps: Default epsilon for Adam-like optimizers.
        **kwargs: Additional parameters for optimizer

    Returns:

    """
    from torch.optim import ASGD, SGD, Adam, RMSprop, AdamW
    from torch_optimizer import RAdam, Lamb, DiffGrad, NovoGrad, Ranger

    # Optimizer parameter groups
    default_pg, biases_pg = [], []

    for k, v in model.named_parameters():
        if v.requires_grad:
            if str.endswith(k, ".bias"):
                biases_pg.append(v)  # biases
            else:
                default_pg.append(v)  # all else

    if no_weight_decay_on_bias:
        parameters = default_pg
    else:
        parameters = default_pg + biases_pg

    optimizer: Optimizer = None

    if optimizer_name.lower() == "sgd":
        optimizer = SGD(
            parameters,
            lr=learning_rate,
            momentum=0.9,
            nesterov=True,
            weight_decay=weight_decay,
            **kwargs,
        )
    elif optimizer_name.lower() == "asgd":
        optimizer = ASGD(
            parameters,
            lr=learning_rate,
            weight_decay=weight_decay,
            **kwargs,
        )
    elif optimizer_name.lower() == "adam":
        optimizer = Adam(
            parameters,
            lr=learning_rate,
            weight_decay=weight_decay,
            eps=eps,
            **kwargs,
        )
    elif optimizer_name.lower() == "rms":
        optimizer = RMSprop(parameters,
                            learning_rate,
                            weight_decay=weight_decay,
                            **kwargs)
    elif optimizer_name.lower() == "adamw":
        optimizer = AdamW(
            parameters,
            lr=learning_rate,
            weight_decay=weight_decay,
            eps=eps,
            **kwargs,
        )
    elif optimizer_name.lower() == "radam":
        optimizer = RAdam(
            parameters,
            lr=learning_rate,
            weight_decay=weight_decay,
            eps=eps,
            **kwargs,
        )
    elif optimizer_name.lower() == "ranger":
        optimizer = Ranger(
            parameters,
            lr=learning_rate,
            eps=eps,
            weight_decay=weight_decay,
            **kwargs,
        )
    elif optimizer_name.lower() == "lamb":
        optimizer = Lamb(
            parameters,
            lr=learning_rate,
            eps=eps,
            weight_decay=weight_decay,
            **kwargs,
        )
    elif optimizer_name.lower() == "diffgrad":
        optimizer = DiffGrad(
            parameters,
            lr=learning_rate,
            eps=eps,
            weight_decay=weight_decay,
            **kwargs,
        )
    elif optimizer_name.lower() == "novograd":
        optimizer = NovoGrad(
            parameters,
            lr=learning_rate,
            eps=eps,
            weight_decay=weight_decay,
            **kwargs,
        )
    elif optimizer_name.lower() == "fused_lamb":
        from apex.optimizers import FusedLAMB

        optimizer = FusedLAMB(parameters,
                              learning_rate,
                              eps=eps,
                              weight_decay=weight_decay,
                              **kwargs)
    elif optimizer_name.lower() == "fused_sgd":
        from apex.optimizers import FusedSGD

        optimizer = FusedSGD(parameters,
                             learning_rate,
                             momentum=0.9,
                             nesterov=True,
                             weight_decay=weight_decay,
                             **kwargs)
    elif optimizer_name.lower() == "fused_adam":
        from apex.optimizers import FusedAdam

        optimizer = FusedAdam(parameters,
                              learning_rate,
                              eps=eps,
                              weight_decay=weight_decay,
                              adam_w_mode=True,
                              **kwargs)
    else:
        raise KeyError(f"Cannot get optimizer by name {optimizer_name}")

    # Currently either no_wd or per-group lr
    if no_weight_decay_on_bias:
        optimizer.add_param_group({"params": biases_pg, "weight_decay": 0})

    return optimizer