Ejemplo n.º 1
0
def get_dataloaders(data_dir,
                    patch_size: int,
                    box_coder,
                    train_batch_size=1,
                    valid_batch_size=1,
                    workers=4,
                    fold=0,
                    fast=False):
    train_ids, valid_ids = get_train_test_split_for_fold(fold, ships_only=True)
    if fast:
        train_ids = train_ids[:train_batch_size * 64]
        valid_ids = valid_ids[:valid_batch_size * 64]

    groundtruth = pd.read_csv(
        os.path.join(data_dir, 'train_ship_segmentations_v2.csv'))

    trainset = D.RSSDDataset(sample_ids=train_ids,
                             data_dir=data_dir,
                             transform=get_transform(training=True,
                                                     width=patch_size,
                                                     height=patch_size),
                             groundtruth=groundtruth,
                             box_coder=box_coder)

    validset = D.RSSDDataset(sample_ids=valid_ids,
                             data_dir=data_dir,
                             transform=get_transform(training=False,
                                                     width=patch_size,
                                                     height=patch_size),
                             groundtruth=groundtruth,
                             box_coder=box_coder)

    shuffle = True
    sampler = None
    if fast:
        shuffle = False
        sampler = WeightedRandomSampler(np.ones(len(trainset)), 1024)

    trainloader = DataLoader(trainset,
                             batch_size=train_batch_size,
                             num_workers=workers,
                             pin_memory=True,
                             drop_last=True,
                             shuffle=shuffle,
                             sampler=sampler)

    validloader = DataLoader(
        validset,
        batch_size=valid_batch_size,
        num_workers=workers,
        pin_memory=True,
        drop_last=False,
        shuffle=False,
    )

    print('Train set', len(trainset), len(trainloader), 'Valid set',
          len(validset), len(validloader))
    return trainloader, validloader
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default='d:\\datasets\\airbus',
                        help='Data dir')
    parser.add_argument('-m', '--model', type=str, default='rssd512', help='')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=4,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-es',
                        '--early-stopping',
                        type=int,
                        default=None,
                        help='Maximum number of epochs without improvement')
    parser.add_argument('-r',
                        '--resume',
                        type=str,
                        default=None,
                        help='Checkpoint filename to resume')
    parser.add_argument('-w',
                        '--workers',
                        default=4,
                        type=int,
                        help='Num workers')
    parser.add_argument('-p', '--patch-size', type=int, default=768, help='')

    args = parser.parse_args()
    set_manual_seed(args.seed)

    fname = auto_file(args.resume)
    exp_dir = os.path.dirname(fname)
    prefix = os.path.splitext(os.path.basename(fname))[0]

    model = get_model(args.model, num_classes=1).cuda()

    start_epoch, train_history, best_score = restore_checkpoint(fname, model)
    print(train_history)

    testset_full = D.RSSDDataset(
        sample_ids=all_test_ids(args.data_dir),
        test=True,
        data_dir=args.data_dir,
        transform=get_transform(training=False,
                                width=args.patch_size,
                                height=args.patch_size),
        box_coder=RSSDBoxCoder(args.patch_size, args.patch_size))
    test_predictions = model.predict_as_csv(testset_full,
                                            batch_size=args.batch_size,
                                            workers=args.workers)
    test_predictions.to_csv(os.path.join(exp_dir,
                                         f'{prefix}_test_predictions.csv'),
                            index=None)
    print('Predictions saved')
Ejemplo n.º 3
0
def get_dataloaders(data_dir, train_batch_size=1, patch_size=768, workers=4, fold=0, fast=False, oversample_ships=True):
    train_ids, valid_ids = get_train_test_split_for_fold(fold, ships_only=True)
    if fast:
        train_ids = train_ids[:train_batch_size * 8]
        valid_ids = valid_ids[:train_batch_size * 8]

    groundtruth = pd.read_csv(os.path.join(data_dir, 'train_ship_segmentations_v2.csv'))

    trainset = D.SegmentationDataset(sample_ids=train_ids,
                                     data_dir=data_dir,
                                     transform=get_transform(True, patch_size, patch_size),
                                     groundtruth=groundtruth)

    validset = D.SegmentationDataset(sample_ids=valid_ids,
                                     data_dir=data_dir,
                                     transform=get_transform(False, patch_size, patch_size),
                                     groundtruth=groundtruth)

    trainloader = DataLoader(trainset,
                             batch_size=train_batch_size,
                             num_workers=workers,
                             pin_memory=True,
                             drop_last=True,
                             shuffle=True
                             )

    validloader = DataLoader(validset,
                             batch_size=train_batch_size,
                             num_workers=workers,
                             pin_memory=True,
                             drop_last=False,
                             shuffle=False,
                             )

    print('Train set', len(trainset), len(trainloader), train_batch_size, 'Valid set', len(validset), len(validloader), train_batch_size)
    return trainloader, validloader
Ejemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default='d:\\datasets\\airbus',
                        help='Data dir')
    parser.add_argument('-m',
                        '--model',
                        type=str,
                        default='rretina_net',
                        help='')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=4,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-e',
                        '--epochs',
                        type=int,
                        default=150,
                        help='Epoch to run')
    parser.add_argument('-es',
                        '--early-stopping',
                        type=int,
                        default=None,
                        help='Maximum number of epochs without improvement')
    parser.add_argument('-f',
                        '--fold',
                        default=0,
                        type=int,
                        help='Fold to train')
    parser.add_argument('-fe',
                        '--freeze-encoder',
                        type=int,
                        default=0,
                        help='Freeze encoder parameters for N epochs')
    parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument('--fast', action='store_true')
    parser.add_argument('-lr',
                        '--learning-rate',
                        type=float,
                        default=1e-3,
                        help='Initial learning rate')
    parser.add_argument('-lrs',
                        '--lr-scheduler',
                        default=None,
                        help='LR scheduler')
    parser.add_argument('-o',
                        '--optimizer',
                        default='Adam',
                        help='Name of the optimizer')
    parser.add_argument('-r',
                        '--resume',
                        type=str,
                        default=None,
                        help='Checkpoint filename to resume')
    parser.add_argument('-w',
                        '--workers',
                        default=4,
                        type=int,
                        help='Num workers')
    parser.add_argument('-wd',
                        '--weight-decay',
                        type=float,
                        default=0,
                        help='L2 weight decay')
    parser.add_argument('-p', '--patch-size', type=int, default=768, help='')
    parser.add_argument('-ew', '--encoder-weights', default=None, type=str)

    args = parser.parse_args()
    set_manual_seed(args.seed)

    train_session_args = vars(args)
    train_session = get_random_name()
    current_time = datetime.now().strftime('%b%d_%H_%M')
    prefix = f'{current_time}_{args.model}_f{args.fold}_{train_session}_{args.patch_size}'
    if args.fast:
        prefix += '_fast'

    print(prefix)
    print(args)

    log_dir = os.path.join('runs', prefix)
    exp_dir = os.path.join('experiments', args.model, prefix)
    os.makedirs(exp_dir, exist_ok=True)

    model = get_model(args.model,
                      num_classes=1,
                      image_size=(args.patch_size, args.patch_size))
    print(count_parameters(model))

    train_loader, valid_loader = get_dataloaders(
        args.data_dir,
        box_coder=model.box_coder,
        fold=args.fold,
        patch_size=args.patch_size,
        train_batch_size=args.batch_size,
        valid_batch_size=args.batch_size,
        fast=args.fast)

    # Declare variables we will use during training
    start_epoch = 0
    train_history = pd.DataFrame()

    best_metric_val = 0
    best_lb_checkpoint = os.path.join(exp_dir, f'{prefix}.pth')

    if args.encoder_weights:
        classifier = get_model('seresnext_cls', num_classes=1)
        restore_checkpoint(auto_file(args.encoder_weights), classifier)
        encoder_state = classifier.encoder.state_dict()
        model.encoder.load_state_dict(encoder_state)
        del classifier

    if args.resume:
        fname = auto_file(args.resume)
        start_epoch, train_history, best_score = restore_checkpoint(
            fname, model)
        print(train_history)
        print('Resuming training from epoch', start_epoch, ' and score',
              best_score, args.resume)

    writer = SummaryWriter(log_dir)
    writer.add_text('train/params',
                    '```' + json.dumps(train_session_args, indent=2) + '```',
                    0)
    # log_model_graph(writer, model)

    config_fname = os.path.join(exp_dir, f'{train_session}.json')
    with open(config_fname, 'w') as f:
        f.write(json.dumps(train_session_args, indent=2))

    # Main training phase
    model.cuda()
    trainable_parameters = filter(lambda p: p.requires_grad,
                                  model.parameters())
    optimizer = get_optimizer(args.optimizer,
                              trainable_parameters,
                              args.learning_rate,
                              weight_decay=args.weight_decay)
    # scheduler = ReduceLROnPlateau(optimizer, mode='max', patience=50, factor=0.5, min_lr=1e-5)
    scheduler = None

    train_history, best_metric_val, start_epoch = train(
        model,
        optimizer,
        scheduler,
        train_loader,
        valid_loader,
        writer,
        start_epoch,
        epochs=args.epochs,
        early_stopping=args.early_stopping,
        train_history=train_history,
        experiment_dir=exp_dir,
        best_metric_val=best_metric_val,
        checkpoint_filename=best_lb_checkpoint)

    train_history.to_csv(os.path.join(exp_dir, 'train_history.csv'),
                         index=False)
    print('Training finished')
    del train_loader, valid_loader, optimizer

    # Restore to best model
    restore_checkpoint(best_lb_checkpoint, model)

    # Make OOF predictions
    _, valid_ids = get_train_test_split_for_fold(args.fold)
    validset_full = D.RSSDDataset(sample_ids=valid_ids,
                                  data_dir=args.data_dir,
                                  transform=get_transform(
                                      training=False,
                                      width=args.patch_size,
                                      height=args.patch_size),
                                  box_coder=model.box_coder)
    oof_predictions = model.predict_as_csv(validset_full,
                                           batch_size=args.batch_size,
                                           workers=args.workers)
    oof_predictions.to_csv(os.path.join(exp_dir,
                                        f'{prefix}_oof_predictions.csv'),
                           index=None)
    del validset_full

    testset_full = D.RSSDDataset(sample_ids=all_test_ids(args.data_dir),
                                 test=True,
                                 data_dir=args.data_dir,
                                 transform=get_transform(
                                     training=False,
                                     width=args.patch_size,
                                     height=args.patch_size),
                                 box_coder=model.box_coder)
    test_predictions = model.predict_as_csv(testset_full,
                                            batch_size=args.batch_size,
                                            workers=args.workers)
    test_predictions.to_csv(os.path.join(exp_dir,
                                         f'{prefix}_test_predictions.csv'),
                            index=None)
    print('Predictions saved')