示例#1
0
    def load_checkpoint(self):
        try:
            map_location = "cuda:0" if torch.cuda.is_available() else "cpu"
            ckpt = load_checkpoint(self.checkpoint_dir,
                                   map_location=map_location)
            # Transition settings
            self.is_transitioning = ckpt["is_transitioning"]
            self.transition_step = ckpt["transition_step"]
            self.current_imsize = ckpt["current_imsize"]
            self.latest_switch = ckpt["latest_switch"]

            # Tracking stats
            self.global_step = ckpt["global_step"]
            self.start_time = time.time() - ckpt["total_time"] * 60
            self.num_skipped_steps = ckpt["num_skipped_steps"]

            # Models
            self.discriminator.load_state_dict(ckpt['D'])

            self.generator.load_state_dict(ckpt['G'])
            self.running_average_generator.load_state_dict(
                ckpt["running_average_generator"])
            to_cuda([self.generator, self.discriminator,
                     self.running_average_generator])
            self.running_average_generator = amp.initialize(self.running_average_generator,
                                                            None, opt_level=self.opt_level)
            self.init_optimizers()
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
            return True
        except FileNotFoundError as e:
            print(e)
            print(' [*] No checkpoint!')
            return False
示例#2
0
def setup(opts):
    shutil.move(
        opts['face_detector'],
        'deep_privacy/detection/dsfd/weights/WIDERFace_DSFD_RES152.pth')
    config = config_parser.load_config('models/default/config.yml')
    ckpt = utils.load_checkpoint(opts['checkpoint_dir'])
    generator = infer.init_generator(config, ckpt)
    anonymizer = deep_privacy_anonymizer.DeepPrivacyAnonymizer(
        generator, 128, use_static_z=True)
    return anonymizer
示例#3
0
 def get_weights(self):
     ckpt = utils.load_checkpoint(self.checkpoint_dir, map_location="cuda:0")        
     """
     for key in list(ckpt['G'].keys()):
         if 'core_blocks_down' in key:
             key
             ckpt[key.replace('model.', '')] = ckpt[key]
             del ckpt[key]        
     """
     #self.generator.eval()
     #self.generator = torch.nn.DataParallel(self.generator, device_ids=[0])
     return self.generator.state_dict()
示例#4
0
def read_args(additional_args=[]):
    config = config_parser.initialize_and_validate_config([
        {"name": "source_path", "default": "test_examples/source"},
        {"name": "target_path", "default": ""}
    ] + additional_args)
    target_path = config.target_path
    target_path = get_default_target_path(config.source_path,
                                          config.target_path,
                                          config.config_path)
    ckpt = utils.load_checkpoint(config.checkpoint_dir)
    generator = init_generator(config, ckpt)

    imsize = ckpt["current_imsize"]
    source_path = config.source_path
    image_paths = get_images_recursive(source_path)
    if additional_args:
        return generator, imsize, source_path, image_paths, target_path, config
    return generator, imsize, source_path, image_paths, target_path
示例#5
0
def read_args():
    config = config_parser.initialize_and_validate_config([
        {"name": "target_path", "default": ""}
    ])
    save_path = config.target_path
    if save_path == "":
        default_path = os.path.join(
            os.path.dirname(config.config_path),
            "fid_images"
        )
        print("Setting target path to default:", default_path)
        save_path = default_path
    model_name = config.config_path.split("/")[-2]
    ckpt = load_checkpoint(os.path.join("validation_checkpoints", model_name))
    #ckpt = load_checkpoint(os.path.join(
    #                                    os.path.dirname(config.config_path),
    #                                    "checkpoints"))
    generator = init_generator(config, ckpt)
    imsize = ckpt["current_imsize"]
    pose_size = config.models.pose_size
    return generator, imsize, save_path, pose_size
示例#6
0
文件: w01.py 项目: garima0106/GAN
import os
import time
import numpy as np
import torch
import tqdm
from deep_privacy.utils import load_checkpoint, save_checkpoint
from deep_privacy.models.base_model import ProgressiveBaseModel

torch.manual_seed(0)
checkpointFile = input(
    "Please enter the path of the checkpoint file to be loaded\n")
loadedCkpt = load_checkpoint(checkpointFile,
                             load_best=False,
                             map_location=None)
print("Discriminator Parameters: " + str(loadedCkpt["D"]["parameters"]))
print("Generator Parameters: " + str(loadedCkpt["G"]["parameters"]))