def test_argva(): model = ARGVA(encoder=lambda x: (x, x), discriminator=lambda x: T([0.5])) x = torch.Tensor([[1, -1], [1, 2], [2, 1]]) model.encode(x) model.reparametrize(model.mu, model.logvar) assert model.kl_loss().item() > 0
def test_init(): encoder = torch.nn.Linear(16, 32) decoder = torch.nn.Linear(32, 16) discriminator = torch.nn.Linear(32, 1) GAE(encoder, decoder) VGAE(encoder, decoder) ARGA(encoder, discriminator, decoder) ARGVA(encoder, discriminator, decoder)
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.save_hyperparameters() num_features = kwargs["num_features"] hidden_channels = kwargs["hidden_channels"] in_channels = kwargs["in_channels"] out_channels = kwargs["out_channels"] self._n_critic = kwargs["n_critic"] self.encoder = Encoder(num_features, hidden_channels=hidden_channels, out_channels=out_channels) self.discriminator = Discriminator( in_channels=in_channels, hidden_channels=2 * hidden_channels, out_channels=in_channels, ) self.model = ARGVA(self.encoder, self.discriminator)
self.lin1 = torch.nn.Linear(in_channels, hidden_channels) self.lin2 = torch.nn.Linear(hidden_channels, hidden_channels) self.lin3 = torch.nn.Linear(hidden_channels, out_channels) def forward(self, x): x = F.relu(self.lin1(x)) x = F.relu(self.lin2(x)) x = self.lin3(x) return x encoder = Encoder(data.num_features, hidden_channels=32, out_channels=32) discriminator = Discriminator(in_channels=32, hidden_channels=64, out_channels=32) model = ARGVA(encoder, discriminator) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model, data = model.to(device), data.to(device) discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001) encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.005) def train(): model.train() encoder_optimizer.zero_grad() z = model.encode(data.x, data.train_pos_edge_index) for i in range(5):
super().__init__() self.lin1 = Linear(in_channels, hidden_channels) self.lin2 = Linear(hidden_channels, hidden_channels) self.lin3 = Linear(hidden_channels, out_channels) def forward(self, x): x = self.lin1(x).relu() x = self.lin2(x).relu() return self.lin3(x) encoder = Encoder(train_data.num_features, hidden_channels=32, out_channels=32) discriminator = Discriminator(in_channels=32, hidden_channels=64, out_channels=32) model = ARGVA(encoder, discriminator).to(device) encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.005) discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001) def train(): model.train() encoder_optimizer.zero_grad() z = model.encode(train_data.x, train_data.edge_index) # We optimize the discriminator more frequently than the encoder. for i in range(5): discriminator_optimizer.zero_grad() discriminator_loss = model.discriminator_loss(z)
def __init__(self, embedding_type, dataset, model_name, graph_type="directed", mode="train", n_latent=16, learning_rate=0.001, weight_decay=0, dropout=0, dis_loss_para=1, reg_loss_para=1, epochs=200, gpu=None): # Set device if torch.cuda.is_available() and gpu is not None: os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu) print("Using GPU device: {}.".format(str(gpu))) self.device = torch.device("cuda:" + str(gpu)) else: self.device = "cpu" self.embedding_type = embedding_type self.dataset = dataset self.model_name = model_name self.graph_type = graph_type self.n_latent = n_latent self.learning_rate = learning_rate self.weight_decay = weight_decay self.dropout = dropout self.dis_loss_para = dis_loss_para self.reg_loss_para = reg_loss_para self.epochs = epochs self.mode = mode # Load training data path_data_raw = os.path.join( os.path.dirname(os.path.realpath(__file__)), "..", "..", "..", "data", "interim", "scibert_arga", self.embedding_type) self.data = ARGADataset(path_data_raw, self.embedding_type, dataset, self.graph_type)[0] n_total_features = self.data.num_features # Initialize encoder and discriminator encoder = Encoder(n_total_features, self.n_latent, self.model_name, self.dropout) discriminator = Discriminator(self.n_latent) if self.device is not "cpu": encoder.to(self.device) discriminator.to(self.device) # Choose and initialize model if self.model_name == "ARGA": self.model = ARGA(encoder=encoder, discriminator=discriminator, decoder=None) else: self.model = ARGVA(encoder=encoder, discriminator=discriminator, decoder=None) if self.device is not "cpu": self.model.to(self.device) if self.mode == "train": print("Preprocessing data...") self.data = self.split_edges(self.data) print("Data preprocessed.\n") self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) # Set model file self.model_dir = self._model_dir() self.model_file = f'{self.model_name}_{self.n_latent}_{self.learning_rate}_{self.weight_decay}_{self.dropout}.pt' print('Model: ' + self.model_name) print("\tEmbedding: {}, Dataset: {}, Graph type: {}".format( self.embedding_type, self.dataset, self.graph_type)) print("\tHidden units: {}".format(self.n_latent)) print("\tLearning rate: {}".format(self.learning_rate)) print("\tWeight decay: {}".format(self.weight_decay)) print("\tDropout: {}\n".format(self.dropout)) print("\tEpochs: {}\n".format(self.epochs))