示例#1
0
def load_victim_model(data, model_name='gcn', device='cpu', file_path=None):

    assert model_name == 'gcn', 'Currently only support gcn as victim model...'
    if file_path is None:
        # file_path = f'results/saved_models/{data.name}/{model_name}_checkpoint'
        file_path = 'results/saved_models/{0}/{1}_checkpoint'.format(
            data.name, model_name)
    else:
        file_path = osp.join(file_path, '{}_checkpoint'.format(model_name))

    # Setup victim model
    if osp.exists(file_path):
        victim_model = GCN(nfeat=data.features.shape[1],
                           nclass=data.labels.max().item() + 1,
                           nhid=16,
                           dropout=0.5,
                           weight_decay=5e-4,
                           device=device)

        victim_model.load_state_dict(torch.load(file_path,
                                                map_location=device))
        victim_model.to(device)
        victim_model.eval()
        return victim_model

    victim_model = train_victim_model(data=data,
                                      model_name=model_name,
                                      device=device,
                                      file_path=osp.dirname(file_path))
    return victim_model
示例#2
0
def load_victim_model(data, model_name='gcn', device='cpu', file_path=None):
    """load_victim_model.

    Parameters
    ----------
    data : deeprobust.graph.Dataset
        graph data
    model_name : str
        victime model name, e.g. ('gcn', 'deepwalk') But currently it only
        supports gcn as victim model.
    device : str
        'cpu' or 'cuda'
    file_path :
        if given, the victim model will be loaded from this path.
    """

    assert model_name == 'gcn', 'Currently only support gcn as victim model...'
    if file_path is None:
        # file_path = f'results/saved_models/{data.name}/{model_name}_checkpoint'
        file_path = 'results/saved_models/{0}/{1}_checkpoint'.format(data.name, model_name)
    else:
        file_path = osp.join(file_path, '{}_checkpoint'.format(model_name))

    # Setup victim model
    if osp.exists(file_path):
        victim_model = GCN(nfeat=data.features.shape[1], nclass=data.labels.max().item()+1,
                    nhid=16, dropout=0.5, weight_decay=5e-4, device=device)

        victim_model.load_state_dict(torch.load(file_path, map_location=device))
        victim_model.to(device)
        victim_model.eval()
        return victim_model

    victim_model = train_victim_model(data=data, model_name=model_name,
                                        device=device, file_path=osp.dirname(file_path))
    return victim_model