Пример #1
0
def set_up_predictor(
    method,  # type: str
    n_unit,  # type: int
    conv_layers,  # type: int
    class_num,  # type: int
    label_scaler=None,  # type: Optional[chainer.Link]
    postprocess_fn=None,  # type: Optional[chainer.FunctionNode]
    conv_kwargs=None  # type: Optional[Dict[str, Any]]
):
    # type: (...) -> GraphConvPredictor
    """Set up the predictor, consisting of a GCN and a MLP.

    Args:
        method (str): Method name.
        n_unit (int): Number of hidden units.
        conv_layers (int): Number of convolutional layers for the graph
            convolution network.
        class_num (int): Number of output classes.
        label_scaler (chainer.Link or None): scaler link
        postprocess_fn (chainer.FunctionNode or None):
            postprocess function for prediction.
        conv_kwargs (dict): keyword args for GraphConvolution model.
    """
    mlp = MLP(out_dim=class_num, hidden_dim=n_unit)  # type: Optional[MLP]
    if conv_kwargs is None:
        conv_kwargs = {}

    if method == 'nfp':
        print('Set up NFP predictor...')
        conv = NFP(out_dim=n_unit,
                   hidden_channels=n_unit,
                   n_update_layers=conv_layers,
                   **conv_kwargs)
    elif method == 'ggnn':
        print('Set up GGNN predictor...')
        conv = GGNN(out_dim=n_unit,
                    hidden_channels=n_unit,
                    n_update_layers=conv_layers,
                    **conv_kwargs)
    elif method == 'schnet':
        print('Set up SchNet predictor...')
        conv = SchNet(out_dim=class_num,
                      hidden_channels=n_unit,
                      n_update_layers=conv_layers,
                      **conv_kwargs)
        mlp = None
    elif method == 'weavenet':
        print('Set up WeaveNet predictor...')
        conv = WeaveNet(hidden_dim=n_unit, **conv_kwargs)
    elif method == 'rsgcn':
        print('Set up RSGCN predictor...')
        conv = RSGCN(out_dim=n_unit,
                     hidden_channels=n_unit,
                     n_update_layers=conv_layers,
                     **conv_kwargs)
    elif method == 'relgcn':
        print('Set up Relational GCN predictor...')
        num_edge_type = 4
        conv = RelGCN(out_dim=n_unit,
                      n_edge_types=num_edge_type,
                      scale_adj=True,
                      **conv_kwargs)
    elif method == 'relgat':
        print('Set up Relational GAT predictor...')
        conv = RelGAT(out_dim=n_unit,
                      hidden_channels=n_unit,
                      n_update_layers=conv_layers,
                      **conv_kwargs)
    elif method == 'gin':
        print('Set up GIN predictor...')
        conv = GIN(out_dim=n_unit,
                   hidden_channels=n_unit,
                   n_update_layers=conv_layers,
                   **conv_kwargs)
    elif method == 'nfp_gwm':
        print('Set up NFP_GWM predictor...')
        conv = NFP_GWM(out_dim=n_unit,
                       hidden_channels=n_unit,
                       n_update_layers=conv_layers,
                       **conv_kwargs)
    elif method == 'ggnn_gwm':
        print('Set up GGNN_GWM predictor...')
        conv = GGNN_GWM(out_dim=n_unit,
                        hidden_channels=n_unit,
                        n_update_layers=conv_layers,
                        **conv_kwargs)
    elif method == 'rsgcn_gwm':
        print('Set up RSGCN_GWM predictor...')
        conv = RSGCN_GWM(out_dim=n_unit,
                         hidden_channels=n_unit,
                         n_update_layers=conv_layers,
                         **conv_kwargs)
    elif method == 'gin_gwm':
        print('Set up GIN_GWM predictor...')
        conv = GIN_GWM(out_dim=n_unit,
                       hidden_channels=n_unit,
                       n_update_layers=conv_layers,
                       **conv_kwargs)
    else:
        raise ValueError('[ERROR] Invalid method: {}'.format(method))

    predictor = GraphConvPredictor(conv, mlp, label_scaler, postprocess_fn)
    return predictor
Пример #2
0
def model_with_nfp_no_dropout():
    return RSGCN(out_dim=out_dim,
                 readout=NFPReadout(in_channels=out_dim, out_size=out_dim),
                 dropout_ratio=0.)
Пример #3
0
def model_no_dropout():
    # To check backward gradient by `gradient_check`,
    # we need to skip stochastic dropout function.
    return RSGCN(out_dim=out_dim, dropout_ratio=0.)
Пример #4
0
def model_with_nfp():
    return RSGCN(out_dim=out_dim,
                 readout=NFPReadout(in_channels=out_dim, out_size=out_dim))
Пример #5
0
def model():
    return RSGCN(out_dim=out_dim)