def __init__(self, dis_cfg, data_cfg): super().__init__() self.data_cfg = data_cfg num_input_channels = get_paired_input_label_channel_number(data_cfg) if num_input_channels == 0: num_input_channels = getattr(data_cfg, 'label_channels', 1) num_img_channels = get_paired_input_image_channel_number(data_cfg) self.num_frames_D = data_cfg.num_frames_D self.num_scales = get_nested_attr(dis_cfg, 'temporal.num_scales', 0) num_netD_input_channels = (num_input_channels + num_img_channels) self.use_few_shot = 'few_shot' in data_cfg.type if self.use_few_shot: num_netD_input_channels *= 2 self.net_D = MultiPatchDiscriminator(dis_cfg.image, num_netD_input_channels) self.add_dis_cfg = getattr(dis_cfg, 'additional_discriminators', None) if self.add_dis_cfg is not None: for name in self.add_dis_cfg: add_dis_cfg = self.add_dis_cfg[name] num_ch = num_img_channels * (2 if self.use_few_shot else 1) setattr(self, 'net_D_' + name, MultiPatchDiscriminator(add_dis_cfg, num_ch)) # Temporal discriminator. self.num_netDT_input_channels = num_img_channels * self.num_frames_D for n in range(self.num_scales): setattr(self, 'net_DT%d' % n, MultiPatchDiscriminator(dis_cfg.temporal, self.num_netDT_input_channels)) self.has_fg = getattr(data_cfg, 'has_foreground', False)
def __init__(self, dis_cfg, data_cfg): super(Discriminator, self).__init__() print('Multi-resolution patch discriminator initialization.') # We assume the first datum is the ground truth image. image_channels = get_paired_input_image_channel_number(data_cfg) # Calculate number of channels in the input label. num_labels = get_paired_input_label_channel_number(data_cfg) # Build the discriminator. kernel_size = getattr(dis_cfg, 'kernel_size', 3) num_filters = getattr(dis_cfg, 'num_filters', 128) max_num_filters = getattr(dis_cfg, 'max_num_filters', 512) num_discriminators = getattr(dis_cfg, 'num_discriminators', 2) num_layers = getattr(dis_cfg, 'num_layers', 5) activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none') weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral') print('\tBase filter number: %d' % num_filters) print('\tNumber of discriminators: %d' % num_discriminators) print('\tNumber of layers in a discriminator: %d' % num_layers) print('\tWeight norm type: %s' % weight_norm_type) num_input_channels = image_channels + num_labels self.model = MultiResPatchDiscriminator( num_discriminators, kernel_size, num_input_channels, num_filters, num_layers, max_num_filters, activation_norm_type, weight_norm_type) print('Done with the Multi-resolution patch ' 'discriminator initialization.')
def __init__(self, dis_cfg, data_cfg): super(Discriminator, self).__init__() print('Multi-resolution patch discriminator initialization.') # We assume the first datum is the ground truth image. image_channels = get_paired_input_image_channel_number(data_cfg) # Calculate number of channels in the input label. if data_cfg.type == 'imaginaire.datasets.paired_videos': num_labels = get_paired_input_label_channel_number(data_cfg, video=True) else: num_labels = get_paired_input_label_channel_number(data_cfg) # Build the discriminator. kernel_size = getattr(dis_cfg, 'kernel_size', 3) num_filters = getattr(dis_cfg, 'num_filters', 128) max_num_filters = getattr(dis_cfg, 'max_num_filters', 512) num_discriminators = getattr(dis_cfg, 'num_discriminators', 2) num_layers = getattr(dis_cfg, 'num_layers', 5) activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none') weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral') print('\tBase filter number: %d' % num_filters) print('\tNumber of discriminators: %d' % num_discriminators) print('\tNumber of layers in a discriminator: %d' % num_layers) print('\tWeight norm type: %s' % weight_norm_type) num_input_channels = image_channels + num_labels self.discriminators = nn.ModuleList() for i in range(num_discriminators): net_discriminator = NLayerPatchDiscriminator( kernel_size, num_input_channels, num_filters, num_layers, max_num_filters, activation_norm_type, weight_norm_type) self.discriminators.append(net_discriminator) print('Done with the Multi-resolution patch ' 'discriminator initialization.') fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3) fpse_activation_norm_type = getattr(dis_cfg, 'fpse_activation_norm_type', 'none') self.fpse_discriminator = FPSEDiscriminator(image_channels, num_labels, num_filters, fpse_kernel_size, weight_norm_type, fpse_activation_norm_type)
def cluster_features(cfg, train_data_loader, net_E, preprocess=None, small_ratio=0.0625, is_cityscapes=True): r"""Use clustering to compute the features. Args: cfg (obj): Global configuration file. train_data_loader (obj): Dataloader for iterate through the training set. net_E (nn.Module): Pytorch network. preprocess (function): Pre-processing function. small_ratio (float): We only consider instance that at least occupy $(small_ratio) amount of image space. is_cityscapes (bool): Is this is the cityscape dataset? In the Cityscapes dataset, the instance labels for car start with 26001, 26002, ... Returns: ( num_labels x num_cluster_centers x feature_dims): cluster centers. """ # Encode features. label_nc = get_paired_input_label_channel_number(cfg.data) feat_nc = cfg.gen.enc.num_feat_channels n_clusters = getattr(cfg.gen.enc, 'num_clusters', 10) # Compute features. features = {} for label in range(label_nc): features[label] = np.zeros((0, feat_nc + 1)) for data in train_data_loader: if preprocess is not None: data = preprocess(data) feat = encode_features(net_E, feat_nc, label_nc, data['images'], data['instance_maps'], is_cityscapes) # We only collect the feature vectors for the master GPU. if is_master(): for label in range(label_nc): features[label] = np.append( features[label], feat[label], axis=0) # Clustering. # We only perform clustering for the master GPU. if is_master(): for label in range(label_nc): feat = features[label] # We only consider segments that are greater than a pre-set # threshold. feat = feat[feat[:, -1] > small_ratio, :-1] if feat.shape[0]: n_clusters = min(feat.shape[0], n_clusters) kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(feat) n, d = kmeans.cluster_centers_.shape this_cluster = getattr(net_E, 'cluster_%d' % label) this_cluster[0:n, :] = torch.Tensor( kmeans.cluster_centers_).float()
def __init__(self, enc_cfg, data_cfg): super(Encoder, self).__init__() label_nc = get_paired_input_label_channel_number(data_cfg) feat_nc = enc_cfg.num_feat_channels n_clusters = getattr(enc_cfg, 'num_clusters', 10) for i in range(label_nc): dummy_arr = np.zeros((n_clusters, feat_nc), dtype=np.float32) self.register_buffer('cluster_%d' % i, torch.tensor(dummy_arr, dtype=torch.float32)) num_img_channels = get_paired_input_image_channel_number(data_cfg) self.num_feat_channels = getattr(enc_cfg, 'num_feat_channels', 3) num_filters = getattr(enc_cfg, 'num_filters', 64) num_downsamples = getattr(enc_cfg, 'num_downsamples', 4) weight_norm_type = getattr(enc_cfg, 'weight_norm_type', 'none') activation_norm_type = getattr(enc_cfg, 'activation_norm_type', 'instance') padding_mode = getattr(enc_cfg, 'padding_mode', 'reflect') base_conv_block = partial(Conv2dBlock, padding_mode=padding_mode, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, nonlinearity='relu') model = [base_conv_block(num_img_channels, num_filters, 7, padding=3)] # Downsample. for i in range(num_downsamples): ch = num_filters * (2**i) model += [base_conv_block(ch, ch * 2, 3, stride=2, padding=1)] # Upsample. for i in reversed(range(num_downsamples)): ch = num_filters * (2 ** i) model += [NearestUpsample(scale_factor=2), base_conv_block(ch * 2, ch, 3, padding=1)] model += [Conv2dBlock(num_filters, self.num_feat_channels, 7, padding=3, padding_mode=padding_mode, nonlinearity='tanh')] self.model = nn.Sequential(*model)
def __init__(self, gen_cfg, data_cfg): super().__init__() self.gen_cfg = gen_cfg self.data_cfg = data_cfg self.num_frames_G = data_cfg.num_frames_G # Number of residual blocks in generator. self.num_layers = num_layers = getattr(gen_cfg, 'num_layers', 7) # Number of downsamplings for previous frame. self.num_downsamples_img = getattr(gen_cfg, 'num_downsamples_img', 4) # Number of filters in the first layer. self.num_filters = num_filters = getattr(gen_cfg, 'num_filters', 32) self.max_num_filters = getattr(gen_cfg, 'max_num_filters', 1024) self.kernel_size = kernel_size = getattr(gen_cfg, 'kernel_size', 3) padding = kernel_size // 2 # For pose dataset. self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset') if self.is_pose_data: pose_cfg = data_cfg.for_pose_dataset self.pose_type = getattr(pose_cfg, 'pose_type', 'both') self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels', False) # Input data params. self.num_input_channels = num_input_channels = get_paired_input_label_channel_number( data_cfg) num_img_channels = get_paired_input_image_channel_number(data_cfg) aug_cfg = data_cfg.val.augmentations if hasattr(aug_cfg, 'center_crop_h_w'): crop_h_w = aug_cfg.center_crop_h_w elif hasattr(aug_cfg, 'resize_h_w'): crop_h_w = aug_cfg.resize_h_w else: raise ValueError('Need to specify output size.') crop_h, crop_w = crop_h_w.split(',') crop_h, crop_w = int(crop_h), int(crop_w) # Spatial size at the bottle neck of generator. self.sh = crop_h // (2**num_layers) self.sw = crop_w // (2**num_layers) # Noise vector dimension. self.z_dim = getattr(gen_cfg, 'style_dims', 256) self.use_segmap_as_input = \ getattr(gen_cfg, 'use_segmap_as_input', False) # Label / image embedding network. self.emb_cfg = emb_cfg = getattr(gen_cfg, 'embed', None) self.use_embed = getattr(emb_cfg, 'use_embed', 'True') self.num_downsamples_embed = getattr(emb_cfg, 'num_downsamples', 5) if self.use_embed: self.label_embedding = LabelEmbedder(emb_cfg, num_input_channels) # Flow network. self.flow_cfg = flow_cfg = gen_cfg.flow # Use SPADE to combine warped and hallucinated frames instead of # linear combination. self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True) # Number of layers to perform multi-spade combine. self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine, 'num_layers', 3) # At beginning of training, only train an image generator. self.temporal_initialized = False # Whether to output hallucinated frame (when training temporal network) # for additional loss. self.generate_raw_output = False # Image generation network. weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral') activation_norm_type = gen_cfg.activation_norm_type activation_norm_params = gen_cfg.activation_norm_params if self.use_embed and \ not hasattr(activation_norm_params, 'num_filters'): activation_norm_params.num_filters = 0 nonlinearity = 'leakyrelu' self.base_res_block = base_res_block = partial( Res2dBlock, kernel_size=kernel_size, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, activation_norm_params=activation_norm_params, nonlinearity=nonlinearity, order='NACNAC') # Upsampling residual blocks. for i in range(num_layers, -1, -1): activation_norm_params.cond_dims = self.get_cond_dims(i) activation_norm_params.partial = self.get_partial(i) if hasattr( self, 'get_partial') else False layer = base_res_block(self.get_num_filters(i + 1), self.get_num_filters(i)) setattr(self, 'up_%d' % i, layer) # Final conv layer. self.conv_img = Conv2dBlock(num_filters, num_img_channels, kernel_size, padding=padding, nonlinearity=nonlinearity, order='AC') num_filters = min(self.max_num_filters, num_filters * (2**(self.num_layers + 1))) if self.use_segmap_as_input: self.fc = Conv2dBlock(num_input_channels, num_filters, kernel_size=3, padding=1) else: self.fc = LinearBlock(self.z_dim, num_filters * self.sh * self.sw) # Misc. self.downsample = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) self.upsample = partial(F.interpolate, scale_factor=2) self.init_temporal_network()
def __init__(self, flow_cfg, data_cfg): super().__init__() num_input_channels = get_paired_input_label_channel_number(data_cfg) num_prev_img_channels = get_paired_input_image_channel_number(data_cfg) num_frames = data_cfg.num_frames_G # Num. of input frames. self.num_filters = num_filters = getattr(flow_cfg, 'num_filters', 32) self.max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024) num_downsamples = getattr(flow_cfg, 'num_downsamples', 5) kernel_size = getattr(flow_cfg, 'kernel_size', 3) padding = kernel_size // 2 self.num_res_blocks = getattr(flow_cfg, 'num_res_blocks', 6) # Multiplier on the flow output. self.flow_output_multiplier = getattr(flow_cfg, 'flow_output_multiplier', 20) activation_norm_type = getattr(flow_cfg, 'activation_norm_type', 'sync_batch') weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral') base_conv_block = partial(Conv2dBlock, kernel_size=kernel_size, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, nonlinearity='leakyrelu') # Will downsample the labels and prev frames separately, then combine. down_lbl = [ base_conv_block(num_input_channels * num_frames, num_filters) ] down_img = [ base_conv_block(num_prev_img_channels * (num_frames - 1), num_filters) ] for i in range(num_downsamples): down_lbl += [ base_conv_block(self.get_num_filters(i), self.get_num_filters(i + 1), stride=2) ] down_img += [ base_conv_block(self.get_num_filters(i), self.get_num_filters(i + 1), stride=2) ] # Resnet blocks. res_flow = [] ch = self.get_num_filters(num_downsamples) for i in range(self.num_res_blocks): res_flow += [ Res2dBlock(ch, ch, kernel_size, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, order='CNACN') ] # Upsample. up_flow = [] for i in reversed(range(num_downsamples)): up_flow += [ nn.Upsample(scale_factor=2), base_conv_block(self.get_num_filters(i + 1), self.get_num_filters(i)) ] conv_flow = [Conv2dBlock(num_filters, 2, kernel_size, padding=padding)] conv_mask = [ Conv2dBlock(num_filters, 1, kernel_size, padding=padding, nonlinearity='sigmoid') ] self.down_lbl = nn.Sequential(*down_lbl) self.down_img = nn.Sequential(*down_img) self.res_flow = nn.Sequential(*res_flow) self.up_flow = nn.Sequential(*up_flow) self.conv_flow = nn.Sequential(*conv_flow) self.conv_mask = nn.Sequential(*conv_mask)
def __init__(self, gen_cfg, data_cfg): super().__init__() # pix2pixHD has a global generator. global_gen_cfg = gen_cfg.global_generator num_filters_global = getattr(global_gen_cfg, 'num_filters', 64) # Optionally, it can have several local enhancers. They are useful # for generating high resolution images. local_gen_cfg = gen_cfg.local_enhancer self.num_local_enhancers = num_local_enhancers = \ getattr(local_gen_cfg, 'num_enhancers', 1) # By default, pix2pixHD using instance normalization. activation_norm_type = getattr(gen_cfg, 'activation_norm_type', 'instance') activation_norm_params = getattr(gen_cfg, 'activation_norm_params', None) weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '') padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect') base_conv_block = partial(Conv2dBlock, padding_mode=padding_mode, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, activation_norm_params=activation_norm_params, nonlinearity='relu') base_res_block = partial(Res2dBlock, padding_mode=padding_mode, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, activation_norm_params=activation_norm_params, nonlinearity='relu', order='CNACN') # Know what is the number of available segmentation labels. num_input_channels = get_paired_input_label_channel_number(data_cfg) self.concat_features = False # Check whether label input contains specific type of data (e.g. # instance_maps). self.contain_instance_map = False if data_cfg.input_labels[-1] == 'instance_maps': self.contain_instance_map = True # The feature encoder is only useful when the instance map is provided. if hasattr(gen_cfg, 'enc') and self.contain_instance_map: num_feat_channels = getattr(gen_cfg.enc, 'num_feat_channels', 0) if num_feat_channels > 0: num_input_channels += num_feat_channels self.concat_features = True self.encoder = Encoder(gen_cfg.enc, data_cfg) # Global generator model. global_model = GlobalGenerator(global_gen_cfg, data_cfg, num_input_channels, padding_mode, base_conv_block, base_res_block) if num_local_enhancers == 0: self.global_model = global_model else: # Get rid of the last layer. global_model = global_model.model global_model = [global_model[i] for i in range(len(global_model) - 1)] # global_model = [global_model[i] # for i in range(len(global_model) - 2)] self.global_model = nn.Sequential(*global_model) # Local enhancer model. for n in range(num_local_enhancers): # num_filters = num_filters_global // (2 ** n) num_filters = num_filters_global // (2 ** (n + 1)) output_img = (n == num_local_enhancers - 1) setattr(self, 'enhancer_%d' % n, LocalEnhancer(local_gen_cfg, data_cfg, num_input_channels, num_filters, padding_mode, base_conv_block, base_res_block, output_img)) self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def __init__(self, gen_cfg, data_cfg): super(Generator, self).__init__() print('SPADE generator initialization.') # We assume the first datum is the ground truth image. image_channels = get_paired_input_image_channel_number(data_cfg) # Calculate number of channels in the input label. num_labels = get_paired_input_label_channel_number(data_cfg) crop_h, crop_w = get_crop_h_w(data_cfg.train.augmentations) # Build the generator out_image_small_side_size = crop_w if crop_w < crop_h else crop_h num_filters = getattr(gen_cfg, 'num_filters', 128) kernel_size = getattr(gen_cfg, 'kernel_size', 3) weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral') cond_dims = 0 # Check whether we use the style code. style_dims = getattr(gen_cfg, 'style_dims', None) self.style_dims = style_dims if style_dims is not None: print('\tStyle code dimensions: %d' % style_dims) cond_dims += style_dims self.use_style = True else: self.use_style = False # Check whether we use the attribute code. if hasattr(gen_cfg, 'attribute_dims'): self.use_attribute = True self.attribute_dims = gen_cfg.attribute_dims cond_dims += gen_cfg.attribute_dims else: self.use_attribute = False if not self.use_style and not self.use_attribute: self.use_style_encoder = False else: self.use_style_encoder = True print('\tBase filter number: %d' % num_filters) print('\tConvolution kernel size: %d' % kernel_size) print('\tWeight norm type: %s' % weight_norm_type) skip_activation_norm = \ getattr(gen_cfg, 'skip_activation_norm', True) activation_norm_params = \ getattr(gen_cfg, 'activation_norm_params', None) if activation_norm_params is None: activation_norm_params = types.SimpleNamespace() if not hasattr(activation_norm_params, 'num_filters'): setattr(activation_norm_params, 'num_filters', 128) if not hasattr(activation_norm_params, 'kernel_size'): setattr(activation_norm_params, 'kernel_size', 3) if not hasattr(activation_norm_params, 'activation_norm_type'): setattr(activation_norm_params, 'activation_norm_type', 'sync_batch') if not hasattr(activation_norm_params, 'separate_projection'): setattr(activation_norm_params, 'separate_projection', False) if not hasattr(activation_norm_params, 'activation_norm_params'): activation_norm_params.activation_norm_params = \ types.SimpleNamespace() activation_norm_params.activation_norm_params.affine = True setattr(activation_norm_params, 'cond_dims', num_labels) if not hasattr(activation_norm_params, 'weight_norm_type'): setattr(activation_norm_params, 'weight_norm_type', weight_norm_type) global_adaptive_norm_type = getattr(gen_cfg, 'global_adaptive_norm_type', 'sync_batch') use_posenc_in_input_layer = getattr(gen_cfg, 'use_posenc_in_input_layer', True) print(activation_norm_params) self.spade_generator = SPADEGenerator( num_labels, out_image_small_side_size, image_channels, num_filters, kernel_size, cond_dims, activation_norm_params, weight_norm_type, global_adaptive_norm_type, skip_activation_norm, use_posenc_in_input_layer, self.use_style_encoder) if self.use_style: # Build the encoder. style_enc_cfg = getattr(gen_cfg, 'style_enc', None) if style_enc_cfg is None: style_enc_cfg = types.SimpleNamespace() if not hasattr(style_enc_cfg, 'num_filters'): setattr(style_enc_cfg, 'num_filters', 128) if not hasattr(style_enc_cfg, 'kernel_size'): setattr(style_enc_cfg, 'kernel_size', 3) if not hasattr(style_enc_cfg, 'weight_norm_type'): setattr(style_enc_cfg, 'weight_norm_type', weight_norm_type) setattr(style_enc_cfg, 'input_image_channels', image_channels) setattr(style_enc_cfg, 'style_dims', style_dims) self.style_encoder = StyleEncoder(style_enc_cfg) self.z = None print('Done with the SPADE generator initialization.')