def validate(args, fixed_z, fid_stat, epoch, gen_net: nn.Module, writer_dict, clean_dir=True): writer = writer_dict['writer'] global_steps = writer_dict['valid_global_steps'] # eval mode gen_net = gen_net.eval() # generate images # sample_imgs = gen_net(fixed_z, epoch) # img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True) # get fid and inception score fid_buffer_dir = os.path.join(args.path_helper['sample_path'], 'fid_buffer') os.makedirs(fid_buffer_dir, exist_ok=True) eval_iter = args.num_eval_imgs // args.eval_batch_size img_list = list() logger.info('=> calculate fid score') fid_score = get_fid(args, fid_stat, epoch, gen_net, args.num_eval_imgs, args.gen_batch_size*2, writer_dict=writer_dict, cls_idx=None) # fid_score = calculate_fid_given_paths([fid_buffer_dir, fid_stat], inception_path=None) # fid_score = 10000 print(f"FID score: {fid_score}") # writer.add_image('sampled_images', img_grid, global_steps) writer.add_scalar('FID_score', fid_score, global_steps) writer_dict['valid_global_steps'] = global_steps + 1 return fid_score
def validate(args, fixed_z, fid_stat, epoch, gen_net: nn.Module, writer_dict, clean_dir=True): writer = writer_dict['writer'] global_steps = writer_dict['valid_global_steps'] # eval mode gen_net = gen_net.eval() # generate images sample_imgs = gen_net(fixed_z, epoch) img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True) # get fid and inception score fid_buffer_dir = os.path.join(args.path_helper['sample_path'], 'fid_buffer') os.makedirs(fid_buffer_dir, exist_ok=True) eval_iter = args.num_eval_imgs // args.eval_batch_size img_list = list() for iter_idx in tqdm(range(eval_iter), desc='sample images'): z = torch.cuda.FloatTensor( np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim))) # Generate a batch of images gen_imgs = gen_net(z, epoch).mul_(127.5).add_(127.5).clamp_( 0.0, 255.0).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy() for img_idx, img in enumerate(gen_imgs): file_name = os.path.join(fid_buffer_dir, f'iter{iter_idx}_b{img_idx}.png') imsave(file_name, img) img_list.extend(list(gen_imgs)) # get inception score logger.info('=> calculate inception score') mean, std = get_inception_score(img_list) print(f"Inception score: {mean}") # get fid score logger.info('=> calculate fid score') fid_score = get_fid(args, fid_stat, epoch, gen_net, args.num_eval_imgs, args.gen_batch_size * 2, writer_dict=writer_dict, cls_idx=None) print(f"FID score: {fid_score}") if clean_dir: os.system('rm -r {}'.format(fid_buffer_dir)) else: logger.info(f'=> sampled images are saved to {fid_buffer_dir}') # print('first') writer.add_image('sampled_images', img_grid, global_steps) writer.add_scalar('Inception_score/mean', mean, global_steps) writer.add_scalar('Inception_score/std', std, global_steps) writer.add_scalar('FID_score', fid_score, global_steps) # print('second') writer_dict['valid_global_steps'] = global_steps + 1 # print('third') return mean, fid_score