def __init__(self, in_channels, graph_args, edge_importance_weighting, **kwargs): super().__init__() # load graph self.graph = Graph(**graph_args) A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False) relation = torch.tensor(self.graph.relation, dtype=torch.float32, requires_grad=False) self.register_buffer('A', A) self.register_buffer('relation', relation) self.edge_importance_weighting = edge_importance_weighting # build networks spatial_kernel_size = A.size(0) #3 temporal_kernel_size = 9 kernel_size = (temporal_kernel_size, spatial_kernel_size) self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'} self.linear = nn.Linear(self.relation.size(2), 1) self.rt_gcn_networks = nn.ModuleList( ( # st_gcn(in_channel,out_channel,kernel_size,stride,dropout,residual) rt_gcn(in_channels, 8, kernel_size, 2, residual=True, **kwargs0), # rt_gcn(64, 64, kernel_size, 1, **kwargs), # rt_gcn(8, 8, kernel_size, 4, **kwargs), # rt_gcn(16, 16, kernel_size, 2, **kwargs), # rt_gcn(256, 256, kernel_size, 1, **kwargs), # rt_gcn(32, 32, kernel_size, 1, **kwargs), )) # initialize parameters for edge importance weighting if self.edge_importance_weighting == 'Uniform': # self.edge_importance = nn.ParameterList([ # nn.Parameter(torch.ones(self.A.size())) # for i in self.rt_gcn_networks # ]) self.edge_importance = [1] * len(self.rt_gcn_networks) elif self.edge_importance_weighting == 'Weight': self.edge_importance = nn.ModuleList([ nn.Linear(self.relation.size(2), 1) for i in self.rt_gcn_networks ]) elif self.edge_importance_weighting == 'Time-aware': # self.edge_importance = [1] * len(self.rt_gcn_networks) self.edge_importance1 = nn.ModuleList([ nn.MultiheadAttention(embed_dim=in_channels, num_heads=1), nn.MultiheadAttention(embed_dim=8, num_heads=1) ]) self.edge_importance2 = nn.ModuleList([ nn.Linear(self.relation.size(2), 1) for i in self.rt_gcn_networks ]) # fcn for prediction self.fcn = nn.Conv2d(8, 1, kernel_size=1) #output_size=(N,1,1,V)
def __init__(self, in_channels, num_class, graph_args, edge_importance_weighting, **kwargs): super().__init__() # load graph self.graph = Graph(**graph_args) A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False) self.register_buffer('A', A) # build networks spatial_kernel_size = A.size(0) #unifor 1, distance 2, spatial 3 temporal_kernel_size = 3 kernel_size = (temporal_kernel_size, spatial_kernel_size) #9,3 #BN for single stream self.data_bn = nn.BatchNorm1d(in_channels * A.size(1) * 100) #BN for concat stream self.data_bn2 = nn.BatchNorm1d(in_channels * 2 * A.size(1) * 100) # 6* 18 kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'} self.st_gcn_networks = nn.ModuleList(( #3 64 (9,3) st_gcn(in_channels, 16, kernel_size, 1, residual=False, **kwargs0), st_gcn(16, 16, kernel_size, 1, **kwargs), st_gcn(16, 64, kernel_size, 2, **kwargs), st_gcn(64, 64, kernel_size, 1, **kwargs), st_gcn(64, 256, kernel_size, 2, **kwargs), st_gcn(256, 256, kernel_size, 1, **kwargs), st_gcn(256, 64, kernel_size, 2, **kwargs), st_gcn(64, 64, kernel_size, 1, **kwargs), st_gcn(64, 16, kernel_size, 2, **kwargs), st_gcn(16, 16, kernel_size, 1, **kwargs), )) self.st_gcn_networks2 = nn.ModuleList(( #3 64 (9,3) st_gcn(in_channels * 2, 16, kernel_size, 1, residual=False, **kwargs0), st_gcn(16, 16, kernel_size, 1, **kwargs), st_gcn(16, 64, kernel_size, 2, **kwargs), st_gcn(64, 64, kernel_size, 1, **kwargs), st_gcn(64, 256, kernel_size, 2, **kwargs), st_gcn(256, 256, kernel_size, 1, **kwargs), st_gcn(256, 64, kernel_size, 2, **kwargs), st_gcn(64, 64, kernel_size, 1, **kwargs), st_gcn(64, 16, kernel_size, 2, **kwargs), st_gcn(16, 16, kernel_size, 1, **kwargs), )) # initialize parameters for edge importance weighting #train.yaml True if edge_importance_weighting: self.edge_importance = nn.ParameterList([ nn.Parameter(torch.ones(self.A.size())) for i in self.st_gcn_networks ]) else: self.edge_importance = [1] * len(self.st_gcn_networks) if edge_importance_weighting: self.edge_importance2 = nn.ParameterList([ nn.Parameter(torch.ones(self.A.size())) for i in self.st_gcn_networks2 ]) else: self.edge_importance2 = [1] * len(self.st_gcn_networks2) self.fuse = nn.ParameterList( [nn.Parameter(torch.ones(1)) for i in range(2)]) # fcn for prediction self.FCN = nn.ModuleList( (nn.Linear(16 * 18, 128), nn.BatchNorm1d(128), nn.Linear(128, 64), nn.BatchNorm1d(64), nn.Linear(64, 32), nn.BatchNorm1d(32), nn.Linear(32, 1))) self.FCN2 = nn.ModuleList( (nn.Linear(16 * 18, 128), nn.BatchNorm1d(128), nn.Linear(128, 64), nn.BatchNorm1d(64), nn.Linear(64, 32), nn.BatchNorm1d(32), nn.Linear(32, 1)))