def test_translation_model_cuda(self): res = sample_img2img_model(self.pix2pix.cuda(), self.img_path, target_domain='photo') assert res.shape == (1, 3, 256, 256) res = sample_img2img_model(self.cyclegan.cuda(), self.img_path, target_domain='photo') assert res.shape == (1, 3, 256, 256)
def main(): args = parse_args() model = init_model(args.config, checkpoint=args.checkpoint, device=args.device) if args.sample_cfg is None: args.sample_cfg = dict() results = sample_img2img_model(model, args.image_path, args.target_domain, **args.sample_cfg) results = (results[:, [2, 1, 0]] + 1.) / 2. # save images mmcv.mkdir_or_exist(os.path.dirname(args.save_path)) utils.save_image(results, args.save_path)
def test_translation_model_cpu(self): res = sample_img2img_model(self.pix2pix, self.img_path) assert res.shape == (1, 3, 256, 256) res = sample_img2img_model(self.cyclegan, self.img_path) assert res.shape == (1, 3, 256, 256)