def forward(self, f1, f2, x1, x2): # nb, nc, np, na -> nb, nc, na sp1 = zptk.SphericalPointCloud(x1, f1, None) sp2 = zptk.SphericalPointCloud(x2, f2, None) f1 = self._pooling(sp1) f2 = self._pooling(sp2) nb = f1.shape[0] na = f1.shape[2] # expand and concat into metric space (nb, nc*2, na_tgt, na_src) f2_expand = f2.unsqueeze(-1).expand(-1,-1,-1,na).contiguous() f1_expand = f1.unsqueeze(-2).expand(-1,-1,na,-1).contiguous() x_out = torch.cat((f1_expand,f2_expand),1) # fc layers with relu for linear in self.linear: x_out = linear(x_out) x_out = F.relu(x_out) attention_wts = self.attention_layer(x_out).view(nb, na, na) confidence = F.softmax(attention_wts * self.temperature, dim=1) y = self.regressor_layer(x_out) # return: [nb, na, na], [nb, n_out, na, na] return confidence, y
def forward(self, frag, clouds): x = self.prop(frag, clouds) feat = self.norm(x.feats) if self.relu is not None: feat = self.relu(feat) if self.training and self.dropout is not None: feat = self.dropout(feat) return zptk.SphericalPointCloud(x.xyz, feat, x.anchors)
def preprocess_input(x, na, add_center=True): has_normals = x.shape[2] == 6 # add a dummy center point at index zero if add_center and not has_normals: center = x.mean(1, keepdim=True) x = torch.cat((center,x),dim=1)[:,:-1] xyz = x[:,:,:3] return zptk.SphericalPointCloud(xyz.permute(0,2,1).contiguous(), sptk.get_occupancy_features(x, na, add_center), None)
def forward(self, x): # [b, 3, p] x [b, c1, p] x = self.conv(x) feat = self.norm(x.feats) if self.relu is not None: feat = self.relu(feat) if self.training and self.dropout is not None: feat = self.dropout(feat) # [b, 3, p] x [b, c2, p] return zptk.SphericalPointCloud(x.xyz, feat, x.anchors)
def forward(self, x, inter_idx=None, inter_w=None): input_x = x inter_idx, inter_w, sample_idx, x = self.conv(x, inter_idx, inter_w) feat = self.norm(x.feats) # feat = x.feats if self.relu is not None: feat = self.relu(feat) if self.training and self.dropout is not None: feat = self.dropout(feat) return inter_idx, inter_w, sample_idx, zptk.SphericalPointCloud(x.xyz, feat, x.anchors)
def forward(self, x, inter_idx, inter_w): ''' inter, intra conv with skip connection ''' skip_feature = x.feats inter_idx, inter_w, sample_idx, x = self.inter_conv(x, inter_idx, inter_w) if self.use_intra: x = self.intra_conv(x) if self.stride > 1: skip_feature = zptk.functional.batched_index_select(skip_feature, 2, sample_idx.long()) skip_feature = self.skip_conv(skip_feature) skip_feature = self.relu(self.norm(skip_feature)) # skip_feature = self.relu(skip_feature) x_out = zptk.SphericalPointCloud(x.xyz, x.feats + skip_feature, x.anchors) return inter_idx, inter_w, sample_idx, x_out
def forward(self, x): # nb, nc, np, na -> nb, nc, na # attention first nb, nc, np, na = x.feats.shape attn = self.attention_layer(x.feats) attn = F.softmax(attn, dim=3) # nb, nc, np, 1 x_out = (x.feats * attn).sum(-1, keepdim=True) x_in = zptk.SphericalPointCloud(x.xyz, x_out, None) # nb, nc x_out = self.pointnet(x_in).view(nb, -1) return F.normalize(x_out, p=2, dim=1), attn
def forward(self, x, label=None): x_out = x.feats if self.debug: return x_out[:,:40].mean(-1).mean(-1),None norm_cnt = 0 end = len(self.linear) for lid, linear in enumerate(self.linear): norm = self.norm[norm_cnt] x_out = linear(x_out) x_out = F.relu(norm(x_out)) # x_out = F.relu(x_out) norm_cnt += 1 out_feat = x_out x_in = zptk.SphericalPointCloud(x.xyz, out_feat, x.anchors) x_out = self.pointnet(x_in) norm = self.norm[norm_cnt] norm_cnt += 1 x_out = F.relu(norm(x_out)) # x_out = F.relu(x_out) # mean pooling if self.pooling_method == 'mean': x_out = x_out.mean(dim=2) elif self.pooling_method == 'debug': # for debug only x_out = x_out[..., 0].mean(2) elif self.pooling_method == 'max': # max pooling x_out = x_out.max(2)[0] elif self.pooling_method.startswith('attention'): out_feat = self.attention_layer(x_out) # Bx1XA or BxCxA confidence = F.softmax(out_feat * self.temperature, dim=2) x_out = x_out * confidence x_out = x_out.sum(-1) else: raise NotImplementedError(f"Pooling mode {self.pooling_method} is not implemented!") x_out = self.fc2(x_out) return x_out, out_feat.squeeze()
def forward(self, feats, label=None): x_out = feats norm_cnt = 0 end = len(self.linear) for lid, linear in enumerate(self.linear): norm = self.norm[norm_cnt] x_out = linear(x_out) x_out = F.relu(norm(x_out)) norm_cnt += 1 # mean pool at xyz out_feat = x_out x_out = x_out.mean(2, keepdim=True) # group convolution after mean pool if hasattr(self, 'intra'): x_in = zptk.SphericalPointCloud(None, x_out, None) for lid, conv in enumerate(self.intra): skip_feat = x_in.feats x_in = conv(x_in) # skip connection norm = self.norm[norm_cnt] skip_feat = self.skipconv[lid](skip_feat) skip_feat = F.relu(norm(skip_feat)) x_in = zptk.SphericalPointCloud(None, skip_feat + x_in.feats, None) norm_cnt += 1 x_out = x_in.feats # mean pooling if self.pooling_method == 'mean': x_out = x_out.mean(dim=3).mean(dim=2) elif self.pooling_method == 'debug': # for debug only x_out = x_out[..., 0].mean(2) elif self.pooling_method == 'max': # max pooling x_out = x_out.mean(2).max(-1)[0] ############## DEBUG ONLY ###################### elif label is not None: def to_one_hot(label, num_class): ''' label: [B,...] return [B,...,num_class] ''' comp = torch.arange(num_class).long().to(label.device) for i in range(label.dim()): comp = comp.unsqueeze(0) onehot = label.unsqueeze(-1) == comp return onehot.float() x_out = x_out.mean(2) label = label.squeeze() if label.dim() == 2: cdim = x_out.shape[1] label = label.repeat(1,5)[:,:cdim] confidence = to_one_hot(label, x_out.shape[2]) if confidence.dim() < 3: confidence = confidence.unsqueeze(1) x_out = x_out * confidence x_out = x_out.sum(-1) elif self.pooling_method.startswith('attention'): x_out = x_out.mean(2) out_feat = self.attention_layer(x_out) # Bx1XA or BxCxA confidence = F.softmax(out_feat * self.temperature, dim=2) x_out = x_out * confidence x_out = x_out.sum(-1) else: raise NotImplementedError(f"Pooling mode {self.pooling_method} is not implemented!") # fc layers for linear in self.fc1: # norm = self.norm[norm_cnt] x_out = linear(x_out) # x_out = F.relu(norm(x_out)) x_out = F.relu(x_out) # norm_cnt += 1 x_out = self.fc2(x_out) return x_out, out_feat.squeeze()