Exemplo n.º 1
0
def fit(epochs, verbose=False, layers=6, lr=0.001, init_filters=16, loss='focal', init_val=0.5):
    net = w(UNetClassify(layers=layers, init_filters=init_filters, init_val=init_val))
    criterion = get_loss(loss)
    optimizer = optim.Adam(net.parameters(), lr=lr)
    train_dataset = OneraPreloader(data_dir , train_csv, input_size, bands, None, None, None)
    train = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
    
    test_dataset = OneraPreloader(data_dir , test_csv, input_size, bands, None, None, None)
    test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=1)

    best_iou = -1.0
    best_net_dict = None
    best_epoch = -1
    best_loss = 1000.0

    for epoch in tqdm(range(epochs)):
        net.train()
        train_losses = []
        for batch, labels in train:
            batch = w(autograd.Variable(batch))
            labels = w(autograd.Variable(labels))

            optimizer.zero_grad()
            output = net(batch)
            loss = criterion(output, labels.view(-1,1,32,32).float())
            loss.backward()
            train_losses.append(loss.item())

            optimizer.step()
        print('train loss', np.mean(train_losses))

        net.eval()
        losses = []
        iou = []
        to_show = random.randint(0, len(test) - 1)
        for batch, labels_true in test:
            labels = w(autograd.Variable(labels_true))
            batch = w(autograd.Variable(batch))
            output = net(batch)
            loss = criterion(output, labels.view(-1,1,32,32).float())
            losses += [loss.item()] * batch.size()[0]
            result = (F.sigmoid(output).data.cpu().numpy() > 0.5).astype(np.uint8)
            for label, res in zip(labels_true, result):
                label = label.cpu().numpy()[:, :]
#                 plt.imshow(label, cmap='tab20c')
#                 plt.show()
#                 plt.imshow(find_clusters(res), cmap='tab20c')
#                 plt.show()
                iou.append(evaluate_combined(label, find_clusters(res[0])))

        cur_iou = np.mean(iou)
        if cur_iou > best_iou or (cur_iou == best_iou and np.mean(losses) < best_loss):
            best_iou = cur_iou
            best_epoch = epoch
            import copy
            best_net_dict = copy.deepcopy(net.state_dict())
            best_loss = np.mean(losses)
        print(np.mean(losses), np.mean(iou), best_loss, best_iou)
    return best_iou, best_loss, best_epoch, best_net_dict
Exemplo n.º 2
0
def get_loaders(opt):
    """Given user arguments, loads dataset metadata, loads full onera dataset,
       defines a preloader and returns train and val dataloaders

    Parameters
    ----------
    opt : dict
        Dictionary of options/flags

    Returns
    -------
    (DataLoader, DataLoader)
        returns train and val dataloaders

    """
    train_samples, val_samples = get_train_val_metadata(
        opt.dataset_dir, opt.validation_cities, opt.patch_size, opt.stride)
    print('train samples : ', len(train_samples))
    print('val samples : ', len(val_samples))

    logging.info('STARTING Dataset Creation')

    full_load = full_onera_loader(opt.dataset_dir, opt)

    train_dataset = OneraPreloader(opt.dataset_dir, train_samples, full_load,
                                   opt.patch_size, opt.augmentation)
    val_dataset = OneraPreloader(opt.dataset_dir, val_samples, full_load,
                                 opt.patch_size, False)

    logging.info('STARTING Dataloading')

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.batch_size,
                                               shuffle=True,
                                               num_workers=opt.num_workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=opt.batch_size,
                                             shuffle=False,
                                             num_workers=opt.num_workers)
    return train_loader, val_loader
Exemplo n.º 3
0
                torch.save(model, weights_dir + '3dconv_seg.pt')

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best test Acc: {:4f}'.format(best_acc))

    model = torch.load(weights_dir + '3dconv_seg.pt')

    return model


#Dataload and generator initialization
image_datasets = {
    'train':
    OneraPreloader(data_dir, train_csv, input_size, bands, bands_mean,
                   bands_std, bands_max),
    'test':
    OneraPreloader(data_dir, test_csv, input_size, bands, bands_mean,
                   bands_std, bands_max)
}
dataloaders = {
    x: torch.utils.data.DataLoader(image_datasets[x],
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=1)
    for x in ['train', 'test']
}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
file_name = __file__.split('/')[-1].split('.')[0]

#Create model and initialize/freeze weights
Exemplo n.º 4
0
bands = [
    'B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09',
    'B10', 'B11', 'B12'
]
data_dir = '../datasets/onera/'
weights_dir = '../weights/onera/'
train_csv = '../datasets/onera/train.csv'
test_csv = '../datasets/onera/test.csv'

net = w(
    UNetClassify(layers=layers, init_filters=init_filters, init_val=init_val))
criterion = get_loss(loss_func)
optimizer = optim.Adam(net.parameters(), lr=lr)

full_load = full_onera_loader(data_dir, bands)
train_dataset = OneraPreloader(data_dir, train_csv, input_size, full_load,
                               onera_siamese_loader_late_pooling)
train = torch.utils.data.DataLoader(train_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=4)

test_dataset = OneraPreloader(data_dir, test_csv, input_size, full_load,
                              onera_siamese_loader_late_pooling)
test = torch.utils.data.DataLoader(test_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=4)

best_iou = -1.0
best_net_dict = None
best_epoch = -1