Пример #1
0
 def __init__(self, num_distr):
     super(ManifoldNetRes, self).__init__()
     self.complex_conv1 = complex.ComplexConv2Deffangle4Dxy(
         1, 20, (5, 1), (5, 1))
     self.proj1 = complex.ReLU4Dsp(20)
     params = {'num_classes': 11, 'num_distr': num_distr, 'num_repeat': 2}
     self.SURE = JS.SURE_pure4D(params, calc_next(128, (5, 1), 5, 20), 20)
     self.relu = nn.ReLU()
     self.dropout = nn.Dropout(0.5)
     self.conv_1 = nn.Conv2d(20, 30, (5, 1))
     self.mp_1 = nn.MaxPool2d((2, 1))
     self.conv_2 = nn.Conv2d(40, 50, (5, 1))
     self.mp_2 = nn.MaxPool2d((2, 1))
     self.conv_3 = nn.Conv2d(60, 80, (3, 1))
     self.bn_1 = nn.BatchNorm2d(30)
     self.bn_2 = nn.BatchNorm2d(50)
     self.bn_3 = nn.BatchNorm2d(80)
     self.linear_2 = nn.Linear(80, 40)
     self.linear_3 = nn.Linear(40, 11)
     self.loss_weight = torch.nn.Parameter(torch.rand(1),
                                           requires_grad=True)
     self.name = "Residual Network"
     self.res1 = nn.Sequential(*self.make_res_block(30, 40))
     self.id1 = nn.Conv2d(30, 40, (1, 1))
     self.res2 = nn.Sequential(*self.make_res_block(50, 60))
     self.id2 = nn.Conv2d(50, 60, (1, 1))
Пример #2
0
 def __init__(self):
     super(ManifoldNetW1, self).__init__()
     self.complex_conv1 = complex.ComplexConv2Deffangle4Dxy(
         1, 20, (5, 1), (5, 1))
     self.proj1 = complex.ReLU4Dsp(20)
     self.relu = nn.ReLU()
     self.dropout = nn.Dropout(0.5)
     self.linear_1 = complex.ComplexLinearangle4Dmw_outfield(500)
     self.conv_1 = nn.Conv2d(20, 40, (7, 1))
     self.mp_1 = nn.MaxPool2d((2, 1))
     self.conv_2 = nn.Conv2d(40, 60, (3, 1))
     self.mp_2 = nn.MaxPool2d((2, 1))
     self.conv_3 = nn.Conv2d(60, 80, (3, 1))
     self.bn_1 = nn.BatchNorm2d(40)
     self.bn_2 = nn.BatchNorm2d(60)
     self.bn_3 = nn.BatchNorm2d(80)
     self.linear_2 = nn.Linear(80, 40)
     self.linear_3 = nn.Linear(40, 11)
     self.name = "Without Shrinkage"
Пример #3
0
 def __init__(self, num_distr):
     super(STFT2, self).__init__()
     self.complex_conv1 = complex.ComplexConv2Deffangle4Dxy(
         1, 20, (1, 3), (1, 3))
     self.proj1 = complex.ReLU4Dsp(20)
     params = {'num_classes': 11, 'num_distr': num_distr, 'num_repeat': 2}
     self.SURE = JS.SURE_pure4D(params, torch.Size([2, 20, 8, 11]), 20)
     self.relu = nn.ReLU()
     self.dropout = nn.Dropout(0.5)
     self.conv_1 = nn.Conv2d(20, 40, (3, 3))
     self.mp_1 = nn.MaxPool2d((1, 2))
     self.conv_2 = nn.Conv2d(40, 60, (3, 3))
     self.mp_2 = nn.MaxPool2d((2, 1))
     self.conv_3 = nn.Conv2d(60, 80, (2, 2))
     self.bn_1 = nn.BatchNorm2d(40)
     self.bn_2 = nn.BatchNorm2d(60)
     self.bn_3 = nn.BatchNorm2d(80)
     self.linear_2 = nn.Linear(80, 40)
     self.linear_3 = nn.Linear(40, 11)
     self.loss_weight = torch.nn.Parameter(torch.rand(1),
                                           requires_grad=True)
     self.name = "Regular Network STFT"