Пример #1
0
                       ['C', 16, 8], 'MP', ['C', 24, 8], ['C', 24, 8], 'MP']),
            scn.Convolution(2, 32, 64, 5, 1, False), scn.BatchNormReLU(64),
            scn.SparseToDense(2, 64))
        self.spatial_size = self.sparseModel.input_spatial_size(
            torch.LongTensor([1, 1]))
        self.inputLayer = scn.InputLayer(2, self.spatial_size, 2)
        self.linear = nn.Linear(64, 183)

    def forward(self, x):
        x = self.inputLayer(x)
        x = self.sparseModel(x)
        x = x.view(-1, 64)
        x = self.linear(x)
        return x


model = Model()
scale = 63
dataset = get_iterators(model.spatial_size, scale)
print('Input spatial size:', model.spatial_size, 'Data scale:', scale)

scn.ClassificationTrainValidate(
    model, dataset, {
        'n_epochs': 100,
        'initial_lr': 0.1,
        'lr_decay': 0.05,
        'weight_decay': 1e-4,
        'use_cuda': torch.cuda.is_available(),
        'check_point': False,
    })
Пример #2
0
def main():

    root_dir = '~/Dropbox/lib/deep_neuro_morpho/data'
    data_dir = 'png_mip_256_fit_2d'
    classes = np.arange(6)
    n_classes = len(classes)
    batch_size = 64

    metadata = pd.read_pickle(
        '../data/rodent_3d_dendrites_br-ct-filter-3_all_mainclasses_use_filter.pkl'
    )
    metadata = metadata[metadata['label1_id'].isin(classes)]
    neuron_ids = metadata['neuron_id'].values
    labels = metadata[
        'label1_id'].values  # contain the same set of values as classes
    unique, counts = np.unique(labels, return_counts=True)

    transform_train = None

    transform_test = None

    train_ids, test_ids, train_y, test_y = \
        train_test_split(neuron_ids, labels, test_size=0.15, random_state=42, stratify=labels)

    train_ids, val_ids, train_y, val_y = \
        train_test_split(train_ids, train_y, test_size=0.15, random_state=42, stratify=train_y)

    kwargs = {'num_workers': 4, 'pin_memory': True}
    train_loader = torch.utils.data.DataLoader(NeuroMorpho(
        root_dir,
        data_dir,
        train_ids,
        train_y,
        img_size=256,
        transform=transform_train,
        rgb=False),
                                               collate_fn=SparseMerge(),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               **kwargs)

    val_loader = torch.utils.data.DataLoader(NeuroMorpho(
        root_dir,
        data_dir,
        val_ids,
        val_y,
        img_size=256,
        transform=transform_test,
        rgb=False),
                                             collate_fn=SparseMerge(),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             **kwargs)

    test_loader = torch.utils.data.DataLoader(NeuroMorpho(
        root_dir,
        data_dir,
        test_ids,
        test_y,
        img_size=256,
        transform=transform_test,
        rgb=False),
                                              collate_fn=SparseMerge(),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              **kwargs)

    model = SparseResNet2D(n_classes)
    dataset = {'train': train_loader, 'val': val_loader}
    print('Input spatial size:', model.spatial_size)

    scn.ClassificationTrainValidate(
        model, dataset, {
            'n_epochs': 100,
            'initial_lr': 0.1,
            'lr_decay': 0.05,
            'weight_decay': 1e-4,
            'use_cuda': torch.cuda.is_available(),
            'check_point': False,
        })