コード例 #1
0
 def __init__(self, ngf = 32, n_layers = 5):
     super(TextureGenerator, self).__init__()
     
     modelList = []
     modelList.append(ReplicationPad2d(padding=4))
     modelList.append(Conv2d(out_channels=ngf, kernel_size=9, padding=0, in_channels=3))
     modelList.append(ReLU())
     modelList.append(myTConv(ngf*2, 2, ngf))
     modelList.append(myTConv(ngf*4, 2, ngf*2))
     
     for n in range(int(n_layers/2)): 
         modelList.append(myTBlock(ngf*4, p=0.0))
     # dropout to make model more robust
     modelList.append(myTBlock(ngf*4, p=0.5))
     for n in range(int(n_layers/2)+1,n_layers):
         modelList.append(myTBlock(ngf*4, p=0.0))  
     
     modelList.append(ConvTranspose2d(out_channels=ngf*2, kernel_size=4, stride=2, padding=0, in_channels=ngf*4))
     modelList.append(BatchNorm2d(num_features=ngf*2, track_running_stats=True))
     modelList.append(ReLU())
     modelList.append(ConvTranspose2d(out_channels=ngf, kernel_size=4, stride=2, padding=0, in_channels=ngf*2))
     modelList.append(BatchNorm2d(num_features=ngf, track_running_stats=True))
     modelList.append(ReLU())
     modelList.append(ReplicationPad2d(padding=1))
     modelList.append(Conv2d(out_channels=3, kernel_size=9, padding=0, in_channels=ngf))
     modelList.append(Tanh())
     self.model = nn.Sequential(*modelList)
コード例 #2
0
    def __init__(self, ngf=32, n_layers = 5):
        super(GlyphGenerator, self).__init__()
        
        encoder = []
        encoder.append(ReplicationPad2d(padding=4))
        encoder.append(Conv2d(out_channels=ngf, kernel_size=9, padding=0, in_channels=3))
        encoder.append(LeakyReLU(0.2))
        encoder.append(myGConv(ngf*2, 2, ngf))
        encoder.append(myGConv(ngf*4, 2, ngf*2))

        transformer = []
        for n in range(int(n_layers/2)-1):
            transformer.append(myGCombineBlock(ngf*4,p=0.0))
        # dropout to make model more robust    
        transformer.append(myGCombineBlock(ngf*4,p=0.5))
        transformer.append(myGCombineBlock(ngf*4,p=0.5))
        for n in range(int(n_layers/2)+1,n_layers):
            transformer.append(myGCombineBlock(ngf*4,p=0.0))  
        
        decoder = []
        decoder.append(ConvTranspose2d(out_channels=ngf*2, kernel_size=4, stride=2, padding=0, in_channels=ngf*4))
        decoder.append(BatchNorm2d(num_features=ngf*2, track_running_stats=True))
        decoder.append(LeakyReLU(0.2))
        decoder.append(ConvTranspose2d(out_channels=ngf, kernel_size=4, stride=2, padding=0, in_channels=ngf*2))
        decoder.append(BatchNorm2d(num_features=ngf, track_running_stats=True))
        decoder.append(LeakyReLU(0.2))
        decoder.append(ReplicationPad2d(padding=1))
        decoder.append(Conv2d(out_channels=3, kernel_size=9, padding=0, in_channels=ngf))
        decoder.append(Tanh())
        
        self.encoder = nn.Sequential(*encoder)
        self.transformer = nn.Sequential(*transformer)
        self.decoder = nn.Sequential(*decoder)
コード例 #3
0
    def __init__(self):
        super(VAE, self).__init__()
        self.e1 = Conv2d(IMAGE_CHANNEL, NDF, 4, 2, 1)
        self.bn1 = BatchNorm2d(NDF)

        self.e2 = Conv2d(NDF, NDF * 2, 4, 2, 1)
        self.bn2 = BatchNorm2d(NDF * 2)

        self.e3 = Conv2d(NDF * 2, NDF * 4, 4, 2, 1)
        self.bn3 = BatchNorm2d(NDF * 4)

        self.e4 = Conv2d(NDF * 4, NDF * 8, 4, 2, 1)
        self.bn4 = BatchNorm2d(NDF * 8)

        self.e5 = Conv2d(NDF * 8, NDF * 8, 4, 2, 1)
        self.bn5 = BatchNorm2d(NDF * 8)

        self.fc1 = Linear(NDF * 8 * 4 * 4, LATENT_VARIABLE_SIZE)
        self.fc2 = Linear(NDF * 8 * 4 * 4, LATENT_VARIABLE_SIZE)

        # decoder
        self.d1 = Linear(LATENT_VARIABLE_SIZE, NGF * 8 * 2 * 4 * 4)

        self.up1 = UpsamplingNearest2d(scale_factor=2)
        self.pd1 = ReplicationPad2d(1)
        self.d2 = Conv2d(NGF * 8 * 2, NGF * 8, 3, 1)
        self.bn6 = BatchNorm2d(NGF * 8, 1.e-3)

        self.up2 = UpsamplingNearest2d(scale_factor=2)
        self.pd2 = ReplicationPad2d(1)
        self.d3 = Conv2d(NGF * 8, NGF * 4, 3, 1)
        self.bn7 = BatchNorm2d(NGF * 4, 1.e-3)

        self.up3 = UpsamplingNearest2d(scale_factor=2)
        self.pd3 = ReplicationPad2d(1)
        self.d4 = Conv2d(NGF * 4, NGF * 2, 3, 1)
        self.bn8 = BatchNorm2d(NGF * 2, 1.e-3)

        self.up4 = UpsamplingNearest2d(scale_factor=2)
        self.pd4 = ReplicationPad2d(1)
        self.d5 = Conv2d(NGF * 2, NGF, 3, 1)
        self.bn9 = BatchNorm2d(NGF, 1.e-3)

        self.up5 = UpsamplingNearest2d(scale_factor=2)
        self.pd5 = ReplicationPad2d(1)
        self.d6 = Conv2d(NGF, IMAGE_CHANNEL, 3, 1)

        self.leakyrelu = LeakyReLU(0.2)
        self.relu = ReLU()
        self.sigmoid = Sigmoid()
コード例 #4
0
 def __init__(self, num_filter=128):
     super(myGBlock, self).__init__()
     
     self.myconv = myGConv(num_filter=num_filter, stride=1, in_channels=num_filter)
     self.pad = ReplicationPad2d(padding=1)
     self.conv = Conv2d(out_channels=num_filter, kernel_size=3, padding=0, in_channels=num_filter)
     self.bn = BatchNorm2d(num_features=num_filter, track_running_stats=True)
コード例 #5
0
 def __init__(self, num_filter=128, stride=1, in_channels=128):
     super(myGConv, self).__init__()
     self.pad = ReplicationPad2d(padding=1)
     self.conv = Conv2d(out_channels=num_filter, kernel_size=3, 
                        stride=stride, padding=0, in_channels=in_channels)
     self.bn = BatchNorm2d(num_features=num_filter, track_running_stats=True)
     # either ReLU or LeakyReLU is OK
     self.relu = LeakyReLU(0.2)
コード例 #6
0
 def __init__(self, num_filter=128, p=0.0):
     super(myTBlock, self).__init__()
     
     self.myconv = myTConv(num_filter=num_filter, stride=1, in_channels=128)
     self.pad = ReplicationPad2d(padding=1)
     self.conv = Conv2d(out_channels=num_filter, kernel_size=3, padding=0, in_channels=128)
     self.bn = BatchNorm2d(num_features=num_filter, track_running_stats=True)
     self.relu = ReLU()
     self.dropout = nn.Dropout(p=p)
コード例 #7
0
    def __call__(self, sample):
        long_exp, short_exp, therm = (
            sample["long_exposure"],
            sample["short_exposure"],
            sample["thermal_response"],
        )

        therm = therm.resize(self.therm_shape)
        m = ReplicationPad2d(self.padding)
        therm_tensor = transforms.ToTensor()(therm).unsqueeze_(0)
        therm_tensor = m(therm_tensor)

        # input and output are PIL image
        therm = transforms.ToPILImage()(therm_tensor.squeeze_(0))
        return {
            "long_exposure": long_exp,
            "short_exposure": short_exp,
            "thermal_response": therm,
        }
コード例 #8
0
    def forward(self, x):
        x = self.net(x)
        nt, c, h, w = x.size()
        c_per_group = c // self.n_group
        n_batch = nt // self.n_segment
        if self.fuse_correlation:
            neighbor_k = self.correlation_num
            boundary_pad = neighbor_k // 2
            if self.fuse_downsample:
                m = ReplicationPad2d(boundary_pad)
                x_pad = m(x)
                correlation_list = []
                for i in range(neighbor_k):
                    for j in range(neighbor_k):
                        correlation = x[::2] * x_pad[1::2, :, i:i + h, j:j + w]
                        if self.GroupConv:
                            correlation = correlation.view(
                                nt // 2, self.n_group, c_per_group, h, w)
                            correlation = correlation.sum(dim=2)
                        else:
                            correlation = correlation.sum(dim=1)
                        correlation_list.append(correlation)
                if self.GroupConv:
                    x_correlation = torch.stack(correlation_list).permute(
                        1, 0, 2, 3,
                        4).contiguous().view(nt // 2, self.in_channels, h, w)
                else:
                    NotImplementedError
                x_correlation = self.weight(x_correlation)
                x_pad = x_pad.view((nt, c_per_group, self.n_group) +
                                   x_pad.shape[-2:]).permute(1, 0, 2, 3, 4)
                x_correlation = x_correlation.view(nt // 2, neighbor_k,
                                                   neighbor_k,
                                                   self.n_group * 2, h, w)
                x_list = []
                for i in range(neighbor_k):
                    for j in range(neighbor_k):
                        x_list.append(x_pad[:,::2,:,i:i+h, j:j+w]*x_correlation[:,i,j,::2] + \
                            x_pad[:,1::2,:,i:i+h, j:j+w]*x_correlation[:,i,j,1::2])
                x = torch.stack(x_list).sum(dim=0).permute(
                    1, 0, 2, 3, 4).contiguous().view(nt // 2, c, h, w)
                x = x / 2 if self.fuse_ave else x
                return self.bn_out(x)
                # x = x.view(nt, c_per_group, self.n_group, h, w).permute(1,0,2,3,4)
                # x = x[:,::2]*x_correlation[:,::2] + x[:,1::2]*x_correlation[:,1::2]
                # x = x.permute(1,0,2,3,4).view(nt//2, c, h, w)
                # x = x/2 if self.fuse_ave else x
                # return self.bn_out(x)
            else:
                temporal_pad = 1
                m1 = ReplicationPad1d(temporal_pad)
                from pdb import set_trace
                set_trace()
                from IPython import embed
                embed()
                x = x.view(n_batch, self.n_segment, c, h,
                           w).permute(0, 2, 3, 4, 1)
                x = x.contiguous().view(n_batch * c, -1, self.n_segment)
                x = m1(x)
                pad_t = x.shape[-1]
                x = x.view(n_batch * c, h, w, pad_t).permute(0, 3, 1, 2)
                m2 = ReplicationPad2d(boundary_pad)
                x_pad = m2(x)
                correlation_list = []
                for i in range(neighbor_k):
                    for j in range(neighbor_k):
                        correlation_past = x[:, 1:1 +
                                             self.n_segment] * x_pad[:, :self.
                                                                     n_segment,
                                                                     i:i + h,
                                                                     j:j + w]
                        correlation_now = x[:, 1:1 + self.
                                            n_segment] * x_pad[:, 1:1 +
                                                               self.n_segment,
                                                               i:i + h,
                                                               j:j + w]
                        correlation_future = x[:, 1:1 + self.
                                               n_segment] * x_pad[:, 2:2 + self
                                                                  .n_segment,
                                                                  i:i + h,
                                                                  j:j + w]
                        if self.GroupConv:
                            correlation_past = correlation_past.view(
                                nt, self.n_group, c_per_group, h, w)
                            correlation_past = correlation_past.sum(dim=2)
                            correlation_now = correlation_now.view(
                                nt, self.n_group, c_per_group, h, w)
                            correlation_now = correlation_now.sum(dim=2)
                            correlation_future = correlation_future.view(
                                nt, self.n_group, c_per_group, h, w)
                            correlation_future = correlation_future.sum(dim=2)
                        else:
                            correlation = correlation.sum(dim=1)
                        correlation_list.append(correlation_past)
                        correlation_list.append(correlation_now)
                        correlation_list.append(correlation_future)
                if self.GroupConv:
                    x_correlation = torch.stack(correlation_list).permute(
                        1, 0, 2, 3,
                        4).contiguous().view(nt, self.in_channels, h, w)
                else:
                    NotImplementedError
                x_correlation = self.weight(x_correlation)
                x_pad = x_pad.view((n_batch, c_per_group, self.n_group) +
                                   x_pad.shape[-3:])
                x_correlation = x_correlation.view(n_batch, self.n_segment,
                                                   neighbor_k, neighbor_k,
                                                   self.n_group * 3, h, w)
                x_list = []
                for i in range(neighbor_k):
                    for j in range(neighbor_k):
                        x_list.append(x_pad[:,::2,:,i:i+h, j:j+w]*x_correlation[:,i,j,::2] + \
                            x_pad[:,1::2,:,i:i+h, j:j+w]*x_correlation[:,i,j,1::2])
                x = torch.stack(x_list).sum(dim=0).permute(
                    1, 0, 2, 3, 4).contiguous().view(nt // 2, c, h, w)
                x = x / 2 if self.fuse_ave else x
                return self.bn_out(x)
                # x = x.view(nt, c_per_group, self.n_group, h, w).permute(1,0,2,3,4)
                # x = x[:,::2]*x_correlation[:,::2] + x[:,1::2]*x_correlation[:,1::2]
                # x = x.permute(1,0,2,3,4).view(nt//2, c, h, w)
                # x = x/2 if self.fuse_ave else x
                # return self.bn_out(x)

        else:
            x = x.view(n_batch, self.n_segment, c, h,
                       w).permute(0, 2, 1, 3, 4)  # x: [n, c, t, h, w ]
            if self.dilation:
                x_list = []
                for weight_layer in self.weight_list:
                    weight = weight_layer(x)
                    weight = weight.permute(
                        1, 0, 2, 3, 4)  # weight: [n, t//2, 2*n_group, h,w]
                    xx = x.reshape(
                        n_batch, c_per_group, self.n_group, self.n_segment, h,
                        w).permute(1, 2, 0, 3, 4,
                                   5)  #xx: [c_per_group, n_group, n, t, h, w]
                    xx = xx[:, :, :, ::2] * weight[::2] + xx[:, :, :, 1::
                                                             2] * weight[1::2]
                    x_list.append(
                        xx.permute(2, 3, 0, 1, 4,
                                   5).contiguous().view(nt // 2, c, h, w))
                x = torch.stack(x_list).sum(dim=0, keepdim=True)
                return x.squeeze(0)
            else:
                if self.fuse_downsample:
                    weight = self.weight(x).permute(
                        1, 0, 2, 3, 4)  # weight: [2*n_group, n, t//2, h,w]
                    x = x.reshape(
                        n_batch, c_per_group, self.n_group, self.n_segment, h,
                        w).permute(1, 2, 0, 3, 4,
                                   5)  #xx: [c_per_group, n_group, n, t, h, w]
                    x = x[:, :, :, ::2] * weight[::2] + x[:, :, :,
                                                          1::2] * weight[1::2]
                    x = x.permute(2, 3, 0, 1, 4,
                                  5).contiguous().view(nt // 2, c, h, w)
                    x = x / 2 if self.fuse_ave else x
                    return self.bn_out(x)
                else:
                    weight = self.weight(x)
                    x = x.permute(1, 0, 3, 4,
                                  2).contiguous().view(c, -1, self.n_segment)
                    m = ReplicationPad1d(1)
                    # from pdb import set_trace;set_trace()
                    # from IPython import embed;embed()
                    x = m(x).view(c_per_group, -1, self.n_segment + 2)
                    weight = weight.permute(1, 0, 3, 4, 2).contiguous().view(
                        3, -1, self.n_segment).permute(2, 1, 0)
                    x_list = []
                    for i in range(self.n_segment):
                        x_list.append(
                            (x[:, :, i:i + 3] * weight[i]).sum(dim=2))
                    x = torch.stack(x_list).view(
                        self.n_segment, c, n_batch,
                        h * w).permute(2, 0, 1,
                                       3).contiguous().view(nt, c, h, w)
                    x = x / 3 if self.fuse_ave else x
                    return self.bn_out(x)