def __init__(self, layout_dim, input_dim, output_dim, normalization='instance', activation='leakyrelu'): super(RefinementModule, self).__init__() layers = [] layers.append( nn.Conv2d(layout_dim + input_dim, output_dim, kernel_size=3, padding=1)) layers.append(get_normalization_2d(output_dim, normalization)) layers.append(get_activation(activation)) layers.append( nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1)) layers.append(get_normalization_2d(output_dim, normalization)) layers.append(get_activation(activation)) layers = [layer for layer in layers if layer is not None] for layer in layers: if isinstance(layer, nn.Conv2d): nn.init.kaiming_normal_(layer.weight) self.net = nn.Sequential(*layers)
def __init__(self, dims, normalization='instance', activation='leakyrelu'): super(RefinementNetwork, self).__init__() layout_dim = dims[0] self.refinement_modules = nn.ModuleList() for i in range(1, len(dims)): input_dim = 1 if i == 1 else dims[i - 1] output_dim = dims[i] mod = RefinementModule(layout_dim, input_dim, output_dim, normalization=normalization, activation=activation) self.refinement_modules.append(mod) output_conv_layers = [ nn.Conv2d(dims[-1], dims[-1], kernel_size=3, padding=1), get_activation(activation), nn.Conv2d(dims[-1], 3, kernel_size=1, padding=0) ] nn.init.kaiming_normal_(output_conv_layers[0].weight) nn.init.kaiming_normal_(output_conv_layers[2].weight) self.output_conv = nn.Sequential(*output_conv_layers)