예제 #1
0
def train(epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        model.train()

        optimizer.zero_grad()

        adjust_learning_rate(optimizer, epoch)

        data, target = data.cuda(), target.cuda()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        writer.add_scalar("train_loss for Unet", loss.item(),
                          (len(train_loader) *
                           (epoch - 1) + batch_idx))  # 675*e+i

        if batch_idx % 20 == 0:
            # validation
            model.eval()
            with torch.no_grad():
                val_losses = 0.0
                val_dice_coeff = 0.0
                for idx, (data, target) in enumerate(val_loader):
                    data, target = data.cuda(), target.cuda()
                    embedded = model(data)
                    val_loss = criterion(embedded, target)
                    val_losses += val_loss
                    #
                    pred = torch.sigmoid(embedded)
                    pred = (pred > 0.5).float()
                    true_masks = target.to(device=device, dtype=torch.long)
                    val_dice_coeff += dice_coeff(pred,
                                                 true_masks.float()).item()
            mean_dice_coeff = val_dice_coeff / len(val_loader)
            mean_val_loss = val_losses / len(val_loader)
            writer.add_scalar("validation/val_loss", mean_val_loss,
                              (len(train_loader) * (epoch - 1) + batch_idx))
            writer.add_scalar("validation/dice_coeff", mean_dice_coeff,
                              (len(train_loader) * (epoch - 1) + batch_idx))
            print(
                'Train Epoch: {:>3} [{:>5}/{:>5} ({:>3.0f}%)]\ttrain loss: {:>2.4f}\tmean val loss: {:>2.4f}\tmean '
                'dice coefficient: {:>2.4f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item(),
                    mean_val_loss, mean_dice_coeff))
예제 #2
0
def train(epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        model.train()
        model_fcn8.train()
        model_segnet.train()

        optimizer.zero_grad()
        optimizer_fcn8.zero_grad()
        optimizer_segnet.zero_grad()

        adjust_learning_rate(optimizer, epoch)
        adjust_learning_rate(optimizer_fcn8, epoch)
        adjust_learning_rate(optimizer_segnet, epoch)

        data, target = data.cuda(), target.cuda()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        output_fcn8 = model_fcn8(data)
        loss_fcn8 = criterion_fcn8(output_fcn8, target)
        loss_fcn8.backward()
        optimizer_fcn8.step()

        output_segnet = model_segnet(data)
        loss_segnet = criterion_segnet(output_segnet, target)
        loss_segnet.backward()
        optimizer_segnet.step()

        writer_unet.add_scalar("train_loss for Unet", loss.item(),
                               (len(train_loader) *
                                (epoch - 1) + batch_idx))  # 675*e+i
        writer_fcn8.add_scalar("train_loss for fcn8", loss_fcn8.item(),
                               (len(train_loader) *
                                (epoch - 1) + batch_idx))  # 675*e+i
        writer_segnet.add_scalar("train_loss for segnet", loss_segnet.item(),
                                 (len(train_loader) *
                                  (epoch - 1) + batch_idx))  # 675*e+i

        if batch_idx % 20 == 0:
            # validation
            model.eval()
            model_fcn8.eval()
            model_segnet.eval()
            with torch.no_grad():
                val_losses, val_losses_fcn8, val_losses_segnet = 0.0, 0.0, 0.0
                val_dice_coeff, val_dice_coeff_fcn8, val_dice_coeff_segnet = 0.0, 0.0, 0.0
                for idx, (data, target) in enumerate(val_loader):
                    data, target = data.cuda(), target.cuda()
                    # unet
                    embedded = model(data)
                    val_loss = criterion(embedded, target)
                    val_losses += val_loss
                    pred = torch.sigmoid(embedded)
                    pred = (pred > 0.5).float()
                    true_masks = target.to(device=device, dtype=torch.long)
                    val_dice_coeff += dice_coeff(pred,
                                                 true_masks.float()).item()
                    # fcn8
                    embedded_fcn8 = model_fcn8(data)
                    val_loss_fcn8 = criterion(embedded_fcn8, target)
                    val_losses_fcn8 += val_loss_fcn8
                    pred = torch.sigmoid(embedded_fcn8)
                    pred = (pred > 0.5).float()
                    val_dice_coeff_fcn8 += dice_coeff(
                        pred, true_masks.float()).item()
                    # segnet
                    embedded_segnet = model_segnet(data)
                    val_loss_segnet = criterion(embedded_segnet, target)
                    val_losses_segnet += val_loss_segnet
                    pred = torch.sigmoid(embedded_segnet)
                    pred = (pred > 0.5).float()
                    val_dice_coeff_segnet += dice_coeff(
                        pred, true_masks.float()).item()
            mean_dice_coeff = val_dice_coeff / len(val_loader)
            mean_val_loss = val_losses / len(val_loader)
            mean_dice_coeff_fcn8 = val_dice_coeff_fcn8 / len(val_loader)
            mean_val_loss_fcn8 = val_losses_fcn8 / len(val_loader)
            mean_dice_coeff_segnet = val_dice_coeff_segnet / len(val_loader)
            mean_val_loss_segnet = val_losses_segnet / len(val_loader)
            writer_unet.add_scalar("validation/val_loss_unet", mean_val_loss,
                                   (len(train_loader) *
                                    (epoch - 1) + batch_idx))
            writer_unet.add_scalar("validation/dice_coeff_unet",
                                   mean_dice_coeff, (len(train_loader) *
                                                     (epoch - 1) + batch_idx))
            writer_fcn8.add_scalar(
                "validation/val_loss_fcn8", mean_val_loss_fcn8,
                (len(train_loader) * (epoch - 1) + batch_idx))
            writer_fcn8.add_scalar(
                "validation/dice_coeff_fcn8", mean_dice_coeff_fcn8,
                (len(train_loader) * (epoch - 1) + batch_idx))
            writer_segnet.add_scalar(
                "validation/val_loss_segnet", mean_val_loss_segnet,
                (len(train_loader) * (epoch - 1) + batch_idx))
            writer_segnet.add_scalar(
                "validation/dice_coeff_segnet", mean_dice_coeff_segnet,
                (len(train_loader) * (epoch - 1) + batch_idx))
            print(
                'Train Epoch: {:>3} [{:>5}/{:>5} ({:>3.0f}%)]\ttrain_loss for Unet: {:>2.4f}\t    train_loss for '
                'fcn8: {:>2.4f}\t    train_loss for segnet: {:>2.4f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item(),
                    loss_fcn8.item(), loss_segnet.item()))
예제 #3
0
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
# model = torch.load(checkpoint_path, map_location='cpu')

model.eval()
with torch.no_grad():
    # data, target = iter(train_loader).next()
    abab = 0
    tot = 0
    for data, target in train_loader:
        output = model(data)
        abab += 1
        pred = torch.sigmoid(output)
        pred = (pred > 0.5).float()
        true_masks = target.to(device=device, dtype=torch.long)
        coeff = dice_coeff(pred, true_masks.float()).item()
        tot += coeff
        print("The dice coeff is {:>0.5}".format(coeff))
        if abab == 20:
            print("The average coeff is {:>0.5}".format(tot / abab))
            break
    # data = data.cuda()
    # output = model(data).data  # [4, 21, 512, 512]

img = Image.fromarray(data[0].detach().cpu().transpose(0, 1).transpose(
    1, 2).numpy().astype(np.uint8))
y = output[0].detach().cpu()
anno_class_img = Image.fromarray(np.uint8(np.argmax(y.numpy(), axis=0)),
                                 mode="P").convert('RGB')
target = Image.fromarray(np.sum(target[0].numpy() *
                                np.arange(0, 21)[:, np.newaxis, np.newaxis],
예제 #4
0
def train(epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        model.train()
        model_small.train()
        model_simple.train()
        optimizer.zero_grad()
        optimizer_small.zero_grad()
        optimizer_simple.zero_grad()

        adjust_learning_rate(optimizer, epoch)
        adjust_learning_rate(optimizer_small, epoch)
        adjust_learning_rate(optimizer_simple, epoch)

        data, target = data.cuda(), target.cuda()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        output_small = model_small(data)
        loss_small = criterion(output_small, target)
        loss_small.backward()
        optimizer_small.step()

        output_simple = model_simple(data)
        loss_simple = criterion(output_simple, target)
        loss_simple.backward()
        optimizer_simple.step()

        writer_half.add_scalar("train_loss for Unet_cut_half", loss.item(),
                               (len(train_loader) * (epoch - 1) + batch_idx))  # 675*e+i
        writer_small.add_scalar("train_loss for Unet simplify input-output layers", loss_small.item(),
                                (len(train_loader) * (epoch - 1) + batch_idx))  # 675*e+i
        writer_simple.add_scalar("train_loss for Unet simplify number of channels", loss_simple.item(),
                                 (len(train_loader) * (epoch - 1) + batch_idx))  # 675*e+i
        if batch_idx % 20 == 0:
            # validation
            model.eval()
            model_small.eval()
            model_simple.eval()
            with torch.no_grad():
                val_losses, val_losses_small, val_losses_simple = 0.0, 0.0, 0.0
                val_dice_coeff, val_dice_coeff_small, val_dice_coeff_simple = 0.0, 0.0, 0.0
                for idx, (data, target) in enumerate(val_loader):
                    data, target = data.cuda(), target.cuda()
                    embedded = model(data)
                    val_loss = criterion(embedded, target)
                    val_losses += val_loss
                    #
                    pred = torch.sigmoid(embedded)
                    pred = (pred > 0.5).float()
                    true_masks = target.to(device=device, dtype=torch.long)
                    val_dice_coeff += dice_coeff(pred, true_masks.float()).item()

                    embedded_small = model_small(data)
                    val_loss_small = criterion(embedded_small, target)
                    val_losses_small += val_loss_small
                    #
                    pred = torch.sigmoid(embedded_small)
                    pred = (pred > 0.5).float()
                    val_dice_coeff_small += dice_coeff(pred, true_masks.float()).item()

                    embedded_simple = model_simple(data)
                    val_loss_simple = criterion(embedded_simple, target)
                    val_losses_simple += val_loss_simple
                    #
                    pred = torch.sigmoid(embedded_simple)
                    pred = (pred > 0.5).float()
                    val_dice_coeff_simple += dice_coeff(pred, true_masks.float()).item()
            mean_dice_coeff = val_dice_coeff / len(val_loader)
            mean_val_loss = val_losses / len(val_loader)
            mean_dice_coeff_small = val_dice_coeff_small / len(val_loader)
            mean_val_loss_small = val_losses_small / len(val_loader)
            mean_dice_coeff_simple = val_dice_coeff_simple / len(val_loader)
            mean_val_loss_simple = val_losses_simple / len(val_loader)
            writer_half.add_scalar("validation/val_loss", mean_val_loss, (len(train_loader) * (epoch - 1) + batch_idx))
            writer_half.add_scalar("validation/dice_coeff", mean_dice_coeff,
                                   (len(train_loader) * (epoch - 1) + batch_idx))

            writer_small.add_scalar("validation/val_loss for simplify in-out layer", mean_val_loss_small,
                                    (len(train_loader) * (epoch - 1) + batch_idx))
            writer_small.add_scalar("validation/dice_coeff for simplify in-out layer", mean_dice_coeff_small,
                                    (len(train_loader) * (epoch - 1) + batch_idx))

            writer_simple.add_scalar("validation/val_loss for simplify number of channels", mean_val_loss_simple,
                                     (len(train_loader) * (epoch - 1) + batch_idx))
            writer_simple.add_scalar("validation/dice_coeff for simplify number of channels", mean_dice_coeff_simple,
                                     (len(train_loader) * (epoch - 1) + batch_idx))
            print('Train Epoch: {:>3} [{:>5}/{:>5} ({:>3.0f}%)]\ttrain loss: {:>2.4f}\tmean val loss: {:>2.4f}\tmean '
                  'dice coefficient: {:>2.4f}'.format(epoch, batch_idx * len(data),
                                                      len(train_loader.dataset), 100. * batch_idx / len(train_loader),
                                                      loss.item(), mean_val_loss, mean_dice_coeff))
            print('Train Epoch  for simplify input-output layers: {:>3} [{:>5}/{:>5} ({:>3.0f}%)]\ttrain loss: {'
                  ':>2.4f}\tmean val loss: {:>2.4f}\tmean dice coefficient: {:>2.4f}'.format(epoch, batch_idx * len(
                data), len(train_loader.dataset), 100. * batch_idx / len(train_loader),
                                                                                             loss_small.item(),
                                                                                             mean_val_loss_small,
                                                                                             mean_dice_coeff_small))
            print('Train Epoch  for simplify number of channels: {:>3} [{:>5}/{:>5} ({:>3.0f}%)]\ttrain loss: {'
                  ':>2.4f}\tmean val loss: {:>2.4f}\tmean dice coefficient: {:>2.4f}'.format(epoch, batch_idx * len(
                data), len(train_loader.dataset), 100. * batch_idx / len(train_loader),
                                                                                             loss_simple.item(),
                                                                                             mean_val_loss_simple,
                                                                                             mean_dice_coeff_simple))