def evaluate(self, model_dir, config): real_features = tf.contrib.gan.eval.run_inception( images=tf.contrib.gan.eval.preprocess_image(self.real_images), output_tensor="pool_3:0") fake_features = tf.contrib.gan.eval.run_inception( images=tf.contrib.gan.eval.preprocess_image(self.fake_images), output_tensor="pool_3:0") with tf.train.SingularMonitoredSession(scaffold=tf.train.Scaffold( init_op=tf.global_variables_initializer(), local_init_op=tf.group(tf.local_variables_initializer(), tf.tables_initializer())), checkpoint_dir=model_dir, config=config) as session: def generator(): while True: try: yield session.run([real_features, fake_features]) except tf.errors.OutOfRangeError: break frechet_inception_distance = metrics.frechet_inception_distance( *map(np.concatenate, zip(*generator()))) tf.logging.info("frechet_inception_distance: {}".format( frechet_inception_distance))
def compute_reconstruction_metrics(self): epoch_id = self.get_last_epoch() data = self.load_data(epoch_id) recon = self.load_recon(epoch_id) # TODO(nina): Rewrite mi and fid in pytorch mutual_information = metrics.mutual_information(recon, data) fid = metrics.frechet_inception_distance(recon, data) data = torch.Tensor(data) recon = torch.Tensor(recon) bce = metrics.binary_cross_entropy(recon, data) mse = metrics.mse(recon, data) l1_norm = metrics.l1_norm(recon, data) context = { 'title': 'Vaetree Report', 'bce': bce, 'mse': mse, 'l1_norm': l1_norm, 'mutual_information': mutual_information, 'fid': fid, } # Placeholder return context
def validate(self): def activation_generator(): self.generator.eval() self.classifier.eval() for real_images, _ in self.val_data_loader: real_images = real_images.cuda(non_blocking=True) latents = torch.randn(self.val_batch_size, self.latent_size, 1, 1) latents = latents.cuda(non_blocking=True) fake_images = self.generator(latents) real_activations = self.classifier(real_images) fake_activations = self.classifier(fake_images) real_activations_list = [real_activations] * distributed.get_world_size() fake_activations_list = [fake_activations] * distributed.get_world_size() distributed.all_gather(real_activations_list, real_activations) distributed.all_gather(fake_activations_list, fake_activations) for real_activations, fake_activations in zip(real_activations_list, fake_activations_list): yield real_activations, fake_activations real_activations, fake_activations = map(torch.cat, zip(*activation_generator())) frechet_inception_distance = metrics.frechet_inception_distance(real_activations.cpu().numpy(), fake_activations.cpu().numpy()) self.log_scalar('frechet_inception_distance', frechet_inception_distance)
def evaluate(self, model_dir, config, classifier, input_name, output_names): real_features, real_logits = tf.import_graph_def( graph_def=classifier, input_map={input_name: self.real_images}, return_elements=output_names) fake_features, fake_logits = tf.import_graph_def( graph_def=classifier, input_map={input_name: self.fake_images}, return_elements=output_names) with tf.train.SingularMonitoredSession(scaffold=tf.train.Scaffold( init_op=tf.global_variables_initializer(), local_init_op=tf.group(tf.local_variables_initializer(), tf.tables_initializer())), checkpoint_dir=model_dir, config=config) as session: def generator(): while not session.should_stop(): try: yield session.run([real_features, fake_features]) except tf.errors.OutOfRangeError: break frechet_inception_distance = metrics.frechet_inception_distance( *map(np.concatenate, zip(*generator()))) return dict(frechet_inception_distance=frechet_inception_distance)
def validate(self): def create_activation_generator(data_loader): def activation_generator(): self.inception.eval() for real_images, _ in data_loader: batch_size = real_images.size(0) real_images = real_images.cuda(non_blocking=True) latents = torch.randn(batch_size, self.latent_size, 1, 1) latents = latents.cuda(non_blocking=True) fake_images = self.generator(latents) real_images = nn.functional.interpolate(real_images, size=(299, 299), mode="bilinear") fake_images = nn.functional.interpolate(fake_images, size=(299, 299), mode="bilinear") real_activations = self.inception(real_images) fake_activations = self.inception(fake_images) real_activations_list = [real_activations ] * self.world_size fake_activations_list = [fake_activations ] * self.world_size distributed.all_gather(real_activations_list, real_activations) distributed.all_gather(fake_activations_list, fake_activations) for real_activations, fake_activations in zip( real_activations_list, fake_activations_list): yield real_activations, fake_activations return activation_generator self.generator.eval() self.discriminator.eval() real_activations, fake_activations = map( torch.cat, zip(*create_activation_generator(self.val_data_loader)())) frechet_inception_distance = metrics.frechet_inception_distance( real_activations.cpu().numpy(), fake_activations.cpu().numpy()) self.log_scalars( {'frechet_inception_distance': frechet_inception_distance}, 'validation')
summary_writer.add_scalars(main_tag="loss", tag_scalar_dict=dict( generator=generator_loss, discriminator=discriminator_loss), global_step=global_step) global_step += 1 torch.save(generator.state_dict(), f"{args.checkpoint_directory}/generator/epoch_{epoch}.pth") torch.save(discriminator.state_dict(), f"{args.checkpoint_directory}/discriminator/epoch_{epoch}.pth") real_activations, fake_activations = map( torch.cat, zip(*create_activation_generator(test_data_loader)())) frechet_inception_distance = metrics.frechet_inception_distance( real_activations.numpy(), fake_activations.numpy()) summary_writer.add_scalars( main_tag="metrics", tag_scalar_dict=dict( frechet_inception_distance=frechet_inception_distance), global_step=global_step) summary_writer.export_scalars_to_json(f"events/scalars.json") summary_writer.close() print("----------------------------------------------------------------") print(f"frechet_inception_distance: {frechet_inception_distance}") print("----------------------------------------------------------------")