def train(train_loader, model, optimizer, epoch): model.train() # ---- multi-scale training ---- size_rates = [0.75, 1, 1.25] loss_record2, loss_record3, loss_record4, loss_record5 = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() for i, pack in enumerate(train_loader, start=1): for rate in size_rates: optimizer.zero_grad() # ---- data prepare ---- images, gts = pack images = Variable(images).cuda() gts = Variable(gts).cuda() # ---- rescale ---- trainsize = int(round(opt.trainsize*rate/32)*32) if rate != 1: images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True) gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True) #gts = F.upsample(gts, size=(trainsize, trainsize), mode='nearest') # ---- forward ---- lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2 = model(images) # ---- loss function ---- loss5 = structure_loss(lateral_map_5, gts) loss4 = structure_loss(lateral_map_4, gts) loss3 = structure_loss(lateral_map_3, gts) loss2 = structure_loss(lateral_map_2, gts) loss = loss2 + loss3 + loss4 + loss5 # TODO: try different weights for loss # ---- backward ---- loss.backward() clip_gradient(optimizer, opt.clip) optimizer.step() # ---- recording loss ---- if rate == 1: loss_record2.update(loss2.data, opt.batchsize) loss_record3.update(loss3.data, opt.batchsize) loss_record4.update(loss4.data, opt.batchsize) loss_record5.update(loss5.data, opt.batchsize) # ---- train visualization ---- if i % 20 == 0 or i == total_step: print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], ' '[lateral-2: {:.4f}, lateral-3: {:0.4f}, lateral-4: {:0.4f}, lateral-5: {:0.4f}]'. format(datetime.now(), epoch, opt.epoch, i, total_step, loss_record2.show(), loss_record3.show(), loss_record4.show(), loss_record5.show())) save_path = 'snapshots/{}/'.format(opt.train_save) os.makedirs(save_path, exist_ok=True) if (epoch+1) % 10 == 0: torch.save(model.state_dict(), save_path + 'PraNet-%d.pth' % epoch) print('[Saving Snapshot:]', save_path + 'PraNet-%d.pth'% epoch)
def train_on_epoches(self, epoch): loss_g_a_meter = AvgMeter() loss_g_b_meter = AvgMeter() loss_cyc_a_meter = AvgMeter() loss_cyc_b_meter = AvgMeter() loss_d_a = AvgMeter() loss_d_b = AvgMeter() loss_meters = [loss_g_a_meter, loss_g_b_meter, loss_cyc_a_meter, loss_cyc_b_meter, loss_d_a, loss_d_b] loss_names = ['G_A', 'G_B', 'Cyc_A', 'Cyc_B', 'D_A', 'D_B'] if self.rank == 0: progress_bar = tqdm(self.loader, desc='Epoch train') else: progress_bar = self.loader for iter_idx, sample in enumerate(progress_bar): losses_set = self.train_on_step(sample) for loss, meter in zip(losses_set, loss_meters): dist.all_reduce(loss) loss = loss / self.args.gpus_num meter.update(loss) cur_lr = self.optim_G.param_groups[0]['lr'] step = iter_idx + 1 + epoch * self.each_epoch_iters if self.rank == 0: str_content = f'epoch: {epoch:d}; lr:{cur_lr:.6f};' for meter, name in zip(loss_meters, loss_names): str_content += f' {name}: {meter.avg:.5f};' progress_bar.set_postfix( logger=str_content) if (iter_idx+1) % 200 == 0: # tensorboard # print('tensorboard logging.') realA = make_grid(self.realA, nrow=5, padding=2, normalize=True, range=(-1,1)) realB = make_grid(self.realB, nrow=5, padding=2, normalize=True, range=(-1,1)) fakeA = make_grid(self.fakeA, nrow=5, padding=2, normalize=True, range=(-1,1)) fakeB = make_grid(self.fakeB, nrow=5, padding=2, normalize=True, range=(-1,1)) recA = make_grid(self.recA, nrow=5, padding=2, normalize=True, range=(-1,1)) recB = make_grid(self.recB, nrow=5, padding=2, normalize=True, range=(-1,1)) self.td.add_image('realA', realA, step) self.td.add_image('fakeA', fakeA, step) self.td.add_image('realB', realB, step) self.td.add_image('fakeB', fakeB, step) self.td.add_image('recA', recA, step) self.td.add_image('recB', recB, step) for name, meter in zip(loss_names, loss_meters): self.td.add_scalar(name, meter.avg, step) self.td.flush() if self.rank == 0: progress_bar.close()
def val(self, test_loader, epoch): len_test = len(test_loader) for i, pack in enumerate(test_loader, start=1): image, gt = pack self.net.eval() # if(os.path.exists(os.path.join(save_dir,test_fold,"v" + str(v),name+"_prv" + str(v) + ext))): # continue # gt = gt[0][0] # gt = np.asarray(gt, np.float32) res2 = 0 image = image.cuda() gt = gt.cuda() loss_recordx2, loss_recordx3, loss_recordx4, loss_record2, loss_record3, loss_record4, loss_record5 = AvgMeter( ), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter( ), AvgMeter() res5, res4, res3, res2 = self.net(image) loss5 = self.loss(res5, gt) loss4 = self.loss(res4, gt) loss3 = self.loss(res3, gt) loss2 = self.loss(res2, gt) loss = loss2 + loss3 + loss4 + loss5 loss_record2.update(loss2.data, 1) loss_record3.update(loss3.data, 1) loss_record4.update(loss4.data, 1) loss_record5.update(loss5.data, 1) self.writer.add_scalar("Loss1_test", loss_record2.show(), (epoch - 1) * len(test_loader) + i) # writer.add_scalar("Loss2", loss_record3.show(), (epoch-1)*len(train_loader) + i) # writer.add_scalar("Loss3", loss_record4.show(), (epoch-1)*len(train_loader) + i) # writer.add_scalar("Loss4", loss_record5.show(), (epoch-1)*len(train_loader) + i) if i == len_test - 1: self.logger.info('TEST:{} Epoch [{:03d}/{:03d}], with lr = {}, Step [{:04d}],\ [loss_record2: {:.4f},loss_record3: {:.4f},loss_record4: {:.4f},loss_record5: {:.4f}]' . format(datetime.now(), epoch, epoch, self.optimizer.param_groups[0]["lr"],i,\ loss_record2.show(), loss_record3.show(), loss_record4.show(), loss_record5.show() ))
def main(): parser = ArgumentParser() parser.add_argument("-c", "--config", required=True, default="configs/default_config.yaml") args = parser.parse_args() logger.info("Loading config") config_path = args.config config = load_cfg(config_path) gts = [] prs = [] folds = config["test"]["folds"] print(folds) dataset = config["dataset"]["test_data_path"][0].split("/")[-1] if len(folds.keys()) == 1: logger.add( f'logs/test_{config["model"]["arch"]}_{str(datetime.now())}_{list(folds.keys())[0]}_{dataset}.log', rotation="10 MB", ) else: logger.add( f'logs/test_{config["model"]["arch"]}_{str(datetime.now())}_kfold.log', rotation="10 MB", ) for id in list(folds.keys()): test_img_paths = [] test_mask_paths = [] test_data_path = config["dataset"]["test_data_path"] for i in test_data_path: test_img_paths.extend(glob(os.path.join(i, "*"))) test_mask_paths.extend(glob(os.path.join(i, "*"))) test_img_paths.sort() test_mask_paths.sort() test_transform = None test_loader = get_loader( test_img_paths, test_mask_paths, transform=test_transform, **config["test"]["dataloader"], type="test", ) test_size = len(test_loader) epochs = folds[id] if type(epochs) != list: epochs = [3 * (epochs // 3) + 2] elif len(epochs) == 2: epochs = [i for i in range(epochs[0], epochs[1])] # epochs = [3 * i + 2 for i in range(epochs[0] // 3, (epochs[1] + 1) // 3)] elif len(epochs) == 1: epochs = [3 * (epochs[0] // 3) + 2] else: logger.debug("Model path must have 0 or 1 num") break for e in epochs: # MODEL logger.info("Loading model") model_prams = config["model"] import network.models as models arch = model_prams["arch"] model = models.__dict__[arch]() # Pranet if "save_dir" not in model_prams: save_dir = os.path.join("snapshots", model_prams["arch"] + "_kfold") else: save_dir = config["model"]["save_dir"] model_path = os.path.join( save_dir, f"PraNetDG-fold{id}-{e}.pth", ) device = torch.device("cpu") # model.cpu() model.cuda() model.eval() logger.info(f"Loading from {model_path}") try: model.load_state_dict( torch.load(model_path)["model_state_dict"]) except RuntimeError: model.load_state_dict(torch.load(model_path)) test_fold = "fold" + str(config["dataset"]["fold"]) logger.info(f"Start testing fold{id} epoch {e}") if "visualize_dir" not in config["test"]: visualize_dir = "results" else: visualize_dir = config["test"]["visualize_dir"] test_fold = "fold" + str(id) logger.info( f"Start testing {len(test_loader)} images in {dataset} dataset" ) vals = AvgMeter() H, W, T = 240, 240, 155 for i, pack in tqdm.tqdm(enumerate(test_loader, start=1)): image, gt, filename, img = pack name = os.path.splitext(filename[0])[0] ext = os.path.splitext(filename[0])[1] # print(gt.shape,image.shape,"ppp") # import sys # sys.exit() gt = gt[0] gt = np.asarray(gt, np.float32) res2 = 0 image = image.cuda() res5, res4, res3, res2 = model(image) # res = res2 # res = F.upsample( # res, size=gt.shape, mode="bilinear", align_corners=False # ) # res = res.sigmoid().data.cpu().numpy().squeeze() # res = (res - res.min()) / (res.max() - res.min() + 1e-8) output = res2[0, :, :H, :W, :T].cpu().detach().numpy() output = output.argmax( 0 ) # (num_classes,height,width,depth) num_classes is now one-hot target_cpu = gt[:H, :W, :T].numpy() scores = softmax_output_dice(output, target_cpu) vals.update(np.array(scores)) # msg += ', '.join(['{}: {:.4f}'.format(k, v) for k, v in zip(keys, scores)]) seg_img = np.zeros(shape=(H, W, T), dtype=np.uint8) # same as res.round() seg_img[np.where(output == 1)] = 1 seg_img[np.where(output == 2)] = 2 seg_img[np.where(output == 3)] = 4 # if verbose: logger.info( f'1:{np.sum(seg_img==1)} | 2: {np.sum(seg_img==2)} | 4: {np.sum(seg_img==4)}' ) logger.info( f'WT: {np.sum((seg_img==1)|(seg_img==2)|(seg_img==4))} | TC: {np.sum((seg_img==1)|(seg_img==4))} | ET: {np.sum(seg_img==4)}' ) overwrite = config["test"]["vis_overwrite"] vis_x = config["test"]["vis_x"] if config["test"]["visualize"]: oname = os.path.join(visualize_dir, 'submission', name[:-8] + '_pred.nii.gz') save_img( oname, seg_img, "nib", overwrite, ) logger.info(vals.avg)
def train(train_loader, model, optimizer, epochs, batch_size, train_size, clip, test_path): best_dice_score = 0 for epoch in range(1, epochs): adjust_lr(optimizer, lr, epoch, 0.1, 200) model.train() size_rates = [0.75, 1, 1.25] loss1_record, loss2_record = AvgMeter(), AvgMeter() criterion = WIoUBCELoss() for i, pack in enumerate(train_loader, start=1): for rate in size_rates: optimizer.zero_grad() images, gts = pack images = Variable(images).cuda() gts = Variable(gts).cuda() trainsize = int(round(train_size * rate / 32) * 32) if rate != 1: images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True) gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True) # predict attention_maps, detection_maps = model(images) loss1 = criterion(attention_maps, gts) loss2 = criterion(detection_maps, gts) loss = loss1 + loss2 loss.backward() clip_gradient(optimizer, clip) optimizer.step() if rate == 1: loss1_record.update(loss1.data, batch_size) loss2_record.update(loss2.data, batch_size) if i % 20 == 0 or i == total_step: print( f'{datetime.now()} Epoch [{epoch}/{epochs}], Step [{i}/{total_step}], Loss: [{loss1_record.show()}, {loss2_record.show()}]' ) train_logger.info( f'{datetime.now()} Epoch [{epoch}/{epochs}], Step [{i}/{total_step}], Loss: [{loss1_record.show()}, {loss2_record.show()}]' ) save_path = 'checkpoints/' os.makedirs(save_path, exist_ok=True) if (epoch + 1) % 1 == 0: meandice = validation(model, test_path) print(f'meandice: {meandice}') train_logger.info(f'meandice: {meandice}') if meandice > best_dice_score: best_dice_score = meandice torch.save(model.state_dict(), save_path + 'effnetv2cpd.pth') print('[Saving Snapshots:]', save_path + 'effnetv2cpd.pth', meandice) if epoch in [50, 60, 70]: file_ = 'effnetv2cpd_' + epoch + '.pth' torch.save(model.state_dict(), save_path + file_) print('[Saving Snapshots:]', save_path + file_, meandice)
def fit(self, train_loader, is_val=False, test_loader=None, img_size=352, start_from=0, num_epochs=200, batchsize=16, clip=0.5, fold=4): size_rates = [0.75, 1, 1.25] test_fold = f'fold{fold}' start = timeit.default_timer() for epoch in range(start_from, num_epochs): self.net.train() loss_all, loss_record2, loss_record3, loss_record4, loss_record5 = AvgMeter( ), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() for i, pack in enumerate(train_loader, start=1): for rate in size_rates: self.optimizer.zero_grad() # ---- data prepare ---- images, gts = pack # images, gts, paths, oriimgs = pack images = Variable(images).cuda() gts = Variable(gts).cuda() trainsize = int(round(img_size * rate / 32) * 32) if rate != 1: images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True) gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True) lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2 = self.net( images) # lateral_map_5 = self.net(images) loss5 = self.loss(lateral_map_5, gts) # loss4 = self.loss(lateral_map_4, gts) # loss3 = self.loss(lateral_map_3, gts) # loss2 = self.loss(lateral_map_2, gts) # loss = loss2 + loss3 + loss4 + loss5 loss = loss5 loss.backward() clip_gradient(self.optimizer, clip) self.optimizer.step() if rate == 1: # loss_record2.update(loss2.data, batchsize) # loss_record3.update(loss3.data, batchsize) # loss_record4.update(loss4.data, batchsize) loss_record5.update(loss5.data, batchsize) loss_all.update(loss.data, batchsize) # self.writer.add_scalar("Loss2", loss_record2.show(), (epoch-1)*len(train_loader) + i) # self.writer.add_scalar("Loss3", loss_record3.show(), (epoch-1)*len(train_loader) + i) # self.writer.add_scalar("Loss4", loss_record4.show(), (epoch-1)*len(train_loader) + i) self.writer.add_scalar( "Loss5", loss_record5.show(), (epoch - 1) * len(train_loader) + i) self.writer.add_scalar( "Loss", loss_all.show(), (epoch - 1) * len(train_loader) + i) total_step = len(train_loader) if i % 25 == 0 or i == total_step: # self.logger.info('{} Epoch [{:03d}/{:03d}], with lr = {}, Step [{:04d}/{:04d}],\ # [loss_record2: {:.4f},loss_record3: {:.4f},loss_record4: {:.4f},loss_record5: {:.4f}]'. # format(datetime.now(), epoch, epoch, self.optimizer.param_groups[0]["lr"],i, total_step,\ # loss_record2.show(), loss_record3.show(), loss_record4.show(), loss_record5.show() # )) self.logger.info( '{} Epoch [{:03d}/{:03d}], with lr = {}, Step [{:04d}/{:04d}],\ [loss_record5: {:.4f}]'.format( datetime.now(), epoch, epoch, self.optimizer.param_groups[0]["lr"], i, total_step, loss_record5.show())) if (is_val): self.val(test_loader, epoch) os.makedirs(self.save_dir, exist_ok=True) if (epoch + 1) % 3 == 0 and epoch > self.save_from or epoch == 23: torch.save( { "model_state_dict": self.net.state_dict(), "lr": self.optimizer.param_groups[0]["lr"] }, os.path.join(self.save_dir, 'PraNetDG-' + test_fold + '-%d.pth' % epoch)) self.logger.info( '[Saving Snapshot:]' + os.path.join(self.save_dir, 'PraNetDG-' + test_fold + '-%d.pth' % epoch)) self.scheduler.step() self.writer.flush() self.writer.close() end = timeit.default_timer() self.logger.info("Training cost: " + str(end - start) + 'seconds')
def fit( self, train_loader, is_val=False, test_loader=None, img_size=352, start_from=0, num_epochs=200, batchsize=16, clip=0.5, fold=4, ): size_rates = [0.75, 1, 1.25] rate = 1 test_fold = f"fold{fold}" start = timeit.default_timer() for epoch in range(start_from, num_epochs): self.net.train() loss_all, loss_record2, loss_record3, loss_record4, loss_record5 = ( AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), ) for i, pack in enumerate(train_loader, start=1): self.optimizer.zero_grad() # ---- data prepare ---- images, gts = pack # images, gts, paths, oriimgs = pack images = Variable(images).cuda() gts = Variable(gts).cuda() lateral_map_5 = self.net(images) loss5 = self.loss(lateral_map_5, gts) loss5.backward() clip_gradient(self.optimizer, clip) self.optimizer.step() if rate == 1: loss_record5.update(loss5.data, batchsize) self.writer.add_scalar( "Loss5", loss_record5.show(), (epoch - 1) * len(train_loader) + i, ) total_step = len(train_loader) if i % 25 == 0 or i == total_step: self.logger.info( "{} Epoch [{:03d}/{:03d}], with lr = {}, Step [{:04d}/{:04d}],\ [loss_record5: {:.4f}]".format( datetime.now(), epoch, epoch, self.optimizer.param_groups[0]["lr"], i, total_step, loss_record5.show(), )) if is_val: self.val(test_loader, epoch) os.makedirs(self.save_dir, exist_ok=True) if (epoch + 1) % 3 == 0 and epoch > self.save_from or epoch == 23: torch.save( { "model_state_dict": self.net.state_dict(), "lr": self.optimizer.param_groups[0]["lr"], }, os.path.join(self.save_dir, "PraNetDG-" + test_fold + "-%d.pth" % epoch), ) self.logger.info( "[Saving Snapshot:]" + os.path.join(self.save_dir, "PraNetDG-" + test_fold + "-%d.pth" % epoch)) self.scheduler.step() self.writer.flush() self.writer.close() end = timeit.default_timer() self.logger.info("Training cost: " + str(end - start) + "seconds")