Example #1
0
def main():
    args = parse_args()
    args.uncropped = True

    with open('models/detection/%s/config.yml' % args.name, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    # config["tvec"] = False
    print('-'*20)
    for key in config.keys():
        print('%s: %s' % (key, str(config[key])))
    print('-'*20)

    cudnn.benchmark = False

    df = pd.read_csv('inputs/sample_submission.csv')
    img_ids = df['ImageId'].values
    img_paths = np.array('inputs/test_images/' + df['ImageId'].values + '.jpg')
    mask_paths = np.array('inputs/test_masks/' + df['ImageId'].values + '.jpg')
    labels = np.array([convert_str_to_labels(s, names=['yaw', 'pitch', 'roll',
                       'x', 'y', 'z', 'score']) for s in df['PredictionString']])

    if not args.uncropped:
        cropped_img_ids = pd.read_csv('inputs/testset_cropped_imageids.csv')['ImageId'].values
        for i, img_id in enumerate(img_ids):
            if img_id in cropped_img_ids:
                img_paths[i] = 'inputs/test_images_uncropped/' + img_id + '.jpg'
                mask_paths[i] = 'inputs/test_masks_uncropped/' + img_id + '.jpg'

    test_set = Dataset(
        img_paths,
        mask_paths,
        labels,
        input_w=config['input_w'],
        input_h=config['input_h'],
        transform=None,
        test=True,
        lhalf=config['lhalf'])
    test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=16,
        shuffle=False,
        num_workers=0,
        # num_workers=config['num_workers'],
        # pin_memory=True,
    )

    heads = OrderedDict([
        ('hm', 1),
        ('reg', 2),
        ('depth', 1),
    ])

    if config['rot'] == 'eular':
        heads['eular'] = 3
    elif config['rot'] == 'trig':
        heads['trig'] = 6
    elif config['rot'] == 'quat':
        heads['quat'] = 4
    else:
        raise NotImplementedError

    if config['wh']:
        heads['wh'] = 2
    
    if config['tvec']:
        heads['tvec'] = 3

    name = args.name
    if args.uncropped:
        name += '_uncropped'
    if args.hflip:
        name += '_hf'

    if os.path.exists('outputs/raw/test/%s.pth' %name):
        merged_outputs = torch.load('outputs/raw/test/%s.pth' %name)

    else:
        merged_outputs = {}
        for i in tqdm(range(len(df))):
            img_id = df.loc[i, 'ImageId']

            output = {
                'hm': 0,
                'reg': 0,
                'depth': 0,
                'eular': 0 if config['rot'] == 'eular' else None,
                'trig': 0 if config['rot'] == 'trig' else None,
                'quat': 0 if config['rot'] == 'quat' else None,
                'wh': 0 if config['wh'] else None,
                'tvec': 0 if config['tvec'] else None,
            }

            merged_outputs[img_id] = output

        preds = []
        for fold in range(config['n_splits']):
            print('Fold [%d/%d]' %(fold + 1, config['n_splits']))

            model = get_model(config['arch'], heads=heads,
                              head_conv=config['head_conv'],
                              num_filters=config['num_filters'],
                              dcn=config['dcn'],
                              gn=config['gn'], ws=config['ws'],
                              freeze_bn=config['freeze_bn'])
            model = model.cuda()

            model_path = 'models/detection/%s/model_%d.pth' % (config['name'], fold+1)
            if not os.path.exists(model_path):
                print('%s is not exists.' %model_path)
                continue
            model.load_state_dict(torch.load(model_path))

            model.eval()

            preds_fold = []
            outputs_fold = {}
            with torch.no_grad():
                pbar = tqdm(total=len(test_loader))
                for i, batch in enumerate(test_loader):
                    input = batch['input'].cuda()
                    mask = batch['mask'].cuda()

                    output = model(input)
                    # print(output)

                    if args.hflip:
                        output_hf = model(torch.flip(input, (-1,)))
                        output_hf['hm'] = torch.flip(output_hf['hm'], (-1,))
                        output_hf['reg'] = torch.flip(output_hf['reg'], (-1,))
                        output_hf['reg'][:, 0] = 1 - output_hf['reg'][:, 0]
                        output_hf['depth'] = torch.flip(output_hf['depth'], (-1,))
                        if config['rot'] == 'trig':
                            output_hf['trig'] = torch.flip(output_hf['trig'], (-1,))
                            yaw = torch.atan2(output_hf['trig'][:, 1], output_hf['trig'][:, 0])
                            yaw *= -1.0
                            output_hf['trig'][:, 0] = torch.cos(yaw)
                            output_hf['trig'][:, 1] = torch.sin(yaw)
                            roll = torch.atan2(output_hf['trig'][:, 5], output_hf['trig'][:, 4])
                            roll = rotate(roll, -np.pi)
                            roll *= -1.0
                            roll = rotate(roll, np.pi)
                            output_hf['trig'][:, 4] = torch.cos(roll)
                            output_hf['trig'][:, 5] = torch.sin(roll)

                        if config['wh']:
                            output_hf['wh'] = torch.flip(output_hf['wh'], (-1,))
                        
                        if config['tvec']:
                            output_hf['tvec'] = torch.flip(output_hf['tvec'], (-1,))
                            output_hf['tvec'][:, 0] *= -1.0

                        output['hm'] = (output['hm'] + output_hf['hm']) / 2
                        output['reg'] = (output['reg'] + output_hf['reg']) / 2
                        output['depth'] = (output['depth'] + output_hf['depth']) / 2
                        if config['rot'] == 'trig':
                            output['trig'] = (output['trig'] + output_hf['trig']) / 2
                        if config['wh']:
                            output['wh'] = (output['wh'] + output_hf['wh']) / 2
                        if config['tvec']:
                            output['tvec'] = (output['tvec'] + output_hf['tvec']) / 2

                    for b in range(len(batch['img_path'])):
                        img_id = os.path.splitext(os.path.basename(batch['img_path'][b]))[0]

                        outputs_fold[img_id] = {
                            'hm': output['hm'][b:b+1].cpu(),
                            'reg': output['reg'][b:b+1].cpu(),
                            'depth': output['depth'][b:b+1].cpu(),
                            'eular': output['eular'][b:b+1].cpu() if config['rot'] == 'eular' else None,
                            'trig': output['trig'][b:b+1].cpu() if config['rot'] == 'trig' else None,
                            'quat': output['quat'][b:b+1].cpu() if config['rot'] == 'quat' else None,
                            'wh': output['wh'][b:b+1].cpu() if config['wh'] else None,
                            'tvec': output['tvec'][b:b+1].cpu() if config['tvec'] else None,
                            'mask': mask[b:b+1].cpu(),
                        }

                        merged_outputs[img_id]['hm'] += outputs_fold[img_id]['hm'] / config['n_splits']
                        merged_outputs[img_id]['reg'] += outputs_fold[img_id]['reg'] / config['n_splits']
                        merged_outputs[img_id]['depth'] += outputs_fold[img_id]['depth'] / config['n_splits']
                        if config['rot'] == 'eular':
                            merged_outputs[img_id]['eular'] += outputs_fold[img_id]['eular'] / config['n_splits']
                        if config['rot'] == 'trig':
                            merged_outputs[img_id]['trig'] += outputs_fold[img_id]['trig'] / config['n_splits']
                        if config['rot'] == 'quat':
                            merged_outputs[img_id]['quat'] += outputs_fold[img_id]['quat'] / config['n_splits']
                        if config['wh']:
                            merged_outputs[img_id]['wh'] += outputs_fold[img_id]['wh'] / config['n_splits']
                        if config['tvec']:
                            merged_outputs[img_id]['tvec'] += outputs_fold[img_id]['tvec'] / config['n_splits']
                        merged_outputs[img_id]['mask'] = outputs_fold[img_id]['mask']

                    batch_det = decode(
                        config,
                        output['hm'],
                        output['reg'],
                        output['depth'],
                        eular=output['eular'] if config['rot'] == 'eular' else None,
                        trig=output['trig'] if config['rot'] == 'trig' else None,
                        quat=output['quat'] if config['rot'] == 'quat' else None,
                        wh=output['wh'] if config['wh'] else None,
                        tvec=output['tvec'] if config['tvec'] else None,
                        mask=mask,
                    )
                    batch_det = batch_det.cpu().numpy()

                    for k, det in enumerate(batch_det):
                        if args.nms:
                            det = nms(det, dist_th=args.nms_th)
                        preds_fold.append(convert_labels_to_str(det[det[:, 6] > args.score_th, :7]))

                        if args.show and not config['cv']:
                            img = cv2.imread(batch['img_path'][k])
                            img_pred = visualize(img, det[det[:, 6] > args.score_th])
                            plt.imshow(img_pred[..., ::-1])
                            plt.show()

                    pbar.update(1)
                pbar.close()

            if not config['cv']:
                df['PredictionString'] = preds_fold
                name = '%s_1_%.2f' %(args.name, args.score_th)
                if args.uncropped:
                    name += '_uncropped'
                if args.nms:
                    name += '_nms%.2f' %args.nms_th
                df.to_csv('outputs/submissions/test/%s.csv' %name, index=False)
                return

        if not args.uncropped:
            # ensemble duplicate images
            dup_df = pd.read_csv('processed/test_image_hash.csv')
            dups = dup_df.hash.value_counts()
            dups = dups.loc[dups>1]

            for i in range(len(dups)):
                img_ids = dup_df[dup_df.hash == dups.index[i]].ImageId

                output = {
                    'hm': 0,
                    'reg': 0,
                    'depth': 0,
                    'eular': 0 if config['rot'] == 'eular' else None,
                    'trig': 0 if config['rot'] == 'trig' else None,
                    'quat': 0 if config['rot'] == 'quat' else None,
                    'wh': 0 if config['wh'] else None,
                    'tvec': 0 if config['tvec'] else None,
                    'mask': 0,
                }
                for img_id in img_ids:
                    if img_id in cropped_img_ids:
                        print('fooo')
                    output['hm'] += merged_outputs[img_id]['hm'] / len(img_ids)
                    output['reg'] += merged_outputs[img_id]['reg'] / len(img_ids)
                    output['depth'] += merged_outputs[img_id]['depth'] / len(img_ids)
                    if config['rot'] == 'eular':
                        output['eular'] += merged_outputs[img_id]['eular'] / len(img_ids)
                    if config['rot'] == 'trig':
                        output['trig'] += merged_outputs[img_id]['trig'] / len(img_ids)
                    if config['rot'] == 'quat':
                        output['quat'] += merged_outputs[img_id]['quat'] / len(img_ids)
                    if config['wh']:
                        output['wh'] += merged_outputs[img_id]['wh'] / len(img_ids)
                    if config['tvec']:
                        output['tvec'] += merged_outputs[img_id]['tvec'] / len(img_ids)
                    output['mask'] += merged_outputs[img_id]['mask'] / len(img_ids)

                for img_id in img_ids:
                    merged_outputs[img_id] = output

        torch.save(merged_outputs, 'outputs/raw/test/%s.pth' %name)

    # decode
    dets = {}
    for i in tqdm(range(len(df))):
        img_id = df.loc[i, 'ImageId']

        output = merged_outputs[img_id]

        det = decode(
            config,
            output['hm'],
            output['reg'],
            output['depth'],
            eular=output['eular'] if config['rot'] == 'eular' else None,
            trig=output['trig'] if config['rot'] == 'trig' else None,
            quat=output['quat'] if config['rot'] == 'quat' else None,
            wh=output['wh'] if config['wh'] else None,
            tvec=output['tvec'] if config['tvec'] else None,
            mask=output['mask'],
        )
        det = det.numpy()[0]

        dets[img_id] = det.tolist()

        if args.nms:
            det = nms(det, dist_th=args.nms_th)

        if np.sum(det[:, 6] > args.score_th) >= args.min_samples:
            det = det[det[:, 6] > args.score_th]
        else:
            det = det[:args.min_samples]

        if args.show:
            img = cv2.imread('inputs/test_images/%s.jpg' %img_id)
            img_pred = visualize(img, det)
            plt.imshow(img_pred[..., ::-1])
            plt.show()

        df.loc[i, 'PredictionString'] = convert_labels_to_str(det[:, :7])

    with open('outputs/decoded/test/%s.json' %name, 'w') as f:
        json.dump(dets, f)

    name = '%s_%.2f' %(args.name, args.score_th)
    if args.uncropped:
        name += '_uncropped'
    if args.nms:
        name += '_nms%.2f' %args.nms_th
    if args.hflip:
        name += '_hf'
    if args.min_samples > 0:
        name += '_min%d' %args.min_samples
    df.to_csv('outputs/submissions/test/%s.csv' %name, index=False)
def main():
    args = parse_args()

    if args.name is None:
        args.name = '%s_%s' % (args.arch, datetime.now().strftime('%m%d%H'))

    if not os.path.exists('models/%s' % args.name):
        os.makedirs('models/%s' % args.name)

    if args.resume:
        args = joblib.load('models/%s/args.pkl' % args.name)
        args.resume = True

    print('Config -----')
    for arg in vars(args):
        print('- %s: %s' % (arg, getattr(args, arg)))
    print('------------')

    with open('models/%s/args.txt' % args.name, 'w') as f:
        for arg in vars(args):
            print('- %s: %s' % (arg, getattr(args, arg)), file=f)

    joblib.dump(args, 'models/%s/args.pkl' % args.name)

    if args.seed is not None and not args.resume:
        print('set random seed')
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

    if args.loss == 'BCEWithLogitsLoss':
        criterion = BCEWithLogitsLoss().cuda()
    elif args.loss == 'WeightedBCEWithLogitsLoss':
        criterion = BCEWithLogitsLoss(weight=torch.Tensor([1., 1., 1., 1., 1., 2.]),
                                      smooth=args.label_smooth).cuda()
    elif args.loss == 'FocalLoss':
        criterion = FocalLoss().cuda()
    elif args.loss == 'WeightedFocalLoss':
        criterion = FocalLoss(weight=torch.Tensor([1., 1., 1., 1., 1., 2.])).cuda()
    else:
        raise NotImplementedError

    if args.pred_type == 'all':
        num_outputs = 6
    elif args.pred_type == 'except_any':
        num_outputs = 5
    else:
        raise NotImplementedError

    cudnn.benchmark = True

    # create model
    model = get_model(model_name=args.arch,
                      num_outputs=num_outputs,
                      freeze_bn=args.freeze_bn,
                      dropout_p=args.dropout_p,
                      pooling=args.pooling,
                      lp_p=args.lp_p)
    model = model.cuda()

    train_transform = Compose([
        transforms.Resize(args.img_size, args.img_size),
        transforms.HorizontalFlip() if args.hflip else NoOp(),
        transforms.VerticalFlip() if args.vflip else NoOp(),
        transforms.ShiftScaleRotate(
            shift_limit=args.shift_limit,
            scale_limit=args.scale_limit,
            rotate_limit=args.rotate_limit,
            border_mode=cv2.BORDER_CONSTANT,
            value=0,
            p=args.shift_scale_rotate_p
        ) if args.shift_scale_rotate else NoOp(),
        transforms.RandomContrast(
            limit=args.contrast_limit,
            p=args.contrast_p
        ) if args.contrast else NoOp(),
        RandomErase() if args.random_erase else NoOp(),
        transforms.CenterCrop(args.crop_size, args.crop_size) if args.center_crop else NoOp(),
        ForegroundCenterCrop(args.crop_size) if args.foreground_center_crop else NoOp(),
        transforms.RandomCrop(args.crop_size, args.crop_size) if args.random_crop else NoOp(),
        transforms.Normalize(mean=model.mean, std=model.std),
        ToTensor(),
    ])

    if args.img_type:
        stage_1_train_dir = 'processed/stage_1_train_%s' %args.img_type
    else:
        stage_1_train_dir = 'processed/stage_1_train'

    df = pd.read_csv('inputs/stage_1_train.csv')
    img_paths = np.array([stage_1_train_dir + '/' + '_'.join(s.split('_')[:-1]) + '.png' for s in df['ID']][::6])
    labels = np.array([df.loc[c::6, 'Label'].values for c in range(6)]).T.astype('float32')

    df = df[::6]
    df['img_path'] = img_paths
    for c in range(6):
        df['label_%d' %c] = labels[:, c]
    df['ID'] = df['ID'].apply(lambda s: '_'.join(s.split('_')[:-1]))

    meta_df = pd.read_csv('processed/stage_1_train_meta.csv')
    meta_df['ID'] = meta_df['SOPInstanceUID']
    test_meta_df = pd.read_csv('processed/stage_1_test_meta.csv')
    df = pd.merge(df, meta_df, how='left')

    patient_ids = meta_df['PatientID'].unique()
    test_patient_ids = test_meta_df['PatientID'].unique()
    if args.remove_test_patient_ids:
        patient_ids = np.array([s for s in patient_ids if not s in test_patient_ids])

    train_img_paths = np.hstack(df[['img_path', 'PatientID']].groupby(['PatientID'])['img_path'].apply(np.array).loc[patient_ids].to_list()).astype('str')
    train_labels = []
    for c in range(6):
        train_labels.append(np.hstack(df[['label_%d' %c, 'PatientID']].groupby(['PatientID'])['label_%d' %c].apply(np.array).loc[patient_ids].to_list()))
    train_labels = np.array(train_labels).T

    if args.resume:
        checkpoint = torch.load('models/%s/checkpoint.pth.tar' % args.name)

    # train
    train_set = Dataset(
        train_img_paths,
        train_labels,
        transform=train_transform)
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        # pin_memory=True,
    )

    if args.optimizer == 'Adam':
        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'AdamW':
        optimizer = optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'RAdam':
        optimizer = RAdam(
            filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'SGD':
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr,
                              momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
    else:
        raise NotImplementedError

    if args.apex:
        amp.initialize(model, optimizer, opt_level='O1')

    if args.scheduler == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs, eta_min=args.min_lr)
    elif args.scheduler == 'MultiStepLR':
        scheduler = lr_scheduler.MultiStepLR(optimizer,
            milestones=[int(e) for e in args.milestones.split(',')], gamma=args.gamma)
    else:
        raise NotImplementedError

    log = {
        'epoch': [],
        'loss': [],
    }

    start_epoch = 0

    if args.resume:
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        start_epoch = checkpoint['epoch']
        log = pd.read_csv('models/%s/log.csv' % args.name).to_dict(orient='list')

    for epoch in range(start_epoch, args.epochs):
        print('Epoch [%d/%d]' % (epoch + 1, args.epochs))

        # train for one epoch
        train_loss = train(args, train_loader, model, criterion, optimizer, epoch)

        if args.scheduler == 'CosineAnnealingLR':
            scheduler.step()

        print('loss %.4f' % (train_loss))

        log['epoch'].append(epoch)
        log['loss'].append(train_loss)

        pd.DataFrame(log).to_csv('models/%s/log.csv' % args.name, index=False)

        torch.save(model.state_dict(), 'models/%s/model.pth' % args.name)
        print("=> saved model")

        state = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
        }
        torch.save(state, 'models/%s/checkpoint.pth.tar' % args.name)
Example #3
0
def main():
    args = parse_args()

    with open('models/detection/%s/config.yml' % args.name, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    print('-' * 20)
    for key in config.keys():
        print('%s: %s' % (key, str(config[key])))
    print('-' * 20)

    cudnn.benchmark = True

    df = pd.read_csv('inputs/train.csv')
    img_paths = np.array('inputs/train_images/' + df['ImageId'].values +
                         '.jpg')
    mask_paths = np.array('inputs/train_masks/' + df['ImageId'].values +
                          '.jpg')
    labels = np.array(
        [convert_str_to_labels(s) for s in df['PredictionString']])

    heads = OrderedDict([
        ('hm', 1),
        ('reg', 2),
        ('depth', 1),
    ])

    if config['rot'] == 'eular':
        heads['eular'] = 3
    elif config['rot'] == 'trig':
        heads['trig'] = 6
    elif config['rot'] == 'quat':
        heads['quat'] = 4
    else:
        raise NotImplementedError

    if config['wh']:
        heads['wh'] = 2

    pred_df = df.copy()
    pred_df['PredictionString'] = np.nan

    dets = {}
    kf = KFold(n_splits=config['n_splits'], shuffle=True, random_state=41)
    for fold, (train_idx, val_idx) in enumerate(kf.split(img_paths)):
        print('Fold [%d/%d]' % (fold + 1, config['n_splits']))

        train_img_paths, val_img_paths = img_paths[train_idx], img_paths[
            val_idx]
        train_mask_paths, val_mask_paths = mask_paths[train_idx], mask_paths[
            val_idx]
        train_labels, val_labels = labels[train_idx], labels[val_idx]

        val_set = Dataset(val_img_paths,
                          val_mask_paths,
                          val_labels,
                          input_w=config['input_w'],
                          input_h=config['input_h'],
                          transform=None,
                          lhalf=config['lhalf'])
        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=config['num_workers'],
            # pin_memory=True,
        )

        model = get_model(config['arch'],
                          heads=heads,
                          head_conv=config['head_conv'],
                          num_filters=config['num_filters'],
                          dcn=config['dcn'],
                          gn=config['gn'],
                          ws=config['ws'],
                          freeze_bn=config['freeze_bn'])
        model = model.cuda()

        model_path = 'models/detection/%s/model_%d.pth' % (config['name'],
                                                           fold + 1)
        if not os.path.exists(model_path):
            print('%s is not exists.' % model_path)
            continue
        model.load_state_dict(torch.load(model_path))

        model.eval()

        outputs = {}

        with torch.no_grad():
            pbar = tqdm(total=len(val_loader))
            for i, batch in enumerate(val_loader):
                input = batch['input'].cuda()
                mask = batch['mask'].cuda()
                hm = batch['hm'].cuda()
                reg_mask = batch['reg_mask'].cuda()

                output = model(input)

                if args.hflip:
                    output_hf = model(torch.flip(input, (-1, )))
                    output_hf['hm'] = torch.flip(output_hf['hm'], (-1, ))
                    output_hf['reg'] = torch.flip(output_hf['reg'], (-1, ))
                    output_hf['reg'][:, 0] = 1 - output_hf['reg'][:, 0]
                    output_hf['depth'] = torch.flip(output_hf['depth'], (-1, ))
                    if config['rot'] == 'trig':
                        output_hf['trig'] = torch.flip(output_hf['trig'],
                                                       (-1, ))
                        yaw = torch.atan2(output_hf['trig'][:, 1],
                                          output_hf['trig'][:, 0])
                        yaw *= -1.0
                        output_hf['trig'][:, 0] = torch.cos(yaw)
                        output_hf['trig'][:, 1] = torch.sin(yaw)
                        roll = torch.atan2(output_hf['trig'][:, 5],
                                           output_hf['trig'][:, 4])
                        roll = rotate(roll, -np.pi)
                        roll *= -1.0
                        roll = rotate(roll, np.pi)
                        output_hf['trig'][:, 4] = torch.cos(roll)
                        output_hf['trig'][:, 5] = torch.sin(roll)

                    if config['wh']:
                        output_hf['wh'] = torch.flip(output_hf['wh'], (-1, ))

                    output['hm'] = (output['hm'] + output_hf['hm']) / 2
                    output['reg'] = (output['reg'] + output_hf['reg']) / 2
                    output['depth'] = (output['depth'] +
                                       output_hf['depth']) / 2
                    if config['rot'] == 'trig':
                        output['trig'] = (output['trig'] +
                                          output_hf['trig']) / 2
                    if config['wh']:
                        output['wh'] = (output['wh'] + output_hf['wh']) / 2

                batch_det = decode(
                    config,
                    output['hm'],
                    output['reg'],
                    output['depth'],
                    eular=output['eular']
                    if config['rot'] == 'eular' else None,
                    trig=output['trig'] if config['rot'] == 'trig' else None,
                    quat=output['quat'] if config['rot'] == 'quat' else None,
                    wh=output['wh'] if config['wh'] else None,
                    mask=mask,
                )
                batch_det = batch_det.cpu().numpy()

                for k, det in enumerate(batch_det):
                    img_id = os.path.splitext(
                        os.path.basename(batch['img_path'][k]))[0]

                    outputs[img_id] = {
                        'hm':
                        output['hm'][k:k + 1].cpu(),
                        'reg':
                        output['reg'][k:k + 1].cpu(),
                        'depth':
                        output['depth'][k:k + 1].cpu(),
                        'eular':
                        output['eular'][k:k + 1].cpu()
                        if config['rot'] == 'eular' else None,
                        'trig':
                        output['trig'][k:k + 1].cpu()
                        if config['rot'] == 'trig' else None,
                        'quat':
                        output['quat'][k:k + 1].cpu()
                        if config['rot'] == 'quat' else None,
                        'wh':
                        output['wh'][k:k + 1].cpu() if config['wh'] else None,
                        'mask':
                        mask[k:k + 1].cpu(),
                    }

                    dets[img_id] = det.tolist()
                    if args.nms:
                        det = nms(det, dist_th=args.nms_th)
                    pred_df.loc[pred_df.ImageId == img_id,
                                'PredictionString'] = convert_labels_to_str(
                                    det[det[:, 6] > args.score_th, :7])

                    if args.show:
                        gt = batch['gt'].numpy()[k]

                        img = cv2.imread(batch['img_path'][k])
                        img_gt = visualize(img, gt[gt[:, -1] > 0])
                        img_pred = visualize(img,
                                             det[det[:, 6] > args.score_th])

                        plt.subplot(121)
                        plt.imshow(img_gt[..., ::-1])
                        plt.subplot(122)
                        plt.imshow(img_pred[..., ::-1])
                        plt.show()

                pbar.update(1)
            pbar.close()

        torch.save(outputs,
                   'outputs/raw/val/%s_%d.pth' % (args.name, fold + 1))

        torch.cuda.empty_cache()

        if not config['cv']:
            break

    with open('outputs/decoded/val/%s.json' % args.name, 'w') as f:
        json.dump(dets, f)

    name = '%s_%.2f' % (args.name, args.score_th)
    if args.nms:
        name += '_nms%.2f' % args.nms_th
    if args.hflip:
        name += '_hf'
    pred_df.to_csv('outputs/submissions/val/%s.csv' % name, index=False)
    print(pred_df.head())
def main():
    config = vars(parse_args())

    if config['name'] is None:
        config['name'] = '%s_%s' % (config['arch'],
                                    datetime.now().strftime('%m%d%H'))

    config['num_filters'] = [int(n) for n in config['num_filters'].split(',')]

    if not os.path.exists('models/detection/%s' % config['name']):
        os.makedirs('models/detection/%s' % config['name'])

    if config['resume']:
        with open('models/detection/%s/config.yml' % config['name'], 'r') as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        config['resume'] = True

    with open('models/detection/%s/config.yml' % config['name'], 'w') as f:
        yaml.dump(config, f)

    print('-' * 20)
    for key in config.keys():
        print('- %s: %s' % (key, str(config[key])))
    print('-' * 20)

    cudnn.benchmark = True

    df = pd.read_csv('inputs/train.csv')
    img_paths = np.array('inputs/train_images/' + df['ImageId'].values +
                         '.jpg')
    mask_paths = np.array('inputs/train_masks/' + df['ImageId'].values +
                          '.jpg')
    labels = np.array(
        [convert_str_to_labels(s) for s in df['PredictionString']])

    test_img_paths = None
    test_mask_paths = None
    test_outputs = None
    if config['pseudo_label'] is not None:
        test_df = pd.read_csv('inputs/sample_submission.csv')
        test_img_paths = np.array('inputs/test_images/' +
                                  test_df['ImageId'].values + '.jpg')
        test_mask_paths = np.array('inputs/test_masks/' +
                                   test_df['ImageId'].values + '.jpg')
        ext = os.path.splitext(config['pseudo_label'])[1]
        if ext == '.pth':
            test_outputs = torch.load('outputs/raw/test/%s' %
                                      config['pseudo_label'])
        elif ext == '.csv':
            test_labels = pd.read_csv('outputs/submissions/test/%s' %
                                      config['pseudo_label'])
            null_idx = test_labels.isnull().any(axis=1)
            test_img_paths = test_img_paths[~null_idx]
            test_mask_paths = test_mask_paths[~null_idx]
            test_labels = test_labels.dropna()
            test_labels = np.array([
                convert_str_to_labels(
                    s, names=['pitch', 'yaw', 'roll', 'x', 'y', 'z', 'score'])
                for s in test_labels['PredictionString']
            ])
            print(test_labels)
        else:
            raise NotImplementedError

    if config['resume']:
        checkpoint = torch.load('models/detection/%s/checkpoint.pth.tar' %
                                config['name'])

    heads = OrderedDict([
        ('hm', 1),
        ('reg', 2),
        ('depth', 1),
    ])

    if config['rot'] == 'eular':
        heads['eular'] = 3
    elif config['rot'] == 'trig':
        heads['trig'] = 6
    elif config['rot'] == 'quat':
        heads['quat'] = 4
    else:
        raise NotImplementedError

    if config['wh']:
        heads['wh'] = 2

    criterion = OrderedDict()
    for head in heads.keys():
        criterion[head] = losses.__dict__[config[head + '_loss']]().cuda()

    train_transform = Compose([
        transforms.ShiftScaleRotate(shift_limit=config['shift_limit'],
                                    scale_limit=0,
                                    rotate_limit=0,
                                    border_mode=cv2.BORDER_CONSTANT,
                                    value=0,
                                    p=config['shift_p'])
        if config['shift'] else NoOp(),
        OneOf([
            transforms.HueSaturationValue(hue_shift_limit=config['hue_limit'],
                                          sat_shift_limit=config['sat_limit'],
                                          val_shift_limit=config['val_limit'],
                                          p=config['hsv_p'])
            if config['hsv'] else NoOp(),
            transforms.RandomBrightness(
                limit=config['brightness_limit'],
                p=config['brightness_p'],
            ) if config['brightness'] else NoOp(),
            transforms.RandomContrast(
                limit=config['contrast_limit'],
                p=config['contrast_p'],
            ) if config['contrast'] else NoOp(),
        ],
              p=1),
        transforms.ISONoise(p=config['iso_noise_p'], )
        if config['iso_noise'] else NoOp(),
        transforms.CLAHE(p=config['clahe_p'], ) if config['clahe'] else NoOp(),
    ],
                              keypoint_params=KeypointParams(
                                  format='xy', remove_invisible=False))

    val_transform = None

    folds = []
    best_losses = []
    # best_scores = []

    kf = KFold(n_splits=config['n_splits'], shuffle=True, random_state=41)
    for fold, (train_idx, val_idx) in enumerate(kf.split(img_paths)):
        print('Fold [%d/%d]' % (fold + 1, config['n_splits']))

        if (config['resume'] and fold < checkpoint['fold'] - 1) or (
                not config['resume']
                and os.path.exists('models/%s/model_%d.pth' %
                                   (config['name'], fold + 1))):
            log = pd.read_csv('models/detection/%s/log_%d.csv' %
                              (config['name'], fold + 1))
            best_loss = log.loc[log['val_loss'].values.argmin(), 'val_loss']
            # best_loss, best_score = log.loc[log['val_loss'].values.argmin(), ['val_loss', 'val_score']].values
            folds.append(str(fold + 1))
            best_losses.append(best_loss)
            # best_scores.append(best_score)
            continue

        train_img_paths, val_img_paths = img_paths[train_idx], img_paths[
            val_idx]
        train_mask_paths, val_mask_paths = mask_paths[train_idx], mask_paths[
            val_idx]
        train_labels, val_labels = labels[train_idx], labels[val_idx]

        if config['pseudo_label'] is not None:
            train_img_paths = np.hstack((train_img_paths, test_img_paths))
            train_mask_paths = np.hstack((train_mask_paths, test_mask_paths))
            train_labels = np.hstack((train_labels, test_labels))

        # train
        train_set = Dataset(
            train_img_paths,
            train_mask_paths,
            train_labels,
            input_w=config['input_w'],
            input_h=config['input_h'],
            transform=train_transform,
            lhalf=config['lhalf'],
            hflip=config['hflip_p'] if config['hflip'] else 0,
            scale=config['scale_p'] if config['scale'] else 0,
            scale_limit=config['scale_limit'],
            # test_img_paths=test_img_paths,
            # test_mask_paths=test_mask_paths,
            # test_outputs=test_outputs,
        )
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=config['num_workers'],
            # pin_memory=True,
        )

        val_set = Dataset(val_img_paths,
                          val_mask_paths,
                          val_labels,
                          input_w=config['input_w'],
                          input_h=config['input_h'],
                          transform=val_transform,
                          lhalf=config['lhalf'])
        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=config['num_workers'],
            # pin_memory=True,
        )

        # create model
        model = get_model(config['arch'],
                          heads=heads,
                          head_conv=config['head_conv'],
                          num_filters=config['num_filters'],
                          dcn=config['dcn'],
                          gn=config['gn'],
                          ws=config['ws'],
                          freeze_bn=config['freeze_bn'])
        model = model.cuda()

        if config['load_model'] is not None:
            model.load_state_dict(
                torch.load('models/detection/%s/model_%d.pth' %
                           (config['load_model'], fold + 1)))

        params = filter(lambda p: p.requires_grad, model.parameters())
        if config['optimizer'] == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=config['lr'],
                                   weight_decay=config['weight_decay'])
        elif config['optimizer'] == 'AdamW':
            optimizer = optim.AdamW(params,
                                    lr=config['lr'],
                                    weight_decay=config['weight_decay'])
        elif config['optimizer'] == 'RAdam':
            optimizer = RAdam(params,
                              lr=config['lr'],
                              weight_decay=config['weight_decay'])
        elif config['optimizer'] == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=config['lr'],
                                  momentum=config['momentum'],
                                  nesterov=config['nesterov'],
                                  weight_decay=config['weight_decay'])
        else:
            raise NotImplementedError

        if config['apex']:
            amp.initialize(model, optimizer, opt_level='O1')

        if config['scheduler'] == 'CosineAnnealingLR':
            scheduler = lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
        elif config['scheduler'] == 'ReduceLROnPlateau':
            scheduler = lr_scheduler.ReduceLROnPlateau(
                optimizer,
                factor=config['factor'],
                patience=config['patience'],
                verbose=1,
                min_lr=config['min_lr'])
        elif config['scheduler'] == 'MultiStepLR':
            scheduler = lr_scheduler.MultiStepLR(
                optimizer,
                milestones=[int(e) for e in config['milestones'].split(',')],
                gamma=config['gamma'])
        else:
            raise NotImplementedError

        log = {
            'epoch': [],
            'loss': [],
            # 'score': [],
            'val_loss': [],
            # 'val_score': [],
        }

        best_loss = float('inf')
        # best_score = float('inf')

        start_epoch = 0

        if config['resume'] and fold == checkpoint['fold'] - 1:
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            start_epoch = checkpoint['epoch']
            log = pd.read_csv(
                'models/detection/%s/log_%d.csv' %
                (config['name'], fold + 1)).to_dict(orient='list')
            best_loss = checkpoint['best_loss']

        for epoch in range(start_epoch, config['epochs']):
            print('Epoch [%d/%d]' % (epoch + 1, config['epochs']))

            # train for one epoch
            train_loss = train(config, heads, train_loader, model, criterion,
                               optimizer, epoch)
            # evaluate on validation set
            val_loss = validate(config, heads, val_loader, model, criterion)

            if config['scheduler'] == 'CosineAnnealingLR':
                scheduler.step()
            elif config['scheduler'] == 'ReduceLROnPlateau':
                scheduler.step(val_loss)

            print('loss %.4f - val_loss %.4f' % (train_loss, val_loss))
            # print('loss %.4f - score %.4f - val_loss %.4f - val_score %.4f'
            #       % (train_loss, train_score, val_loss, val_score))

            log['epoch'].append(epoch)
            log['loss'].append(train_loss)
            # log['score'].append(train_score)
            log['val_loss'].append(val_loss)
            # log['val_score'].append(val_score)

            pd.DataFrame(log).to_csv('models/detection/%s/log_%d.csv' %
                                     (config['name'], fold + 1),
                                     index=False)

            if val_loss < best_loss:
                torch.save(
                    model.state_dict(), 'models/detection/%s/model_%d.pth' %
                    (config['name'], fold + 1))
                best_loss = val_loss
                # best_score = val_score
                print("=> saved best model")

            state = {
                'fold': fold + 1,
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_loss': best_loss,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
            }
            torch.save(
                state,
                'models/detection/%s/checkpoint.pth.tar' % config['name'])

        print('val_loss:  %f' % best_loss)
        # print('val_score: %f' % best_score)

        folds.append(str(fold + 1))
        best_losses.append(best_loss)
        # best_scores.append(best_score)

        results = pd.DataFrame({
            'fold':
            folds + ['mean'],
            'best_loss':
            best_losses + [np.mean(best_losses)],
            # 'best_score': best_scores + [np.mean(best_scores)],
        })

        print(results)
        results.to_csv('models/detection/%s/results.csv' % config['name'],
                       index=False)

        del model
        torch.cuda.empty_cache()

        del train_set, train_loader
        del val_set, val_loader
        gc.collect()

        if not config['cv']:
            break
Example #5
0
def main():
    test_args = parse_args()

    args = joblib.load('models/%s/args.pkl' % test_args.name)

    print('Config -----')
    for arg in vars(args):
        print('%s: %s' % (arg, getattr(args, arg)))
    print('------------')

    if args.pred_type == 'classification':
        num_outputs = 5
    elif args.pred_type == 'regression':
        num_outputs = 1
    elif args.pred_type == 'multitask':
        num_outputs = 6
    else:
        raise NotImplementedError

    cudnn.benchmark = True

    test_transform = transforms.Compose([
        transforms.Resize((args.input_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    # data loading code
    test_dir = preprocess('test',
                          args.img_size,
                          scale=args.scale_radius,
                          norm=args.normalize,
                          pad=args.padding,
                          remove=args.remove)
    test_df = pd.read_csv('inputs/test.csv')
    test_img_paths = test_dir + '/' + test_df['id_code'].values + '.png'
    test_labels = np.zeros(len(test_img_paths))

    test_set = Dataset(test_img_paths, test_labels, transform=test_transform)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=4)

    preds = []
    for fold in range(args.n_splits):
        print('Fold [%d/%d]' % (fold + 1, args.n_splits))

        # create model
        model_path = 'models/%s/model_%d.pth' % (args.name, fold + 1)
        if not os.path.exists(model_path):
            print('%s is not exists.' % model_path)
            continue
        model = get_model(model_name=args.arch,
                          num_outputs=num_outputs,
                          freeze_bn=args.freeze_bn,
                          dropout_p=args.dropout_p)
        model = model.cuda()
        model.load_state_dict(torch.load(model_path))

        model.eval()

        preds_fold = []
        with torch.no_grad():
            for i, (input, _) in tqdm(enumerate(test_loader),
                                      total=len(test_loader)):
                if test_args.tta:
                    outputs = []
                    for input in apply_tta(input):
                        input = input.cuda()
                        output = model(input)
                        outputs.append(output.data.cpu().numpy()[:, 0])
                    preds_fold.extend(np.mean(outputs, axis=0))
                else:
                    input = input.cuda()
                    output = model(input)

                    preds_fold.extend(output.data.cpu().numpy()[:, 0])
        preds_fold = np.array(preds_fold)
        preds.append(preds_fold)

        if not args.cv:
            break

    preds = np.mean(preds, axis=0)

    if test_args.tta:
        args.name += '_tta'

    test_df['diagnosis'] = preds
    test_df.to_csv('probs/%s.csv' % args.name, index=False)

    thrs = [0.5, 1.5, 2.5, 3.5]
    preds[preds < thrs[0]] = 0
    preds[(preds >= thrs[0]) & (preds < thrs[1])] = 1
    preds[(preds >= thrs[1]) & (preds < thrs[2])] = 2
    preds[(preds >= thrs[2]) & (preds < thrs[3])] = 3
    preds[preds >= thrs[3]] = 4
    preds = preds.astype('int')

    test_df['diagnosis'] = preds
    test_df.to_csv('submissions/%s.csv' % args.name, index=False)
Example #6
0
def main():
    args = parse_args()

    if args.name is None:
        args.name = '%s_%s' % (args.mode, args.arch)
    if not os.path.exists('models/%s' % args.name):
        os.makedirs('models/%s' % args.name)

    print('Config -----')
    for arg in vars(args):
        print('- %s: %s' % (arg, getattr(args, arg)))
    print('------------')

    with open('models/%s/args.txt' % args.name, 'w') as f:
        for arg in vars(args):
            print('- %s: %s' % (arg, getattr(args, arg)), file=f)

    joblib.dump(args, 'models/%s/args.pkl' % args.name)

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus

    # switch to benchmark model, a little forward results fluctuation, a little fast training
    cudnn.benchmark = True
    # switch to deterministic model, more stable
    # cudnn.deterministic = True

    img_path, img_labels, num_outputs = img_path_generator(
        dataset=args.train_dataset)
    if args.pred_type == 'regression':
        num_outputs = 1

    skf = StratifiedKFold(n_splits=args.n_splits,
                          shuffle=True,
                          random_state=30)
    img_paths = []
    labels = []
    for fold, (train_idx,
               val_idx) in enumerate(skf.split(img_path, img_labels)):
        img_paths.append((img_path[train_idx], img_path[val_idx]))
        labels.append((img_labels[train_idx], img_labels[val_idx]))

    train_transform = []
    train_transform = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        # transforms.RandomAffine(
        #     degrees=(args.rotate_min, args.rotate_max) if args.rotate else 0,
        #     translate=(args.translate_min, args.translate_max) if args.translate else None,
        #     scale=(args.rescale_min, args.rescale_max) if args.rescale else None,
        #     shear=(args.shear_min, args.shear_max) if args.shear else None,
        # ),
        transforms.RandomCrop(args.input_size, padding=4),
        transforms.RandomHorizontalFlip(),
        # transforms.RandomVerticalFlip(),
        # transforms.ColorJitter(
        #     brightness=0,
        #     contrast=args.contrast,
        #     saturation=0,
        #     hue=0),
        RandomErase(prob=args.random_erase_prob if args.random_erase else 0,
                    sl=args.random_erase_sl,
                    sh=args.random_erase_sh,
                    r=args.random_erase_r),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    val_transform = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.CenterCrop(args.input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    if args.loss == 'CrossEntropyLoss':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss == 'FocalLoss':
        criterion = FocalLoss().cuda()
    elif args.loss == 'MSELoss':
        criterion = nn.MSELoss().cuda()
    elif args.loss == 'LabelSmoothingLoss':
        criterion = LabelSmoothingLoss(classes=num_outputs,
                                       smoothing=0.8).cuda()
    else:
        raise NotImplementedError

    folds = []
    best_losses = []
    best_ac_scores = []
    best_epochs = []

    for fold, ((train_img_paths, val_img_paths),
               (train_labels, val_labels)) in enumerate(zip(img_paths,
                                                            labels)):
        print('Fold [%d/%d]' % (fold + 1, len(img_paths)))

        # if os.path.exists('models/%s/model_%d.pth' % (args.name, fold+1)):
        #     log = pd.read_csv('models/%s/log_%d.csv' % (args.name, fold+1))
        #     best_loss, best_ac_score = log.loc[log['val_loss'].values.argmin(
        #     ), ['val_loss', 'val_score', 'val_ac_score']].values
        #     folds.append(str(fold + 1))
        #     best_losses.append(best_loss)
        #     best_ac_scores.append(best_ac_score)
        #     continue

        # train
        train_set = Dataset(train_img_paths,
                            train_labels,
                            transform=train_transform)

        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=0,
                                                   sampler=None)

        val_set = Dataset(val_img_paths, val_labels, transform=val_transform)
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=0)

        # create model
        if args.mode == 'baseline':
            model = get_model(model_name=args.arch,
                              num_outputs=num_outputs,
                              freeze_bn=args.freeze_bn,
                              dropout_p=args.dropout_p)
        elif args.mode == 'gcn':
            model_path = 'models/%s/model_%d.pth' % ('baseline_' + args.arch,
                                                     fold + 1)
            if not os.path.exists(model_path):
                print('%s is not exists' % model_path)
                continue
            model = SoftLabelGCN(cnn_model_name=args.arch,
                                 cnn_pretrained=False,
                                 num_outputs=num_outputs)
            pretrained_dict = torch.load(model_path)
            model_dict = model.cnn.state_dict()
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            model.cnn.load_state_dict(model_dict)
            for p in model.cnn.parameters():
                p.requires_grad = False
        else:
            # model = RA(cnn_model_name=args.arch, input_size=args.input_size, hidden_size=args.lstm_hidden,
            #            layer_num=args.lstm_layers, recurrent_num=args.lstm_recurrence, class_num=num_outputs, pretrain=True)
            model_path = 'models/%s/model_%d.pth' % ('baseline_' + args.arch,
                                                     fold + 1)
            if not os.path.exists(model_path):
                print('%s is not exists' % model_path)
                continue
            model = RA(cnn_model_name=args.arch,
                       input_size=args.input_size,
                       hidden_size=args.lstm_hidden,
                       layer_num=args.lstm_layers,
                       recurrent_num=args.lstm_recurrence,
                       class_num=num_outputs)
            pretrained_dict = torch.load(model_path)
            model_dict = model.cnn.state_dict()
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            model.cnn.load_state_dict(model_dict)
            for p in model.cnn.parameters():
                p.requires_grad = False

        device = torch.device('cuda')
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model.to(device)
        # model = model.cuda()
        if args.pretrained_model is not None:
            model.load_state_dict(
                torch.load('models/%s/model_%d.pth' %
                           (args.pretrained_model, fold + 1)))

        # print(model)

        if args.optimizer == 'Adam':
            optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                          model.parameters()),
                                   lr=args.lr)
        elif args.optimizer == 'AdamW':
            optimizer = optim.AdamW(filter(lambda p: p.requires_grad,
                                           model.parameters()),
                                    lr=args.lr)
        elif args.optimizer == 'RAdam':
            optimizer = RAdam(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=args.lr)
        elif args.optimizer == 'SGD':
            optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                         model.parameters()),
                                  lr=args.lr,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay,
                                  nesterov=args.nesterov)
            # optimizer = optim.SGD(model.get_config_optim(args.lr, args.lr * 10, args.lr * 10), lr=args.lr,
            #                       momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)

        if args.scheduler == 'CosineAnnealingLR':
            scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                       T_max=args.epochs,
                                                       eta_min=args.min_lr)
        elif args.scheduler == 'ReduceLROnPlateau':
            scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       factor=args.factor,
                                                       patience=args.patience,
                                                       verbose=1,
                                                       min_lr=args.min_lr)

        log = pd.DataFrame(index=[],
                           columns=[
                               'epoch',
                               'loss',
                               'ac_score',
                               'val_loss',
                               'val_ac_score',
                           ])
        log = {
            'epoch': [],
            'loss': [],
            'ac_score': [],
            'val_loss': [],
            'val_ac_score': [],
        }

        best_loss = float('inf')
        best_ac_score = 0
        best_epoch = 0

        for epoch in range(args.epochs):
            print('Epoch [%d/%d]' % (epoch + 1, args.epochs))

            # train for one epoch
            train_loss, train_ac_score = train(args, train_loader, model,
                                               criterion, optimizer, epoch)

            # evaluate on validation set
            val_loss, val_ac_score = validate(args, val_loader, model,
                                              criterion)

            if args.scheduler == 'CosineAnnealingLR':
                scheduler.step()
            elif args.scheduler == 'ReduceLROnPlateau':
                scheduler.step(val_loss)

            print(
                'loss %.4f - ac_score %.4f - val_loss %.4f - val_ac_score %.4f'
                % (train_loss, train_ac_score, val_loss, val_ac_score))

            log['epoch'].append(epoch)
            log['loss'].append(train_loss)
            log['ac_score'].append(train_ac_score)
            log['val_loss'].append(val_loss)
            log['val_ac_score'].append(val_ac_score)

            pd.DataFrame(log).to_csv('models/%s/log_%d.csv' %
                                     (args.name, fold + 1),
                                     index=False)

            if val_ac_score > best_ac_score:
                if args.mode == 'baseline':
                    torch.save(
                        model.state_dict(),
                        'models/%s/model_%d.pth' % (args.name, fold + 1))
                best_loss = val_loss
                best_ac_score = val_ac_score
                best_epoch = epoch
                print("=> saved best model")

        print('val_loss:  %f' % best_loss)
        print('val_ac_score: %f' % best_ac_score)

        folds.append(str(fold + 1))
        best_losses.append(best_loss)
        best_ac_scores.append(best_ac_score)
        best_epochs.append(best_epoch)

        results = pd.DataFrame({
            'fold':
            folds + ['mean'],
            'best_loss':
            best_losses + [np.mean(best_losses)],
            'best_ac_score':
            best_ac_scores + [np.mean(best_ac_scores)],
            'best_epoch':
            best_epochs + [''],
        })

        print(results)
        results.to_csv('models/%s/results.csv' % args.name, index=False)
        torch.cuda.empty_cache()

        if not args.cv:
            break
Example #7
0
def main():
    test_args = parse_args()

    args = joblib.load('models/%s/args.pkl' % test_args.name)

    print('Config -----')
    for arg in vars(args):
        print('%s: %s' % (arg, getattr(args, arg)))
    print('------------')

    if args.pred_type == 'all':
        num_outputs = 6
    elif args.pred_type == 'except_any':
        num_outputs = 5
    else:
        raise NotImplementedError

    # create model
    model_path = 'models/%s/model.pth' % args.name
    model = get_model(model_name=args.arch,
                      num_outputs=num_outputs,
                      freeze_bn=args.freeze_bn,
                      dropout_p=args.dropout_p,
                      pooling=args.pooling,
                      lp_p=args.lp_p)
    model = model.cuda()
    model.load_state_dict(torch.load(model_path))

    model.eval()

    cudnn.benchmark = True

    test_transform = Compose([
        transforms.Resize(args.img_size, args.img_size),
        ForegroundCenterCrop(args.crop_size),
        transforms.Normalize(mean=model.mean, std=model.std),
        ToTensor(),
    ])

    # data loading code
    # if args.img_type:
    #     stage_1_test_dir = 'processed/stage_1_test_%s' %args.img_type
    # else:
    #     stage_1_test_dir = 'processed/stage_1_test'
    if args.img_type:
        stage_2_test_dir = 'processed/stage_2_test_%s' % args.img_type
    else:
        stage_2_test_dir = 'processed/stage_2_test'

    # test_df = pd.read_csv('inputs/stage_1_sample_submission.csv')
    test_df = pd.read_csv('inputs/stage_2_sample_submission.csv')
    # test_img_paths = np.array([stage_1_test_dir + '/' + '_'.join(s.split('_')[:-1]) + '.png' for s in test_df['ID']][::6])
    test_img_paths = np.array([
        stage_2_test_dir + '/' + '_'.join(s.split('_')[:-1]) + '.png'
        for s in test_df['ID']
    ][::6])
    test_labels = np.array([
        test_df.loc[c::6, 'Label'].values for c in range(6)
    ]).T.astype('float32')

    test_set = Dataset(test_img_paths, test_labels, transform=test_transform)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers)

    preds = []
    preds_fold = []
    with torch.no_grad():
        for i, (input, _) in tqdm(enumerate(test_loader),
                                  total=len(test_loader)):
            outputs = []
            for input in apply_tta(args, input):
                input = input.cuda()
                output = model(input)
                output = torch.sigmoid(output)
                if args.pred_type == 'except_any':
                    output = torch.cat(
                        [output, torch.max(output, 1, keepdim=True)[0]], 1)
                outputs.append(output.data.cpu().numpy())
            preds_fold.extend(np.mean(outputs, axis=0))
    preds_fold = np.vstack(preds_fold)
    preds.append(preds_fold)

    preds = np.mean(preds, axis=0)

    test_df = pd.DataFrame(preds,
                           columns=[
                               'epidural', 'intraparenchymal',
                               'intraventricular', 'subarachnoid', 'subdural',
                               'any'
                           ])
    test_df['ID'] = test_img_paths

    # Unpivot table, i.e. wide (N x 6) to long format (6N x 1)
    test_df = test_df.melt(id_vars=['ID'])

    # Combine the filename column with the variable column
    test_df['ID'] = test_df.ID.apply(lambda x: os.path.basename(x).replace(
        '.png', '')) + '_' + test_df.variable
    test_df['Label'] = test_df['value']

    if test_args.hflip:
        args.name += '_hflip'
    test_df[['ID', 'Label']].to_csv('submissions/%s.csv' % args.name,
                                    index=False)
Example #8
0
def main():
    args = parse_args()

    if args.name is None:
        args.name = '%s_%s' % (args.arch, datetime.now().strftime('%m%d%H'))

    if not os.path.exists('models/%s' % args.name):
        os.makedirs('models/%s' % args.name)

    print('Config -----')
    for arg in vars(args):
        print('- %s: %s' % (arg, getattr(args, arg)))
    print('------------')

    with open('models/%s/args.txt' % args.name, 'w') as f:
        for arg in vars(args):
            print('- %s: %s' % (arg, getattr(args, arg)), file=f)

    joblib.dump(args, 'models/%s/args.pkl' % args.name)

    if args.loss == 'CrossEntropyLoss':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss == 'FocalLoss':
        criterion = FocalLoss().cuda()
    elif args.loss == 'MSELoss':
        criterion = nn.MSELoss().cuda()
    elif args.loss == 'multitask':
        criterion = {
            'classification': nn.CrossEntropyLoss().cuda(),
            'regression': nn.MSELoss().cuda(),
        }
    else:
        raise NotImplementedError

    if args.pred_type == 'classification':
        num_outputs = 5
    elif args.pred_type == 'regression':
        num_outputs = 1
    elif args.loss == 'multitask':
        num_outputs = 6
    else:
        raise NotImplementedError

    cudnn.benchmark = True

    model = get_model(model_name=args.arch,
                      num_outputs=num_outputs,
                      freeze_bn=args.freeze_bn,
                      dropout_p=args.dropout_p)

    train_transform = []
    train_transform = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.RandomAffine(
            degrees=(args.rotate_min, args.rotate_max) if args.rotate else 0,
            translate=(args.translate_min, args.translate_max) if args.translate else None,
            scale=(args.rescale_min, args.rescale_max) if args.rescale else None,
            shear=(args.shear_min, args.shear_max) if args.shear else None,
        ),
        transforms.CenterCrop(args.input_size),
        transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
        transforms.RandomVerticalFlip(p=0.5 if args.flip else 0),
        transforms.ColorJitter(
            brightness=0,
            contrast=args.contrast,
            saturation=0,
            hue=0),
        RandomErase(
            prob=args.random_erase_prob if args.random_erase else 0,
            sl=args.random_erase_sl,
            sh=args.random_erase_sh,
            r=args.random_erase_r),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    val_transform = transforms.Compose([
        transforms.Resize((args.img_size, args.input_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    # data loading code
    if 'diabetic_retinopathy' in args.train_dataset:
        diabetic_retinopathy_dir = preprocess(
            'diabetic_retinopathy',
            args.img_size,
            scale=args.scale_radius,
            norm=args.normalize,
            pad=args.padding,
            remove=args.remove)
        diabetic_retinopathy_df = pd.read_csv('inputs/diabetic-retinopathy-resized/trainLabels.csv')
        diabetic_retinopathy_img_paths = \
            diabetic_retinopathy_dir + '/' + diabetic_retinopathy_df['image'].values + '.jpeg'
        diabetic_retinopathy_labels = diabetic_retinopathy_df['level'].values

    if 'aptos2019' in args.train_dataset:
        aptos2019_dir = preprocess(
            'aptos2019',
            args.img_size,
            scale=args.scale_radius,
            norm=args.normalize,
            pad=args.padding,
            remove=args.remove)
        aptos2019_df = pd.read_csv('inputs/train.csv')
        aptos2019_img_paths = aptos2019_dir + '/' + aptos2019_df['id_code'].values + '.png'
        aptos2019_labels = aptos2019_df['diagnosis'].values

    if args.train_dataset == 'aptos2019':
        skf = StratifiedKFold(n_splits=args.n_splits, shuffle=True, random_state=41)
        img_paths = []
        labels = []
        for fold, (train_idx, val_idx) in enumerate(skf.split(aptos2019_img_paths, aptos2019_labels)):
            img_paths.append((aptos2019_img_paths[train_idx], aptos2019_img_paths[val_idx]))
            labels.append((aptos2019_labels[train_idx], aptos2019_labels[val_idx]))
    elif args.train_dataset == 'diabetic_retinopathy':
        img_paths = [(diabetic_retinopathy_img_paths, aptos2019_img_paths)]
        labels = [(diabetic_retinopathy_labels, aptos2019_labels)]
    elif 'diabetic_retinopathy' in args.train_dataset and 'aptos2019' in args.train_dataset:
        skf = StratifiedKFold(n_splits=args.n_splits, shuffle=True, random_state=41)
        img_paths = []
        labels = []
        for fold, (train_idx, val_idx) in enumerate(skf.split(aptos2019_img_paths, aptos2019_labels)):
            img_paths.append((np.hstack((aptos2019_img_paths[train_idx], diabetic_retinopathy_img_paths)), aptos2019_img_paths[val_idx]))
            labels.append((np.hstack((aptos2019_labels[train_idx], diabetic_retinopathy_labels)), aptos2019_labels[val_idx]))
    # else:
    #     raise NotImplementedError

    if args.pseudo_labels:
        test_df = pd.read_csv('probs/%s.csv' % args.pseudo_labels)
        test_dir = preprocess(
            'test',
            args.img_size,
            scale=args.scale_radius,
            norm=args.normalize,
            pad=args.padding,
            remove=args.remove)
        test_img_paths = test_dir + '/' + test_df['id_code'].values + '.png'
        test_labels = test_df['diagnosis'].values
        for fold in range(len(img_paths)):
            img_paths[fold] = (np.hstack((img_paths[fold][0], test_img_paths)), img_paths[fold][1])
            labels[fold] = (np.hstack((labels[fold][0], test_labels)), labels[fold][1])

    if 'messidor' in args.train_dataset:
        test_dir = preprocess(
            'messidor',
            args.img_size,
            scale=args.scale_radius,
            norm=args.normalize,
            pad=args.padding,
            remove=args.remove)

    folds = []
    best_losses = []
    best_scores = []

    for fold, ((train_img_paths, val_img_paths), (train_labels, val_labels)) in enumerate(zip(img_paths, labels)):
        print('Fold [%d/%d]' %(fold+1, len(img_paths)))

        if os.path.exists('models/%s/model_%d.pth' % (args.name, fold+1)):
            log = pd.read_csv('models/%s/log_%d.csv' %(args.name, fold+1))
            best_loss, best_score = log.loc[log['val_loss'].values.argmin(), ['val_loss', 'val_score']].values
            folds.append(str(fold + 1))
            best_losses.append(best_loss)
            best_scores.append(best_score)
            continue

        if args.remove_duplicate:
            md5_df = pd.read_csv('inputs/strMd5.csv')
            duplicate_img_paths = aptos2019_dir + '/' + md5_df[(md5_df.strMd5_count > 1) & (~md5_df.diagnosis.isnull())]['id_code'].values + '.png'
            print(duplicate_img_paths)
            for duplicate_img_path in duplicate_img_paths:
                train_labels = train_labels[train_img_paths != duplicate_img_path]
                train_img_paths = train_img_paths[train_img_paths != duplicate_img_path]
                val_labels = val_labels[val_img_paths != duplicate_img_path]
                val_img_paths = val_img_paths[val_img_paths != duplicate_img_path]

        # train
        train_set = Dataset(
            train_img_paths,
            train_labels,
            transform=train_transform)

        _, class_sample_counts = np.unique(train_labels, return_counts=True)
        # print(class_sample_counts)
        # weights = 1. / torch.tensor(class_sample_counts, dtype=torch.float)
        # weights = np.array([0.2, 0.1, 0.6, 0.1, 0.1])
        # samples_weights = weights[train_labels]
        # sampler = WeightedRandomSampler(
        #     weights=samples_weights,
        #     num_samples=11000,
        #     replacement=False)
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            shuffle=False if args.class_aware else True,
            num_workers=4,
            sampler=sampler if args.class_aware else None)

        val_set = Dataset(
            val_img_paths,
            val_labels,
            transform=val_transform)
        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=4)

        # create model
        model = get_model(model_name=args.arch,
                          num_outputs=num_outputs,
                          freeze_bn=args.freeze_bn,
                          dropout_p=args.dropout_p)
        model = model.cuda()
        if args.pretrained_model is not None:
            model.load_state_dict(torch.load('models/%s/model_%d.pth' % (args.pretrained_model, fold+1)))

        # print(model)

        if args.optimizer == 'Adam':
            optimizer = optim.Adam(
                filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
        elif args.optimizer == 'AdamW':
            optimizer = optim.AdamW(
                filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
        elif args.optimizer == 'RAdam':
            optimizer = RAdam(
                filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
        elif args.optimizer == 'SGD':
            optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr,
                                  momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)

        if args.scheduler == 'CosineAnnealingLR':
            scheduler = lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=args.epochs, eta_min=args.min_lr)
        elif args.scheduler == 'ReduceLROnPlateau':
            scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=args.factor, patience=args.patience,
                                                       verbose=1, min_lr=args.min_lr)

        log = pd.DataFrame(index=[], columns=[
            'epoch', 'loss', 'score', 'val_loss', 'val_score'
        ])
        log = {
            'epoch': [],
            'loss': [],
            'score': [],
            'val_loss': [],
            'val_score': [],
        }

        best_loss = float('inf')
        best_score = 0
        for epoch in range(args.epochs):
            print('Epoch [%d/%d]' % (epoch + 1, args.epochs))

            # train for one epoch
            train_loss, train_score = train(
                args, train_loader, model, criterion, optimizer, epoch)
            # evaluate on validation set
            val_loss, val_score = validate(args, val_loader, model, criterion)

            if args.scheduler == 'CosineAnnealingLR':
                scheduler.step()
            elif args.scheduler == 'ReduceLROnPlateau':
                scheduler.step(val_loss)

            print('loss %.4f - score %.4f - val_loss %.4f - val_score %.4f'
                  % (train_loss, train_score, val_loss, val_score))

            log['epoch'].append(epoch)
            log['loss'].append(train_loss)
            log['score'].append(train_score)
            log['val_loss'].append(val_loss)
            log['val_score'].append(val_score)

            pd.DataFrame(log).to_csv('models/%s/log_%d.csv' % (args.name, fold+1), index=False)

            if val_loss < best_loss:
                torch.save(model.state_dict(), 'models/%s/model_%d.pth' % (args.name, fold+1))
                best_loss = val_loss
                best_score = val_score
                print("=> saved best model")

        print('val_loss:  %f' % best_loss)
        print('val_score: %f' % best_score)

        folds.append(str(fold + 1))
        best_losses.append(best_loss)
        best_scores.append(best_score)

        results = pd.DataFrame({
            'fold': folds + ['mean'],
            'best_loss': best_losses + [np.mean(best_losses)],
            'best_score': best_scores + [np.mean(best_scores)],
        })

        print(results)
        results.to_csv('models/%s/results.csv' % args.name, index=False)

        torch.cuda.empty_cache()

        if not args.cv:
            break
Example #9
0
def main():
    args = parse_args()
    np.random.seed(args.seed)
    cudnn.benchmark = False
    cudnn.deterministic = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)

    if args.name is None:
        args.name = '%s_%s' % (args.arch, datetime.now().strftime('%m%d%H'))

    if not os.path.exists('models/%s' % args.name):
        os.makedirs('models/%s' % args.name)

    print('Config -----')
    for arg in vars(args):
        print('- %s: %s' % (arg, getattr(args, arg)))
    print('------------')

    with open('models/%s/args.txt' % args.name, 'w') as f:
        for arg in vars(args):
            print('- %s: %s' % (arg, getattr(args, arg)), file=f)

    joblib.dump(args, 'models/%s/args.pkl' % args.name)

    if args.loss == 'CrossEntropyLoss':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss == 'FocalLoss':
        criterion = FocalLoss().cuda()
    elif args.loss == 'MSELoss':
        criterion = nn.MSELoss().cuda()
    elif args.loss == 'multitask':
        criterion = {
            'classification': nn.CrossEntropyLoss().cuda(),
            'regression': nn.MSELoss().cuda(),
        }
    else:
        raise NotImplementedError

    if args.pred_type == 'classification':
        num_outputs = 5
    elif args.pred_type == 'regression':
        num_outputs = 1
    elif args.loss == 'multitask':
        num_outputs = 6
    else:
        raise NotImplementedError

    train_transform = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.RandomAffine(
            degrees=(args.rotate_min, args.rotate_max) if args.rotate else 0,
            translate=(args.translate_min,
                       args.translate_max) if args.translate else None,
            scale=(args.rescale_min,
                   args.rescale_max) if args.rescale else None,
            shear=(args.shear_min, args.shear_max) if args.shear else None,
        ),
        transforms.CenterCrop(args.input_size),
        transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
        transforms.RandomVerticalFlip(p=0.5 if args.flip else 0),
        transforms.ColorJitter(brightness=0,
                               contrast=args.contrast,
                               saturation=0,
                               hue=0),
        RandomErase(prob=args.random_erase_prob if args.random_erase else 0,
                    sl=args.random_erase_sl,
                    sh=args.random_erase_sh,
                    r=args.random_erase_r),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    val_transform = transforms.Compose([
        transforms.Resize((args.img_size, args.input_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    # data loading code
    if 'diabetic_retinopathy' in args.train_dataset:
        diabetic_retinopathy_dir = preprocess('diabetic_retinopathy',
                                              args.img_size,
                                              scale=args.scale_radius,
                                              norm=args.normalize,
                                              pad=args.padding,
                                              remove=args.remove)
        diabetic_retinopathy_df = pd.read_csv(
            'inputs/diabetic-retinopathy-resized/trainLabels.csv')
        diabetic_retinopathy_img_paths = \
            diabetic_retinopathy_dir + '/' + diabetic_retinopathy_df['image'].values + '.jpeg'
        diabetic_retinopathy_labels = diabetic_retinopathy_df['level'].values

    if 'aptos2019' in args.train_dataset:
        aptos2019_dir = preprocess('aptos2019',
                                   args.img_size,
                                   scale=args.scale_radius,
                                   norm=args.normalize,
                                   pad=args.padding,
                                   remove=args.remove)
        aptos2019_df = pd.read_csv('inputs/train.csv')
        aptos2019_img_paths = aptos2019_dir + '/' + aptos2019_df[
            'id_code'].values + '.png'
        aptos2019_labels = aptos2019_df['diagnosis'].values

    if 'chestxray' in args.train_dataset:
        chestxray_dir = preprocess('chestxray',
                                   args.img_size,
                                   scale=args.scale_radius,
                                   norm=args.normalize,
                                   pad=args.padding,
                                   remove=args.remove)

        chestxray_img_paths = []
        chestxray_labels = []
        normal_cases = glob('chest_xray/chest_xray/train/NORMAL/*.jpeg')
        pneumonia_cases = glob('chest_xray/chest_xray/train/PNEUMONIA/*.jpeg')
        for nor in normal_cases:
            p = nor.split('/')[-1]
            chestxray_img_paths.append(chestxray_dir + '/' + p)
            chestxray_labels.append(0)
        for abn in pneumonia_cases:
            p = abn.split('/')[-1]
            chestxray_img_paths.append(chestxray_dir + '/' + p)
            chestxray_labels.append(1)

        normal_cases = glob('chest_xray/chest_xray/test/NORMAL/*.jpeg')
        pneumonia_cases = glob('chest_xray/chest_xray/test/PNEUMONIA/*.jpeg')
        for nor in normal_cases:
            p = nor.split('/')[-1]
            chestxray_img_paths.append(chestxray_dir + '/' + p)
            chestxray_labels.append(0)
        for abn in pneumonia_cases:
            p = abn.split('/')[-1]
            chestxray_img_paths.append(chestxray_dir + '/' + p)
            chestxray_labels.append(1)

        normal_cases = glob('chest_xray/chest_xray/val/NORMAL/*.jpeg')
        pneumonia_cases = glob('chest_xray/chest_xray/val/PNEUMONIA/*.jpeg')
        for nor in normal_cases:
            p = nor.split('/')[-1]
            chestxray_img_paths.append(chestxray_dir + '/' + p)
            chestxray_labels.append(0)
        for abn in pneumonia_cases:
            p = abn.split('/')[-1]
            chestxray_img_paths.append(chestxray_dir + '/' + p)
            chestxray_labels.append(1)

        chestxray_img_paths = np.array(chestxray_img_paths)
        chestxray_labels = np.array(chestxray_labels)

    if args.train_dataset == 'aptos2019':
        skf = StratifiedKFold(n_splits=args.n_splits,
                              shuffle=True,
                              random_state=41)
        img_paths = []
        labels = []
        for fold, (train_idx, val_idx) in enumerate(
                skf.split(aptos2019_img_paths, aptos2019_labels)):
            img_paths.append(
                (aptos2019_img_paths[train_idx], aptos2019_img_paths[val_idx]))
            labels.append(
                (aptos2019_labels[train_idx], aptos2019_labels[val_idx]))
    elif args.train_dataset == 'diabetic_retinopathy':
        img_paths = [(diabetic_retinopathy_img_paths, aptos2019_img_paths)]
        labels = [(diabetic_retinopathy_labels, aptos2019_labels)]
    elif 'diabetic_retinopathy' in args.train_dataset and 'aptos2019' in args.train_dataset:
        skf = StratifiedKFold(n_splits=args.n_splits,
                              shuffle=True,
                              random_state=41)
        img_paths = []
        labels = []
        for fold, (train_idx, val_idx) in enumerate(
                skf.split(aptos2019_img_paths, aptos2019_labels)):
            img_paths.append((np.hstack((aptos2019_img_paths[train_idx],
                                         diabetic_retinopathy_img_paths)),
                              aptos2019_img_paths[val_idx]))
            labels.append((np.hstack(
                (aptos2019_labels[train_idx], diabetic_retinopathy_labels)),
                           aptos2019_labels[val_idx]))

    # FL setting: separate data into users
    if 'diabetic_retinopathy' in args.train_dataset and 'aptos2019' in args.train_dataset:
        combined_paths = np.hstack(
            (aptos2019_img_paths, diabetic_retinopathy_img_paths))
        combined_labels = np.hstack(
            (aptos2019_labels, diabetic_retinopathy_labels))
    elif 'chestxray' in args.train_dataset:
        combined_paths = chestxray_img_paths
        combined_labels = chestxray_labels
    else:
        raise NotImplementedError
    user_ind_dict, ind_test = split_dataset(combined_labels, args.num_users,
                                            args.iid)

    model = get_model(model_name=args.arch,
                      num_outputs=num_outputs,
                      freeze_bn=args.freeze_bn,
                      dropout_p=args.dropout_p)
    model = model.cuda()
    test_set = Dataset(combined_paths[ind_test],
                       combined_labels[ind_test],
                       transform=val_transform)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=4)

    test_acc = []
    test_scores = []
    test_scores_f1 = []
    lr = args.lr
    for epoch in range(args.epochs):

        print('Epoch [%d/%d]' % (epoch + 1, args.epochs))
        weight_list = []
        selected_ind = np.random.choice(args.num_users,
                                        int(args.num_users / 10),
                                        replace=False)
        for i in selected_ind:
            print('user: %d' % (i + 1))
            train_set = Dataset(combined_paths[user_ind_dict[i]],
                                combined_labels[user_ind_dict[i]],
                                transform=train_transform)
            train_loader = torch.utils.data.DataLoader(
                train_set,
                batch_size=args.batch_size,
                shuffle=False if args.class_aware else True,
                num_workers=4,
                sampler=sampler if args.class_aware else None)

            # train for one epoch
            train_loss, train_score, ret_w = train(args, train_loader,
                                                   copy.deepcopy(model),
                                                   criterion, lr)
            weight_list.append(ret_w)
            print('loss %.4f - score %.4f' % (train_loss, train_score))
        weights = fedavg(weight_list)
        model.load_state_dict(weights)
        test_loss, test_score, test_scoref1, accuracy, confusion_matrix = test(
            args, test_loader, copy.deepcopy(model), criterion)
        print('loss %.4f - score %.4f - accuracy %.4f' %
              (test_loss, test_score, accuracy))
        test_acc.append(accuracy)
        test_scores.append(test_score)
        test_scores_f1.append(test_scoref1)
        lr *= 0.992

    np.savez('./accuracy-xray-iid' + str(args.iid) + '-' + str(args.epochs) +
             '-beta' + str(args.beta) + '-seed' + str(args.seed),
             acc=np.array(test_acc),
             score=np.array(test_scores),
             scoref1=np.array(test_scores_f1),
             confusion=confusion_matrix)