Esempio n. 1
0
def main():

    # config
    parser = argparse.ArgumentParser()
    parser.add_argument('-s',
                        '--save',
                        default='./save',
                        help='保存的文件夹路径,如果有重名,会在其后加-来区别')
    parser.add_argument('-is',
                        '--image_size',
                        default=224,
                        type=int,
                        help='patch会被resize到多大,默认时224 x 224')
    parser.add_argument('-vts',
                        '--valid_test_size',
                        default=(0.1, 0.1),
                        type=float,
                        nargs=2,
                        help='验证集、测试集的大小,默认时0.1, 0.1')
    parser.add_argument('-bs',
                        '--batch_size',
                        default=64,
                        type=int,
                        help='batch size,默认时64')
    parser.add_argument('-nw',
                        '--num_workers',
                        default=12,
                        type=int,
                        help='多进程数目,默认时12')
    parser.add_argument('-lr',
                        '--learning_rate',
                        default=0.0001,
                        type=float,
                        help='学习率大小,默认时0.0001')
    parser.add_argument('-e',
                        '--epoch',
                        default=10,
                        type=int,
                        help='epoch 数量,默认是10')
    parser.add_argument('-tp',
                        '--test_patches',
                        default=None,
                        type=int,
                        help=('测试时随机从每个patient中抽取的patches的数量,默认是None,'
                              '即使用全部的patches进行测试'))
    parser.add_argument('--cindex_reduction',
                        default='mean',
                        help='聚合同一张slide的patches时的聚合方式,默认时mean')
    parser.add_argument('--loss_type',
                        default='cox',
                        help='使用的loss的类型,默认是cox,也可以是svmloss')
    parser.add_argument('--zoom', default='40.0', help="使用的放大倍数,默认是40.0")
    parser.add_argument('--rank_ratio',
                        default=1.0,
                        type=float,
                        help="svmloss的rank_ratio,默认是1.0")
    args = parser.parse_args()
    save = args.save
    image_size = (args.image_size, args.image_size)
    valid_size, test_size = args.test_size
    batch_size = args.batch_size
    num_workers = args.num_workers
    lr = args.learning_rate
    epoch = args.epoch
    test_patches = args.test_patches
    cindex_reduction = args.cindex_reduction
    zoom = args.zoom
    rank_ratio = args.rank_ratio

    # ----- 读取数据 -----
    demographic_file = '/home/dl/NewDisk/Slides/TCGA-OV/demographic.csv'
    tiles_dir = '/home/dl/NewDisk/Slides/TCGA-OV/Tiles'

    dat = SlidePatchData.from_demographic(demographic_file,
                                          tiles_dir,
                                          transfer=transforms.ToTensor(),
                                          zoom=zoom)
    train_transfer = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    test_transfer = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    train_dat, valid_dat = dat.split_by_patients(valid_size + test_size,
                                                 train_transfer=train_transfer,
                                                 test_transfer=test_transfer)
    valid_dat, test_dat = valid_dat.split_by_patients(test_size /
                                                      (valid_size + test_size))
    train_sampler = OneEveryPatientSampler(train_dat)
    if test_patches is not None:
        test_sampler = OneEveryPatientSampler(valid_dat,
                                              num_per_patients=test_patches)
    else:
        test_sampler = None
    dataloaders = {
        'train':
        data.DataLoader(train_dat,
                        batch_size=batch_size,
                        sampler=train_sampler,
                        num_workers=num_workers),
        'valid':
        data.DataLoader(valid_dat,
                        batch_size=batch_size,
                        sampler=test_sampler,
                        num_workers=num_workers),
        'test':
        data.DataLoader(test_dat,
                        batch_size=batch_size,
                        sampler=test_sampler,
                        num_workers=num_workers),
    }

    # ----- 构建网络和优化器 -----
    net = SurvivalPatchCNN()
    if args.loss_type == 'cox':
        criterion = NegativeLogLikelihood()
    elif args.loss_type == 'svmloss':
        criterion = SvmLoss(rank_ratio=rank_ratio)
    optimizer = optim.Adam(net.parameters(), lr=lr)
    scorings = [mm.Loss(), mm.CIndexForSlide(reduction=cindex_reduction)]

    # ----- 训练网络 -----
    net, hist = train(net,
                      criterion,
                      optimizer,
                      dataloaders,
                      epoch=epoch,
                      metrics=scorings)
    print('')

    # ----- 最后的测试 -----
    test_hist = evaluate(net, dataloaders['test'], criterion, metrics=scorings)

    # 保存结果
    dirname = check_update_dirname(save)
    torch.save(net.state_dict(), os.path.join(dirname, 'model.pth'))
    pd.DataFrame(hist).to_csv(os.path.join(dirname, 'train.csv'))
    with open(os.path.join(dirname, 'test.json'), 'w') as f:
        json.dump(test_hist, f)
    with open(os.path.join(dirname, 'config.json'), 'w') as f:
        json.dump(args.__dict__, f)
Esempio n. 2
0
def train(model,
          criterion,
          optimizer,
          dataloaders,
          scheduler=NoneScheduler(None),
          epoch=100,
          device=torch.device('cuda:0'),
          l2=0.0,
          metrics=(mm.Loss(), mm.CIndexForSlide()),
          standard_metric_index=1,
          clip_grad=False):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_metric = 0.0
    best_metric_name = metrics[standard_metric_index].__class__.__name__ + \
        '_valid'
    history = {
        m.__class__.__name__ + p: []
        for p in ['_train', '_valid'] for m in metrics
    }
    model.to(device)

    for e in range(epoch):
        for phase in ['train', 'valid']:
            if phase == 'train':
                scheduler.step()
                model.train()
                prefix = "Train: "
            else:
                model.eval()
                prefix = "Valid: "
            # progressbar
            format_custom_text = pb.FormatCustomText('Loss: %(loss).4f',
                                                     dict(loss=0.))
            widgets = [
                prefix, " ",
                pb.Counter(), ' ',
                pb.Bar(), ' ',
                pb.Timer(), ' ',
                pb.AdaptiveETA(), ' ', format_custom_text
            ]
            iterator = pb.progressbar(dataloaders[phase], widgets=widgets)

            for m in metrics:
                m.reset()
            for batch_x, batch_y, (batch_ids, batch_files) in iterator:
                batch_x = batch_x.to(device)
                batch_y = batch_y.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    logit = model(batch_x)
                    loss = criterion(logit, batch_y)
                    # 只给weight加l2正则化
                    if l2 > 0.0:
                        for p_n, p_v in model.named_parameters():
                            if p_n == 'weight':
                                loss += l2 * p_v.norm()
                    if phase == 'train':
                        loss.backward()
                        if clip_grad:
                            nn.utils.clip_grad_norm_(model.parameters(),
                                                     max_norm=1)
                        optimizer.step()
                with torch.no_grad():
                    for m in metrics:
                        if isinstance(m, mm.Loss):
                            m.add(loss.cpu().item(), batch_x.size(0))
                            format_custom_text.update_mapping(loss=m.value())
                        else:
                            m.add(logit.squeeze(), batch_y, batch_ids)

            for m in metrics:
                history[m.__class__.__name__ + '_' + phase].append(m.value())
            print("Epoch: %d, Phase:%s, " % (e, phase) + ", ".join([
                '%s: %.4f' % (m.__class__.__name__,
                              history[m.__class__.__name__ + '_' + phase][-1])
                for m in metrics
            ]))

            if phase == 'valid':
                epoch_metric = history[best_metric_name][-1]
                if epoch_metric > best_metric:
                    best_metric = epoch_metric
                    best_model_wts = copy.deepcopy(model.state_dict())

    print("Best metric: %.4f" % best_metric)
    model.load_state_dict(best_model_wts)
    return model, history
Esempio n. 3
0
def main():
    warnings.filterwarnings('ignore')

    # config
    parser = argparse.ArgumentParser()
    parser.add_argument('-s',
                        '--save',
                        default='./save',
                        help='保存的文件夹路径,如果有重名,会在其后加-来区别')
    parser.add_argument('-is',
                        '--image_size',
                        default=224,
                        type=int,
                        help='patch会被resize到多大,默认时224 x 224')
    parser.add_argument('-vts',
                        '--valid_test_size',
                        default=(0.1, 0.1),
                        type=float,
                        nargs=2,
                        help='训练集和测试集的大小,默认时0.1, 0.1')
    parser.add_argument('-bs',
                        '--batch_size',
                        default=32,
                        type=int,
                        help='batch size,默认时32')
    parser.add_argument('-nw',
                        '--num_workers',
                        default=12,
                        type=int,
                        help='多进程数目,默认时12')
    parser.add_argument('-lr',
                        '--learning_rate',
                        default=0.0001,
                        type=float,
                        help='学习率大小,默认时0.0001')
    parser.add_argument('-e',
                        '--epoch',
                        default=10,
                        type=int,
                        help='epoch 数量,默认是10')
    parser.add_argument('--reduction',
                        default='mean',
                        help='聚合同一bag的instances时的聚合方式,默认时mean')
    parser.add_argument('--multipler',
                        default=2.0,
                        type=float,
                        help="为了平衡pos和neg,在weight再乘以一个大于1的数,默认是2.0")
    args = parser.parse_args()
    save = args.save
    image_size = (args.image_size, args.image_size)
    valid_size, test_size = args.valid_test_size
    batch_size = args.batch_size
    num_workers = args.num_workers
    lr = args.learning_rate
    epoch = args.epoch
    reduction = args.reduction
    multipler = args.multipler

    # ----- 读取数据 -----
    neg_dir = './DATA/TCT/negative'
    pos_dir = './DATA/TCT/positive'

    dat = MilData.from2dir(neg_dir, pos_dir)
    train_transfer = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    test_transfer = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    train_dat, valid_dat, test_dat = dat.split_by_bag(
        test_size,
        valid_size,
        train_transfer=train_transfer,
        valid_transfer=test_transfer,
        test_transfer=test_transfer)
    dataloaders = {
        'train':
        data.DataLoader(train_dat,
                        batch_size=batch_size,
                        num_workers=num_workers,
                        shuffle=True),
        'valid':
        data.DataLoader(
            valid_dat,
            batch_size=batch_size,
            num_workers=num_workers,
        ),
        'test':
        data.DataLoader(
            test_dat,
            batch_size=batch_size,
            num_workers=num_workers,
        )
    }

    # ----- 构建网络和优化器 -----
    net = NormalCnn()
    criterion = nn.BCELoss(reduction='none')
    optimizer = optim.Adam(net.parameters(), lr=lr)
    scorings = [
        mm.Loss(),
        mm.Recall(reduction=reduction),
        mm.ROCAUC(reduction=reduction),
        mm.BalancedAccuracy(reduction=reduction),
        mm.F1Score(reduction=reduction),
        mm.Precision(reduction=reduction),
        mm.Accuracy(reduction=reduction)
    ]

    # ----- 训练网络 -----
    try:
        net, hist, weighter = train(net,
                                    criterion,
                                    optimizer,
                                    dataloaders,
                                    epoch=epoch,
                                    metrics=scorings,
                                    weighter_multipler=multipler)

        test_hist = evaluate(net, dataloaders['test'], criterion, scorings)
    except Exception as e:
        import ipdb
        ipdb.set_trace()  # XXX BREAKPOINT

    # 保存结果
    dirname = check_update_dirname(save)
    torch.save(net.state_dict(), os.path.join(dirname, 'model.pth'))
    torch.save(weighter, os.path.join(dirname, 'weigher.pth'))
    pd.DataFrame(hist).to_csv(os.path.join(dirname, 'train.csv'))
    with open(os.path.join(dirname, 'config.json'), 'w') as f:
        json.dump(args.__dict__, f)
    with open(os.path.join(dirname, 'test.json'), 'w') as f:
        json.dump(test_hist, f)
Esempio n. 4
0
def train(model,
          criterion,
          optimizer,
          dataloaders,
          scheduler=NoneScheduler(None),
          epoch=100,
          device=torch.device('cuda:0'),
          l2=0.0,
          metrics=(mm.Loss(), ),
          standard_metric_index=1,
          clip_grad=False,
          weighter_multipler=1.0):
    weighter = Weighter(dataloaders['train'].dataset,
                        device,
                        multipler=weighter_multipler)
    # 构建几个变量来储存最好的模型
    best_model_wts = copy.deepcopy(model.state_dict())
    best_metric = 0.0
    best_metric_name = metrics[standard_metric_index].__class__.__name__ + \
        '_valid'
    best_weighter = copy.deepcopy(weighter)  # 储存最好模型对应的weighter
    # 构建dict来储存训练过程中的结果
    history = {
        m.__class__.__name__ + p: []
        for p in ['_train', '_valid'] for m in metrics
    }
    model.to(device)

    for e in range(epoch):
        for phase in ['train', 'valid']:
            if phase == 'train':
                scheduler.step()
                model.train()
                prefix = "Train: "
            else:
                model.eval()
                prefix = "Valid: "
            # progressbar
            format_custom_text = pb.FormatCustomText('Loss: %(loss).4f',
                                                     dict(loss=0.))
            widgets = [
                prefix, " ",
                pb.Percentage(), ' ',
                pb.SimpleProgress(format='(%s)' %
                                  pb.SimpleProgress.DEFAULT_FORMAT), ' ',
                pb.Bar(), ' ',
                pb.Timer(), ' ',
                pb.AdaptiveETA(), ' ', format_custom_text
            ]
            iterator = pb.progressbar(dataloaders[phase], widgets=widgets)

            for m in metrics:
                m.reset()
            for batch_x, batch_y, bag_ids, inst_ids in iterator:
                batch_x = batch_x.to(device)
                batch_y = batch_y.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    proba = model(batch_x).squeeze()  # 注意模型输出的需要时proba
                    # 计算每个样本的权重
                    w = weighter(proba, batch_y, bag_ids, inst_ids)
                    # 这个criterion不能reduction
                    loss_es = criterion(proba, batch_y.float())
                    # 使用计算的权重
                    loss = (loss_es * w).mean()
                    # 只给weight加l2正则化
                    if l2 > 0.0:
                        for p_n, p_v in model.named_parameters():
                            if p_n == 'weight':
                                loss += l2 * p_v.norm()
                    if phase == 'train':
                        loss.backward()
                        if clip_grad:
                            nn.utils.clip_grad_norm_(model.parameters(),
                                                     max_norm=1)
                        optimizer.step()
                with torch.no_grad():
                    for m in metrics:
                        if isinstance(m, mm.Loss):
                            m.add(loss.cpu().item(), batch_x.size(0))
                            format_custom_text.update_mapping(loss=m.value())
                        else:
                            m.add(proba.squeeze(), batch_y, bag_ids)

            for m in metrics:
                history[m.__class__.__name__ + '_' + phase].append(m.value())
            print("Epoch: %d, Phase:%s, " % (e, phase) + ", ".join([
                '%s: %.4f' % (m.__class__.__name__,
                              history[m.__class__.__name__ + '_' + phase][-1])
                for m in metrics
            ]))

            if phase == 'valid':
                epoch_metric = history[best_metric_name][-1]
                if epoch_metric > best_metric:
                    best_metric = epoch_metric
                    best_model_wts = copy.deepcopy(model.state_dict())
                    best_weighter = copy.deepcopy(weighter)

    print("Best metric: %.4f" % best_metric)
    model.load_state_dict(best_model_wts)
    return model, history, best_weighter
Esempio n. 5
0
    def fit(self, datas):
        # get dataloaders
        train_data = data.ConcatDataset([datas['subject'], datas['qc']]) \
            if self.train_with_qc else datas['subject']
        dataloaders = {
            'train': data.DataLoader(train_data, batch_size=self.bs,
                                     num_workers=self.nw, shuffle=True),
            'qc': data.DataLoader(datas['qc'], batch_size=self.bs,
                                  num_workers=self.nw)
        }
        # begin training
        pbar = tqdm(total=self.epoch)
        for e in range(self.epoch):
            self.e = e
            if e < self.rec_epoch:
                self.phase = 'rec_pretrain'
            elif e < self.rec_epoch + self.disc_epoch:
                self.phase = 'disc_pretrain'
            else:
                self.phase = 'iter_train'
            pbar.set_description(self.phase)

            # --- train phase ---
            for model in self.models.values():
                model.train()
            disc_b_loss_obj = mm.Loss()
            disc_o_loss_obj = mm.Loss()
            adv_b_loss_obj = mm.Loss()
            adv_o_loss_obj = mm.Loss()
            rec_loss_obj = mm.Loss()
            for batch_x, batch_y in tqdm(dataloaders['train'], 'Batch: '):
                batch_x = batch_x.to(self.device).float()
                batch_y = batch_y.to(self.device).float()
                bs0 = batch_x.size(0)
                for optimizer in self.optimizers.values():
                    optimizer.zero_grad()
                if self.phase in ['disc_pretrain', 'iter_train']:
                    disc_b_loss, disc_o_loss = \
                        self._forward_discriminate(batch_x, batch_y)
                    disc_b_loss_obj.add(disc_b_loss, bs0)
                    disc_o_loss_obj.add(disc_o_loss, bs0)
                if self.phase in ['rec_pretrain', 'iter_train']:
                    rec_loss, adv_b_loss, adv_o_loss = \
                        self._forward_autoencode(batch_x, batch_y)
                    rec_loss_obj.add(rec_loss, bs0)
                    adv_b_loss_obj.add(adv_b_loss, bs0)
                    adv_o_loss_obj.add(adv_o_loss, bs0)
            # record loss
            self.history['disc_b_loss'].append(disc_b_loss_obj.value())
            self.history['disc_o_loss'].append(disc_o_loss_obj.value())
            self.history['adv_b_loss'].append(adv_b_loss_obj.value())
            self.history['adv_o_loss'].append(adv_o_loss_obj.value())
            self.history['rec_loss'].append(rec_loss_obj.value())
            # visual epoch loss
            self.visobj.add_epoch_loss(
                winname='disc_losses',
                disc_b_loss=self.history['disc_b_loss'][-1],
                disc_o_loss=self.history['disc_o_loss'][-1],
                adv_b_loss=self.history['adv_b_loss'][-1],
                adv_o_loss=self.history['adv_o_loss'][-1],
            )
            self.visobj.add_epoch_loss(
                winname='recon_losses',
                recon_loss=self.history['rec_loss'][-1]
            )

            # --- valid phase ---
            all_data = ConcatData(datas['subject'], datas['qc'])
            all_reses_dict, qc_loss = self.generate(
                all_data, verbose=False, compute_qc_loss=True)
            # pca
            subject_pca, qc_pca = pca_for_dict(all_reses_dict, 3)
            # plot pca
            pca_plot(subject_pca, qc_pca)
            # display in visdom
            self.visobj.vis.matplot(plt, win='PCA', opts={'title': 'PCA'})
            plt.close()

            # --- early stopping ---
            qc_dist = mm.mean_distance(qc_pca['Rec_nobe'])
            self.history['qc_rec_loss'].append(qc_loss)
            self.history['qc_distance'].append(qc_dist)
            self.visobj.add_epoch_loss(winname='qc_rec_loss', qc_loss=qc_loss)
            self.visobj.add_epoch_loss(winname='qc_distance', qc_dist=qc_dist)
            if e >= (self.epoch - 200):
                self._check_qc(qc_dist, qc_loss)

            # progressbar
            pbar.update(1)
        pbar.close()

        # early stop information and save visdom env
        if self.visdom_env != 'main':
            self.visobj.vis.save([self.visdom_env])
        print('')
        print('The best epoch is %d' % self.early_stop_objs['best_epoch'])
        print('The best qc loss is %.4f' %
              self.early_stop_objs['best_qc_loss'])
        print('The best qc distance is %.4f' %
              self.early_stop_objs['best_qc_distance'])
        for k, v in self.models.items():
            v.load_state_dict(self.early_stop_objs['best_models'][k])
        self.early_stop_objs.pop('best_models')
        return self.models, self.history, self.early_stop_objs
Esempio n. 6
0
    def generate(self, data_loader, verbose=True, compute_qc_loss=False):
        for m in self.models.values():
            m.to(self.device).eval()
        if isinstance(data_loader, data.Dataset):
            data_loader = data.DataLoader(
                data_loader, batch_size=self.bs, num_workers=self.nw)
        x_ori, x_rec, x_rec_nobe, ys, codes = [], [], [], [], []
        qc_loss = mm.Loss()

        # encoding
        if verbose:
            print('----- encoding -----')
            iterator = tqdm(data_loader, 'encode and decode: ')
        else:
            iterator = data_loader
        with torch.no_grad():
            for batch_x, batch_y in iterator:
                # return x and y
                x_ori.append(batch_x)
                ys.append(batch_y)
                # return latent representation
                batch_x = batch_x.to(self.device, torch.float)
                batch_y = batch_y.to(self.device, torch.float)
                hidden = self.models['encoder'](batch_x)
                codes.append(hidden)
                # return rec with and without batch effects
                batch_ys = [
                    torch.eye(self.batch_label_num)[batch_y[:, 1].long()].to(
                        hidden),
                    batch_y[:, [0]]
                ]
                batch_ys = torch.cat(batch_ys, dim=1)
                hidden_be = hidden + self.models['map'](batch_ys)
                x_rec.append(self.models['decoder'](hidden_be))
                x_rec_nobe.append(self.models['decoder'](hidden))
                # return qc loss
                if compute_qc_loss:
                    qc_index = batch_y[:, -1] == 0.
                    if qc_index.sum() > 0:
                        batch_qc_loss = self.criterions['rec'](
                            batch_x[qc_index], x_rec[-1][qc_index])
                        qc_loss.add(
                            batch_qc_loss,
                            qc_index.sum().detach().cpu().item()
                        )
                    else:
                        qc_loss.add(torch.tensor(0.), 0)

        # return dataframe
        res = {
            'Ori': torch.cat(x_ori), 'Ys': torch.cat(ys),
            'Codes': torch.cat(codes), 'Rec': torch.cat(x_rec),
            'Rec_nobe': torch.cat(x_rec_nobe)
        }
        for k, v in res.items():
            if v is not None:
                if k == 'Ys':
                    res[k] = pd.DataFrame(
                        v.detach().cpu().numpy(),
                        index=data_loader.dataset.y_df.index,
                        columns=data_loader.dataset.y_df.columns
                    )
                elif k != 'Codes':
                    res[k] = pd.DataFrame(
                        v.detach().cpu().numpy(),
                        index=data_loader.dataset.x_df.index,
                        columns=data_loader.dataset.x_df.columns
                    )
                    res[k] = self.pre_transfer.inverse_transform(
                        res[k], None)[0]
                else:
                    res[k] = pd.DataFrame(
                        v.detach().cpu().numpy(),
                        index=data_loader.dataset.x_df.index,
                    )

        if compute_qc_loss:
            return res, qc_loss.value()
        return res
Esempio n. 7
0
def main():

    # ----- 根据data来读取不同的数据和不同的loss、metrics -----
    if config.args.data == 'brca':
        rna = RnaData.predicted_data(config.brca_cli, config.brca_rna,
                                     {'PAM50Call_RNAseq': 'pam50'})
        rna.transform(tf.LabelMapper(config.brca_label_mapper))
        out_shape = len(config.brca_label_mapper)
        criterion = nn.CrossEntropyLoss()
        scorings = (mm.Loss(), mm.Accuracy(), mm.BalancedAccuracy(),
                    mm.F1Score(average='macro'), mm.Precision(average='macro'),
                    mm.Recall(average='macro'), mm.ROCAUC(average='macro'))
    elif config.args.data == 'survival':
        if os.path.exists('./DATA/temp_pan.pth'):
            rna = RnaData.load('./DATA/temp_pan.pth')
        else:
            rna = RnaData.survival_data(config.pan_cli, config.pan_rna,
                                        '_OS_IND', '_OS')
        out_shape = 1
        if config.args.loss_type == 'cox':
            criterion = NegativeLogLikelihood()
        elif config.args.loss_type == 'svm':
            criterion = SvmLoss(rank_ratio=config.args.svm_rankratio)
        scorings = (mm.Loss(), mm.CIndex())
    rna.transform(tf.ZeroFilterCol(0.8))
    rna.transform(tf.MeanFilterCol(1))
    rna.transform(tf.StdFilterCol(0.5))
    norm = tf.Normalization()
    rna.transform(norm)

    # ----- 构建网络和优化器 -----
    inpt_shape = rna.X.shape[1]
    if config.args.net_type == 'mlp':
        net = MLP(inpt_shape, out_shape, config.args.hidden_num,
                  config.args.block_num).cuda()
    elif config.args.net_type == 'atten':
        net = SelfAttentionNet(inpt_shape, out_shape, config.args.hidden_num,
                               config.args.bottle_num, config.args.block_num,
                               config.args.no_res, config.act,
                               config.args.no_head, config.args.no_bottle,
                               config.args.no_atten,
                               config.args.dropout_rate).cuda()
    elif config.args.net_type == 'resnet':
        net = ResidualNet(inpt_shape, out_shape, config.args.hidden_num,
                          config.args.bottle_num,
                          config.args.block_num).cuda()

    # ----- 训练网络,cross validation -----
    split_iterator = rna.split_cv(config.args.test_size,
                                  config.args.cross_valid)
    train_hists = []
    test_hists = []
    for split_index, (train_rna, test_rna) in enumerate(split_iterator):
        print('##### save: %s, split: %d #####' %
              (config.args.save, split_index))
        #  从train中再分出一部分用作验证集,决定停止
        train_rna, valid_rna = train_rna.split(0.1)
        dats = {
            'train': train_rna.to_torchdat(),
            'valid': valid_rna.to_torchdat(),
        }
        dataloaders = {
            k: data.DataLoader(v, batch_size=config.args.batch_size)
            for k, v in dats.items()
        }
        test_dataloader = data.DataLoader(test_rna.to_torchdat(),
                                          batch_size=config.args.batch_size)
        # 网络训练前都进行一次参数重置,避免之前的训练的影响
        net.reset_parameters()
        # train
        optimizer = optim.Adamax(net.parameters(),
                                 lr=config.args.learning_rate)
        lrs = config.lrs(optimizer)
        net, hist = train(
            net,
            criterion,
            optimizer,
            dataloaders,
            epoch=config.args.epoch,
            metrics=scorings,
            l2=config.args.l2,
            standard_metric_index=config.args.standard_metric_index,
            scheduler=lrs)
        # test
        test_res = evaluate(net, criterion, test_dataloader, metrics=scorings)
        # 将多次训练的结果保存到一个df中
        hist = pd.DataFrame(hist)
        hist['split_index'] = split_index
        train_hists.append(hist)
        # 保存多次test的结果
        test_res['split_index'] = split_index
        test_hists.append(test_res)
        # 每个split训练的模型保存为一个文件
        torch.save(net.state_dict(),
                   os.path.join(config.save_dir, 'model%d.pth' % split_index))
    # 保存train的结果
    train_hists = pd.concat(train_hists)
    train_hists.to_csv(os.path.join(config.save_dir, 'train.csv'))
    # 保存test的结果
    test_hists = pd.DataFrame(test_hists)
    test_hists.to_csv(os.path.join(config.save_dir, 'test.csv'))