def forward(self, input): x = input # first layer x = self.down_layers[0](x) down_outputs = [x] # down layers for unet_layer in self.down_layers[1:]: x = self.max_pool(x) x = unet_layer(x) down_outputs.append(x) # up layers for (upconv_layer, unet_layer), down_output in zip(self.up_layers, down_outputs[-2::-1]): x = upconv_layer(x) x = crop_and_merge(down_output, x) x = unet_layer(x) pred = self.final_layer(x) return pred
def forward(self, input): # input x = input # first layer x = self.down_layers[0](x) down_outputs = [x] # down layers for unet_layer in self.down_layers[1:]: x = self.max_pool(x) x = unet_layer(x) down_outputs.append(x) # pose estimation and atlas registration prior = self.prior.repeat(*([input.shape[0], 1] + [1 for _ in range(self.config.ndims)])) pose = self.find_pose(down_outputs) registered_atlas = self.register_atlas(prior, pose, input.shape) # Uncomment if the priors are registered based on pose of each structure separately # In pose first n parameters corresponds to the pose of structure 1 and last n parameters corresponds to pose of structure 2 # registered_atlas_1 = self.register_atlas(self.prior[:, 0], axis[:, :n]) # registered_atlas_2 = self.register_atlas(self.prior[:, 1], axis[:, n:]) # registered_atlas = torch.cat((registered_atlas_1, registered_atlas_2), dim=1) # up layers for (upconv_layer, unet_layer), down_output in zip(self.up_layers, down_outputs[-2::-1]): x = upconv_layer(x) x = crop_and_merge(down_output, x) x = crop_and_merge( F.upsample(registered_atlas, size=x.shape[2:], mode=self.upsample_mode), x) # PAs x = unet_layer(x) x = crop_and_merge( F.upsample(registered_atlas, size=x.shape[2:], mode=self.upsample_mode), x) # PAs x = self.final_layer(x) return x, pose
def forward(self, data): x = data['x'] unpool_indices = data['unpool'] sphere_vertices = self.sphere_vertices.clone() vertices = sphere_vertices.clone() faces = self.sphere_faces.clone() batch_size = self.config.batch_size # first layer x = self.down_layers[0](x) down_outputs = [x] # down layers for unet_layer in self.down_layers[1:]: x = self.max_pool(x) x = unet_layer(x) down_outputs.append(x) A, D = adjacency_matrix(vertices, faces) pred = [None] * self.config.num_classes for k in range(self.config.num_classes - 1): pred[k] = [[ vertices.clone(), faces.clone(), None, None, sphere_vertices.clone() ]] for i, ((skip_connection, grid_upconv_layer, grid_unet_layer), up_f2f_layers, up_f2v_layers, down_output, skip_amount, do_unpool) in enumerate( zip(self.up_std_conv_layers, self.up_f2f_layers, self.up_f2v_layers, down_outputs[::-1], self.skip_count, unpool_indices)): if grid_upconv_layer is not None and i > 0: x = grid_upconv_layer(x) x = crop_and_merge(down_output, x) x = grid_unet_layer(x) elif grid_upconv_layer is None: x = down_output for k in range(self.config.num_classes - 1): # load mesh information from previous iteratioin for class k vertices = pred[k][i][0] faces = pred[k][i][1] latent_features = pred[k][i][2] sphere_vertices = pred[k][i][4] graph_unet_layer = up_f2f_layers[k] feature2vertex = up_f2v_layers[k] if do_unpool[0] == 1: faces_prev = faces _, N_prev, _ = vertices.shape # Get candidate vertices using uniform unpool vertices, faces_ = uniform_unpool(vertices, faces) latent_features, _ = uniform_unpool(latent_features, faces) sphere_vertices, _ = uniform_unpool(sphere_vertices, faces) faces = faces_ A, D = adjacency_matrix(vertices, faces) skipped_features = skip_connection(x[:, :skip_amount], vertices) latent_features = torch.cat( [latent_features, skipped_features, vertices], dim=2) if latent_features is not None else torch.cat( [skipped_features, vertices], dim=2) latent_features = graph_unet_layer(latent_features, A, D, vertices, faces) deltaV = feature2vertex(latent_features, A, D, vertices, faces) vertices = vertices + deltaV if do_unpool[0] == 1: # Discard the vertices that were introduced from the uniform unpool and didn't deform much vertices, faces, latent_features, sphere_vertices = adoptive_unpool( vertices, faces_prev, sphere_vertices, latent_features, N_prev) voxel_pred = self.final_layer(x) if i == len( self.up_std_conv_layers) - 1 else None pred[k] += [[ vertices, faces, latent_features, voxel_pred, sphere_vertices ]] return pred