コード例 #1
0
ファイル: create_splits_seq.py プロジェクト: zpeng1989/CLAM
num_slides_cls = np.array(
    [len(cls_ids) for cls_ids in dataset.patient_cls_ids])
val_num = np.round(num_slides_cls * args.val_frac).astype(int)
test_num = np.round(num_slides_cls * args.test_frac).astype(int)

if __name__ == '__main__':
    if args.label_frac > 0:
        label_fracs = [args.label_frac]
    else:
        label_fracs = [0.1, 0.25, 0.5, 0.75, 1.0]

    for lf in label_fracs:
        split_dir = 'splits/' + str(args.task) + '_{}'.format(int(lf * 100))
        os.makedirs(split_dir, exist_ok=True)
        dataset.create_splits(k=args.k,
                              val_num=val_num,
                              test_num=test_num,
                              label_frac=lf)
        for i in range(args.k):
            dataset.set_splits()
            descriptor_df = dataset.test_split_gen(return_descriptor=True)
            splits = dataset.return_splits(from_id=True)
            save_splits(splits, ['train', 'val', 'test'],
                        os.path.join(split_dir, 'splits_{}.csv'.format(i)))
            save_splits(splits, ['train', 'val', 'test'],
                        os.path.join(split_dir,
                                     'splits_{}_bool.csv'.format(i)),
                        boolean_style=True)
            descriptor_df.to_csv(
                os.path.join(split_dir, 'splits_{}_descriptor.csv'.format(i)))
コード例 #2
0
def eval_model(datasets: tuple, cur: int, args):
    """   
        train for a single fold
    """
    print('\nTraining Fold {}!'.format(cur))

    print('\nInit train/val/test splits...', end=' ')
    train_split, val_split = datasets
    save_splits(datasets, ['train', 'val'],
                os.path.join(args.results_dir, 'splits_{}.csv'.format(cur)))
    print('Done!')
    print("Training on {} samples".format(len(train_split)))
    print("Validating on {} samples".format(len(val_split)))

    print('\nInit loss function...', end=' ')
    if args.task_type == 'survival':
        if args.bag_loss == 'ce_surv':
            loss_fn = CrossEntropySurvLoss(alpha=args.alpha_surv)
        elif args.bag_loss == 'nll_surv':
            loss_fn = NLLSurvLoss(alpha=args.alpha_surv)
        elif args.bag_loss == 'cox_surv':
            loss_fn = CoxSurvLoss()
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError

    if args.reg_type == 'omic':
        reg_fn = l1_reg_all
    elif args.reg_type == 'pathomic':
        reg_fn = l1_reg_modules
    else:
        reg_fn = None

    print('Done!')

    print('\nInit Model...', end=' ')
    model_dict = {"dropout": args.drop_out, 'n_classes': args.n_classes}
    if args.model_type == 'attention_mil':
        if args.task_type == 'survival':
            model = MIL_Attention_fc_surv(**model_dict)
        else:
            raise NotImplementedError
    elif args.model_type == 'mm_attention_mil':
        model_dict.update({
            'input_dim': args.omic_input_dim,
            'fusion': args.fusion,
            'model_size_wsi': args.model_size_wsi,
            'model_size_omic': args.model_size_omic,
            'gate_path': args.gate_path,
            'gate_omic': args.gate_omic,
            'n_classes': args.n_classes
        })
        if args.task_type == 'survival':
            model = MM_MIL_Attention_fc_surv(**model_dict)
        else:
            raise NotImplementedError
    elif args.model_type == 'max_net':
        model_dict = {
            'input_dim': args.omic_input_dim,
            'model_size_omic': args.model_size_omic,
            'n_classes': args.n_classes
        }
        if args.task_type == 'survival':
            model = MaxNet(**model_dict)
        else:
            raise NotImplementedError

    else:
        raise NotImplementedError

    model.relocate()
    print('Done!')
    print_network(model)
    ckpt = torch.load(
        os.path.join(args.results_dir, "s_{}_checkpoint.pt".format(cur)))
    model.load_state_dict(ckpt, strict=False)
    model.eval()

    print('\nInit Loaders...', end=' ')
    train_loader = get_split_loader(train_split,
                                    training=True,
                                    testing=args.testing,
                                    weighted=args.weighted_sample,
                                    task_type=args.task_type,
                                    batch_size=args.batch_size)
    val_loader = get_split_loader(val_split,
                                  testing=args.testing,
                                  task_type=args.task_type,
                                  batch_size=args.batch_size)
    print('Done!')

    results = {}
    if args.task_type == 'survival':
        results_val_dict, val_c_index = summary_survival(
            model, val_loader, args.n_classes)
        return results_val_dict, val_c_index
コード例 #3
0
def train(datasets, cur, args):
    """   
        train for a single fold
    """
    print('\nTraining Fold {}!'.format(cur))
    writer_dir = os.path.join(args.results_dir, str(cur))
    if not os.path.isdir(writer_dir):
        os.mkdir(writer_dir)

    if args.log_data:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(writer_dir, flush_secs=15)

    else:
        writer = None

    print('\nInit train/val/test splits...', end=' ')
    train_split, val_split, test_split = datasets
    save_splits(datasets, ['train', 'val', 'test'],
                os.path.join(args.results_dir, 'splits_{}.csv'.format(cur)))
    print('Done!')
    print("Training on {} samples".format(len(train_split)))
    print("Validating on {} samples".format(len(val_split)))
    print("Testing on {} samples".format(len(test_split)))

    print('\nInit loss function...', end=' ')
    if args.bag_loss == 'svm':
        from topk import SmoothTop1SVM
        loss_fn = SmoothTop1SVM(n_classes=args.n_classes)
        if device.type == 'cuda':
            loss_fn = loss_fn.cuda()
    else:
        loss_fn = nn.CrossEntropyLoss()
    print('Done!')

    print('\nInit Model...', end=' ')
    model_dict = {"dropout": args.drop_out, 'n_classes': args.n_classes}
    if args.model_type == 'clam' and args.subtyping:
        model_dict.update({'subtyping': True})

    if args.model_size is not None and args.model_type != 'mil':
        model_dict.update({"size_arg": args.model_size})

    if args.model_type == 'clam':
        if args.inst_loss == 'svm':
            from topk import SmoothTop1SVM
            instance_loss_fn = SmoothTop1SVM(n_classes=2)
            if device.type == 'cuda':
                instance_loss_fn = instance_loss_fn.cuda()
        else:
            instance_loss_fn = nn.CrossEntropyLoss()

        model = CLAM(**model_dict, instance_loss_fn=instance_loss_fn)

    else:  # args.model_type == 'mil'
        if args.n_classes > 2:
            model = MIL_fc_mc(**model_dict)
        else:
            model = MIL_fc(**model_dict)

    model.relocate()
    print('Done!')
    print_network(model)

    print('\nInit optimizer ...', end=' ')
    optimizer = get_optim(model, args)
    print('Done!')

    print('\nInit Loaders...', end=' ')
    train_loader = get_split_loader(train_split,
                                    training=True,
                                    testing=args.testing,
                                    weighted=args.weighted_sample)
    val_loader = get_split_loader(val_split, testing=args.testing)
    test_loader = get_split_loader(test_split, testing=args.testing)
    print('Done!')

    print('\nSetup EarlyStopping...', end=' ')
    if args.early_stopping:
        early_stopping = EarlyStopping(patience=20,
                                       stop_epoch=50,
                                       verbose=True)

    else:
        early_stopping = None
    print('Done!')

    for epoch in range(args.max_epochs):
        if args.model_type == 'clam':
            train_loop_clam(epoch, model, train_loader, optimizer,
                            args.n_classes, args.bag_weight, writer, loss_fn)
            stop = validate_clam(cur, epoch, model, val_loader, args.n_classes,
                                 early_stopping, writer, loss_fn,
                                 args.results_dir)

        else:
            train_loop(epoch, model, train_loader, optimizer, args.n_classes,
                       writer, loss_fn)
            stop = validate(cur, epoch, model, val_loader, args.n_classes,
                            early_stopping, writer, loss_fn, args.results_dir)

        if stop:
            break

    if args.early_stopping:
        model.load_state_dict(
            torch.load(
                os.path.join(args.results_dir,
                             "s_{}_checkpoint.pt".format(cur))))
    else:
        torch.save(
            model.state_dict(),
            os.path.join(args.results_dir, "s_{}_checkpoint.pt".format(cur)))

    _, val_error, val_auc, _ = summary(model, val_loader, args.n_classes)
    print('Val error: {:.4f}, ROC AUC: {:.4f}'.format(val_error, val_auc))

    results_dict, test_error, test_auc, acc_logger = summary(
        model, test_loader, args.n_classes)
    print('Test error: {:.4f}, ROC AUC: {:.4f}'.format(test_error, test_auc))

    for i in range(args.n_classes):
        acc, correct, count = acc_logger.get_summary(i)
        print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))

        if writer:
            writer.add_scalar('final/test_class_{}_acc'.format(i), acc, 0)

    if writer:
        writer.add_scalar('final/val_error', val_error, 0)
        writer.add_scalar('final/val_auc', val_auc, 0)
        writer.add_scalar('final/test_error', test_error, 0)
        writer.add_scalar('final/test_auc', test_auc, 0)

    writer.close()
    return results_dict, test_auc, val_auc, 1 - test_error, 1 - val_error
コード例 #4
0
def train(datasets: tuple, cur: int, args: Namespace):
    """   
        train for a single fold
    """
    print('\nTraining Fold {}!'.format(cur))
    writer_dir = os.path.join(args.results_dir, str(cur))
    if not os.path.isdir(writer_dir):
        os.mkdir(writer_dir)

    if args.log_data:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(writer_dir, flush_secs=15)

    else:
        writer = None

    print('\nInit train/val/test splits...', end=' ')
    train_split, val_split = datasets
    save_splits(datasets, ['train', 'val'],
                os.path.join(args.results_dir, 'splits_{}.csv'.format(cur)))
    print('Done!')
    print("Training on {} samples".format(len(train_split)))
    print("Validating on {} samples".format(len(val_split)))

    print('\nInit loss function...', end=' ')
    if args.task_type == 'survival':
        if args.bag_loss == 'ce_surv':
            loss_fn = CrossEntropySurvLoss(alpha=args.alpha_surv)
        elif args.bag_loss == 'nll_surv':
            loss_fn = NLLSurvLoss(alpha=args.alpha_surv)
        elif args.bag_loss == 'cox_surv':
            loss_fn = CoxSurvLoss()
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError

    if args.reg_type == 'omic':
        reg_fn = l1_reg_all
    elif args.reg_type == 'pathomic':
        reg_fn = l1_reg_modules
    else:
        reg_fn = None

    print('Done!')

    print('\nInit Model...', end=' ')
    model_dict = {"dropout": args.drop_out, 'n_classes': args.n_classes}

    if args.model_type == 'attention_mil':
        if args.task_type == 'survival':
            model = MIL_Attention_fc_surv(**model_dict)
        else:
            raise NotImplementedError
    elif args.model_type == 'mm_attention_mil':
        model_dict.update({
            'input_dim': args.omic_input_dim,
            'fusion': args.fusion,
            'model_size_wsi': args.model_size_wsi,
            'model_size_omic': args.model_size_omic,
            'gate_path': args.gate_path,
            'gate_omic': args.gate_omic,
            'n_classes': args.n_classes
        })

        if args.task_type == 'survival':
            model = MM_MIL_Attention_fc_surv(**model_dict)
        else:
            raise NotImplementedError
    elif args.model_type == 'max_net':
        model_dict = {
            'input_dim': args.omic_input_dim,
            'model_size_omic': args.model_size_omic,
            'n_classes': args.n_classes
        }
        if args.task_type == 'survival':
            model = MaxNet(**model_dict)
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError

    model.relocate()
    print('Done!')
    print_network(model)

    print('\nInit optimizer ...', end=' ')
    optimizer = get_optim(model, args)
    print('Done!')

    print('\nInit Loaders...', end=' ')
    train_loader = get_split_loader(train_split,
                                    training=True,
                                    testing=args.testing,
                                    weighted=args.weighted_sample,
                                    task_type=args.task_type,
                                    batch_size=args.batch_size)
    val_loader = get_split_loader(val_split,
                                  testing=args.testing,
                                  task_type=args.task_type,
                                  batch_size=args.batch_size)
    print('Done!')

    print('\nSetup EarlyStopping...', end=' ')
    if args.early_stopping:
        early_stopping = EarlyStopping(warmup=0,
                                       patience=10,
                                       stop_epoch=20,
                                       verbose=True)

    else:
        early_stopping = None

    print('\nSetup Validation C-Index Monitor...', end=' ')
    monitor_cindex = Monitor_CIndex()
    print('Done!')

    for epoch in range(args.max_epochs):
        if args.task_type == 'survival':
            train_loop_survival(epoch, model, train_loader, optimizer,
                                args.n_classes, writer, loss_fn, reg_fn,
                                args.lambda_reg, args.gc)
            stop = validate_survival(cur, epoch, model, val_loader,
                                     args.n_classes, early_stopping,
                                     monitor_cindex, writer, loss_fn, reg_fn,
                                     args.lambda_reg, args.results_dir)

    torch.save(
        model.state_dict(),
        os.path.join(args.results_dir, "s_{}_checkpoint.pt".format(cur)))
    model.load_state_dict(
        torch.load(
            os.path.join(args.results_dir, "s_{}_checkpoint.pt".format(cur))))
    results_val_dict, val_cindex = summary_survival(model, val_loader,
                                                    args.n_classes)
    print('Val c-Index: {:.4f}'.format(val_cindex))
    writer.close()
    return results_val_dict, val_cindex
コード例 #5
0
ファイル: core_utils.py プロジェクト: miccai2021anon/2410
def train(datasets: tuple, cur: int, args: Namespace):
    """   
        train for a single fold
    """
    print('\nTraining Fold {}!'.format(cur))
    writer_dir = os.path.join(args.results_dir, str(cur))
    if not os.path.isdir(writer_dir):
        os.mkdir(writer_dir)

    if args.log_data:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(writer_dir, flush_secs=15)

    else:
        writer = None

    print('\nInit train/val/test splits...', end=' ')
    train_split, val_split = datasets
    save_splits(datasets, ['train', 'val'],
                os.path.join(args.results_dir, 'splits_{}.csv'.format(cur)))
    print('Done!')
    print("Training on {} samples".format(len(train_split)))
    print("Validating on {} samples".format(len(val_split)))

    print('\nInit loss function...', end=' ')
    if args.task_type == 'survival':
        if args.bag_loss == 'ce_surv':
            loss_fn = CrossEntropySurvLoss(alpha=args.alpha_surv)
        elif args.bag_loss == 'nll_surv':
            loss_fn = NLLSurvLoss(alpha=args.alpha_surv)
        elif args.bag_loss == 'cox_surv':
            loss_fn = CoxSurvLoss()
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError

    print('\nInit Model...', end=' ')
    model_dict = {"dropout": args.drop_out, 'n_classes': args.n_classes}
    if args.model_type == 'deepset':
        model_dict = {'n_classes': args.n_classes}
        model = MIL_Sum_FC_surv(**model_dict)
    elif args.model_type == 'amil':
        model_dict = {'n_classes': args.n_classes}
        model = MIL_Attention_FC_surv(**model_dict)
    elif args.model_type == 'mifcn':
        model_dict = {'num_clusters': 10, 'n_classes': args.n_classes}
        model = MIL_Cluster_FC_surv(**model_dict)
    elif args.model_type == 'dgc':
        model_dict = {
            'edge_agg': args.edge_agg,
            'resample': args.resample,
            'n_classes': args.n_classes
        }
        model = DeepGraphConv_Surv(**model_dict)
    elif args.model_type == 'patchgcn':
        model_dict = {
            'num_layers': args.num_gcn_layers,
            'edge_agg': args.edge_agg,
            'resample': args.resample,
            'n_classes': args.n_classes
        }
        model = PatchGCN_Surv(**model_dict)
    else:
        raise NotImplementedError

    if hasattr(model, "relocate"):
        model.relocate()
    else:
        model = model.to(torch.device('cuda'))
    print('Done!')
    print_network(model)

    print('\nInit optimizer ...', end=' ')
    optimizer = get_optim(model, args)
    print('Done!')

    print('\nInit Loaders...', end=' ')
    train_loader = get_split_loader(train_split,
                                    training=True,
                                    testing=args.testing,
                                    weighted=args.weighted_sample,
                                    mode=args.mode,
                                    batch_size=args.batch_size)
    val_loader = get_split_loader(val_split,
                                  testing=args.testing,
                                  mode=args.mode,
                                  batch_size=args.batch_size)
    print('Done!')

    print('\nSetup EarlyStopping...', end=' ')
    if args.early_stopping:
        early_stopping = EarlyStopping(warmup=0,
                                       patience=10,
                                       stop_epoch=20,
                                       verbose=True)
    else:
        early_stopping = None

    print('\nSetup Validation C-Index Monitor...', end=' ')
    monitor_cindex = Monitor_CIndex()
    print('Done!')

    for epoch in range(args.max_epochs):
        if args.task_type == 'survival':
            if args.mode == 'cluster':
                train_loop_survival_cluster(epoch, model, train_loader,
                                            optimizer, args.n_classes, writer,
                                            loss_fn, reg_fn, args.lambda_reg,
                                            args.gc, VAE)
                stop = validate_survival_cluster(
                    cur, epoch, model, val_loader, args.n_classes,
                    early_stopping, monitor_cindex, writer, loss_fn, reg_fn,
                    args.lambda_reg, args.results_dir, VAE)
            else:
                train_loop_survival(epoch, model, train_loader, optimizer,
                                    args.n_classes, writer, loss_fn, reg_fn,
                                    args.lambda_reg, args.gc)
                stop = validate_survival(cur, epoch, model, val_loader,
                                         args.n_classes, early_stopping,
                                         monitor_cindex, writer, loss_fn,
                                         reg_fn, args.lambda_reg,
                                         args.results_dir)

    torch.save(
        model.state_dict(),
        os.path.join(args.results_dir, "s_{}_checkpoint.pt".format(cur)))
    model.load_state_dict(
        torch.load(
            os.path.join(args.results_dir, "s_{}_checkpoint.pt".format(cur))))
    results_val_dict, val_cindex = summary_survival(model, val_loader,
                                                    args.n_classes)
    print('Val c-Index: {:.4f}'.format(val_cindex))
    writer.close()
    return results_val_dict, val_cindex