def calculate_features(pdb, chain, id):
    from molmimic.biopdbtools import Structure
    inputSpatialSize = torch.LongTensor((264, 264, 264))
    struct = Structure(pdb, chain, id=id)
    for rotation in struct.rotate(1000):
        indices, data = struct.map_atoms_to_voxel_space()
        inputs = scn.InputBatch(3, inputSpatialSize)
        inputs.addSample()
        try:
            inputs.setLocations(
                torch.from_numpy(indices).long(),
                torch.from_numpy(data).float(), 0)
        except AssertionError:
            theta, phi, z = rotation[1:]
            min_coord = np.min(indices, axis=0)
            max_coord = np.max(indices, axis=0)
            dist = int(np.ceil(np.linalg.norm(max_coord - min_coord)))

            with open("{}_{}_{}.txt".format(pdb, chain, id), "a") as f:
                print("{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}".
                      format(pdb, chain, id, theta, phi, z, dist, min_coord[0],
                             min_coord[1], min_coord[2], max_coord[0],
                             max_coord[1], max_coord[2]),
                      file=f)
        del inputs
        del indices
        del data
示例#2
0
 def merge(tbl):
     inp = scn.InputBatch(2, spatial_size)
     center = spatial_size.float().view(1, 2) / 2
     p = torch.LongTensor(2)
     v = torch.FloatTensor([1, 0, 0])
     for char in tbl['input']:
         inp.add_sample()
         for stroke in char:
             stroke = stroke.float() * (Scale - 0.01) / 255 - 0.5 * (Scale -
                                                                     0.01)
             stroke += center.expand_as(stroke)
             ###############################################################
             # To avoid GIL problems use a helper function:
             scn.dim_fn(2, 'drawCurve')(inp.metadata.ffi, inp.features,
                                        stroke)
             ###############################################################
             # Above is equivalent to :
             # x1,x2,y1,y2,l=0,stroke[0][0],0,stroke[0][1],0
             # for i in range(1,stroke.size(0)):
             #     x1=x2
             #     y1=y2
             #     x2=stroke[i][0]
             #     y2=stroke[i][1]
             #     l=1e-10+((x2-x1)**2+(y2-y1)**2)**0.5
             #     v[1]=(x2-x1)/l
             #     v[2]=(y2-y1)/l
             #     l=max(x2-x1,y2-y1,x1-x2,y1-y2,0.9)
             #     for j in numpy.arange(0,1,1/l):
             #         p[0]=math.floor(x1*j+x2*(1-j))
             #         p[1]=math.floor(y1*j+y2*(1-j))
             #         inp.set_location(p,v,False)
             ###############################################################
     inp.precomputeMetadata(precomputeSize)
     return {'input': inp, 'target': torch.LongTensor(tbl['target'])}
示例#3
0
def SampleLocation(dimension, x, sample, sptialSize):
    locations = x.getSpatialLocations(sptialSize)
    sample_idx = sample.nonzero().view(-1)
    sample_locations = torch.index_select(locations, 0, sample_idx)
    x_sample = scn.InputBatch(dimension, x.getSpatialSize())
    x_sample.setLocations(sample_locations,
                          torch.Tensor(sample_locations.size(0), 1))
    return x_sample
示例#4
0
def Sample(dimension, x, sample):
    locations = x.getSpatialLocations()
    # sample = (torch.Tensor(locations.size(0)).uniform_(0,1) < sample_p).float()
    sample_idx = sample.nonzero().view(-1)
    sample_locations = torch.index_select(locations, 0, sample_idx)
    sample_features = torch.index_select(
        x.features, 0, Variable(sample_idx.cuda(), requires_grad=False))
    x_sample = scn.InputBatch(dimension, x.getSpatialSize())
    x_sample.setLocations(sample_locations,
                          torch.Tensor(sample_features.data.shape))
    x_sample.features = sample_features
    # x.locations = locations
    return x_sample
示例#5
0
 def merge(tbl):
     inp = scn.InputBatch(2, spatial_size)
     center = spatial_size.float().view(1, 2) / 2
     p = torch.LongTensor(2)
     v = torch.FloatTensor([1, 0, 0])
     np_random = np.random.RandomState(tbl['idx'])
     for char in tbl['input']:
         inp.add_sample()
         m = torch.eye(2)
         r = np_random.randint(1, 3)
         alpha = random.uniform(-0.2, 0.2)
         if alpha == 1:
             m[0][1] = alpha
         elif alpha == 2:
             m[1][0] = alpha
         else:
             m = torch.mm(m, torch.FloatTensor(
                 [[math.cos(alpha), math.sin(alpha)],
                  [-math.sin(alpha), math.cos(alpha)]]))
         c = center + torch.FloatTensor(1, 2).uniform_(-8, 8)
         for stroke in char:
             stroke = stroke.float() / 255 - 0.5
             stroke = c.expand_as(stroke) + \
                 torch.mm(stroke, m * (Scale - 0.01))
             ###############################################################
             # To avoid GIL problems use a helper function:
             scn.dim_fn(
                 2,
                 'drawCurve')(
                 inp.metadata.ffi,
                 inp.features,
                 stroke)
             ###############################################################
             # Above is equivalent to :
             # x1,x2,y1,y2,l=0,stroke[0][0],0,stroke[0][1],0
             # for i in range(1,stroke.size(0)):
             #     x1=x2
             #     y1=y2
             #     x2=stroke[i][0]
             #     y2=stroke[i][1]
             #     l=1e-10+((x2-x1)**2+(y2-y1)**2)**0.5
             #     v[1]=(x2-x1)/l
             #     v[2]=(y2-y1)/l
             #     l=max(x2-x1,y2-y1,x1-x2,y1-y2,0.9)
             #     for j in np.arange(0,1,1/l):
             #         p[0]=math.floor(x1*j+x2*(1-j))
             #         p[1]=math.floor(y1*j+y2*(1-j))
             #         inp.set_location(p,v,False)
             ###############################################################
     inp.precomputeMetadata(precomputeSize)
     return {'input': inp, 'target': torch.LongTensor(tbl['target']) - 1}
示例#6
0
 def merge(tbl):
     inp = scn.InputBatch(2, spatial_size)
     center = spatial_size.float().view(1, 2) / 2
     p = torch.LongTensor(2)
     v = torch.FloatTensor([1, 0, 0])
     for char in tbl['input']:
         inp.add_sample()
         for stroke in char:
             stroke = stroke.float() * (Scale - 0.01) / 255 - 0.5 * (Scale -
                                                                     0.01)
             stroke += center.expand_as(stroke)
             scn.dim_fn(2, 'drawCurve')(inp.metadata.ffi, inp.features,
                                        stroke)
     inp.precomputeMetadata(precomputeSize)
     return {'input': inp, 'target': torch.LongTensor(tbl['target'])}
示例#7
0
def spatialGroupConv(x, model, abc, group_num, group_x=None):
    # we will use the precompute rules if we have, rules are precomputed by precomputeMetadata()
    # or it will compute rules itself, see getRuleBook() in Metadata.h
    if group_x == None:
        locations = x.getSpatialLocations()
        batch_size = torch.LongTensor(locations[:, -1]).max() + 1
        locations = partition(locations, abc[0], abc[1], abc[2], group_num,
                              batch_size)
        group_x = scn.InputBatch(3, x.getSpatialSize())
        group_x.setLocations(locations,
                             torch.Tensor(len(locations)).view(-1, 1))
    assert group_x.features.size(0) == x.features.size(0)
    group_x.features = x.features
    group_x = model(group_x)
    group_x.metadata = x.metadata  # gather operation
    return group_x
示例#8
0
def abstract(dimension, input, prediction, kernel_size=1):
    prediction = prediction.features.data
    _, predicted = prediction.max(1)
    if predicted.sum() == 0:
        print('dangerous! no predicted structure')
        predicted[:] = 1
    structure = input.extractStructure(predicted.cpu(),
                                       kernel_size).nonzero().view(-1)
    input_locations = input.getSpatialLocations()
    input_features = input.features
    output_locations = torch.index_select(input_locations, 0, structure)

    structure = Variable(structure, requires_grad=False).cuda()
    output_features = torch.index_select(input_features, 0, structure)

    output = scn.InputBatch(dimension, input.getSpatialSize())
    output.setLocations(output_locations,
                        torch.Tensor(output_features.data.shape))
    output.features = output_features  #Preserve autograd continuity
    return output
示例#9
0
    def __init__(self, model, selected_layer, selected_filter, use_gpu=True):
        self.model = model
        self.model.eval()
        self.selected_layer = selected_layer
        self.selected_filter = selected_filter
        self.conv_output = 0

        # Generate a 3D random volume, fills entire volume, but uses sparse matrices
        dim = np.arange(0, 96)
        x, y, z = np.meshgrid(dim, dim, dim)
        points = zip(x.ravel(), y.ravel(), z.ravel())
        features = np.array(list(product([0, 1], repeat=3)))
        features = features[np.random.choice(8, 96 * 96 * 96)]

        self.inputSpatialSize = torch.LongTensor((96, 96, 96))
        self.created_image = scn.InputBatch(3, self.inputSpatialSize)
        self.created_image.addSample()
        indices = torch.LongTensor(points)
        labels = torch.from_numpy(features).float()
        self.created_image.setLocations(indices, labels, 0)

        if use_gpu:
            self.created_image = self.created_image.cuda()

        self.created_image = self.created_image.to_variable(requires_grad=True)

        del indices
        del labels
        del features
        del points
        del x
        del y
        del z
        del dim

        # Create the folder to export images if not exists
        if not os.path.exists('generated'):
            os.makedirs('generated')
示例#10
0
    def forward(self, input):
        assert input.features.ndimension() == 0 or input.features.size(
            1) == self.nIn
        input2 = sparseconvnet.InputBatch(self.dimension, input.spatial_size)
        input2.setLocations(input.getSpatialLocations(),
                            torch.Tensor(input.features.data.shape))
        input2.features = input.features

        output = SparseConvNetTensor()
        output.metadata = input2.metadata
        output.spatial_size =\
            (input.spatial_size - 1) * self.filter_stride + self.filter_size
        output.features = DenseDeconvolutionFunction().apply(
            input2.features,
            self.weight,
            self.bias,
            input2.metadata,
            input2.spatial_size,
            output.spatial_size,
            self.dimension,
            self.filter_size,
            self.filter_stride,
        )
        return output
示例#11
0
		     [['C',  8], ['C',  8], ['MP', 3, 2],
		      ['C', 16], ['C', 16], ['MP', 3, 2],
		      ['C', 24], ['C', 24], ['MP', 3, 2]])
).add(
    scn.ValidConvolution(2, 24, 32, 3, False)
).add(
    scn.BatchNormReLU(32)
).add(
    scn.SparseToDense(2,32)
)
if use_gpu:
    model.cuda()

# output will be 10x10
inputSpatialSize = model.input_spatial_size(torch.LongTensor([10, 10]))
input = scn.InputBatch(2, inputSpatialSize)

msg = [
    " X   X  XXX  X    X    XX     X       X   XX   XXX   X    XXX   ",
    " X   X  X    X    X   X  X    X       X  X  X  X  X  X    X  X  ",
    " XXXXX  XX   X    X   X  X    X   X   X  X  X  XXX   X    X   X ",
    " X   X  X    X    X   X  X     X X X X   X  X  X  X  X    X  X  ",
    " X   X  XXX  XXX  XXX  XX       X   X     XX   X  X  XXX  XXX   "]

#Add a sample using setLocation
input.addSample()
for y, line in enumerate(msg):
    for x, c in enumerate(line):
	if c == 'X':
	    location = torch.LongTensor([y, x])
	    featureVector = torch.FloatTensor([1])
示例#12
0
def infer(model, data, input_shape=(256,256,256), use_gpu=True, course_grained=False):
    model_prefix = "./molmimic_model_{}".format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))

    dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor'
    inputSpatialSize = torch.LongTensor(input_shape)

    labels = None

    mlog = MeterLogger(nclass=2, title="Sparse 3D UNet Inference")

    phase = "test"
    epoch = 0

    batch_weight = data.get("weight", None)
    if batch_weight is not None:
        batch_weight = torch.from_numpy(batch_weight).float().cuda()
    sample_weights = data.get("sample_weights", None)
    if sample_weights is not None:
        sample_weights = torch.from_numpy(sample_weights).float().cuda()
        use_size_average = False
        weight = sample_weights if use_size_average else batch_weight

    if data["data"].__class__.__name__ == "InputBatch":
        sparse_input = True
        inputs = data["data"]
        labels = data["truth"] if "truth" in data else None
        if use_gpu:
            inputs = inputs.cuda().to_variable(requires_grad=False)
            labels = labels.cuda().to_variable()
        else:
            inputs = inputs.to_variable(requires_grad=False)
            labels = labels.to_variable()

    elif isinstance(data["data"], (list, tuple)):
        sparse_input = True
        inputs = scn.InputBatch(3, inputSpatialSize)
        labels = scn.InputBatch(3, inputSpatialSize) if "truth" in data else None

        if isinstance(data["data"][0], np.ndarray):
            long_tensor = lambda arr: torch.from_numpy(arr).long()
            float_tensor = lambda arr: torch.from_numpy(arr).float()
        elif isinstance(data["data"][0], (list, tuple)):
            long_tensor = lambda arr: torch.LongTensor(arr)
            float_tensor = lambda arr: torch.FloatTensor(arr)
        else:
            raise RuntimeError("invalid datatype")

        for sample, (indices, features, truth) in enumerate(izip(data["indices"], data["data"], data["truth"])):
            inputs.addSample()
            if labels is not None:
                labels.addSample()

            try:
                indices = long_tensor(indices)
                features = float_tensor(features)
                if labels is not None:
                    truth = float_tensor(truth)
            except RuntimeError as e:
                print e
                continue

            inputs.setLocations(indices, features, 0) #Use 1 to remove duplicate coords?
            if labels is not None:
                labels.setLocations(indices, truth, 0)

        del data
        del indices
        del truth

        inputs.precomputeMetadata(1)

        if use_gpu:
            inputs = inputs.cuda()
            labels = labels.cuda()

        inputs = inputs.to_variable(requires_grad=True)
        labels = labels.to_variable()

    elif isinstance(data["data"], torch.FloatTensor):
        #Input is dense
        print "Input is Dense"
        sparse_input = False
        if use_gpu:
            inputs = inputs.cuda()
            labels = labels.cuda()
        inputs = Variable(data["data"], requires_grad=True)
        inputs = scn.DenseToSparse(3)(inputs)
        try:
            inputs = inputs.cuda().to_variable(requires_grad=True)
        except:
            pass
        labels = Variable(data["truth"])

    else:
        raise RuntimeError("Invalid data from dataset")

    # forward
    try:
        outputs = model(inputs)
    except AssertionError as e:
        print e
        #print nFeatures, inputs
        raise

    if labels is None:
        return outputs
    else:
        loss_fn = torch.nn.CrossEntropyLoss(weight=weight)

        loss = loss_fn(outputs, torch.max(labels.features, 1)[1])

        mlog.update_loss(loss, meter='loss')
        mlog.update_meter(outputs, torch.max(labels.features, 1)[1], meters={'accuracy', 'map'})
        add_to_logger(mlog,  "Train" if phase=="train" else "Test", epoch, outputs, labels.features, batch_weight)

        del inputs
        del labels
        del loss
        del loss_fn
        del batch_weight
        del sample_weights

        return outputs, mlog
示例#13
0
def train(ibis_data,
          input_shape=(264, 264, 264),
          model_prefix=None,
          check_point=True,
          save_final=True,
          only_aa=False,
          only_atom=False,
          non_geom_features=False,
          use_deepsite_features=False,
          expand_atom=False,
          num_workers=None,
          num_epochs=30,
          batch_size=20,
          shuffle=True,
          use_gpu=True,
          initial_learning_rate=0.0001,
          learning_rate_drop=0.5,
          learning_rate_epochs=10,
          lr_decay=4e-2,
          data_split=0.8,
          course_grained=False,
          no_batch_norm=False,
          use_resnet_unet=False,
          unclustered=False,
          undersample=False,
          oversample=False,
          nFeatures=None,
          allow_feature_combos=False,
          bs_feature=None,
          bs_feature2=None,
          bs_features=None,
          stripes=False,
          data_parallel=False,
          dropout_depth=False,
          dropout_width=False,
          dropout_p=0.5,
          wide_model=False,
          cellular_organisms=False,
          autoencoder=False,
          checkpoint_callback=None):
    if model_prefix is None:
        model_prefix = "./molmimic_model_{}".format(
            datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))

    if num_workers is None:
        num_workers = multiprocessing.cpu_count() - 1

    since = time.time()

    if ibis_data == "spheres":
        from torch_loader import SphereDataset
        nFeatures = nFeatures or 3
        datasets = SphereDataset.get_training_and_validation(
            input_shape,
            cnt=1,
            n_samples=1000,
            nFeatures=nFeatures,
            allow_feature_combos=allow_feature_combos,
            bs_feature=bs_feature,
            bs_feature2=bs_feature2,
            bs_features=bs_features,
            stripes=stripes,
            data_split=0.99)
        validation_batch_size = 1
        if bs_features is not None:
            nClasses = len(bs_features) + 1
        else:
            nClasses = 2
    elif os.path.isfile(ibis_data):
        dataset = IBISDataset
        print allow_feature_combos, nFeatures
        if allow_feature_combos and nFeatures is not None:
            random_features = (nFeatures, allow_feature_combos, bs_feature,
                               bs_feature2)
        elif not allow_feature_combos and nFeatures is not None:
            random_features = (nFeatures, False, bs_feature, bs_feature2)
        elif allow_feature_combos and nFeatures is None:
            random_features = None
            print "ignoring --allow-feature-combos"
        else:
            random_features = None

        datasets = dataset.get_training_and_validation(
            ibis_data,
            input_shape=input_shape,
            only_aa=only_aa,
            only_atom=only_atom,
            non_geom_features=non_geom_features,
            use_deepsite_features=use_deepsite_features,
            data_split=data_split,
            course_grained=course_grained,
            oversample=oversample,
            undersample=undersample,
            cellular_organisms=cellular_organisms,
            random_features=random_features)
        nFeatures = datasets["train"].get_number_of_features()
        nClasses = 2 if not autoencoder else nFeatures

        validation_batch_size = batch_size
    else:
        raise RuntimeError("Invalid training data")

    if num_workers % 2 == 0:
        num_workers -= 1
    num_workers /= 2
    num_workers = 6

    dataloaders = {name:dataset.get_data_loader(
        batch_size if dataset.train else validation_batch_size,
        shuffle,
        num_workers) \
        for name, dataset in datasets.iteritems()}

    dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available(
    ) else 'torch.FloatTensor'

    if use_resnet_unet:
        model = ResNetUNet(nFeatures,
                           nClasses,
                           dropout_depth=dropout_depth,
                           dropout_width=dropout_width,
                           dropout_p=dropout_p,
                           wide_model=wide_model)
    else:
        model = UNet3D(nFeatures, nClasses, batchnorm=not no_batch_norm)

    if data_parallel:
        model = torch.nn.DataParallel(model)

    model.type(dtype)

    optimizer = SGD(model.parameters(),
                    lr=initial_learning_rate,
                    momentum=0.999,
                    weight_decay=1e-4,
                    nesterov=True)

    scheduler = LambdaLR(optimizer, lambda epoch: math.exp(
        (1 - epoch) * lr_decay))

    check_point_model_file = "{}_checkpoint_model.pth".format(model_prefix)
    check_point_epoch_file = "{}_checkpoint_epoch.pth".format(model_prefix)
    if check_point and os.path.isfile(
            check_point_model_file) and os.path.isfile(check_point_epoch_file):
        start_epoch = torch.load(check_point_epoch_file)
        print "Restarting at epoch {} from {}".format(start_epoch + 1,
                                                      check_point_model_file)
        model.load_state_dict(torch.load(check_point_model_file))
    else:
        start_epoch = 0

    inputSpatialSize = torch.LongTensor(input_shape)

    draw_graph = True

    mlog = MeterLogger(nclass=nClasses,
                       title="Sparse 3D UNet",
                       server="cn4216")

    #Start clean
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data')
                                        and torch.is_tensor(obj.data)):
                #print type(obj), obj.size()
                del obj
        except (SystemExit, KeyboardInterrupt):
            raise
        except Exception as e:
            pass

    for epoch in xrange(start_epoch, num_epochs):
        print "Epoch {}/{}".format(epoch, num_epochs - 1)
        print "-" * 10

        mlog.timer.reset()
        for phase in ['train', 'val']:
            datasets[phase].epoch = epoch
            num_batches = int(
                np.ceil(
                    len(datasets[phase]) /
                    float(batch_size if phase ==
                          "train" else validation_batch_size)))

            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            # Iterate over data.
            bar = tqdm(enumerate(dataloaders[phase]),
                       total=num_batches,
                       unit="batch",
                       desc="Loading data",
                       leave=True)
            for data_iter_num, data in bar:
                datasets[phase].batch = data_iter_num
                batch_weight = data.get("weight", None)

                if batch_weight is not None:
                    batch_weight = torch.from_numpy(batch_weight).float()
                    if use_gpu:
                        batch_weight = batch_weight.cuda()

                sample_weights = data.get("sample_weights", None)

                if sample_weights is not None:
                    sample_weights = torch.from_numpy(sample_weights).float()
                    if use_gpu:
                        sample_weights = sample_weights.cuda()

                if data["data"].__class__.__name__ == "InputBatch":
                    sparse_input = True
                    inputs = data["data"]
                    labels = data["truth"]
                    if use_gpu:
                        inputs = inputs.cuda().to_variable(requires_grad=True)
                        labels = labels.cuda().to_variable()
                    else:
                        inputs = inputs.to_variable(requires_grad=True)
                        labels = labels.to_variable()

                elif isinstance(data["data"], (list, tuple)):
                    sparse_input = True
                    inputs = scn.InputBatch(3, inputSpatialSize)
                    labels = scn.InputBatch(3, inputSpatialSize)

                    if isinstance(data["data"][0], np.ndarray):
                        long_tensor = lambda arr: torch.from_numpy(arr).long()
                        float_tensor = lambda arr: torch.from_numpy(arr).float(
                        )
                    elif isinstance(data["data"][0], (list, tuple)):
                        long_tensor = lambda arr: torch.LongTensor(arr)
                        float_tensor = lambda arr: torch.FloatTensor(arr)
                    else:
                        raise RuntimeError("invalid datatype")

                    for sample, (indices, features, truth, id) in enumerate(
                            izip(data["indices"], data["data"], data["truth"],
                                 data["id"])):
                        inputs.addSample()
                        labels.addSample()

                        try:
                            indices = long_tensor(indices)
                            features = float_tensor(features)
                            truth = float_tensor(truth)
                        except RuntimeError as e:
                            print e
                            continue

                        try:
                            inputs.setLocations(
                                indices, features,
                                0)  #Use 1 to remove duplicate coords?
                            labels.setLocations(indices, truth, 0)
                        except AssertionError:
                            print "Error with PDB:", id
                            with open("bad_pdbs.txt", "a") as f:
                                print >> f, id

                    del data
                    del indices
                    del truth

                    inputs.precomputeMetadata(1)

                    if use_gpu:
                        inputs = inputs.cuda()
                        labels = labels.cuda()

                    inputs = inputs.to_variable(requires_grad=True)
                    labels = labels.to_variable()

                elif isinstance(data["data"], torch.FloatTensor):
                    #Input is dense
                    print "Input is Dense"
                    sparse_input = False
                    if use_gpu:
                        inputs = inputs.cuda()
                        labels = labels.cuda()
                    inputs = Variable(data["data"], requires_grad=True)
                    inputs = scn.DenseToSparse(3)(inputs)
                    try:
                        inputs = inputs.cuda().to_variable(requires_grad=True)
                    except:
                        pass
                    labels = Variable(data["truth"])

                else:
                    raise RuntimeError("Invalid data from dataset")

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                try:
                    outputs = model(inputs)
                except AssertionError:
                    print nFeatures, inputs
                    raise

                if sparse_input:
                    use_size_average = False
                    weight = sample_weights if use_size_average else batch_weight

                    loss_fn = torch.nn.CrossEntropyLoss(weight=weight)

                    loss = loss_fn(outputs, torch.max(labels.features, 1)[1])

                    if draw_graph:
                        var_dot = dot.make_dot(loss)
                        var_dot.render('SparseUnet3dCNN_graph.pdf')
                        draw_graph = False
                        del var_dot

                else:
                    outputs = scn.SparseToDense(3, 1)(outputs)
                    criterion = DiceLoss(size_average=False)
                    loss = criterion(outputs.cpu(), labels.cpu(
                    ))  #, inputs.getSpatialLocations(), scaling)
                    stats.update(outputs.data.cpu().view(-1),
                                 labels.data.cpu().view(-1), loss.data[0])

                mlog.update_loss(loss, meter='loss')
                mlog.update_meter(outputs,
                                  torch.max(labels.features, 1)[1],
                                  meters={'accuracy', 'map'})
                add_to_logger(mlog,
                              "Train" if phase == "train" else "Test",
                              epoch,
                              outputs,
                              labels.features,
                              batch_weight,
                              n_classes=nClasses)

                # backward + optimize only if in training phase
                if phase == 'train':
                    a = list(model.parameters())[0].clone().data
                    loss.backward()
                    optimizer.step()
                    b = list(model.parameters())[0].clone().data
                    if torch.equal(a, b): print "NOT UPDATED"
                    del a
                    del b

                bar.set_description("{}: [{}][{}/{}]".format(
                    phase, epoch, data_iter_num + 1, num_batches))
                bar.set_postfix(loss="{:.4f} ({:.4f})".format(
                    mlog.meter["loss"].val, mlog.meter["loss"].mean),
                                dice_class1="{:.4f} ({:.4f})".format(
                                    mlog.meter["dice_class1"].val,
                                    mlog.meter["dice_class1"].mean),
                                weight_dice="{:.4f} ({:.4f})".format(
                                    mlog.meter["weighted_dice_wavg"].val,
                                    mlog.meter["weighted_dice_wavg"].mean),
                                refresh=False)
                bar.refresh()

                del inputs
                del labels
                del outputs
                del loss
                del loss_fn
                del batch_weight
                del sample_weights

                #Delete all unused objects on the GPU
                for obj in gc.get_objects():
                    try:
                        if torch.is_tensor(obj) or (hasattr(obj, 'data') and
                                                    torch.is_tensor(obj.data)):
                            #print type(obj), obj.size()
                            del obj
                    except (SystemExit, KeyboardInterrupt):
                        raise
                    except Exception as e:
                        pass

            statsfile, graphs = graph_logger(
                mlog, "Train" if phase == "train" else "Test", epoch)
            mlog.reset_meter(epoch, "Train" if phase == "train" else "Test")

            if check_point:
                torch.save(epoch, check_point_epoch_file)
                torch.save(model.state_dict(), check_point_model_file)
                if callable(checkpoint_callback):
                    checkpoint_callback(epoch, statsfile, graphs,
                                        check_point_epoch_file,
                                        check_point_model_file)
            elif callable(checkpoint_callback):
                checkpoint_callback(epoch, statsfile, graphs, None, None)

    #stats.plot_final()

    statsfile, graphs = graph_logger(mlog,
                                     "Train" if phase == "train" else "Test",
                                     epoch,
                                     final=True)

    time_elapsed = time.time() - since
    print 'Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed / 60, time_elapsed % 60)

    torch.save(model.state_dict(), "{}.pth".format(model_prefix))

    if callable(checkpoint_callback):
        checkpoint_callback(epoch, statsfile, graphs, check_point_epoch_file,
                            check_point_model_file)

    return model
示例#14
0
def sparse_collate(data, input_shape=(256,256,256), create_tensor=False):
    if scn is None or not create_tensor:
        batch = {
            "indices": [],
            "data": [],
            "truth": [],
            "id": []
            }

        def add_sample(indices, features, truth, id):
            batch["indices"].append(indices)
            batch["data"].append(features)
            batch["truth"].append(truth)
            batch["id"].append(id)

    else:
        inputSpatialSize = torch.LongTensor(input_shape)
        batch = {
            "data": scn.InputBatch(3, inputSpatialSize),
            "truth": scn.InputBatch(3, inputSpatialSize),
            "id": []
        }

        def add_sample(indices, features, truth, id):
            batch["data"].addSample()
            batch["truth"].addSample()
            indices = torch.from_numpy(indices).long()
            try:
                batch["data"].setLocations(indices, torch.from_numpy(features).float(), 0) #Use 1 to remove duplicate coords?
                batch["truth"].setLocations(indices, torch.from_numpy(truth).float(), 0)
                batch["id"].append(id)
            except AssertionError:
                #PDB didn't fit in grid?
                pass
            #del features
            #del truth

    sample_weights = []
    batch_weight = None
    num_data = 0.0
    for i, d in enumerate(data):
        if d["data"] is None: continue
        if batch_weight is None:
            batch_weight = 0.0 if d["truth"].shape[1] == 2 else np.zeros(data[0]["truth"].shape[1])

        add_sample(d["indices"], d["data"], d["truth"], d["id"])
        if d["truth"].shape[1] == 2:
            num_true = np.sum(d["truth"][:, 0])
            true_prob = num_true/float(d["truth"].shape[0])
            sample_weights.append(np.array((1-true_prob, true_prob)))
            batch_weight += num_true
            num_data += d["truth"].shape[0]
        else:
            num_true = np.sum(d["truth"], axis=0)
            batch_weight += num_true
            sample_weights.append(num_true/float(d["truth"].shape[0]))
            num_data += d["truth"].shape[0]

    batch_weight /= float(num_data)

    #print "Made batch"
    if create_tensor:
        batch["data"].precomputeMetadata(1)

    if isinstance(batch_weight, float):
        batch["sample_weights"] = np.array(sample_weights)
        batch["weight"] = np.array([1-batch_weight, batch_weight]) #None #1.-float(num_true)/len(data) #(256*256*256)
    else:
        batch["sample_weights"] = np.array(sample_weights)
        batch["weight"] = batch_weight
    return batch
示例#15
0
def test(model_file,
         ibis_data,
         input_shape=(512, 512, 512),
         only_aa=False,
         only_atom=False,
         expand_atom=False,
         num_workers=None,
         batch_size=20,
         shuffle=True,
         use_gpu=True,
         data_split=0.8,
         test_full=False,
         no_batch_norm=False):

    if num_workers is None:
        num_workers = multiprocessing.cpu_count() - 1
    print "Using {} workers".format(num_workers)

    since = time.time()

    if ibis_data == "spheres":
        from torch_loader import SphereDataset
        datasets = SphereDataset.get_training_and_validation(input_shape,
                                                             cnt=1,
                                                             n_samples=1000,
                                                             data_split=0.99)
        nFeatures = 1
        validation_batch_size = 1
        input_shape = (96, 96, 96)
    elif os.path.isfile(ibis_data):
        datasets = IBISDataset.get_training_and_validation(
            ibis_data,
            input_shape=input_shape,
            only_aa=only_aa,
            only_atom=only_atom,
            expand_atom=expand_atom,
            data_split=data_split,
            train_full=test_full,
            validate_full=test_full)
        if only_atom:
            nFeatures = 5
        elif only_aa:
            nFeatures = 21
        else:
            nFeatures = 59
        validation_batch_size = batch_size
    else:
        raise RuntimeError("Invalid training data")

    dataloader = datasets["val"].get_data_loader(
        batch_size if datasets["val"].train else validation_batch_size,
        shuffle, num_workers)

    dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available(
    ) else 'torch.FloatTensor'

    model = UNet3D(nFeatures, 1, batchnorm=not no_batch_norm)
    model.type(dtype)

    if not os.path.isfile(model_file):
        raise IOError("Model cannot be opened")

    model.load_state_dict(torch.load(model_file))
    model.train(False)  # Set model to evaluate mode

    criterion = DiceLoss()

    inputSpatialSize = torch.LongTensor(input_shape)

    stats = ModelStats()

    print "Starting Test..."

    for data_iter_num, data in enumerate(dataloader):
        if data["data"].__class__.__name__ == "InputBatch":
            sparse_input = True
            inputs = data["data"]
            labels = data["truth"]
            if use_gpu:
                inputs = inputs.cuda().to_variable(requires_grad=True)
                labels = labels.cuda().to_variable()
            else:
                inputs = inputs.to_variable(requires_grad=True)
                labels = labels.to_variable()

        elif isinstance(data["data"], (list, tuple)):
            sparse_input = True
            inputs = scn.InputBatch(3, inputSpatialSize)
            labels = scn.InputBatch(3, inputSpatialSize)

            for sample, (indices, features, truth) in enumerate(
                    izip(data["indices"], data["data"], data["truth"])):
                inputs.addSample()
                labels.addSample()

                indices = torch.LongTensor(indices)
                features = torch.FloatTensor(features)
                truth = torch.FloatTensor(truth)

                try:
                    inputs.setLocations(indices, features,
                                        0)  #Use 1 to remove duplicate coords?
                    labels.setLocations(indices, truth, 0)
                except AssertionError:
                    #PDB didn't fit in grid?
                    continue

            del data

            inputs.precomputeMetadata(1)

            if use_gpu:
                inputs = inputs.cuda().to_variable(requires_grad=True)
                labels = labels.cuda().to_variable()
            else:
                inputs = inputs.to_variable(requires_grad=True)
                labels = labels.to_variable()

        elif isinstance(data["data"], torch.FloatTensor):
            #Input is dense
            print "Input is Dense"
            sparse_input = False
            if use_gpu:
                inputs = inputs.cuda()
                labels = labels.cuda()
            inputs = Variable(data["data"], requires_grad=True)
            inputs = scn.DenseToSparse(3)(inputs)
            try:
                inputs = inputs.cuda().to_variable(requires_grad=True)
            except:
                pass
            labels = Variable(data["truth"])

        else:
            raise RuntimeError("Invalid data from dataset")

        outputs = model(inputs)

        if sparse_input:
            loss = criterion(outputs.features, labels.features)

            if math.isnan(loss.data[0]):
                print "Loss is Nan?"
                import pdb
                pdb.set_trace()

            stats.update(outputs.features.data, labels.features.data,
                         loss.data[0])
        else:
            outputs = scn.SparseToDense(3, 1)(outputs)
            loss = criterion(outputs.cpu(), labels.cpu())
            stats.update(outputs.data.cpu().view(-1),
                         labels.data.cpu().view(-1), loss.data[0])

        print "Batch {}: corrects:{:.2f}% nll:{:.2f}% dice:{:.4f}% time:{:.1f}s".format(
            data_iter_num, stats.correctpct(), stats.nllpct(),
            loss.data[0] * -100,
            time.time() - since)

        save_batch_prediction(outputs)

        stats.save("val", 0)
    stats.plot_final()
    time_elapsed = time.time() - since
    print 'Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed / 60,
                                                       time_elapsed % 60)