Exemplo n.º 1
0
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.nets, self.nets_ema = build_model(args)

        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + "_ema", module)

        if args.resume_ckpt is not None:
            print("Transfer learning from", args.resume_ckpt)
            CheckpointIO(
                args.resume_ckpt, **{
                    k: (n.module if isinstance(n, nn.DataParallel) else n)
                    for k, n in self.nets.items()
                }).load(args.resume_ckpt.split("_")[0], restore_D=False
                        )  # no discriminator included in EMA ckpts :(\

        if args.mode == "train":
            self.optims = Munch()
            for net in self.nets.keys():
                if net == "fan":
                    continue
                self.optims[net] = torch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=args.f_lr if net == "mapping_network" else args.lr,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay,
                )

            self.ckptios = [
                CheckpointIO(
                    ospj(args.checkpoint_dir, "{:06d}_nets.ckpt"), **{
                        k: (n.module if isinstance(n, nn.DataParallel) else n)
                        for k, n in self.nets.items()
                    }),
                CheckpointIO(ospj(args.checkpoint_dir, "{:06d}_nets_ema.ckpt"),
                             **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, "{:06d}_optims.ckpt"),
                             **self.optims),
            ]
        else:
            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, "{:06d}_nets_ema.ckpt"),
                             **self.nets_ema)
            ]

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ("ema" not in name) and ("fan" not in name):
                print("Initializing %s..." % name)
                network.apply(utils.he_init)
Exemplo n.º 2
0
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.nets, self.nets_ema = build_model(args)
        self.arcface, self.conf = load_arcface()
        self.writer = SummaryWriter('log/test11')
        # print(self.arcface)
        # assert False
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + '_ema', module)

        if args.mode == 'train':
            self.optims = Munch()
            for net in self.nets.keys():
                if net == 'fan':
                    continue
                self.optims[net] = torch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=args.f_lr if net == 'mapping_network' else args.lr,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay)

            # self.ckptios = [
            #     CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets.ckpt'), **self.nets),
            #     CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema),
            #     CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_optims.ckpt'), **self.optims)]
            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{}_nets.ckpt'),
                             **self.nets),
                CheckpointIO(ospj(args.checkpoint_dir, '{}_nets_ema.ckpt'),
                             **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, '{}_optims.ckpt'),
                             **self.optims)
            ]
        else:

            self.ckptios = [
                CheckpointIO(
                    ospj(args.checkpoint_dir,
                         '{:06d}_nets_ema.ckpt'.format(100000)),
                    **self.nets_ema)
            ]
            # self.ckptios = [CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema)]

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ('ema' not in name) and ('fan' not in name):
                print('Initializing %s...' % name)
                network.apply(utils.he_init)
Exemplo n.º 3
0
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.summary_writer = SummaryWriter('./runs/experiment_1')

        #self.nets, self.nets_ema = build_model(args)
        self.nets = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)

        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        # for name, module in self.nets_ema.items():
        #     setattr(self, name + '_ema', module)

        if args.mode == 'train':
            self.optims = Munch()
            for net in self.nets.keys():
                if net == 'fan':
                    continue
                self.optims[net] = torch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=args.lr,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay)

            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{0}_nets.ckpt'),
                             **self.nets),
                #CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, '{0}_optims.ckpt'),
                             **self.optims)
            ]

            #""" load the pretrained checkpoint """
            #self._load_checkpoint(step="", fname='./checkpoints/git_nets_ema.ckpt')

        if args.mode == 'eval':
            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{0}_nets.ckpt'),
                             **self.nets),
                #CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema),
                #CheckpointIO(ospj(args.checkpoint_dir, '{0}_optims.ckpt'), **self.optims)]
            ]
        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ('ema' not in name) and ('fan' not in name):
                print('Initializing %s...' % name)
                network.apply(utils.he_init)
Exemplo n.º 4
0
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.nets, self.nets_ema = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + '_ema', module)

        if args.mode == 'train':
            self.optims = Munch()
            for net in self.nets.keys():
                if net == 'fan':
                    continue
                self.optims[net] = torch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=args.f_lr if net == 'mapping_network' else args.lr,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay)

            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets.ckpt'), **self.nets),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_optims.ckpt'), **self.optims)]
        else:
            self.ckptios = [CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema)]

        # Multi-gpu Training
        if self.args.gpus != "0" and torch.cuda.is_available():
            self.gpus = self.gpus.split(',')
            self.gpus = [int(i) for i in self.gpus]
            self = torch.nn.DataParallel(self,device_ids=self.gpus)
            """
            self.nets.generator = torch.nn.DataParallel(self.G, device_ids=self.gpus)
            self.nets.generator = torch.nn.DataParallel(self.D, device_ids=self.gpus)
            self.M = torch.nn.DataParallel(self.M, device_ids=self.gpus)
            self.S = torch.nn.DataParallel(self.S, device_ids=self.gpus)
            """

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ('ema' not in name) and ('fan' not in name):
                print('Initializing %s...' % name)
                network.apply(utils.he_init)
Exemplo n.º 5
0
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.nets, self.nets_ema = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + "_ema", module)

        if args.mode == "train":
            self.optims = Munch()
            for net in self.nets.keys():
                if net == "fan":
                    continue
                self.optims[net] = torch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=args.f_lr if net == "mapping_network" else args.lr,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay,
                )

            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, "{:06d}_nets.ckpt"),
                             **self.nets),
                CheckpointIO(ospj(args.checkpoint_dir, "{:06d}_nets_ema.ckpt"),
                             **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, "{:06d}_optims.ckpt"),
                             **self.optims),
            ]
        else:
            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, "{:06d}_nets_ema.ckpt"),
                             **self.nets_ema)
            ]

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ("ema" not in name) and ("fan" not in name):
                print("Initializing %s..." % name)
                network.apply(utils.he_init)
Exemplo n.º 6
0
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = torch.device('cuda')

        self.nets, self.nets_ema = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + '_ema', module)

        if args.mode == 'train':
            self.optims = Munch()
            for net in self.nets.keys():
                if net == 'fan':
                    continue
                self.optims[net] = torch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=args.f_lr if net == 'mapping_network' else args.lr,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay)

            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets.ckpt'), **self.nets),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_optims.ckpt'), **self.optims)]
        else:
            self.ckptios = [CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema)]

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ('ema' not in name) and ('fan' not in name):
                print('Initializing %s...' % name)
                network.apply(utils.he_init)

        ### modify def sample
        self._load_checkpoint(args.resume_iter)
Exemplo n.º 7
0
    def __init__(self, args):

        super().__init__()
        self.args = args
        # self.device = porch.device('cuda' if porch.cuda.is_available() else 'cpu')
        print("Solver init....")
        self.nets, self.nets_ema = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + '_ema', module)

        if args.mode == 'train':
            self.optims = Munch()
            for net in self.nets.keys():
                if net == 'fan':
                    continue
                self.optims[net] = porch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=args.f_lr if net == 'mapping_network' else args.lr,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay)

            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'),
                             **self.nets),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'),
                             **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_optims.ckpt'),
                             **self.optims)
            ]
        else:
            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'),
                             **self.nets_ema)
            ]

        self
Exemplo n.º 8
0
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.nets, self.nets_ema, self.vgg, self.VggExtract = build_model(args)
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.L1Loss = nn.L1Loss()
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + '_ema', module)

        if args.mode == 'train':
            self.optims = Munch()
            for net in self.nets.keys():
                if net == 'fan':
                    continue
                self.optims[net] = torch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=args.f_lr if net == 'mapping_network' else args.lr,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay)

            self.ckptios = [CheckpointIO(ospj(args.checkpoint_dir, '100000_nets.ckpt'), **self.nets),
                CheckpointIO(ospj(args.checkpoint_dir, '100000_nets_ema.ckpt'), **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, '100000_optims.ckpt'), **self.optims)]
        else:
            self.ckptios = [CheckpointIO(ospj(args.checkpoint_dir, '100000_nets_ema.ckpt'), **self.nets_ema)]

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ('ema' not in name) and ('fan' not in name):
                print('Initializing %s...' % name)
                network.apply(utils.he_init)
Exemplo n.º 9
0
    def __init__(self):
        super().__init__()
        args = resolver_args()
        self.args = args
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.nets, self.nets_ema = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + '_ema', module)

    
        self.ckptios = [CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema)]

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ('ema' not in name) and ('fan' not in name):
                print('Initializing %s...' % name)
                network.apply(utils.he_init)