예제 #1
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=(3, 3),
                 stride=(1, 1),
                 padding="same",
                 do_actnorm=True,
                 weight_std=0.05):
        super().__init__()

        if padding == "same":
            padding = compute_same_pad(kernel_size, stride)
        elif padding == "valid":
            padding = 0

        self.conv = nn.Conv2d(in_channels,
                              out_channels,
                              kernel_size,
                              stride,
                              padding,
                              bias=(not do_actnorm))

        # init weight with std
        self.conv.weight.data.normal_(mean=0.0, std=weight_std)

        if not do_actnorm:
            self.conv.bias.data.zero_()
        else:
            self.actnorm = ActNorm2d(out_channels)

        self.do_actnorm = do_actnorm
예제 #2
0
    def __init__(self, in_channels, out_channels,
                 kernel_size=(3, 3), stride=(1, 1),
                 padding="same", logscale_factor=3):
        super().__init__()

        if padding == "same":
            padding = compute_same_pad(kernel_size, stride)
        elif padding == "valid":
            padding = 0

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
                              padding)

        self.conv.weight.data.zero_()
        self.conv.bias.data.zero_()

        self.logscale_factor = logscale_factor
        self.logs = nn.Parameter(torch.zeros(out_channels, 1, 1))
예제 #3
0
    def __init__(self,
                 in_channels,
                 hidden_channels,
                 out_channels,
                 num_classes,
                 kernel_size=(3, 3),
                 stride=(1, 1),
                 padding="same",
                 do_actnorm=True,
                 weight_std=0.05,
                 logscale_factor=3):
        super().__init__()

        if padding == "same":
            padding_a = compute_same_pad((3, 3), (1, 1))
            padding_b = compute_same_pad((1, 1), (1, 1))
        elif padding == "valid":
            padding = 0

        # self.in_channels = in_channels
        # # print("in", in_channels)
        # self.hidden_channels = hidden_channels
        # # print("hidden", hidden_channels)
        # self.out_channels = out_channels
        # # print("hidden", in_channels)
        # self.project_ycond_in = LinearZeros(num_classes, hidden_channels)
        # self.project_ycond_hidden = LinearZeros(num_classes, hidden_channels)
        # self.project_ycond_out = LinearZeros(num_classes, out_channels)

        self.cond_fc1 = nn.Linear(
            num_classes,
            int(in_channels * 96 * 96 / (in_channels * in_channels)))
        self.conv1 = nn.Conv2d(in_channels + in_channels,
                               hidden_channels,
                               stride,
                               padding_a,
                               bias=(not do_actnorm))
        self.cond_conv1 = nn.Conv2d(in_channels,
                                    in_channels,
                                    kernel_size,
                                    stride,
                                    padding_a,
                                    bias=(not do_actnorm))

        self.cond_fc2 = nn.Linear(
            num_classes,
            int(in_channels * 96 * 96 / (in_channels * in_channels)))
        self.conv2 = nn.Conv2d(hidden_channels + in_channels,
                               hidden_channels,
                               kernel_size=(1, 1),
                               stride=stride,
                               padding=padding_b,
                               bias=(not do_actnorm))
        self.cond_conv2 = nn.Conv2d(in_channels,
                                    in_channels,
                                    kernel_size=(1, 1),
                                    stride=stride,
                                    padding=padding_b,
                                    bias=(not do_actnorm))

        self.cond_fc3 = nn.Linear(
            num_classes,
            int(in_channels * 96 * 96 / (in_channels * in_channels)))
        self.conv3 = nn.Conv2d(hidden_channels + in_channels, out_channels,
                               kernel_size, stride, padding_a)

        self.cond_conv3 = nn.Conv2d(in_channels,
                                    in_channels,
                                    kernel_size=kernel_size,
                                    stride=stride,
                                    padding=padding_a)

        # init weight with std
        nn.init.normal_(self.cond_fc1.weight, 0., std=weight_std)
        nn.init.normal_(self.cond_fc2.weight, 0., std=weight_std)
        nn.init.normal_(self.cond_fc3.weight, 0., std=weight_std)

        self.conv1.weight.data.normal_(mean=0.0, std=weight_std)
        self.cond_conv1.weight.data.normal_(mean=0.0, std=weight_std)

        self.conv2.weight.data.normal_(mean=0.0, std=weight_std)
        self.cond_conv2.weight.data.normal_(mean=0.0, std=weight_std)

        self.conv3.weight.data.zero_()
        self.conv3.bias.data.zero_()
        self.cond_conv3.weight.data.zero_()
        self.cond_conv3.bias.data.zero_()

        self.logscale_factor = logscale_factor
        self.logs = nn.Parameter(torch.zeros(out_channels, 1, 1))

        if not do_actnorm:
            self.conv.bias.data.zero_()
        else:
            self.actnorm1 = ActNorm2d(hidden_channels)
            self.actnorm2 = ActNorm2d(hidden_channels)

        self.do_actnorm = do_actnorm