Example #1
0
 def forward(self, input):
     return complex_max_pool2d(input,
                               kernel_size=self.kernel_size,
                               stride=self.stride,
                               padding=self.padding,
                               dilation=self.dilation,
                               ceil_mode=self.ceil_mode,
                               return_indices=self.return_indices)
    def forward(self, input):
        """
        Args:
            input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        stack = []
        output_r = input[..., 0]
        output_i = input[..., 1]

        # Apply down-sampling layers
        for i, layer in enumerate(self.down_sample_layers):
            output_r, output_i = layer(output_r, output_i)
            stack.append([output_r, output_i])
            #print(stack)
            output_r, output_i = complex_max_pool2d(output_r,
                                                    output_i,
                                                    kernel_size=2,
                                                    stride=2,
                                                    padding=0)

        output_r, output_i = self.conv(output_r, output_i)

        # Apply up-sampling layers
        for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):
            downsample_layer = stack.pop()
            output_r, output_i = transpose_conv(output_r, output_i)

            # Reflect pad on the right/botton if needed to handle odd input dimensions.
            padding = [0, 0, 0, 0]
            if output_r.shape[-1] != downsample_layer[0].shape[-1]:
                padding[1] = 1  # Padding right
            if output_r.shape[-2] != downsample_layer[0].shape[-2]:
                padding[3] = 1  # Padding bottom
            if sum(padding) != 0:
                output_r = F.pad(output_r, padding, "reflect")
                output_i = F.pad(output_i, padding, "reflect")

            output_r = torch.cat([output_r, downsample_layer[0]], dim=1)
            output_i = torch.cat([output_i, downsample_layer[1]], dim=1)
            output_r, output_i = conv(output_r, output_i)

        if output_r.shape[1] > 1:
            #print("R",F.mse_loss(output_r[input[...,0]!=0], input[...,0][input[...,0]!=0]).item())
            #print("I",F.mse_loss(output_i[input[...,1]!=0], input[...,1][input[...,1]!=0]).item())
            output_r[input[..., 0] != 0], output_i[input[..., 1] != 0] = input[
                ..., 0][input[..., 0] != 0], input[..., 1][input[..., 1] != 0]
            output_r, output_i = self.combine_layer(output_r, output_i)
        output_r, output_i = self.f_layer1(output_r, output_i)
        output_r, output_i = output_r.squeeze(-1).unsqueeze(
            1), output_i.squeeze(-1).unsqueeze(1)
        #print("out",output_r.shape)
        output_r, output_i = self.f_layer2(output_r, output_i)
        output_r, output_i = output_r.squeeze(-1).unsqueeze(
            1), output_i.squeeze(-1).unsqueeze(1)

        return output_r, output_i
Example #3
0
    def forward(self, x):
        xr = x
        # imaginary part to zero
        xi = torch.zeros(xr.shape, dtype=xr.dtype, device=xr.device)
        xr, xi = self.conv1(xr, xi)
        xr, xi = complex_relu(xr, xi)
        xr, xi = complex_max_pool2d(xr, xi, 2, 2)

        xr, xi = self.bn(xr, xi)
        xr, xi = self.conv2(xr, xi)
        xr, xi = complex_relu(xr, xi)
        xr, xi = complex_max_pool2d(xr, xi, 2, 2)

        xr = xr.view(-1, 4 * 4 * 50)
        xi = xi.view(-1, 4 * 4 * 50)
        xr, xi = self.fc1(xr, xi)
        xr, xi = complex_relu(xr, xi)
        xr, xi = self.fc2(xr, xi)
        # take the absolute value as output
        x = torch.sqrt(torch.pow(xr, 2) + torch.pow(xi, 2))
        return F.log_softmax(x, dim=1)