def main(): # config parser = argparse.ArgumentParser() parser.add_argument('-s', '--save', default='./save', help='保存的文件夹路径,如果有重名,会在其后加-来区别') parser.add_argument('-is', '--image_size', default=224, type=int, help='patch会被resize到多大,默认时224 x 224') parser.add_argument('-vts', '--valid_test_size', default=(0.1, 0.1), type=float, nargs=2, help='验证集、测试集的大小,默认时0.1, 0.1') parser.add_argument('-bs', '--batch_size', default=64, type=int, help='batch size,默认时64') parser.add_argument('-nw', '--num_workers', default=12, type=int, help='多进程数目,默认时12') parser.add_argument('-lr', '--learning_rate', default=0.0001, type=float, help='学习率大小,默认时0.0001') parser.add_argument('-e', '--epoch', default=10, type=int, help='epoch 数量,默认是10') parser.add_argument('-tp', '--test_patches', default=None, type=int, help=('测试时随机从每个patient中抽取的patches的数量,默认是None,' '即使用全部的patches进行测试')) parser.add_argument('--cindex_reduction', default='mean', help='聚合同一张slide的patches时的聚合方式,默认时mean') parser.add_argument('--loss_type', default='cox', help='使用的loss的类型,默认是cox,也可以是svmloss') parser.add_argument('--zoom', default='40.0', help="使用的放大倍数,默认是40.0") parser.add_argument('--rank_ratio', default=1.0, type=float, help="svmloss的rank_ratio,默认是1.0") args = parser.parse_args() save = args.save image_size = (args.image_size, args.image_size) valid_size, test_size = args.test_size batch_size = args.batch_size num_workers = args.num_workers lr = args.learning_rate epoch = args.epoch test_patches = args.test_patches cindex_reduction = args.cindex_reduction zoom = args.zoom rank_ratio = args.rank_ratio # ----- 读取数据 ----- demographic_file = '/home/dl/NewDisk/Slides/TCGA-OV/demographic.csv' tiles_dir = '/home/dl/NewDisk/Slides/TCGA-OV/Tiles' dat = SlidePatchData.from_demographic(demographic_file, tiles_dir, transfer=transforms.ToTensor(), zoom=zoom) train_transfer = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) test_transfer = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) train_dat, valid_dat = dat.split_by_patients(valid_size + test_size, train_transfer=train_transfer, test_transfer=test_transfer) valid_dat, test_dat = valid_dat.split_by_patients(test_size / (valid_size + test_size)) train_sampler = OneEveryPatientSampler(train_dat) if test_patches is not None: test_sampler = OneEveryPatientSampler(valid_dat, num_per_patients=test_patches) else: test_sampler = None dataloaders = { 'train': data.DataLoader(train_dat, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers), 'valid': data.DataLoader(valid_dat, batch_size=batch_size, sampler=test_sampler, num_workers=num_workers), 'test': data.DataLoader(test_dat, batch_size=batch_size, sampler=test_sampler, num_workers=num_workers), } # ----- 构建网络和优化器 ----- net = SurvivalPatchCNN() if args.loss_type == 'cox': criterion = NegativeLogLikelihood() elif args.loss_type == 'svmloss': criterion = SvmLoss(rank_ratio=rank_ratio) optimizer = optim.Adam(net.parameters(), lr=lr) scorings = [mm.Loss(), mm.CIndexForSlide(reduction=cindex_reduction)] # ----- 训练网络 ----- net, hist = train(net, criterion, optimizer, dataloaders, epoch=epoch, metrics=scorings) print('') # ----- 最后的测试 ----- test_hist = evaluate(net, dataloaders['test'], criterion, metrics=scorings) # 保存结果 dirname = check_update_dirname(save) torch.save(net.state_dict(), os.path.join(dirname, 'model.pth')) pd.DataFrame(hist).to_csv(os.path.join(dirname, 'train.csv')) with open(os.path.join(dirname, 'test.json'), 'w') as f: json.dump(test_hist, f) with open(os.path.join(dirname, 'config.json'), 'w') as f: json.dump(args.__dict__, f)
def train(model, criterion, optimizer, dataloaders, scheduler=NoneScheduler(None), epoch=100, device=torch.device('cuda:0'), l2=0.0, metrics=(mm.Loss(), mm.CIndexForSlide()), standard_metric_index=1, clip_grad=False): best_model_wts = copy.deepcopy(model.state_dict()) best_metric = 0.0 best_metric_name = metrics[standard_metric_index].__class__.__name__ + \ '_valid' history = { m.__class__.__name__ + p: [] for p in ['_train', '_valid'] for m in metrics } model.to(device) for e in range(epoch): for phase in ['train', 'valid']: if phase == 'train': scheduler.step() model.train() prefix = "Train: " else: model.eval() prefix = "Valid: " # progressbar format_custom_text = pb.FormatCustomText('Loss: %(loss).4f', dict(loss=0.)) widgets = [ prefix, " ", pb.Counter(), ' ', pb.Bar(), ' ', pb.Timer(), ' ', pb.AdaptiveETA(), ' ', format_custom_text ] iterator = pb.progressbar(dataloaders[phase], widgets=widgets) for m in metrics: m.reset() for batch_x, batch_y, (batch_ids, batch_files) in iterator: batch_x = batch_x.to(device) batch_y = batch_y.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): logit = model(batch_x) loss = criterion(logit, batch_y) # 只给weight加l2正则化 if l2 > 0.0: for p_n, p_v in model.named_parameters(): if p_n == 'weight': loss += l2 * p_v.norm() if phase == 'train': loss.backward() if clip_grad: nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) optimizer.step() with torch.no_grad(): for m in metrics: if isinstance(m, mm.Loss): m.add(loss.cpu().item(), batch_x.size(0)) format_custom_text.update_mapping(loss=m.value()) else: m.add(logit.squeeze(), batch_y, batch_ids) for m in metrics: history[m.__class__.__name__ + '_' + phase].append(m.value()) print("Epoch: %d, Phase:%s, " % (e, phase) + ", ".join([ '%s: %.4f' % (m.__class__.__name__, history[m.__class__.__name__ + '_' + phase][-1]) for m in metrics ])) if phase == 'valid': epoch_metric = history[best_metric_name][-1] if epoch_metric > best_metric: best_metric = epoch_metric best_model_wts = copy.deepcopy(model.state_dict()) print("Best metric: %.4f" % best_metric) model.load_state_dict(best_model_wts) return model, history
def main(): warnings.filterwarnings('ignore') # config parser = argparse.ArgumentParser() parser.add_argument('-s', '--save', default='./save', help='保存的文件夹路径,如果有重名,会在其后加-来区别') parser.add_argument('-is', '--image_size', default=224, type=int, help='patch会被resize到多大,默认时224 x 224') parser.add_argument('-vts', '--valid_test_size', default=(0.1, 0.1), type=float, nargs=2, help='训练集和测试集的大小,默认时0.1, 0.1') parser.add_argument('-bs', '--batch_size', default=32, type=int, help='batch size,默认时32') parser.add_argument('-nw', '--num_workers', default=12, type=int, help='多进程数目,默认时12') parser.add_argument('-lr', '--learning_rate', default=0.0001, type=float, help='学习率大小,默认时0.0001') parser.add_argument('-e', '--epoch', default=10, type=int, help='epoch 数量,默认是10') parser.add_argument('--reduction', default='mean', help='聚合同一bag的instances时的聚合方式,默认时mean') parser.add_argument('--multipler', default=2.0, type=float, help="为了平衡pos和neg,在weight再乘以一个大于1的数,默认是2.0") args = parser.parse_args() save = args.save image_size = (args.image_size, args.image_size) valid_size, test_size = args.valid_test_size batch_size = args.batch_size num_workers = args.num_workers lr = args.learning_rate epoch = args.epoch reduction = args.reduction multipler = args.multipler # ----- 读取数据 ----- neg_dir = './DATA/TCT/negative' pos_dir = './DATA/TCT/positive' dat = MilData.from2dir(neg_dir, pos_dir) train_transfer = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) test_transfer = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) train_dat, valid_dat, test_dat = dat.split_by_bag( test_size, valid_size, train_transfer=train_transfer, valid_transfer=test_transfer, test_transfer=test_transfer) dataloaders = { 'train': data.DataLoader(train_dat, batch_size=batch_size, num_workers=num_workers, shuffle=True), 'valid': data.DataLoader( valid_dat, batch_size=batch_size, num_workers=num_workers, ), 'test': data.DataLoader( test_dat, batch_size=batch_size, num_workers=num_workers, ) } # ----- 构建网络和优化器 ----- net = NormalCnn() criterion = nn.BCELoss(reduction='none') optimizer = optim.Adam(net.parameters(), lr=lr) scorings = [ mm.Loss(), mm.Recall(reduction=reduction), mm.ROCAUC(reduction=reduction), mm.BalancedAccuracy(reduction=reduction), mm.F1Score(reduction=reduction), mm.Precision(reduction=reduction), mm.Accuracy(reduction=reduction) ] # ----- 训练网络 ----- try: net, hist, weighter = train(net, criterion, optimizer, dataloaders, epoch=epoch, metrics=scorings, weighter_multipler=multipler) test_hist = evaluate(net, dataloaders['test'], criterion, scorings) except Exception as e: import ipdb ipdb.set_trace() # XXX BREAKPOINT # 保存结果 dirname = check_update_dirname(save) torch.save(net.state_dict(), os.path.join(dirname, 'model.pth')) torch.save(weighter, os.path.join(dirname, 'weigher.pth')) pd.DataFrame(hist).to_csv(os.path.join(dirname, 'train.csv')) with open(os.path.join(dirname, 'config.json'), 'w') as f: json.dump(args.__dict__, f) with open(os.path.join(dirname, 'test.json'), 'w') as f: json.dump(test_hist, f)
def train(model, criterion, optimizer, dataloaders, scheduler=NoneScheduler(None), epoch=100, device=torch.device('cuda:0'), l2=0.0, metrics=(mm.Loss(), ), standard_metric_index=1, clip_grad=False, weighter_multipler=1.0): weighter = Weighter(dataloaders['train'].dataset, device, multipler=weighter_multipler) # 构建几个变量来储存最好的模型 best_model_wts = copy.deepcopy(model.state_dict()) best_metric = 0.0 best_metric_name = metrics[standard_metric_index].__class__.__name__ + \ '_valid' best_weighter = copy.deepcopy(weighter) # 储存最好模型对应的weighter # 构建dict来储存训练过程中的结果 history = { m.__class__.__name__ + p: [] for p in ['_train', '_valid'] for m in metrics } model.to(device) for e in range(epoch): for phase in ['train', 'valid']: if phase == 'train': scheduler.step() model.train() prefix = "Train: " else: model.eval() prefix = "Valid: " # progressbar format_custom_text = pb.FormatCustomText('Loss: %(loss).4f', dict(loss=0.)) widgets = [ prefix, " ", pb.Percentage(), ' ', pb.SimpleProgress(format='(%s)' % pb.SimpleProgress.DEFAULT_FORMAT), ' ', pb.Bar(), ' ', pb.Timer(), ' ', pb.AdaptiveETA(), ' ', format_custom_text ] iterator = pb.progressbar(dataloaders[phase], widgets=widgets) for m in metrics: m.reset() for batch_x, batch_y, bag_ids, inst_ids in iterator: batch_x = batch_x.to(device) batch_y = batch_y.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): proba = model(batch_x).squeeze() # 注意模型输出的需要时proba # 计算每个样本的权重 w = weighter(proba, batch_y, bag_ids, inst_ids) # 这个criterion不能reduction loss_es = criterion(proba, batch_y.float()) # 使用计算的权重 loss = (loss_es * w).mean() # 只给weight加l2正则化 if l2 > 0.0: for p_n, p_v in model.named_parameters(): if p_n == 'weight': loss += l2 * p_v.norm() if phase == 'train': loss.backward() if clip_grad: nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) optimizer.step() with torch.no_grad(): for m in metrics: if isinstance(m, mm.Loss): m.add(loss.cpu().item(), batch_x.size(0)) format_custom_text.update_mapping(loss=m.value()) else: m.add(proba.squeeze(), batch_y, bag_ids) for m in metrics: history[m.__class__.__name__ + '_' + phase].append(m.value()) print("Epoch: %d, Phase:%s, " % (e, phase) + ", ".join([ '%s: %.4f' % (m.__class__.__name__, history[m.__class__.__name__ + '_' + phase][-1]) for m in metrics ])) if phase == 'valid': epoch_metric = history[best_metric_name][-1] if epoch_metric > best_metric: best_metric = epoch_metric best_model_wts = copy.deepcopy(model.state_dict()) best_weighter = copy.deepcopy(weighter) print("Best metric: %.4f" % best_metric) model.load_state_dict(best_model_wts) return model, history, best_weighter
def fit(self, datas): # get dataloaders train_data = data.ConcatDataset([datas['subject'], datas['qc']]) \ if self.train_with_qc else datas['subject'] dataloaders = { 'train': data.DataLoader(train_data, batch_size=self.bs, num_workers=self.nw, shuffle=True), 'qc': data.DataLoader(datas['qc'], batch_size=self.bs, num_workers=self.nw) } # begin training pbar = tqdm(total=self.epoch) for e in range(self.epoch): self.e = e if e < self.rec_epoch: self.phase = 'rec_pretrain' elif e < self.rec_epoch + self.disc_epoch: self.phase = 'disc_pretrain' else: self.phase = 'iter_train' pbar.set_description(self.phase) # --- train phase --- for model in self.models.values(): model.train() disc_b_loss_obj = mm.Loss() disc_o_loss_obj = mm.Loss() adv_b_loss_obj = mm.Loss() adv_o_loss_obj = mm.Loss() rec_loss_obj = mm.Loss() for batch_x, batch_y in tqdm(dataloaders['train'], 'Batch: '): batch_x = batch_x.to(self.device).float() batch_y = batch_y.to(self.device).float() bs0 = batch_x.size(0) for optimizer in self.optimizers.values(): optimizer.zero_grad() if self.phase in ['disc_pretrain', 'iter_train']: disc_b_loss, disc_o_loss = \ self._forward_discriminate(batch_x, batch_y) disc_b_loss_obj.add(disc_b_loss, bs0) disc_o_loss_obj.add(disc_o_loss, bs0) if self.phase in ['rec_pretrain', 'iter_train']: rec_loss, adv_b_loss, adv_o_loss = \ self._forward_autoencode(batch_x, batch_y) rec_loss_obj.add(rec_loss, bs0) adv_b_loss_obj.add(adv_b_loss, bs0) adv_o_loss_obj.add(adv_o_loss, bs0) # record loss self.history['disc_b_loss'].append(disc_b_loss_obj.value()) self.history['disc_o_loss'].append(disc_o_loss_obj.value()) self.history['adv_b_loss'].append(adv_b_loss_obj.value()) self.history['adv_o_loss'].append(adv_o_loss_obj.value()) self.history['rec_loss'].append(rec_loss_obj.value()) # visual epoch loss self.visobj.add_epoch_loss( winname='disc_losses', disc_b_loss=self.history['disc_b_loss'][-1], disc_o_loss=self.history['disc_o_loss'][-1], adv_b_loss=self.history['adv_b_loss'][-1], adv_o_loss=self.history['adv_o_loss'][-1], ) self.visobj.add_epoch_loss( winname='recon_losses', recon_loss=self.history['rec_loss'][-1] ) # --- valid phase --- all_data = ConcatData(datas['subject'], datas['qc']) all_reses_dict, qc_loss = self.generate( all_data, verbose=False, compute_qc_loss=True) # pca subject_pca, qc_pca = pca_for_dict(all_reses_dict, 3) # plot pca pca_plot(subject_pca, qc_pca) # display in visdom self.visobj.vis.matplot(plt, win='PCA', opts={'title': 'PCA'}) plt.close() # --- early stopping --- qc_dist = mm.mean_distance(qc_pca['Rec_nobe']) self.history['qc_rec_loss'].append(qc_loss) self.history['qc_distance'].append(qc_dist) self.visobj.add_epoch_loss(winname='qc_rec_loss', qc_loss=qc_loss) self.visobj.add_epoch_loss(winname='qc_distance', qc_dist=qc_dist) if e >= (self.epoch - 200): self._check_qc(qc_dist, qc_loss) # progressbar pbar.update(1) pbar.close() # early stop information and save visdom env if self.visdom_env != 'main': self.visobj.vis.save([self.visdom_env]) print('') print('The best epoch is %d' % self.early_stop_objs['best_epoch']) print('The best qc loss is %.4f' % self.early_stop_objs['best_qc_loss']) print('The best qc distance is %.4f' % self.early_stop_objs['best_qc_distance']) for k, v in self.models.items(): v.load_state_dict(self.early_stop_objs['best_models'][k]) self.early_stop_objs.pop('best_models') return self.models, self.history, self.early_stop_objs
def generate(self, data_loader, verbose=True, compute_qc_loss=False): for m in self.models.values(): m.to(self.device).eval() if isinstance(data_loader, data.Dataset): data_loader = data.DataLoader( data_loader, batch_size=self.bs, num_workers=self.nw) x_ori, x_rec, x_rec_nobe, ys, codes = [], [], [], [], [] qc_loss = mm.Loss() # encoding if verbose: print('----- encoding -----') iterator = tqdm(data_loader, 'encode and decode: ') else: iterator = data_loader with torch.no_grad(): for batch_x, batch_y in iterator: # return x and y x_ori.append(batch_x) ys.append(batch_y) # return latent representation batch_x = batch_x.to(self.device, torch.float) batch_y = batch_y.to(self.device, torch.float) hidden = self.models['encoder'](batch_x) codes.append(hidden) # return rec with and without batch effects batch_ys = [ torch.eye(self.batch_label_num)[batch_y[:, 1].long()].to( hidden), batch_y[:, [0]] ] batch_ys = torch.cat(batch_ys, dim=1) hidden_be = hidden + self.models['map'](batch_ys) x_rec.append(self.models['decoder'](hidden_be)) x_rec_nobe.append(self.models['decoder'](hidden)) # return qc loss if compute_qc_loss: qc_index = batch_y[:, -1] == 0. if qc_index.sum() > 0: batch_qc_loss = self.criterions['rec']( batch_x[qc_index], x_rec[-1][qc_index]) qc_loss.add( batch_qc_loss, qc_index.sum().detach().cpu().item() ) else: qc_loss.add(torch.tensor(0.), 0) # return dataframe res = { 'Ori': torch.cat(x_ori), 'Ys': torch.cat(ys), 'Codes': torch.cat(codes), 'Rec': torch.cat(x_rec), 'Rec_nobe': torch.cat(x_rec_nobe) } for k, v in res.items(): if v is not None: if k == 'Ys': res[k] = pd.DataFrame( v.detach().cpu().numpy(), index=data_loader.dataset.y_df.index, columns=data_loader.dataset.y_df.columns ) elif k != 'Codes': res[k] = pd.DataFrame( v.detach().cpu().numpy(), index=data_loader.dataset.x_df.index, columns=data_loader.dataset.x_df.columns ) res[k] = self.pre_transfer.inverse_transform( res[k], None)[0] else: res[k] = pd.DataFrame( v.detach().cpu().numpy(), index=data_loader.dataset.x_df.index, ) if compute_qc_loss: return res, qc_loss.value() return res
def main(): # ----- 根据data来读取不同的数据和不同的loss、metrics ----- if config.args.data == 'brca': rna = RnaData.predicted_data(config.brca_cli, config.brca_rna, {'PAM50Call_RNAseq': 'pam50'}) rna.transform(tf.LabelMapper(config.brca_label_mapper)) out_shape = len(config.brca_label_mapper) criterion = nn.CrossEntropyLoss() scorings = (mm.Loss(), mm.Accuracy(), mm.BalancedAccuracy(), mm.F1Score(average='macro'), mm.Precision(average='macro'), mm.Recall(average='macro'), mm.ROCAUC(average='macro')) elif config.args.data == 'survival': if os.path.exists('./DATA/temp_pan.pth'): rna = RnaData.load('./DATA/temp_pan.pth') else: rna = RnaData.survival_data(config.pan_cli, config.pan_rna, '_OS_IND', '_OS') out_shape = 1 if config.args.loss_type == 'cox': criterion = NegativeLogLikelihood() elif config.args.loss_type == 'svm': criterion = SvmLoss(rank_ratio=config.args.svm_rankratio) scorings = (mm.Loss(), mm.CIndex()) rna.transform(tf.ZeroFilterCol(0.8)) rna.transform(tf.MeanFilterCol(1)) rna.transform(tf.StdFilterCol(0.5)) norm = tf.Normalization() rna.transform(norm) # ----- 构建网络和优化器 ----- inpt_shape = rna.X.shape[1] if config.args.net_type == 'mlp': net = MLP(inpt_shape, out_shape, config.args.hidden_num, config.args.block_num).cuda() elif config.args.net_type == 'atten': net = SelfAttentionNet(inpt_shape, out_shape, config.args.hidden_num, config.args.bottle_num, config.args.block_num, config.args.no_res, config.act, config.args.no_head, config.args.no_bottle, config.args.no_atten, config.args.dropout_rate).cuda() elif config.args.net_type == 'resnet': net = ResidualNet(inpt_shape, out_shape, config.args.hidden_num, config.args.bottle_num, config.args.block_num).cuda() # ----- 训练网络,cross validation ----- split_iterator = rna.split_cv(config.args.test_size, config.args.cross_valid) train_hists = [] test_hists = [] for split_index, (train_rna, test_rna) in enumerate(split_iterator): print('##### save: %s, split: %d #####' % (config.args.save, split_index)) # 从train中再分出一部分用作验证集,决定停止 train_rna, valid_rna = train_rna.split(0.1) dats = { 'train': train_rna.to_torchdat(), 'valid': valid_rna.to_torchdat(), } dataloaders = { k: data.DataLoader(v, batch_size=config.args.batch_size) for k, v in dats.items() } test_dataloader = data.DataLoader(test_rna.to_torchdat(), batch_size=config.args.batch_size) # 网络训练前都进行一次参数重置,避免之前的训练的影响 net.reset_parameters() # train optimizer = optim.Adamax(net.parameters(), lr=config.args.learning_rate) lrs = config.lrs(optimizer) net, hist = train( net, criterion, optimizer, dataloaders, epoch=config.args.epoch, metrics=scorings, l2=config.args.l2, standard_metric_index=config.args.standard_metric_index, scheduler=lrs) # test test_res = evaluate(net, criterion, test_dataloader, metrics=scorings) # 将多次训练的结果保存到一个df中 hist = pd.DataFrame(hist) hist['split_index'] = split_index train_hists.append(hist) # 保存多次test的结果 test_res['split_index'] = split_index test_hists.append(test_res) # 每个split训练的模型保存为一个文件 torch.save(net.state_dict(), os.path.join(config.save_dir, 'model%d.pth' % split_index)) # 保存train的结果 train_hists = pd.concat(train_hists) train_hists.to_csv(os.path.join(config.save_dir, 'train.csv')) # 保存test的结果 test_hists = pd.DataFrame(test_hists) test_hists.to_csv(os.path.join(config.save_dir, 'test.csv'))