def __init__(self, args): super(LossWrapper, self).__init__() self.apply_to = [ rn_utils.parse_str_to_list(s, sep=',') for s in rn_utils.parse_str_to_list(args.pix_loss_apply_to, sep=';') ] # Supported loss functions losses = {'mse': F.mse_loss, 'l1': F.l1_loss, 'ce': F.cross_entropy} self.loss = losses[args.pix_loss_type] # Weights for each feature extractor self.weights = rn_utils.parse_str_to_list(args.pix_loss_weights, value_type=float) self.names = rn_utils.parse_str_to_list(args.pix_loss_names)
def __init__(self, args): super(LossWrapper, self).__init__() self.apply_to = rn_utils.parse_str_to_list(args.wpr_loss_apply_to) self.eps = args.eps self.reg_type = args.wpr_loss_type self.weight = args.wpr_loss_weight self.weight_decay = args.wpr_loss_weight_decay self.decay_schedule = args.wpr_loss_decay_schedule self.num_iters = 0
def __init__(self, args_dict): super(InferenceWrapper, self).__init__() # Get a config for the network self.args = self.get_args(args_dict) self.to_tensor = transforms.ToTensor() # Load the model self.runner = importlib.import_module( f'runners.{self.args.runner_name}').RunnerWrapper(self.args, training=False) self.runner.eval() # Load pretrained weights checkpoints_dir = pathlib.Path( self.args.project_dir ) / 'runs' / self.args.experiment_name / 'checkpoints' # Load pre-trained weights init_networks = rn_utils.parse_str_to_list( self.args.init_networks) if self.args.init_networks else {} networks_to_train = self.runner.nets_names_to_train if self.args.init_which_epoch != 'none' and self.args.init_experiment_dir: for net_name in init_networks: self.runner.nets[net_name].load_state_dict( torch.load(pathlib.Path(self.args.init_experiment_dir) / 'checkpoints' / f'{self.args.init_which_epoch}_{net_name}.pth', map_location='cpu')) for net_name in networks_to_train: if net_name not in init_networks and net_name in self.runner.nets.keys( ): self.runner.nets[net_name].load_state_dict( torch.load(checkpoints_dir / f'{self.args.which_epoch}_{net_name}.pth', map_location='cpu')) # Stickman/facemasks drawer self.fa = face_alignment.FaceAlignment( face_alignment.LandmarksType._2D, flip_input=True) self.net_seg = wrapper.SegmentationWrapper(self.args) # Remove spectral norm to improve the performance self.runner.apply(rn_utils.remove_spectral_norm) # self.runner.apply(rn_utils.prepare_for_mobile_inference) if self.args.num_gpus > 0: self.cuda()
def __init__(self, args): super(LossWrapper, self).__init__() self.apply_to = [ rn_utils.parse_str_to_list(s, sep=',') for s in rn_utils.parse_str_to_list(args.seg_loss_apply_to, sep=';') ] self.names = rn_utils.parse_str_to_list(args.seg_loss_names, sep=',') # Supported loss functions losses = { 'bce': F.binary_cross_entropy_with_logits, 'dice': lambda fake_seg, real_seg: torch.log( (fake_seg**2).sum() + (real_seg**2).sum()) - torch.log( (2 * fake_seg * real_seg).sum()) } self.loss = losses[args.seg_loss_type] self.weights = args.seg_loss_weights self.eps = args.eps
def load_names(self, args): # Initialize utility lists and dicts for the networks self.nets_names_to_train = utils.parse_str_to_list(args.networks_to_train) self.nets_names_train = utils.parse_str_to_list(args.networks_train) self.nets_names_test = utils.parse_str_to_list(args.networks_test) self.nets_names_calc_stats = utils.parse_str_to_list(args.networks_calc_stats) # Initialize utility lists and dicts for the networks self.losses_names_train = utils.parse_str_to_list(args.losses_train) self.losses_names_test = utils.parse_str_to_list(args.losses_test)
def get_args(parser): # Networks used in train and test parser.add( '--networks_train', default= 'identity_embedder, texture_generator, keypoints_embedder, inference_generator, discriminator', help= 'order of forward passes during the training of gen (or gen and dis for sim sgd)' ) parser.add( '--networks_test', default= 'identity_embedder, texture_generator, keypoints_embedder, inference_generator', help='order of forward passes during testing') parser.add( '--networks_calc_stats', default= 'identity_embedder, texture_generator, keypoints_embedder, inference_generator', help='order of forward passes during stats calculation') parser.add( '--networks_to_train', default= 'identity_embedder, texture_generator, keypoints_embedder, inference_generator, discriminator', help='names of networks that are being trained') # Losses used in train and test parser.add( '--losses_train', default= 'adversarial, feature_matching, perceptual, pixelwise, segmentation, warping_regularizer', help='losses evaluated during training') parser.add('--losses_test', default='lpips, csim', help='losses evaluated during testing') # Spectral norm options parser.add( '--spn_networks', default= 'identity_embedder, texture_generator, keypoints_embedder, inference_generator, discriminator', help='networks to apply spectral norm to') parser.add( '--spn_exceptions', default='', help= 'a list of exceptional submodules that have spectral norm removed') parser.add('--spn_layers', default='conv2d, linear', help='layers to apply spectral norm to') # Weight averaging options parser.add( '--wgv_mode', default='none', help= 'none|running_average|average -- exponential moving averaging or weight averaging' ) parser.add('--wgv_momentum', default=0.999, type=float, help='momentum value in EMA weight averaging') # Training options parser.add('--eps', default=1e-7, type=float) parser.add( '--optims', default= 'identity_embedder: adam, texture_generator: adam, keypoints_embedder: adam, inference_generator: adam, discriminator: adam', help='network_name: optimizer') parser.add( '--lrs', default= 'identity_embedder: 2e-4, texture_generator: 2e-4, keypoints_embedder: 2e-4, inference_generator: 2e-4, discriminator: 2e-4', help='learning rates for each network') parser.add( '--stats_calc_iters', default=100, type=int, help='number of iterations used to calculate standing statistics') parser.add('--num_visuals', default=32, type=int, help='the total number of output visuals') parser.add('--bn_momentum', default=1.0, type=float, help='momentum of the batchnorm layers') parser.add('--adam_beta1', default=0.5, type=float, help='beta1 (momentum of the gradient) parameter for Adam') args, _ = parser.parse_known_args() # Add args from the required networks networks_names = list( set( utils.parse_str_to_list(args.networks_train, sep=',') + utils.parse_str_to_list(args.networks_test, sep=','))) for network_name in networks_names: importlib.import_module( f'networks.{network_name}').NetworkWrapper.get_args(parser) # Add args from the losses losses_names = list( set( utils.parse_str_to_list(args.losses_train, sep=',') + utils.parse_str_to_list(args.losses_test, sep=','))) for loss_name in losses_names: importlib.import_module( f'losses.{loss_name}').LossWrapper.get_args(parser) return parser
def __init__(self, args, training=True): super(RunnerWrapper, self).__init__() # Store general options self.args = args self.training = training # Read names lists from the args self.load_names(args) # Initialize classes for the networks nets_names = self.nets_names_test if self.training: nets_names += self.nets_names_train nets_names = list(set(nets_names)) self.nets = nn.ModuleDict() for net_name in sorted(nets_names): self.nets[net_name] = importlib.import_module( f'networks.{net_name}').NetworkWrapper(args) if args.num_gpus > 1: # Apex is only needed for multi-gpu training from apex import parallel self.nets[net_name] = parallel.convert_syncbn_model( self.nets[net_name]) # Set nets that are not training into eval mode for net_name in self.nets.keys(): if net_name not in self.nets_names_to_train: self.nets[net_name].eval() # Initialize classes for the losses if self.training: losses_names = list( set(self.losses_names_train + self.losses_names_test)) self.losses = nn.ModuleDict() for loss_name in sorted(losses_names): self.losses[loss_name] = importlib.import_module( f'losses.{loss_name}').LossWrapper(args) # Spectral norm if args.spn_layers: spn_layers = utils.parse_str_to_list(args.spn_layers, sep=',') spn_nets_names = utils.parse_str_to_list(args.spn_networks, sep=',') for net_name in spn_nets_names: self.nets[net_name].apply(lambda module: utils.spectral_norm( module, apply_to=spn_layers, eps=args.eps)) # Remove spectral norm in modules in exceptions spn_exceptions = utils.parse_str_to_list(args.spn_exceptions, sep=',') for full_module_name in spn_exceptions: if not full_module_name: continue parts = full_module_name.split('.') # Get the module that needs to be changed module = self.nets[parts[0]] for part in parts[1:]: module = getattr(module, part) module.apply(utils.remove_spectral_norm) # Weight averaging if args.wgv_mode != 'none': # Apply weight averaging only for networks that are being trained for net_name, _ in self.nets_names_to_train: self.nets[net_name].apply( lambda module: utils.weight_averaging( module, mode=args.wgv_mode, momentum=args.wgv_momentum) ) # Check which networks are being trained and put the rest into the eval mode for net_name in self.nets.keys(): if net_name not in self.nets_names_to_train: self.nets[net_name].eval() # Set the same batchnorm momentum accross all modules if self.training: self.apply(lambda module: utils.set_batchnorm_momentum( module, args.bn_momentum)) # Store a history of losses and images for visualization self.losses_history = { True: {}, # self.training = True False: {} }
def __init__(self, args, runner=None): super(TrainingWrapper, self).__init__() # Initialize and apply general options ssl._create_default_https_context = ssl._create_unverified_context torch.backends.cudnn.benchmark = True torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) # Set distributed training options if args.num_gpus > 1 and args.num_gpus <= 8: args.rank = args.local_rank args.world_size = args.num_gpus torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') elif args.num_gpus > 8: raise # Not supported # Prepare experiment directories and save options project_dir = pathlib.Path(args.project_dir) self.checkpoints_dir = project_dir / 'runs' / args.experiment_name / 'checkpoints' # Store options if not args.no_disk_write_ops: os.makedirs(self.checkpoints_dir, exist_ok=True) self.experiment_dir = project_dir / 'runs' / args.experiment_name if not args.no_disk_write_ops: # Redirect stdout if args.redirect_print_to_file: logs_dir = self.experiment_dir / 'logs' os.makedirs(logs_dir, exist_ok=True) sys.stdout = open(os.path.join(logs_dir, f'stdout_{args.rank}.txt'), 'w') sys.stderr = open(os.path.join(logs_dir, f'stderr_{args.rank}.txt'), 'w') if args.rank == 0: print(args) with open(self.experiment_dir / 'args.txt', 'wt') as args_file: for k, v in sorted(vars(args).items()): args_file.write('%s: %s\n' % (str(k), str(v))) # Initialize model self.runner = runner if self.runner is None: self.runner = importlib.import_module(f'runners.{args.runner_name}').RunnerWrapper(args) # Load pre-trained weights (if needed) init_networks = rn_utils.parse_str_to_list(args.init_networks) if args.init_networks else {} networks_to_train = self.runner.nets_names_to_train if args.init_which_epoch != 'none' and args.init_experiment_dir: for net_name in init_networks: self.runner.nets[net_name].load_state_dict(torch.load(pathlib.Path(args.init_experiment_dir) / 'checkpoints' / f'{args.init_which_epoch}_{net_name}.pth', map_location='cpu')) if args.which_epoch != 'none': for net_name in networks_to_train: if net_name not in init_networks: self.runner.nets[net_name].load_state_dict(torch.load(self.checkpoints_dir / f'{args.which_epoch}_{net_name}.pth', map_location='cpu')) if args.num_gpus > 0: self.runner.cuda() if args.rank == 0: print(self.runner)
def __init__(self, args): super(LossWrapper, self).__init__() ### Define losses ### losses = { 'mse': F.mse_loss, 'l1': F.l1_loss} self.loss = losses[args.per_loss_type] # Weights for each feature extractor self.weights = rn_utils.parse_str_to_list(args.per_loss_weights, value_type=float, sep=',') self.layer_weights = rn_utils.parse_str_to_list(args.per_layer_weights, value_type=float, sep=',') self.names = [rn_utils.parse_str_to_list(s, sep=',') for s in rn_utils.parse_str_to_list(args.per_loss_names, sep=';')] ### Define extractors ### self.apply_to = [rn_utils.parse_str_to_list(s, sep=',') for s in rn_utils.parse_str_to_list(args.per_loss_apply_to, sep=';')] weights_dir = pathlib.Path(args.project_dir) / 'pretrained_weights' / 'perceptual' # Architectures for the supported networks networks = { 'vgg16': models.vgg16, 'vgg19': models.vgg19} # Build a list of used networks self.nets = nn.ModuleList() self.full_net_names = rn_utils.parse_str_to_list(args.per_full_net_names, sep=',') for full_net_name in self.full_net_names: net_name, dataset_name, framework_name = full_net_name.split('_') if dataset_name == 'imagenet' and framework_name == 'pytorch': self.nets.append(networks[net_name](pretrained=True)) mean = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] * 2 - 1 std = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] * 2 elif framework_name == 'caffe': self.nets.append(networks[net_name]()) self.nets[-1].load_state_dict(torch.load(weights_dir / f'{full_net_name}.pth')) self.nets[-1] = self.nets[-1] mean = torch.FloatTensor([103.939, 116.779, 123.680])[None, :, None, None] / 127.5 - 1 std = torch.FloatTensor([ 1., 1., 1.])[None, :, None, None] / 127.5 # Register means and stds as buffers self.register_buffer(f'{full_net_name}_mean', mean) self.register_buffer(f'{full_net_name}_std', std) # Perform the slicing according to the required layers for n, (net, block_idx) in enumerate(zip(self.nets, rn_utils.parse_str_to_list(args.per_net_layers, sep=';'))): net_blocks = nn.ModuleList() # Parse indices of slices block_idx = rn_utils.parse_str_to_list(block_idx, value_type=int, sep=',') for i, idx in enumerate(block_idx): block_idx[i] = idx # Slice conv blocks layers = [] for i, layer in enumerate(net.features): if layer.__class__.__name__ == 'MaxPool2d' and args.per_pooling == 'avgpool': layer = nn.AvgPool2d(2) layers.append(layer) if i in block_idx: net_blocks.append(nn.Sequential(*layers)) layers = [] # Add layers for prediction of the scores (if needed) if block_idx[-1] == 'fc': layers.extend([ nn.AdaptiveAvgPool2d(7), utils.Flatten(1)]) for layer in net.classifier: layers.append(layer) net_blocks.append(nn.Sequential(*layers)) # Store sliced net self.nets[n] = net_blocks