class LSTMConsensus(nn.Module): """LSTM consensus module. Args: input_size (int): The number of expected features in the input x num_layers (int): Number of recurrent layers Default: 1. dim (int): Decide which dim consensus function to apply. Default: 1. """ def __init__(self, input_size, hidden_size, num_layers=1, batch_first=True): super().__init__() self.h0 = Parameter(torch.zeros((num_layers, 1, hidden_size))) self.c0 = Parameter(torch.zeros((num_layers, 1, hidden_size))) self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=batch_first) self.hidden_size = hidden_size self.num_layers = num_layers def forward(self, x): """Defines the computation performed at every call.""" h0 = self.h0.repeat_interleave(x.size(0), dim=1) c0 = self.c0.repeat_interleave(x.size(0), dim=1) ht = self.lstm(x.view(x.size(0), x.size(1), -1), (h0, c0))[0] return ht[:, -1, :].view(x.size(0), 1, self.hidden_size, *(x.size()[3:]))
class CompositionalRecognizer(nn.Module): def __init__(self, class_count=10, hidden_size=512, dropout=0): super().__init__() self.class_count = class_count self.hidden_size = hidden_size self.comp_count = class_count**2 + int(comb(class_count, 2)) + class_count self.h0 = Parameter(torch.zeros((1, 1, hidden_size))) self.c0 = Parameter(torch.zeros((1, 1, hidden_size))) self.lstm = nn.LSTM(class_count, hidden_size, batch_first=True, dropout=dropout) self.fc = nn.Linear(hidden_size, self.comp_count) self.loss = nn.BCEWithLogitsLoss() def forward(self, x): h0 = self.h0.repeat_interleave(x.size(0), dim=1) c0 = self.c0.repeat_interleave(x.size(0), dim=1) ht = self.lstm(x.view(x.size(0), x.size(1), -1), (h0, c0))[0][:, -1, :] return self.fc(ht) def forward_loss(self, x, labels): result = self(x) return self.loss(result, labels)