def generate_story(self, netG, dataloader): from miscc.utils import images_to_numpy import PIL # netG, _, _ = self.load_network_stageI() # state_dict = torch.load(model_path, # map_location=lambda storage, loc: storage) # netG.load_state_dict(state_dict) origin_img_path = os.path.join(self.save_dir, 'original') generated_img_path = os.path.join(self.save_dir, 'generate') os.makedirs(origin_img_path, exist_ok=True) os.makedirs(generated_img_path, exist_ok=True) print('Generating Test Samples...') save_images, save_labels = [], [] story_id = 0 for batch in tqdm(dataloader): #print('Processing at ' + str(i)) real_cpu = batch['images'] motion_input = batch['description'][:, :, :cfg.TEXT.DIMENSION] content_input = batch['description'][:, :, :cfg.TEXT.DIMENSION] catelabel = batch['labels'] real_imgs = Variable(real_cpu) motion_input = Variable(motion_input) content_input = Variable(content_input) if cfg.CUDA: real_imgs = real_imgs.cuda() motion_input = motion_input.cuda() content_input = content_input.cuda() catelabel = catelabel.cuda() motion_input = torch.cat((motion_input, catelabel), 2) #content_input = torch.cat((content_input, catelabel), 2) _, fake_stories, _,_,_,_,_ = netG.sample_videos(motion_input, content_input) real_cpu = real_cpu.transpose(1, 2) fake_stories = fake_stories.transpose(1, 2) for (fake_story, real_story) in zip(fake_stories, real_cpu): origin_story_path = os.path.join(origin_img_path, str(story_id)) os.makedirs(origin_story_path, exist_ok=True) generated_story_path = os.path.join(generated_img_path, str(story_id)) os.makedirs(generated_story_path, exist_ok=True) for idx, (fake, real) in enumerate(zip(fake_story, real_story)): fake_img = images_to_numpy(fake) fake_img = PIL.Image.fromarray(fake_img) fake_img.save(os.path.join(generated_story_path, str(idx)+'.png')) real_img = images_to_numpy(real) real_img = PIL.Image.fromarray(real_img) real_img.save(os.path.join(origin_story_path, str(idx)+'.png')) story_id += 1
def _save_story_images(self, lr_fake, st_fake, num, output_dir): fake_imgs = lr_fake if lr_fake is not None else st_fake # 24, 3, 5, 64, 64 for i in range(fake_imgs.shape[0]): story_imgs = fake_imgs[i].squeeze(0).transpose(0, 1) # 5, 3, 64, 64 for j in range(story_imgs.shape[0]): sentence_img = story_imgs[j] sentence_img = images_to_numpy(sentence_img) image = PIL.Image.fromarray(sentence_img) image.save( '%s/test_epoch_%03d_batch_%d_story_%d_sentence_%d.png' % (output_dir, 22, num, i, j))
def evaluate(self, weight_path, testloader, output_path): self.testloader = testloader model_structure = torch.load(os.path.join(self.model_dir, 'barebone.pth')) netG = model_structure['netG'] # netG, _, _ = self.load_networks() netG_weights = torch.load(weight_path) netG.load_state_dict(netG_weights) st_id = 0 netG = netG.cuda() netG.eval() with torch.no_grad(): for data in tqdm(self.testloader): st_batch = data st_real_cpu = st_batch['images'] st_motion_input = st_batch['description'] st_content_input = st_batch['description'] st_catelabel = st_batch['label'] st_real_imgs = Variable(st_real_cpu) st_motion_input = Variable(st_motion_input) st_content_input = Variable(st_content_input) if cfg.CUDA: st_real_imgs = st_real_imgs.cuda() st_motion_input = st_motion_input.cuda() st_content_input = st_content_input.cuda() st_catelabel = st_catelabel.cuda() st_inputs = (st_motion_input, st_content_input) _, st_fake, _, _, _, _ = netG.sample_videos(*st_inputs) for story_imgs in st_fake: # convert C x T x W x H -> T x C x W x H story_imgs = story_imgs.transpose(1, 0) for idx, img in enumerate(story_imgs): img = images_to_numpy(img) output = Image.fromarray(img) output.save(os.path.join(output_path, str(st_id)+'_'+str(idx)+'.jpg')) st_id += 1
st.info("Encoding your story...") tokenized_descriptions = torch.cat( [clip.tokenize(s) for s in sentences]).to(device) with torch.no_grad(): encoded_descriptions = clip_model.encode_text( tokenized_descriptions).float() # Run the encoded story text through StoryGAN fake_imgs = single_inference(gan_args, encoded_descriptions) # Show the credits for each photo in an expandable sidebar st.markdown(f"## Your Storybook: \n") col_generator = st.beta_columns(video_len) for i, col in enumerate(col_generator): images = fake_imgs.squeeze(0).transpose(0, 1)[i].squeeze(0) images = images_to_numpy(images) image = PIL.Image.fromarray(images) with col: st.image(image, use_column_width='always') st.text(sentences[i]) # Format st.sidebar.title("Storybook Illustrator") st.sidebar.markdown("-----------------------------------") st.sidebar.markdown( f"[{app_formal_name}](https://github.com/eunjeeSung/StoryGAN) " f"creates storybook illustrations from stories." f" The model was trained on a GANILLA-fied subset of the VIST dataset from Microsoft." ) st.sidebar.markdown(