def __init__(self, feat_stride=16, init_filter_reg=1e-2, init_gauss_sigma=1.0, num_dist_bins=5, bin_displacement=1.0, mask_init_factor=4.0, score_act='bentpar', act_param=None, mask_act='sigmoid'): super().__init__() self.filter_reg = nn.Parameter(init_filter_reg * torch.ones(1)) self.feat_stride = feat_stride self.distance_map = DistanceMap(num_dist_bins, bin_displacement) # Distance coordinates d = torch.arange(num_dist_bins, dtype=torch.float32).reshape( 1, -1, 1, 1) * bin_displacement if init_gauss_sigma == 0: init_gauss = torch.zeros_like(d) init_gauss[0, 0, 0, 0] = 1 else: init_gauss = torch.exp(-1 / 2 * (d / init_gauss_sigma)**2) self.label_map_predictor = nn.Conv2d(num_dist_bins, 1, kernel_size=1, bias=False) self.label_map_predictor.weight.data = init_gauss - init_gauss.min() mask_layers = [nn.Conv2d(num_dist_bins, 1, kernel_size=1, bias=False)] if mask_act == 'sigmoid': mask_layers.append(nn.Sigmoid()) init_bias = 0.0 elif mask_act == 'linear': init_bias = 0.5 else: raise ValueError('Unknown activation') self.target_mask_predictor = nn.Sequential(*mask_layers) self.target_mask_predictor[ 0].weight.data = mask_init_factor * torch.tanh(2.0 - d) + init_bias self.spatial_weight_predictor = nn.Conv2d(num_dist_bins, 1, kernel_size=1, bias=False) self.spatial_weight_predictor.weight.data.fill_(1.0) if score_act == 'bentpar': self.score_activation = activation.BentIdentPar(act_param) elif score_act == 'relu': self.score_activation = activation.LeakyReluPar() else: raise ValueError('Unknown activation')
def __init__(self, num_iter=1, feat_stride=16, init_step_length=1.0, init_filter_reg=1e-2, init_gauss_sigma=1.0, num_dist_bins=5, bin_displacement=1.0, mask_init_factor=4.0, score_act='relu', act_param=None, min_filter_reg=1e-3, mask_act='sigmoid', detach_length=float('Inf'), alpha_eps=0): super().__init__() self.num_iter = num_iter self.feat_stride = feat_stride self.log_step_length = nn.Parameter(math.log(init_step_length) * torch.ones(1)) self.filter_reg = nn.Parameter(init_filter_reg * torch.ones(1)) self.distance_map = DistanceMap(num_dist_bins, bin_displacement) self.min_filter_reg = min_filter_reg self.detach_length = detach_length self.alpha_eps = alpha_eps # Distance coordinates d = torch.arange(num_dist_bins, dtype=torch.float32).reshape(1,-1,1,1) * bin_displacement if init_gauss_sigma == 0: init_gauss = torch.zeros_like(d) init_gauss[0,0,0,0] = 1 else: init_gauss = torch.exp(-1/2 * (d / init_gauss_sigma)**2) # Module that predicts the target label function (y in the paper) self.label_map_predictor = nn.Conv2d(num_dist_bins, 1, kernel_size=1, bias=False) self.label_map_predictor.weight.data = init_gauss - init_gauss.min() # Module that predicts the target mask (m in the paper) mask_layers = [nn.Conv2d(num_dist_bins, 1, kernel_size=1, bias=False)] if mask_act == 'sigmoid': mask_layers.append(nn.Sigmoid()) init_bias = 0.0 elif mask_act == 'linear': init_bias = 0.5 else: raise ValueError('Unknown activation') self.target_mask_predictor = nn.Sequential(*mask_layers) self.target_mask_predictor[0].weight.data = mask_init_factor * torch.tanh(2.0 - d) + init_bias # Module that predicts the residual weights (v in the paper) self.spatial_weight_predictor = nn.Conv2d(num_dist_bins, 1, kernel_size=1, bias=False) self.spatial_weight_predictor.weight.data.fill_(1.0) # The score actvation and its derivative if score_act == 'bentpar': self.score_activation = activation.BentIdentPar(act_param) self.score_activation_deriv = activation.BentIdentParDeriv(act_param) elif score_act == 'relu': self.score_activation = activation.LeakyReluPar() self.score_activation_deriv = activation.LeakyReluParDeriv() else: raise ValueError('Unknown score activation')