Exemplo n.º 1
0
def train(train_loader, model, optimizer, epoch):
    model.train()
    # ---- multi-scale training ----
    size_rates = [0.75, 1, 1.25]
    loss_record2, loss_record3, loss_record4, loss_record5 = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
    for i, pack in enumerate(train_loader, start=1):
        for rate in size_rates:
            optimizer.zero_grad()
            # ---- data prepare ----
            images, gts = pack
            images = Variable(images).cuda()
            gts = Variable(gts).cuda()
            # ---- rescale ----
            trainsize = int(round(opt.trainsize*rate/32)*32)
            if rate != 1:
                images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
                gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
                #gts = F.upsample(gts, size=(trainsize, trainsize), mode='nearest')
            # ---- forward ----
            lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2 = model(images)
            # ---- loss function ----
            loss5 = structure_loss(lateral_map_5, gts)
            loss4 = structure_loss(lateral_map_4, gts)
            loss3 = structure_loss(lateral_map_3, gts)
            loss2 = structure_loss(lateral_map_2, gts)
            loss = loss2 + loss3 + loss4 + loss5    # TODO: try different weights for loss
            # ---- backward ----
            loss.backward()
            clip_gradient(optimizer, opt.clip)
            optimizer.step()
            # ---- recording loss ----
            if rate == 1:
                loss_record2.update(loss2.data, opt.batchsize)
                loss_record3.update(loss3.data, opt.batchsize)
                loss_record4.update(loss4.data, opt.batchsize)
                loss_record5.update(loss5.data, opt.batchsize)
        # ---- train visualization ----
        if i % 20 == 0 or i == total_step:
            print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], '
                  '[lateral-2: {:.4f}, lateral-3: {:0.4f}, lateral-4: {:0.4f}, lateral-5: {:0.4f}]'.
                  format(datetime.now(), epoch, opt.epoch, i, total_step,
                         loss_record2.show(), loss_record3.show(), loss_record4.show(), loss_record5.show()))
    save_path = 'snapshots/{}/'.format(opt.train_save)
    os.makedirs(save_path, exist_ok=True)
    if (epoch+1) % 10 == 0:
        torch.save(model.state_dict(), save_path + 'PraNet-%d.pth' % epoch)
        print('[Saving Snapshot:]', save_path + 'PraNet-%d.pth'% epoch)
Exemplo n.º 2
0
    def train_on_epoches(self, epoch):
        loss_g_a_meter = AvgMeter()
        loss_g_b_meter = AvgMeter()
        loss_cyc_a_meter = AvgMeter()
        loss_cyc_b_meter = AvgMeter()
        loss_d_a = AvgMeter()
        loss_d_b = AvgMeter()
        loss_meters = [loss_g_a_meter, loss_g_b_meter, loss_cyc_a_meter, loss_cyc_b_meter, loss_d_a, loss_d_b]
        loss_names = ['G_A', 'G_B', 'Cyc_A', 'Cyc_B', 'D_A', 'D_B']

        if self.rank == 0:
            progress_bar = tqdm(self.loader, desc='Epoch train')
        else:
            progress_bar = self.loader
        for iter_idx, sample in enumerate(progress_bar):
            losses_set = self.train_on_step(sample)
            for loss, meter in zip(losses_set, loss_meters):
                dist.all_reduce(loss)
                loss = loss / self.args.gpus_num
                meter.update(loss)

            cur_lr = self.optim_G.param_groups[0]['lr']
            step = iter_idx + 1 + epoch * self.each_epoch_iters
            if self.rank == 0:
                str_content = f'epoch: {epoch:d}; lr:{cur_lr:.6f};'
                for meter, name in zip(loss_meters, loss_names):
                    str_content += f' {name}: {meter.avg:.5f};'
                progress_bar.set_postfix(
                    logger=str_content)


                if (iter_idx+1) % 200 == 0:  # tensorboard
                    # print('tensorboard logging.')
                    realA = make_grid(self.realA, nrow=5, padding=2, normalize=True, range=(-1,1))
                    realB = make_grid(self.realB, nrow=5, padding=2, normalize=True, range=(-1,1))
                    fakeA = make_grid(self.fakeA, nrow=5, padding=2, normalize=True, range=(-1,1))
                    fakeB = make_grid(self.fakeB, nrow=5, padding=2, normalize=True, range=(-1,1))
                    recA = make_grid(self.recA, nrow=5, padding=2, normalize=True, range=(-1,1))
                    recB = make_grid(self.recB, nrow=5, padding=2, normalize=True, range=(-1,1))
                    self.td.add_image('realA', realA, step)
                    self.td.add_image('fakeA', fakeA, step)
                    self.td.add_image('realB', realB, step)
                    self.td.add_image('fakeB', fakeB, step)
                    self.td.add_image('recA', recA, step)
                    self.td.add_image('recB', recB, step)
                    for name, meter in zip(loss_names, loss_meters):
                        self.td.add_scalar(name, meter.avg, step)
                    self.td.flush()

        if self.rank == 0:
            progress_bar.close()
Exemplo n.º 3
0
    def val(self, test_loader, epoch):
        len_test = len(test_loader)

        for i, pack in enumerate(test_loader, start=1):
            image, gt = pack
            self.net.eval()
            # if(os.path.exists(os.path.join(save_dir,test_fold,"v" + str(v),name+"_prv" + str(v) + ext))):
            #     continue

            # gt = gt[0][0]
            # gt = np.asarray(gt, np.float32)
            res2 = 0
            image = image.cuda()
            gt = gt.cuda()

            loss_recordx2, loss_recordx3, loss_recordx4, loss_record2, loss_record3, loss_record4, loss_record5 = AvgMeter(
            ), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(
            ), AvgMeter()

            res5, res4, res3, res2 = self.net(image)

            loss5 = self.loss(res5, gt)
            loss4 = self.loss(res4, gt)
            loss3 = self.loss(res3, gt)
            loss2 = self.loss(res2, gt)
            loss = loss2 + loss3 + loss4 + loss5

            loss_record2.update(loss2.data, 1)
            loss_record3.update(loss3.data, 1)
            loss_record4.update(loss4.data, 1)
            loss_record5.update(loss5.data, 1)

            self.writer.add_scalar("Loss1_test", loss_record2.show(),
                                   (epoch - 1) * len(test_loader) + i)
            # writer.add_scalar("Loss2", loss_record3.show(), (epoch-1)*len(train_loader) + i)
            # writer.add_scalar("Loss3", loss_record4.show(), (epoch-1)*len(train_loader) + i)
            # writer.add_scalar("Loss4", loss_record5.show(), (epoch-1)*len(train_loader) + i)

            if i == len_test - 1:
                self.logger.info('TEST:{} Epoch [{:03d}/{:03d}], with lr = {}, Step [{:04d}],\
                    [loss_record2: {:.4f},loss_record3: {:.4f},loss_record4: {:.4f},loss_record5: {:.4f}]'                                                                                                                                                                                                                                                                                                                              .
                    format(datetime.now(), epoch, epoch, self.optimizer.param_groups[0]["lr"],i,\
                            loss_record2.show(), loss_record3.show(), loss_record4.show(), loss_record5.show()
                            ))
Exemplo n.º 4
0
def main():

    parser = ArgumentParser()
    parser.add_argument("-c",
                        "--config",
                        required=True,
                        default="configs/default_config.yaml")
    args = parser.parse_args()

    logger.info("Loading config")
    config_path = args.config
    config = load_cfg(config_path)

    gts = []
    prs = []

    folds = config["test"]["folds"]
    print(folds)
    dataset = config["dataset"]["test_data_path"][0].split("/")[-1]
    if len(folds.keys()) == 1:
        logger.add(
            f'logs/test_{config["model"]["arch"]}_{str(datetime.now())}_{list(folds.keys())[0]}_{dataset}.log',
            rotation="10 MB",
        )
    else:
        logger.add(
            f'logs/test_{config["model"]["arch"]}_{str(datetime.now())}_kfold.log',
            rotation="10 MB",
        )

    for id in list(folds.keys()):

        test_img_paths = []
        test_mask_paths = []
        test_data_path = config["dataset"]["test_data_path"]
        for i in test_data_path:
            test_img_paths.extend(glob(os.path.join(i, "*")))
            test_mask_paths.extend(glob(os.path.join(i, "*")))
        test_img_paths.sort()
        test_mask_paths.sort()

        test_transform = None

        test_loader = get_loader(
            test_img_paths,
            test_mask_paths,
            transform=test_transform,
            **config["test"]["dataloader"],
            type="test",
        )
        test_size = len(test_loader)

        epochs = folds[id]
        if type(epochs) != list:
            epochs = [3 * (epochs // 3) + 2]
        elif len(epochs) == 2:
            epochs = [i for i in range(epochs[0], epochs[1])]
            # epochs = [3 * i + 2 for i in range(epochs[0] // 3, (epochs[1] + 1) // 3)]
        elif len(epochs) == 1:
            epochs = [3 * (epochs[0] // 3) + 2]
        else:
            logger.debug("Model path must have 0 or 1 num")
            break
        for e in epochs:
            # MODEL

            logger.info("Loading model")
            model_prams = config["model"]
            import network.models as models

            arch = model_prams["arch"]

            model = models.__dict__[arch]()  # Pranet
            if "save_dir" not in model_prams:
                save_dir = os.path.join("snapshots",
                                        model_prams["arch"] + "_kfold")
            else:
                save_dir = config["model"]["save_dir"]

            model_path = os.path.join(
                save_dir,
                f"PraNetDG-fold{id}-{e}.pth",
            )

            device = torch.device("cpu")
            # model.cpu()

            model.cuda()
            model.eval()

            logger.info(f"Loading from {model_path}")
            try:
                model.load_state_dict(
                    torch.load(model_path)["model_state_dict"])
            except RuntimeError:
                model.load_state_dict(torch.load(model_path))

            test_fold = "fold" + str(config["dataset"]["fold"])
            logger.info(f"Start testing fold{id} epoch {e}")
            if "visualize_dir" not in config["test"]:
                visualize_dir = "results"
            else:
                visualize_dir = config["test"]["visualize_dir"]

            test_fold = "fold" + str(id)
            logger.info(
                f"Start testing {len(test_loader)} images in {dataset} dataset"
            )
            vals = AvgMeter()
            H, W, T = 240, 240, 155

            for i, pack in tqdm.tqdm(enumerate(test_loader, start=1)):
                image, gt, filename, img = pack
                name = os.path.splitext(filename[0])[0]
                ext = os.path.splitext(filename[0])[1]
                # print(gt.shape,image.shape,"ppp")
                # import sys
                # sys.exit()
                gt = gt[0]
                gt = np.asarray(gt, np.float32)
                res2 = 0
                image = image.cuda()

                res5, res4, res3, res2 = model(image)

                # res = res2
                # res = F.upsample(
                #     res, size=gt.shape, mode="bilinear", align_corners=False
                # )
                # res = res.sigmoid().data.cpu().numpy().squeeze()
                # res = (res - res.min()) / (res.max() - res.min() + 1e-8)
                output = res2[0, :, :H, :W, :T].cpu().detach().numpy()
                output = output.argmax(
                    0
                )  # (num_classes,height,width,depth) num_classes is now one-hot

                target_cpu = gt[:H, :W, :T].numpy()
                scores = softmax_output_dice(output, target_cpu)
                vals.update(np.array(scores))
                # msg += ', '.join(['{}: {:.4f}'.format(k, v) for k, v in zip(keys, scores)])

                seg_img = np.zeros(shape=(H, W, T), dtype=np.uint8)

                # same as res.round()
                seg_img[np.where(output == 1)] = 1
                seg_img[np.where(output == 2)] = 2
                seg_img[np.where(output == 3)] = 4
                # if verbose:
                logger.info(
                    f'1:{np.sum(seg_img==1)} | 2: {np.sum(seg_img==2)} | 4: {np.sum(seg_img==4)}'
                )
                logger.info(
                    f'WT: {np.sum((seg_img==1)|(seg_img==2)|(seg_img==4))} | TC: {np.sum((seg_img==1)|(seg_img==4))} | ET: {np.sum(seg_img==4)}'
                )

                overwrite = config["test"]["vis_overwrite"]
                vis_x = config["test"]["vis_x"]
                if config["test"]["visualize"]:
                    oname = os.path.join(visualize_dir, 'submission',
                                         name[:-8] + '_pred.nii.gz')
                    save_img(
                        oname,
                        seg_img,
                        "nib",
                        overwrite,
                    )
            logger.info(vals.avg)
Exemplo n.º 5
0
def train(train_loader, model, optimizer, epochs, batch_size, train_size, clip,
          test_path):
    best_dice_score = 0
    for epoch in range(1, epochs):
        adjust_lr(optimizer, lr, epoch, 0.1, 200)
        model.train()
        size_rates = [0.75, 1, 1.25]
        loss1_record, loss2_record = AvgMeter(), AvgMeter()
        criterion = WIoUBCELoss()
        for i, pack in enumerate(train_loader, start=1):
            for rate in size_rates:
                optimizer.zero_grad()
                images, gts = pack
                images = Variable(images).cuda()
                gts = Variable(gts).cuda()
                trainsize = int(round(train_size * rate / 32) * 32)
                if rate != 1:
                    images = F.upsample(images,
                                        size=(trainsize, trainsize),
                                        mode='bilinear',
                                        align_corners=True)
                    gts = F.upsample(gts,
                                     size=(trainsize, trainsize),
                                     mode='bilinear',
                                     align_corners=True)
                # predict
                attention_maps, detection_maps = model(images)
                loss1 = criterion(attention_maps, gts)
                loss2 = criterion(detection_maps, gts)
                loss = loss1 + loss2
                loss.backward()
                clip_gradient(optimizer, clip)
                optimizer.step()

                if rate == 1:
                    loss1_record.update(loss1.data, batch_size)
                    loss2_record.update(loss2.data, batch_size)

            if i % 20 == 0 or i == total_step:
                print(
                    f'{datetime.now()} Epoch [{epoch}/{epochs}], Step [{i}/{total_step}], Loss: [{loss1_record.show()}, {loss2_record.show()}]'
                )
                train_logger.info(
                    f'{datetime.now()} Epoch [{epoch}/{epochs}], Step [{i}/{total_step}], Loss: [{loss1_record.show()}, {loss2_record.show()}]'
                )

        save_path = 'checkpoints/'
        os.makedirs(save_path, exist_ok=True)

        if (epoch + 1) % 1 == 0:
            meandice = validation(model, test_path)
            print(f'meandice: {meandice}')
            train_logger.info(f'meandice: {meandice}')
            if meandice > best_dice_score:
                best_dice_score = meandice
                torch.save(model.state_dict(), save_path + 'effnetv2cpd.pth')
                print('[Saving Snapshots:]', save_path + 'effnetv2cpd.pth',
                      meandice)

        if epoch in [50, 60, 70]:
            file_ = 'effnetv2cpd_' + epoch + '.pth'
            torch.save(model.state_dict(), save_path + file_)
            print('[Saving Snapshots:]', save_path + file_, meandice)
Exemplo n.º 6
0
    def fit(self,
            train_loader,
            is_val=False,
            test_loader=None,
            img_size=352,
            start_from=0,
            num_epochs=200,
            batchsize=16,
            clip=0.5,
            fold=4):

        size_rates = [0.75, 1, 1.25]

        test_fold = f'fold{fold}'
        start = timeit.default_timer()
        for epoch in range(start_from, num_epochs):

            self.net.train()
            loss_all, loss_record2, loss_record3, loss_record4, loss_record5 = AvgMeter(
            ), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
            for i, pack in enumerate(train_loader, start=1):
                for rate in size_rates:
                    self.optimizer.zero_grad()

                    # ---- data prepare ----
                    images, gts = pack
                    # images, gts, paths, oriimgs = pack

                    images = Variable(images).cuda()
                    gts = Variable(gts).cuda()

                    trainsize = int(round(img_size * rate / 32) * 32)

                    if rate != 1:
                        images = F.upsample(images,
                                            size=(trainsize, trainsize),
                                            mode='bilinear',
                                            align_corners=True)
                        gts = F.upsample(gts,
                                         size=(trainsize, trainsize),
                                         mode='bilinear',
                                         align_corners=True)

                    lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2 = self.net(
                        images)
                    # lateral_map_5 = self.net(images)

                    loss5 = self.loss(lateral_map_5, gts)
                    # loss4 = self.loss(lateral_map_4, gts)
                    # loss3 = self.loss(lateral_map_3, gts)
                    # loss2 = self.loss(lateral_map_2, gts)

                    # loss = loss2 + loss3 + loss4 + loss5
                    loss = loss5

                    loss.backward()
                    clip_gradient(self.optimizer, clip)
                    self.optimizer.step()

                    if rate == 1:
                        # loss_record2.update(loss2.data, batchsize)
                        # loss_record3.update(loss3.data, batchsize)
                        # loss_record4.update(loss4.data, batchsize)
                        loss_record5.update(loss5.data, batchsize)
                        loss_all.update(loss.data, batchsize)

                        # self.writer.add_scalar("Loss2", loss_record2.show(), (epoch-1)*len(train_loader) + i)
                        # self.writer.add_scalar("Loss3", loss_record3.show(), (epoch-1)*len(train_loader) + i)
                        # self.writer.add_scalar("Loss4", loss_record4.show(), (epoch-1)*len(train_loader) + i)
                        self.writer.add_scalar(
                            "Loss5", loss_record5.show(),
                            (epoch - 1) * len(train_loader) + i)
                        self.writer.add_scalar(
                            "Loss", loss_all.show(),
                            (epoch - 1) * len(train_loader) + i)

                total_step = len(train_loader)
                if i % 25 == 0 or i == total_step:
                    # self.logger.info('{} Epoch [{:03d}/{:03d}], with lr = {}, Step [{:04d}/{:04d}],\
                    #     [loss_record2: {:.4f},loss_record3: {:.4f},loss_record4: {:.4f},loss_record5: {:.4f}]'.
                    #     format(datetime.now(), epoch, epoch, self.optimizer.param_groups[0]["lr"],i, total_step,\
                    #             loss_record2.show(), loss_record3.show(), loss_record4.show(), loss_record5.show()
                    #             ))
                    self.logger.info(
                        '{} Epoch [{:03d}/{:03d}], with lr = {}, Step [{:04d}/{:04d}],\
                        [loss_record5: {:.4f}]'.format(
                            datetime.now(), epoch, epoch,
                            self.optimizer.param_groups[0]["lr"], i,
                            total_step, loss_record5.show()))

            if (is_val):
                self.val(test_loader, epoch)

            os.makedirs(self.save_dir, exist_ok=True)
            if (epoch + 1) % 3 == 0 and epoch > self.save_from or epoch == 23:
                torch.save(
                    {
                        "model_state_dict": self.net.state_dict(),
                        "lr": self.optimizer.param_groups[0]["lr"]
                    },
                    os.path.join(self.save_dir,
                                 'PraNetDG-' + test_fold + '-%d.pth' % epoch))
                self.logger.info(
                    '[Saving Snapshot:]' +
                    os.path.join(self.save_dir, 'PraNetDG-' + test_fold +
                                 '-%d.pth' % epoch))

            self.scheduler.step()

        self.writer.flush()
        self.writer.close()
        end = timeit.default_timer()

        self.logger.info("Training cost: " + str(end - start) + 'seconds')
Exemplo n.º 7
0
    def fit(
        self,
        train_loader,
        is_val=False,
        test_loader=None,
        img_size=352,
        start_from=0,
        num_epochs=200,
        batchsize=16,
        clip=0.5,
        fold=4,
    ):

        size_rates = [0.75, 1, 1.25]
        rate = 1

        test_fold = f"fold{fold}"
        start = timeit.default_timer()
        for epoch in range(start_from, num_epochs):

            self.net.train()
            loss_all, loss_record2, loss_record3, loss_record4, loss_record5 = (
                AvgMeter(),
                AvgMeter(),
                AvgMeter(),
                AvgMeter(),
                AvgMeter(),
            )
            for i, pack in enumerate(train_loader, start=1):

                self.optimizer.zero_grad()

                # ---- data prepare ----
                images, gts = pack
                # images, gts, paths, oriimgs = pack

                images = Variable(images).cuda()
                gts = Variable(gts).cuda()

                lateral_map_5 = self.net(images)
                loss5 = self.loss(lateral_map_5, gts)

                loss5.backward()
                clip_gradient(self.optimizer, clip)
                self.optimizer.step()

                if rate == 1:
                    loss_record5.update(loss5.data, batchsize)
                    self.writer.add_scalar(
                        "Loss5",
                        loss_record5.show(),
                        (epoch - 1) * len(train_loader) + i,
                    )

                total_step = len(train_loader)
                if i % 25 == 0 or i == total_step:
                    self.logger.info(
                        "{} Epoch [{:03d}/{:03d}], with lr = {}, Step [{:04d}/{:04d}],\
                        [loss_record5: {:.4f}]".format(
                            datetime.now(),
                            epoch,
                            epoch,
                            self.optimizer.param_groups[0]["lr"],
                            i,
                            total_step,
                            loss_record5.show(),
                        ))

            if is_val:
                self.val(test_loader, epoch)

            os.makedirs(self.save_dir, exist_ok=True)
            if (epoch + 1) % 3 == 0 and epoch > self.save_from or epoch == 23:
                torch.save(
                    {
                        "model_state_dict": self.net.state_dict(),
                        "lr": self.optimizer.param_groups[0]["lr"],
                    },
                    os.path.join(self.save_dir,
                                 "PraNetDG-" + test_fold + "-%d.pth" % epoch),
                )
                self.logger.info(
                    "[Saving Snapshot:]" +
                    os.path.join(self.save_dir, "PraNetDG-" + test_fold +
                                 "-%d.pth" % epoch))

            self.scheduler.step()

        self.writer.flush()
        self.writer.close()
        end = timeit.default_timer()

        self.logger.info("Training cost: " + str(end - start) + "seconds")