def calculate_features(pdb, chain, id): from molmimic.biopdbtools import Structure inputSpatialSize = torch.LongTensor((264, 264, 264)) struct = Structure(pdb, chain, id=id) for rotation in struct.rotate(1000): indices, data = struct.map_atoms_to_voxel_space() inputs = scn.InputBatch(3, inputSpatialSize) inputs.addSample() try: inputs.setLocations( torch.from_numpy(indices).long(), torch.from_numpy(data).float(), 0) except AssertionError: theta, phi, z = rotation[1:] min_coord = np.min(indices, axis=0) max_coord = np.max(indices, axis=0) dist = int(np.ceil(np.linalg.norm(max_coord - min_coord))) with open("{}_{}_{}.txt".format(pdb, chain, id), "a") as f: print("{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}". format(pdb, chain, id, theta, phi, z, dist, min_coord[0], min_coord[1], min_coord[2], max_coord[0], max_coord[1], max_coord[2]), file=f) del inputs del indices del data
def merge(tbl): inp = scn.InputBatch(2, spatial_size) center = spatial_size.float().view(1, 2) / 2 p = torch.LongTensor(2) v = torch.FloatTensor([1, 0, 0]) for char in tbl['input']: inp.add_sample() for stroke in char: stroke = stroke.float() * (Scale - 0.01) / 255 - 0.5 * (Scale - 0.01) stroke += center.expand_as(stroke) ############################################################### # To avoid GIL problems use a helper function: scn.dim_fn(2, 'drawCurve')(inp.metadata.ffi, inp.features, stroke) ############################################################### # Above is equivalent to : # x1,x2,y1,y2,l=0,stroke[0][0],0,stroke[0][1],0 # for i in range(1,stroke.size(0)): # x1=x2 # y1=y2 # x2=stroke[i][0] # y2=stroke[i][1] # l=1e-10+((x2-x1)**2+(y2-y1)**2)**0.5 # v[1]=(x2-x1)/l # v[2]=(y2-y1)/l # l=max(x2-x1,y2-y1,x1-x2,y1-y2,0.9) # for j in numpy.arange(0,1,1/l): # p[0]=math.floor(x1*j+x2*(1-j)) # p[1]=math.floor(y1*j+y2*(1-j)) # inp.set_location(p,v,False) ############################################################### inp.precomputeMetadata(precomputeSize) return {'input': inp, 'target': torch.LongTensor(tbl['target'])}
def SampleLocation(dimension, x, sample, sptialSize): locations = x.getSpatialLocations(sptialSize) sample_idx = sample.nonzero().view(-1) sample_locations = torch.index_select(locations, 0, sample_idx) x_sample = scn.InputBatch(dimension, x.getSpatialSize()) x_sample.setLocations(sample_locations, torch.Tensor(sample_locations.size(0), 1)) return x_sample
def Sample(dimension, x, sample): locations = x.getSpatialLocations() # sample = (torch.Tensor(locations.size(0)).uniform_(0,1) < sample_p).float() sample_idx = sample.nonzero().view(-1) sample_locations = torch.index_select(locations, 0, sample_idx) sample_features = torch.index_select( x.features, 0, Variable(sample_idx.cuda(), requires_grad=False)) x_sample = scn.InputBatch(dimension, x.getSpatialSize()) x_sample.setLocations(sample_locations, torch.Tensor(sample_features.data.shape)) x_sample.features = sample_features # x.locations = locations return x_sample
def merge(tbl): inp = scn.InputBatch(2, spatial_size) center = spatial_size.float().view(1, 2) / 2 p = torch.LongTensor(2) v = torch.FloatTensor([1, 0, 0]) np_random = np.random.RandomState(tbl['idx']) for char in tbl['input']: inp.add_sample() m = torch.eye(2) r = np_random.randint(1, 3) alpha = random.uniform(-0.2, 0.2) if alpha == 1: m[0][1] = alpha elif alpha == 2: m[1][0] = alpha else: m = torch.mm(m, torch.FloatTensor( [[math.cos(alpha), math.sin(alpha)], [-math.sin(alpha), math.cos(alpha)]])) c = center + torch.FloatTensor(1, 2).uniform_(-8, 8) for stroke in char: stroke = stroke.float() / 255 - 0.5 stroke = c.expand_as(stroke) + \ torch.mm(stroke, m * (Scale - 0.01)) ############################################################### # To avoid GIL problems use a helper function: scn.dim_fn( 2, 'drawCurve')( inp.metadata.ffi, inp.features, stroke) ############################################################### # Above is equivalent to : # x1,x2,y1,y2,l=0,stroke[0][0],0,stroke[0][1],0 # for i in range(1,stroke.size(0)): # x1=x2 # y1=y2 # x2=stroke[i][0] # y2=stroke[i][1] # l=1e-10+((x2-x1)**2+(y2-y1)**2)**0.5 # v[1]=(x2-x1)/l # v[2]=(y2-y1)/l # l=max(x2-x1,y2-y1,x1-x2,y1-y2,0.9) # for j in np.arange(0,1,1/l): # p[0]=math.floor(x1*j+x2*(1-j)) # p[1]=math.floor(y1*j+y2*(1-j)) # inp.set_location(p,v,False) ############################################################### inp.precomputeMetadata(precomputeSize) return {'input': inp, 'target': torch.LongTensor(tbl['target']) - 1}
def merge(tbl): inp = scn.InputBatch(2, spatial_size) center = spatial_size.float().view(1, 2) / 2 p = torch.LongTensor(2) v = torch.FloatTensor([1, 0, 0]) for char in tbl['input']: inp.add_sample() for stroke in char: stroke = stroke.float() * (Scale - 0.01) / 255 - 0.5 * (Scale - 0.01) stroke += center.expand_as(stroke) scn.dim_fn(2, 'drawCurve')(inp.metadata.ffi, inp.features, stroke) inp.precomputeMetadata(precomputeSize) return {'input': inp, 'target': torch.LongTensor(tbl['target'])}
def spatialGroupConv(x, model, abc, group_num, group_x=None): # we will use the precompute rules if we have, rules are precomputed by precomputeMetadata() # or it will compute rules itself, see getRuleBook() in Metadata.h if group_x == None: locations = x.getSpatialLocations() batch_size = torch.LongTensor(locations[:, -1]).max() + 1 locations = partition(locations, abc[0], abc[1], abc[2], group_num, batch_size) group_x = scn.InputBatch(3, x.getSpatialSize()) group_x.setLocations(locations, torch.Tensor(len(locations)).view(-1, 1)) assert group_x.features.size(0) == x.features.size(0) group_x.features = x.features group_x = model(group_x) group_x.metadata = x.metadata # gather operation return group_x
def abstract(dimension, input, prediction, kernel_size=1): prediction = prediction.features.data _, predicted = prediction.max(1) if predicted.sum() == 0: print('dangerous! no predicted structure') predicted[:] = 1 structure = input.extractStructure(predicted.cpu(), kernel_size).nonzero().view(-1) input_locations = input.getSpatialLocations() input_features = input.features output_locations = torch.index_select(input_locations, 0, structure) structure = Variable(structure, requires_grad=False).cuda() output_features = torch.index_select(input_features, 0, structure) output = scn.InputBatch(dimension, input.getSpatialSize()) output.setLocations(output_locations, torch.Tensor(output_features.data.shape)) output.features = output_features #Preserve autograd continuity return output
def __init__(self, model, selected_layer, selected_filter, use_gpu=True): self.model = model self.model.eval() self.selected_layer = selected_layer self.selected_filter = selected_filter self.conv_output = 0 # Generate a 3D random volume, fills entire volume, but uses sparse matrices dim = np.arange(0, 96) x, y, z = np.meshgrid(dim, dim, dim) points = zip(x.ravel(), y.ravel(), z.ravel()) features = np.array(list(product([0, 1], repeat=3))) features = features[np.random.choice(8, 96 * 96 * 96)] self.inputSpatialSize = torch.LongTensor((96, 96, 96)) self.created_image = scn.InputBatch(3, self.inputSpatialSize) self.created_image.addSample() indices = torch.LongTensor(points) labels = torch.from_numpy(features).float() self.created_image.setLocations(indices, labels, 0) if use_gpu: self.created_image = self.created_image.cuda() self.created_image = self.created_image.to_variable(requires_grad=True) del indices del labels del features del points del x del y del z del dim # Create the folder to export images if not exists if not os.path.exists('generated'): os.makedirs('generated')
def forward(self, input): assert input.features.ndimension() == 0 or input.features.size( 1) == self.nIn input2 = sparseconvnet.InputBatch(self.dimension, input.spatial_size) input2.setLocations(input.getSpatialLocations(), torch.Tensor(input.features.data.shape)) input2.features = input.features output = SparseConvNetTensor() output.metadata = input2.metadata output.spatial_size =\ (input.spatial_size - 1) * self.filter_stride + self.filter_size output.features = DenseDeconvolutionFunction().apply( input2.features, self.weight, self.bias, input2.metadata, input2.spatial_size, output.spatial_size, self.dimension, self.filter_size, self.filter_stride, ) return output
[['C', 8], ['C', 8], ['MP', 3, 2], ['C', 16], ['C', 16], ['MP', 3, 2], ['C', 24], ['C', 24], ['MP', 3, 2]]) ).add( scn.ValidConvolution(2, 24, 32, 3, False) ).add( scn.BatchNormReLU(32) ).add( scn.SparseToDense(2,32) ) if use_gpu: model.cuda() # output will be 10x10 inputSpatialSize = model.input_spatial_size(torch.LongTensor([10, 10])) input = scn.InputBatch(2, inputSpatialSize) msg = [ " X X XXX X X XX X X XX XXX X XXX ", " X X X X X X X X X X X X X X X X ", " XXXXX XX X X X X X X X X X XXX X X X ", " X X X X X X X X X X X X X X X X X X ", " X X XXX XXX XXX XX X X XX X X XXX XXX "] #Add a sample using setLocation input.addSample() for y, line in enumerate(msg): for x, c in enumerate(line): if c == 'X': location = torch.LongTensor([y, x]) featureVector = torch.FloatTensor([1])
def infer(model, data, input_shape=(256,256,256), use_gpu=True, course_grained=False): model_prefix = "./molmimic_model_{}".format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S')) dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor' inputSpatialSize = torch.LongTensor(input_shape) labels = None mlog = MeterLogger(nclass=2, title="Sparse 3D UNet Inference") phase = "test" epoch = 0 batch_weight = data.get("weight", None) if batch_weight is not None: batch_weight = torch.from_numpy(batch_weight).float().cuda() sample_weights = data.get("sample_weights", None) if sample_weights is not None: sample_weights = torch.from_numpy(sample_weights).float().cuda() use_size_average = False weight = sample_weights if use_size_average else batch_weight if data["data"].__class__.__name__ == "InputBatch": sparse_input = True inputs = data["data"] labels = data["truth"] if "truth" in data else None if use_gpu: inputs = inputs.cuda().to_variable(requires_grad=False) labels = labels.cuda().to_variable() else: inputs = inputs.to_variable(requires_grad=False) labels = labels.to_variable() elif isinstance(data["data"], (list, tuple)): sparse_input = True inputs = scn.InputBatch(3, inputSpatialSize) labels = scn.InputBatch(3, inputSpatialSize) if "truth" in data else None if isinstance(data["data"][0], np.ndarray): long_tensor = lambda arr: torch.from_numpy(arr).long() float_tensor = lambda arr: torch.from_numpy(arr).float() elif isinstance(data["data"][0], (list, tuple)): long_tensor = lambda arr: torch.LongTensor(arr) float_tensor = lambda arr: torch.FloatTensor(arr) else: raise RuntimeError("invalid datatype") for sample, (indices, features, truth) in enumerate(izip(data["indices"], data["data"], data["truth"])): inputs.addSample() if labels is not None: labels.addSample() try: indices = long_tensor(indices) features = float_tensor(features) if labels is not None: truth = float_tensor(truth) except RuntimeError as e: print e continue inputs.setLocations(indices, features, 0) #Use 1 to remove duplicate coords? if labels is not None: labels.setLocations(indices, truth, 0) del data del indices del truth inputs.precomputeMetadata(1) if use_gpu: inputs = inputs.cuda() labels = labels.cuda() inputs = inputs.to_variable(requires_grad=True) labels = labels.to_variable() elif isinstance(data["data"], torch.FloatTensor): #Input is dense print "Input is Dense" sparse_input = False if use_gpu: inputs = inputs.cuda() labels = labels.cuda() inputs = Variable(data["data"], requires_grad=True) inputs = scn.DenseToSparse(3)(inputs) try: inputs = inputs.cuda().to_variable(requires_grad=True) except: pass labels = Variable(data["truth"]) else: raise RuntimeError("Invalid data from dataset") # forward try: outputs = model(inputs) except AssertionError as e: print e #print nFeatures, inputs raise if labels is None: return outputs else: loss_fn = torch.nn.CrossEntropyLoss(weight=weight) loss = loss_fn(outputs, torch.max(labels.features, 1)[1]) mlog.update_loss(loss, meter='loss') mlog.update_meter(outputs, torch.max(labels.features, 1)[1], meters={'accuracy', 'map'}) add_to_logger(mlog, "Train" if phase=="train" else "Test", epoch, outputs, labels.features, batch_weight) del inputs del labels del loss del loss_fn del batch_weight del sample_weights return outputs, mlog
def train(ibis_data, input_shape=(264, 264, 264), model_prefix=None, check_point=True, save_final=True, only_aa=False, only_atom=False, non_geom_features=False, use_deepsite_features=False, expand_atom=False, num_workers=None, num_epochs=30, batch_size=20, shuffle=True, use_gpu=True, initial_learning_rate=0.0001, learning_rate_drop=0.5, learning_rate_epochs=10, lr_decay=4e-2, data_split=0.8, course_grained=False, no_batch_norm=False, use_resnet_unet=False, unclustered=False, undersample=False, oversample=False, nFeatures=None, allow_feature_combos=False, bs_feature=None, bs_feature2=None, bs_features=None, stripes=False, data_parallel=False, dropout_depth=False, dropout_width=False, dropout_p=0.5, wide_model=False, cellular_organisms=False, autoencoder=False, checkpoint_callback=None): if model_prefix is None: model_prefix = "./molmimic_model_{}".format( datetime.now().strftime('%Y-%m-%d_%H:%M:%S')) if num_workers is None: num_workers = multiprocessing.cpu_count() - 1 since = time.time() if ibis_data == "spheres": from torch_loader import SphereDataset nFeatures = nFeatures or 3 datasets = SphereDataset.get_training_and_validation( input_shape, cnt=1, n_samples=1000, nFeatures=nFeatures, allow_feature_combos=allow_feature_combos, bs_feature=bs_feature, bs_feature2=bs_feature2, bs_features=bs_features, stripes=stripes, data_split=0.99) validation_batch_size = 1 if bs_features is not None: nClasses = len(bs_features) + 1 else: nClasses = 2 elif os.path.isfile(ibis_data): dataset = IBISDataset print allow_feature_combos, nFeatures if allow_feature_combos and nFeatures is not None: random_features = (nFeatures, allow_feature_combos, bs_feature, bs_feature2) elif not allow_feature_combos and nFeatures is not None: random_features = (nFeatures, False, bs_feature, bs_feature2) elif allow_feature_combos and nFeatures is None: random_features = None print "ignoring --allow-feature-combos" else: random_features = None datasets = dataset.get_training_and_validation( ibis_data, input_shape=input_shape, only_aa=only_aa, only_atom=only_atom, non_geom_features=non_geom_features, use_deepsite_features=use_deepsite_features, data_split=data_split, course_grained=course_grained, oversample=oversample, undersample=undersample, cellular_organisms=cellular_organisms, random_features=random_features) nFeatures = datasets["train"].get_number_of_features() nClasses = 2 if not autoencoder else nFeatures validation_batch_size = batch_size else: raise RuntimeError("Invalid training data") if num_workers % 2 == 0: num_workers -= 1 num_workers /= 2 num_workers = 6 dataloaders = {name:dataset.get_data_loader( batch_size if dataset.train else validation_batch_size, shuffle, num_workers) \ for name, dataset in datasets.iteritems()} dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available( ) else 'torch.FloatTensor' if use_resnet_unet: model = ResNetUNet(nFeatures, nClasses, dropout_depth=dropout_depth, dropout_width=dropout_width, dropout_p=dropout_p, wide_model=wide_model) else: model = UNet3D(nFeatures, nClasses, batchnorm=not no_batch_norm) if data_parallel: model = torch.nn.DataParallel(model) model.type(dtype) optimizer = SGD(model.parameters(), lr=initial_learning_rate, momentum=0.999, weight_decay=1e-4, nesterov=True) scheduler = LambdaLR(optimizer, lambda epoch: math.exp( (1 - epoch) * lr_decay)) check_point_model_file = "{}_checkpoint_model.pth".format(model_prefix) check_point_epoch_file = "{}_checkpoint_epoch.pth".format(model_prefix) if check_point and os.path.isfile( check_point_model_file) and os.path.isfile(check_point_epoch_file): start_epoch = torch.load(check_point_epoch_file) print "Restarting at epoch {} from {}".format(start_epoch + 1, check_point_model_file) model.load_state_dict(torch.load(check_point_model_file)) else: start_epoch = 0 inputSpatialSize = torch.LongTensor(input_shape) draw_graph = True mlog = MeterLogger(nclass=nClasses, title="Sparse 3D UNet", server="cn4216") #Start clean for obj in gc.get_objects(): try: if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): #print type(obj), obj.size() del obj except (SystemExit, KeyboardInterrupt): raise except Exception as e: pass for epoch in xrange(start_epoch, num_epochs): print "Epoch {}/{}".format(epoch, num_epochs - 1) print "-" * 10 mlog.timer.reset() for phase in ['train', 'val']: datasets[phase].epoch = epoch num_batches = int( np.ceil( len(datasets[phase]) / float(batch_size if phase == "train" else validation_batch_size))) if phase == 'train': scheduler.step() model.train(True) # Set model to training mode else: model.train(False) # Set model to evaluate mode # Iterate over data. bar = tqdm(enumerate(dataloaders[phase]), total=num_batches, unit="batch", desc="Loading data", leave=True) for data_iter_num, data in bar: datasets[phase].batch = data_iter_num batch_weight = data.get("weight", None) if batch_weight is not None: batch_weight = torch.from_numpy(batch_weight).float() if use_gpu: batch_weight = batch_weight.cuda() sample_weights = data.get("sample_weights", None) if sample_weights is not None: sample_weights = torch.from_numpy(sample_weights).float() if use_gpu: sample_weights = sample_weights.cuda() if data["data"].__class__.__name__ == "InputBatch": sparse_input = True inputs = data["data"] labels = data["truth"] if use_gpu: inputs = inputs.cuda().to_variable(requires_grad=True) labels = labels.cuda().to_variable() else: inputs = inputs.to_variable(requires_grad=True) labels = labels.to_variable() elif isinstance(data["data"], (list, tuple)): sparse_input = True inputs = scn.InputBatch(3, inputSpatialSize) labels = scn.InputBatch(3, inputSpatialSize) if isinstance(data["data"][0], np.ndarray): long_tensor = lambda arr: torch.from_numpy(arr).long() float_tensor = lambda arr: torch.from_numpy(arr).float( ) elif isinstance(data["data"][0], (list, tuple)): long_tensor = lambda arr: torch.LongTensor(arr) float_tensor = lambda arr: torch.FloatTensor(arr) else: raise RuntimeError("invalid datatype") for sample, (indices, features, truth, id) in enumerate( izip(data["indices"], data["data"], data["truth"], data["id"])): inputs.addSample() labels.addSample() try: indices = long_tensor(indices) features = float_tensor(features) truth = float_tensor(truth) except RuntimeError as e: print e continue try: inputs.setLocations( indices, features, 0) #Use 1 to remove duplicate coords? labels.setLocations(indices, truth, 0) except AssertionError: print "Error with PDB:", id with open("bad_pdbs.txt", "a") as f: print >> f, id del data del indices del truth inputs.precomputeMetadata(1) if use_gpu: inputs = inputs.cuda() labels = labels.cuda() inputs = inputs.to_variable(requires_grad=True) labels = labels.to_variable() elif isinstance(data["data"], torch.FloatTensor): #Input is dense print "Input is Dense" sparse_input = False if use_gpu: inputs = inputs.cuda() labels = labels.cuda() inputs = Variable(data["data"], requires_grad=True) inputs = scn.DenseToSparse(3)(inputs) try: inputs = inputs.cuda().to_variable(requires_grad=True) except: pass labels = Variable(data["truth"]) else: raise RuntimeError("Invalid data from dataset") # zero the parameter gradients optimizer.zero_grad() # forward try: outputs = model(inputs) except AssertionError: print nFeatures, inputs raise if sparse_input: use_size_average = False weight = sample_weights if use_size_average else batch_weight loss_fn = torch.nn.CrossEntropyLoss(weight=weight) loss = loss_fn(outputs, torch.max(labels.features, 1)[1]) if draw_graph: var_dot = dot.make_dot(loss) var_dot.render('SparseUnet3dCNN_graph.pdf') draw_graph = False del var_dot else: outputs = scn.SparseToDense(3, 1)(outputs) criterion = DiceLoss(size_average=False) loss = criterion(outputs.cpu(), labels.cpu( )) #, inputs.getSpatialLocations(), scaling) stats.update(outputs.data.cpu().view(-1), labels.data.cpu().view(-1), loss.data[0]) mlog.update_loss(loss, meter='loss') mlog.update_meter(outputs, torch.max(labels.features, 1)[1], meters={'accuracy', 'map'}) add_to_logger(mlog, "Train" if phase == "train" else "Test", epoch, outputs, labels.features, batch_weight, n_classes=nClasses) # backward + optimize only if in training phase if phase == 'train': a = list(model.parameters())[0].clone().data loss.backward() optimizer.step() b = list(model.parameters())[0].clone().data if torch.equal(a, b): print "NOT UPDATED" del a del b bar.set_description("{}: [{}][{}/{}]".format( phase, epoch, data_iter_num + 1, num_batches)) bar.set_postfix(loss="{:.4f} ({:.4f})".format( mlog.meter["loss"].val, mlog.meter["loss"].mean), dice_class1="{:.4f} ({:.4f})".format( mlog.meter["dice_class1"].val, mlog.meter["dice_class1"].mean), weight_dice="{:.4f} ({:.4f})".format( mlog.meter["weighted_dice_wavg"].val, mlog.meter["weighted_dice_wavg"].mean), refresh=False) bar.refresh() del inputs del labels del outputs del loss del loss_fn del batch_weight del sample_weights #Delete all unused objects on the GPU for obj in gc.get_objects(): try: if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): #print type(obj), obj.size() del obj except (SystemExit, KeyboardInterrupt): raise except Exception as e: pass statsfile, graphs = graph_logger( mlog, "Train" if phase == "train" else "Test", epoch) mlog.reset_meter(epoch, "Train" if phase == "train" else "Test") if check_point: torch.save(epoch, check_point_epoch_file) torch.save(model.state_dict(), check_point_model_file) if callable(checkpoint_callback): checkpoint_callback(epoch, statsfile, graphs, check_point_epoch_file, check_point_model_file) elif callable(checkpoint_callback): checkpoint_callback(epoch, statsfile, graphs, None, None) #stats.plot_final() statsfile, graphs = graph_logger(mlog, "Train" if phase == "train" else "Test", epoch, final=True) time_elapsed = time.time() - since print 'Training complete in {:.0f}m {:.0f}s'.format( time_elapsed / 60, time_elapsed % 60) torch.save(model.state_dict(), "{}.pth".format(model_prefix)) if callable(checkpoint_callback): checkpoint_callback(epoch, statsfile, graphs, check_point_epoch_file, check_point_model_file) return model
def sparse_collate(data, input_shape=(256,256,256), create_tensor=False): if scn is None or not create_tensor: batch = { "indices": [], "data": [], "truth": [], "id": [] } def add_sample(indices, features, truth, id): batch["indices"].append(indices) batch["data"].append(features) batch["truth"].append(truth) batch["id"].append(id) else: inputSpatialSize = torch.LongTensor(input_shape) batch = { "data": scn.InputBatch(3, inputSpatialSize), "truth": scn.InputBatch(3, inputSpatialSize), "id": [] } def add_sample(indices, features, truth, id): batch["data"].addSample() batch["truth"].addSample() indices = torch.from_numpy(indices).long() try: batch["data"].setLocations(indices, torch.from_numpy(features).float(), 0) #Use 1 to remove duplicate coords? batch["truth"].setLocations(indices, torch.from_numpy(truth).float(), 0) batch["id"].append(id) except AssertionError: #PDB didn't fit in grid? pass #del features #del truth sample_weights = [] batch_weight = None num_data = 0.0 for i, d in enumerate(data): if d["data"] is None: continue if batch_weight is None: batch_weight = 0.0 if d["truth"].shape[1] == 2 else np.zeros(data[0]["truth"].shape[1]) add_sample(d["indices"], d["data"], d["truth"], d["id"]) if d["truth"].shape[1] == 2: num_true = np.sum(d["truth"][:, 0]) true_prob = num_true/float(d["truth"].shape[0]) sample_weights.append(np.array((1-true_prob, true_prob))) batch_weight += num_true num_data += d["truth"].shape[0] else: num_true = np.sum(d["truth"], axis=0) batch_weight += num_true sample_weights.append(num_true/float(d["truth"].shape[0])) num_data += d["truth"].shape[0] batch_weight /= float(num_data) #print "Made batch" if create_tensor: batch["data"].precomputeMetadata(1) if isinstance(batch_weight, float): batch["sample_weights"] = np.array(sample_weights) batch["weight"] = np.array([1-batch_weight, batch_weight]) #None #1.-float(num_true)/len(data) #(256*256*256) else: batch["sample_weights"] = np.array(sample_weights) batch["weight"] = batch_weight return batch
def test(model_file, ibis_data, input_shape=(512, 512, 512), only_aa=False, only_atom=False, expand_atom=False, num_workers=None, batch_size=20, shuffle=True, use_gpu=True, data_split=0.8, test_full=False, no_batch_norm=False): if num_workers is None: num_workers = multiprocessing.cpu_count() - 1 print "Using {} workers".format(num_workers) since = time.time() if ibis_data == "spheres": from torch_loader import SphereDataset datasets = SphereDataset.get_training_and_validation(input_shape, cnt=1, n_samples=1000, data_split=0.99) nFeatures = 1 validation_batch_size = 1 input_shape = (96, 96, 96) elif os.path.isfile(ibis_data): datasets = IBISDataset.get_training_and_validation( ibis_data, input_shape=input_shape, only_aa=only_aa, only_atom=only_atom, expand_atom=expand_atom, data_split=data_split, train_full=test_full, validate_full=test_full) if only_atom: nFeatures = 5 elif only_aa: nFeatures = 21 else: nFeatures = 59 validation_batch_size = batch_size else: raise RuntimeError("Invalid training data") dataloader = datasets["val"].get_data_loader( batch_size if datasets["val"].train else validation_batch_size, shuffle, num_workers) dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available( ) else 'torch.FloatTensor' model = UNet3D(nFeatures, 1, batchnorm=not no_batch_norm) model.type(dtype) if not os.path.isfile(model_file): raise IOError("Model cannot be opened") model.load_state_dict(torch.load(model_file)) model.train(False) # Set model to evaluate mode criterion = DiceLoss() inputSpatialSize = torch.LongTensor(input_shape) stats = ModelStats() print "Starting Test..." for data_iter_num, data in enumerate(dataloader): if data["data"].__class__.__name__ == "InputBatch": sparse_input = True inputs = data["data"] labels = data["truth"] if use_gpu: inputs = inputs.cuda().to_variable(requires_grad=True) labels = labels.cuda().to_variable() else: inputs = inputs.to_variable(requires_grad=True) labels = labels.to_variable() elif isinstance(data["data"], (list, tuple)): sparse_input = True inputs = scn.InputBatch(3, inputSpatialSize) labels = scn.InputBatch(3, inputSpatialSize) for sample, (indices, features, truth) in enumerate( izip(data["indices"], data["data"], data["truth"])): inputs.addSample() labels.addSample() indices = torch.LongTensor(indices) features = torch.FloatTensor(features) truth = torch.FloatTensor(truth) try: inputs.setLocations(indices, features, 0) #Use 1 to remove duplicate coords? labels.setLocations(indices, truth, 0) except AssertionError: #PDB didn't fit in grid? continue del data inputs.precomputeMetadata(1) if use_gpu: inputs = inputs.cuda().to_variable(requires_grad=True) labels = labels.cuda().to_variable() else: inputs = inputs.to_variable(requires_grad=True) labels = labels.to_variable() elif isinstance(data["data"], torch.FloatTensor): #Input is dense print "Input is Dense" sparse_input = False if use_gpu: inputs = inputs.cuda() labels = labels.cuda() inputs = Variable(data["data"], requires_grad=True) inputs = scn.DenseToSparse(3)(inputs) try: inputs = inputs.cuda().to_variable(requires_grad=True) except: pass labels = Variable(data["truth"]) else: raise RuntimeError("Invalid data from dataset") outputs = model(inputs) if sparse_input: loss = criterion(outputs.features, labels.features) if math.isnan(loss.data[0]): print "Loss is Nan?" import pdb pdb.set_trace() stats.update(outputs.features.data, labels.features.data, loss.data[0]) else: outputs = scn.SparseToDense(3, 1)(outputs) loss = criterion(outputs.cpu(), labels.cpu()) stats.update(outputs.data.cpu().view(-1), labels.data.cpu().view(-1), loss.data[0]) print "Batch {}: corrects:{:.2f}% nll:{:.2f}% dice:{:.4f}% time:{:.1f}s".format( data_iter_num, stats.correctpct(), stats.nllpct(), loss.data[0] * -100, time.time() - since) save_batch_prediction(outputs) stats.save("val", 0) stats.plot_final() time_elapsed = time.time() - since print 'Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed / 60, time_elapsed % 60)