コード例 #1
0
    def generate_story(self, netG, dataloader):
        from miscc.utils import images_to_numpy
        import PIL

        # netG, _, _ = self.load_network_stageI()
        # state_dict = torch.load(model_path,
        #                 map_location=lambda storage, loc: storage)
        # netG.load_state_dict(state_dict)

        origin_img_path = os.path.join(self.save_dir, 'original')
        generated_img_path = os.path.join(self.save_dir, 'generate')
        os.makedirs(origin_img_path, exist_ok=True)
        os.makedirs(generated_img_path, exist_ok=True)

        print('Generating Test Samples...')
        save_images, save_labels = [], []
        story_id = 0
        for batch in tqdm(dataloader):
            #print('Processing at ' + str(i))
            real_cpu = batch['images']
            motion_input = batch['description'][:, :, :cfg.TEXT.DIMENSION]
            content_input = batch['description'][:, :, :cfg.TEXT.DIMENSION]
            catelabel = batch['labels']
            real_imgs = Variable(real_cpu)
            motion_input = Variable(motion_input)
            content_input = Variable(content_input)
            if cfg.CUDA:
                real_imgs = real_imgs.cuda()            
                motion_input = motion_input.cuda()
                content_input = content_input.cuda()
                catelabel = catelabel.cuda()
            motion_input = torch.cat((motion_input, catelabel), 2)
            #content_input = torch.cat((content_input, catelabel), 2)
            _, fake_stories, _,_,_,_,_ = netG.sample_videos(motion_input, content_input)
            real_cpu = real_cpu.transpose(1, 2)
            fake_stories = fake_stories.transpose(1, 2)

            for (fake_story, real_story) in zip(fake_stories, real_cpu):
                origin_story_path = os.path.join(origin_img_path, str(story_id))
                os.makedirs(origin_story_path, exist_ok=True)
                generated_story_path = os.path.join(generated_img_path, str(story_id))
                os.makedirs(generated_story_path, exist_ok=True)

                for idx, (fake, real) in enumerate(zip(fake_story, real_story)):
                    fake_img = images_to_numpy(fake)
                    fake_img = PIL.Image.fromarray(fake_img)
                    fake_img.save(os.path.join(generated_story_path, str(idx)+'.png'))

                    real_img = images_to_numpy(real)
                    real_img = PIL.Image.fromarray(real_img)
                    real_img.save(os.path.join(origin_story_path, str(idx)+'.png'))
                
                story_id += 1
コード例 #2
0
 def _save_story_images(self, lr_fake, st_fake, num, output_dir):
     fake_imgs = lr_fake if lr_fake is not None else st_fake  # 24, 3, 5, 64, 64
     for i in range(fake_imgs.shape[0]):
         story_imgs = fake_imgs[i].squeeze(0).transpose(0,
                                                        1)  # 5, 3, 64, 64
         for j in range(story_imgs.shape[0]):
             sentence_img = story_imgs[j]
             sentence_img = images_to_numpy(sentence_img)
             image = PIL.Image.fromarray(sentence_img)
             image.save(
                 '%s/test_epoch_%03d_batch_%d_story_%d_sentence_%d.png' %
                 (output_dir, 22, num, i, j))
コード例 #3
0
ファイル: trainer.py プロジェクト: theblackcat102/StoryGAN
	def evaluate(self, weight_path, testloader, output_path):
		self.testloader = testloader
		model_structure = torch.load(os.path.join(self.model_dir, 'barebone.pth'))
		netG = model_structure['netG']
		# netG, _, _ = self.load_networks()
		netG_weights = torch.load(weight_path)
		netG.load_state_dict(netG_weights)
		st_id = 0

		netG = netG.cuda()
		netG.eval()
		with torch.no_grad():
			for data in tqdm(self.testloader):
				st_batch = data
				st_real_cpu = st_batch['images']
				st_motion_input = st_batch['description']
				st_content_input = st_batch['description']
				st_catelabel = st_batch['label']
				st_real_imgs = Variable(st_real_cpu)
				st_motion_input = Variable(st_motion_input)
				st_content_input = Variable(st_content_input)

				if cfg.CUDA:
					st_real_imgs = st_real_imgs.cuda()
					st_motion_input = st_motion_input.cuda()
					st_content_input = st_content_input.cuda()
					st_catelabel = st_catelabel.cuda()			
				st_inputs = (st_motion_input, st_content_input)
				_, st_fake, _, _, _, _ = netG.sample_videos(*st_inputs)
				for story_imgs in st_fake:
					# convert C x T x W x H -> T x C x W x H
					story_imgs = story_imgs.transpose(1, 0)
					for idx, img in enumerate(story_imgs):
						img = images_to_numpy(img)
						output = Image.fromarray(img)
						output.save(os.path.join(output_path, str(st_id)+'_'+str(idx)+'.jpg'))
					st_id += 1
コード例 #4
0
        st.info("Encoding your story...")
        tokenized_descriptions = torch.cat(
            [clip.tokenize(s) for s in sentences]).to(device)
        with torch.no_grad():
            encoded_descriptions = clip_model.encode_text(
                tokenized_descriptions).float()

        # Run the encoded story text through StoryGAN
        fake_imgs = single_inference(gan_args, encoded_descriptions)

        # Show the credits for each photo in an expandable sidebar
        st.markdown(f"## Your Storybook: \n")
        col_generator = st.beta_columns(video_len)
        for i, col in enumerate(col_generator):
            images = fake_imgs.squeeze(0).transpose(0, 1)[i].squeeze(0)
            images = images_to_numpy(images)
            image = PIL.Image.fromarray(images)

            with col:
                st.image(image, use_column_width='always')
                st.text(sentences[i])

        # Format
        st.sidebar.title("Storybook Illustrator")
        st.sidebar.markdown("-----------------------------------")
        st.sidebar.markdown(
            f"[{app_formal_name}](https://github.com/eunjeeSung/StoryGAN) "
            f"creates storybook illustrations from stories."
            f" The model was trained on a GANILLA-fied subset of the VIST dataset from Microsoft."
        )
        st.sidebar.markdown(