Exemple #1
0
class Beam_Classifier(nn.Module):
    def __init__(self, n_antenna, n_wide_beam, n_narrow_beam, trainable_codebook = True, theta = None):
        super(Beam_Classifier, self).__init__()
        self.trainable_codebook = trainable_codebook
        self.n_antenna = n_antenna
        self.n_wide_beam = n_wide_beam
        self.n_narrow_beam = n_narrow_beam
        if trainable_codebook:
            self.codebook = PhaseShifter(in_features=2*n_antenna, out_features=n_wide_beam, scale=np.sqrt(n_antenna), theta=theta)
        else:
            dft_codebook = DFT_codebook_blockmatrix(n_antenna=n_antenna, nseg=n_wide_beam)
            self.codebook = torch.from_numpy(dft_codebook).float()
            self.codebook.requires_grad = False
        self.compute_power = ComputePower(2*n_wide_beam)
        self.relu = nn.ReLU()
        self.dense1 = nn.Linear(in_features=n_wide_beam, out_features=2*n_wide_beam)
        self.dense2 = nn.Linear(in_features=2*n_wide_beam, out_features=3*n_wide_beam)
        self.dense3 = nn.Linear(in_features=3*n_wide_beam, out_features=n_narrow_beam)
        self.softmax = nn.Softmax()
    def forward(self, x):
        if self.trainable_codebook:
            bf_signal = self.codebook(x)
        else:
            bf_signal = torch.matmul(x,self.codebook)
        bf_power = self.compute_power(bf_signal)
        out = self.relu(bf_power)
        out = self.relu(self.dense1(out))
        out = self.relu(self.dense2(out))
        out = self.dense3(out)
        return out
    def get_codebook(self) -> np.ndarray:
        if self.trainable_codebook:
            return self.codebook.get_weights().detach().clone().numpy()
        else:
            return DFT_codebook(nseg=self.n_wide_beam,n_antenna=self.n_antenna).T
class AnalogBeamformer(nn.Module):
    def __init__(self, n_antenna = 64, n_beam = 64, theta = None):
        super(AnalogBeamformer, self).__init__()
        self.codebook = PhaseShifter(in_features=2*n_antenna, out_features=n_beam, scale=np.sqrt(n_antenna), theta = theta)
        self.beam_selection = PowerPooling(2*n_beam)
        self.compute_power = ComputePower(2*n_beam)
    def forward(self, x, z) -> None:
        bf_signal = self.codebook(x)
        # bf_power_sel = self.beam_selection(bf_signal)
        # return bf_power_sel
        if not z is None:
            diff = z - bf_signal.detach().clone()
            bf_signal = bf_signal + diff
            bf_power = self.compute_power(bf_signal)
            bf_power_sel = torch.max(bf_power, dim=-1)[0]
            bf_power_sel = torch.unsqueeze(bf_power_sel,dim=-1)
        else:
            bf_power_sel = self.beam_selection(bf_signal)
        return bf_power_sel
    
    def get_theta(self) -> torch.Tensor:
        return self.codebook.get_theta()
    def get_weights(self) -> torch.Tensor:
        return self.codebook.get_weights()