class Beam_Classifier(nn.Module): def __init__(self, n_antenna, n_wide_beam, n_narrow_beam, trainable_codebook = True, theta = None): super(Beam_Classifier, self).__init__() self.trainable_codebook = trainable_codebook self.n_antenna = n_antenna self.n_wide_beam = n_wide_beam self.n_narrow_beam = n_narrow_beam if trainable_codebook: self.codebook = PhaseShifter(in_features=2*n_antenna, out_features=n_wide_beam, scale=np.sqrt(n_antenna), theta=theta) else: dft_codebook = DFT_codebook_blockmatrix(n_antenna=n_antenna, nseg=n_wide_beam) self.codebook = torch.from_numpy(dft_codebook).float() self.codebook.requires_grad = False self.compute_power = ComputePower(2*n_wide_beam) self.relu = nn.ReLU() self.dense1 = nn.Linear(in_features=n_wide_beam, out_features=2*n_wide_beam) self.dense2 = nn.Linear(in_features=2*n_wide_beam, out_features=3*n_wide_beam) self.dense3 = nn.Linear(in_features=3*n_wide_beam, out_features=n_narrow_beam) self.softmax = nn.Softmax() def forward(self, x): if self.trainable_codebook: bf_signal = self.codebook(x) else: bf_signal = torch.matmul(x,self.codebook) bf_power = self.compute_power(bf_signal) out = self.relu(bf_power) out = self.relu(self.dense1(out)) out = self.relu(self.dense2(out)) out = self.dense3(out) return out def get_codebook(self) -> np.ndarray: if self.trainable_codebook: return self.codebook.get_weights().detach().clone().numpy() else: return DFT_codebook(nseg=self.n_wide_beam,n_antenna=self.n_antenna).T
class AnalogBeamformer(nn.Module): def __init__(self, n_antenna = 64, n_beam = 64, theta = None): super(AnalogBeamformer, self).__init__() self.codebook = PhaseShifter(in_features=2*n_antenna, out_features=n_beam, scale=np.sqrt(n_antenna), theta = theta) self.beam_selection = PowerPooling(2*n_beam) self.compute_power = ComputePower(2*n_beam) def forward(self, x, z) -> None: bf_signal = self.codebook(x) # bf_power_sel = self.beam_selection(bf_signal) # return bf_power_sel if not z is None: diff = z - bf_signal.detach().clone() bf_signal = bf_signal + diff bf_power = self.compute_power(bf_signal) bf_power_sel = torch.max(bf_power, dim=-1)[0] bf_power_sel = torch.unsqueeze(bf_power_sel,dim=-1) else: bf_power_sel = self.beam_selection(bf_signal) return bf_power_sel def get_theta(self) -> torch.Tensor: return self.codebook.get_theta() def get_weights(self) -> torch.Tensor: return self.codebook.get_weights()