示例#1
0
    def test_translation_model_cuda(self):
        res = sample_img2img_model(self.pix2pix.cuda(),
                                   self.img_path,
                                   target_domain='photo')
        assert res.shape == (1, 3, 256, 256)

        res = sample_img2img_model(self.cyclegan.cuda(),
                                   self.img_path,
                                   target_domain='photo')
        assert res.shape == (1, 3, 256, 256)
示例#2
0
def main():
    args = parse_args()
    model = init_model(args.config,
                       checkpoint=args.checkpoint,
                       device=args.device)

    if args.sample_cfg is None:
        args.sample_cfg = dict()

    results = sample_img2img_model(model, args.image_path, args.target_domain,
                                   **args.sample_cfg)
    results = (results[:, [2, 1, 0]] + 1.) / 2.

    # save images
    mmcv.mkdir_or_exist(os.path.dirname(args.save_path))
    utils.save_image(results, args.save_path)
示例#3
0
    def test_translation_model_cpu(self):
        res = sample_img2img_model(self.pix2pix, self.img_path)
        assert res.shape == (1, 3, 256, 256)

        res = sample_img2img_model(self.cyclegan, self.img_path)
        assert res.shape == (1, 3, 256, 256)