def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.nets, self.nets_ema = build_model(args)
        # self.arcface, self.conf = load_arcface()
        # self.arcface = load_arcface_2()
        self.writer = SummaryWriter(
            'log/test_vox_256_smalldata_id_1_20_20_retrain_alldata_id_embedder_vggface_add_noise'
        )
        # print(self.arcface)
        # assert False
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + '_ema', module)

        if args.mode == 'train':
            self.optims = Munch()
            for net in self.nets.keys():
                if net == 'fan':
                    continue
                self.optims[net] = torch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=args.f_lr if net == 'mapping_network' else args.lr,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay)

            # self.ckptios = [
            #     CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets.ckpt'), **self.nets),
            #     CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema),
            #     CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_optims.ckpt'), **self.optims)]
            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{}_nets.ckpt'),
                             **self.nets),
                CheckpointIO(ospj(args.checkpoint_dir, '{}_nets_ema.ckpt'),
                             **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, '{}_optims.ckpt'),
                             **self.optims)
            ]
        else:

            self.ckptios = [
                CheckpointIO(
                    ospj(args.checkpoint_dir,
                         '{}_nets_ema.ckpt'.format(args.resume_iter)),
                    **self.nets_ema)
            ]
            # self.ckptios = [CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema)]

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ('ema' not in name) and ('fan' not in name):
                print('Initializing %s...' % name)
                network.apply(utils.he_init)
Esempio n. 2
0
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.nets, self.nets_ema = build_model(args)

        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + "_ema", module)

        if args.resume_ckpt is not None:
            print("Transfer learning from", args.resume_ckpt)
            CheckpointIO(
                args.resume_ckpt, **{
                    k: (n.module if isinstance(n, nn.DataParallel) else n)
                    for k, n in self.nets.items()
                }).load(args.resume_ckpt.split("_")[0], restore_D=False
                        )  # no discriminator included in EMA ckpts :(\

        if args.mode == "train":
            self.optims = Munch()
            for net in self.nets.keys():
                if net == "fan":
                    continue
                self.optims[net] = torch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=args.f_lr if net == "mapping_network" else args.lr,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay,
                )

            self.ckptios = [
                CheckpointIO(
                    ospj(args.checkpoint_dir, "{:06d}_nets.ckpt"), **{
                        k: (n.module if isinstance(n, nn.DataParallel) else n)
                        for k, n in self.nets.items()
                    }),
                CheckpointIO(ospj(args.checkpoint_dir, "{:06d}_nets_ema.ckpt"),
                             **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, "{:06d}_optims.ckpt"),
                             **self.optims),
            ]
        else:
            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, "{:06d}_nets_ema.ckpt"),
                             **self.nets_ema)
            ]

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ("ema" not in name) and ("fan" not in name):
                print("Initializing %s..." % name)
                network.apply(utils.he_init)
Esempio n. 3
0
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.summary_writer = SummaryWriter('./runs/experiment_1')

        #self.nets, self.nets_ema = build_model(args)
        self.nets = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)

        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        # for name, module in self.nets_ema.items():
        #     setattr(self, name + '_ema', module)

        if args.mode == 'train':
            self.optims = Munch()
            for net in self.nets.keys():
                if net == 'fan':
                    continue
                self.optims[net] = torch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=args.lr,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay)

            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{0}_nets.ckpt'),
                             **self.nets),
                #CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, '{0}_optims.ckpt'),
                             **self.optims)
            ]

            #""" load the pretrained checkpoint """
            #self._load_checkpoint(step="", fname='./checkpoints/git_nets_ema.ckpt')

        if args.mode == 'eval':
            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{0}_nets.ckpt'),
                             **self.nets),
                #CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema),
                #CheckpointIO(ospj(args.checkpoint_dir, '{0}_optims.ckpt'), **self.optims)]
            ]
        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ('ema' not in name) and ('fan' not in name):
                print('Initializing %s...' % name)
                network.apply(utils.he_init)
Esempio n. 4
0
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.nets, self.nets_ema = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + '_ema', module)

        if args.mode == 'train':
            self.optims = Munch()
            for net in self.nets.keys():
                if net == 'fan':
                    continue
                self.optims[net] = torch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=args.f_lr if net == 'mapping_network' else args.lr,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay)

            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets.ckpt'),
                             **self.nets),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'),
                             **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_optims.ckpt'),
                             **self.optims)
            ]
            # self.ckptios = [
            # CheckpointIO(ospj(args.checkpoint_dir, str(self.nets)  + '_nets.ckpt'), **self.nets),
            # CheckpointIO(ospj(args.checkpoint_dir, str(self.nets_ema)+ '_nets_ema.ckpt'), **self.nets_ema),
            # CheckpointIO(ospj(args.checkpoint_dir, str(self.optims)+ '_optims.ckpt'), **self.optims)]
        else:
            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'),
                             **self.nets_ema)
            ]
            # self.ckptios = [CheckpointIO(ospj(args.checkpoint_dir, str(self.nets_ema) + '_nets_ema.ckpt'), **self.nets_ema)]

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ('ema' not in name) and ('fan' not in name):
                print('Initializing %s...' % name)
                network.apply(utils.he_init)
Esempio n. 5
0
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.nets, self.nets_ema = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + "_ema", module)

        if args.mode == "train":
            self.optims = Munch()
            for net in self.nets.keys():
                if net == "fan":
                    continue
                self.optims[net] = torch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=args.f_lr if net == "mapping_network" else args.lr,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay,
                )

            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, "{:06d}_nets.ckpt"),
                             **self.nets),
                CheckpointIO(ospj(args.checkpoint_dir, "{:06d}_nets_ema.ckpt"),
                             **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, "{:06d}_optims.ckpt"),
                             **self.optims),
            ]
        else:
            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, "{:06d}_nets_ema.ckpt"),
                             **self.nets_ema)
            ]

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ("ema" not in name) and ("fan" not in name):
                print("Initializing %s..." % name)
                network.apply(utils.he_init)
Esempio n. 6
0
    def __init__(self, args):
        super().__init__()
        self.args = args

        self.nets, self.nets_ema = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + '_ema', module)

        if args.mode == 'train':
            place = paddle.fluid.CUDAPlace(
                self.args.whichgpu) if paddle.fluid.is_compiled_with_cuda(
                ) else paddle.fluid.CPUPlace()
            with fluid.dygraph.guard(place):
                self.optims = Munch()
                self.ckptios = Munch()
                for net in self.nets.keys():
                    if net == 'fan':
                        continue
                    self.optims[net] = fluid.optimizer.AdamOptimizer(
                        learning_rate=args.f_lr
                        if net == 'mapping_network' else args.lr,
                        beta1=args.beta1,
                        beta2=args.beta2,
                        parameter_list=self.nets[net].parameters(),
                        regularization=fluid.regularizer.L2Decay(
                            regularization_coeff=args.weight_decay))
                    self.ckptios[net] = [
                        CheckpointIO(
                            ospj(args.checkpoint_dir,
                                 '{:06d}_nets_ema_' + net)),
                        CheckpointIO(
                            ospj(args.checkpoint_dir, '{:06d}_nets_' + net))
                    ]
        else:
            self.ckptios = Munch()
            for net in self.nets.keys():
                self.ckptios[net] = [
                    CheckpointIO(
                        ospj(args.checkpoint_dir, '{:06d}_nets_ema_' + net))
                ]
Esempio n. 7
0
def load_model(args):
    _, nets_ema = build_model(args)

    ckptios = [
        CheckpointIO(ospj(args.checkpoint_dir, '{0:06d}_nets_ema.ckpt'),
                     **nets_ema)
    ]  # compatible with Windows
    for ckptio in ckptios:
        ckptio.load(args.resume_iter)

    return nets_ema
Esempio n. 8
0
    def __init__(self, args):

        super().__init__()
        self.args = args
        # self.device = porch.device('cuda' if porch.cuda.is_available() else 'cpu')
        print("Solver init....")
        self.nets, self.nets_ema = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + '_ema', module)

        if args.mode == 'train':
            self.optims = Munch()
            for net in self.nets.keys():
                if net == 'fan':
                    continue
                self.optims[net] = porch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=args.f_lr if net == 'mapping_network' else args.lr,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay)

            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'),
                             **self.nets),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'),
                             **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_optims.ckpt'),
                             **self.optims)
            ]
        else:
            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'),
                             **self.nets_ema)
            ]

        self
Esempio n. 9
0
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        _, self.nets_ema = build_model(args)

        for name, module in self.nets_ema.items():
            setattr(self, name + '_ema', module)


        self.ckptios = [CheckpointIO(ospj(args.checkpoint_dir, '{}_nets_ema.ckpt'.format(args.resume_iter)), **self.nets_ema)]

        self.to(self.device)


        self._load_checkpoint(args.resume_iter)
Esempio n. 10
0
    def __init__(self):
        super().__init__()
        args = resolver_args()
        self.args = args
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.nets, self.nets_ema = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + '_ema', module)

    
        self.ckptios = [CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema)]

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ('ema' not in name) and ('fan' not in name):
                print('Initializing %s...' % name)
                network.apply(utils.he_init)