コード例 #1
0
def test_argva():
    model = ARGVA(encoder=lambda x: (x, x), discriminator=lambda x: T([0.5]))

    x = torch.Tensor([[1, -1], [1, 2], [2, 1]])
    model.encode(x)
    model.reparametrize(model.mu, model.logvar)
    assert model.kl_loss().item() > 0
コード例 #2
0
def test_init():
    encoder = torch.nn.Linear(16, 32)
    decoder = torch.nn.Linear(32, 16)
    discriminator = torch.nn.Linear(32, 1)

    GAE(encoder, decoder)
    VGAE(encoder, decoder)
    ARGA(encoder, discriminator, decoder)
    ARGVA(encoder, discriminator, decoder)
コード例 #3
0
ファイル: argva.py プロジェクト: tchaton/lightning-geometric
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.save_hyperparameters()

        num_features = kwargs["num_features"]
        hidden_channels = kwargs["hidden_channels"]
        in_channels = kwargs["in_channels"]
        out_channels = kwargs["out_channels"]

        self._n_critic = kwargs["n_critic"]

        self.encoder = Encoder(num_features,
                               hidden_channels=hidden_channels,
                               out_channels=out_channels)

        self.discriminator = Discriminator(
            in_channels=in_channels,
            hidden_channels=2 * hidden_channels,
            out_channels=in_channels,
        )

        self.model = ARGVA(self.encoder, self.discriminator)
コード例 #4
0
        self.lin1 = torch.nn.Linear(in_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.lin3 = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = self.lin3(x)
        return x


encoder = Encoder(data.num_features, hidden_channels=32, out_channels=32)
discriminator = Discriminator(in_channels=32,
                              hidden_channels=64,
                              out_channels=32)
model = ARGVA(encoder, discriminator)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, data = model.to(device), data.to(device)

discriminator_optimizer = torch.optim.Adam(discriminator.parameters(),
                                           lr=0.001)
encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.005)


def train():
    model.train()
    encoder_optimizer.zero_grad()
    z = model.encode(data.x, data.train_pos_edge_index)

    for i in range(5):
コード例 #5
0
        super().__init__()
        self.lin1 = Linear(in_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, hidden_channels)
        self.lin3 = Linear(hidden_channels, out_channels)

    def forward(self, x):
        x = self.lin1(x).relu()
        x = self.lin2(x).relu()
        return self.lin3(x)


encoder = Encoder(train_data.num_features, hidden_channels=32, out_channels=32)
discriminator = Discriminator(in_channels=32,
                              hidden_channels=64,
                              out_channels=32)
model = ARGVA(encoder, discriminator).to(device)

encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.005)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(),
                                           lr=0.001)


def train():
    model.train()
    encoder_optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)

    # We optimize the discriminator more frequently than the encoder.
    for i in range(5):
        discriminator_optimizer.zero_grad()
        discriminator_loss = model.discriminator_loss(z)
コード例 #6
0
    def __init__(self,
                 embedding_type,
                 dataset,
                 model_name,
                 graph_type="directed",
                 mode="train",
                 n_latent=16,
                 learning_rate=0.001,
                 weight_decay=0,
                 dropout=0,
                 dis_loss_para=1,
                 reg_loss_para=1,
                 epochs=200,
                 gpu=None):

        # Set device
        if torch.cuda.is_available() and gpu is not None:
            os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)
            print("Using GPU device: {}.".format(str(gpu)))
            self.device = torch.device("cuda:" + str(gpu))
        else:
            self.device = "cpu"

        self.embedding_type = embedding_type
        self.dataset = dataset
        self.model_name = model_name
        self.graph_type = graph_type
        self.n_latent = n_latent
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.dropout = dropout
        self.dis_loss_para = dis_loss_para
        self.reg_loss_para = reg_loss_para
        self.epochs = epochs
        self.mode = mode

        # Load training data
        path_data_raw = os.path.join(
            os.path.dirname(os.path.realpath(__file__)), "..", "..", "..",
            "data", "interim", "scibert_arga", self.embedding_type)
        self.data = ARGADataset(path_data_raw, self.embedding_type, dataset,
                                self.graph_type)[0]
        n_total_features = self.data.num_features

        # Initialize encoder and discriminator
        encoder = Encoder(n_total_features, self.n_latent, self.model_name,
                          self.dropout)
        discriminator = Discriminator(self.n_latent)
        if self.device is not "cpu":
            encoder.to(self.device)
            discriminator.to(self.device)

        # Choose and initialize model
        if self.model_name == "ARGA":
            self.model = ARGA(encoder=encoder,
                              discriminator=discriminator,
                              decoder=None)
        else:
            self.model = ARGVA(encoder=encoder,
                               discriminator=discriminator,
                               decoder=None)
        if self.device is not "cpu":
            self.model.to(self.device)

        if self.mode == "train":
            print("Preprocessing data...")
            self.data = self.split_edges(self.data)
            print("Data preprocessed.\n")

        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.learning_rate,
                                          weight_decay=self.weight_decay)

        # Set model file
        self.model_dir = self._model_dir()
        self.model_file = f'{self.model_name}_{self.n_latent}_{self.learning_rate}_{self.weight_decay}_{self.dropout}.pt'

        print('Model: ' + self.model_name)
        print("\tEmbedding: {}, Dataset: {}, Graph type: {}".format(
            self.embedding_type, self.dataset, self.graph_type))
        print("\tHidden units: {}".format(self.n_latent))
        print("\tLearning rate: {}".format(self.learning_rate))
        print("\tWeight decay: {}".format(self.weight_decay))
        print("\tDropout: {}\n".format(self.dropout))
        print("\tEpochs: {}\n".format(self.epochs))