Exemplo n.º 1
0
def train(train_loader, model, optimizer, epochs, batch_size, train_size, clip,
          test_path):
    best_dice_score = 0
    for epoch in range(1, epochs):  # 99 epoch
        adjust_lr(optimizer, lr, epoch, 0.1, 200)
        model.train()
        size_rates = [0.75, 1, 1.25]
        loss_record = AvgMeter()
        criterion = WIoUBCELoss()
        for i, pack in enumerate(train_loader, start=1):
            for rate in size_rates:
                optimizer.zero_grad()
                images, gts = pack
                images = Variable(images).cuda()
                gts = Variable(gts).cuda()
                trainsize = int(round(train_size * rate / 32) * 32)
                if rate != 1:
                    images = F.upsample(images,
                                        size=(trainsize, trainsize),
                                        mode='bilinear',
                                        align_corners=True)
                    gts = F.upsample(gts,
                                     size=(trainsize, trainsize),
                                     mode='bilinear',
                                     align_corners=True)
                # predict
                predict_maps = model(images)
                loss = criterion(predict_maps, gts)
                loss.backward()
                clip_gradient(optimizer, clip)
                optimizer.step()

                if rate == 1:
                    loss_record.update(loss.data, batch_size)

            if i % 20 == 0 or i == total_step:
                print(
                    f'{datetime.now()} Epoch [{epoch}/{epochs}], Step [{i}/{total_step}], Loss: {loss_record.show()}'
                )
                train_logger.info(
                    f'{datetime.now()} Epoch [{epoch}/{epochs}], Step [{i}/{total_step}], Loss: {loss_record.show()}'
                )

        save_path = 'checkpoints/'
        os.makedirs(save_path, exist_ok=True)

        if (epoch + 1) % 1 == 0:
            meandice = validation(model, test_path)
            print(f'meandice: {meandice}')
            train_logger.info(f'meandice: {meandice}')
            if meandice > best_dice_score:
                best_dice_score = meandice
                torch.save(model.state_dict(), save_path + 'effnetv2pd.pth')
                print('[Saving Snapshots:]', save_path + 'effnetv2pd.pth',
                      meandice)
        if epoch in [50, 60, 70]:
            file_ = 'effnetv2pd_' + str(epoch) + '.pth'
            torch.save(model.state_dict(), save_path + file_)
            print('[Saving Snapshots:]', save_path + file_, meandice)
Exemplo n.º 2
0
def train(train_loader, model, optimizer, epoch, test_path, best):
    model.train()
    size_rates = [0.75, 1, 1.25]
    loss_attention_record, loss_detection_record = AvgMeter(), AvgMeter()
    criterion = WIoUBCELoss()
    for i, pack in enumerate(train_loader, start=1):
        for rate in size_rates:
            optimizer.zero_grad()
            images, gts = pack
            images = Variable(images).cuda()
            gts = Variable(gts).cuda()
            trainsize = int(round(opt.trainsize * rate / 32) * 32)
            if rate != 1:
                images = F.upsample(images,
                                    size=(trainsize, trainsize),
                                    mode='bilinear',
                                    align_corners=True)
                gts = F.upsample(gts,
                                 size=(trainsize, trainsize),
                                 mode='bilinear',
                                 align_corners=True)

            attention_map, detection_map = model(images)
            loss1 = criterion(attention_map, gts)
            loss2 = criterion(detection_map, gts)
            loss = loss1 + loss2
            loss.backward()
            clip_gradient(optimizer, opt.clip)
            optimizer.step()

            if rate == 1:
                loss_attention_record.update(loss1.data, opt.batchsize)
                loss_detection_record.update(loss2.data, opt.batchsize)

        if i % 20 == 0 or i == total_step:
            print(
                f'{datetime.now()} Epoch [{epoch}/{opt.epoch}], Step [{i}/{total_step}], \
                    Loss [attention_loss: {loss_attention_record.show()}, detection_loss: {loss_detection_record.show()}]'
            )
            train_logger.info(
                f'{datetime.now()} Epoch [{epoch}/{opt.epoch}], Step [{i}/{total_step}], \
                                Loss [attention_loss: {loss_attention_record.show()}, detection_loss: {loss_detection_record.show()}]'
            )

    save_path = 'snapshots/{}/'.format(opt.train_save)
    os.makedirs(save_path, exist_ok=True)

    if (epoch + 1) % 1 == 0:
        meandice = validation(model, test_path)
        if meandice > best:
            best = meandice
            torch.save(model.state_dict(), save_path + 'HarDMSEG-best.pth')
            print('[Saving Snapshots:]', save_path + 'HarDMSEG-best.pth',
                  meandice)
            train_logger.info(
                f'[Saving Snapshots: {save_path + "HarGMSEG-best"} {meandice}]'
            )

    return best
Exemplo n.º 3
0
def train(train_loader, model, optimizer, epochs, batch_size, train_size, clip, test_path):
    best_dice_score = 0
    for epoch in range(1, epochs):
        adjust_lr(optimizer, lr, epoch, 0.1, 200)
        for param in optimizer.param_groups:
            print(param['lr'])
        model.train()
        size_rates = [0.75, 1, 1.25]
        loss_record2, loss_record3, loss_record4, loss_record5 = AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter()
        criterion = WIoUBCELoss()
        for i, pack in enumerate(train_loader, start=1):
            for rate in size_rates:
                optimizer.zero_grad()
                images, gts = pack
                images = Variable(images).cuda()
                gts = Variable(gts).cuda()
                trainsize = int(round(train_size * rate / 32) * 32)
                if rate != 1:
                    images = F.upsample(images, size=(
                        trainsize, trainsize), mode='bilinear', align_corners=True)
                    gts = F.upsample(gts, size=(
                        trainsize, trainsize), mode='bilinear', align_corners=True)
                # predict
                lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2 = model(
                    images)
                loss5 = criterion(lateral_map_5, gts)
                loss4 = criterion(lateral_map_4, gts)
                loss3 = criterion(lateral_map_3, gts)
                loss2 = criterion(lateral_map_2, gts)
                loss = loss2 + loss3 + loss4 + loss5
                loss.backward()
                clip_gradient(optimizer, clip)
                optimizer.step()

                if rate == 1:
                    loss_record2.update(loss2.data, batch_size)
                    loss_record3.update(loss3.data, batch_size)
                    loss_record4.update(loss4.data, batch_size)
                    loss_record5.update(loss5.data, batch_size)

            if i % 20 == 0 or i == total_step:
                print(f'{datetime.now()} Epoch [{epoch}/{epochs}], Step [{i}/{total_step}], [lateral-2: {loss_record2.show()}, lateral-3: {loss_record3.show()}, lateral-4: {loss_record4.show()}, lateral-5: {loss_record5.show()},]')
                train_logger.info(
                    f'{datetime.now()} Epoch [{epoch}/{epochs}], Step [{i}/{total_step}], [lateral-2: {loss_record2.show()}, lateral-3: {loss_record3.show()}, lateral-4: {loss_record4.show()}, lateral-5: {loss_record5.show()},]')

        save_path = 'checkpoints/'
        os.makedirs(save_path, exist_ok=True)

        if (epoch+1) % 1 == 0:
            meandice = validation(model, test_path)
            print(f'meandice: {meandice}')
            train_logger.info(f'meandice: {meandice}')
            if meandice > best_dice_score:
                best_dice_score = meandice
                torch.save(model.state_dict(), save_path + 'PraHarDNet.pth')
                print('[Saving Snapshots:]', save_path + 'PraHarDNet.pth', meandice)