Example #1
0
def main():

    if args.device == 'cuda':
        torch.backends.cudnn.benchmark = True  #  should add to speed up the code when input array shape doesn't vary
        print('Using cudnn.benchmark.')

    model = model_dispatcher(True, args.base_model, args.nclass)
    model.to(args.device)

    train_size = len(pd.read_csv(args.train_file))
    print(train_size)

    train_dataset = ImageSamplerDataset(phase='train',
                                        train_file=args.train_file,
                                        image_file_path=args.image_file,
                                        image_height=args.image_height,
                                        image_width=args.image_width,
                                        mean=MODEL_MEAN,
                                        std=MODEL_STD,
                                        binclass=args.binclass)

    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.train_batch_size,
                                  num_workers=args.num_workers)

    optimizer = AdamW(model.parameters(),
                      lr=args.lr,
                      weight_decay=args.weight_decay)
    scheduler_cosine = CosineAnnealingLR(optimizer, args.epochs)

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel()

    for epoch in range(args.epochs):
        tr_loss = train(dataset_size=train_size,
                        dataloader=train_dataloader,
                        model=model,
                        optimizer=optimizer,
                        device=args.device,
                        loss_fn=focal_loss_fn)
        print(f'Epoch_{epoch+1} Train Loss:{tr_loss}')

        scheduler_cosine.step(epoch)

    torch.save(
        model.state_dict(),
        os.path.join(args.save_dir, f'{args.base_model}_on_all_epoch11.bin'))
    print('train on all is complete')
Example #2
0
def main():
    model = model_dispatcher(False, args.base_model, args.nclass)
    model.to(args.device)
    model.load_state_dict(
        torch.load(os.path.join(args.save_dir, args.model_weights)))
    model.eval()
    print(f'Loading pretrained model: {args.base_model} for eval')

    test_dataset = ImageTestDataset(file_path=args.image_file,
                                    image_height=args.image_height,
                                    image_width=args.image_width,
                                    mean=MODEL_MEAN,
                                    std=MODEL_STD)

    test_dataloader = DataLoader(
        dataset=test_dataset,
        batch_size=args.test_batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )

    image_id_list = []
    image_pred_list = []

    with torch.no_grad():
        for batch_id, d in enumerate(tqdm(test_dataloader)):
            image = d['image']
            img_id = d['image_id']

            image = image.to(args.device, dtype=torch.float)
            outputs = model(image)
            pred_prob = torch.nn.Softmax(dim=1)(outputs)

            image_id_list.append(img_id)
            image_pred_list.append(pred_prob)

    preds = torch.cat(image_pred_list).cpu().numpy()

    ids = list(chain(*image_id_list))

    df_pred = pd.DataFrame(preds, columns=['A', 'B', 'C'])
    df_id = pd.DataFrame(ids, columns=['image_ids'])
    sub = pd.concat([df_id, df_pred], axis=1)
    sub.to_csv(f'{args.output_name}.csv', index=False)
Example #3
0
def main():
    model = model_dispatcher(False, args.base_model, args.nclass)
    model.to(args.device)
    model.load_state_dict(torch.load(os.path.join(args.save_dir,args.model_weights)))
    model.eval()
    print(f'Loading pretrained model: {args.base_model} for eval')

    test_dataset = ImageTestDataset(
        file_path = args.image_file,
        image_height=args.image_height,
        image_width=args.image_width,
        mean = MODEL_MEAN,
        std = MODEL_STD
    )

    test_dataloader = DataLoader(
        dataset=test_dataset,
        batch_size=args.test_batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )

    image_id_list = []
    image_feat_list = []

    with torch.no_grad():
        for batch_id, d in enumerate(tqdm(test_dataloader)):
            image = d['image']
            img_id = d['image_id']

            image = image.to(args.device, dtype=torch.float)
            model(image)
            img_features = model.imgfeatures
            image_id_list.append(img_id)
            image_feat_list.extend(img_features.cpu().numpy().tolist())
    preds = image_feat_list

    ids = list(chain(*image_id_list))

    sub = pd.DataFrame({'image_ids':ids, 'features':preds})
    sub.to_csv(f'{args.base_model}_imagefeatures.csv',index=False)
Example #4
0
def main():
    model = model_dispatcher(False, args.base_model, args.nclass)
    model.to(args.device)
    model.load_state_dict(
        torch.load(os.path.join(args.save_dir, args.model_weights)))

    # checkpoint = torch.load(os.path.join(args.save_dir,args.model_weights), map_location=args.device)
    # model.load_state_dict(checkpoint)

    model.eval()
    print(f'Loading pretrained model: {args.base_model} for eval')

    for num_tta in range(args.num_tta):
        if num_tta == 0:
            test_dataset = ImageTTADataset(file_path=args.image_file,
                                           transform=data_transforms_test)
            test_dataloader = DataLoader(test_dataset,
                                         batch_size=args.test_batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers)

        elif num_tta == 1:
            test_dataset = ImageTTADataset(file_path=args.image_file,
                                           transform=data_transforms_tta1)
            test_dataloader = DataLoader(test_dataset,
                                         batch_size=args.test_batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers)

        elif num_tta == 2:
            test_dataset = ImageTTADataset(file_path=args.image_file,
                                           transform=data_transforms_tta2)
            test_dataloader = DataLoader(test_dataset,
                                         batch_size=args.test_batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers)

        elif num_tta == 3:
            test_dataset = ImageTTADataset(file_path=args.image_file,
                                           transform=data_transforms_tta3)
            test_dataloader = DataLoader(test_dataset,
                                         batch_size=args.test_batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers)

        elif num_tta < 8:
            test_dataset = ImageTTADataset(file_path=args.image_file,
                                           transform=data_transforms_tta0)
            test_dataloader = DataLoader(test_dataset,
                                         batch_size=args.test_batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers)

        else:
            test_dataset = ImageTTADataset(file_path=args.image_file,
                                           transform=data_transforms)
            test_dataloader = DataLoader(test_dataset,
                                         batch_size=args.test_batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers)

        image_id_list = []
        image_pred_list = []

        with torch.no_grad():
            for d in test_dataloader:
                image = d['image']
                img_id = d['image_id']

                image = image.to(args.device, dtype=torch.float)
                outputs = model(image)
                pred_prob = torch.nn.Softmax(dim=1)(outputs)

                image_id_list.append(img_id)
                image_pred_list.append(pred_prob / args.num_tta)

            if num_tta == 0:
                ids = list(chain(*image_id_list))
                preds = torch.cat(image_pred_list).cpu().numpy()

            else:
                preds_tmp = torch.cat(image_pred_list).cpu().numpy()
                preds += preds_tmp

        print(num_tta)

    df_pred = pd.DataFrame(preds, columns=['A', 'B', 'C'])
    df_id = pd.DataFrame(ids, columns=['image_ids'])
    sub = pd.concat([df_id, df_pred], axis=1)
    sub.to_csv(f'{args.output_name}.csv', index=False)
Example #5
0
def main():

    if args.device == 'cuda':
        torch.backends.cudnn.benchmark = True  #  should add to speed up the code when input array shape doesn't vary
        print('Using cudnn.benchmark.')

    model = model_dispatcher(True, args.base_model, args.nclass)
    model.to(args.device)
    # print(f'Loading pretrained model: {args.base_model}')

    train_indices, val_indices = get_train_valid_indice(
        test_size=args.test_size, random_state=args.random_state)

    train_size = len(train_indices)
    valid_size = len(val_indices)

    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)

    train_dataset = ImageSamplerDataset(phase='train',
                                        train_file=args.train_file,
                                        image_file_path=args.image_file,
                                        image_height=args.image_height,
                                        image_width=args.image_width,
                                        mean=MODEL_MEAN,
                                        std=MODEL_STD,
                                        binclass=args.binclass)

    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.train_batch_size,
                                  num_workers=args.num_workers,
                                  sampler=train_sampler)

    valid_dataset = ImageSamplerDataset(phase='valid',
                                        train_file=args.train_file,
                                        image_file_path=args.image_file,
                                        image_height=args.image_height,
                                        image_width=args.image_width,
                                        mean=MODEL_MEAN,
                                        std=MODEL_STD,
                                        binclass=args.binclass)

    valid_dataloader = DataLoader(dataset=valid_dataset,
                                  batch_size=args.test_batch_size,
                                  num_workers=args.num_workers,
                                  sampler=valid_sampler)

    optimizer = AdamW(model.parameters(),
                      lr=args.lr,
                      weight_decay=args.weight_decay)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               mode='max',
                                               patience=5,
                                               factor=0.3)

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel()

    val_accu_benchmark = 0.34
    val_loss_list = []
    val_accu_list = []
    tr_loss_list = []
    # tr_accu_list = []
    best_epoch = 0
    for epoch in range(args.epochs):
        tr_loss = train(dataset_size=train_size,
                        dataloader=train_dataloader,
                        model=model,
                        optimizer=optimizer,
                        device=args.device,
                        loss_fn=focal_loss_fn)
        val_loss, val_accu = evaluate(dataset_size=valid_size,
                                      dataloader=valid_dataloader,
                                      model=model,
                                      device=args.device,
                                      loss_fn=focal_loss_fn,
                                      tag='valid')
        print(f'Epoch_{epoch+1} Train Loss:{tr_loss}')
        print(f'Epoch_{epoch+1} Valid Loss:{val_loss}')
        scheduler.step(val_loss)

        tr_loss_list.append(tr_loss)
        # tr_accu_list.append(tr_accu)
        val_loss_list.append(val_loss)
        val_accu_list.append(val_accu)
        if val_accu > val_accu_benchmark:
            best_epoch = epoch + 1
            print(f'save {args.base_model} model on epoch {epoch+1}')
            torch.save(
                model.state_dict(),
                os.path.join(args.save_dir,
                             f'{args.base_model}_fold_{VALID_FOLDS[0]}.bin'))
            val_accu_benchmark = val_accu
    print(f'Save the best model on epoch {best_epoch}')

    stored_metrics = {
        'train': {
            'tr_loss_list': tr_loss_list
            # , 'tr_accu_list': tr_accu_list
        },
        'valid': {
            'val_loss_list': val_loss_list,
            'val_accu_list': val_accu_list
        }
    }

    # pickle a variable to a file
    file = open(os.path.join(args.save_dir, 'stored_metrics.pickle'), 'wb')
    pickle.dump(stored_metrics, file)
    file.close()