Example #1
0
def set_up_predictor(method, n_unit, conv_layers, class_num, scaler):
    """Sets up the predictor, consisting of a graph convolution network and
    a multilayer perceptron.

    Args:
        method (str): Method name.
        n_unit (int): Number of hidden units.
        conv_layers (int): Number of convolutional layers for the graph
            convolution network.
        class_num (int): Number of output classes.
    Returns:
        predictor (chainer.Chain): An instance of the selected predictor.
    """
    mlp = MLP(out_dim=class_num, hidden_dim=n_unit)

    if method == 'nfp':
        print('Training an NFP predictor...')
        nfp = NFP(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
        predictor = GraphConvPredictor(nfp, mlp, scaler)
    elif method == 'ggnn':
        print('Training a GGNN predictor...')
        ggnn = GGNN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
        predictor = GraphConvPredictor(ggnn, mlp, scaler)
    elif method == 'schnet':
        print('Training an SchNet predictor...')
        schnet = SchNet(out_dim=class_num,
                        hidden_dim=n_unit,
                        n_layers=conv_layers)
        predictor = GraphConvPredictor(schnet, None, scaler)
    elif method == 'weavenet':
        print('Training a WeaveNet predictor...')
        n_atom = 20
        n_sub_layer = 1
        weave_channels = [50] * conv_layers

        weavenet = WeaveNet(weave_channels=weave_channels,
                            hidden_dim=n_unit,
                            n_sub_layer=n_sub_layer,
                            n_atom=n_atom)
        predictor = GraphConvPredictor(weavenet, mlp, scaler)
    elif method == 'rsgcn':
        print('Training an RSGCN predictor...')
        rsgcn = RSGCN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
        predictor = GraphConvPredictor(rsgcn, mlp, scaler)
    elif method == 'relgcn':
        print('Use Relational GCN predictor...')
        num_edge_type = 4
        relgcn = RelGCN(out_channels=n_unit,
                        num_edge_type=num_edge_type,
                        scale_adj=True)
        predictor = GraphConvPredictor(relgcn, mlp, scaler)
    elif method == 'relgat':
        print('Train Relational GAT predictor...')
        relgat = RelGAT(out_dim=n_unit,
                        hidden_dim=n_unit,
                        n_layers=conv_layers)
        predictor = GraphConvPredictor(relgat, mlp, scaler)
    else:
        raise ValueError('[ERROR] Invalid method: {}'.format(method))
    return predictor
Example #2
0
def set_up_predictor(method, n_unit, conv_layers, class_num):
    """Sets up the graph convolution network  predictor.

    Args:
        method: Method name. Currently, the supported ones are `nfp`, `ggnn`,
                `schnet`, `weavenet` and `rsgcn`.
        n_unit: Number of hidden units.
        conv_layers: Number of convolutional layers for the graph convolution
                     network.
        class_num: Number of output classes.

    Returns:
        An instance of the selected predictor.
    """

    predictor = None
    mlp = MLP(out_dim=class_num, hidden_dim=n_unit)

    if method == 'nfp':
        print('Training an NFP predictor...')
        nfp = NFP(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
        predictor = GraphConvPredictor(nfp, mlp)
    elif method == 'ggnn':
        print('Training a GGNN predictor...')
        ggnn = GGNN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
        predictor = GraphConvPredictor(ggnn, mlp)
    elif method == 'schnet':
        print('Training an SchNet predictor...')
        schnet = SchNet(out_dim=class_num,
                        hidden_dim=n_unit,
                        n_layers=conv_layers)
        predictor = GraphConvPredictor(schnet, None)
    elif method == 'weavenet':
        print('Training a WeaveNet predictor...')
        n_atom = 20
        n_sub_layer = 1
        weave_channels = [50] * conv_layers

        weavenet = WeaveNet(weave_channels=weave_channels,
                            hidden_dim=n_unit,
                            n_sub_layer=n_sub_layer,
                            n_atom=n_atom)
        predictor = GraphConvPredictor(weavenet, mlp)
    elif method == 'rsgcn':
        print('Training an RSGCN predictor...')
        rsgcn = RSGCN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
        predictor = GraphConvPredictor(rsgcn, mlp)
    elif method == 'relgcn':
        print('Training an RelGCN predictor...')
        num_edge_type = 4
        relgcn = RelGCN(out_channels=class_num,
                        num_edge_type=num_edge_type,
                        scale_adj=True)
        predictor = GraphConvPredictor(relgcn, None)
    else:
        raise ValueError('[ERROR] Invalid method: {}'.format(method))
    return predictor
Example #3
0
def build_predictor(method, n_unit, conv_layers, class_num):
    if method == 'nfp':
        print('Use NFP predictor...')
        predictor = GraphConvPredictor(
            NFP(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers),
            MLP(out_dim=class_num, hidden_dim=n_unit))
    elif method == 'ggnn':
        print('Use GGNN predictor...')
        predictor = GraphConvPredictor(
            GGNN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers),
            MLP(out_dim=class_num, hidden_dim=n_unit))
    elif method == 'schnet':
        print('Use SchNet predictor...')
        # MLP layer is not necessary for SchNet
        predictor = GraphConvPredictor(
            SchNet(out_dim=class_num,
                   hidden_dim=n_unit,
                   n_layers=conv_layers,
                   readout_hidden_dim=n_unit), None)
    elif method == 'weavenet':
        print('Use WeaveNet predictor...')
        n_atom = 20
        n_sub_layer = 1
        weave_channels = [50] * conv_layers
        predictor = GraphConvPredictor(
            WeaveNet(weave_channels=weave_channels,
                     hidden_dim=n_unit,
                     n_sub_layer=n_sub_layer,
                     n_atom=n_atom), MLP(out_dim=class_num, hidden_dim=n_unit))
    elif method == 'rsgcn':
        print('Use RSGCN predictor...')
        predictor = GraphConvPredictor(
            RSGCN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers),
            MLP(out_dim=class_num, hidden_dim=n_unit))
    elif method == 'relgcn':
        print('Use Relational GCN predictor...')
        num_edge_type = 4
        predictor = GraphConvPredictor(
            RelGCN(out_channels=n_unit,
                   num_edge_type=num_edge_type,
                   scale_adj=True), MLP(out_dim=class_num, hidden_dim=n_unit))
    elif method == 'relgat':
        print('Use GAT predictor...')
        predictor = GraphConvPredictor(
            RelGAT(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers),
            MLP(out_dim=class_num, hidden_dim=n_unit))
    else:
        raise ValueError('[ERROR] Invalid predictor: method={}'.format(method))
    return predictor
Example #4
0
def set_up_predictor(method, n_unit, conv_layers, class_num):
    """Sets up the predictor, consisting of a graph convolution network and
    a multilayer perceptron.
    Args:
        method: Method name. See `parse_arguments`.
        n_unit: Number of hidden units.
        conv_layers: Number of convolutional layers for the graph convolution
                     network.
        class_num: Number of output classes.
    Returns:
        An instance of the selected predictor.
    """

    mlp = MLP(out_dim=class_num, hidden_dim=n_unit)

    if method == 'nfp':
        print('Training an NFP predictor...')
        nfp = NFP(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
        return GraphConvPredictor(nfp, mlp)
    elif method == 'nfp_gwm':
        print('Training an NFP+GWM predictor...')
        nfp_gwm = NFP_GWM(out_dim=n_unit,
                          hidden_dim=n_unit,
                          hidden_dim_super=n_unit,
                          n_layers=conv_layers,
                          dropout_ratio=0.5)
        return GraphConvPredictorForGWM(nfp_gwm, mlp)
    elif method == 'ggnn':
        print('Training a GGNN predictor...')
        ggnn = GGNN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
        return GraphConvPredictor(ggnn, mlp)
    elif method == 'ggnn_gwm':
        print('Train GGNN+GWM model...')
        ggnn_gwm = GGNN_GWM(out_dim=n_unit,
                            hidden_dim=n_unit,
                            hidden_dim_super=n_unit,
                            n_layers=conv_layers,
                            dropout_ratio=0.5,
                            weight_tying=True)
        return GraphConvPredictorForGWM(ggnn_gwm, mlp)
    elif method == 'schnet':
        print('Training an SchNet predictor...')
        schnet = SchNet(out_dim=class_num,
                        hidden_dim=n_unit,
                        n_layers=conv_layers)
        return GraphConvPredictor(schnet, None)
    elif method == 'weavenet':
        print('Training a WeaveNet predictor...')
        n_atom = 20
        n_sub_layer = 1
        weave_channels = [50] * conv_layers

        weavenet = WeaveNet(weave_channels=weave_channels,
                            hidden_dim=n_unit,
                            n_sub_layer=n_sub_layer,
                            n_atom=n_atom)
        return GraphConvPredictor(weavenet, mlp)
    elif method == 'rsgcn':
        print('Training an RSGCN predictor...')
        rsgcn = RSGCN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
        return GraphConvPredictor(rsgcn, mlp)
    elif method == 'rsgcn_gwm':
        print('Training an RSGCN+GWM predictor...')
        rsgcn_gwm = RSGCN_GWM(out_dim=n_unit,
                              hidden_dim=n_unit,
                              hidden_dim_super=n_unit,
                              n_layers=conv_layers,
                              dropout_ratio=0.5)
        return GraphConvPredictorForGWM(rsgcn_gwm, mlp)
    elif method == 'relgcn':
        print('Training an RelGCN predictor...')
        num_edge_type = 4
        relgcn = RelGCN(out_channels=n_unit,
                        num_edge_type=num_edge_type,
                        scale_adj=True)
        return GraphConvPredictor(relgcn, mlp)
    elif method == 'relgat':
        print('Train Relational GAT model...')
        relgat = RelGAT(out_dim=n_unit,
                        hidden_dim=n_unit,
                        n_layers=conv_layers)
        return GraphConvPredictor(relgat, mlp)
    elif method == 'gin':
        print('Training a GIN predictor...')
        gin = GIN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
        return GraphConvPredictor(gin, mlp)
    elif method == 'gin_gwm':
        print('Training a GIN+GWM predictor...')
        gin_gwm = GIN_GWM(out_dim=n_unit,
                          hidden_dim=n_unit,
                          hidden_dim_super=n_unit,
                          n_layers=conv_layers,
                          dropout_ratio=0.5,
                          weight_tying=True)
        return GraphConvPredictorForGWM(gin_gwm, mlp)
    raise ValueError('[ERROR] Invalid method: {}'.format(method))
def main():
    method_list = ['nfp', 'ggnn', 'schnet', 'weavenet', 'rsgcn', 'relgcn']
    dataset_names = list(molnet_default_config.keys())

    parser = argparse.ArgumentParser(description='molnet example')
    parser.add_argument('--method', '-m', type=str, choices=method_list,
                        default='nfp')
    parser.add_argument('--label', '-l', type=str, default='',
                        help='target label for regression, empty string means '
                        'to predict all property at once')
    parser.add_argument('--conv-layers', '-c', type=int, default=4)
    parser.add_argument('--batchsize', '-b', type=int, default=32)
    parser.add_argument('--gpu', '-g', type=int, default=-1)
    parser.add_argument('--out', '-o', type=str, default='result')
    parser.add_argument('--epoch', '-e', type=int, default=20)
    parser.add_argument('--unit-num', '-u', type=int, default=16)
    parser.add_argument('--dataset', '-d', type=str, choices=dataset_names,
                        default='bbbp')
    parser.add_argument('--protocol', type=int, default=2)
    parser.add_argument('--model-filename', type=str, default='regressor.pkl')
    parser.add_argument('--num-data', type=int, default=-1,
                        help='Number of data to be parsed from parser.'
                             '-1 indicates to parse all data.')
    args = parser.parse_args()
    dataset_name = args.dataset
    method = args.method
    num_data = args.num_data
    n_unit = args.unit_num
    conv_layers = args.conv_layers
    print('Use {} dataset'.format(dataset_name))

    if args.label:
        labels = args.label
        cache_dir = os.path.join('input', '{}_{}_{}'.format(dataset_name,
                                                            method, labels))
        class_num = len(labels) if isinstance(labels, list) else 1
    else:
        labels = None
        cache_dir = os.path.join('input', '{}_{}_all'.format(dataset_name,
                                                             method))
        class_num = len(molnet_default_config[args.dataset]['tasks'])

    # Dataset preparation
    def get_dataset_paths(cache_dir, num_data):
        filepaths = []
        for filetype in ['train', 'valid', 'test']:
            filename = filetype+'_data'
            if num_data >= 0:
                filename += '_' + str(num_data)
            filename += '.npz'
            filepath = os.path.join(cache_dir, filename)
            filepaths.append(filepath)
        return filepaths
    filepaths = get_dataset_paths(cache_dir, num_data)
    if all([os.path.exists(fpath) for fpath in filepaths]):
        datasets = []
        for fpath in filepaths:
            print('load from cache {}'.format(fpath))
            datasets.append(NumpyTupleDataset.load(fpath))
    else:
        print('preprocessing dataset...')
        preprocessor = preprocess_method_dict[method]()
        # only use first 100 for debug if num_data >= 0
        target_index = numpy.arange(num_data) if num_data >= 0 else None
        datasets = D.molnet.get_molnet_dataset(dataset_name, preprocessor,
                                               labels=labels,
                                               target_index=target_index)
        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir)
        datasets = datasets['dataset']
        for i, fpath in enumerate(filepaths):
            NumpyTupleDataset.save(fpath, datasets[i])

    train, val, _ = datasets

    # Network
    if method == 'nfp':
        print('Train NFP model...')
        predictor = GraphConvPredictor(NFP(out_dim=n_unit, hidden_dim=n_unit,
                                       n_layers=conv_layers),
                                       MLP(out_dim=class_num,
                                           hidden_dim=n_unit))
    elif method == 'ggnn':
        print('Train GGNN model...')
        predictor = GraphConvPredictor(GGNN(out_dim=n_unit, hidden_dim=n_unit,
                                            n_layers=conv_layers),
                                       MLP(out_dim=class_num,
                                           hidden_dim=n_unit))
    elif method == 'schnet':
        print('Train SchNet model...')
        predictor = GraphConvPredictor(
            SchNet(out_dim=class_num, hidden_dim=n_unit, n_layers=conv_layers),
            None)
    elif method == 'weavenet':
        print('Train WeaveNet model...')
        n_atom = 20
        n_sub_layer = 1
        weave_channels = [50] * conv_layers
        predictor = GraphConvPredictor(
            WeaveNet(weave_channels=weave_channels, hidden_dim=n_unit,
                     n_sub_layer=n_sub_layer, n_atom=n_atom),
            MLP(out_dim=class_num, hidden_dim=n_unit))
    elif method == 'rsgcn':
        print('Train RSGCN model...')
        predictor = GraphConvPredictor(
            RSGCN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers),
            MLP(out_dim=class_num, hidden_dim=n_unit))
    elif method == 'relgcn':
        print('Train RelGCN model...')
        num_edge_type = 4
        predictor = GraphConvPredictor(
            RelGCN(out_channels=class_num, num_edge_type=num_edge_type,
                   scale_adj=True),
            None)
    else:
        raise ValueError('[ERROR] Invalid method {}'.format(method))

    train_iter = iterators.SerialIterator(train, args.batchsize)
    val_iter = iterators.SerialIterator(val, args.batchsize,
                                        repeat=False, shuffle=False)

    metrics = molnet_default_config[dataset_name]['metrics']
    metrics_fun = {k: v for k, v in metrics.items()
                   if isinstance(v, types.FunctionType)}
    # loss_fun = molnet_default_config[dataset_name]['loss']
    task_type = molnet_default_config[dataset_name]['task_type']
    if task_type == 'regression':
        loss_fun = regression_loss_fun
        model = Regressor(predictor, lossfun=loss_fun, metrics_fun=metrics_fun,
                          device=args.gpu)
        # TODO(nakago): Use standard scaler for regression task
    elif task_type == 'classification':
        loss_fun = F.sigmoid_cross_entropy
        model = Classifier(predictor, lossfun=loss_fun,
                           metrics_fun=metrics_fun, device=args.gpu)
    else:
        raise NotImplementedError(
            'Not implemented task_type = {}'.format(task_type))

    optimizer = optimizers.Adam()
    optimizer.setup(model)

    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu,
                                       converter=concat_mols)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
    trainer.extend(E.Evaluator(val_iter, model, device=args.gpu,
                               converter=concat_mols))
    trainer.extend(E.snapshot(), trigger=(args.epoch, 'epoch'))
    trainer.extend(E.LogReport())
    print_report_targets = ['epoch', 'main/loss', 'validation/main/loss']
    for metric_name, metric_fun in metrics.items():
        if isinstance(metric_fun, types.FunctionType):
            print_report_targets.append('main/' + metric_name)
            print_report_targets.append('validation/main/' + metric_name)
        elif issubclass(metric_fun, BatchEvaluator):
            # Evaluation for train data takes time, skip for now.
            # trainer.extend(metric_fun(
            #     train_iter, model, device=args.gpu, eval_func=predictor,
            #     converter=concat_mols, name='train',
            #     raise_value_error=False))
            # print_report_targets.append('train/main/roc_auc')
            trainer.extend(metric_fun(
                val_iter, model, device=args.gpu, eval_func=predictor,
                converter=concat_mols, name='val',
                raise_value_error=False))
            print_report_targets.append('val/main/' + metric_name)
        else:
            raise TypeError('{} is not supported for metrics function.'
                            .format(type(metrics_fun)))
    print_report_targets.append('elapsed_time')
    trainer.extend(E.PrintReport(print_report_targets))
    trainer.extend(E.ProgressBar())
    trainer.run()

    # --- save model ---
    protocol = args.protocol
    model.save_pickle(os.path.join(args.out, args.model_filename),
                      protocol=protocol)