Ejemplo n.º 1
0
def train_random_holdout(config: dict):
    """"""
    psm = configs_train['path_save_model']
    mn = configs_train['model_name']
    ph = configs_train['path_history']
    path_history = ph + mn
    config_dataloader = config['dataloader']
    path_df = config_dataloader['path_df']

    df = dl.transform_df(path_df)
    train_df, valid_df = dl.get_train_valid_df(path_df, valid_size=0.1)
    train_trf = trfs.ImgAugTrainTransform()
    valid_trf = trfs.to_tensor()

    train_data_loader, valid_data_loader = dl.get_train_valid_dataloaders(
        config_dataloader, train_df, valid_df, dl.collate_fn,
        train_trf, valid_trf)

    model = get_model()
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.0005, momentum=0.9,
        weight_decay=0.0005)
    lr_scheduler = None
    num_epochs = configs_train['epochs']

    df_history, df_scores_train, df_scores_valid = train_model(train_data_loader,
        valid_data_loader, model, optimizer, num_epochs, lr_scheduler,
        path_save_model)

    df_history.to_csv(path_history + '_history.csv', index=False)
    df_scores_train.to_csv(path_history + '_scores_train.csv', index=False)
    df_scores_valid.to_csv(path_history + '_scores_valid.csv', index=False)
Ejemplo n.º 2
0
    def __init__(self, args):
        self.args = args

        trainset = ChangeDetection(root=args.data_root,
                                   mode="train",
                                   use_pseudo_label=args.use_pseudo_label)
        valset = ChangeDetection(root=args.data_root, mode="val")
        self.trainloader = DataLoader(trainset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      pin_memory=False,
                                      num_workers=16,
                                      drop_last=True)
        self.valloader = DataLoader(valset,
                                    batch_size=args.val_batch_size,
                                    shuffle=False,
                                    pin_memory=True,
                                    num_workers=16,
                                    drop_last=False)

        self.model = get_model(args.model, args.backbone, args.pretrained,
                               len(trainset.CLASSES) - 1, args.lightweight)
        if args.pretrain_from:
            self.model.load_state_dict(torch.load(args.pretrain_from),
                                       strict=False)

        if args.load_from:
            self.model.load_state_dict(torch.load(args.load_from), strict=True)

        if args.use_pseudo_label:
            weight = torch.FloatTensor([1, 1, 1, 1, 1, 1]).cuda()
        else:
            weight = torch.FloatTensor([2, 1, 2, 2, 1, 1]).cuda()
        self.criterion = CrossEntropyLoss(ignore_index=-1, weight=weight)
        self.criterion_bin = BCELoss(reduction='none')

        self.optimizer = Adam([{
            "params": [
                param for name, param in self.model.named_parameters()
                if "backbone" in name
            ],
            "lr":
            args.lr
        }, {
            "params": [
                param for name, param in self.model.named_parameters()
                if "backbone" not in name
            ],
            "lr":
            args.lr * 10.0
        }],
                              lr=args.lr,
                              weight_decay=args.weight_decay)

        self.model = DataParallel(self.model).cuda()

        self.iters = 0
        self.total_iters = len(self.trainloader) * args.epochs
        self.previous_best = 0.0
Ejemplo n.º 3
0
def train_random_holdout(config: dict):
    """"""
    cwd = os.getcwd()
    en = config['experiment_name']
    ph = os.path.join(cwd, *config['path_history'], en)
    psm = os.path.join(cwd, *config['path_save_model'], en + '.pth')
    config_dataloader = config['dataloader']
    config_train = config['train']
    path_df = os.path.join(cwd, *config_dataloader['path_df'])

    optimizer_config = config_train['optimizer']
    num_epochs = config_train['epochs']
    batch_size = config_dataloader['train_loader']['batch_size']
    lr_scheduler = None
    hparams = {
        **optimizer_config,
        'epochs': num_epochs,
        'batch_size': batch_size
    }

    make_dir(ph)
    writer = SummaryWriter(ph, comment=en)
    #writer.add_hparams(hparams)
    train_df, valid_df = dl.get_train_valid_df(path_df, valid_size=0.2)
    train_trf = trfs.ImgAugTrainTransform()
    valid_trf = trfs.to_tensor()

    train_data_loader, valid_data_loader = dl.get_train_valid_data_loaders(
        config_dataloader, train_df, valid_df, dl.collate_fn,
        train_trf, valid_trf)

    model = get_model()
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, **optimizer_config)

    df_history, df_scores_train, df_scores_valid = train_model(
        train_data_loader, valid_data_loader, model, optimizer,
        num_epochs, psm, writer, lr_scheduler)

    df_history.iloc[0].to_dict()

    df_history.to_csv(os.path.join(ph, en + '_history.csv'), index=False)
    df_scores_train.to_csv(
        os.path.join(ph, en + '_scores_train.csv'),
        index=False
    )
    df_scores_valid.to_csv(
        os.path.join(ph, en + '_scores_valid.csv'),
        index=False
    )
Ejemplo n.º 4
0
def set_up_evaluation(config: dict, weights_path, output_path):
    cwd = os.getcwd()

    config_dataloader = config['dataloader']
    path_df = os.path.join(cwd, *config_dataloader['path_df'])

    train_df, valid_df = dl.get_train_valid_df(path_df, valid_size=0.1)
    valid_trf = trfs.to_tensor()
    _, valid_data_loader = dl.get_train_valid_data_loaders(config_dataloader,
                                                           train_df,
                                                           valid_df,
                                                           dl.collate_fn,
                                                           valid_trf=valid_trf)

    model = get_model()
    predict_model(valid_data_loader, model, weights_path, output_path)
Ejemplo n.º 5
0
def evaluate_for_best_epoch(
    fold,
    model_path,
    device,
    valid_loader,
    model_name,
    valid_targets,
    epoch="final",
    meta_features=None,
):
    args = get_args()
    with open(args.config) as file:
        config_file = yaml.load(file, Loader=yaml.FullLoader)
    config = wandb.config  # Initialize config
    config.update(config_file)

    print(f"Evaluating on epoch {epoch} from {model_path}")
    model = get_model(
        config.model_backbone,
        config.model_name,
        config.num_classes,
        config.input_size,
        config.use_metadata,
        meta_features,
    )
    model.load_state_dict(torch.load(model_path))

    model.to(device, non_blocking=True)
    predictions, valid_loss = Engine.evaluate(
        valid_loader,
        model,
        device=device,
        wandb=wandb,
        epoch=epoch,
        upload_image=True,
        use_sigmoid=True,
    )
    predictions = np.vstack((predictions)).ravel()

    auc = metrics.roc_auc_score(valid_targets, predictions)
    oof_file = config.oof_file.replace(".npy", "_" + str(fold) + ".npy")
    np.save(oof_file, valid_targets)
    print(f"Epoch = {epoch}, AUC = {auc}")
    wandb.log({
        "best_valid_auc": auc,
    })
Ejemplo n.º 6
0
def train_skfold(config: dict):
    """"""
    configs_train = config['train']
    psm = configs_train['path_save_model']
    mn = configs_train['model_name']
    ph = configs_train['path_history']
    path_history = ph + mn
    config_dataloader = config['dataloader']
    path_df = config_dataloader['path_df']

    df = dl.transform_df(path_df)
    df = dl.split_stratifiedKFolds_bbox_count(df, config_dataloader['n_splits'])
    folds = list(df['fold'].unique())
    train_trf = trfs.ImgAugTrainTransform()
    valid_trf = trfs.to_tensor()

    for i, fold in enumerate(folds):

        print(f"{'_'*30}Training on fold {fold}...{'_'*30}")
        path_save_model = psm + mn + f'_fold{fold}.pth'
        train_df, valid_df = dl.get_train_valid_df_skfold(df, fold)
        train_data_loader, valid_data_loader = dl.get_train_valid_dataloaders(
            config_dataloader, train_df, valid_df, dl.collate_fn,
        train_trf, valid_trf)

        model = get_model()
        params = [p for p in model.parameters() if p.requires_grad]
        optimizer = torch.optim.SGD(params, lr=0.0005, momentum=0.9,
            weight_decay=0.0005)
        lr_scheduler = None
        num_epochs = configs_train['epochs']

        df_history, df_scores_train, df_scores_valid = train_model(
            train_data_loader, valid_data_loader, model, optimizer, num_epochs,
            lr_scheduler, path_save_model
        )
        df_history.to_csv(path_history + f'fold{fold}_history.csv', index=False)
        df_scores_train.to_csv(path_history + f'fold{fold}_scores_train.csv',
            index=False)
        df_scores_valid.to_csv(path_history + f'fold{fold}_scores_valid.csv',
            index=False)
Ejemplo n.º 7
0
def main():
    # init the args
    global best_pred, acclist_train, acclist_val
    args = Options().parse()

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    print(args)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    _, _, transform_infer = transforms.get_transform(args.dataset)
    galleryset = datasets.get_dataset(args.dataset,
                                      root='/home/ace19/dl_data/materials/train',
                                      transform=transform_infer)
    queryset = datasets.get_dataset(args.dataset,
                                    split='eval',
                                    root='/home/ace19/dl_data/materials/query',
                                    transform=transform_infer)
    gallery_loader = DataLoader(
        galleryset, batch_size=args.batch_size, num_workers=args.workers)
    query_loader = torch.utils.data.DataLoader(
        queryset, batch_size=args.test_batch_size, num_workers=args.workers)

    # init the model
    model = model_zoo.get_model(args.model)
    print(model)

    if args.cuda:
        model.cuda()
        # Please use CUDA_VISIBLE_DEVICES to control the number of gpus
        model = nn.DataParallel(model)

    # check point
    if args.checkpoint is not None:
        if os.path.isfile(args.checkpoint):
            print("=> loading checkpoint '{}'".format(args.checkpoint))
            checkpoint = torch.load(args.checkpoint)
            args.start_epoch = checkpoint['epoch'] + 1
            best_pred = checkpoint['best_pred']
            acclist_train = checkpoint['acclist_train']
            acclist_val = checkpoint['acclist_val']
            model.module.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.checkpoint, checkpoint['epoch']))
        else:
            raise RuntimeError("=> no infer checkpoint found at '{}'". \
                               format(args.checkpoint))
    else:
        raise RuntimeError("=> config \'args.checkpoint\' is '{}'". \
                           format(args.checkpoint))


    gallery_features_list = []
    gallery_path_list = []
    query_features_list = []
    query_path_list = []
    def retrieval():
        model.eval()

        print(" ==> Loading gallery ... ")
        tbar = tqdm(gallery_loader, desc='\r')
        for batch_idx, (gallery_paths, data, gt) in enumerate(tbar):
            if args.cuda:
                data, gt = data.cuda(), gt.cuda()

            with torch.no_grad():
                # features [256, 2048]
                # output [256, 128]
                # features, output = model(data)

                # TTA
                batch_size, n_crops, c, h, w = data.size()
                # fuse batch size and ncrops
                features, _ = model(data.view(-1, c, h, w))
                # avg over crops
                features = features.view(batch_size, n_crops, -1).mean(1)
                gallery_features_list.extend(features)
                gallery_path_list.extend(gallery_paths)
        # end of for

        print(" ==> Loading query ... ")
        tbar = tqdm(query_loader, desc='\r')
        for batch_idx, (query_paths, data) in enumerate(tbar):
            if args.cuda:
                data = data.cuda()

            with torch.no_grad():
                # TTA
                batch_size, n_crops, c, h, w = data.size()
                # fuse batch size and ncrops
                features, _ = model(data.view(-1, c, h, w))
                # avg over crops
                features = features.view(batch_size, n_crops, -1).mean(1)
                query_features_list.extend(features)
                query_path_list.extend(query_paths)
        # end of for

        if len(query_features_list) == 0:
            print('No query data!!')
            return

        # matching
        top_n_indice, top_n_distance = \
            match_n(TOP_N,
                    torch.stack(gallery_features_list).cpu(),
                    torch.stack(query_features_list).cpu())

        # Show n images from the gallery similar to the query image.
        show_retrieval_result(top_n_indice, top_n_distance, gallery_path_list, query_path_list)

    retrieval()
Ejemplo n.º 8
0
from torch.nn import DataParallel
from torch.utils.data import DataLoader
from tqdm import tqdm

if __name__ == '__main__':
    args = Options().parse()

    dataset = ChangeDetection(root=args.data_root, mode='pseudo_labeling')
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            pin_memory=True,
                            num_workers=4,
                            drop_last=False)

    model1 = get_model('fcn', 'hrnet_w40', False,
                       len(dataset.CLASSES) - 1, True)
    model1.load_state_dict(torch.load('outdir/models/fcn_hrnet_w40_37.37.pth'),
                           strict=True)

    model2 = get_model('fcn', 'hrnet_w48', False,
                       len(dataset.CLASSES) - 1, True)
    model2.load_state_dict(torch.load('outdir/models/fcn_hrnet_w48_37.46.pth'),
                           strict=True)

    model3 = get_model('pspnet', 'hrnet_w40', False,
                       len(dataset.CLASSES) - 1, True)
    model3.load_state_dict(
        torch.load('outdir/models/pspnet_hrnet_w40_37.69.pth'), strict=True)

    model4 = get_model('pspnet', 'hrnet_w48', False,
                       len(dataset.CLASSES) - 1, True)
    TTA_TIME = 0

    args = Options().parse()

    torch.backends.cudnn.benchmark = True

    print(torch.cuda.is_available())
    testset = ChangeDetection(root=args.data_root, mode="test")
    testloader = DataLoader(testset,
                            batch_size=8,
                            shuffle=False,
                            pin_memory=True,
                            num_workers=4,
                            drop_last=False)

    model1 = get_model('pspnet', 'hrnet_w40', False,
                       len(testset.CLASSES) - 1, True)
    model1.load_state_dict(
        torch.load('outdir/models/pspnet_hrnet_w40_39.37.pth'), strict=True)
    model2 = get_model('pspnet', 'hrnet_w18', False,
                       len(testset.CLASSES) - 1, True)
    model2.load_state_dict(
        torch.load('outdir/models/pspnet_hrnet_w18_38.74.pth'), strict=True)

    models = [model1, model2]
    for i in range(len(models)):
        models[i] = models[i].cuda()
        models[i].eval()

    cmap = color_map()

    tbar = tqdm(testloader)
Ejemplo n.º 10
0
def main():
    # init the args
    global best_pred, acclist_train, acclist_val
    args = Options().parse()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    print(args)

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # init dataloader
    transform_train, transform_val, _ = transforms.get_transform(args.dataset)
    trainset = datasets.get_dataset(args.dataset,
                                    root='/home/ace19/dl_data/materials/train',
                                    transform=transform_train)
    valset = datasets.get_dataset(
        args.dataset,
        root='/home/ace19/dl_data/materials/validation',
        transform=transform_val)

    # balanced sampling between classes
    train_loader = DataLoader(trainset,
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              sampler=ImbalancedDatasetSampler(trainset))
    # train_loader = DataLoader(
    #     trainset, batch_size=args.batch_size, shuffle=True,
    #     num_workers=args.workers, pin_memory=True)
    val_loader = DataLoader(valset,
                            batch_size=args.test_batch_size,
                            shuffle=False,
                            num_workers=args.workers)

    # init the backbone model
    if args.pretrained is not None:
        model = model_zoo.get_model(args.model, backbone=args.backbone)
    else:
        model = model_zoo.get_model(args.model,
                                    backbone_pretrained=True,
                                    backbone=args.backbone)
    print(model)

    # criterion and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
    #                             weight_decay=args.weight_decay)
    if args.cuda:
        model.cuda()
        criterion.cuda()
        # Please use CUDA_VISIBLE_DEVICES to control the number of gpus
        model = nn.DataParallel(model)
    # check point
    if args.pretrained is not None:
        if os.path.isfile(args.pretrained):
            print("=> loading checkpoint '{}'".format(args.pretrained))
            checkpoint = torch.load(args.pretrained)
            args.start_epoch = checkpoint['epoch'] + 1
            best_pred = checkpoint['best_pred']
            acclist_train = checkpoint['acclist_train']
            acclist_val = checkpoint['acclist_val']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.pretrained, checkpoint['epoch']))
        else:
            raise RuntimeError("=> no pretrained checkpoint found at '{}'". \
                               format(args.pretrained))

    scheduler = lr_scheduler.LR_Scheduler(args.lr_scheduler,
                                          args.lr, args.epochs,
                                          len(train_loader), args.lr_step)

    def train(epoch):
        model.train()
        losses = AverageMeter()
        top1 = AverageMeter()

        global best_pred, acclist_train, acclist_val

        tbar = tqdm(train_loader, desc='\r')
        for batch_idx, (_, images, targets) in enumerate(tbar):
            scheduler(optimizer, batch_idx, epoch, best_pred)
            # display_data(images)
            # TODO: Convert from list of 3D to 4D
            # images = np.stack(images, axis=1)
            # images = torch.from_numpy(images)

            if args.cuda:
                images, targets = images.cuda(), targets.cuda()

            # compute gradient and do SGD step
            optimizer.zero_grad()
            _, output = model(images)
            loss = criterion(output, targets)
            loss.backward()
            optimizer.step()

            acc1 = accuracy(output, targets)
            top1.update(acc1[0], images.size(0))
            losses.update(loss.item(), images.size(0))
            tbar.set_description('\rLoss: %.3f | Top1: %.3f' %
                                 (losses.avg, top1.avg))

        acclist_train += [top1.avg]

    def validate(epoch):
        model.eval()
        top1 = AverageMeter()
        top5 = AverageMeter()
        confusion_matrix = torch.zeros(args.nclass, args.nclass)

        global best_pred, acclist_train, acclist_val
        is_best = False

        tbar = tqdm(val_loader, desc='\r')
        # TTA(TenCrop) input, target = batch # input is a 5d tensor, target is 2d
        # bs, ncrops, c, h, w = input.size()
        # result = model(input.view(-1, c, h, w))  # fuse batch size and ncrops
        # result_avg = result.view(bs, ncrops, -1).mean(1)  # avg over crops
        for batch_idx, (fnames, images, targets) in enumerate(tbar):
            # Convert from list of 3D to 4D
            # images = np.stack(images, axis=1)
            # images = torch.from_numpy(images)

            if args.cuda:
                images, targets = images.cuda(), targets.cuda()
                # images, targets = Variable(images), Variable(targets)
            with torch.no_grad():
                # _, output = model(images)

                # TTA
                batch_size, n_crops, c, h, w = images.size()
                # fuse batch size and ncrops
                _, output = model(images.view(-1, c, h, w))
                # avg over crops
                output = output.view(batch_size, n_crops, -1).mean(1)
                # accuracy
                acc1, acc5 = accuracy(output, targets, topk=(1, 1))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))

                # confusion matrix
                _, preds = torch.max(output, 1)
                for t, p in zip(targets.view(-1), preds.view(-1)):
                    confusion_matrix[t.long(), p.long()] += 1

            tbar.set_description('Top1: %.3f | Top5: %.3f' %
                                 (top1.avg, top5.avg))
        # end of for

        print('\n----------------------------------')
        print('confusion matrix:\n', confusion_matrix)
        # get the per-class accuracy
        print('\nper-class accuracy(precision):\n',
              confusion_matrix.diag() / confusion_matrix.sum(1))
        print('----------------------------------\n')

        if args.eval:
            print('Top1 Acc: %.3f | Top5 Acc: %.3f ' % (top1.avg, top5.avg))
            return

        # save checkpoint
        acclist_val += [top1.avg]
        if top1.avg > best_pred:
            best_pred = top1.avg
            is_best = True
        files.save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_pred': best_pred,
                'acclist_train': acclist_train,
                'acclist_val': acclist_val,
            },
            args=args,
            is_best=is_best)

    if args.eval:
        validate(args.start_epoch)
        # writer.close()
        return

    for epoch in range(args.start_epoch, args.epochs + 1):
        train(epoch)
        validate(epoch)
Ejemplo n.º 11
0
def main():
    # init the args
    global best_pred, acclist_train, acclist_val
    args = Options().parse()

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    print(args)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # init dataloader
    _, _, transform_infer = transforms.get_transform(args.dataset)
    infer_set = datasets.get_dataset(
        args.dataset,
        split='eval',
        root='/home/ace19/dl_data/v2-plant-seedlings-dataset2/evalset',
        transform=transform_infer)
    infer_loader = torch.utils.data.DataLoader(infer_set,
                                               batch_size=args.test_batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True)

    # init the model
    model = model_zoo.get_model(args.model)
    print(model)

    if args.cuda:
        model.cuda()
        # Please use CUDA_VISIBLE_DEVICES to control the number of gpus
        model = nn.DataParallel(model)

    # check point
    if args.checkpoint is not None:
        if os.path.isfile(args.checkpoint):
            print("=> loading checkpoint '{}'".format(args.checkpoint))
            checkpoint = torch.load(args.checkpoint)
            args.start_epoch = checkpoint['epoch'] + 1
            best_pred = checkpoint['best_pred']
            acclist_train = checkpoint['acclist_train']
            acclist_val = checkpoint['acclist_val']
            model.module.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.checkpoint, checkpoint['epoch']))
        else:
            raise RuntimeError("=> no infer checkpoint found at '{}'". \
                               format(args.checkpoint))
    else:
        raise RuntimeError("=> config \'args.checkpoint\' is '{}'". \
                           format(args.checkpoint))

    def eval():
        model.eval()

        submission = {}

        tbar = tqdm(infer_loader, desc='\r')
        for batch_idx, (fnames, images) in enumerate(tbar):
            if args.cuda:
                images = images.cuda()

            with torch.no_grad():
                # TTA
                batch_size, n_crops, c, h, w = images.size()
                # fuse batch size and ncrops
                output = model(images.view(-1, c, h, w))
                # avg over crops
                output = output.view(batch_size, n_crops, -1).mean(1)
                _, preds = torch.max(output, 1)

            size = len(fnames)
            for i in range(size):
                submission[fnames[i]] = preds[i].cpu()
        # end of for

        ########################
        # make submission.csv
        ########################
        if not os.path.exists('result'):
            os.makedirs('result')

        fout = open(os.path.join('result', args.result + '#20.csv'),
                    'w',
                    encoding='utf-8',
                    newline='')
        writer = csv.writer(fout)
        # writer.writerow(['id', 'label'])
        for key in sorted(submission.keys()):
            name = key.split('/')[-1]
            writer.writerow([name, submission[key].numpy()])
        fout.close()

    eval()
Ejemplo n.º 12
0
def train(fold):
    args = get_args()
    with open(args.config) as file:
        config_file = yaml.load(file, Loader=yaml.FullLoader)

    wandb.init(
        project="siim2020",
        entity="siim_melanoma",
        # name=f"20200718-effb0-adamw-consineaneal-{fold}",
        name=f"2017-2018-rexnet-test-{fold}",
        #name=f"swav-test-{fold}",
        #name=f"RAdam-b6-384x384-{fold}"
    )
    config = wandb.config  # Initialize config
    config.update(config_file)
    device = config.device

    model_path = config.model_path.format(fold)

    seed_everything(config.seed)
    df = pd.read_csv(config.train_csv_fold)
    df_train = df[df.kfold != fold].reset_index(drop=True)
    df_train["image_name"] = config.training_data_path + df_train[
        "image_name"] + ".jpg"

    if config.supplement_data["use_supplement"]:
        print(f"training shape before merge {df_train.shape}")
        df_supplement = pd.read_csv(config.supplement_data["csv_file"])
        df_supplement = df_supplement[df_supplement["tfrecord"] % 2 == 0]
        df_supplement = df_supplement[df_supplement["target"] == 1]
        df_supplement["image_name"] = (config.supplement_data["file_path"] +
                                       df_supplement["image_name"] + ".jpg")
        df_train = pd.concat([df_train, df_supplement]).reset_index(drop=True)
        df_train = df_train.sample(
            frac=1, random_state=config.seed).reset_index(drop=True)
        del df_supplement
        print(f"training shape after merge {df_train.shape}")

    df_valid = df[df.kfold == fold].reset_index(drop=True)
    df_valid["image_name"] = config.training_data_path + df_valid[
        "image_name"] + ".jpg"

    if config.use_metadata:
        df_train, meta_features = get_meta_feature(df_train)
        df_valid, _ = get_meta_feature(df_valid)
    else:
        meta_features = None

    model = get_model(
        config.model_backbone,
        config.model_name,
        config.num_classes,
        config.input_size,
        config.use_metadata,
        meta_features,
    )

    model = model.to(config.device)
    print("watching model")
    wandb.watch(model, log="all")

    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    train_aug = albumentations.Compose([
        AdvancedHairAugmentation(hairs_folder="../input/melanoma-hairs/"),
        # albumentations.augmentations.transforms.CenterCrop(64, 64, p=0.8),
        albumentations.augmentations.transforms.RandomBrightnessContrast(),
        albumentations.augmentations.transforms.HueSaturationValue(),
        # Microscope(p=0.4),
        albumentations.augmentations.transforms.RandomResizedCrop(
            config.input_size, config.input_size, scale=(0.7, 1.0), p=0.4),
        albumentations.augmentations.transforms.VerticalFlip(p=0.4),
        albumentations.augmentations.transforms.Cutout(p=0.3),  # doesnt work
        albumentations.ShiftScaleRotate(shift_limit=0.0625,
                                        scale_limit=0.1,
                                        rotate_limit=15),
        albumentations.Flip(p=0.5),
        RandomAugMix(severity=7, width=7, alpha=5, p=0.3),
        # albumentations.augmentations.transforms.Resize(
        #    config.input_size, config.input_size, p=1
        # ),
        albumentations.Normalize(mean,
                                 std,
                                 max_pixel_value=255.0,
                                 always_apply=True),
    ])

    valid_aug = albumentations.Compose([
        albumentations.Normalize(mean,
                                 std,
                                 max_pixel_value=255.0,
                                 always_apply=True),
    ])

    train_images = df_train.image_name.values.tolist()
    # train_images = [
    #    os.path.join(config.training_data_path, i + ".jpg") for i in train_images
    # ]
    train_targets = df_train.target.values

    valid_images = df_valid.image_name.values.tolist()
    # valid_images = [
    #    os.path.join(config.training_data_path, i + ".jpg") for i in valid_images
    # ]
    valid_targets = df_valid.target.values

    train_dataset = ClassificationDataset(
        image_paths=train_images,
        targets=train_targets,
        resize=None,
        augmentations=train_aug,
        meta_features=meta_features,
        df_meta_features=df_train,
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        # num_workers=4,
        num_workers=1,
        pin_memory=True,
        shuffle=True,
        #sampler=BalanceClassSampler(labels=train_targets, mode="upsampling"),
        drop_last=True,
    )

    valid_dataset = ClassificationDataset(
        image_paths=valid_images,
        targets=valid_targets,
        resize=None,
        augmentations=valid_aug,
        meta_features=meta_features,
        df_meta_features=df_valid,
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.test_batch_size,
        shuffle=False,
        # num_workers=4,
        num_workers=1,
        pin_memory=True,
        # drop_last=True
    )

    #optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    optimizer = RAdam(model.parameters(), lr=config.lr)
    if config.swa["use_swa"]:
        optimizer = SWA(optimizer, swa_start=12, swa_freq=1)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=2,
                                                           threshold=0.0001,
                                                           mode="max")
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #    optimizer, len(train_loader) * config.epochs
    # )

    #scheduler = torch.optim.lr_scheduler.CyclicLR(
    #   optimizer,
    #   base_lr=config.lr / 10,
    #   max_lr=config.lr * 100,
    #   mode="triangular2",
    #   cycle_momentum=False,
    #)

    #scheduler = torch.optim.lr_scheduler.OneCycleLR(
    #    optimizer, max_lr=3e-3, steps_per_epoch=len(train_loader), epochs=config.epochs
    #)

    es = EarlyStopping(patience=6, mode="max")
    if config.fp16:
        print("************* using fp16 *************")
        scaler = GradScaler()
    else:
        scaler = False

    for epoch in range(config.epochs):
        train_loss = Engine.train(
            train_loader,
            model,
            optimizer,
            device=config.device,
            wandb=wandb,
            accumulation_steps=config.accumulation_steps,
            fp16=config.fp16,
            scaler=scaler,
        )
        predictions, valid_loss = Engine.evaluate(
            valid_loader,
            model,
            device=config.device,
            wandb=wandb,
            epoch=epoch,
            upload_image=False,
            use_sigmoid=True,
        )
        predictions = np.vstack((predictions)).ravel()

        auc = metrics.roc_auc_score(valid_targets, predictions)
        print(f"Epoch = {epoch}, AUC = {auc}")
        wandb.log({
            "valid_auc": auc,
        })

        scheduler.step(auc)

        es(auc, model, model_path=model_path)
        if es.early_stop:
            print("Early stopping")
            break
    if config.swa["use_swa"]:
        print("saving the model using SWA")
        optimizer.swap_swa_sgd()
        torch.save(model.state_dict(), config.swa["model_path"].format(fold))

    evaluate_for_best_epoch(
        fold,
        model_path,
        config.device,
        valid_loader,
        config.model_name,
        valid_targets,
        "final",
        meta_features=meta_features,
    )
    if config.swa["use_swa"]:
        model_path = config.swa["model_path"].format(fold)
        evaluate_for_best_epoch(
            fold,
            model_path,
            config.device,
            valid_loader,
            config.model_name,
            valid_targets,
            "swa",
            meta_features=meta_features,
        )
Ejemplo n.º 13
0
def predict(fold):
    print(f"Prediction on test set fold {fold}")
    args = get_args()
    with open(args.config) as file:
        config_file = yaml.load(file, Loader=yaml.FullLoader)
    config = wandb.config  # Initialize config
    config.update(config_file)

    df_test = pd.read_csv(config.test_csv)
    if config.use_metadata:
        df_test, meta_features = get_meta_feature(df_test)
    else:
        meta_features = None

    if config.swa["use_swa"]:
        model_path = config.swa["model_path"].format(fold)
        print(f"using SWA, loading checkpoint from {model_path}")
    else:
        model_path = config.model_path.format(fold)

    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    # aug = albumentations.Compose(
    #    [albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True)]
    # )
    aug = albumentations.Compose([
        AdvancedHairAugmentation(hairs_folder="../input/melanoma-hairs/"),
        # albumentations.augmentations.transforms.CenterCrop(64, 64, p=0.8),
        albumentations.augmentations.transforms.RandomBrightnessContrast(),
        albumentations.augmentations.transforms.HueSaturationValue(),
        # Microscope(p=0.4),
        albumentations.augmentations.transforms.RandomResizedCrop(
            config.input_size, config.input_size, scale=(0.7, 1.0), p=0.4),
        albumentations.augmentations.transforms.VerticalFlip(p=0.4),
        # albumentations.augmentations.transforms.Cutout(p=0.8), # doesnt work
        albumentations.ShiftScaleRotate(shift_limit=0.0625,
                                        scale_limit=0.1,
                                        rotate_limit=15),
        albumentations.Flip(p=0.5),
        # RandomAugMix(severity=7, width=7, alpha=5, p=1),
        # albumentations.augmentations.transforms.Resize(
        #    config.input_size, config.input_size, p=1
        # ),
        albumentations.Normalize(mean,
                                 std,
                                 max_pixel_value=255.0,
                                 always_apply=True),
    ])

    images = df_test.image_name.values.tolist()
    images = [os.path.join(config.test_data_path, i + ".jpg") for i in images]
    targets = np.zeros(len(images))

    test_dataset = ClassificationDataset(
        image_paths=images,
        targets=targets,
        resize=None,
        augmentations=aug,
        meta_features=meta_features,
        df_meta_features=df_test,
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.test_batch_size,
        shuffle=False,
        num_workers=4,
    )

    model = get_model(
        config.model_backbone,
        config.model_name,
        config.num_classes,
        config.input_size,
        config.use_metadata,
        meta_features,
    )
    model.load_state_dict(torch.load(model_path))
    model.to(config.device)

    if config.tta:
        predictions = get_tta_prediction(config.tta, test_loader, model,
                                         config.device, True, len(images))
    else:
        predictions = Engine.predict(test_loader,
                                     model,
                                     device=config.device,
                                     use_sigmoid=True)

        predictions = np.vstack((predictions)).ravel()

    return predictions