Beispiel #1
0
def train(train_loader, model, optimizer, epoch, train_save):
    model.train()
    # ---- multi-scale training ----
    size_rates = [0.75, 1, 1.25]    # replace your desired scale, try larger scale for better accuracy in small object
    loss_record1, loss_record2, loss_record3, loss_record4, loss_record5 = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
    for i, pack in enumerate(train_loader, start=1):
        for rate in size_rates:
            optimizer.zero_grad()
            # ---- data prepare ----
            images, gts, edges = pack
            images = Variable(images).cuda()
            gts = Variable(gts).cuda()
            edges = Variable(edges).cuda()
            # ---- rescaling the inputs (img/gt/edge) ----
            trainsize = int(round(opt.trainsize*rate/32)*32)
            if rate != 1:
                images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
                gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
                edges = F.upsample(edges, size=(trainsize, trainsize), mode='bilinear', align_corners=True)

            # ---- forward ----
            lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2, lateral_edge = model(images)
            # ---- loss function ----
            loss5 = joint_loss(lateral_map_5, gts)
            loss4 = joint_loss(lateral_map_4, gts)
            loss3 = joint_loss(lateral_map_3, gts)
            loss2 = joint_loss(lateral_map_2, gts)
            loss1 = BCE(lateral_edge, edges)
            loss = loss1 + loss2 + loss3 + loss4 + loss5
            # ---- backward ----
            loss.backward()
            clip_gradient(optimizer, opt.clip)
            optimizer.step()
            # ---- recording loss ----
            if rate == 1:
                loss_record1.update(loss1.data, opt.batchsize)
                loss_record2.update(loss2.data, opt.batchsize)
                loss_record3.update(loss3.data, opt.batchsize)
                loss_record4.update(loss4.data, opt.batchsize)
                loss_record5.update(loss5.data, opt.batchsize)
        # ---- train logging ----
        if i % 20 == 0 or i == total_step:
            print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], [lateral-edge: {:.4f}, '
                  'lateral-2: {:.4f}, lateral-3: {:0.4f}, lateral-4: {:0.4f}, lateral-5: {:0.4f}]'.
                  format(datetime.now(), epoch, opt.epoch, i, total_step, loss_record1.show(),
                         loss_record2.show(), loss_record3.show(), loss_record4.show(), loss_record5.show()))

    # ---- save model_lung_infection ----
    save_path = './Snapshots/save_weights/{}/'.format(train_save)
    os.makedirs(save_path, exist_ok=True)

    if (epoch+1) % 10 == 0:
        torch.save(model.state_dict(), save_path + 'Inf-Net-%d.pth' % (epoch+1))
        print('[Saving Snapshot:]', save_path + 'Inf-Net-%d.pth' % (epoch+1))

    return model
Beispiel #2
0
def train(train_loader, test_loader, model, optimizer, epoch, train_save,
          device, opt):
    global global_current_iteration
    global best_loss
    global focal_loss_criterion

    if opt.lookahead:
        optimizer = Lookahead(optimizer, k=5, alpha=0.5)
    optimizer.zero_grad()
    focal_loss_criterion = focal_loss_criterion.to(device)

    model.train()
    # ---- multi-scale training ----
    size_rates = [
        0.75, 1, 1.25
    ]  # replace your desired scale, try larger scale for better accuracy in small object
    loss_record1, loss_record2, loss_record3, loss_record4, loss_record5 = AvgMeter(
    ), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
    for i, pack in enumerate(train_loader, start=1):
        global_current_iteration += 1
        for rate in size_rates:
            optimizer.zero_grad()
            # ---- data prepare ----
            images, gts, edges = pack
            images = Variable(images).to(device)
            gts = Variable(gts).to(device)
            edges = Variable(edges).to(device)
            # ---- rescaling the inputs (img/gt/edge) ----
            trainsize = int(round(opt.trainsize * rate / 32) * 32)
            if rate != 1:
                images = F.upsample(images,
                                    size=(trainsize, trainsize),
                                    mode='bilinear',
                                    align_corners=True)
                gts = F.upsample(gts,
                                 size=(trainsize, trainsize),
                                 mode='bilinear',
                                 align_corners=True)
                edges = F.upsample(edges,
                                   size=(trainsize, trainsize),
                                   mode='bilinear',
                                   align_corners=True)

            # ---- forward ----
            lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2, lateral_edge = model(
                images)
            # ---- loss function ----
            loss5 = joint_loss(lateral_map_5, gts, opt)
            loss4 = joint_loss(lateral_map_4, gts, opt)
            loss3 = joint_loss(lateral_map_3, gts, opt)
            loss2 = joint_loss(lateral_map_2, gts, opt)
            loss1 = BCE(lateral_edge, edges)
            loss = loss1 + loss2 + loss3 + loss4 + loss5

            train_writer.add_scalar('train/edge_loss', loss1.item(),
                                    global_current_iteration)
            train_writer.add_scalar('train/loss2', loss2.item(),
                                    global_current_iteration)
            train_writer.add_scalar('train/loss3', loss3.item(),
                                    global_current_iteration)
            train_writer.add_scalar('train/loss4', loss4.item(),
                                    global_current_iteration)
            train_writer.add_scalar('train/loss5', loss5.item(),
                                    global_current_iteration)
            scalar_total_loss = loss2.item() + loss3.item() + loss4.item(
            ) + loss5.item()
            train_writer.add_scalar('train/total_loss', scalar_total_loss,
                                    global_current_iteration)

            # ---- backward ----
            loss.backward()
            clip_gradient(optimizer, opt.clip)
            optimizer.step()
            # ---- recording loss ----
            if rate == 1:
                loss_record1.update(loss1.data, opt.batchsize)
                loss_record2.update(loss2.data, opt.batchsize)
                loss_record3.update(loss3.data, opt.batchsize)
                loss_record4.update(loss4.data, opt.batchsize)
                loss_record5.update(loss5.data, opt.batchsize)
        # ---- train logging ----
        if i % 20 == 0 or i == total_step:
            print(
                '{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], [lateral-edge: {:.4f}, '
                'lateral-2: {:.4f}, lateral-3: {:0.4f}, lateral-4: {:0.4f}, lateral-5: {:0.4f}]'
                .format(datetime.now(), epoch, opt.epoch, i, total_step,
                        loss_record1.show(), loss_record2.show(),
                        loss_record3.show(), loss_record4.show(),
                        loss_record5.show()))
    # check testing error
    total_test_step = 0
    total_loss_5 = 0
    total_loss_4 = 0
    total_loss_3 = 0
    total_loss_2 = 0

    total_dice_5 = 0
    total_dice_4 = 0
    total_dice_3 = 0
    total_dice_2 = 0
    model.eval()
    for pack in test_loader:
        total_test_step += 1
        image, gt, _, name = pack
        image = Variable(image).to(device)
        gt = Variable(gt).to(device)
        # ---- forward ----
        lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2, lateral_edge = model(
            image)
        # ---- loss function ----
        loss5 = joint_loss(lateral_map_5, gt, opt)
        loss4 = joint_loss(lateral_map_4, gt, opt)
        loss3 = joint_loss(lateral_map_3, gt, opt)
        loss2 = joint_loss(lateral_map_2, gt, opt)
        total_loss_5 += loss5.item()
        total_loss_4 += loss4.item()
        total_loss_3 += loss3.item()
        total_loss_2 += loss2.item()

        total_dice_5 += dice_similarity_coefficient(lateral_map_5.sigmoid(),
                                                    gt, 0.5)
        total_dice_4 += dice_similarity_coefficient(lateral_map_4.sigmoid(),
                                                    gt, 0.5)
        total_dice_3 += dice_similarity_coefficient(lateral_map_3.sigmoid(),
                                                    gt, 0.5)
        total_dice_2 += dice_similarity_coefficient(lateral_map_2.sigmoid(),
                                                    gt, 0.5)

    total_average_loss = (total_loss_2 + total_loss_3 + total_loss_4 +
                          total_loss_5) / total_test_step / 4
    test_writer.add_scalar('test/loss2', total_loss_2 / total_test_step,
                           global_current_iteration)
    test_writer.add_scalar('test/loss3', total_loss_3 / total_test_step,
                           global_current_iteration)
    test_writer.add_scalar('test/loss4', total_loss_4 / total_test_step,
                           global_current_iteration)
    test_writer.add_scalar('test/loss5', total_loss_5 / total_test_step,
                           global_current_iteration)
    test_writer.add_scalar('test/total_loss', total_average_loss,
                           global_current_iteration)
    test_writer.add_scalar(
        'test/dice',
        (total_dice_2 + total_dice_3 + total_dice_4 + total_dice_5) /
        total_test_step / 4, global_current_iteration)
    model.train()

    if total_average_loss < best_loss:
        best_loss = total_average_loss
        # ---- save model_lung_infection ----
        save_path = './Snapshots/save_weights/{}/'.format(train_save)
        os.makedirs(save_path, exist_ok=True)
        torch.save(model.state_dict(),
                   save_path + 'Inf-Net-%d.pth' % (epoch + 1))
        print('[Saving Snapshot:]', save_path + 'Inf-Net-%d.pth' % (epoch + 1))
    return total_average_loss