Ejemplo n.º 1
0
 def download(self):
     download_url(self.url, self.raw_dir, name="processed.zip")
Ejemplo n.º 2
0
    def train(self, data):
        if not os.path.isfile(self.load_path):
            print("=> no checkpoint found at '{}'".format(self.load_path))
            url = "https://github.com/cenyk1230/gcc-data/raw/master/saved/gcc_pretrained.pth"
            path = "/".join(self.load_path.split("/")[:-1])
            name = self.load_path.split("/")[-1]
            download_url(url, path, name=name)

        print("=> loading checkpoint '{}'".format(self.load_path))
        checkpoint = torch.load(self.load_path, map_location="cpu")
        print("=> loaded successfully '{}' (epoch {})".format(
            self.load_path, checkpoint["epoch"]))
        args = checkpoint["opt"]

        args.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        if isinstance(data, list):
            train_dataset = GraphClassificationDataset(
                data=data,
                rw_hops=args.rw_hops,
                subgraph_size=args.subgraph_size,
                restart_prob=args.restart_prob,
                positional_embedding_size=args.positional_embedding_size,
            )
        else:
            train_dataset = NodeClassificationDataset(
                data=data,
                rw_hops=args.rw_hops,
                subgraph_size=args.subgraph_size,
                restart_prob=args.restart_prob,
                positional_embedding_size=args.positional_embedding_size,
            )
        args.batch_size = len(train_dataset)
        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_size=args.batch_size,
            collate_fn=batcher(),
            shuffle=False,
            num_workers=args.num_workers,
        )

        # create model and optimizer
        model = GraphEncoder(
            positional_embedding_size=args.positional_embedding_size,
            max_node_freq=args.max_node_freq,
            max_edge_freq=args.max_edge_freq,
            max_degree=args.max_degree,
            freq_embedding_size=args.freq_embedding_size,
            degree_embedding_size=args.degree_embedding_size,
            output_dim=args.hidden_size,
            node_hidden_dim=args.hidden_size,
            edge_hidden_dim=args.hidden_size,
            num_layers=args.num_layer,
            num_step_set2set=args.set2set_iter,
            num_layer_set2set=args.set2set_lstm_layer,
            gnn_model=args.model,
            norm=args.norm,
            degree_input=True,
        )

        model = model.to(args.device)

        model.load_state_dict(checkpoint["model"])

        del checkpoint

        emb = test_moco(train_loader, model, args)

        return emb.numpy()