Esempio n. 1
0
def prepare_model_and_loader(config):
    mean_std = pkl.load(
        open(config['dataset_folder'] + '/S2-2019-T33TWM-meanstd.pkl', 'rb'))
    extra = 'geomfeat' if config['geomfeat'] else None
    ##################
    # TOP 20 CLASSES FRANCE
    ##################
    # dt = PixelSetData(config['dataset_folder'], labels='c_group_co', npixel=config['npixel'],
    #                   sub_classes=[1, 3, 4, 5, 6, 8, 9, 12, 13, 14, 16, 18, 19, 23, 28, 31, 33, 34, 36, 39],
    #                   norm=mean_std,
    #                   extra_feature=extra, return_id=True)
    # dl = data.DataLoader(dt, batch_size=config['batch_size'], num_workers=config['num_workers'])
    ##################
    # TOP 20 CLASSES SLOVENIA
    ##################
    dt = PixelSetData(config['dataset_folder'],
                      labels='c_group_co',
                      npixel=config['npixel'],
                      sub_classes=[
                          33200000, 33101060, 33101010, 33101040, 33301010,
                          33304000, 33111023, 33109000, 33103000, 33107000,
                          33101070, 33106042, 33101050, 33101030, 33111022,
                          33101100, 33301040, 33106020, 33106040, 33101080
                      ],
                      norm=mean_std,
                      extra_feature=extra,
                      return_id=True)
    dl = data.DataLoader(dt,
                         batch_size=config['batch_size'],
                         num_workers=config['num_workers'])

    model_config = dict(input_dim=config['input_dim'],
                        mlp1=config['mlp1'],
                        pooling=config['pooling'],
                        mlp2=config['mlp2'],
                        n_head=config['n_head'],
                        d_k=config['d_k'],
                        mlp3=config['mlp3'],
                        dropout=config['dropout'],
                        T=config['T'],
                        len_max_seq=config['lms'],
                        positions=dt.date_positions
                        if config['positions'] == 'bespoke' else None,
                        mlp4=config['mlp4'])

    if config['geomfeat']:
        model_config.update(with_extra=True, extra_size=4)
    else:
        model_config.update(with_extra=False, extra_size=None)

    model = PseTae_pretrained(config['weight_dir'],
                              model_config,
                              device=config['device'],
                              fold=config['fold'])

    return model, dl
Esempio n. 2
0
def prepare_model_and_loader(config):
    mean_std = pkl.load(
        open(config['dataset_folder'] + '/S2-2017-T31TFM-meanstd.pkl', 'rb'))
    extra = 'geomfeat' if config['geomfeat'] else None
    dt = PixelSetData(config['dataset_folder'],
                      labels='label_44class',
                      npixel=config['npixel'],
                      sub_classes=[
                          1, 3, 4, 5, 6, 8, 9, 12, 13, 14, 16, 18, 19, 23, 28,
                          31, 33, 34, 36, 39
                      ],
                      norm=mean_std,
                      extra_feature=extra,
                      return_id=True)
    dl = data.DataLoader(dt,
                         batch_size=config['batch_size'],
                         num_workers=config['num_workers'])

    model_config = dict(input_dim=config['input_dim'],
                        mlp1=config['mlp1'],
                        pooling=config['pooling'],
                        mlp2=config['mlp2'],
                        n_head=config['n_head'],
                        d_k=config['d_k'],
                        mlp3=config['mlp3'],
                        dropout=config['dropout'],
                        T=config['T'],
                        len_max_seq=config['lms'],
                        positions=dt.date_positions
                        if config['positions'] == 'bespoke' else None,
                        mlp4=config['mlp4'])

    if config['geomfeat']:
        model_config.update(with_extra=True, extra_size=4)
    else:
        model_config.update(with_extra=False, extra_size=None)

    model = PseTae_pretrained(config['weight_dir'],
                              model_config,
                              device=config['device'],
                              fold=config['fold'])

    return model, dl
Esempio n. 3
0
def main(config):
    np.random.seed(config['rdm_seed'])
    torch.manual_seed(config['rdm_seed'])
    prepare_output(config)

    mean_std = pkl.load(
        open(config['dataset_folder'] + '/S2-2017-T31TFM-meanstd.pkl', 'rb'))
    extra = 'geomfeat' if config['geomfeat'] else None

    if config['preload']:
        dt = PixelSetData_preloaded(config['dataset_folder'],
                                    labels='label_44class',
                                    npixel=config['npixel'],
                                    sub_classes=[
                                        1, 3, 4, 5, 6, 8, 9, 12, 13, 14, 16,
                                        18, 19, 23, 28, 31, 33, 34, 36, 39
                                    ],
                                    norm=mean_std,
                                    extra_feature=extra)
    else:
        dt = PixelSetData(config['dataset_folder'],
                          labels='label_44class',
                          npixel=config['npixel'],
                          sub_classes=[
                              1, 3, 4, 5, 6, 8, 9, 12, 13, 14, 16, 18, 19, 23,
                              28, 31, 33, 34, 36, 39
                          ],
                          norm=mean_std,
                          extra_feature=extra)
    device = torch.device(config['device'])

    loaders = get_loaders(dt, config['kfold'], config)
    for fold, (train_loader, val_loader, test_loader) in enumerate(loaders):
        print('Starting Fold {}'.format(fold + 1))
        print('Train {}, Val {}, Test {}'.format(len(train_loader),
                                                 len(val_loader),
                                                 len(test_loader)))

        model_config = dict(input_dim=config['input_dim'],
                            mlp1=config['mlp1'],
                            pooling=config['pooling'],
                            mlp2=config['mlp2'],
                            n_head=config['n_head'],
                            d_k=config['d_k'],
                            mlp3=config['mlp3'],
                            dropout=config['dropout'],
                            T=config['T'],
                            len_max_seq=config['lms'],
                            positions=dt.date_positions
                            if config['positions'] == 'bespoke' else None,
                            mlp4=config['mlp4'])

        if config['geomfeat']:
            model_config.update(with_extra=True, extra_size=4)
        else:
            model_config.update(with_extra=False, extra_size=None)

        model = PseTae(**model_config)

        print(model.param_ratio())

        model = model.to(device)
        model.apply(weight_init)
        optimizer = torch.optim.Adam(model.parameters())
        criterion = FocalLoss(config['gamma'])

        trainlog = {}

        best_mIoU = 0
        for epoch in range(1, config['epochs'] + 1):
            print('EPOCH {}/{}'.format(epoch, config['epochs']))

            model.train()
            train_metrics = train_epoch(model,
                                        optimizer,
                                        criterion,
                                        train_loader,
                                        device=device,
                                        config=config)

            print('Validation . . . ')
            model.eval()
            val_metrics = evaluation(model,
                                     criterion,
                                     val_loader,
                                     device=device,
                                     config=config,
                                     mode='val')

            print('Loss {:.4f},  Acc {:.2f},  IoU {:.4f}'.format(
                val_metrics['val_loss'], val_metrics['val_accuracy'],
                val_metrics['val_IoU']))

            trainlog[epoch] = {**train_metrics, **val_metrics}
            checkpoint(fold + 1, trainlog, config)

            if val_metrics['val_IoU'] >= best_mIoU:
                best_mIoU = val_metrics['val_IoU']
                torch.save(
                    {
                        'epoch': epoch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    },
                    os.path.join(config['res_dir'], 'Fold_{}'.format(fold + 1),
                                 'model.pth.tar'))

        print('Testing best epoch . . .')
        model.load_state_dict(
            torch.load(
                os.path.join(config['res_dir'], 'Fold_{}'.format(fold + 1),
                             'model.pth.tar'))['state_dict'])
        model.eval()

        test_metrics, conf_mat = evaluation(model,
                                            criterion,
                                            test_loader,
                                            device=device,
                                            mode='test',
                                            config=config)

        print('Loss {:.4f},  Acc {:.2f},  IoU {:.4f}'.format(
            test_metrics['test_loss'], test_metrics['test_accuracy'],
            test_metrics['test_IoU']))
        save_results(fold + 1, test_metrics, conf_mat, config)

    overall_performance(config)