示例#1
0
    def __init__(self, in_channels, out_channels,
                 edge_channels, activation, edge_mode,
                 normalize_emb,
                 aggr):
        super(EGraphSage, self).__init__(aggr=aggr)
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.edge_channels = edge_channels
        self.edge_mode = edge_mode
        self.tanh = nn.Tanh()
        if edge_mode == 0:
            self.message_lin = nn.Linear(in_channels, out_channels)
            self.attention_lin = nn.Linear(2*in_channels+edge_channels, 1)
        elif edge_mode == 1:
            self.message_lin = nn.Linear(in_channels+edge_channels, out_channels)
        elif edge_mode == 2:
            self.message_lin = nn.Linear(2*in_channels+edge_channels, out_channels)
        elif edge_mode == 3:
            self.message_lin = nn.Sequential(
                    nn.Linear(2*in_channels+edge_channels, out_channels),
                    get_activation(activation),
                    nn.Linear(out_channels, out_channels),
                    )
        elif edge_mode == 4:
            self.message_lin = nn.Linear(in_channels, out_channels*edge_channels)
        elif edge_mode == 5:
            self.message_lin = nn.Linear(2*in_channels, out_channels*edge_channels)
        elif edge_mode == 6:
            self.source_lin = nn.Linear(in_channels, 8)
            self.message_lin = nn.Linear(8+edge_channels, out_channels)
        elif edge_mode == 7:
            self.message_lin = nn.Linear(edge_channels, out_channels)
            

        self.agg_lin = nn.Linear(in_channels+out_channels, out_channels)

        self.message_activation = get_activation(activation)
        self.update_activation = get_activation(activation)
        self.normalize_emb = normalize_emb

        ##GGNN
        self.GRU = nn.GRU(out_channels, in_channels, 1)
        self.Lin_GRU = nn.Linear(in_channels, out_channels)

        ##ResNET
        self.lin_res = nn.Sequential(
                nn.Linear(in_channels, out_channels),
                nn.Sigmoid()
        )
示例#2
0
 def build_edge_update_mlps(self, node_dim, edge_input_dim, edge_dim,
                            gnn_layer_num, activation):
     edge_update_mlps = nn.ModuleList()
     edge_update_mlp = nn.Sequential(
         nn.Linear(node_dim + node_dim + edge_input_dim, edge_dim),
         get_activation(activation),
     )
     edge_update_mlps.append(edge_update_mlp)
     for l in range(1, gnn_layer_num):
         edge_update_mlp = nn.Sequential(
             nn.Linear(node_dim + node_dim + edge_dim, edge_dim),
             get_activation(activation),
         )
         edge_update_mlps.append(edge_update_mlp)
     return edge_update_mlps
示例#3
0
 def build_node_post_mlp(self, input_dim, output_dim, hidden_dims, dropout,
                         activation):
     if 0 in hidden_dims:
         return get_activation('none')
     else:
         layers = []
         for hidden_dim in hidden_dims:
             layer = nn.Sequential(
                 nn.Linear(input_dim, hidden_dim),
                 get_activation(activation),
                 nn.Dropout(dropout),
             )
             layers.append(layer)
             input_dim = hidden_dim
         layer = nn.Linear(input_dim, output_dim)
         layers.append(layer)
         return nn.Sequential(*layers)
示例#4
0
 def _dense_branches(x):
     for idx, unit in enumerate(units):
         x = Dense(units=unit, name=f'{name}_dense_{idx}')(x)
         x = get_activation(activation)(x)
         if use_dropout:
             x = Dropout(rate=dropout_rate)(x)
     x = Dense(units=output_num,
               activation='softmax',
               name=f'{name}_output')(x)
     return x
示例#5
0
 def _deform_conv_block(x):
     for idx, channel in enumerate(channels):
         x = ConvOffset2D(deform_channels[idx],
                          name=f'{name}_deform_offset_{idx+name_offset}')(x)
         x = Conv2D(filters=channel,
                    kernel_size=kernel_size,
                    strides=strides,
                    padding='same',
                    name=f'{name}_deform_conv_{idx+name_offset}')(x)
         x = get_activation(activation)(x)
     return x
示例#6
0
文件: egsage.py 项目: maxiaoba/GRAPE
    def __init__(self, in_channels, out_channels, edge_channels, activation,
                 edge_mode, normalize_emb, aggr):
        super(EGraphSage, self).__init__(aggr=aggr)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.edge_channels = edge_channels
        self.edge_mode = edge_mode

        if edge_mode == 0:
            self.message_lin = nn.Linear(in_channels, out_channels)
            self.attention_lin = nn.Linear(2 * in_channels + edge_channels, 1)
        elif edge_mode == 1:
            self.message_lin = nn.Linear(in_channels + edge_channels,
                                         out_channels)
        elif edge_mode == 2:
            self.message_lin = nn.Linear(2 * in_channels + edge_channels,
                                         out_channels)
        elif edge_mode == 3:
            self.message_lin = nn.Sequential(
                nn.Linear(2 * in_channels + edge_channels, out_channels),
                get_activation(activation),
                nn.Linear(out_channels, out_channels),
            )
        elif edge_mode == 4:
            self.message_lin = nn.Linear(in_channels,
                                         out_channels * edge_channels)
        elif edge_mode == 5:
            self.message_lin = nn.Linear(2 * in_channels,
                                         out_channels * edge_channels)

        self.agg_lin = nn.Linear(in_channels + out_channels, out_channels)

        self.message_activation = get_activation(activation)
        self.update_activation = get_activation(activation)
        self.normalize_emb = normalize_emb
示例#7
0
def conv2d(x,
           filters,
           kernel_size=3,
           strides=1,
           use_sn=False,
           pad_type=None,
           pad_size=None,
           norm=None,
           activation=None):
    if pad_type == 'zero':
        x = ZeroPadding2D(padding=pad_size)(x)
    if pad_type in [None, 'zero']:
        x = SN(Conv2D(filters, kernel_size=kernel_size, strides=strides))(x) if use_sn else \
            Conv2D(filters, kernel_size=kernel_size, strides=strides)(x)
    else:
        x = SN(Conv2D(filters, kernel_size=kernel_size, strides=strides, padding=pad_type))(x) if use_sn else \
            Conv2D(filters, kernel_size=kernel_size, strides=strides, padding=pad_type)(x)

    # Normalization
    x = get_normalization(norm)(x)

    # Activation
    x = get_activation(activation)(x)
    return x
示例#8
0
    def define_model(self):
        input_layer = Input(shape=self.config.input_shape)

        x = conv2d(input_layer,
                   filters=64,
                   kernel_size=7,
                   strides=2,
                   pad_type='zero',
                   pad_size=3,
                   norm='in',
                   activation='lrelu')
        x = ZeroPadding2D(padding=(1, 1))(x)
        x = MaxPooling2D((3, 3), 2)(x)

        shortcut = conv2d(x,
                          filters=256,
                          kernel_size=1,
                          strides=1,
                          pad_type='valid',
                          norm='in')
        for i in range(3):
            x = conv2d(x,
                       filters=64,
                       kernel_size=1,
                       strides=1,
                       pad_type='valid',
                       norm='in',
                       activation='lrelu')
            x = conv2d(x,
                       filters=64,
                       kernel_size=3,
                       strides=1,
                       pad_type='same',
                       norm='in',
                       activation='lrelu')
            x = conv2d(x,
                       filters=256,
                       kernel_size=1,
                       strides=1,
                       pad_type='valid',
                       norm='in')

            x = Add()([x, shortcut])
            x = get_activation('lrelu')(x)
            shortcut = x

        shortcut = conv2d(x,
                          filters=512,
                          kernel_size=1,
                          strides=2,
                          pad_type='valid',
                          norm='in')
        for i in range(4):
            x = conv2d(x, filters=128, kernel_size=1, strides=2, pad_type='valid', norm='in', activation='lrelu') if i == 0 \
                else conv2d(x, filters=128, kernel_size=1, strides=1, pad_type='valid', norm='in', activation='lrelu')
            x = conv2d(x,
                       filters=128,
                       kernel_size=3,
                       strides=1,
                       pad_type='same',
                       norm='in',
                       activation='lrelu')
            x = conv2d(x,
                       filters=512,
                       kernel_size=1,
                       strides=1,
                       pad_type='valid',
                       norm='in')

            x = Add()([x, shortcut])
            x = get_activation('lrelu')(x)
            shortcut = x

        shortcut = conv2d(x,
                          filters=1024,
                          kernel_size=1,
                          strides=2,
                          pad_type='valid',
                          norm='in')
        for i in range(6):
            x = conv2d(x, filters=256, kernel_size=1, strides=2, pad_type='valid', norm='in', activation='lrelu') if i == 0 \
                else conv2d(x, filters=256, kernel_size=1, strides=1, pad_type='valid', norm='in', activation='lrelu')
            x = conv2d(x,
                       filters=256,
                       kernel_size=3,
                       strides=1,
                       pad_type='same',
                       norm='in',
                       activation='lrelu')
            x = conv2d(x,
                       filters=1024,
                       kernel_size=1,
                       strides=1,
                       pad_type='valid',
                       norm='in')

            x = Add()([x, shortcut])
            x = get_activation('lrelu')(x)
            shortcut = x

        shortcut = conv2d(x,
                          filters=2048,
                          kernel_size=1,
                          strides=2,
                          pad_type='valid',
                          norm='in')
        for i in range(3):
            x = conv2d(x, filters=512, kernel_size=1, strides=2, pad_type='valid', norm='in', activation='lrelu') if i == 0 \
                else conv2d(x, filters=512, kernel_size=1, strides=1, pad_type='valid', norm='in', activation='lrelu')
            x = conv2d(x,
                       filters=512,
                       kernel_size=3,
                       strides=1,
                       pad_type='same',
                       norm='in',
                       activation='lrelu')
            x = conv2d(x,
                       filters=2048,
                       kernel_size=1,
                       strides=1,
                       pad_type='valid',
                       norm='in')

            x = Add()([x, shortcut])
            x = get_activation('lrelu')(x)
            shortcut = x

        x = GlobalAveragePooling2D()(x)
        out = Dense(self.config.num_classes,
                    activation='softmax',
                    name=f'class_output')(x)
        return Model(inputs=input_layer, outputs=out, name='resnet50')