Пример #1
0
def train(device, model_path, dataset_path):
    """
    Trains the network according on the dataset_path
    """
    network = UNet(1, 3).to(device)
    optimizer = torch.optim.Adam(network.parameters())
    criteria = torch.nn.MSELoss()

    dataset = GrayColorDataset(dataset_path, transform=train_transform)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=16,
                                         shuffle=True,
                                         num_workers=cpu_count())

    if os.path.exists(model_path):
        network.load_state_dict(torch.load(model_path))
    for _ in tqdm.trange(10, desc="Epoch"):
        network.train()
        for gray, color in tqdm.tqdm(loader, desc="Training", leave=False):
            gray, color = gray.to(device), color.to(device)
            optimizer.zero_grad()
            pred_color = network(gray)
            loss = criteria(pred_color, color)
            loss.backward()
            optimizer.step()
        torch.save(network.state_dict(), model_path)
Пример #2
0
def train_val(config):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_loader = get_dataloader(img_dir=config.train_img_dir,
                                  mask_dir=config.train_mask_dir,
                                  mode="train",
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers,
                                  smooth=config.smooth)
    val_loader = get_dataloader(img_dir=config.val_img_dir,
                                mask_dir=config.val_mask_dir,
                                mode="val",
                                batch_size=4,
                                num_workers=config.num_workers)

    writer = SummaryWriter(
        comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" %
        (config.lr, config.batch_size, config.model_type, config.data_type))

    if config.model_type == "UNet":
        model = UNet()
    elif config.model_type == "UNet++":
        model = UNetPP()
    elif config.model_type == "SEDANet":
        model = SEDANet()
    elif config.model_type == "RefineNet":
        model = rf101()
    elif config.model_type == "BASNet":
        model = BASNet(n_classes=8)
    elif config.model_type == "DANet":
        model = DANet(backbone='resnet101',
                      nclass=config.output_ch,
                      pretrained=True,
                      norm_layer=nn.BatchNorm2d)
    elif config.model_type == "Deeplabv3+":
        model = deeplabv3_plus.DeepLabv3_plus(in_channels=3,
                                              num_classes=8,
                                              backend='resnet101',
                                              os=16,
                                              pretrained=True,
                                              norm_layer=nn.BatchNorm2d)
    elif config.model_type == "HRNet_OCR":
        model = seg_hrnet_ocr.get_seg_model()
    elif config.model_type == "scSEUNet":
        model = scSEUNet(pretrained=True, norm_layer=nn.BatchNorm2d)
    else:
        model = UNet()

    if config.iscontinue:
        model = torch.load("./exp/24_Deeplabv3+_0.7825757691389714.pth").module

    for k, m in model.named_modules():
        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatability

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device)

    labels = [100, 200, 300, 400, 500, 600, 700, 800]
    objects = ['水体', '交通建筑', '建筑', '耕地', '草地', '林地', '裸土', '其他']

    if config.optimizer == "sgd":
        optimizer = SGD(model.parameters(),
                        lr=config.lr,
                        weight_decay=1e-4,
                        momentum=0.9)
    elif config.optimizer == "adamw":
        optimizer = adamw.AdamW(model.parameters(), lr=config.lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    # weight = torch.tensor([1, 1.5, 1, 2, 1.5, 2, 2, 1.2]).to(device)
    # criterion = nn.CrossEntropyLoss(weight=weight)

    criterion = BasLoss()

    # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[25, 30, 35, 40], gamma=0.5)
    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.1, patience=5, verbose=True)
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
                                                         T_0=15,
                                                         eta_min=1e-4)

    global_step = 0
    max_fwiou = 0
    frequency = np.array(
        [0.1051, 0.0607, 0.1842, 0.1715, 0.0869, 0.1572, 0.0512, 0.1832])
    for epoch in range(config.num_epochs):
        epoch_loss = 0.0
        cm = np.zeros([8, 8])
        print(optimizer.param_groups[0]['lr'])
        with tqdm(total=config.num_train,
                  desc="Epoch %d / %d" % (epoch + 1, config.num_epochs),
                  unit='img',
                  ncols=100) as train_pbar:
            model.train()

            for image, mask in train_loader:
                image = image.to(device, dtype=torch.float32)
                mask = mask.to(device, dtype=torch.float16)

                pred = model(image)
                loss = criterion(pred, mask)
                epoch_loss += loss.item()

                writer.add_scalar('Loss/train', loss.item(), global_step)
                train_pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_pbar.update(image.shape[0])
                global_step += 1
                # if global_step > 10:
                #     break

            # scheduler.step()
            print("\ntraining epoch loss: " +
                  str(epoch_loss / (float(config.num_train) /
                                    (float(config.batch_size)))))
            torch.cuda.empty_cache()

        val_loss = 0
        with torch.no_grad():
            with tqdm(total=config.num_val,
                      desc="Epoch %d / %d validation round" %
                      (epoch + 1, config.num_epochs),
                      unit='img',
                      ncols=100) as val_pbar:
                model.eval()
                locker = 0
                for image, mask in val_loader:
                    image = image.to(device, dtype=torch.float32)
                    target = mask.to(device, dtype=torch.long).argmax(dim=1)
                    mask = mask.cpu().numpy()
                    pred, _, _, _, _, _, _, _ = model(image)
                    val_loss += F.cross_entropy(pred, target).item()
                    pred = pred.cpu().detach().numpy()
                    mask = semantic_to_mask(mask, labels)
                    pred = semantic_to_mask(pred, labels)
                    cm += get_confusion_matrix(mask, pred, labels)
                    val_pbar.update(image.shape[0])
                    if locker == 25:
                        writer.add_images('mask_a/true',
                                          mask[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_a/pred',
                                          pred[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/true',
                                          mask[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/pred',
                                          pred[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                    locker += 1

                    # break
                miou = get_miou(cm)
                fw_miou = (miou * frequency).sum()
                scheduler.step()

                if fw_miou > max_fwiou:
                    if torch.__version__ == "1.6.0":
                        torch.save(model,
                                   config.result_path + "/%d_%s_%.4f.pth" %
                                   (epoch + 1, config.model_type, fw_miou),
                                   _use_new_zipfile_serialization=False)
                    else:
                        torch.save(
                            model, config.result_path + "/%d_%s_%.4f.pth" %
                            (epoch + 1, config.model_type, fw_miou))
                    max_fwiou = fw_miou
                print("\n")
                print(miou)
                print("testing epoch loss: " + str(val_loss),
                      "FWmIoU = %.4f" % fw_miou)
                writer.add_scalar('mIoU/val', miou.mean(), epoch + 1)
                writer.add_scalar('FWIoU/val', fw_miou, epoch + 1)
                writer.add_scalar('loss/val', val_loss, epoch + 1)
                for idx, name in enumerate(objects):
                    writer.add_scalar('iou/val' + name, miou[idx], epoch + 1)
                torch.cuda.empty_cache()
    writer.close()
    print("Training finished")
Пример #3
0
def train_val(config):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_loader = get_dataloader(img_dir=config.train_img_dir,
                                  mask_dir=config.train_mask_dir,
                                  mode="train",
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers,
                                  smooth=config.smooth)
    val_loader = get_dataloader(img_dir=config.val_img_dir,
                                mask_dir=config.val_mask_dir,
                                mode="val",
                                batch_size=config.batch_size,
                                num_workers=config.num_workers)

    writer = SummaryWriter(
        comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" %
        (config.lr, config.batch_size, config.model_type, config.data_type))

    if config.model_type == "UNet":
        model = UNet()
    elif config.model_type == "UNet++":
        model = UNetPP()
    elif config.model_type == "SEDANet":
        model = SEDANet()
    elif config.model_type == "RefineNet":
        model = rf101()
    elif config.model_type == "DANet":
        # src = "./pretrained/60_DANet_0.8086.pth"
        # pretrained_dict = torch.load(src, map_location='cpu').module.state_dict()
        # print("load pretrained params from stage 1: " + src)
        # pretrained_dict.pop('seg1.1.weight')
        # pretrained_dict.pop('seg1.1.bias')
        model = DANet(backbone='resnext101',
                      nclass=config.output_ch,
                      pretrained=True,
                      norm_layer=nn.BatchNorm2d)
        # model_dict = model.state_dict()
        # model_dict.update(pretrained_dict)
        # model.load_state_dict(model_dict)
    elif config.model_type == "Deeplabv3+":
        # src = "./pretrained/Deeplabv3+.pth"
        # pretrained_dict = torch.load(src, map_location='cpu').module.state_dict()
        # print("load pretrained params from stage 1: " + src)
        # # print(pretrained_dict.keys())
        # for key in list(pretrained_dict.keys()):
        #     if key.split('.')[0] == "cbr_last":
        #         pretrained_dict.pop(key)
        model = deeplabv3_plus.DeepLabv3_plus(in_channels=3,
                                              num_classes=config.output_ch,
                                              backend='resnet101',
                                              os=16,
                                              pretrained=True,
                                              norm_layer=nn.BatchNorm2d)
        # model_dict = model.state_dict()
        # model_dict.update(pretrained_dict)
        # model.load_state_dict(model_dict)
    elif config.model_type == "HRNet_OCR":
        model = seg_hrnet_ocr.get_seg_model()
    elif config.model_type == "scSEUNet":
        model = scSEUNet(pretrained=True, norm_layer=nn.BatchNorm2d)
    else:
        model = UNet()

    if config.iscontinue:
        model = torch.load("./exp/13_Deeplabv3+_0.7619.pth",
                           map_location='cpu').module

    for k, m in model.named_modules():
        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatability

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device)

    labels = [1, 2, 3, 4, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    objects = [
        '水体', '道路', '建筑物', '机场', '停车场', '操场', '普通耕地', '农业大棚', '自然草地', '绿地绿化',
        '自然林', '人工林', '自然裸土', '人为裸土', '其它'
    ]
    frequency = np.array([
        0.0279, 0.0797, 0.1241, 0.00001, 0.0616, 0.0029, 0.2298, 0.0107,
        0.1207, 0.0249, 0.1470, 0.0777, 0.0617, 0.0118, 0.0187
    ])

    if config.optimizer == "sgd":
        optimizer = SGD(model.parameters(),
                        lr=config.lr,
                        weight_decay=1e-4,
                        momentum=0.9)
    elif config.optimizer == "adamw":
        optimizer = adamw.AdamW(model.parameters(), lr=config.lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    # weight = torch.tensor([1, 1.5, 1, 2, 1.5, 2, 2, 1.2]).to(device)
    # criterion = nn.CrossEntropyLoss(weight=weight)

    if config.smooth == "all":
        criterion = LabelSmoothSoftmaxCE()
    elif config.smooth == "edge":
        criterion = LabelSmoothCE()
    else:
        criterion = nn.CrossEntropyLoss()

    # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[25, 30, 35, 40], gamma=0.5)
    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.1, patience=5, verbose=True)
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
                                                         T_0=15,
                                                         eta_min=1e-4)

    global_step = 0
    max_fwiou = 0
    for epoch in range(config.num_epochs):
        epoch_loss = 0.0
        seed = np.random.randint(0, 2, 1)
        seed = 0
        print("seed is ", seed)
        if seed == 1:
            train_loader = get_dataloader(img_dir=config.train_img_dir,
                                          mask_dir=config.train_mask_dir,
                                          mode="train",
                                          batch_size=config.batch_size // 2,
                                          num_workers=config.num_workers,
                                          smooth=config.smooth)
            val_loader = get_dataloader(img_dir=config.val_img_dir,
                                        mask_dir=config.val_mask_dir,
                                        mode="val",
                                        batch_size=config.batch_size // 2,
                                        num_workers=config.num_workers)
        else:
            train_loader = get_dataloader(img_dir=config.train_img_dir,
                                          mask_dir=config.train_mask_dir,
                                          mode="train",
                                          batch_size=config.batch_size,
                                          num_workers=config.num_workers,
                                          smooth=config.smooth)
            val_loader = get_dataloader(img_dir=config.val_img_dir,
                                        mask_dir=config.val_mask_dir,
                                        mode="val",
                                        batch_size=config.batch_size,
                                        num_workers=config.num_workers)

        cm = np.zeros([15, 15])
        print(optimizer.param_groups[0]['lr'])
        with tqdm(total=config.num_train,
                  desc="Epoch %d / %d" % (epoch + 1, config.num_epochs),
                  unit='img',
                  ncols=100) as train_pbar:
            model.train()

            for image, mask in train_loader:
                image = image.to(device, dtype=torch.float32)

                if seed == 0:
                    pass
                elif seed == 1:
                    image = F.interpolate(image,
                                          size=(384, 384),
                                          mode='bilinear',
                                          align_corners=True)
                    mask = F.interpolate(mask.float(),
                                         size=(384, 384),
                                         mode='nearest')

                if config.smooth == "edge":
                    mask = mask.to(device, dtype=torch.float32)
                else:
                    mask = mask.to(device, dtype=torch.long).argmax(dim=1)

                aux_out, out = model(image)
                aux_loss = criterion(aux_out, mask)
                seg_loss = criterion(out, mask)
                loss = aux_loss + seg_loss

                # pred = model(image)
                # loss = criterion(pred, mask)

                epoch_loss += loss.item()

                writer.add_scalar('Loss/train', loss.item(), global_step)
                train_pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_pbar.update(image.shape[0])
                global_step += 1
                # if global_step > 10:
                #     break

            # scheduler.step()
            print("\ntraining epoch loss: " +
                  str(epoch_loss / (float(config.num_train) /
                                    (float(config.batch_size)))))
        torch.cuda.empty_cache()
        val_loss = 0
        with torch.no_grad():
            with tqdm(total=config.num_val,
                      desc="Epoch %d / %d validation round" %
                      (epoch + 1, config.num_epochs),
                      unit='img',
                      ncols=100) as val_pbar:
                model.eval()
                locker = 0
                for image, mask in val_loader:
                    image = image.to(device, dtype=torch.float32)
                    target = mask.to(device, dtype=torch.long).argmax(dim=1)
                    mask = mask.cpu().numpy()
                    _, pred = model(image)
                    val_loss += F.cross_entropy(pred, target).item()
                    pred = pred.cpu().detach().numpy()
                    mask = semantic_to_mask(mask, labels)
                    pred = semantic_to_mask(pred, labels)
                    cm += get_confusion_matrix(mask, pred, labels)
                    val_pbar.update(image.shape[0])
                    if locker == 5:
                        writer.add_images('mask_a/true',
                                          mask[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_a/pred',
                                          pred[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/true',
                                          mask[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/pred',
                                          pred[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                    locker += 1

                    # break
                miou = get_miou(cm)
                fw_miou = (miou * frequency).sum()
                scheduler.step()

                if True:
                    if torch.__version__ == "1.6.0":
                        torch.save(model,
                                   config.result_path + "/%d_%s_%.4f.pth" %
                                   (epoch + 1, config.model_type, fw_miou),
                                   _use_new_zipfile_serialization=False)
                    else:
                        torch.save(
                            model, config.result_path + "/%d_%s_%.4f.pth" %
                            (epoch + 1, config.model_type, fw_miou))
                    max_fwiou = fw_miou
                print("\n")
                print(miou)
                print("testing epoch loss: " + str(val_loss),
                      "FWmIoU = %.4f" % fw_miou)
                writer.add_scalar('FWIoU/val', fw_miou, epoch + 1)
                writer.add_scalar('loss/val', val_loss, epoch + 1)
                for idx, name in enumerate(objects):
                    writer.add_scalar('iou/val' + name, miou[idx], epoch + 1)
            torch.cuda.empty_cache()
    writer.close()
    print("Training finished")
Пример #4
0
    optimizer = Adam(model.parameters(), lr=1e-4)

    criterion_global = LossMulti(num_classes=2, jaccard_weight=0)
    criterion_local = LossMulti(num_classes=2, jaccard_weight=0)

    for epoch in tqdm(
            range(int(epoch_start) + 1,
                  int(epoch_start) + 1 + no_of_epochs)):

        global_step = epoch * len(trainLoader)
        running_loss = 0.0
        running_loss_local = 0.0

        for i, (inputs, targets, coord) in enumerate(tqdm(trainLoader)):

            model.train()
            inputs = inputs.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                loss_global = criterion_global(outputs, targets)

                local_outputs_ = []
                local_targets_ = []
                for i, coord in enumerate(coord):

                    cx, cy = int(coord[:, 0].item()), int(coord[:, 1].item())
                    targets_numpy = targets[i].detach().cpu().numpy()
Пример #5
0
    label_list=label_list,
    transforms=train_transforms,
    shuffle=True)

eval_reader = Reader(
    data_dir=data_dir,
    file_list=val_list,
    label_list=label_list,
    transforms=eval_transforms)

if args.model_type == 'unet':
    model = UNet(num_classes=num_classes, input_channel=channel)
elif args.model_type == 'hrnet':
    model = HRNet(num_classes=num_classes, input_channel=channel)
else:
    raise ValueError(
        "--model_type: {} is set wrong, it shold be one of ('unet', "
        "'hrnet')".format(args.model_type))

model.train(
    num_epochs=num_epochs,
    train_reader=train_reader,
    train_batch_size=train_batch_size,
    eval_reader=eval_reader,
    eval_best_metric='miou',
    save_interval_epochs=5,
    log_interval_steps=10,
    save_dir=save_dir,
    learning_rate=lr,
    use_vdl=True)
def main():
    net = UNet(num_classes=num_classes).cuda()
    if len(train_args['snapshot']) == 0:
        curr_epoch = 0
    else:
        print 'training resumes from ' + train_args['snapshot']
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1])
        train_record['best_val_loss'] = float(split_snapshot[3])
        train_record['corr_mean_iu'] = float(split_snapshot[6])
        train_record['corr_epoch'] = curr_epoch

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.RandomCrop(train_args['input_size']),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.CenterCrop(train_args['input_size'])
    ])
    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = standard_transforms.Compose([
        expanded_transforms.MaskToTensor(),
        expanded_transforms.ChangeLabel(ignored_label, num_classes - 1)
    ])
    restore_transform = standard_transforms.Compose([
        expanded_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    train_set = CityScapes('train', simul_transform=train_simul_transform, transform=img_transform,
                           target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=train_args['batch_size'], num_workers=16, shuffle=True)
    val_set = CityScapes('val', simul_transform=val_simul_transform, transform=img_transform,
                         target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=val_args['batch_size'], num_workers=16, shuffle=False)

    weight = torch.ones(num_classes)
    weight[num_classes - 1] = 0
    criterion = CrossEntropyLoss2d(weight).cuda()

    # don't use weight_decay for bias
    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], momentum=0.9, nesterov=True)

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    if not os.path.exists(ckpt_path):
        os.mkdir(ckpt_path)
    if not os.path.exists(os.path.join(ckpt_path, exp_name)):
        os.mkdir(os.path.join(ckpt_path, exp_name))

    for epoch in range(curr_epoch, train_args['epoch_num']):
        train(train_loader, net, criterion, optimizer, epoch)
        validate(val_loader, net, criterion, optimizer, epoch, restore_transform)
Пример #7
0
def train():
    t.cuda.set_device(1)

    # n_channels:医学影像为一通道灰度图    n_classes:二分类
    net = UNet(n_channels=1, n_classes=1)
    optimizer = t.optim.SGD(net.parameters(),
                            lr=opt.learning_rate,
                            momentum=0.9,
                            weight_decay=0.0005)
    criterion = t.nn.BCELoss()  # 二进制交叉熵(适合mask占据图像面积较大的场景)

    start_epoch = 0
    if opt.load_model_path:
        checkpoint = t.load(opt.load_model_path)

        # 加载多GPU模型参数到 单模型上
        state_dict = checkpoint['net']
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        net.load_state_dict(new_state_dict)  # 加载模型
        optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器
        start_epoch = checkpoint['epoch']  # 加载训练批次

    # 学习率每当到达milestones值则更新参数
    if start_epoch == 0:
        scheduler = t.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=opt.milestones,
                                                     gamma=0.1,
                                                     last_epoch=-1)  # 默认为-1
        print('从头训练 ,学习率为{}'.format(optimizer.param_groups[0]['lr']))
    else:
        scheduler = t.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=opt.milestones,
                                                     gamma=0.1,
                                                     last_epoch=start_epoch)
        print('加载预训练模型{}并从{}轮开始训练,学习率为{}'.format(
            opt.load_model_path, start_epoch, optimizer.param_groups[0]['lr']))

    # 网络转移到GPU上
    if opt.use_gpu:
        net = t.nn.DataParallel(net, device_ids=opt.device_ids)  # 模型转为GPU并行
        net.cuda()
        cudnn.benchmark = True

    # 定义可视化对象
    vis = Visualizer(opt.env)

    train_data = NodeDataSet(train=True)
    val_data = NodeDataSet(val=True)
    test_data = NodeDataSet(test=True)

    # 数据集加载器
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=True,
                                num_workers=opt.num_workers)
    test_dataloader = DataLoader(test_data,
                                 opt.test_batch_size,
                                 shuffle=False,
                                 num_workers=opt.num_workers)
    for epoch in range(opt.max_epoch - start_epoch):
        print('开始 epoch {}/{}.'.format(start_epoch + epoch + 1, opt.max_epoch))
        epoch_loss = 0

        # 每轮判断是否更新学习率
        scheduler.step()

        # 迭代数据集加载器
        for ii, (img, mask) in enumerate(
                train_dataloader):  # pytorch0.4写法,不再将tensor封装为Variable
            # 将数据转到GPU
            if opt.use_gpu:
                img = img.cuda()
                true_masks = mask.cuda()
            masks_pred = net(img)

            # 经过sigmoid
            masks_probs = t.sigmoid(masks_pred)

            # 损失 = 二进制交叉熵损失 + dice损失
            loss = criterion(masks_probs.view(-1), true_masks.view(-1))

            # 加入dice损失
            if opt.use_dice_loss:
                loss += dice_loss(masks_probs, true_masks)

            epoch_loss += loss.item()

            if ii % 2 == 0:
                vis.plot('训练集loss', loss.item())

            # 优化器梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 更新参数
            optimizer.step()

        # 当前时刻的一些信息
        vis.log("epoch:{epoch},lr:{lr},loss:{loss}".format(
            epoch=epoch, loss=loss.item(), lr=optimizer.param_groups[0]['lr']))

        vis.plot('每轮epoch的loss均值', epoch_loss / ii)
        # 保存模型、优化器、当前轮次等
        state = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch
        }
        t.save(state, opt.checkpoint_root + '{}_unet.pth'.format(epoch))

        # ============验证===================

        net.eval()
        # 评价函数:Dice系数    Dice距离用于度量两个集合的相似性
        tot = 0
        for jj, (img_val, mask_val) in enumerate(val_dataloader):
            img_val = img_val
            true_mask_val = mask_val
            if opt.use_gpu:
                img_val = img_val.cuda()
                true_mask_val = true_mask_val.cuda()

            mask_pred = net(img_val)
            mask_pred = (t.sigmoid(mask_pred) > 0.5).float()  # 阈值为0.5
            # 评价函数:Dice系数   Dice距离用于度量两个集合的相似性
            tot += dice_loss(mask_pred, true_mask_val).item()
        val_dice = tot / jj
        vis.plot('验证集 Dice损失', val_dice)

        # ============验证召回率===================
        # 每10轮验证一次测试集召回率
        if epoch % 10 == 0:
            result_test = []
            for kk, (img_test, mask_test) in enumerate(test_dataloader):
                # 测试 unet分割能力,故 不使用真值mask
                if opt.use_gpu:
                    img_test = img_test.cuda()
                mask_pred_test = net(img_test)  # [1,1,512,512]

                probs = t.sigmoid(mask_pred_test).squeeze().squeeze().cpu(
                ).detach().numpy()  # [512,512]
                mask = probs > opt.out_threshold
                result_test.append(mask)

            # 得到 测试集所有预测掩码,计算二维召回率
            vis.plot('测试集二维召回率', getRecall(result_test).getResult())
        net.train()
Пример #8
0
def train_val(config):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_loader = get_dataloader(img_dir=config.train_img_dir,
                                  mask_dir=config.train_mask_dir,
                                  mode="train",
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers)
    val_loader = get_dataloader(img_dir=config.val_img_dir,
                                mask_dir=config.val_mask_dir,
                                mode="val",
                                batch_size=config.batch_size,
                                num_workers=config.num_workers)

    writer = SummaryWriter(
        comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" %
        (config.lr, config.batch_size, config.model_type, config.data_type))

    if config.model_type not in [
            'UNet', 'R2UNet', 'AUNet', 'R2AUNet', 'SEUNet', 'SEUNet++',
            'UNet++', 'DAUNet', 'DANet', 'AUNetR', 'RendDANet', "BASNet"
    ]:
        print('ERROR!! model_type should be selected in supported models')
        print('Choose model %s' % config.model_type)
        return
    if config.model_type == "UNet":
        model = UNet()
    elif config.model_type == "AUNet":
        model = AUNet()
    elif config.model_type == "R2UNet":
        model = R2UNet()
    elif config.model_type == "SEUNet":
        model = SEUNet(useCSE=False, useSSE=False, useCSSE=True)
    elif config.model_type == "UNet++":
        model = UNetPP()
    elif config.model_type == "DANet":
        model = DANet(backbone='resnet101', nclass=1)
    elif config.model_type == "AUNetR":
        model = AUNet_R16(n_classes=1, learned_bilinear=True)
    elif config.model_type == "RendDANet":
        model = RendDANet(backbone='resnet101', nclass=1)
    elif config.model_type == "BASNet":
        model = BASNet(n_channels=3, n_classes=1)
    else:
        model = UNet()

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device, dtype=torch.float)

    if config.optimizer == "sgd":
        optimizer = SGD(model.parameters(),
                        lr=config.lr,
                        weight_decay=1e-6,
                        momentum=0.9)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    if config.loss == "dice":
        criterion = DiceLoss()
    elif config.loss == "bce":
        criterion = nn.BCELoss()
    elif config.loss == "bas":
        criterion = BasLoss()
    else:
        criterion = MixLoss()

    scheduler = lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
    global_step = 0
    best_dice = 0.0
    for epoch in range(config.num_epochs):
        epoch_loss = 0.0
        with tqdm(total=config.num_train,
                  desc="Epoch %d / %d" % (epoch + 1, config.num_epochs),
                  unit='img') as train_pbar:
            model.train()
            for image, mask in train_loader:
                image = image.to(device, dtype=torch.float)
                mask = mask.to(device, dtype=torch.float)
                d0, d1, d2, d3, d4, d5, d6, d7 = model(image)
                loss = criterion(d0, d1, d2, d3, d4, d5, d6, d7, mask)
                epoch_loss += loss.item()

                writer.add_scalar('Loss/train', loss.item(), global_step)
                train_pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_pbar.update(image.shape[0])
                global_step += 1

                # if global_step % 100 == 0:
                #     writer.add_images('masks/true', mask, global_step)
                #     writer.add_images('masks/pred', d0 > 0.5, global_step)
            scheduler.step()
        epoch_dice = 0.0
        epoch_acc = 0.0
        epoch_sen = 0.0
        epoch_spe = 0.0
        epoch_pre = 0.0
        current_num = 0
        with tqdm(total=config.num_val,
                  desc="Epoch %d / %d validation round" %
                  (epoch + 1, config.num_epochs),
                  unit='img') as val_pbar:
            model.eval()
            locker = 0
            for image, mask in val_loader:
                current_num += image.shape[0]
                image = image.to(device, dtype=torch.float)
                mask = mask.to(device, dtype=torch.float)
                d0, d1, d2, d3, d4, d5, d6, d7 = model(image)
                batch_dice = dice_coeff(mask, d0).item()
                epoch_dice += batch_dice * image.shape[0]
                epoch_acc += get_accuracy(pred=d0, true=mask) * image.shape[0]
                epoch_sen += get_sensitivity(pred=d0,
                                             true=mask) * image.shape[0]
                epoch_spe += get_specificity(pred=d0,
                                             true=mask) * image.shape[0]
                epoch_pre += get_precision(pred=d0, true=mask) * image.shape[0]
                if locker == 200:
                    writer.add_images('masks/true', mask, epoch + 1)
                    writer.add_images('masks/pred', d0 > 0.5, epoch + 1)
                val_pbar.set_postfix(**{'dice (batch)': batch_dice})
                val_pbar.update(image.shape[0])
                locker += 1
            epoch_dice /= float(current_num)
            epoch_acc /= float(current_num)
            epoch_sen /= float(current_num)
            epoch_spe /= float(current_num)
            epoch_pre /= float(current_num)
            epoch_f1 = get_F1(SE=epoch_sen, PR=epoch_pre)
            if epoch_dice > best_dice:
                best_dice = epoch_dice
                writer.add_scalar('Best Dice/test', best_dice, epoch + 1)
                torch.save(
                    model, config.result_path + "/%s_%s_%d.pth" %
                    (config.model_type, str(epoch_dice), epoch + 1))
            logging.info('Validation Dice Coeff: {}'.format(epoch_dice))
            print("epoch dice: " + str(epoch_dice))
            writer.add_scalar('Dice/test', epoch_dice, epoch + 1)
            writer.add_scalar('Acc/test', epoch_acc, epoch + 1)
            writer.add_scalar('Sen/test', epoch_sen, epoch + 1)
            writer.add_scalar('Spe/test', epoch_spe, epoch + 1)
            writer.add_scalar('Pre/test', epoch_pre, epoch + 1)
            writer.add_scalar('F1/test', epoch_f1, epoch + 1)

    writer.close()
    print("Training finished")
Пример #9
0
def train_unet(epoch=100):

    # Get all images in train set
    image_names = os.listdir('dataset/train/images/')
    image_names = [name for name in image_names if name.endswith(('.jpg', '.JPG', '.png'))]

    # Split into train and validation sets
    np.random.shuffle(image_names)
    split = int(len(image_names) * 0.9)
    train_image_names = image_names[:split]
    val_image_names = image_names[split:]

    # Create a dataset
    train_dataset = EggsPansDataset('dataset/train', train_image_names, mode='train')
    val_dataset = EggsPansDataset('dataset/train', val_image_names, mode='val')

    # Create a dataloader
    train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=False, num_workers=0)
    val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0)

    # Initialize model and transfer to device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = UNet()
    model = model.to(device)
    optim = torch.optim.Adam(model.parameters(), lr=0.0001)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='max', verbose=True)
    loss_obj = EggsPansLoss()
    metrics_obj = EggsPansMetricIoU()

    # Keep best IoU and checkpoint
    best_iou = 0.0

    # Train epochs
    for epoch_idx in range(epoch):

        print('Epoch: {:2}/{}'.format(epoch_idx + 1, epoch))
        # Reset metrics and loss
        loss_obj.reset_loss()
        metrics_obj.reset_iou()

        # Train phase
        model.train()

        # Train epoch
        pbar = tqdm(train_dataloader)
        for imgs, egg_masks, pan_masks in pbar:

            # Convert to device
            imgs = imgs.to(device)
            gt_egg_masks = egg_masks.to(device)
            gt_pan_masks = pan_masks.to(device)

            # Zero gradients
            optim.zero_grad()

            # Forward through net, and get the loss
            pred_egg_masks, pred_pan_masks = model(imgs)

            loss = loss_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks])
            iou = metrics_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks])

            # Compute gradients and compute them
            loss.backward()
            optim.step()

            # Update metrics
            pbar.set_description('Loss: {:5.6f}, IoU: {:5.6f}'.format(loss_obj.get_running_loss(),
                                                                      metrics_obj.get_running_iou()))

        print('Validation: ')

        # Reset metrics and loss
        loss_obj.reset_loss()
        metrics_obj.reset_iou()

        # Val phase
        model.eval()

        # Val epoch
        pbar = tqdm(val_dataloader)
        for imgs, egg_masks, pan_masks in pbar:

            # Convert to device
            imgs = imgs.to(device)
            gt_egg_masks = egg_masks.to(device)
            gt_pan_masks = pan_masks.to(device)

            with torch.no_grad():
                # Forward through net, and get the loss
                pred_egg_masks, pred_pan_masks = model(imgs)

                loss = loss_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks])
                iou = metrics_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks])

            pbar.set_description('Val Loss: {:5.6f}, IoU: {:5.6f}'.format(loss_obj.get_running_loss(),
                                                                          metrics_obj.get_running_iou()))

        # Save best model
        if best_iou < metrics_obj.get_running_iou():
            best_iou = metrics_obj.get_running_iou()
            torch.save(model.state_dict(), os.path.join('checkpoints/', 'epoch_{}_{:.4f}.pth'.format(
                epoch_idx + 1, metrics_obj.get_running_iou())))

        # Reduce learning rate on plateau
        lr_scheduler.step(metrics_obj.get_running_iou())

        print('\n')
        print('-'*100)
Пример #10
0
class Model:
    def __init__(self, train_dl, val_dl):
        self.device = ('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.train_dl = train_dl
        self.val_dl = val_dl

        self.loss = Loss()
        self.net = UNet(1).to(self.device)
        self.net.apply(Model._init_weights)
        self.criterion = self.loss.BCEDiceLoss
        self.optim = None
        self.scheduler = None

        self._init_optim(LR, BETAS)

        self.cycles = 0
        self.hist = {'train': [], 'val': [], 'loss': []}

        utils.create_dir('./pt')
        utils.log_data_to_txt('train_log', f'\nUsing device {self.device}')

    def _init_optim(self, lr, betas):
        self.optim = optim.Adam(utils.filter_gradients(self.net), lr=lr)

        self.scheduler = optim.lr_scheduler.StepLR(self.optim,
                                                   step_size=100,
                                                   gamma=.75)

    def _save_models(self):
        utils.save_state_dict(self.net, 'model', './pt')
        utils.save_state_dict(self.optim, 'optim', './pt')
        utils.save_state_dict(self.scheduler, 'scheduler', './pt')

    def train(self, epochs):
        self.net.train()
        for epoch in range(epochs):
            self.net.train()
            for idx, data in enumerate(self.train_dl):
                batch_time = time.time()

                self.cycles += 1
                print(self.cycles)

                image = data['MRI'].to(self.device)
                target = data['Mask'].to(self.device)

                output = self.net(image)

                output_rounded = np.copy(output.data.cpu().numpy())
                output_rounded[np.nonzero(output_rounded < 0.5)] = 0.
                output_rounded[np.nonzero(output_rounded >= 0.5)] = 1.
                train_f1 = self.loss.F1_metric(output_rounded,
                                               target.data.cpu().numpy())

                loss = self.criterion(output, target)

                self.hist['train'].append(train_f1)
                self.hist['loss'].append(loss.item())

                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                self.scheduler.step()

                if self.cycles % 100 == 0:
                    self._save_models()
                    val_f1 = self.evaluate()
                    utils.log_data_to_txt(
                        'train_log',
                        f'\nEpoch: {epoch}/{epochs} - Batch: {idx * BATCH_SIZE}/{len(self.train_dl.dataset)}'
                        f'\nLoss: {loss.mean().item():.4f}'
                        f'\nTrain F1: {train_f1:.4f} - Val F1: {val_f1}'
                        f'\nTime taken: {time.time() - batch_time:.4f}s')

    def evaluate(self):
        # model.eval()
        loss_v = 0

        with torch.no_grad():
            for idx, data in enumerate(self.val_dl):
                image, target = data['MRI'], data['Mask']

                image = image.to(self.device)
                target = target.to(self.device)

                outputs = self.net(image)

                out_thresh = np.copy(outputs.data.cpu().numpy())
                out_thresh[np.nonzero(out_thresh < .3)] = 0.0
                out_thresh[np.nonzero(out_thresh >= .3)] = 1.0

                loss = self.loss.F1_metric(out_thresh,
                                           target.data.cpu().numpy())
                loss_v += loss

        return loss_v / idx

    @classmethod
    def _init_weights(cls, layer: nn.Module):
        name = layer.__class__.__name__
        if name.find('Conv') != -1 and name.find('2d') != -1:
            nn.init.normal_(layer.weight.data, .0, 2e-2)
        if name.find('BatchNorm') != -1:
            nn.init.normal_(layer.weight.data, 1.0, 2e-2)
            nn.init.constant_(layer.bias.data, .0)
Пример #11
0
class Trainer:
    @classmethod
    def intersection_over_union(cls, y, z):
        iou = (torch.sum(torch.min(y, z))) / (torch.sum(torch.max(y, z)))
        return iou

    @classmethod
    def get_number_of_batches(cls, image_paths, batch_size):
        batches = len(image_paths) / batch_size
        if not batches.is_integer():
            batches = math.floor(batches) + 1
        return int(batches)

    @classmethod
    def evaluate_loss(cls, criterion, output, target):
        loss_1 = criterion(output, target)
        loss_2 = 1 - Trainer.intersection_over_union(output, target)
        loss = loss_1 + 0.1 * loss_2
        return loss

    def __init__(self, side_length, batch_size, epochs, learning_rate,
                 momentum_parameter, seed, image_paths, state_dict,
                 train_val_split):
        self.side_length = side_length
        self.batch_size = batch_size
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.momentum_parameter = momentum_parameter
        self.seed = seed
        self.image_paths = glob.glob(image_paths)
        self.batches = Trainer.get_number_of_batches(self.image_paths,
                                                     self.batch_size)
        self.model = UNet()
        self.loader = Loader(self.side_length)
        self.state_dict = state_dict
        self.train_val_split = train_val_split
        self.train_size = int(np.floor((self.train_val_split * self.batches)))

    def set_cuda(self):
        if torch.cuda.is_available():
            self.model = self.model.cuda()

    def set_seed(self):
        if self.seed is not None:
            np.random.seed(self.seed)

    def process_batch(self, batch):
        # Grab a batch, shuffled according to the provided seed. Note that
        # i-th image: samples[i][0], i-th mask: samples[i][1]
        samples = Loader.get_batch(self.image_paths, self.batch_size, batch,
                                   self.seed)
        samples.astype(float)
        # Cast samples into torch.FloatTensor for interaction with U-Net
        samples = torch.from_numpy(samples)
        samples = samples.float()

        # Cast into a CUDA tensor, if GPUs are available
        if torch.cuda.is_available():
            samples = samples.cuda()

        # Isolate images and their masks
        samples_images = samples[:, 0]
        samples_masks = samples[:, 1]

        # Reshape for interaction with U-Net
        samples_images = samples_images.unsqueeze(1)
        samples_masks = samples_masks.unsqueeze(1)

        # Run inputs through the model
        output = self.model(samples_images)

        # Clamp the target for proper interaction with BCELoss
        target = torch.clamp(samples_masks, min=0, max=1)

        del samples

        return output, target

    def train_model(self):
        self.model.train()
        criterion = nn.BCELoss()
        optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)

        iteration = 0
        best_iteration = 0
        best_loss = 10**10

        losses_train = []
        losses_val = []

        iou_train = []
        average_iou_train = []
        iou_val = []
        average_iou_val = []

        print("BEGIN TRAINING")
        print("TRAINING BATCHES:", self.train_size)
        print("VALIDATION BATCHES:", self.batches - self.train_size)
        print("BATCH SIZE:", self.batch_size)
        print("EPOCHS:", self.epochs)
        print("~~~~~~~~~~~~~~~~~~~~~~~~~~")

        for k in range(0, self.epochs):
            print("EPOCH:", k + 1)
            print("~~~~~~~~~~~~~~~~~~~~~~~~~~")

            # Train
            for batch in range(0, self.train_size):
                iteration = iteration + 1
                output, target = self.process_batch(batch)
                loss = Trainer.evaluate_loss(criterion, output, target)
                print("EPOCH:", self.epochs)
                print("Batch", batch, "of", self.train_size)

                # Aggregate intersection over union scores for each element in the batch
                for i in range(0, output.shape[0]):
                    binary_mask = Editor.make_binary_mask_from_torch(
                        output[i, :, :, :], 1.0)
                    iou = Trainer.intersection_over_union(
                        binary_mask, target[i, :, :, :].cpu())
                    iou_train.append(iou.item())
                    print("IoU:", iou.item())

                # Clear data to prevent memory overload
                del target
                del output

                # Clear gradients, back-propagate, and update weights
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Record the loss value
                loss_value = loss.item()
                if best_loss > loss_value:
                    best_loss = loss_value
                    best_iteration = iteration
                losses_train.append(loss_value)

                if batch == self.train_size - 1:
                    print("LOSS:", loss_value)
                    print("~~~~~~~~~~~~~~~~~~~~~~~~~~")

            average_iou = sum(iou_train) / len(iou_train)
            print("Average IoU:", average_iou)
            average_iou_train.append(average_iou)
            #Visualizer.save_loss_plot(average_iou_train, "average_iou_train.png")

            # Validate
            for batch in range(self.train_size, self.batches):
                output, target = self.process_batch(batch)
                loss = Trainer.evaluate_loss(criterion, output, target)

                for i in range(0, output.shape[0]):
                    binary_mask = Editor.make_binary_mask_from_torch(
                        output[i, :, :, :], 1.0)
                    iou = Trainer.intersection_over_union(
                        binary_mask, target[i, :, :, :].cpu())
                    iou_val.append(iou.item())
                    print("IoU:", iou.item())

                loss_value = loss.item()
                losses_val.append(loss_value)
                print("EPOCH:", self.epochs)
                print("VALIDATION LOSS:", loss_value)
                print("~~~~~~~~~~~~~~~~~~~~~~~~~~")
                del output
                del target

            average_iou = sum(iou_val) / len(iou_val)
            print("Average IoU:", average_iou)
            average_iou_val.append(average_iou)
            #Visualizer.save_loss_plot(average_iou_val, "average_iou_val.png")

        print("Least loss", best_loss, "at iteration", best_iteration)

        torch.save(self.model.state_dict(), "weights/" + self.state_dict)
Пример #12
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--jaccard-weight', type=float, default=1)
    arg('--root', type=str, default='runs/debug', help='checkpoint root')
    arg('--image-path', type=str, default='data', help='image path')
    arg('--batch-size', type=int, default=2)
    arg('--n-epochs', type=int, default=100)
    arg('--optimizer', type=str, default='Adam', help='Adam or SGD')
    arg('--lr', type=float, default=0.001)
    arg('--workers', type=int, default=10)
    arg('--model',
        type=str,
        default='UNet16',
        choices=[
            'UNet', 'UNet11', 'UNet16', 'LinkNet34', 'FCDenseNet57',
            'FCDenseNet67', 'FCDenseNet103'
        ])
    arg('--model-weight', type=str, default=None)
    arg('--resume-path', type=str, default=None)
    arg('--attribute',
        type=str,
        default='all',
        choices=[
            'pigment_network', 'negative_network', 'streaks',
            'milia_like_cyst', 'globules', 'all'
        ])
    args = parser.parse_args()

    ## folder for checkpoint
    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    image_path = args.image_path

    #print(args)
    if args.attribute == 'all':
        num_classes = 5
    else:
        num_classes = 1
    args.num_classes = num_classes
    ### save initial parameters
    print('--' * 10)
    print(args)
    print('--' * 10)
    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    ## load pretrained model
    if args.model == 'UNet':
        model = UNet(num_classes=num_classes)
    elif args.model == 'UNet11':
        model = UNet11(num_classes=num_classes, pretrained='vgg')
    elif args.model == 'UNet16':
        model = UNet16(num_classes=num_classes, pretrained='vgg')
    elif args.model == 'LinkNet34':
        model = LinkNet34(num_classes=num_classes, pretrained=True)
    elif args.model == 'FCDenseNet103':
        model = FCDenseNet103(num_classes=num_classes)
    else:
        model = UNet(num_classes=num_classes, input_channels=3)

    ## multiple GPUs
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.to(device)

    ## load pretrained model
    if args.model_weight is not None:
        state = torch.load(args.model_weight)
        #epoch = state['epoch']
        #step = state['step']
        model.load_state_dict(state['model'])
        print('--' * 10)
        print('Load pretrained model', args.model_weight)
        #print('Restored model, epoch {}, step {:,}'.format(epoch, step))
        print('--' * 10)
        ## replace the last layer
        ## although the model and pre-trained weight have differernt size (the last layer is different)
        ## pytorch can still load the weight
        ## I found that the weight for one layer just duplicated for all layers
        ## therefore, the following code is not necessary
        # if args.attribute == 'all':
        #     model = list(model.children())[0]
        #     num_filters = 32
        #     model.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
        #     print('--' * 10)
        #     print('Load pretrained model and replace the last layer', args.model_weight, num_classes)
        #     print('--' * 10)
        #     if torch.cuda.device_count() > 1:
        #         model = nn.DataParallel(model)
        #     model.to(device)

    ## model summary
    print_model_summay(model)

    ## define loss
    loss_fn = LossBinary(jaccard_weight=args.jaccard_weight)

    ## It enables benchmark mode in cudnn.
    ## benchmark mode is good whenever your input sizes for your network do not vary. This way, cudnn will look for the
    ## optimal set of algorithms for that particular configuration (which takes some time). This usually leads to faster runtime.
    ## But if your input sizes changes at each iteration, then cudnn will benchmark every time a new size appears,
    ## possibly leading to worse runtime performances.
    cudnn.benchmark = True

    ## get train_test_id
    train_test_id = get_split()

    ## train vs. val
    print('--' * 10)
    print('num train = {}, num_val = {}'.format(
        (train_test_id['Split'] == 'train').sum(),
        (train_test_id['Split'] != 'train').sum()))
    print('--' * 10)

    train_transform = DualCompose(
        [HorizontalFlip(),
         VerticalFlip(),
         ImageOnly(Normalize())])

    val_transform = DualCompose([ImageOnly(Normalize())])

    ## define data loader
    train_loader = make_loader(train_test_id,
                               image_path,
                               args,
                               train=True,
                               shuffle=True,
                               transform=train_transform)
    valid_loader = make_loader(train_test_id,
                               image_path,
                               args,
                               train=False,
                               shuffle=True,
                               transform=val_transform)

    if True:
        print('--' * 10)
        print('check data')
        train_image, train_mask, train_mask_ind = next(iter(train_loader))
        print('train_image.shape', train_image.shape)
        print('train_mask.shape', train_mask.shape)
        print('train_mask_ind.shape', train_mask_ind.shape)
        print('train_image.min', train_image.min().item())
        print('train_image.max', train_image.max().item())
        print('train_mask.min', train_mask.min().item())
        print('train_mask.max', train_mask.max().item())
        print('train_mask_ind.min', train_mask_ind.min().item())
        print('train_mask_ind.max', train_mask_ind.max().item())
        print('--' * 10)

    valid_fn = validation_binary

    ###########
    ## optimizer
    if args.optimizer == 'Adam':
        optimizer = Adam(model.parameters(), lr=args.lr)
    elif args.optimizer == 'SGD':
        optimizer = SGD(model.parameters(), lr=args.lr, momentum=0.9)

    ## loss
    criterion = loss_fn
    ## change LR
    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  factor=0.8,
                                  patience=5,
                                  verbose=True)

    ##########
    ## load previous model status
    previous_valid_loss = 10
    model_path = root / 'model.pt'
    if args.resume_path is not None and model_path.exists():
        state = torch.load(str(model_path))
        epoch = state['epoch']
        step = state['step']
        model.load_state_dict(state['model'])
        epoch = 1
        step = 0
        try:
            previous_valid_loss = state['valid_loss']
        except:
            previous_valid_loss = 10
        print('--' * 10)
        print('Restored previous model, epoch {}, step {:,}'.format(
            epoch, step))
        print('--' * 10)
    else:
        epoch = 1
        step = 0

    #########
    ## start training
    log = root.joinpath('train.log').open('at', encoding='utf8')
    writer = SummaryWriter()
    meter = AllInOneMeter()
    #if previous_valid_loss = 10000
    print('Start training')
    print_model_summay(model)
    previous_valid_jaccard = 0
    for epoch in range(epoch, args.n_epochs + 1):
        model.train()
        random.seed()
        #jaccard = []
        start_time = time.time()
        meter.reset()
        w1 = 1.0
        w2 = 0.5
        w3 = 0.5
        try:
            train_loss = 0
            valid_loss = 0
            # if epoch == 1:
            #     freeze_layer_names = get_freeze_layer_names(part='encoder')
            #     set_freeze_layers(model, freeze_layer_names=freeze_layer_names)
            #     #set_train_layers(model, train_layer_names=['module.final.weight','module.final.bias'])
            #     print_model_summay(model)
            # elif epoch == 5:
            #     w1 = 1.0
            #     w2 = 0.0
            #     w3 = 0.5
            #     freeze_layer_names = get_freeze_layer_names(part='encoder')
            #     set_freeze_layers(model, freeze_layer_names=freeze_layer_names)
            #     # set_train_layers(model, train_layer_names=['module.final.weight','module.final.bias'])
            #     print_model_summay(model)
            #elif epoch == 3:
            #     set_train_layers(model, train_layer_names=['module.dec5.block.0.conv.weight','module.dec5.block.0.conv.bias',
            #                                                'module.dec5.block.1.weight','module.dec5.block.1.bias',
            #                                                'module.dec4.block.0.conv.weight','module.dec4.block.0.conv.bias',
            #                                                'module.dec4.block.1.weight','module.dec4.block.1.bias',
            #                                                'module.dec3.block.0.conv.weight','module.dec3.block.0.conv.bias',
            #                                                'module.dec3.block.1.weight','module.dec3.block.1.bias',
            #                                                'module.dec2.block.0.conv.weight','module.dec2.block.0.conv.bias',
            #                                                'module.dec2.block.1.weight','module.dec2.block.1.bias',
            #                                                'module.dec1.conv.weight','module.dec1.conv.bias',
            #                                                'module.final.weight','module.final.bias'])
            #     print_model_summa zvgf    t5y(model)
            # elif epoch == 50:
            #     set_freeze_layers(model, freeze_layer_names=None)
            #     print_model_summay(model)
            for i, (train_image, train_mask,
                    train_mask_ind) in enumerate(train_loader):
                # inputs, targets = variable(inputs), variable(targets)

                train_image = train_image.permute(0, 3, 1, 2)
                train_mask = train_mask.permute(0, 3, 1, 2)
                train_image = train_image.to(device)
                train_mask = train_mask.to(device).type(torch.cuda.FloatTensor)
                train_mask_ind = train_mask_ind.to(device).type(
                    torch.cuda.FloatTensor)
                # if args.problem_type == 'binary':
                #     train_mask = train_mask.to(device).type(torch.cuda.FloatTensor)
                # else:
                #     #train_mask = train_mask.to(device).type(torch.cuda.LongTensor)
                #     train_mask = train_mask.to(device).type(torch.cuda.FloatTensor)

                outputs, outputs_mask_ind1, outputs_mask_ind2 = model(
                    train_image)
                #print(outputs.size())
                #print(outputs_mask_ind1.size())
                #print(outputs_mask_ind2.size())
                ### note that the last layer in the model is defined differently
                # if args.problem_type == 'binary':
                #     train_prob = F.sigmoid(outputs)
                #     loss = criterion(outputs, train_mask)
                # else:
                #     #train_prob = outputs
                #     train_prob = F.sigmoid(outputs)
                #     loss = torch.tensor(0).type(train_mask.type())
                #     for feat_inx in range(train_mask.shape[1]):
                #         loss += criterion(outputs, train_mask)
                train_prob = F.sigmoid(outputs)
                train_mask_ind_prob1 = F.sigmoid(outputs_mask_ind1)
                train_mask_ind_prob2 = F.sigmoid(outputs_mask_ind2)
                loss1 = criterion(outputs, train_mask)
                #loss1 = F.binary_cross_entropy_with_logits(outputs, train_mask)
                #loss2 = nn.BCEWithLogitsLoss()(outputs_mask_ind1, train_mask_ind)
                #print(train_mask_ind.size())
                #weight = torch.ones_like(train_mask_ind)
                #weight[:, 0] = weight[:, 0] * 1
                #weight[:, 1] = weight[:, 1] * 14
                #weight[:, 2] = weight[:, 2] * 14
                #weight[:, 3] = weight[:, 3] * 4
                #weight[:, 4] = weight[:, 4] * 4
                #weight = weight * train_mask_ind + 1
                #weight = weight.to(device).type(torch.cuda.FloatTensor)
                loss2 = F.binary_cross_entropy_with_logits(
                    outputs_mask_ind1, train_mask_ind)
                loss3 = F.binary_cross_entropy_with_logits(
                    outputs_mask_ind2, train_mask_ind)
                #loss3 = criterion(outputs_mask_ind2, train_mask_ind)
                loss = loss1 * w1 + loss2 * w2 + loss3 * w3
                #print(loss1.item(), loss2.item(), loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                step += 1
                #jaccard += [get_jaccard(train_mask, (train_prob > 0).float()).item()]
                meter.add(train_prob, train_mask, train_mask_ind_prob1,
                          train_mask_ind_prob2, train_mask_ind, loss1.item(),
                          loss2.item(), loss3.item(), loss.item())
                # print(train_mask.data.shape)
                # print(train_mask.data.sum(dim=-2).shape)
                # print(train_mask.data.sum(dim=-2).sum(dim=-1).shape)
                # print(train_mask.data.sum(dim=-2).sum(dim=-1).sum(dim=0).shape)
                # intersection = train_mask.data.sum(dim=-2).sum(dim=-1)
                # print(intersection.shape)
                # print(intersection.dtype)
                # print(train_mask.data.shape[0])
                #torch.zeros([2, 4], dtype=torch.float32)
            #########################
            ## at the end of each epoch, evualte the metrics
            epoch_time = time.time() - start_time
            train_metrics = meter.value()
            train_metrics['epoch_time'] = epoch_time
            train_metrics['image'] = train_image.data
            train_metrics['mask'] = train_mask.data
            train_metrics['prob'] = train_prob.data

            #train_jaccard = np.mean(jaccard)
            #train_auc = str(round(mtr1.value()[0],2))+' '+str(round(mtr2.value()[0],2))+' '+str(round(mtr3.value()[0],2))+' '+str(round(mtr4.value()[0],2))+' '+str(round(mtr5.value()[0],2))
            valid_metrics = valid_fn(model, criterion, valid_loader, device,
                                     num_classes)
            ##############
            ## write events
            write_event(log,
                        step,
                        epoch=epoch,
                        train_metrics=train_metrics,
                        valid_metrics=valid_metrics)
            #save_weights(model, model_path, epoch + 1, step)
            #########################
            ## tensorboard
            write_tensorboard(writer,
                              model,
                              epoch,
                              train_metrics=train_metrics,
                              valid_metrics=valid_metrics)
            #########################
            ## save the best model
            valid_loss = valid_metrics['loss1']
            valid_jaccard = valid_metrics['jaccard']
            if valid_loss < previous_valid_loss:
                save_weights(model, model_path, epoch + 1, step, train_metrics,
                             valid_metrics)
                previous_valid_loss = valid_loss
                print('Save best model by loss')
            if valid_jaccard > previous_valid_jaccard:
                save_weights(model, model_path, epoch + 1, step, train_metrics,
                             valid_metrics)
                previous_valid_jaccard = valid_jaccard
                print('Save best model by jaccard')
            #########################
            ## change learning rate
            scheduler.step(valid_metrics['loss1'])

        except KeyboardInterrupt:
            # print('--' * 10)
            # print('Ctrl+C, saving snapshot')
            # save_weights(model, model_path, epoch, step)
            # print('done.')
            # print('--' * 10)
            writer.close()
            #return
    writer.close()
class Runner(object):
    def __init__(self,
                 hparams,
                 train_size: int,
                 class_weight: Optional[Tensor] = None):
        # model, criterion, and prediction
        self.model = UNet(ch_in=2, ch_out=1, **hparams.model)
        self.sigmoid = torch.nn.Sigmoid()
        self.criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
        self.class_weight = class_weight

        # for prediction
        self.frame2time = hparams.hop_size / hparams.sample_rate
        self.T_6s = round(6 / self.frame2time) - 1
        self.T_12s = round(12 / self.frame2time) - 1
        self.metrics = ('precision', 'recall', 'F1')

        # optimizer and scheduler
        self.optimizer = AdamW(
            self.model.parameters(),
            lr=hparams.learning_rate,
            weight_decay=hparams.weight_decay,
        )
        self.scheduler = CosineLRWithRestarts(self.optimizer,
                                              batch_size=hparams.batch_size,
                                              epoch_size=train_size,
                                              **hparams.scheduler)
        self.scheduler.step()

        self.f1_last_restart = -1

        # device
        device_for_summary = self._init_device(hparams.device,
                                               hparams.out_device)

        # summary
        self.writer = SummaryWriter(logdir=hparams.logdir)
        path_summary = Path(self.writer.logdir, 'summary.txt')
        if not path_summary.exists():
            print_to_file(path_summary, summary,
                          (self.model,
                           (2, 128, 16 * hparams.model['stride'][1]**4)),
                          dict(device=device_for_summary))

        # save hyperparameters
        path_hparam = Path(self.writer.logdir, 'hparams.txt')
        if not path_hparam.exists():
            with path_hparam.open('w') as f:
                for var in vars(hparams):
                    value = getattr(hparams, var)
                    print(f'{var}: {value}', file=f)

    def _init_device(self, device, out_device) -> str:
        if device == 'cpu':
            self.device = torch.device('cpu')
            self.out_device = torch.device('cpu')
            self.str_device = 'cpu'
            return 'cpu'

        # device type
        if type(device) == int:
            device = [device]
        elif type(device) == str:
            device = [int(device[-1])]
        else:  # sequence of devices
            if type(device[0]) == int:
                device = device
            else:
                device = [int(d[-1]) for d in device]

        # out_device type
        if type(out_device) == int:
            out_device = torch.device(f'cuda:{out_device}')
        else:
            out_device = torch.device(out_device)

        self.device = torch.device(f'cuda:{device[0]}')
        self.out_device = out_device

        if len(device) > 1:
            self.model = nn.DataParallel(self.model,
                                         device_ids=device,
                                         output_device=out_device)
            self.str_device = ', '.join([f'cuda:{d}' for d in device])
        else:
            self.str_device = str(self.device)

        self.model.cuda(device[0])
        self.criterion.cuda(out_device)
        if self.sigmoid:
            self.sigmoid.cuda(device[0])
        torch.cuda.set_device(device[0])
        return 'cuda'

    def calc_loss(self, y: Tensor, out: Tensor, Ts: Union[List[int],
                                                          int]) -> Tensor:
        """

        :param y: (B, T) or (T,)
        :param out: (B, T) or (T,)
        :param Ts: length B list or int
        :return:
        """
        assert self.class_weight is not None
        weight = (y > 0).float() * self.class_weight[1].item()
        weight += (y == 0).float() * self.class_weight[0].item()

        if y.dim() == 1:  # if batch_size == 1
            y = (y, )
            out = (out, )
            weight = (weight, )
            Ts = (Ts, )

        loss = torch.zeros(1, device=self.out_device)
        for ii, T in enumerate(Ts):
            loss_no_red = self.criterion(out[ii:ii + 1, ..., :T],
                                         y[ii:ii + 1, :T])
            loss += (loss_no_red * weight[ii:ii + 1, :T]).sum() / T

        return loss

    def predict(self, out_np: ndarray, Ts: Union[List[int], int]) \
            -> Tuple[List[ndarray], List]:
        """ peak-picking prediction

        :param out_np: (B, T) or (T,)
        :param Ts: length B list or int
        :return: boundaries, thresholds
            boundaries: length B list of boundary interval ndarrays
            thresholds: length B list of threshold values
        """
        if out_np.ndim == 1:  # if batch_size == 1
            out_np = (out_np, )
            Ts = (Ts, )

        boundaries = []
        thresholds = []
        for item, T in zip(out_np, Ts):
            candid_idx = []
            for idx in range(1, T - 1):
                i_first = max(idx - self.T_6s, 0)
                i_last = min(idx + self.T_6s + 1, T)
                if item[idx] >= np.amax(item[i_first:i_last]):
                    candid_idx.append(idx)

            boundary_idx = []
            threshold = np.mean(item[candid_idx])
            for idx in candid_idx:
                if item[idx] > threshold:
                    boundary_idx.append(idx)

            boundary_interval = np.array(
                [[0] + boundary_idx, boundary_idx + [T]], dtype=np.float64).T
            boundary_interval *= self.frame2time

            boundaries.append(boundary_interval)
            thresholds.append(threshold)

        return boundaries, thresholds

    @staticmethod
    def evaluate(reference: Union[List[ndarray], ndarray],
                 prediction: Union[List[ndarray], ndarray]):
        """

        :param reference: length B list of ndarray or just ndarray
        :param prediction: length B list of ndarray or just ndarray
        :return: (3,) ndarray
        """
        if isinstance(reference, ndarray):  # if batch_size == 1
            reference = (reference, )

        result = np.zeros(3)
        for item_truth, item_pred in zip(reference, prediction):
            mir_result = mir_eval.segment.detection(item_truth,
                                                    item_pred,
                                                    trim=True)
            result += np.array(mir_result)

        return result

    # Running model for train, test and validation.
    def run(self, dataloader, mode: str, epoch: int):
        self.model.train() if mode == 'train' else self.model.eval()
        if mode == 'test':
            state_dict = torch.load(Path(self.writer.logdir, f'{epoch}.pt'))
            if isinstance(self.model, nn.DataParallel):
                self.model.module.load_state_dict(state_dict)
            else:
                self.model.load_state_dict(state_dict)
            path_test_result = Path(self.writer.logdir, f'test_{epoch}')
            os.makedirs(path_test_result, exist_ok=True)
        else:
            path_test_result = None

        avg_loss = 0.
        avg_eval = 0.
        all_thresholds = dict()
        print()
        pbar = tqdm(dataloader,
                    desc=f'{mode} {epoch:3d}',
                    postfix='-',
                    dynamic_ncols=True)

        for i_batch, (x, y, intervals, Ts, ids) in enumerate(pbar):
            # data
            n_batch = len(Ts) if hasattr(Ts, 'len') else 1
            x = x.to(self.device)  # B, C, F, T
            x = dataloader.dataset.normalization.normalize_(x)
            y = y.to(self.out_device)  # B, T

            # forward
            out = self.model(x)  # B, C, 1, T
            out = out[..., 0, 0, :]  # B, T

            # loss
            if mode != 'test':
                if mode == 'valid':
                    with torch.autograd.detect_anomaly():
                        loss = self.calc_loss(y, out, Ts)
                else:
                    loss = self.calc_loss(y, out, Ts)

            else:
                loss = 0

            out_np = self.sigmoid(out).detach().cpu().numpy()
            prediction, thresholds = self.predict(out_np, Ts)

            eval_result = self.evaluate(intervals, prediction)

            if mode == 'train':
                # backward
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.scheduler.batch_step()
                loss = loss.item()
            elif mode == 'valid':
                loss = loss.item()
                if i_batch == 0:  # save only the 0-th data
                    id_0, T_0 = ids[0], Ts[0]
                    out_np_0 = out_np[0, :T_0]
                    pred_0, truth_0 = prediction[0][1:, 0], intervals[0][1:, 0]
                    t_axis = np.arange(T_0) * self.frame2time
                    fig = draw_lineplot(t_axis, out_np_0, pred_0, truth_0,
                                        id_0)
                    self.writer.add_figure(f'{mode}/out', fig, epoch)
                    np.save(Path(self.writer.logdir, f'{id_0}_{epoch}.npy'),
                            out_np_0)
                    np.save(
                        Path(self.writer.logdir, f'{id_0}_{epoch}_pred.npy'),
                        pred_0)
                    if epoch == 0:
                        np.save(Path(self.writer.logdir, f'{id_0}_truth.npy'),
                                truth_0)
            else:
                # save all test data
                for id_, item_truth, item_pred, item_out, threshold, T \
                        in zip(ids, intervals, prediction, out_np, thresholds, Ts):
                    np.save(path_test_result / f'{id_}_truth.npy', item_truth)
                    np.save(path_test_result / f'{id_}.npy', item_out[:T])
                    np.save(path_test_result / f'{id_}_pred.npy', item_pred)
                    all_thresholds[str(id_)] = threshold

            str_eval = np.array2string(eval_result / n_batch, precision=3)
            pbar.set_postfix_str(f'{loss / n_batch:.3f}, {str_eval}')

            avg_loss += loss
            avg_eval += eval_result

        avg_loss = avg_loss / len(dataloader.dataset)
        avg_eval = avg_eval / len(dataloader.dataset)

        if mode == 'test':
            np.savez(path_test_result / f'thresholds.npz', **all_thresholds)

        return avg_loss, avg_eval

    def step(self, valid_f1: float, epoch: int):
        """

        :param valid_f1:
        :param epoch:
        :return: test epoch or 0
        """
        last_restart = self.scheduler.last_restart
        self.scheduler.step()  # scheduler.last_restart can be updated

        if epoch == self.scheduler.last_restart:
            if valid_f1 < self.f1_last_restart:
                return last_restart
            else:
                self.f1_last_restart = valid_f1
                torch.save(self.model.module.state_dict(),
                           Path(self.writer.logdir, f'{epoch}.pt'))

        return 0
Пример #14
0
def main():
    # 네트워크
    G = UNet().to(device)
    D = Discriminator().to(device)

    # 네트워크 초기화
    G.apply(weight_init)
    D.apply(weight_init)

    # pretrained 모델 불러오기
    if args.reuse:
        assert os.path.isfile(args.save_path), '[!]Pretrained model not found'
        checkpoint = torch.load(args.save_path)
        G.load_state_dict(checkpoint['G'])
        D.load_state_dict(checkpoint['D'])
        print('[*]Pretrained model loaded')

    # optimizer
    G_optim = optim.Adam(G.parameters(), lr=args.lr, betas=(args.b1, args.b2))
    D_optim = optim.Adam(D.parameters(), lr=args.lr, betas=(args.b1, args.b2))

    for epoch in range(args.num_epoch):
        for i, imgs in enumerate(dataloader['train']):
            A = imgs['A'].to(device)
            B = imgs['B'].to(device)

            # # # # #
            # Discriminator
            # # # # #
            G.eval()
            D.train()

            fake = G(B)
            D_fake = D(fake, B)
            D_real = D(A, B)

            # original loss D
            loss_D = -((D_real.log() + (1 - D_fake).log()).mean())

            #            # LSGAN loss D
            #            loss_D = ((D_real - 1)**2).mean() + (D_fake**2).mean()

            D_optim.zero_grad()
            loss_D.backward()
            D_optim.step()

            # # # # #
            # Generator
            # # # # #
            G.train()
            D.eval()

            fake = G(B)
            D_fake = D(fake, B)

            # original loss G
            loss_G = -(D_fake.mean().log()
                       ) + args.lambda_recon * torch.abs(A - fake).mean()

            #            # LSGAN loss G
            #            loss_G = ((D_fake-1)**2).mean() + args.lambda_recon * torch.abs(A - fake).mean()

            G_optim.zero_grad()
            loss_G.backward()
            G_optim.step()

            # 학습 진행사항 출력
            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                  (epoch, args.num_epoch, i * args.batch_size,
                   len(datasets['train']), loss_D.item(), loss_G.item()))

        # 이미지 저장 (save per epoch)
        val = next(iter(dataloader['test']))
        real_A = val['A'].to(device)
        real_B = val['B'].to(device)

        with torch.no_grad():
            fake_A = G(real_B)
        save_image(torch.cat([real_A, real_B, fake_A], dim=3),
                   'images/{0:03d}.png'.format(epoch + 1),
                   nrow=2,
                   normalize=True)

        # 모델 저장
        torch.save({
            'G': G.state_dict(),
            'D': D.state_dict(),
        }, args.save_path)
Пример #15
0
                      buffer_size=16,
                      shuffle=True,
                      parallel_method='thread')

eval_reader = Reader(data_dir=data_dir,
                     file_list=val_list,
                     label_list=label_list,
                     transforms=eval_transforms,
                     num_workers=8,
                     buffer_size=16,
                     shuffle=False,
                     parallel_method='thread')

model = UNet(num_classes=2,
             input_channel=channel,
             use_bce_loss=True,
             use_dice_loss=True)

model.train(
    num_epochs=num_epochs,
    train_reader=train_reader,
    train_batch_size=train_batch_size,
    eval_reader=eval_reader,
    save_interval_epochs=5,
    log_interval_steps=10,
    save_dir=save_dir,
    pretrain_weights=None,
    optimizer=None,
    learning_rate=lr,
)