Example #1
0
    def __init__(self, args, p_array):
        super(DEC_LargeRNN_rate2, self).__init__()
        self.args = args

        use_cuda = not args.no_cuda and torch.cuda.is_available()
        self.this_device = torch.device("cuda" if use_cuda else "cpu")

        self.interleaver          = Interleaver(args, p_array)
        self.deinterleaver        = DeInterleaver(args, p_array)

        self.dec1_rnns      = torch.nn.ModuleList()
        self.dec2_rnns      = torch.nn.ModuleList()
        self.dec1_outputs   = torch.nn.ModuleList()
        self.dec2_outputs   = torch.nn.ModuleList()

        for idx in range(args.num_iteration):

            self.dec1_rnns.append(torch.nn.GRU(1 + args.num_iter_ft,  args.dec_num_unit,
                                               num_layers=2, bias=True, batch_first=True,
                                               dropout=args.dropout, bidirectional=True))

            self.dec2_rnns.append(torch.nn.GRU(1+ args.num_iter_ft,  args.dec_num_unit,
                                           num_layers=2, bias=True, batch_first=True,
                                           dropout=args.dropout, bidirectional=True))

            self.dec1_outputs.append(torch.nn.Linear(2*args.dec_num_unit, args.num_iter_ft))

            if idx == args.num_iteration -1:
                self.dec2_outputs.append(torch.nn.Linear(2*args.dec_num_unit, 1))
            else:
                self.dec2_outputs.append(torch.nn.Linear(2*args.dec_num_unit, args.num_iter_ft))
Example #2
0
    def __init__(self, args, p_array):
        super(DEC_LargeCNN, self).__init__()
        self.args = args

        use_cuda = not args.no_cuda and torch.cuda.is_available()
        self.this_device = torch.device("cuda" if use_cuda else "cpu")

        self.interleaver = Interleaver(args, p_array)
        self.deinterleaver = DeInterleaver(args, p_array)

        self.dec1_cnns = torch.nn.ModuleList()
        self.dec2_cnns = torch.nn.ModuleList()
        self.dec1_outputs = torch.nn.ModuleList()
        self.dec2_outputs = torch.nn.ModuleList()

        for idx in range(args.num_iteration):
            self.dec1_cnns.append(
                SameShapeConv1d(num_layer=args.dec_num_layer,
                                in_channels=2 + args.num_iter_ft,
                                out_channels=args.dec_num_unit,
                                kernel_size=args.dec_kernel_size))

            self.dec2_cnns.append(
                SameShapeConv1d(num_layer=args.dec_num_layer,
                                in_channels=2 + args.num_iter_ft,
                                out_channels=args.dec_num_unit,
                                kernel_size=args.dec_kernel_size))
            self.dec1_outputs.append(
                torch.nn.Linear(args.dec_num_unit, args.num_iter_ft))

            if idx == args.num_iteration - 1:
                self.dec2_outputs.append(torch.nn.Linear(args.dec_num_unit, 1))
            else:
                self.dec2_outputs.append(
                    torch.nn.Linear(args.dec_num_unit, args.num_iter_ft))
Example #3
0
    def __init__(self, args, p_array):
        super(TurboAE_decoder1D, self).__init__()
        self.args = args
        cuda = True if torch.cuda.is_available() else False
        self.this_device = torch.device("cuda" if cuda else "cpu")

        self.interleaver          = Interleaver(args, p_array)
        self.deinterleaver        = DeInterleaver(args, p_array)

        self.dec1_cnns      = torch.nn.ModuleList()
        self.dec2_cnns      = torch.nn.ModuleList()
        self.dec1_outputs   = torch.nn.ModuleList()
        self.dec2_outputs   = torch.nn.ModuleList()

        for idx in range(args.num_iteration):
            self.dec1_cnns.append(SameShapeConv1d(num_layer=args.dec_num_layer, in_channels=2 + args.num_iter_ft,
                                                  out_channels= args.dec_num_unit, kernel_size = args.dec_kernel_size)
            )

            self.dec2_cnns.append(SameShapeConv1d(num_layer=args.dec_num_layer, in_channels=2 + args.num_iter_ft,
                                                  out_channels= args.dec_num_unit, kernel_size = args.dec_kernel_size)
            )
            self.dec1_outputs.append(torch.nn.Linear(args.dec_num_unit, args.num_iter_ft))

            if idx == args.num_iteration -1:
                self.dec2_outputs.append(torch.nn.Linear(args.dec_num_unit, args.code_rate_k))
            else:
                self.dec2_outputs.append(torch.nn.Linear(args.dec_num_unit, args.num_iter_ft))

        # also need some CNN for f
        self.ftstart =  SameShapeConv2d(num_layer=args.dec_num_layer, in_channels=args.img_channels,
                                                  out_channels= args.dec_num_unit, kernel_size = args.dec_kernel_size)
        self.ftend   =  SameShapeConv2d(num_layer=1, in_channels=args.dec_num_unit,
                                                  out_channels= args.img_channels, kernel_size = args.dec_kernel_size)
Example #4
0
    def __init__(self, args, p_array):
        super(FTAE_Shareddecoder, self).__init__()

        use_cuda = not args.no_cuda and torch.cuda.is_available()
        self.this_device = torch.device("cuda" if use_cuda else "cpu")

        self.args = args
        # interleaver
        self.p_array = p_array
        self.interleaver = Interleaver(args, p_array)
        self.deinterleaver = DeInterleaver(args, p_array)

        if args.cnn_type == 'dense':
            CNNModel = DenseSameShapeConv1d
        else:
            CNNModel = SameShapeConv1d

        self.dec1_cnns = CNNModel(num_layer=args.dec_num_layer,
                                  in_channels=2 + args.num_iter_ft,
                                  out_channels=args.dec_num_unit,
                                  kernel_size=args.dec_kernel_size)
        self.dec1_outputs = torch.nn.Linear(args.dec_num_unit,
                                            args.num_iter_ft)
        self.dec2_cnns = CNNModel(num_layer=args.dec_num_layer,
                                  in_channels=2 + args.num_iter_ft,
                                  out_channels=args.dec_num_unit,
                                  kernel_size=args.dec_kernel_size)
        self.dec2_outputs = torch.nn.Linear(args.dec_num_unit,
                                            args.num_iter_ft)

        self.final_outputs = torch.nn.Linear(args.num_iter_ft, 1)
Example #5
0
    def __init__(self, args, p_array):
        super(NeuralTurbofyDec, self).__init__()

        self.args             = args

        self.interleaver          = Interleaver(args, p_array)
        self.deinterleaver        = DeInterleaver(args, p_array)

        self.dec_rnn  = torch.nn.GRU(args.code_rate_n + args.num_iter_ft - 1 , args.dec_num_unit, num_layers=2, bias=True, batch_first=True,
                                   dropout=args.dropout, bidirectional=True)
        self.dec_out = torch.nn.Linear(2*args.dec_num_unit, args.num_iter_ft)

        self.dec_final = torch.nn.Linear(args.num_iter_ft, 1)

        use_cuda = not args.no_cuda and torch.cuda.is_available()
        self.device = torch.device("cuda" if use_cuda else "cpu")
Example #6
0
    def __init__(self, args, p_array):
        super(DEC_LargeCNN2Int, self).__init__()
        self.args = args

        use_cuda = not args.no_cuda and torch.cuda.is_available()
        self.this_device = torch.device("cuda" if use_cuda else "cpu")

        self.interleaver1 = Interleaver(args, p_array)
        self.deinterleaver1 = DeInterleaver(args, p_array)

        seed2 = 1000
        rand_gen2 = mtrand.RandomState(seed2)
        p_array2 = rand_gen2.permutation(arange(args.block_len))

        print('p_array1 dec', p_array)
        print('p_array2 dec', p_array2)

        self.interleaver2 = Interleaver(args, p_array2)
        self.deinterleaver2 = DeInterleaver(args, p_array2)

        self.dec1_cnns = torch.nn.ModuleList()
        self.dec2_cnns = torch.nn.ModuleList()
        self.dec1_outputs = torch.nn.ModuleList()
        self.dec2_outputs = torch.nn.ModuleList()

        for idx in range(args.num_iteration):
            self.dec1_cnns.append(
                SameShapeConv1d(num_layer=args.dec_num_layer,
                                in_channels=2 + args.num_iter_ft,
                                out_channels=args.dec_num_unit,
                                kernel_size=args.dec_kernel_size))

            self.dec2_cnns.append(
                SameShapeConv1d(num_layer=args.dec_num_layer,
                                in_channels=2 + args.num_iter_ft,
                                out_channels=args.dec_num_unit,
                                kernel_size=args.dec_kernel_size))
            self.dec1_outputs.append(
                torch.nn.Linear(args.dec_num_unit, args.num_iter_ft))

            if idx == args.num_iteration - 1:
                self.dec2_outputs.append(torch.nn.Linear(args.dec_num_unit, 1))
            else:
                self.dec2_outputs.append(
                    torch.nn.Linear(args.dec_num_unit, args.num_iter_ft))
Example #7
0
    def __init__(self, args, p_array):
        super(FTAE_decoder, self).__init__()

        use_cuda = not args.no_cuda and torch.cuda.is_available()
        self.this_device = torch.device("cuda" if use_cuda else "cpu")

        self.args = args
        # interleaver
        self.p_array = p_array
        self.interleaver          = Interleaver(args, p_array)
        self.deinterleaver        = DeInterleaver(args, p_array)

        # Decoder
        self.dec1_cnns      = torch.nn.ModuleList()
        self.dec2_cnns      = torch.nn.ModuleList()
        self.dec1_outputs   = torch.nn.ModuleList()
        self.dec2_outputs   = torch.nn.ModuleList()

        if args.cnn_type =='dense':
            CNNModel = DenseSameShapeConv1d
        else:
            CNNModel = SameShapeConv1d

        for idx in range(args.num_iteration):
            if self.args.codec == 'turboae_blockdelay_cnn':
                self.dec1_cnns.append(CNNModel(num_layer=args.dec_num_layer, in_channels=2 + args.num_iter_ft,
                                                      out_channels= args.dec_num_unit, kernel_size = args.dec_kernel_size)
                )

                self.dec2_cnns.append(CNNModel(num_layer=args.dec_num_layer, in_channels=2 + args.num_iter_ft,
                                                      out_channels= args.dec_num_unit, kernel_size = args.dec_kernel_size)
                )
                self.dec1_outputs.append(torch.nn.Linear(args.dec_num_unit, args.num_iter_ft))

                if idx == args.num_iteration -1:
                    self.dec2_outputs.append(torch.nn.Linear(args.dec_num_unit, args.code_rate_k))
                else:
                    self.dec2_outputs.append(torch.nn.Linear(args.dec_num_unit, args.num_iter_ft))
            else: # RNN based

                self.dec1_cnns.append(torch.nn.GRU(2 + args.num_iter_ft, args.dec_num_unit,
                                      num_layers=args.dec_num_layer, bias=True, batch_first=True,
                                      dropout=0, bidirectional=True)
                )

                self.dec2_cnns.append(torch.nn.GRU(2 + args.num_iter_ft, args.dec_num_unit,
                                      num_layers=args.dec_num_layer, bias=True, batch_first=True,
                                      dropout=0, bidirectional=True)
                )
                self.dec1_outputs.append(torch.nn.Linear(2*args.dec_num_unit, args.num_iter_ft))

                if idx == args.num_iteration -1:
                    self.dec2_outputs.append(torch.nn.Linear(2*args.dec_num_unit, args.code_rate_k))
                else:
                    self.dec2_outputs.append(torch.nn.Linear(2*args.dec_num_unit, args.num_iter_ft))
Example #8
0
class DEC_LargeCNN(torch.nn.Module):
    def __init__(self, args, p_array):
        super(DEC_LargeCNN, self).__init__()
        self.args = args

        use_cuda = not args.no_cuda and torch.cuda.is_available()
        self.this_device = torch.device("cuda" if use_cuda else "cpu")

        self.interleaver = Interleaver(args, p_array)
        self.deinterleaver = DeInterleaver(args, p_array)

        self.dec1_cnns = torch.nn.ModuleList()
        self.dec2_cnns = torch.nn.ModuleList()
        self.dec1_outputs = torch.nn.ModuleList()
        self.dec2_outputs = torch.nn.ModuleList()

        for idx in range(args.num_iteration):
            self.dec1_cnns.append(
                SameShapeConv1d(num_layer=args.dec_num_layer,
                                in_channels=2 + args.num_iter_ft,
                                out_channels=args.dec_num_unit,
                                kernel_size=args.dec_kernel_size))

            self.dec2_cnns.append(
                SameShapeConv1d(num_layer=args.dec_num_layer,
                                in_channels=2 + args.num_iter_ft,
                                out_channels=args.dec_num_unit,
                                kernel_size=args.dec_kernel_size))
            self.dec1_outputs.append(
                torch.nn.Linear(args.dec_num_unit, args.num_iter_ft))

            if idx == args.num_iteration - 1:
                self.dec2_outputs.append(torch.nn.Linear(args.dec_num_unit, 1))
            else:
                self.dec2_outputs.append(
                    torch.nn.Linear(args.dec_num_unit, args.num_iter_ft))

    def set_parallel(self):
        for idx in range(self.args.num_iteration):
            self.dec1_cnns[idx] = torch.nn.DataParallel(self.dec1_cnns[idx])
            self.dec2_cnns[idx] = torch.nn.DataParallel(self.dec2_cnns[idx])
            self.dec1_outputs[idx] = torch.nn.DataParallel(
                self.dec1_outputs[idx])
            self.dec2_outputs[idx] = torch.nn.DataParallel(
                self.dec2_outputs[idx])

    def set_interleaver(self, p_array):
        self.interleaver.set_parray(p_array)
        self.deinterleaver.set_parray(p_array)

    def forward(self, received):
        received = received.type(torch.FloatTensor).to(self.this_device)
        # Turbo Decoder
        r_sys = received[:, :, 0].view(
            (self.args.batch_size, self.args.block_len, 1))
        r_sys_int = self.interleaver(r_sys)
        r_par1 = received[:, :, 1].view(
            (self.args.batch_size, self.args.block_len, 1))
        r_par2 = received[:, :, 2].view(
            (self.args.batch_size, self.args.block_len, 1))

        #num_iteration,
        prior = torch.zeros((self.args.batch_size, self.args.block_len,
                             self.args.num_iter_ft)).to(self.this_device)

        for idx in range(self.args.num_iteration - 1):
            x_this_dec = torch.cat([r_sys, r_par1, prior], dim=2)

            x_dec = self.dec1_cnns[idx](x_this_dec)
            x_plr = self.dec1_outputs[idx](x_dec)

            if self.args.extrinsic:
                x_plr = x_plr - prior

            x_plr_int = self.interleaver(x_plr)

            x_this_dec = torch.cat([r_sys_int, r_par2, x_plr_int], dim=2)

            x_dec = self.dec2_cnns[idx](x_this_dec)

            x_plr = self.dec2_outputs[idx](x_dec)

            if self.args.extrinsic:
                x_plr = x_plr - x_plr_int

            prior = self.deinterleaver(x_plr)

        # last round
        x_this_dec = torch.cat([r_sys, r_par1, prior], dim=2)

        x_dec = self.dec1_cnns[self.args.num_iteration - 1](x_this_dec)
        x_plr = self.dec1_outputs[self.args.num_iteration - 1](x_dec)

        if self.args.extrinsic:
            x_plr = x_plr - prior

        x_plr_int = self.interleaver(x_plr)

        x_this_dec = torch.cat([r_sys_int, r_par2, x_plr_int], dim=2)

        x_dec = self.dec2_cnns[self.args.num_iteration - 1](x_this_dec)
        x_plr = self.dec2_outputs[self.args.num_iteration - 1](x_dec)

        final = torch.sigmoid(self.deinterleaver(x_plr))

        return final
Example #9
0
class DEC_LargeRNN(torch.nn.Module):
    def __init__(self, args, p_array):
        super(DEC_LargeRNN, self).__init__()
        self.args = args

        use_cuda = not args.no_cuda and torch.cuda.is_available()
        self.this_device = torch.device("cuda" if use_cuda else "cpu")

        self.interleaver = Interleaver(args, p_array)
        self.deinterleaver = DeInterleaver(args, p_array)

        if args.dec_rnn == 'gru':
            RNN_MODEL = torch.nn.GRU
        elif args.dec_rnn == 'lstm':
            RNN_MODEL = torch.nn.LSTM
        else:
            RNN_MODEL = torch.nn.RNN

        self.dropout = torch.nn.Dropout(args.dropout)

        self.dec1_rnns = torch.nn.ModuleList()
        self.dec2_rnns = torch.nn.ModuleList()
        self.dec1_outputs = torch.nn.ModuleList()
        self.dec2_outputs = torch.nn.ModuleList()

        for idx in range(args.num_iteration):
            self.dec1_rnns.append(
                RNN_MODEL(2 + args.num_iter_ft,
                          args.dec_num_unit,
                          num_layers=2,
                          bias=True,
                          batch_first=True,
                          dropout=args.dropout,
                          bidirectional=True))

            self.dec2_rnns.append(
                RNN_MODEL(2 + args.num_iter_ft,
                          args.dec_num_unit,
                          num_layers=2,
                          bias=True,
                          batch_first=True,
                          dropout=args.dropout,
                          bidirectional=True))

            self.dec1_outputs.append(
                torch.nn.Linear(2 * args.dec_num_unit, args.num_iter_ft))

            if idx == args.num_iteration - 1:
                self.dec2_outputs.append(
                    torch.nn.Linear(2 * args.dec_num_unit, 1))
            else:
                self.dec2_outputs.append(
                    torch.nn.Linear(2 * args.dec_num_unit, args.num_iter_ft))

    def dec_act(self, inputs):
        if self.args.dec_act == 'tanh':
            return F.tanh(inputs)
        elif self.args.dec_act == 'elu':
            return F.elu(inputs)
        elif self.args.dec_act == 'relu':
            return F.relu(inputs)
        elif self.args.dec_act == 'selu':
            return F.selu(inputs)
        elif self.args.dec_act == 'sigmoid':
            return F.sigmoid(inputs)
        elif self.args.dec_act == 'linear':
            return inputs
        else:
            return inputs

    def set_parallel(self):
        for idx in range(self.args.num_iteration):
            self.dec1_rnns[idx] = torch.nn.DataParallel(self.dec1_rnns[idx])
            self.dec2_rnns[idx] = torch.nn.DataParallel(self.dec2_rnns[idx])
            self.dec1_outputs[idx] = torch.nn.DataParallel(
                self.dec1_outputs[idx])
            self.dec2_outputs[idx] = torch.nn.DataParallel(
                self.dec2_outputs[idx])

    def set_interleaver(self, p_array):
        self.interleaver.set_parray(p_array)
        self.deinterleaver.set_parray(p_array)

    def forward(self, received):
        received = received.type(torch.FloatTensor).to(self.this_device)
        # Turbo Decoder
        r_sys = received[:, :, 0].view(
            (self.args.batch_size, self.args.block_len, 1))
        r_sys_int = self.interleaver(r_sys)
        r_par1 = received[:, :, 1].view(
            (self.args.batch_size, self.args.block_len, 1))
        r_par2 = received[:, :, 2].view(
            (self.args.batch_size, self.args.block_len, 1))

        #num_iteration,
        prior = torch.zeros((self.args.batch_size, self.args.block_len,
                             self.args.num_iter_ft)).to(self.this_device)

        for idx in range(self.args.num_iteration - 1):
            x_this_dec = torch.cat([r_sys, r_par1, prior], dim=2)

            if self.args.is_parallel:
                self.dec1_rnns[idx].module.flatten_parameters()
            x_dec, _ = self.dec1_rnns[idx](x_this_dec)
            x_plr = self.dec_act(self.dropout(self.dec1_outputs[idx](x_dec)))

            if self.args.extrinsic:
                x_plr = x_plr - prior

            x_plr_int = self.interleaver(x_plr)

            x_this_dec = torch.cat([r_sys_int, r_par2, x_plr_int], dim=2)

            if self.args.is_parallel:
                self.dec2_rnns[idx].module.flatten_parameters()
            x_dec, _ = self.dec2_rnns[idx](x_this_dec)
            x_plr = self.dec_act(self.dropout(self.dec2_outputs[idx](x_dec)))

            if self.args.extrinsic:
                x_plr = x_plr - x_plr_int

            prior = self.deinterleaver(x_plr)

        # last round
        x_this_dec = torch.cat([r_sys, r_par1, prior], dim=2)

        if self.args.is_parallel:
            self.dec1_rnns[self.args.num_iteration -
                           1].module.flatten_parameters()

        x_dec, _ = self.dec1_rnns[self.args.num_iteration - 1](x_this_dec)
        x_plr = self.dec_act(
            self.dropout(self.dec1_outputs[self.args.num_iteration -
                                           1](x_dec)))

        if self.args.extrinsic:
            x_plr = x_plr - prior

        x_plr_int = self.interleaver(x_plr)

        x_this_dec = torch.cat([r_sys_int, r_par2, x_plr_int], dim=2)

        if self.args.is_parallel:
            self.dec2_rnns[self.args.num_iteration -
                           1].module.flatten_parameters()

        x_dec, _ = self.dec2_rnns[self.args.num_iteration - 1](x_this_dec)

        x_plr = self.dec_act(
            self.dropout(self.dec2_outputs[self.args.num_iteration -
                                           1](x_dec)))

        logit = self.deinterleaver(x_plr)

        final = torch.sigmoid(logit)

        return final
Example #10
0
class NeuralTurbofyDec(torch.nn.Module):
    def __init__(self, args, p_array):
        super(NeuralTurbofyDec, self).__init__()

        self.args             = args

        self.interleaver          = Interleaver(args, p_array)
        self.deinterleaver        = DeInterleaver(args, p_array)

        self.dec_rnn  = torch.nn.GRU(args.code_rate_n + args.num_iter_ft - 1 , args.dec_num_unit, num_layers=2, bias=True, batch_first=True,
                                   dropout=args.dropout, bidirectional=True)
        self.dec_out = torch.nn.Linear(2*args.dec_num_unit, args.num_iter_ft)

        self.dec_final = torch.nn.Linear(args.num_iter_ft, 1)

        use_cuda = not args.no_cuda and torch.cuda.is_available()
        self.device = torch.device("cuda" if use_cuda else "cpu")

    def enc_act(self, inputs):
        if self.enc_act == 'tanh':
            return  F.tanh(inputs)
        elif self.enc_act == 'elu':
            return F.elu(inputs)
        elif self.enc_act == 'relu':
            return F.relu(inputs)
        elif self.enc_act == 'selu':
            return F.selu(inputs)
        elif self.enc_act == 'sigmoid':
            return F.sigmoid(inputs)
        else:
            return inputs

    def set_interleaver(self, p_array):
        self.interleaver.set_parray(p_array)
        self.deinterleaver.set_parray(p_array)

    def forward(self, inputs):
        inputs = inputs.type(torch.FloatTensor).to(self.device)
        ##############################################################
        #
        # Neural Turbo Decoder
        #
        ##############################################################
        input_shape = inputs.shape
        # Turbo Decoder
        r_sys     = inputs[:,:,0].view((input_shape[0], self.args.block_len, 1))
        r_sys_int = self.interleaver(r_sys)
        r_par1    = inputs[:,:,1].view((input_shape[0], self.args.block_len, 1))
        r_par2    = inputs[:,:,2].view((input_shape[0], self.args.block_len, 1))

        #num_iteration,
        prior = torch.zeros((input_shape[0], self.args.block_len, self.args.num_iter_ft)).to(self.device)

        for idx in range(self.args.num_iteration - 1):

            x_this_dec = torch.cat([r_sys, r_par1, prior], dim = 2)
            x_dec, _   = self.dec_rnn(x_this_dec)
            x_plr      = self.dec_out(x_dec)

            if not self.args.extrinsic:
                x_plr = x_plr - prior

            x_plr_int  = self.interleaver(x_plr)

            x_this_dec = torch.cat([r_sys_int,r_par2, x_plr_int ], dim = 2)

            x_dec, _   = self.dec_rnn(x_this_dec)
            x_plr      = self.dec_out(x_dec)

            if not self.args.extrinsic:
                x_plr = x_plr - x_plr_int

            prior      = self.deinterleaver(x_plr)

        # last round
        x_this_dec = torch.cat([r_sys, r_par1, prior], dim = 2)
        x_dec, _   = self.dec_rnn(x_this_dec)
        x_plr      = self.dec_out(x_dec)

        if not self.args.extrinsic:
            x_plr = x_plr - prior

        x_plr_int  = self.interleaver(x_plr)

        x_this_dec = torch.cat([r_sys_int,r_par2, x_plr_int ], dim = 2)
        x_dec, _   = self.dec_rnn(x_this_dec)
        x_dec      = self.dec_out(x_dec)
        x_final    = self.dec_final(x_dec)
        x_plr      = torch.sigmoid(x_final)

        final      = self.deinterleaver(x_plr)

        return final
Example #11
0
class DEC_LargeRNN_rate2(torch.nn.Module):
    def __init__(self, args, p_array):
        super(DEC_LargeRNN_rate2, self).__init__()
        self.args = args

        use_cuda = not args.no_cuda and torch.cuda.is_available()
        self.this_device = torch.device("cuda" if use_cuda else "cpu")

        self.interleaver          = Interleaver(args, p_array)
        self.deinterleaver        = DeInterleaver(args, p_array)

        self.dec1_rnns      = torch.nn.ModuleList()
        self.dec2_rnns      = torch.nn.ModuleList()
        self.dec1_outputs   = torch.nn.ModuleList()
        self.dec2_outputs   = torch.nn.ModuleList()

        for idx in range(args.num_iteration):

            self.dec1_rnns.append(torch.nn.GRU(1 + args.num_iter_ft,  args.dec_num_unit,
                                               num_layers=2, bias=True, batch_first=True,
                                               dropout=args.dropout, bidirectional=True))

            self.dec2_rnns.append(torch.nn.GRU(1+ args.num_iter_ft,  args.dec_num_unit,
                                           num_layers=2, bias=True, batch_first=True,
                                           dropout=args.dropout, bidirectional=True))

            self.dec1_outputs.append(torch.nn.Linear(2*args.dec_num_unit, args.num_iter_ft))

            if idx == args.num_iteration -1:
                self.dec2_outputs.append(torch.nn.Linear(2*args.dec_num_unit, 1))
            else:
                self.dec2_outputs.append(torch.nn.Linear(2*args.dec_num_unit, args.num_iter_ft))

    def set_interleaver(self, p_array):
        self.interleaver.set_parray(p_array)
        self.deinterleaver.set_parray(p_array)

    def set_parallel(self):
        for idx in range(self.args.num_iteration):
            self.dec1_rnns[idx] = torch.nn.DataParallel(self.dec1_rnns[idx])
            self.dec2_rnns[idx] = torch.nn.DataParallel(self.dec2_rnns[idx])
            self.dec1_outputs[idx] = torch.nn.DataParallel(self.dec1_outputs[idx])
            self.dec2_outputs[idx] = torch.nn.DataParallel(self.dec2_outputs[idx])

    def forward(self, received):

        # Turbo Decoder
        r_sys     = received[:,:,0].view((self.args.batch_size, self.args.block_len, 1))
        r_int     = received[:,:,1].view((self.args.batch_size, self.args.block_len, 1))

        prior = torch.zeros((self.args.batch_size, self.args.block_len, self.args.num_iter_ft)).to(self.this_device)

        for idx in range(self.args.num_iteration - 1):

            x_this_dec = torch.cat([r_sys,  prior], dim = 2)
            x_dec, _   = self.dec1_rnns[idx](x_this_dec)
            x_plr      = self.dec1_outputs[idx](x_dec)

            if self.args.extrinsic:
                x_plr = x_plr - prior

            x_plr_int  = self.interleaver(x_plr)

            x_this_dec = torch.cat([r_int, x_plr_int ], dim = 2)
            x_dec, _   = self.dec2_rnns[idx](x_this_dec)
            x_plr      = self.dec2_outputs[idx](x_dec)

            if self.args.extrinsic:
                x_plr = x_plr - x_plr_int

            prior      = self.deinterleaver(x_plr)

        # last round
        x_this_dec = torch.cat([r_sys,  prior], dim = 2)
        x_dec, _   = self.dec1_rnns[self.args.num_iteration - 1](x_this_dec)
        x_plr      = self.dec1_outputs[self.args.num_iteration - 1](x_dec)

        if self.args.extrinsic:
            x_plr = x_plr - prior

        x_plr_int  = self.interleaver(x_plr)

        x_this_dec = torch.cat([r_int, x_plr_int ], dim = 2)
        x_dec, _   = self.dec2_rnns[self.args.num_iteration - 1](x_this_dec)
        x_plr      = self.dec2_outputs[self.args.num_iteration - 1](x_dec)

        final      = torch.sigmoid(self.deinterleaver(x_plr))

        return final