class Imagine(nn.Module): def __init__( self, *, text=None, img=None, clip_encoding=None, lr=1e-5, batch_size=4, gradient_accumulate_every=4, save_every=100, image_width=512, num_layers=16, epochs=20, iterations=1050, save_progress=False, seed=None, open_folder=True, save_date_time=False, start_image_path=None, start_image_train_iters=10, start_image_lr=3e-4, theta_initial=None, theta_hidden=None, ): super().__init__() if exists(seed): tqdm.write(f'setting seed: {seed}') torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True self.epochs = epochs self.iterations = iterations self.image_width = image_width total_batches = epochs * iterations * batch_size * gradient_accumulate_every model = DeepDaze( total_batches=total_batches, batch_size=batch_size, image_width=image_width, num_layers=num_layers, theta_initial=theta_initial, theta_hidden=theta_hidden ).cuda() self.model = model self.scaler = GradScaler() self.optimizer = AdamP(model.parameters(), lr) self.gradient_accumulate_every = gradient_accumulate_every self.save_every = save_every self.save_date_time = save_date_time self.open_folder = open_folder self.save_progress = save_progress self.text = text self.image = img self.textpath = create_text_path(text=text, img=img, encoding=clip_encoding) self.filename = self.image_output_path() # create coding to optimize for self.clip_img_transform = create_clip_img_transform(perceptor.input_resolution.item()) self.clip_encoding = self.create_clip_encoding(text=text, img=img, encoding=clip_encoding) self.start_image = None self.start_image_train_iters = start_image_train_iters self.start_image_lr = start_image_lr if exists(start_image_path): file = Path(start_image_path) assert file.exists(), f'file does not exist at given starting image path {self.start_image_path}' image = Image.open(str(file)) transform = T.Compose([ T.Resize(image_width), T.CenterCrop((image_width, image_width)), T.ToTensor(), T.Normalize(0.5, 0.5) ]) image_tensor = transform(image)[None, ...].cuda() self.start_image = image_tensor def create_clip_encoding(self, text=None, img=None, encoding=None): self.text = text self.img = img if encoding is not None: return encoding.cuda() elif text is not None: return self.create_text_encoding(text) elif img is not None: return self.create_img_encoding(img) @staticmethod def create_text_encoding(text): tokenized_text = tokenize(text).cuda() text_encoding = perceptor.encode_text(tokenized_text).detach() return text_encoding def create_img_encoding(self, img): if isinstance(img, str): img = Image.open(img) normed_img = self.clip_img_transform(img).unsqueeze(0).cuda() img_encoding = perceptor.encode_image(normed_img).detach() return img_encoding def set_clip_encoding(self, text=None, img=None, encoding=None): encoding = self.create_clip_encoding(text=text, img=img, encoding=encoding).cuda() self.clip_encoding = encoding def image_output_path(self, sequence_number=None): """ Returns underscore separated Path. A current timestamp is prepended if `self.save_date_time` is set. Sequence number left padded with 6 zeroes is appended if `save_every` is set. :rtype: Path """ output_path = self.textpath if sequence_number: sequence_number_left_padded = str(sequence_number).zfill(6) output_path = f"{output_path}.{sequence_number_left_padded}" if self.save_date_time: current_time = datetime.now().strftime("%y%m%d-%H%M%S_%f") output_path = f"{current_time}_{output_path}" return Path(f"{output_path}.png") def train_step(self, epoch, iteration): total_loss = 0 for _ in range(self.gradient_accumulate_every): with autocast(): loss = self.model(self.clip_encoding) loss = loss / self.gradient_accumulate_every total_loss += loss self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() if (iteration % self.save_every == 0) and self.save_progress: self.save_image(epoch, iteration) return total_loss @torch.no_grad() def save_image(self, epoch, iteration): current_total_iterations = epoch * self.iterations + iteration sequence_number = current_total_iterations // self.save_every img = normalize_image(self.model(self.clip_encoding, return_loss=False).cpu()) img.clamp_(0., 1.) self.filename = self.image_output_path(sequence_number=sequence_number) save_image(img, self.filename) save_image(img, f"{self.textpath}.png") tqdm.write(f'image updated at "./{str(self.filename)}"') def forward(self): if exists(self.start_image): tqdm.write('Preparing with initial image...') optim = DiffGrad(self.model.parameters(), lr = self.start_image_lr) pbar = trange(self.start_image_train_iters, desc='iteration') for _ in pbar: loss = self.model.model(self.start_image) loss.backward() pbar.set_description(f'loss: {loss.item():.2f}') optim.step() optim.zero_grad() if terminate: print('interrupted by keyboard, gracefully exiting') return sys.exit() del self.start_image del optim tqdm.write(f'Imagining "{self.textpath}" from the depths of my weights...') self.model(self.clip_encoding, dry_run = True) # do one warmup step due to potential issue with CLIP and CUDA if self.open_folder: open_folder('./') self.open_folder = False for epoch in trange(self.epochs, desc='epochs'): pbar = trange(self.iterations, desc='iteration') for i in pbar: loss = self.train_step(epoch, i) pbar.set_description(f'loss: {loss.item():.2f}') if terminate: print('interrupted by keyboard, gracefully exiting') return self.save_image(self.epochs, self.iterations) # one final save at end
class Imagine(nn.Module): def __init__( self, *, text=None, img=None, clip_encoding=None, lr=1e-5, batch_size=4, gradient_accumulate_every=4, save_every=100, image_width=512, num_layers=16, epochs=20, iterations=1050, save_progress=True, seed=None, open_folder=True, save_date_time=False, start_image_path=None, start_image_train_iters=10, start_image_lr=3e-4, theta_initial=None, theta_hidden=None, lower_bound_cutout=0.1, # should be smaller than 0.8 upper_bound_cutout=1.0, saturate_bound=False, create_story=False, story_start_words=5, story_words_per_epoch=5, save_gif=False): super().__init__() if exists(seed): tqdm.write(f'setting seed: {seed}') torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True # fields for story creation: self.create_story = create_story self.words = None self.all_words = text.split(" ") if text is not None else None self.num_start_words = story_start_words self.words_per_epoch = story_words_per_epoch if create_story: assert text is not None, "We need text input to create a story..." # overwrite epochs to match story length num_words = len(self.all_words) self.epochs = 1 + (num_words - self.num_start_words) / self.words_per_epoch # add one epoch if not divisible self.epochs = int(self.epochs) if int( self.epochs) == self.epochs else int(self.epochs) + 1 print("Running for ", self.epochs, "epochs") else: self.epochs = epochs self.iterations = iterations self.image_width = image_width total_batches = self.epochs * self.iterations * batch_size * gradient_accumulate_every model = DeepDaze( total_batches=total_batches, batch_size=batch_size, image_width=image_width, num_layers=num_layers, theta_initial=theta_initial, theta_hidden=theta_hidden, lower_bound_cutout=lower_bound_cutout, upper_bound_cutout=upper_bound_cutout, saturate_bound=saturate_bound, ).cuda() self.model = model self.scaler = GradScaler() self.optimizer = AdamP(model.parameters(), lr) self.gradient_accumulate_every = gradient_accumulate_every self.save_every = save_every self.save_date_time = save_date_time self.open_folder = open_folder self.save_progress = save_progress self.text = text self.image = img self.textpath = create_text_path(text=text, img=img, encoding=clip_encoding) self.filename = self.image_output_path() # create coding to optimize for self.clip_img_transform = create_clip_img_transform(224) self.clip_encoding = self.create_clip_encoding(text=text, img=img, encoding=clip_encoding) self.start_image = None self.start_image_train_iters = start_image_train_iters self.start_image_lr = start_image_lr if exists(start_image_path): file = Path(start_image_path) assert file.exists( ), f'file does not exist at given starting image path {self.start_image_path}' image = Image.open(str(file)) image_tensor = self.clip_img_transform(image)[None, ...].cuda() self.start_image = image_tensor self.save_gif = save_gif def create_clip_encoding(self, text=None, img=None, encoding=None): self.text = text self.img = img if encoding is not None: encoding = encoding.cuda() elif self.create_story: encoding = self.update_story_encoding(epoch=0, iteration=1) elif text is not None and img is not None: encoding = (self.create_text_encoding(text) + self.create_img_encoding(img)) / 2 elif text is not None: encoding = self.create_text_encoding(text) elif img is not None: encoding = self.create_img_encoding(img) return encoding def create_text_encoding(self, text): tokenized_text = tokenize(text).cuda() with torch.no_grad(): text_encoding = perceptor.encode_text(tokenized_text).detach() return text_encoding def create_img_encoding(self, img): if isinstance(img, str): img = Image.open(img) normed_img = self.clip_img_transform(img).unsqueeze(0).cuda() with torch.no_grad(): img_encoding = perceptor.encode_image(normed_img).detach() return img_encoding def set_clip_encoding(self, text=None, img=None, encoding=None): encoding = self.create_clip_encoding(text=text, img=img, encoding=encoding) self.clip_encoding = encoding.cuda() def update_story_encoding(self, epoch, iteration): if self.words is None: self.words = " ".join(self.all_words[:self.num_start_words]) self.all_words = self.all_words[self.num_start_words:] else: # add words_per_epoch new words count = 0 while count < self.words_per_epoch and len(self.all_words) > 0: new_word = self.all_words[0] self.words = " ".join(self.words.split(" ") + [new_word]) self.all_words = self.all_words[1:] count += 1 # TODO: possibly do not increase count for stop-words and break if a "." is encountered. # remove words until it fits in context length while len(self.words) > perceptor.context_length: # remove first word self.words = " ".join(self.words.split(" ")[1:]) # get new encoding print("Now thinking of: ", '"', self.words, '"') sequence_number = self.get_img_sequence_number(epoch, iteration) # save new words to disc with open("story_transitions.txt", "a") as f: f.write(f"{epoch}, {sequence_number}, {self.words}\n") encoding = self.create_text_encoding(self.words) return encoding def image_output_path(self, sequence_number=None): """ Returns underscore separated Path. A current timestamp is prepended if `self.save_date_time` is set. Sequence number left padded with 6 zeroes is appended if `save_every` is set. :rtype: Path """ output_path = self.textpath if sequence_number: sequence_number_left_padded = str(sequence_number).zfill(6) output_path = f"{output_path}.{sequence_number_left_padded}" if self.save_date_time: current_time = datetime.now().strftime("%y%m%d-%H%M%S_%f") output_path = f"{current_time}_{output_path}" return Path(f"{output_path}.jpg") def train_step(self, epoch, iteration): total_loss = 0 for _ in range(self.gradient_accumulate_every): with autocast(): out, loss = self.model(self.clip_encoding) loss = loss / self.gradient_accumulate_every total_loss += loss self.scaler.scale(loss).backward() out = out.cpu().float().clamp(0., 1.) self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() if (iteration % self.save_every == 0) and self.save_progress: self.save_image(epoch, iteration, img=out) return out, total_loss def get_img_sequence_number(self, epoch, iteration): current_total_iterations = epoch * self.iterations + iteration sequence_number = current_total_iterations // self.save_every return sequence_number @torch.no_grad() def save_image(self, epoch, iteration, img=None): sequence_number = self.get_img_sequence_number(epoch, iteration) if img is None: img = self.model(self.clip_encoding, return_loss=False).cpu().float().clamp(0., 1.) self.filename = self.image_output_path(sequence_number=sequence_number) pil_img = T.ToPILImage()(img.squeeze()) pil_img.save(self.filename, quality=95, subsampling=0) pil_img.save(f"{self.textpath}.jpg", quality=95, subsampling=0) #save_image(img, self.filename) #save_image(img, f"{self.textpath}.png") tqdm.write(f'image updated at "./{str(self.filename)}"') def generate_gif(self): images = [] for file_name in sorted(os.listdir('./')): if file_name.startswith( self.textpath) and file_name != f'{self.textpath}.jpg': images.append(imread(os.path.join('./', file_name))) mimsave(f'{self.textpath}.gif', images) print(f'Generated image generation animation at ./{self.textpath}.gif') def forward(self): if exists(self.start_image): tqdm.write('Preparing with initial image...') optim = DiffGrad(self.model.parameters(), lr=self.start_image_lr) pbar = trange(self.start_image_train_iters, desc='iteration') for _ in pbar: loss = self.model.model(self.start_image) loss.backward() pbar.set_description(f'loss: {loss.item():.2f}') optim.step() optim.zero_grad() if terminate: print('interrupted by keyboard, gracefully exiting') return exit() del self.start_image del optim tqdm.write( f'Imagining "{self.textpath}" from the depths of my weights...') with torch.no_grad(): self.model( self.clip_encoding, dry_run=True ) # do one warmup step due to potential issue with CLIP and CUDA if self.open_folder: open_folder('./') self.open_folder = False for epoch in trange(self.epochs, desc='epochs'): pbar = trange(self.iterations, desc='iteration') for i in pbar: _, loss = self.train_step(epoch, i) pbar.set_description(f'loss: {loss.item():.2f}') if terminate: print('interrupted by keyboard, gracefully exiting') return # Update clip_encoding per epoch if we are creating a story if self.create_story: self.clip_encoding = self.update_story_encoding(epoch, i) self.save_image(epoch, i) # one final save at end if self.save_gif and self.save_progress: self.generate_gif()
class Imagine(nn.Module): def __init__( self, text, *, lr=1e-5, batch_size=4, gradient_accumulate_every=4, save_every=100, image_width=512, num_layers=16, epochs=20, iterations=1050, save_progress=False, seed=None, open_folder=True, save_date_time=False, start_image_path=None, start_image_train_iters=10, start_image_lr=3e-4, theta_initial=None, theta_hidden=None, ): super().__init__() if exists(seed): tqdm.write(f'setting seed: {seed}') torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True self.epochs = epochs self.iterations = iterations total_batches = epochs * iterations * batch_size * gradient_accumulate_every model = DeepDaze(total_batches=total_batches, batch_size=batch_size, image_width=image_width, num_layers=num_layers, theta_initial=theta_initial, theta_hidden=theta_hidden).cuda() self.model = model self.scaler = GradScaler() self.optimizer = AdamP(model.parameters(), lr) self.gradient_accumulate_every = gradient_accumulate_every self.save_every = save_every self.save_date_time = save_date_time self.open_folder = open_folder self.save_progress = save_progress self.text = text self.textpath = text.replace(" ", "_") self.filename = self.image_output_path() self.encoded_text = tokenize(text).cuda() self.start_image = None self.start_image_train_iters = start_image_train_iters self.start_image_lr = start_image_lr if exists(start_image_path): file = Path(start_image_path) assert file.exists( ), f'file does not exist at given starting image path {self.start_image_path}' image = Image.open(str(file)) transform = T.Compose([ T.Resize(image_width), T.CenterCrop((image_width, image_width)), T.ToTensor(), T.Normalize(0.5, 0.5) ]) image_tensor = transform(image)[None, ...].cuda() self.start_image = image_tensor def image_output_path(self, sequence_number=None): """ Returns underscore separated Path. A current timestamp is prepended if `self.save_date_time` is set. Sequence number left padded with 6 zeroes is appended if `save_every` is set. :rtype: Path """ output_path = self.textpath if sequence_number: sequence_number_left_padded = str(sequence_number).zfill(6) output_path = f"{output_path}.{sequence_number_left_padded}" if self.save_date_time: current_time = datetime.now().strftime("%y%m%d-%H%M%S_%f") output_path = f"{current_time}_{output_path}" return Path(f"{output_path}.png") def generate_and_save_image(self, sequence_number=None): """ :param sequence_number: :param custom_filename: A custom filename to use when saving - e.g. "testing.png" """ with torch.no_grad(): img = normalize_image( self.model(self.encoded_text, return_loss=False).cpu()) img.clamp_(0., 1.) self.filename = self.image_output_path( sequence_number=sequence_number) save_image(img, self.filename) save_image(img, f"{self.textpath}.png") tqdm.write(f'image updated at "./{str(self.filename)}"') def train_step(self, epoch, iteration) -> int: total_loss = 0 for _ in range(self.gradient_accumulate_every): with autocast(): loss = self.model(self.encoded_text) loss = loss / self.gradient_accumulate_every total_loss += loss self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() if (iteration % self.save_every == 0) and self.save_progress: current_total_iterations = epoch * self.iterations + iteration sequence_number = current_total_iterations // self.save_every self.generate_and_save_image(sequence_number=sequence_number) return total_loss def forward(self): if exists(self.start_image): tqdm.write('Preparing with initial image...') optim = DiffGrad(self.model.parameters(), lr=self.start_image_lr) pbar = trange(self.start_image_train_iters, desc='iteration') for _ in pbar: loss = self.model.model(self.start_image) loss.backward() pbar.set_description(f'loss: {loss.item():.2f}') optim.step() optim.zero_grad() if terminate: print('interrupted by keyboard, gracefully exiting') return exit() del self.start_image del optim tqdm.write(f'Imagining "{self.text}" from the depths of my weights...') if self.open_folder: open_folder('./') self.open_folder = False for epoch in trange(self.epochs, desc='epochs'): pbar = trange(self.iterations, desc='iteration') for i in pbar: loss = self.train_step(epoch, i) pbar.set_description(f'loss: {loss.item():.2f}') if terminate: print('interrupted by keyboard, gracefully exiting') return