def trace(self, pth='./tmp', name_prefix=''): params = inspect.signature(self.forward).parameters params = OrderedDict(params) names_i = [name for name in params.keys()] names_o = ['y'] x_1 = torch.rand(1, self.dim_node) e = torch.rand(1, self.dim_edge) x_2 = torch.rand(1, self.dim_node) self(x_1, e, x_2) name = name_prefix + '_' + self.name op_utils.export(self, (x_1, e, x_2), os.path.join(pth, name), input_names=names_i, output_names=names_o, dynamic_axes={ names_i[0]: { 0: 'n_edge' }, names_i[1]: { 0: 'n_edge' }, names_i[2]: { 0: 'n_edge' } }) names = dict() names['model_' + name] = dict() names['model_' + name]['path'] = name names['model_' + name]['input'] = names_i names['model_' + name]['output'] = names_o return names
def trace(self, pth='./tmp', name_prefix=''): n_node = 2 n_edge = 4 x = torch.rand(n_node, self.dim_node) edge_feature = torch.rand(n_edge, self.dim_edge) edge_index = torch.randint(0, n_node - 1, [2, n_edge]) edge_index[0] = torch.zeros([n_edge]) edge_index[1] = torch.ones([n_edge]) self.eval() self(x, edge_feature, edge_index) x_i, x_j = self.index_get(x, edge_index) xx, edge_feature, prob = self.edgeatten(x_i, edge_feature, x_j) xx = self.index_aggr(xx, edge_index, dim_size=x.shape[0]) # y = self.prop(torch.cat([x,xx],dim=1)) names_i = ['x_in'] names_o = ['x_out'] name_nn = name_prefix + '_' + self.name + '_prop' cated = torch.cat([x, xx], dim=1) op_utils.export(self.prop, (cated), os.path.join(pth, name_nn), input_names=names_i, output_names=names_o, dynamic_axes={names_i[0]: { 0: 'n_node' }}) names_nn = dict() names_nn['model_' + name_nn] = dict() names_nn['model_' + name_nn]['path'] = name_nn names_nn['model_' + name_nn]['input'] = names_i names_nn['model_' + name_nn]['output'] = names_o name = name_prefix + '_' + self.name names_atten = self.edgeatten.trace(pth, name) names = dict() names[name] = dict() names[name]['atten'] = names_atten names[name]['prop'] = names_nn return names
def trace(self, pth='./tmp', name_prefix=''): import os x = torch.rand(1, self.in_size) names_i = ['x'] names_o = ['y'] name = name_prefix + '_' + self.name op_utils.export(self, (x), os.path.join(pth, name), input_names=names_i, output_names=names_o, dynamic_axes={names_i[0]: { 0: 'n_node', 2: 'n_pts' }}) names = dict() names['model_' + name] = dict() names['model_' + name]['path'] = name names['model_' + name]['input'] = names_i names['model_' + name]['output'] = names_o return names