Пример #1
0
 def build_readout(self, readout_feats):
     normalization_fgru = pt_utils.get_norm(self.normalization_fgru)
     self.readout_norm = normalization_fgru(
         self.output_feats, **self.normalization_fgru_params)
     init.constant_(self.readout_norm.weight, 0.1)
     init.constant_(self.readout_norm.bias, 0)
     self.readout_conv = Conv2dSamePadding(self.output_feats, readout_feats,
                                           1)
     init.kaiming_normal_(self.readout_conv.weight)
     init.constant_(self.readout_conv.bias, 0)
Пример #2
0
    def __init__(
            self,
            input_size,
            output_size,
            filter_size,
            layers,
            normalization=True,
            normalization_type='InstanceNorm2d',  # 'BatchNorm2D'
            normalization_params=None,
            non_linearity='ReLU',
            norm_pre_nl=False):
        super().__init__()

        if normalization_params is None:
            normalization_params = {}

        curr_feat = input_size
        self.module_list = []

        for i in range(layers):
            if i == layers - 1:
                next_feat = output_size
            elif i < layers // 2:
                next_feat = curr_feat // 2
            else:
                next_feat = curr_feat * 2

            conv = Conv2dSamePadding(curr_feat, next_feat, filter_size)
            init.orthogonal_(conv.weight)  # xavier_normal_
            init.constant_(conv.bias, 0)
            self.module_list.append(conv)

            if non_linearity is not None:
                nl = pt_utils.get_nl(non_linearity)

            if normalization is not None:
                norm = pt_utils.get_norm(normalization)(next_feat,
                                                        **normalization_params)
                init.constant_(norm.weight, 0.1)
                init.constant_(norm.bias, 0)

            if norm_pre_nl:
                if normalization is not None:
                    self.module_list.append(norm)
                if non_linearity is not None:
                    self.module_list.append(nl)
            else:
                if non_linearity is not None:
                    self.module_list.append(nl)
                if normalization is not None:
                    self.module_list.append(norm)

            curr_feat = next_feat
        self.attention = nn.Sequential(*self.module_list)
Пример #3
0
    def create_us_block(self, input_feat, output_feat):
        # us options: norm top_h, resize before or after block, ...
        normalization_fgru = pt_utils.get_norm(self.normalization_fgru)

        norm = normalization_fgru(input_feat, **self.normalization_fgru_params)
        init.constant_(norm.weight, 0.1)
        init.constant_(norm.bias, 0)
        conv1 = Conv2dSamePadding(input_feat, output_feat, 1)
        init.kaiming_normal_(conv1.weight)
        init.constant_(conv1.bias, 0)
        nl1 = nn.ReLU()
        conv2 = Conv2dSamePadding(output_feat, output_feat, 1)
        init.kaiming_normal_(conv2.weight)
        init.constant_(conv2.bias, 0)
        nl2 = nn.ReLU()

        module_list = [norm, conv1, nl1, conv2, nl2]
        # bilinear resize -> dependent on the other size
        # other version : norm -> conv 1*1 -> norm -> (extra conv 1*1 ->) resize
        # other version : transpose_conv 4*4/2 -> conv 3*3 -> norm
        return nn.Sequential(*module_list)
Пример #4
0
    def __init__(
            self,
            input_size,
            hidden_size,
            kernel_size,
            hidden_init='identity',
            attention='gala',  # 'se', None
            attention_layers=2,
            # attention_normalization=True,
            saliency_filter_size=5,
            tied_kernels=None,
            norm_attention=False,
            normalization_fgru='InstanceNorm2d',
            normalization_fgru_params={'affine': True},
            normalization_gate='InstanceNorm2d',
            normalization_gate_params={'affine': True},
            ff_non_linearity='ReLU',
            force_alpha_divisive=True,
            force_non_negativity=True,
            multiplicative_excitation=True,
            gate_bias_init='chronos',  #'ones'
            timesteps=8):
        super().__init__()

        self.padding = 'same'  # kernel_size // 2

        self.kernel_size = kernel_size
        self.tied_kernels = tied_kernels

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.hidden_init = hidden_init

        #self.ff_nl = ff_non_linearity
        self.ff_nl = pt_utils.get_nl(ff_non_linearity)

        self.normalization_fgru = normalization_fgru
        self.normalization_gate = normalization_gate
        self.normalization_fgru_params = normalization_fgru_params if normalization_fgru_params is not None else {}
        self.normalization_gate_params = normalization_gate_params if normalization_gate_params is not None else {}

        self.normalization_gate = normalization_gate

        if self.normalization_fgru:
            normalization_fgru = pt_utils.get_norm(normalization_fgru)
        if self.normalization_gate:
            normalization_gate = pt_utils.get_norm(normalization_gate)

        self.force_alpha_divisive = force_alpha_divisive
        self.force_non_negativity = force_non_negativity
        self.multiplicative_excitation = multiplicative_excitation

        # add attention
        if attention is not None and attention_layers > 0:
            if attention == 'se':
                self.attention = SE_Attention(
                    hidden_size,
                    hidden_size,
                    1,
                    layers=attention_layers,
                    normalization=self.normalization_gate
                    if norm_attention else None,  # 'BatchNorm2D'
                    normalization_params=self.normalization_gate_params,
                    non_linearity=ff_non_linearity,
                    norm_pre_nl=False)
            elif attention == 'gala':
                self.attention = GALA_Attention(
                    hidden_size,
                    hidden_size,
                    saliency_filter_size,
                    layers=attention_layers,
                    normalization=self.normalization_gate
                    if norm_attention else None,  # 'BatchNorm2D'
                    normalization_params=self.normalization_gate_params,
                    non_linearity=ff_non_linearity,
                    norm_pre_nl=False)
            else:
                raise 'attention type unknown'
        else:
            self.conv_g1_w = nn.Parameter(
                torch.empty(hidden_size, hidden_size, 1, 1))
            init.orthogonal_(self.conv_g1_w)  # xavier_normal_

        self.conv_g1_b = nn.Parameter(torch.empty(hidden_size, 1, 1))

        if self.normalization_gate:
            self.bn_g1 = normalization_gate(hidden_size,
                                            track_running_stats=False,
                                            **self.normalization_gate_params)
            init.constant_(self.bn_g1.weight, 0.1)
            init.constant_(self.bn_g1.bias, 0)

        if self.normalization_fgru:
            self.bn_c1 = normalization_fgru(hidden_size,
                                            track_running_stats=False,
                                            **self.normalization_fgru_params)
            init.constant_(self.bn_c1.weight, 0.1)
            init.constant_(self.bn_c1.bias, 0)

        if tied_kernels == 'depth':
            self.conv_c1_w = nn.Parameter(
                torch.empty(hidden_size, 1, kernel_size, kernel_size))
        else:
            self.conv_c1_w = nn.Parameter(
                torch.empty(hidden_size, hidden_size, kernel_size,
                            kernel_size))

        self.conv_g2_w = nn.Parameter(
            torch.empty(hidden_size, hidden_size, 1, 1))
        init.orthogonal_(self.conv_g2_w)

        self.conv_g2_b = nn.Parameter(torch.empty(hidden_size, 1, 1))

        if self.normalization_gate:
            self.bn_g2 = normalization_gate(hidden_size,
                                            track_running_stats=False,
                                            **self.normalization_gate_params)
            init.constant_(self.bn_g2.weight, 0.1)
            init.constant_(self.bn_g2.bias, 0)

        if self.normalization_fgru:
            self.bn_c2 = normalization_fgru(hidden_size,
                                            track_running_stats=False,
                                            **self.normalization_fgru_params)
            init.constant_(self.bn_c2.weight, 0.1)
            init.constant_(self.bn_c2.bias, 0)

        if tied_kernels == 'depth':
            self.conv_c2_w = nn.Parameter(
                torch.empty(hidden_size, 1, kernel_size, kernel_size))
        else:
            self.conv_c2_w = nn.Parameter(
                torch.empty(hidden_size, hidden_size, kernel_size,
                            kernel_size))

        init.orthogonal_(self.conv_c1_w)
        init.orthogonal_(self.conv_c2_w)

        # if tied_kernels!='depth':
        #     self.conv_c1_w.register_hook(lambda grad: (grad + torch.transpose(grad,1,0))*0.5)
        #     self.conv_c2_w.register_hook(lambda grad: (grad + torch.transpose(grad,1,0))*0.5)

        if gate_bias_init == 'chronos':
            init_chronos = np.log(
                np.random.uniform(1.0, max(float(timesteps - 1), 1.0),
                                  [hidden_size, 1, 1]))

            self.conv_g1_b.data = torch.FloatTensor(init_chronos)
            self.conv_g2_b.data = torch.FloatTensor(-init_chronos)
        else:
            init.constant_(self.conv_g1_b, 1)
            init.constant_(self.conv_g2_b, 1)

        self.alpha = nn.Parameter(torch.empty((hidden_size, 1, 1)))
        self.mu = nn.Parameter(torch.empty((hidden_size, 1, 1)))

        self.omega = nn.Parameter(torch.empty((hidden_size, 1, 1)))
        # self.gamma = nn.Parameter(torch.empty((hidden_size,1,1)))
        self.kappa = nn.Parameter(torch.empty((hidden_size, 1, 1)))

        init.constant_(self.alpha, 0.1)
        init.constant_(self.mu, 1.0)

        init.constant_(self.omega, 0.5)
        # init.constant_(self.gamma, 1.0)
        init.constant_(self.kappa, 0.5)