示例#1
0
    def _make_transpose(self, transblock, planes, blocks, stride=1):

        upsample = None
        if stride != 1:
            upsample = scn.Sequential(
                scn.SparseToDense(2,self.inplanes * transblock.expansion),
                nn.ConvTranspose2d(self.inplanes * transblock.expansion, planes,
                                  kernel_size=2, stride=stride, padding=0, bias=False),
                scn.DenseToSparse(2),
                scn.BatchNormalization(planes)                
            )            
        elif self.inplanes * transblock.expansion != planes:
            upsample = scn.Sequential(
                scn.NetworkInNetwork(self.inplanes * transblock.expansion, planes, False),
                scn.BatchNormalization(planes)
            )

        layers = []
        
        for i in range(1, blocks):
            layers.append(transblock(self.inplanes, self.inplanes * transblock.expansion))

        layers.append(transblock(self.inplanes, planes, stride, upsample))
        self.inplanes = planes // transblock.expansion

        return scn.Sequential(*layers)
示例#2
0
 def __init__(self, num_classes=5):
     nn.Module.__init__(self)
     self.sparseModel = scn.Sequential().add(scn.DenseToSparse(2)).add(
         scn.ValidConvolution(2, 3, 8, 2, False)).add(
             scn.MaxPooling(2, 4, 2)).add(
                 scn.SparseResNet(2, 8, [['b', 8, 3, 1], [
                     'b', 16, 2, 2
                 ], ['b', 24, 2, 2], ['b', 32, 2, 2]])).add(
                     scn.Convolution(2, 32, 64, 4, 1,
                                     False)).add(scn.BatchNormReLU(64)).add(
                                         scn.SparseToDense(2, 64))
     self.linear = nn.Linear(6400, num_classes)
    def normalize_input(self, x, ones_conv):
        # implements sparsity invariant convolutions: http://www.cvlibs.net/publications/Uhrig2017THREEDV.pdf
        # unfortunately, now works by bringing activation to dense representation, normalizing and then back to sparse
        eps = 1e-5
        dense_to_sparse = scn.DenseToSparse(2)
        sparse_to_dense = scn.SparseToDense(2, x.features.shape[-1])

        dense_x = sparse_to_dense(x)

        # compute active sites and normalization factors
        active_sites = (dense_x.abs().sum(1, keepdim=True) > 0).float()
        dense_x_normalized = dense_x / (ones_conv(active_sites) + eps)

        # back to sparse
        x_normalized = dense_to_sparse(dense_x_normalized)

        return x_normalized
示例#4
0
    def forward(self, x):
        # Concatenate MLPs that treat PID, pos, dir and energy inputs separately
        net = torch.cat((self._mlp_pid(x[:, 0:3]), self._mlp_pos(
            x[:, 3:6]), self._mlp_dir(
                x[:, 6:9]), self._mlp_E(x[:, 9].reshape(len(x[:, 9]), 1))), 1)

        # MegaMLP
        net = self._mlp(net)

        # Reshape into 11 x 21 figure in 64 channels. Enough?!
        net = net.view(-1, 64, 11, 21)

        #        netSparse = scn.InputBatch(4, net)
        net = scn.DenseToSparse(net)

        #        net = self._upconvs(net)

        # Need to flatten? Maybe...
        return self._upconvs(net).view(-1, 88 * 168)
示例#5
0
 def __init__(self, inplanes, planes, stride=1, upsample=None, **kwargs):
     super(TransBasicBlockSparse, self).__init__()
     self.conv1 = conv3x3_sparse(inplanes, inplanes)
     self.bn1 = scn.BatchNormReLU(inplanes)
     self.relu = scn.ReLU()
     if upsample is not None and stride != 1:
         self.conv2 = scn.Sequential(
             scn.SparseToDense(2,inplanes),
             nn.ConvTranspose2d(inplanes, planes,
                               kernel_size=2, stride=stride, padding=0,
                               output_padding=0, bias=False),
             scn.DenseToSparse(2)
         )
     else:
         self.conv2 = conv3x3_sparse(inplanes, planes, stride)
     self.bn2 = scn.BatchNormalization(planes)
     self.add = scn.AddTable()
     self.upsample = upsample
     self.stride = stride
示例#6
0
 def __init__(self, transblock, layers, num_classes=150):
     self.inplanes = 512
     super(ResNetTransposeSparse, self).__init__()
     
     self.dense_to_sparse = scn.DenseToSparse(2)
     self.add = AddSparseDense()
     
     self.deconv1 = self._make_transpose(transblock, 256 * transblock.expansion, layers[0], stride=2)
     self.deconv2 = self._make_transpose(transblock, 128 * transblock.expansion, layers[1], stride=2)
     self.deconv3 = self._make_transpose(transblock, 64 * transblock.expansion, layers[2], stride=2)
     self.deconv4 = self._make_transpose(transblock, 64 * transblock.expansion, layers[3], stride=2)
     
     self.skip0 = self._make_skip_layer(128, 64 * transblock.expansion)
     self.skip1 = self._make_skip_layer(256, 64 * transblock.expansion)
     self.skip2 = self._make_skip_layer(512, 128 * transblock.expansion)
     self.skip3 = self._make_skip_layer(1024, 256 * transblock.expansion)
     self.skip4 = self._make_skip_layer(2048, 512 * transblock.expansion)
     
     self.densify0 = scn.SparseToDense(2, 64 * transblock.expansion)
     self.densify1 = scn.SparseToDense(2, 64 * transblock.expansion)
     self.densify2 = scn.SparseToDense(2, 128 * transblock.expansion)
     self.densify3 = scn.SparseToDense(2, 256 * transblock.expansion)
     
     self.inplanes = 64
     self.final_conv = self._make_transpose(transblock, 64 * transblock.expansion, 3)
     
     self.final_deconv = scn.Sequential(
             scn.SparseToDense(2, self.inplanes * transblock.expansion),
             nn.ConvTranspose2d(self.inplanes * transblock.expansion, num_classes, kernel_size=2,
                                            stride=2, padding=0, bias=True)
         )            
     
     self.out6_conv = nn.Conv2d(2048, num_classes, kernel_size=1, stride=1, bias=True)
     self.out5_conv = scn.NetworkInNetwork(256 * transblock.expansion, num_classes, True)
     self.out4_conv = scn.NetworkInNetwork(128 * transblock.expansion, num_classes, True)
     self.out3_conv = scn.NetworkInNetwork(64 * transblock.expansion, num_classes, True)
     self.out2_conv = scn.NetworkInNetwork(64 * transblock.expansion, num_classes, True)
     
     self.sparse_to_dense = scn.SparseToDense(2, num_classes)
示例#7
0
def infer(model, data, input_shape=(256,256,256), use_gpu=True, course_grained=False):
    model_prefix = "./molmimic_model_{}".format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))

    dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor'
    inputSpatialSize = torch.LongTensor(input_shape)

    labels = None

    mlog = MeterLogger(nclass=2, title="Sparse 3D UNet Inference")

    phase = "test"
    epoch = 0

    batch_weight = data.get("weight", None)
    if batch_weight is not None:
        batch_weight = torch.from_numpy(batch_weight).float().cuda()
    sample_weights = data.get("sample_weights", None)
    if sample_weights is not None:
        sample_weights = torch.from_numpy(sample_weights).float().cuda()
        use_size_average = False
        weight = sample_weights if use_size_average else batch_weight

    if data["data"].__class__.__name__ == "InputBatch":
        sparse_input = True
        inputs = data["data"]
        labels = data["truth"] if "truth" in data else None
        if use_gpu:
            inputs = inputs.cuda().to_variable(requires_grad=False)
            labels = labels.cuda().to_variable()
        else:
            inputs = inputs.to_variable(requires_grad=False)
            labels = labels.to_variable()

    elif isinstance(data["data"], (list, tuple)):
        sparse_input = True
        inputs = scn.InputBatch(3, inputSpatialSize)
        labels = scn.InputBatch(3, inputSpatialSize) if "truth" in data else None

        if isinstance(data["data"][0], np.ndarray):
            long_tensor = lambda arr: torch.from_numpy(arr).long()
            float_tensor = lambda arr: torch.from_numpy(arr).float()
        elif isinstance(data["data"][0], (list, tuple)):
            long_tensor = lambda arr: torch.LongTensor(arr)
            float_tensor = lambda arr: torch.FloatTensor(arr)
        else:
            raise RuntimeError("invalid datatype")

        for sample, (indices, features, truth) in enumerate(izip(data["indices"], data["data"], data["truth"])):
            inputs.addSample()
            if labels is not None:
                labels.addSample()

            try:
                indices = long_tensor(indices)
                features = float_tensor(features)
                if labels is not None:
                    truth = float_tensor(truth)
            except RuntimeError as e:
                print e
                continue

            inputs.setLocations(indices, features, 0) #Use 1 to remove duplicate coords?
            if labels is not None:
                labels.setLocations(indices, truth, 0)

        del data
        del indices
        del truth

        inputs.precomputeMetadata(1)

        if use_gpu:
            inputs = inputs.cuda()
            labels = labels.cuda()

        inputs = inputs.to_variable(requires_grad=True)
        labels = labels.to_variable()

    elif isinstance(data["data"], torch.FloatTensor):
        #Input is dense
        print "Input is Dense"
        sparse_input = False
        if use_gpu:
            inputs = inputs.cuda()
            labels = labels.cuda()
        inputs = Variable(data["data"], requires_grad=True)
        inputs = scn.DenseToSparse(3)(inputs)
        try:
            inputs = inputs.cuda().to_variable(requires_grad=True)
        except:
            pass
        labels = Variable(data["truth"])

    else:
        raise RuntimeError("Invalid data from dataset")

    # forward
    try:
        outputs = model(inputs)
    except AssertionError as e:
        print e
        #print nFeatures, inputs
        raise

    if labels is None:
        return outputs
    else:
        loss_fn = torch.nn.CrossEntropyLoss(weight=weight)

        loss = loss_fn(outputs, torch.max(labels.features, 1)[1])

        mlog.update_loss(loss, meter='loss')
        mlog.update_meter(outputs, torch.max(labels.features, 1)[1], meters={'accuracy', 'map'})
        add_to_logger(mlog,  "Train" if phase=="train" else "Test", epoch, outputs, labels.features, batch_weight)

        del inputs
        del labels
        del loss
        del loss_fn
        del batch_weight
        del sample_weights

        return outputs, mlog
示例#8
0
def train(ibis_data,
          input_shape=(264, 264, 264),
          model_prefix=None,
          check_point=True,
          save_final=True,
          only_aa=False,
          only_atom=False,
          non_geom_features=False,
          use_deepsite_features=False,
          expand_atom=False,
          num_workers=None,
          num_epochs=30,
          batch_size=20,
          shuffle=True,
          use_gpu=True,
          initial_learning_rate=0.0001,
          learning_rate_drop=0.5,
          learning_rate_epochs=10,
          lr_decay=4e-2,
          data_split=0.8,
          course_grained=False,
          no_batch_norm=False,
          use_resnet_unet=False,
          unclustered=False,
          undersample=False,
          oversample=False,
          nFeatures=None,
          allow_feature_combos=False,
          bs_feature=None,
          bs_feature2=None,
          bs_features=None,
          stripes=False,
          data_parallel=False,
          dropout_depth=False,
          dropout_width=False,
          dropout_p=0.5,
          wide_model=False,
          cellular_organisms=False,
          autoencoder=False,
          checkpoint_callback=None):
    if model_prefix is None:
        model_prefix = "./molmimic_model_{}".format(
            datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))

    if num_workers is None:
        num_workers = multiprocessing.cpu_count() - 1

    since = time.time()

    if ibis_data == "spheres":
        from torch_loader import SphereDataset
        nFeatures = nFeatures or 3
        datasets = SphereDataset.get_training_and_validation(
            input_shape,
            cnt=1,
            n_samples=1000,
            nFeatures=nFeatures,
            allow_feature_combos=allow_feature_combos,
            bs_feature=bs_feature,
            bs_feature2=bs_feature2,
            bs_features=bs_features,
            stripes=stripes,
            data_split=0.99)
        validation_batch_size = 1
        if bs_features is not None:
            nClasses = len(bs_features) + 1
        else:
            nClasses = 2
    elif os.path.isfile(ibis_data):
        dataset = IBISDataset
        print allow_feature_combos, nFeatures
        if allow_feature_combos and nFeatures is not None:
            random_features = (nFeatures, allow_feature_combos, bs_feature,
                               bs_feature2)
        elif not allow_feature_combos and nFeatures is not None:
            random_features = (nFeatures, False, bs_feature, bs_feature2)
        elif allow_feature_combos and nFeatures is None:
            random_features = None
            print "ignoring --allow-feature-combos"
        else:
            random_features = None

        datasets = dataset.get_training_and_validation(
            ibis_data,
            input_shape=input_shape,
            only_aa=only_aa,
            only_atom=only_atom,
            non_geom_features=non_geom_features,
            use_deepsite_features=use_deepsite_features,
            data_split=data_split,
            course_grained=course_grained,
            oversample=oversample,
            undersample=undersample,
            cellular_organisms=cellular_organisms,
            random_features=random_features)
        nFeatures = datasets["train"].get_number_of_features()
        nClasses = 2 if not autoencoder else nFeatures

        validation_batch_size = batch_size
    else:
        raise RuntimeError("Invalid training data")

    if num_workers % 2 == 0:
        num_workers -= 1
    num_workers /= 2
    num_workers = 6

    dataloaders = {name:dataset.get_data_loader(
        batch_size if dataset.train else validation_batch_size,
        shuffle,
        num_workers) \
        for name, dataset in datasets.iteritems()}

    dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available(
    ) else 'torch.FloatTensor'

    if use_resnet_unet:
        model = ResNetUNet(nFeatures,
                           nClasses,
                           dropout_depth=dropout_depth,
                           dropout_width=dropout_width,
                           dropout_p=dropout_p,
                           wide_model=wide_model)
    else:
        model = UNet3D(nFeatures, nClasses, batchnorm=not no_batch_norm)

    if data_parallel:
        model = torch.nn.DataParallel(model)

    model.type(dtype)

    optimizer = SGD(model.parameters(),
                    lr=initial_learning_rate,
                    momentum=0.999,
                    weight_decay=1e-4,
                    nesterov=True)

    scheduler = LambdaLR(optimizer, lambda epoch: math.exp(
        (1 - epoch) * lr_decay))

    check_point_model_file = "{}_checkpoint_model.pth".format(model_prefix)
    check_point_epoch_file = "{}_checkpoint_epoch.pth".format(model_prefix)
    if check_point and os.path.isfile(
            check_point_model_file) and os.path.isfile(check_point_epoch_file):
        start_epoch = torch.load(check_point_epoch_file)
        print "Restarting at epoch {} from {}".format(start_epoch + 1,
                                                      check_point_model_file)
        model.load_state_dict(torch.load(check_point_model_file))
    else:
        start_epoch = 0

    inputSpatialSize = torch.LongTensor(input_shape)

    draw_graph = True

    mlog = MeterLogger(nclass=nClasses,
                       title="Sparse 3D UNet",
                       server="cn4216")

    #Start clean
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data')
                                        and torch.is_tensor(obj.data)):
                #print type(obj), obj.size()
                del obj
        except (SystemExit, KeyboardInterrupt):
            raise
        except Exception as e:
            pass

    for epoch in xrange(start_epoch, num_epochs):
        print "Epoch {}/{}".format(epoch, num_epochs - 1)
        print "-" * 10

        mlog.timer.reset()
        for phase in ['train', 'val']:
            datasets[phase].epoch = epoch
            num_batches = int(
                np.ceil(
                    len(datasets[phase]) /
                    float(batch_size if phase ==
                          "train" else validation_batch_size)))

            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            # Iterate over data.
            bar = tqdm(enumerate(dataloaders[phase]),
                       total=num_batches,
                       unit="batch",
                       desc="Loading data",
                       leave=True)
            for data_iter_num, data in bar:
                datasets[phase].batch = data_iter_num
                batch_weight = data.get("weight", None)

                if batch_weight is not None:
                    batch_weight = torch.from_numpy(batch_weight).float()
                    if use_gpu:
                        batch_weight = batch_weight.cuda()

                sample_weights = data.get("sample_weights", None)

                if sample_weights is not None:
                    sample_weights = torch.from_numpy(sample_weights).float()
                    if use_gpu:
                        sample_weights = sample_weights.cuda()

                if data["data"].__class__.__name__ == "InputBatch":
                    sparse_input = True
                    inputs = data["data"]
                    labels = data["truth"]
                    if use_gpu:
                        inputs = inputs.cuda().to_variable(requires_grad=True)
                        labels = labels.cuda().to_variable()
                    else:
                        inputs = inputs.to_variable(requires_grad=True)
                        labels = labels.to_variable()

                elif isinstance(data["data"], (list, tuple)):
                    sparse_input = True
                    inputs = scn.InputBatch(3, inputSpatialSize)
                    labels = scn.InputBatch(3, inputSpatialSize)

                    if isinstance(data["data"][0], np.ndarray):
                        long_tensor = lambda arr: torch.from_numpy(arr).long()
                        float_tensor = lambda arr: torch.from_numpy(arr).float(
                        )
                    elif isinstance(data["data"][0], (list, tuple)):
                        long_tensor = lambda arr: torch.LongTensor(arr)
                        float_tensor = lambda arr: torch.FloatTensor(arr)
                    else:
                        raise RuntimeError("invalid datatype")

                    for sample, (indices, features, truth, id) in enumerate(
                            izip(data["indices"], data["data"], data["truth"],
                                 data["id"])):
                        inputs.addSample()
                        labels.addSample()

                        try:
                            indices = long_tensor(indices)
                            features = float_tensor(features)
                            truth = float_tensor(truth)
                        except RuntimeError as e:
                            print e
                            continue

                        try:
                            inputs.setLocations(
                                indices, features,
                                0)  #Use 1 to remove duplicate coords?
                            labels.setLocations(indices, truth, 0)
                        except AssertionError:
                            print "Error with PDB:", id
                            with open("bad_pdbs.txt", "a") as f:
                                print >> f, id

                    del data
                    del indices
                    del truth

                    inputs.precomputeMetadata(1)

                    if use_gpu:
                        inputs = inputs.cuda()
                        labels = labels.cuda()

                    inputs = inputs.to_variable(requires_grad=True)
                    labels = labels.to_variable()

                elif isinstance(data["data"], torch.FloatTensor):
                    #Input is dense
                    print "Input is Dense"
                    sparse_input = False
                    if use_gpu:
                        inputs = inputs.cuda()
                        labels = labels.cuda()
                    inputs = Variable(data["data"], requires_grad=True)
                    inputs = scn.DenseToSparse(3)(inputs)
                    try:
                        inputs = inputs.cuda().to_variable(requires_grad=True)
                    except:
                        pass
                    labels = Variable(data["truth"])

                else:
                    raise RuntimeError("Invalid data from dataset")

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                try:
                    outputs = model(inputs)
                except AssertionError:
                    print nFeatures, inputs
                    raise

                if sparse_input:
                    use_size_average = False
                    weight = sample_weights if use_size_average else batch_weight

                    loss_fn = torch.nn.CrossEntropyLoss(weight=weight)

                    loss = loss_fn(outputs, torch.max(labels.features, 1)[1])

                    if draw_graph:
                        var_dot = dot.make_dot(loss)
                        var_dot.render('SparseUnet3dCNN_graph.pdf')
                        draw_graph = False
                        del var_dot

                else:
                    outputs = scn.SparseToDense(3, 1)(outputs)
                    criterion = DiceLoss(size_average=False)
                    loss = criterion(outputs.cpu(), labels.cpu(
                    ))  #, inputs.getSpatialLocations(), scaling)
                    stats.update(outputs.data.cpu().view(-1),
                                 labels.data.cpu().view(-1), loss.data[0])

                mlog.update_loss(loss, meter='loss')
                mlog.update_meter(outputs,
                                  torch.max(labels.features, 1)[1],
                                  meters={'accuracy', 'map'})
                add_to_logger(mlog,
                              "Train" if phase == "train" else "Test",
                              epoch,
                              outputs,
                              labels.features,
                              batch_weight,
                              n_classes=nClasses)

                # backward + optimize only if in training phase
                if phase == 'train':
                    a = list(model.parameters())[0].clone().data
                    loss.backward()
                    optimizer.step()
                    b = list(model.parameters())[0].clone().data
                    if torch.equal(a, b): print "NOT UPDATED"
                    del a
                    del b

                bar.set_description("{}: [{}][{}/{}]".format(
                    phase, epoch, data_iter_num + 1, num_batches))
                bar.set_postfix(loss="{:.4f} ({:.4f})".format(
                    mlog.meter["loss"].val, mlog.meter["loss"].mean),
                                dice_class1="{:.4f} ({:.4f})".format(
                                    mlog.meter["dice_class1"].val,
                                    mlog.meter["dice_class1"].mean),
                                weight_dice="{:.4f} ({:.4f})".format(
                                    mlog.meter["weighted_dice_wavg"].val,
                                    mlog.meter["weighted_dice_wavg"].mean),
                                refresh=False)
                bar.refresh()

                del inputs
                del labels
                del outputs
                del loss
                del loss_fn
                del batch_weight
                del sample_weights

                #Delete all unused objects on the GPU
                for obj in gc.get_objects():
                    try:
                        if torch.is_tensor(obj) or (hasattr(obj, 'data') and
                                                    torch.is_tensor(obj.data)):
                            #print type(obj), obj.size()
                            del obj
                    except (SystemExit, KeyboardInterrupt):
                        raise
                    except Exception as e:
                        pass

            statsfile, graphs = graph_logger(
                mlog, "Train" if phase == "train" else "Test", epoch)
            mlog.reset_meter(epoch, "Train" if phase == "train" else "Test")

            if check_point:
                torch.save(epoch, check_point_epoch_file)
                torch.save(model.state_dict(), check_point_model_file)
                if callable(checkpoint_callback):
                    checkpoint_callback(epoch, statsfile, graphs,
                                        check_point_epoch_file,
                                        check_point_model_file)
            elif callable(checkpoint_callback):
                checkpoint_callback(epoch, statsfile, graphs, None, None)

    #stats.plot_final()

    statsfile, graphs = graph_logger(mlog,
                                     "Train" if phase == "train" else "Test",
                                     epoch,
                                     final=True)

    time_elapsed = time.time() - since
    print 'Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed / 60, time_elapsed % 60)

    torch.save(model.state_dict(), "{}.pth".format(model_prefix))

    if callable(checkpoint_callback):
        checkpoint_callback(epoch, statsfile, graphs, check_point_epoch_file,
                            check_point_model_file)

    return model
示例#9
0
def test(model_file,
         ibis_data,
         input_shape=(512, 512, 512),
         only_aa=False,
         only_atom=False,
         expand_atom=False,
         num_workers=None,
         batch_size=20,
         shuffle=True,
         use_gpu=True,
         data_split=0.8,
         test_full=False,
         no_batch_norm=False):

    if num_workers is None:
        num_workers = multiprocessing.cpu_count() - 1
    print "Using {} workers".format(num_workers)

    since = time.time()

    if ibis_data == "spheres":
        from torch_loader import SphereDataset
        datasets = SphereDataset.get_training_and_validation(input_shape,
                                                             cnt=1,
                                                             n_samples=1000,
                                                             data_split=0.99)
        nFeatures = 1
        validation_batch_size = 1
        input_shape = (96, 96, 96)
    elif os.path.isfile(ibis_data):
        datasets = IBISDataset.get_training_and_validation(
            ibis_data,
            input_shape=input_shape,
            only_aa=only_aa,
            only_atom=only_atom,
            expand_atom=expand_atom,
            data_split=data_split,
            train_full=test_full,
            validate_full=test_full)
        if only_atom:
            nFeatures = 5
        elif only_aa:
            nFeatures = 21
        else:
            nFeatures = 59
        validation_batch_size = batch_size
    else:
        raise RuntimeError("Invalid training data")

    dataloader = datasets["val"].get_data_loader(
        batch_size if datasets["val"].train else validation_batch_size,
        shuffle, num_workers)

    dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available(
    ) else 'torch.FloatTensor'

    model = UNet3D(nFeatures, 1, batchnorm=not no_batch_norm)
    model.type(dtype)

    if not os.path.isfile(model_file):
        raise IOError("Model cannot be opened")

    model.load_state_dict(torch.load(model_file))
    model.train(False)  # Set model to evaluate mode

    criterion = DiceLoss()

    inputSpatialSize = torch.LongTensor(input_shape)

    stats = ModelStats()

    print "Starting Test..."

    for data_iter_num, data in enumerate(dataloader):
        if data["data"].__class__.__name__ == "InputBatch":
            sparse_input = True
            inputs = data["data"]
            labels = data["truth"]
            if use_gpu:
                inputs = inputs.cuda().to_variable(requires_grad=True)
                labels = labels.cuda().to_variable()
            else:
                inputs = inputs.to_variable(requires_grad=True)
                labels = labels.to_variable()

        elif isinstance(data["data"], (list, tuple)):
            sparse_input = True
            inputs = scn.InputBatch(3, inputSpatialSize)
            labels = scn.InputBatch(3, inputSpatialSize)

            for sample, (indices, features, truth) in enumerate(
                    izip(data["indices"], data["data"], data["truth"])):
                inputs.addSample()
                labels.addSample()

                indices = torch.LongTensor(indices)
                features = torch.FloatTensor(features)
                truth = torch.FloatTensor(truth)

                try:
                    inputs.setLocations(indices, features,
                                        0)  #Use 1 to remove duplicate coords?
                    labels.setLocations(indices, truth, 0)
                except AssertionError:
                    #PDB didn't fit in grid?
                    continue

            del data

            inputs.precomputeMetadata(1)

            if use_gpu:
                inputs = inputs.cuda().to_variable(requires_grad=True)
                labels = labels.cuda().to_variable()
            else:
                inputs = inputs.to_variable(requires_grad=True)
                labels = labels.to_variable()

        elif isinstance(data["data"], torch.FloatTensor):
            #Input is dense
            print "Input is Dense"
            sparse_input = False
            if use_gpu:
                inputs = inputs.cuda()
                labels = labels.cuda()
            inputs = Variable(data["data"], requires_grad=True)
            inputs = scn.DenseToSparse(3)(inputs)
            try:
                inputs = inputs.cuda().to_variable(requires_grad=True)
            except:
                pass
            labels = Variable(data["truth"])

        else:
            raise RuntimeError("Invalid data from dataset")

        outputs = model(inputs)

        if sparse_input:
            loss = criterion(outputs.features, labels.features)

            if math.isnan(loss.data[0]):
                print "Loss is Nan?"
                import pdb
                pdb.set_trace()

            stats.update(outputs.features.data, labels.features.data,
                         loss.data[0])
        else:
            outputs = scn.SparseToDense(3, 1)(outputs)
            loss = criterion(outputs.cpu(), labels.cpu())
            stats.update(outputs.data.cpu().view(-1),
                         labels.data.cpu().view(-1), loss.data[0])

        print "Batch {}: corrects:{:.2f}% nll:{:.2f}% dice:{:.4f}% time:{:.1f}s".format(
            data_iter_num, stats.correctpct(), stats.nllpct(),
            loss.data[0] * -100,
            time.time() - since)

        save_batch_prediction(outputs)

        stats.save("val", 0)
    stats.plot_final()
    time_elapsed = time.time() - since
    print 'Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed / 60,
                                                       time_elapsed % 60)