Exemple #1
0
def load_checkpoints(config_path, checkpoint_path, cpu=False):

    with open(config_path) as f:
        config = yaml.load(f)

    generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                        **config['model_params']['common_params'])
    if not cpu:
        generator.cuda()

    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    if not cpu:
        kp_detector.cuda()
    
    if cpu:
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(checkpoint_path)
 
    generator.load_state_dict(checkpoint['generator'])
    kp_detector.load_state_dict(checkpoint['kp_detector'])
    
    if not cpu:
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()
    
    return generator, kp_detector
Exemple #2
0
def load_checkpoints(config_path, checkpoint_path, device="cuda"):

    with open(config_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    generator = OcclusionAwareGenerator(
        **config["model_params"]["generator_params"],
        **config["model_params"]["common_params"],
    )
    generator.to(device)

    kp_detector = KPDetector(
        **config["model_params"]["kp_detector_params"],
        **config["model_params"]["common_params"],
    )
    kp_detector.to(device)

    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint["generator"])
    kp_detector.load_state_dict(checkpoint["kp_detector"])

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    return generator, kp_detector
Exemple #3
0
def load_checkpoints(config_path, checkpoint_path, device='cuda'):

    with open(config_path) as f:
        config = yaml.load(f)

    generator = OcclusionAwareGenerator(
        **config['model_params']['generator_params'],
        **config['model_params']['common_params'])
    generator.to(device)

    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    kp_detector.to(device)

    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint['generator'])
    kp_detector.load_state_dict(checkpoint['kp_detector'])

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    return generator, kp_detector
Exemple #4
0
def load_checkpoints(config_path):
    with open(config_path) as f:
        config = yaml.load(f)
    pretrain_model = config['ckpt_model']
    generator = OcclusionAwareGenerator(
        **config['model_params']['generator_params'],
        **config['model_params']['common_params'])
    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    load_ckpt(pretrain_model, generator=generator, kp_detector=kp_detector)
    generator.eval()
    kp_detector.eval()
    return generator, kp_detector
    def load_generator_and_keypoint_detector(self):
        config = self.load_config()
        generator = OcclusionAwareGenerator(
            **config['model_params']['generator_params'],
            **config['model_params']['common_params'])
        generator.to(self.device)
        kp_detector = KPDetector(
            **config['model_params']['kp_detector_params'],
            **config['model_params']['common_params'])
        kp_detector.to(self.device)

        checkpoints = self.load_checkpoints()
        generator.load_state_dict(checkpoints['generator'])
        kp_detector.load_state_dict(checkpoints['kp_detector'])

        generator.eval()
        kp_detector.eval()

        return generator, kp_detector
Exemple #6
0
 def load_checkpoints(self):
     with open(self.config_path) as f:
         config = yaml.load(f)
 
     generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                         **config['model_params']['common_params'])
     generator.to(self.device)
 
     kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                              **config['model_params']['common_params'])
     kp_detector.to(self.device)
 
     checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
     generator.load_state_dict(checkpoint['generator'])
     kp_detector.load_state_dict(checkpoint['kp_detector'])
 
     generator.eval()
     kp_detector.eval()
     
     return generator, kp_detector
Exemple #7
0
def load_checkpoints(config_path):

    with open(config_path) as f:
        config = yaml.load(f)
    pretrain_model = config['ckpt_model']
    generator = OcclusionAwareGenerator(
        **config['model_params']['generator_params'],
        **config['model_params']['common_params'])

    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    if pretrain_model['generator'] is not None:
        if pretrain_model['generator'][-3:] == 'npz':
            G_param = np.load(pretrain_model['generator'],
                              allow_pickle=True)['arr_0'].item()
            G_param_clean = [(i, G_param[i]) for i in G_param
                             if 'num_batches_tracked' not in i]
            parameter_clean = generator.parameters()
            del (
                parameter_clean[65]
            )  # The parameters in AntiAliasInterpolation2d is not in dict_set and should be ignore.
            for p, v in zip(parameter_clean, G_param_clean):
                p.set_value(v[1])
        else:
            a, b = fluid.load_dygraph(pretrain_model['generator'])
            generator.set_dict(a)
        print('Restore Pre-trained Generator')
    if pretrain_model['kp'] is not None:
        if pretrain_model['kp'][-3:] == 'npz':
            KD_param = np.load(pretrain_model['kp'],
                               allow_pickle=True)['arr_0'].item()
            KD_param_clean = [(i, KD_param[i]) for i in KD_param
                              if 'num_batches_tracked' not in i]
            parameter_clean = kp_detector.parameters()
            for p, v in zip(parameter_clean, KD_param_clean):
                p.set_value(v[1])
        else:
            a, b = fluid.load_dygraph(pretrain_model['kp'])
            kp_detector.set_dict(a)
        print('Restore Pre-trained KD')
    generator.eval()
    kp_detector.eval()

    return generator, kp_detector
Exemple #8
0
    if torch.cuda.is_available():
        generator.to(opt.device_ids[0])
    if opt.verbose:
        print(generator)

    # Declare a discriminator
    discriminator = MultiScaleDiscriminator(
        **config['model_params']['discriminator_params'],
        **config['model_params']['common_params'])
    if torch.cuda.is_available():
        discriminator.to(opt.device_ids[0])
    if opt.verbose:
        print(discriminator)

    # Declare a key point detector
    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])

    if torch.cuda.is_available():
        kp_detector.to(opt.device_ids[0])

    # Print network details if using --verbose flag
    if opt.verbose:
        print(kp_detector)

    # Read in dataset details, defined in *.yaml config file, "dataset_params" section
    # Refer to ./config/vox-256.yaml for details
    # 数据预处理在此步骤完成,并读取进 dataset 变量中
    dataset = FramesDataset(is_train=(opt.mode == 'train'),
                            **config['dataset_params'])
    print("Dataset size: {}, repeat number: {}".format(
        len(dataset), config['train_params']['num_repeats']))
def load_checkpoints(config_path, checkpoint_path, cpu=False):
    with open(config_path) as f:
        config = yaml.load(f)

    generator = OcclusionAwareGenerator(
        **config["model_params"]["generator_params"],
        **config["model_params"]["common_params"],
    )
    if cpu:
        generator.cpu()
    else:
        generator.cuda()

    kp_detector = KPDetector(
        **config["model_params"]["kp_detector_params"],
        **config["model_params"]["common_params"],
    )
    if cpu:
        kp_detector.cpu()
    else:
        kp_detector.cuda()

    checkpoint = torch.load(checkpoint_path, map_location="cpu" if cpu else None)
    generator.load_state_dict(checkpoint["generator"])
    kp_detector.load_state_dict(checkpoint["kp_detector"])

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    return generator, kp_detector
                        help="path to save in")
    parser.add_argument("--preload",
                        action='store_true',
                        help="preload dataset to RAM")
    parser.set_defaults(verbose=False)
    opt = parser.parse_args()
    with open(opt.config) as f:
        config = yaml.load(f)

    generator = OcclusionAwareGenerator(
        **config['model_params']['generator_params'],
        **config['model_params']['common_params'])
    discriminator = MultiScaleDiscriminator(
        **config['model_params']['discriminator_params'],
        **config['model_params']['common_params'])
    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])

    dataset = FramesDataset(is_train=(opt.mode == 'train'),
                            **config['dataset_params'])
    if opt.preload:
        logging.info('PreLoad Dataset: Start')
        pre_list = list(range(len(dataset)))
        import multiprocessing.pool as pool
        with pool.Pool(4) as pl:
            buf = pl.map(dataset.preload, pre_list)
        for idx, (i, v) in enumerate(zip(pre_list, buf)):
            dataset.buffed[i] = v.copy()
            buf[idx] = None
        logging.info('PreLoad Dataset: End')

    if opt.mode == 'train':
Exemple #11
0
        log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
    else:
        log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
        log_dir += ' ' + strftime('%d_%m_%y_%H.%M.%S', gmtime())

    generator = Generator(**config['model_params']['generator_params'],
                          **config['model_params']['common_params'])

    if torch.cuda.is_available():
        generator.to(opt.device_ids[0])
    if opt.verbose:
        print(generator)

    checkpoint_with_kp = torch.load(opt.checkpoint_with_kp, map_location='cpu' if opt.cpu else None)

    kp_detector = KPDetector(checkpoint_with_kp, **config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])

    if torch.cuda.is_available():
        kp_detector.to(opt.device_ids[0])

    if opt.verbose:
        print(kp_detector)

    dataset = FramesDataset(is_train=(opt.mode == 'train'), **config['dataset_params'])

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
        copy(opt.config, log_dir)

    if opt.mode == 'train':
Exemple #12
0
    parser.set_defaults(cpu=False)

    opt = parser.parse_args()

    with open(opt.config) as f:
        config = yaml.load(f)
        blocks_discriminator = config['model_params']['discriminator_params']['num_blocks']
        assert len(config['train_params']['loss_weights']['reconstruction']) == blocks_discriminator + 1

    generator = MotionTransferGenerator(**config['model_params']['generator_params'],
                                        **config['model_params']['common_params'])
    if not opt.cpu:
        generator.cuda()

    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    if not opt.cpu:
        kp_detector = kp_detector.cuda()

    Logger.load_cpk(opt.checkpoint, generator=generator, kp_detector=kp_detector, use_cpu=True)

    vis = Visualizer()

    if not opt.cpu: 
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    '''