Esempio n. 1
0
def train(train_loader, model, optimizer, epoch):
    model.train()
    # ---- multi-scale training ----
    loss_record = AvgMeter()
    for i, pack in enumerate(train_loader, start=1):
        optimizer.zero_grad()
        # ---- data prepare ----
        images, gts = pack['image'], pack['label']
        images = Variable(images).cuda()
        gts = Variable(gts).cuda()
        # ---- forward ----
        pred = model(images)
        # ---- loss function ----
        loss4 = BCEDiceLoss(pred[4], gts)
        loss3 = BCEDiceLoss(pred[3], gts)
        loss2 = BCEDiceLoss(pred[2], gts)
        loss1 = BCEDiceLoss(pred[1], gts)
        loss0 = BCEDiceLoss(pred[0], gts)
        loss = loss0 + loss1 + loss2 + loss3 + loss4
        # ---- backward ----
        loss.backward()
        optimizer.step()
        # ---- recording loss ----
        loss_record.update(loss.data, opt.batchsize)
        # ---- train visualization ----
        if i == total_step:
            print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], '
                  '[loss: {:.4f}]'.format(datetime.now(), epoch, opt.epoch, i,
                                          total_step, loss_record.show()))
    save_path = './snapshots/{}/'.format(opt.save_root)
    os.makedirs(save_path, exist_ok=True)
    if epoch % 50 == 0:
        torch.save(model.state_dict(), save_path + 'model-%d.pth' % epoch)
        print('[Saving Snapshot:]', save_path + 'model-%d.pth' % epoch)
Esempio n. 2
0
def train(train_loader, model, optimizer, epoch):
    model.train()
    # ---- multi-scale training ----
    size_rates = [0.75, 1, 1.25]
    loss_record2, loss_record3, loss_record4, loss_record5 = AvgMeter(
    ), AvgMeter(), AvgMeter(), AvgMeter()
    for i, pack in enumerate(train_loader, start=1):
        for rate in size_rates:
            optimizer.zero_grad()
            # ---- data prepare ----
            images, gts = pack
            images = Variable(images).cuda()
            gts = Variable(gts).cuda()
            # ---- rescale ----
            trainsize = int(round(opt.trainsize * rate / 32) * 32)
            if rate != 1:
                images = F.upsample(images,
                                    size=(trainsize, trainsize),
                                    mode='bilinear',
                                    align_corners=True)
                gts = F.upsample(gts,
                                 size=(trainsize, trainsize),
                                 mode='bilinear',
                                 align_corners=True)
            # ---- forward ----
            lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2 = model(
                images)
            # ---- loss function ----
            loss5 = structure_loss(lateral_map_5, gts)
            loss4 = structure_loss(lateral_map_4, gts)
            loss3 = structure_loss(lateral_map_3, gts)
            loss2 = structure_loss(lateral_map_2, gts)
            loss = loss2 + loss3 + loss4 + loss5  # TODO: try different weights for loss
            # ---- backward ----
            loss.backward()
            clip_gradient(optimizer, opt.clip)
            optimizer.step()
            # ---- recording loss ----
            if rate == 1:
                loss_record2.update(loss2.data, opt.batchsize)
                loss_record3.update(loss3.data, opt.batchsize)
                loss_record4.update(loss4.data, opt.batchsize)
                loss_record5.update(loss5.data, opt.batchsize)
        # ---- train visualization ----
        if i % 20 == 0 or i == total_step:
            print(
                '{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], '
                '[lateral-2: {:.4f}, lateral-3: {:0.4f}, lateral-4: {:0.4f}, lateral-5: {:0.4f}]'
                .format(datetime.now(), epoch, opt.epoch, i, total_step,
                        loss_record2.show(), loss_record3.show(),
                        loss_record4.show(), loss_record5.show()))
    save_path = 'snapshots/{}/'.format(opt.train_save)
    os.makedirs(save_path, exist_ok=True)
    if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), save_path + 'PraNet-%d.pth' % epoch)
        print('[Saving Snapshot:]', save_path + 'PraNet-%d.pth' % epoch)
    def fit(self,
            train_loader,
            is_val=False,
            test_loader=None,
            img_size=352,
            start_from=0,
            num_epochs=200,
            batchsize=16,
            clip=0.5,
            fold=4):

        size_rates = [0.75, 1, 1.25]

        test_fold = f'fold{fold}'
        start = timeit.default_timer()
        for epoch in range(start_from, num_epochs):

            self.net.train()
            loss_all, loss_record2, loss_record3, loss_record4, loss_record5 = AvgMeter(
            ), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
            for i, pack in enumerate(train_loader, start=1):
                for rate in size_rates:
                    self.optimizer.zero_grad()

                    # ---- data prepare ----
                    images, gts = pack
                    # images, gts, paths, oriimgs = pack

                    images = Variable(images).cuda()
                    gts = Variable(gts).cuda()

                    trainsize = int(round(img_size * rate / 32) * 32)

                    if rate != 1:
                        images = F.upsample(images,
                                            size=(trainsize, trainsize),
                                            mode='bilinear',
                                            align_corners=True)
                        gts = F.upsample(gts,
                                         size=(trainsize, trainsize),
                                         mode='bilinear',
                                         align_corners=True)

                    lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2 = self.net(
                        images)
                    # lateral_map_5 = self.net(images)

                    loss5 = self.loss(lateral_map_5, gts)
                    # loss4 = self.loss(lateral_map_4, gts)
                    # loss3 = self.loss(lateral_map_3, gts)
                    # loss2 = self.loss(lateral_map_2, gts)

                    # loss = loss2 + loss3 + loss4 + loss5
                    loss = loss5

                    loss.backward()
                    clip_gradient(self.optimizer, clip)
                    self.optimizer.step()

                    if rate == 1:
                        # loss_record2.update(loss2.data, batchsize)
                        # loss_record3.update(loss3.data, batchsize)
                        # loss_record4.update(loss4.data, batchsize)
                        loss_record5.update(loss5.data, batchsize)
                        loss_all.update(loss.data, batchsize)

                        # self.writer.add_scalar("Loss2", loss_record2.show(), (epoch-1)*len(train_loader) + i)
                        # self.writer.add_scalar("Loss3", loss_record3.show(), (epoch-1)*len(train_loader) + i)
                        # self.writer.add_scalar("Loss4", loss_record4.show(), (epoch-1)*len(train_loader) + i)
                        self.writer.add_scalar(
                            "Loss5", loss_record5.show(),
                            (epoch - 1) * len(train_loader) + i)
                        self.writer.add_scalar(
                            "Loss", loss_all.show(),
                            (epoch - 1) * len(train_loader) + i)

                total_step = len(train_loader)
                if i % 25 == 0 or i == total_step:
                    # self.logger.info('{} Epoch [{:03d}/{:03d}], with lr = {}, Step [{:04d}/{:04d}],\
                    #     [loss_record2: {:.4f},loss_record3: {:.4f},loss_record4: {:.4f},loss_record5: {:.4f}]'.
                    #     format(datetime.now(), epoch, epoch, self.optimizer.param_groups[0]["lr"],i, total_step,\
                    #             loss_record2.show(), loss_record3.show(), loss_record4.show(), loss_record5.show()
                    #             ))
                    self.logger.info(
                        '{} Epoch [{:03d}/{:03d}], with lr = {}, Step [{:04d}/{:04d}],\
                        [loss_record5: {:.4f}]'.format(
                            datetime.now(), epoch, epoch,
                            self.optimizer.param_groups[0]["lr"], i,
                            total_step, loss_record5.show()))

            if (is_val):
                self.val(test_loader, epoch)

            os.makedirs(self.save_dir, exist_ok=True)
            if (epoch + 1) % 3 == 0 and epoch > self.save_from or epoch == 23:
                torch.save(
                    {
                        "model_state_dict": self.net.state_dict(),
                        "lr": self.optimizer.param_groups[0]["lr"]
                    },
                    os.path.join(self.save_dir,
                                 'PraNetDG-' + test_fold + '-%d.pth' % epoch))
                self.logger.info(
                    '[Saving Snapshot:]' +
                    os.path.join(self.save_dir, 'PraNetDG-' + test_fold +
                                 '-%d.pth' % epoch))

            self.scheduler.step()

        self.writer.flush()
        self.writer.close()
        end = timeit.default_timer()

        self.logger.info("Training cost: " + str(end - start) + 'seconds')