def _make_transpose(self, transblock, planes, blocks, stride=1): upsample = None if stride != 1: upsample = scn.Sequential( scn.SparseToDense(2,self.inplanes * transblock.expansion), nn.ConvTranspose2d(self.inplanes * transblock.expansion, planes, kernel_size=2, stride=stride, padding=0, bias=False), scn.DenseToSparse(2), scn.BatchNormalization(planes) ) elif self.inplanes * transblock.expansion != planes: upsample = scn.Sequential( scn.NetworkInNetwork(self.inplanes * transblock.expansion, planes, False), scn.BatchNormalization(planes) ) layers = [] for i in range(1, blocks): layers.append(transblock(self.inplanes, self.inplanes * transblock.expansion)) layers.append(transblock(self.inplanes, planes, stride, upsample)) self.inplanes = planes // transblock.expansion return scn.Sequential(*layers)
def __init__(self, num_classes=5): nn.Module.__init__(self) self.sparseModel = scn.Sequential().add(scn.DenseToSparse(2)).add( scn.ValidConvolution(2, 3, 8, 2, False)).add( scn.MaxPooling(2, 4, 2)).add( scn.SparseResNet(2, 8, [['b', 8, 3, 1], [ 'b', 16, 2, 2 ], ['b', 24, 2, 2], ['b', 32, 2, 2]])).add( scn.Convolution(2, 32, 64, 4, 1, False)).add(scn.BatchNormReLU(64)).add( scn.SparseToDense(2, 64)) self.linear = nn.Linear(6400, num_classes)
def normalize_input(self, x, ones_conv): # implements sparsity invariant convolutions: # unfortunately, now works by bringing activation to dense representation, normalizing and then back to sparse eps = 1e-5 dense_to_sparse = scn.DenseToSparse(2) sparse_to_dense = scn.SparseToDense(2, x.features.shape[-1]) dense_x = sparse_to_dense(x) # compute active sites and normalization factors active_sites = (dense_x.abs().sum(1, keepdim=True) > 0).float() dense_x_normalized = dense_x / (ones_conv(active_sites) + eps) # back to sparse x_normalized = dense_to_sparse(dense_x_normalized) return x_normalized
def forward(self, x): # Concatenate MLPs that treat PID, pos, dir and energy inputs separately net =[:, 0:3]), self._mlp_pos( x[:, 3:6]), self._mlp_dir( x[:, 6:9]), self._mlp_E(x[:, 9].reshape(len(x[:, 9]), 1))), 1) # MegaMLP net = self._mlp(net) # Reshape into 11 x 21 figure in 64 channels. Enough?! net = net.view(-1, 64, 11, 21) # netSparse = scn.InputBatch(4, net) net = scn.DenseToSparse(net) # net = self._upconvs(net) # Need to flatten? Maybe... return self._upconvs(net).view(-1, 88 * 168)
def __init__(self, inplanes, planes, stride=1, upsample=None, **kwargs): super(TransBasicBlockSparse, self).__init__() self.conv1 = conv3x3_sparse(inplanes, inplanes) self.bn1 = scn.BatchNormReLU(inplanes) self.relu = scn.ReLU() if upsample is not None and stride != 1: self.conv2 = scn.Sequential( scn.SparseToDense(2,inplanes), nn.ConvTranspose2d(inplanes, planes, kernel_size=2, stride=stride, padding=0, output_padding=0, bias=False), scn.DenseToSparse(2) ) else: self.conv2 = conv3x3_sparse(inplanes, planes, stride) self.bn2 = scn.BatchNormalization(planes) self.add = scn.AddTable() self.upsample = upsample self.stride = stride
def __init__(self, transblock, layers, num_classes=150): self.inplanes = 512 super(ResNetTransposeSparse, self).__init__() self.dense_to_sparse = scn.DenseToSparse(2) self.add = AddSparseDense() self.deconv1 = self._make_transpose(transblock, 256 * transblock.expansion, layers[0], stride=2) self.deconv2 = self._make_transpose(transblock, 128 * transblock.expansion, layers[1], stride=2) self.deconv3 = self._make_transpose(transblock, 64 * transblock.expansion, layers[2], stride=2) self.deconv4 = self._make_transpose(transblock, 64 * transblock.expansion, layers[3], stride=2) self.skip0 = self._make_skip_layer(128, 64 * transblock.expansion) self.skip1 = self._make_skip_layer(256, 64 * transblock.expansion) self.skip2 = self._make_skip_layer(512, 128 * transblock.expansion) self.skip3 = self._make_skip_layer(1024, 256 * transblock.expansion) self.skip4 = self._make_skip_layer(2048, 512 * transblock.expansion) self.densify0 = scn.SparseToDense(2, 64 * transblock.expansion) self.densify1 = scn.SparseToDense(2, 64 * transblock.expansion) self.densify2 = scn.SparseToDense(2, 128 * transblock.expansion) self.densify3 = scn.SparseToDense(2, 256 * transblock.expansion) self.inplanes = 64 self.final_conv = self._make_transpose(transblock, 64 * transblock.expansion, 3) self.final_deconv = scn.Sequential( scn.SparseToDense(2, self.inplanes * transblock.expansion), nn.ConvTranspose2d(self.inplanes * transblock.expansion, num_classes, kernel_size=2, stride=2, padding=0, bias=True) ) self.out6_conv = nn.Conv2d(2048, num_classes, kernel_size=1, stride=1, bias=True) self.out5_conv = scn.NetworkInNetwork(256 * transblock.expansion, num_classes, True) self.out4_conv = scn.NetworkInNetwork(128 * transblock.expansion, num_classes, True) self.out3_conv = scn.NetworkInNetwork(64 * transblock.expansion, num_classes, True) self.out2_conv = scn.NetworkInNetwork(64 * transblock.expansion, num_classes, True) self.sparse_to_dense = scn.SparseToDense(2, num_classes)
def infer(model, data, input_shape=(256,256,256), use_gpu=True, course_grained=False): model_prefix = "./molmimic_model_{}".format('%Y-%m-%d_%H:%M:%S')) dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor' inputSpatialSize = torch.LongTensor(input_shape) labels = None mlog = MeterLogger(nclass=2, title="Sparse 3D UNet Inference") phase = "test" epoch = 0 batch_weight = data.get("weight", None) if batch_weight is not None: batch_weight = torch.from_numpy(batch_weight).float().cuda() sample_weights = data.get("sample_weights", None) if sample_weights is not None: sample_weights = torch.from_numpy(sample_weights).float().cuda() use_size_average = False weight = sample_weights if use_size_average else batch_weight if data["data"].__class__.__name__ == "InputBatch": sparse_input = True inputs = data["data"] labels = data["truth"] if "truth" in data else None if use_gpu: inputs = inputs.cuda().to_variable(requires_grad=False) labels = labels.cuda().to_variable() else: inputs = inputs.to_variable(requires_grad=False) labels = labels.to_variable() elif isinstance(data["data"], (list, tuple)): sparse_input = True inputs = scn.InputBatch(3, inputSpatialSize) labels = scn.InputBatch(3, inputSpatialSize) if "truth" in data else None if isinstance(data["data"][0], np.ndarray): long_tensor = lambda arr: torch.from_numpy(arr).long() float_tensor = lambda arr: torch.from_numpy(arr).float() elif isinstance(data["data"][0], (list, tuple)): long_tensor = lambda arr: torch.LongTensor(arr) float_tensor = lambda arr: torch.FloatTensor(arr) else: raise RuntimeError("invalid datatype") for sample, (indices, features, truth) in enumerate(izip(data["indices"], data["data"], data["truth"])): inputs.addSample() if labels is not None: labels.addSample() try: indices = long_tensor(indices) features = float_tensor(features) if labels is not None: truth = float_tensor(truth) except RuntimeError as e: print e continue inputs.setLocations(indices, features, 0) #Use 1 to remove duplicate coords? if labels is not None: labels.setLocations(indices, truth, 0) del data del indices del truth inputs.precomputeMetadata(1) if use_gpu: inputs = inputs.cuda() labels = labels.cuda() inputs = inputs.to_variable(requires_grad=True) labels = labels.to_variable() elif isinstance(data["data"], torch.FloatTensor): #Input is dense print "Input is Dense" sparse_input = False if use_gpu: inputs = inputs.cuda() labels = labels.cuda() inputs = Variable(data["data"], requires_grad=True) inputs = scn.DenseToSparse(3)(inputs) try: inputs = inputs.cuda().to_variable(requires_grad=True) except: pass labels = Variable(data["truth"]) else: raise RuntimeError("Invalid data from dataset") # forward try: outputs = model(inputs) except AssertionError as e: print e #print nFeatures, inputs raise if labels is None: return outputs else: loss_fn = torch.nn.CrossEntropyLoss(weight=weight) loss = loss_fn(outputs, torch.max(labels.features, 1)[1]) mlog.update_loss(loss, meter='loss') mlog.update_meter(outputs, torch.max(labels.features, 1)[1], meters={'accuracy', 'map'}) add_to_logger(mlog, "Train" if phase=="train" else "Test", epoch, outputs, labels.features, batch_weight) del inputs del labels del loss del loss_fn del batch_weight del sample_weights return outputs, mlog
def train(ibis_data, input_shape=(264, 264, 264), model_prefix=None, check_point=True, save_final=True, only_aa=False, only_atom=False, non_geom_features=False, use_deepsite_features=False, expand_atom=False, num_workers=None, num_epochs=30, batch_size=20, shuffle=True, use_gpu=True, initial_learning_rate=0.0001, learning_rate_drop=0.5, learning_rate_epochs=10, lr_decay=4e-2, data_split=0.8, course_grained=False, no_batch_norm=False, use_resnet_unet=False, unclustered=False, undersample=False, oversample=False, nFeatures=None, allow_feature_combos=False, bs_feature=None, bs_feature2=None, bs_features=None, stripes=False, data_parallel=False, dropout_depth=False, dropout_width=False, dropout_p=0.5, wide_model=False, cellular_organisms=False, autoencoder=False, checkpoint_callback=None): if model_prefix is None: model_prefix = "./molmimic_model_{}".format('%Y-%m-%d_%H:%M:%S')) if num_workers is None: num_workers = multiprocessing.cpu_count() - 1 since = time.time() if ibis_data == "spheres": from torch_loader import SphereDataset nFeatures = nFeatures or 3 datasets = SphereDataset.get_training_and_validation( input_shape, cnt=1, n_samples=1000, nFeatures=nFeatures, allow_feature_combos=allow_feature_combos, bs_feature=bs_feature, bs_feature2=bs_feature2, bs_features=bs_features, stripes=stripes, data_split=0.99) validation_batch_size = 1 if bs_features is not None: nClasses = len(bs_features) + 1 else: nClasses = 2 elif os.path.isfile(ibis_data): dataset = IBISDataset print allow_feature_combos, nFeatures if allow_feature_combos and nFeatures is not None: random_features = (nFeatures, allow_feature_combos, bs_feature, bs_feature2) elif not allow_feature_combos and nFeatures is not None: random_features = (nFeatures, False, bs_feature, bs_feature2) elif allow_feature_combos and nFeatures is None: random_features = None print "ignoring --allow-feature-combos" else: random_features = None datasets = dataset.get_training_and_validation( ibis_data, input_shape=input_shape, only_aa=only_aa, only_atom=only_atom, non_geom_features=non_geom_features, use_deepsite_features=use_deepsite_features, data_split=data_split, course_grained=course_grained, oversample=oversample, undersample=undersample, cellular_organisms=cellular_organisms, random_features=random_features) nFeatures = datasets["train"].get_number_of_features() nClasses = 2 if not autoencoder else nFeatures validation_batch_size = batch_size else: raise RuntimeError("Invalid training data") if num_workers % 2 == 0: num_workers -= 1 num_workers /= 2 num_workers = 6 dataloaders = {name:dataset.get_data_loader( batch_size if dataset.train else validation_batch_size, shuffle, num_workers) \ for name, dataset in datasets.iteritems()} dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available( ) else 'torch.FloatTensor' if use_resnet_unet: model = ResNetUNet(nFeatures, nClasses, dropout_depth=dropout_depth, dropout_width=dropout_width, dropout_p=dropout_p, wide_model=wide_model) else: model = UNet3D(nFeatures, nClasses, batchnorm=not no_batch_norm) if data_parallel: model = torch.nn.DataParallel(model) model.type(dtype) optimizer = SGD(model.parameters(), lr=initial_learning_rate, momentum=0.999, weight_decay=1e-4, nesterov=True) scheduler = LambdaLR(optimizer, lambda epoch: math.exp( (1 - epoch) * lr_decay)) check_point_model_file = "{}_checkpoint_model.pth".format(model_prefix) check_point_epoch_file = "{}_checkpoint_epoch.pth".format(model_prefix) if check_point and os.path.isfile( check_point_model_file) and os.path.isfile(check_point_epoch_file): start_epoch = torch.load(check_point_epoch_file) print "Restarting at epoch {} from {}".format(start_epoch + 1, check_point_model_file) model.load_state_dict(torch.load(check_point_model_file)) else: start_epoch = 0 inputSpatialSize = torch.LongTensor(input_shape) draw_graph = True mlog = MeterLogger(nclass=nClasses, title="Sparse 3D UNet", server="cn4216") #Start clean for obj in gc.get_objects(): try: if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor( #print type(obj), obj.size() del obj except (SystemExit, KeyboardInterrupt): raise except Exception as e: pass for epoch in xrange(start_epoch, num_epochs): print "Epoch {}/{}".format(epoch, num_epochs - 1) print "-" * 10 mlog.timer.reset() for phase in ['train', 'val']: datasets[phase].epoch = epoch num_batches = int( np.ceil( len(datasets[phase]) / float(batch_size if phase == "train" else validation_batch_size))) if phase == 'train': scheduler.step() model.train(True) # Set model to training mode else: model.train(False) # Set model to evaluate mode # Iterate over data. bar = tqdm(enumerate(dataloaders[phase]), total=num_batches, unit="batch", desc="Loading data", leave=True) for data_iter_num, data in bar: datasets[phase].batch = data_iter_num batch_weight = data.get("weight", None) if batch_weight is not None: batch_weight = torch.from_numpy(batch_weight).float() if use_gpu: batch_weight = batch_weight.cuda() sample_weights = data.get("sample_weights", None) if sample_weights is not None: sample_weights = torch.from_numpy(sample_weights).float() if use_gpu: sample_weights = sample_weights.cuda() if data["data"].__class__.__name__ == "InputBatch": sparse_input = True inputs = data["data"] labels = data["truth"] if use_gpu: inputs = inputs.cuda().to_variable(requires_grad=True) labels = labels.cuda().to_variable() else: inputs = inputs.to_variable(requires_grad=True) labels = labels.to_variable() elif isinstance(data["data"], (list, tuple)): sparse_input = True inputs = scn.InputBatch(3, inputSpatialSize) labels = scn.InputBatch(3, inputSpatialSize) if isinstance(data["data"][0], np.ndarray): long_tensor = lambda arr: torch.from_numpy(arr).long() float_tensor = lambda arr: torch.from_numpy(arr).float( ) elif isinstance(data["data"][0], (list, tuple)): long_tensor = lambda arr: torch.LongTensor(arr) float_tensor = lambda arr: torch.FloatTensor(arr) else: raise RuntimeError("invalid datatype") for sample, (indices, features, truth, id) in enumerate( izip(data["indices"], data["data"], data["truth"], data["id"])): inputs.addSample() labels.addSample() try: indices = long_tensor(indices) features = float_tensor(features) truth = float_tensor(truth) except RuntimeError as e: print e continue try: inputs.setLocations( indices, features, 0) #Use 1 to remove duplicate coords? labels.setLocations(indices, truth, 0) except AssertionError: print "Error with PDB:", id with open("bad_pdbs.txt", "a") as f: print >> f, id del data del indices del truth inputs.precomputeMetadata(1) if use_gpu: inputs = inputs.cuda() labels = labels.cuda() inputs = inputs.to_variable(requires_grad=True) labels = labels.to_variable() elif isinstance(data["data"], torch.FloatTensor): #Input is dense print "Input is Dense" sparse_input = False if use_gpu: inputs = inputs.cuda() labels = labels.cuda() inputs = Variable(data["data"], requires_grad=True) inputs = scn.DenseToSparse(3)(inputs) try: inputs = inputs.cuda().to_variable(requires_grad=True) except: pass labels = Variable(data["truth"]) else: raise RuntimeError("Invalid data from dataset") # zero the parameter gradients optimizer.zero_grad() # forward try: outputs = model(inputs) except AssertionError: print nFeatures, inputs raise if sparse_input: use_size_average = False weight = sample_weights if use_size_average else batch_weight loss_fn = torch.nn.CrossEntropyLoss(weight=weight) loss = loss_fn(outputs, torch.max(labels.features, 1)[1]) if draw_graph: var_dot = dot.make_dot(loss) var_dot.render('SparseUnet3dCNN_graph.pdf') draw_graph = False del var_dot else: outputs = scn.SparseToDense(3, 1)(outputs) criterion = DiceLoss(size_average=False) loss = criterion(outputs.cpu(), labels.cpu( )) #, inputs.getSpatialLocations(), scaling) stats.update(,,[0]) mlog.update_loss(loss, meter='loss') mlog.update_meter(outputs, torch.max(labels.features, 1)[1], meters={'accuracy', 'map'}) add_to_logger(mlog, "Train" if phase == "train" else "Test", epoch, outputs, labels.features, batch_weight, n_classes=nClasses) # backward + optimize only if in training phase if phase == 'train': a = list(model.parameters())[0].clone().data loss.backward() optimizer.step() b = list(model.parameters())[0].clone().data if torch.equal(a, b): print "NOT UPDATED" del a del b bar.set_description("{}: [{}][{}/{}]".format( phase, epoch, data_iter_num + 1, num_batches)) bar.set_postfix(loss="{:.4f} ({:.4f})".format( mlog.meter["loss"].val, mlog.meter["loss"].mean), dice_class1="{:.4f} ({:.4f})".format( mlog.meter["dice_class1"].val, mlog.meter["dice_class1"].mean), weight_dice="{:.4f} ({:.4f})".format( mlog.meter["weighted_dice_wavg"].val, mlog.meter["weighted_dice_wavg"].mean), refresh=False) bar.refresh() del inputs del labels del outputs del loss del loss_fn del batch_weight del sample_weights #Delete all unused objects on the GPU for obj in gc.get_objects(): try: if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor( #print type(obj), obj.size() del obj except (SystemExit, KeyboardInterrupt): raise except Exception as e: pass statsfile, graphs = graph_logger( mlog, "Train" if phase == "train" else "Test", epoch) mlog.reset_meter(epoch, "Train" if phase == "train" else "Test") if check_point:, check_point_epoch_file), check_point_model_file) if callable(checkpoint_callback): checkpoint_callback(epoch, statsfile, graphs, check_point_epoch_file, check_point_model_file) elif callable(checkpoint_callback): checkpoint_callback(epoch, statsfile, graphs, None, None) #stats.plot_final() statsfile, graphs = graph_logger(mlog, "Train" if phase == "train" else "Test", epoch, final=True) time_elapsed = time.time() - since print 'Training complete in {:.0f}m {:.0f}s'.format( time_elapsed / 60, time_elapsed % 60), "{}.pth".format(model_prefix)) if callable(checkpoint_callback): checkpoint_callback(epoch, statsfile, graphs, check_point_epoch_file, check_point_model_file) return model
def test(model_file, ibis_data, input_shape=(512, 512, 512), only_aa=False, only_atom=False, expand_atom=False, num_workers=None, batch_size=20, shuffle=True, use_gpu=True, data_split=0.8, test_full=False, no_batch_norm=False): if num_workers is None: num_workers = multiprocessing.cpu_count() - 1 print "Using {} workers".format(num_workers) since = time.time() if ibis_data == "spheres": from torch_loader import SphereDataset datasets = SphereDataset.get_training_and_validation(input_shape, cnt=1, n_samples=1000, data_split=0.99) nFeatures = 1 validation_batch_size = 1 input_shape = (96, 96, 96) elif os.path.isfile(ibis_data): datasets = IBISDataset.get_training_and_validation( ibis_data, input_shape=input_shape, only_aa=only_aa, only_atom=only_atom, expand_atom=expand_atom, data_split=data_split, train_full=test_full, validate_full=test_full) if only_atom: nFeatures = 5 elif only_aa: nFeatures = 21 else: nFeatures = 59 validation_batch_size = batch_size else: raise RuntimeError("Invalid training data") dataloader = datasets["val"].get_data_loader( batch_size if datasets["val"].train else validation_batch_size, shuffle, num_workers) dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available( ) else 'torch.FloatTensor' model = UNet3D(nFeatures, 1, batchnorm=not no_batch_norm) model.type(dtype) if not os.path.isfile(model_file): raise IOError("Model cannot be opened") model.load_state_dict(torch.load(model_file)) model.train(False) # Set model to evaluate mode criterion = DiceLoss() inputSpatialSize = torch.LongTensor(input_shape) stats = ModelStats() print "Starting Test..." for data_iter_num, data in enumerate(dataloader): if data["data"].__class__.__name__ == "InputBatch": sparse_input = True inputs = data["data"] labels = data["truth"] if use_gpu: inputs = inputs.cuda().to_variable(requires_grad=True) labels = labels.cuda().to_variable() else: inputs = inputs.to_variable(requires_grad=True) labels = labels.to_variable() elif isinstance(data["data"], (list, tuple)): sparse_input = True inputs = scn.InputBatch(3, inputSpatialSize) labels = scn.InputBatch(3, inputSpatialSize) for sample, (indices, features, truth) in enumerate( izip(data["indices"], data["data"], data["truth"])): inputs.addSample() labels.addSample() indices = torch.LongTensor(indices) features = torch.FloatTensor(features) truth = torch.FloatTensor(truth) try: inputs.setLocations(indices, features, 0) #Use 1 to remove duplicate coords? labels.setLocations(indices, truth, 0) except AssertionError: #PDB didn't fit in grid? continue del data inputs.precomputeMetadata(1) if use_gpu: inputs = inputs.cuda().to_variable(requires_grad=True) labels = labels.cuda().to_variable() else: inputs = inputs.to_variable(requires_grad=True) labels = labels.to_variable() elif isinstance(data["data"], torch.FloatTensor): #Input is dense print "Input is Dense" sparse_input = False if use_gpu: inputs = inputs.cuda() labels = labels.cuda() inputs = Variable(data["data"], requires_grad=True) inputs = scn.DenseToSparse(3)(inputs) try: inputs = inputs.cuda().to_variable(requires_grad=True) except: pass labels = Variable(data["truth"]) else: raise RuntimeError("Invalid data from dataset") outputs = model(inputs) if sparse_input: loss = criterion(outputs.features, labels.features) if math.isnan([0]): print "Loss is Nan?" import pdb pdb.set_trace() stats.update(,,[0]) else: outputs = scn.SparseToDense(3, 1)(outputs) loss = criterion(outputs.cpu(), labels.cpu()) stats.update(,,[0]) print "Batch {}: corrects:{:.2f}% nll:{:.2f}% dice:{:.4f}% time:{:.1f}s".format( data_iter_num, stats.correctpct(), stats.nllpct(),[0] * -100, time.time() - since) save_batch_prediction(outputs)"val", 0) stats.plot_final() time_elapsed = time.time() - since print 'Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed / 60, time_elapsed % 60)