class Encoder(nn.Module): def __init__( self, PARAM_FILE, LANG_FILE, WEIGHT_FILE, device, in_channels=256, hidden_channels=512, batch_mode='Packed' ): super().__init__() self.paccmann_vae = EncoderPaccmann( PARAM_FILE, LANG_FILE, WEIGHT_FILE, batch_size=1, batch_mode='Packed' ) self.paccmann_vae.gru_vae.encoder.set_batch_mode('packed') self.conv = GCNConv(in_channels, hidden_channels, cached=False) self.prelu = nn.PReLU(hidden_channels) self.to(device) def forward(self, x, edge_index): x = self.paccmann_vae.encode(x) x = self.conv(x, edge_index) x = self.prelu(x) return x def update_batch_size(self, batch_size: int): """Paccmann's batch size needs to be adjusted before passing forward an input. Args: batch_size (int) """ self.paccmann_vae.gru_decoder._update_batch_size(batch_size) self.paccmann_vae.gru_encoder._update_batch_size(batch_size) def to(self, device): self.paccmann_vae.gru_vae.to(device) self.conv.to(device) self.prelu.to(device)