Exemplo n.º 1
0
    def __init__(self, args):
        super(LossWrapper, self).__init__()
        self.apply_to = [
            rn_utils.parse_str_to_list(s, sep=',')
            for s in rn_utils.parse_str_to_list(args.pix_loss_apply_to,
                                                sep=';')
        ]

        # Supported loss functions
        losses = {'mse': F.mse_loss, 'l1': F.l1_loss, 'ce': F.cross_entropy}

        self.loss = losses[args.pix_loss_type]

        # Weights for each feature extractor
        self.weights = rn_utils.parse_str_to_list(args.pix_loss_weights,
                                                  value_type=float)
        self.names = rn_utils.parse_str_to_list(args.pix_loss_names)
Exemplo n.º 2
0
    def __init__(self, args):
        super(LossWrapper, self).__init__()
        self.apply_to = rn_utils.parse_str_to_list(args.wpr_loss_apply_to)
        self.eps = args.eps

        self.reg_type = args.wpr_loss_type
        self.weight = args.wpr_loss_weight

        self.weight_decay = args.wpr_loss_weight_decay
        self.decay_schedule = args.wpr_loss_decay_schedule
        self.num_iters = 0
Exemplo n.º 3
0
    def __init__(self, args_dict):
        super(InferenceWrapper, self).__init__()
        # Get a config for the network
        self.args = self.get_args(args_dict)
        self.to_tensor = transforms.ToTensor()

        # Load the model
        self.runner = importlib.import_module(
            f'runners.{self.args.runner_name}').RunnerWrapper(self.args,
                                                              training=False)
        self.runner.eval()

        # Load pretrained weights
        checkpoints_dir = pathlib.Path(
            self.args.project_dir
        ) / 'runs' / self.args.experiment_name / 'checkpoints'

        # Load pre-trained weights
        init_networks = rn_utils.parse_str_to_list(
            self.args.init_networks) if self.args.init_networks else {}
        networks_to_train = self.runner.nets_names_to_train

        if self.args.init_which_epoch != 'none' and self.args.init_experiment_dir:
            for net_name in init_networks:
                self.runner.nets[net_name].load_state_dict(
                    torch.load(pathlib.Path(self.args.init_experiment_dir) /
                               'checkpoints' /
                               f'{self.args.init_which_epoch}_{net_name}.pth',
                               map_location='cpu'))

        for net_name in networks_to_train:
            if net_name not in init_networks and net_name in self.runner.nets.keys(
            ):
                self.runner.nets[net_name].load_state_dict(
                    torch.load(checkpoints_dir /
                               f'{self.args.which_epoch}_{net_name}.pth',
                               map_location='cpu'))

        # Stickman/facemasks drawer
        self.fa = face_alignment.FaceAlignment(
            face_alignment.LandmarksType._2D, flip_input=True)

        self.net_seg = wrapper.SegmentationWrapper(self.args)

        # Remove spectral norm to improve the performance
        self.runner.apply(rn_utils.remove_spectral_norm)

        # self.runner.apply(rn_utils.prepare_for_mobile_inference)

        if self.args.num_gpus > 0:
            self.cuda()
    def __init__(self, args):
        super(LossWrapper, self).__init__()
        self.apply_to = [
            rn_utils.parse_str_to_list(s, sep=',')
            for s in rn_utils.parse_str_to_list(args.seg_loss_apply_to,
                                                sep=';')
        ]
        self.names = rn_utils.parse_str_to_list(args.seg_loss_names, sep=',')

        # Supported loss functions
        losses = {
            'bce':
            F.binary_cross_entropy_with_logits,
            'dice':
            lambda fake_seg, real_seg: torch.log(
                (fake_seg**2).sum() + (real_seg**2).sum()) - torch.log(
                    (2 * fake_seg * real_seg).sum())
        }

        self.loss = losses[args.seg_loss_type]

        self.weights = args.seg_loss_weights

        self.eps = args.eps
Exemplo n.º 5
0
    def load_names(self, args):
        # Initialize utility lists and dicts for the networks
        self.nets_names_to_train = utils.parse_str_to_list(args.networks_to_train)
        self.nets_names_train = utils.parse_str_to_list(args.networks_train)
        self.nets_names_test = utils.parse_str_to_list(args.networks_test)
        self.nets_names_calc_stats = utils.parse_str_to_list(args.networks_calc_stats)

        # Initialize utility lists and dicts for the networks
        self.losses_names_train = utils.parse_str_to_list(args.losses_train)
        self.losses_names_test = utils.parse_str_to_list(args.losses_test)
Exemplo n.º 6
0
    def get_args(parser):
        # Networks used in train and test
        parser.add(
            '--networks_train',
            default=
            'identity_embedder, texture_generator, keypoints_embedder, inference_generator, discriminator',
            help=
            'order of forward passes during the training of gen (or gen and dis for sim sgd)'
        )

        parser.add(
            '--networks_test',
            default=
            'identity_embedder, texture_generator, keypoints_embedder, inference_generator',
            help='order of forward passes during testing')

        parser.add(
            '--networks_calc_stats',
            default=
            'identity_embedder, texture_generator, keypoints_embedder, inference_generator',
            help='order of forward passes during stats calculation')

        parser.add(
            '--networks_to_train',
            default=
            'identity_embedder, texture_generator, keypoints_embedder, inference_generator, discriminator',
            help='names of networks that are being trained')

        # Losses used in train and test
        parser.add(
            '--losses_train',
            default=
            'adversarial, feature_matching, perceptual, pixelwise, segmentation, warping_regularizer',
            help='losses evaluated during training')

        parser.add('--losses_test',
                   default='lpips, csim',
                   help='losses evaluated during testing')

        # Spectral norm options
        parser.add(
            '--spn_networks',
            default=
            'identity_embedder, texture_generator, keypoints_embedder, inference_generator, discriminator',
            help='networks to apply spectral norm to')

        parser.add(
            '--spn_exceptions',
            default='',
            help=
            'a list of exceptional submodules that have spectral norm removed')

        parser.add('--spn_layers',
                   default='conv2d, linear',
                   help='layers to apply spectral norm to')

        # Weight averaging options
        parser.add(
            '--wgv_mode',
            default='none',
            help=
            'none|running_average|average -- exponential moving averaging or weight averaging'
        )

        parser.add('--wgv_momentum',
                   default=0.999,
                   type=float,
                   help='momentum value in EMA weight averaging')

        # Training options
        parser.add('--eps', default=1e-7, type=float)

        parser.add(
            '--optims',
            default=
            'identity_embedder: adam, texture_generator: adam, keypoints_embedder: adam, inference_generator: adam, discriminator: adam',
            help='network_name: optimizer')

        parser.add(
            '--lrs',
            default=
            'identity_embedder: 2e-4, texture_generator: 2e-4, keypoints_embedder: 2e-4, inference_generator: 2e-4, discriminator: 2e-4',
            help='learning rates for each network')

        parser.add(
            '--stats_calc_iters',
            default=100,
            type=int,
            help='number of iterations used to calculate standing statistics')

        parser.add('--num_visuals',
                   default=32,
                   type=int,
                   help='the total number of output visuals')

        parser.add('--bn_momentum',
                   default=1.0,
                   type=float,
                   help='momentum of the batchnorm layers')

        parser.add('--adam_beta1',
                   default=0.5,
                   type=float,
                   help='beta1 (momentum of the gradient) parameter for Adam')

        args, _ = parser.parse_known_args()

        # Add args from the required networks
        networks_names = list(
            set(
                utils.parse_str_to_list(args.networks_train, sep=',') +
                utils.parse_str_to_list(args.networks_test, sep=',')))
        for network_name in networks_names:
            importlib.import_module(
                f'networks.{network_name}').NetworkWrapper.get_args(parser)

        # Add args from the losses
        losses_names = list(
            set(
                utils.parse_str_to_list(args.losses_train, sep=',') +
                utils.parse_str_to_list(args.losses_test, sep=',')))
        for loss_name in losses_names:
            importlib.import_module(
                f'losses.{loss_name}').LossWrapper.get_args(parser)

        return parser
Exemplo n.º 7
0
    def __init__(self, args, training=True):
        super(RunnerWrapper, self).__init__()
        # Store general options
        self.args = args
        self.training = training

        # Read names lists from the args
        self.load_names(args)

        # Initialize classes for the networks
        nets_names = self.nets_names_test
        if self.training:
            nets_names += self.nets_names_train
        nets_names = list(set(nets_names))

        self.nets = nn.ModuleDict()

        for net_name in sorted(nets_names):
            self.nets[net_name] = importlib.import_module(
                f'networks.{net_name}').NetworkWrapper(args)

            if args.num_gpus > 1:
                # Apex is only needed for multi-gpu training
                from apex import parallel

                self.nets[net_name] = parallel.convert_syncbn_model(
                    self.nets[net_name])

        # Set nets that are not training into eval mode
        for net_name in self.nets.keys():
            if net_name not in self.nets_names_to_train:
                self.nets[net_name].eval()

        # Initialize classes for the losses
        if self.training:
            losses_names = list(
                set(self.losses_names_train + self.losses_names_test))
            self.losses = nn.ModuleDict()

            for loss_name in sorted(losses_names):
                self.losses[loss_name] = importlib.import_module(
                    f'losses.{loss_name}').LossWrapper(args)

        # Spectral norm
        if args.spn_layers:
            spn_layers = utils.parse_str_to_list(args.spn_layers, sep=',')
            spn_nets_names = utils.parse_str_to_list(args.spn_networks,
                                                     sep=',')

            for net_name in spn_nets_names:
                self.nets[net_name].apply(lambda module: utils.spectral_norm(
                    module, apply_to=spn_layers, eps=args.eps))

            # Remove spectral norm in modules in exceptions
            spn_exceptions = utils.parse_str_to_list(args.spn_exceptions,
                                                     sep=',')

            for full_module_name in spn_exceptions:
                if not full_module_name:
                    continue

                parts = full_module_name.split('.')

                # Get the module that needs to be changed
                module = self.nets[parts[0]]
                for part in parts[1:]:
                    module = getattr(module, part)

                module.apply(utils.remove_spectral_norm)

        # Weight averaging
        if args.wgv_mode != 'none':
            # Apply weight averaging only for networks that are being trained
            for net_name, _ in self.nets_names_to_train:
                self.nets[net_name].apply(
                    lambda module: utils.weight_averaging(
                        module, mode=args.wgv_mode, momentum=args.wgv_momentum)
                )

        # Check which networks are being trained and put the rest into the eval mode
        for net_name in self.nets.keys():
            if net_name not in self.nets_names_to_train:
                self.nets[net_name].eval()

        # Set the same batchnorm momentum accross all modules
        if self.training:
            self.apply(lambda module: utils.set_batchnorm_momentum(
                module, args.bn_momentum))

        # Store a history of losses and images for visualization
        self.losses_history = {
            True: {},  # self.training = True
            False: {}
        }
Exemplo n.º 8
0
    def __init__(self, args, runner=None):
        super(TrainingWrapper, self).__init__()
        # Initialize and apply general options
        ssl._create_default_https_context = ssl._create_unverified_context
        torch.backends.cudnn.benchmark = True
        torch.manual_seed(args.random_seed)
        torch.cuda.manual_seed_all(args.random_seed)

        # Set distributed training options
        if args.num_gpus > 1 and args.num_gpus <= 8:
            args.rank = args.local_rank
            args.world_size = args.num_gpus
            torch.cuda.set_device(args.local_rank)
            torch.distributed.init_process_group(backend='nccl', init_method='env://')

        elif args.num_gpus > 8:
            raise # Not supported

        # Prepare experiment directories and save options
        project_dir = pathlib.Path(args.project_dir)
        self.checkpoints_dir = project_dir / 'runs' / args.experiment_name / 'checkpoints'

        # Store options
        if not args.no_disk_write_ops:
            os.makedirs(self.checkpoints_dir, exist_ok=True)

        self.experiment_dir = project_dir / 'runs' / args.experiment_name

        if not args.no_disk_write_ops:
            # Redirect stdout
            if args.redirect_print_to_file:
                logs_dir = self.experiment_dir / 'logs'
                os.makedirs(logs_dir, exist_ok=True)
                sys.stdout = open(os.path.join(logs_dir, f'stdout_{args.rank}.txt'), 'w')
                sys.stderr = open(os.path.join(logs_dir, f'stderr_{args.rank}.txt'), 'w')

            if args.rank == 0:
                print(args)
                with open(self.experiment_dir / 'args.txt', 'wt') as args_file:
                    for k, v in sorted(vars(args).items()):
                        args_file.write('%s: %s\n' % (str(k), str(v)))

        # Initialize model
        self.runner = runner

        if self.runner is None:
            self.runner = importlib.import_module(f'runners.{args.runner_name}').RunnerWrapper(args)

        # Load pre-trained weights (if needed)
        init_networks = rn_utils.parse_str_to_list(args.init_networks) if args.init_networks else {}
        networks_to_train = self.runner.nets_names_to_train

        if args.init_which_epoch != 'none' and args.init_experiment_dir:
            for net_name in init_networks:
                self.runner.nets[net_name].load_state_dict(torch.load(pathlib.Path(args.init_experiment_dir) / 'checkpoints' / f'{args.init_which_epoch}_{net_name}.pth', map_location='cpu'))

        if args.which_epoch != 'none':
            for net_name in networks_to_train:
                if net_name not in init_networks:
                    self.runner.nets[net_name].load_state_dict(torch.load(self.checkpoints_dir / f'{args.which_epoch}_{net_name}.pth', map_location='cpu'))

        if args.num_gpus > 0:
            self.runner.cuda()

        if args.rank == 0:
            print(self.runner)
Exemplo n.º 9
0
    def __init__(self, args):
        super(LossWrapper, self).__init__()
        ### Define losses ###
        losses = {
            'mse': F.mse_loss,
            'l1': F.l1_loss}

        self.loss = losses[args.per_loss_type]

        # Weights for each feature extractor
        self.weights = rn_utils.parse_str_to_list(args.per_loss_weights, value_type=float, sep=',')
        self.layer_weights = rn_utils.parse_str_to_list(args.per_layer_weights, value_type=float, sep=',')
        self.names = [rn_utils.parse_str_to_list(s, sep=',') for s in rn_utils.parse_str_to_list(args.per_loss_names, sep=';')]

        ### Define extractors ###
        self.apply_to = [rn_utils.parse_str_to_list(s, sep=',') for s in rn_utils.parse_str_to_list(args.per_loss_apply_to, sep=';')]
        weights_dir = pathlib.Path(args.project_dir) / 'pretrained_weights' / 'perceptual'

        # Architectures for the supported networks 
        networks = {
            'vgg16': models.vgg16,
            'vgg19': models.vgg19}

        # Build a list of used networks
        self.nets = nn.ModuleList()
        self.full_net_names = rn_utils.parse_str_to_list(args.per_full_net_names, sep=',')

        for full_net_name in self.full_net_names:
            net_name, dataset_name, framework_name = full_net_name.split('_')

            if dataset_name == 'imagenet' and framework_name == 'pytorch':
                self.nets.append(networks[net_name](pretrained=True))
                mean = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] * 2 - 1
                std  = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] * 2
            
            elif framework_name == 'caffe':
                self.nets.append(networks[net_name]())
                self.nets[-1].load_state_dict(torch.load(weights_dir / f'{full_net_name}.pth'))
                self.nets[-1] = self.nets[-1]
                mean = torch.FloatTensor([103.939, 116.779, 123.680])[None, :, None, None] / 127.5 - 1
                std  = torch.FloatTensor([     1.,      1.,      1.])[None, :, None, None] / 127.5
            
            # Register means and stds as buffers
            self.register_buffer(f'{full_net_name}_mean', mean)
            self.register_buffer(f'{full_net_name}_std', std)

        # Perform the slicing according to the required layers
        for n, (net, block_idx) in enumerate(zip(self.nets, rn_utils.parse_str_to_list(args.per_net_layers, sep=';'))):
            net_blocks = nn.ModuleList()

            # Parse indices of slices
            block_idx = rn_utils.parse_str_to_list(block_idx, value_type=int, sep=',')
            for i, idx in enumerate(block_idx):
                block_idx[i] = idx

            # Slice conv blocks
            layers = []
            for i, layer in enumerate(net.features):
                if layer.__class__.__name__ == 'MaxPool2d' and args.per_pooling == 'avgpool':
                    layer = nn.AvgPool2d(2)
                layers.append(layer)
                if i in block_idx:
                    net_blocks.append(nn.Sequential(*layers))
                    layers = []

            # Add layers for prediction of the scores (if needed)
            if block_idx[-1] == 'fc':
                layers.extend([
                    nn.AdaptiveAvgPool2d(7),
                    utils.Flatten(1)])
                for layer in net.classifier:
                    layers.append(layer)
                net_blocks.append(nn.Sequential(*layers))

            # Store sliced net
            self.nets[n] = net_blocks