Пример #1
0
def train(device, model_path, dataset_path):
    """
    Trains the network according on the dataset_path
    """
    network = UNet(1, 3).to(device)
    optimizer = torch.optim.Adam(network.parameters())
    criteria = torch.nn.MSELoss()

    dataset = GrayColorDataset(dataset_path, transform=train_transform)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=16,
                                         shuffle=True,
                                         num_workers=cpu_count())

    if os.path.exists(model_path):
        network.load_state_dict(torch.load(model_path))
    for _ in tqdm.trange(10, desc="Epoch"):
        network.train()
        for gray, color in tqdm.tqdm(loader, desc="Training", leave=False):
            gray, color = gray.to(device), color.to(device)
            optimizer.zero_grad()
            pred_color = network(gray)
            loss = criteria(pred_color, color)
            loss.backward()
            optimizer.step()
        torch.save(network.state_dict(), model_path)
Пример #2
0
def train(device, gen_model, disc_model, real_dataset_path, epochs):
    """trains a gan"""
    train_transform = tv.transforms.Compose([
        tv.transforms.Resize((224, 224)),
        tv.transforms.RandomHorizontalFlip(0.5),
        tv.transforms.RandomVerticalFlip(0.5),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, ), (0.5, ))
    ])

    realdataset = ColorDataset(real_dataset_path, transform=train_transform)
    realloader = torch.utils.data.DataLoader(realdataset,
                                             batch_size=20,
                                             shuffle=True,
                                             num_workers=cpu_count(),
                                             drop_last=True)
    realiter = iter(realloader)

    discriminator = discriminator_model(3, 1024).to(device)
    disc_optimizer = torch.optim.Adam(discriminator.parameters(),
                                      lr=0.0001,
                                      betas=(0, 0.9))
    if os.path.exists(disc_model):
        discriminator.load_state_dict(torch.load(disc_model))

    generator = UNet(1, 3).to(device)
    gen_optimizer = torch.optim.Adam(generator.parameters(),
                                     lr=0.0001,
                                     betas=(0, 0.9))
    if os.path.exists(gen_model):
        generator.load_state_dict(torch.load(gen_model))

    one = torch.FloatTensor([1])
    mone = one * -1
    one = one.to(device).squeeze()
    mone = mone.to(device).squeeze()

    n_critic = 5
    lam = 10
    for _ in tqdm.trange(epochs, desc="Epochs"):
        for param in discriminator.parameters():
            param.requires_grad = True

        for _ in range(n_critic):
            real_data, realiter = try_iter(realiter, realloader)
            real_data = real_data.to(device)

            disc_optimizer.zero_grad()

            disc_real = discriminator(real_data)
            real_cost = torch.mean(disc_real)
            real_cost.backward(mone)

            # fake_data, fakeiter = try_iter(fakeiter, fakeloader)
            fake_data = torch.randn(real_data.shape[0], 1, 224, 224)
            fake_data = fake_data.to(device)
            disc_fake = discriminator(generator(fake_data))
            fake_cost = torch.mean(disc_fake)
            fake_cost.backward(one)

            gradient_penalty = calc_gp(device, discriminator, real_data,
                                       fake_data, lam)
            gradient_penalty.backward()

            disc_optimizer.step()
        for param in discriminator.parameters():
            param.requires_grad = False
        gen_optimizer.zero_grad()

        # fake_data, fakeiter = try_iter(fakeiter, fakeloader)
        fake_data = torch.randn(real_data.shape[0], 1, 224, 224)
        fake_data = fake_data.to(device)
        disc_g = discriminator(generator(fake_data)).mean()
        disc_g.backward(mone)
        gen_optimizer.step()

        torch.save(generator.state_dict(), gen_model)
        torch.save(discriminator.state_dict(), disc_model)
Пример #3
0
def train():
    transform = transforms.Compose([
        transforms.CenterCrop(256),
        transforms.ToTensor(),
    ])
    dataset = SpectralDataSet(
        root_dir=
        '/mnt/liguanlin/DataSets/lowlight_hyperspectral_datasets/band_splited_dataset',
        type_name='train',
        transform=transform)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    unet = UNet(input_dim, label_dim).to(device)
    unet.init_weight()
    unet_opt = torch.optim.Adam(unet.parameters(), lr=lr)
    scheduler = lr_scheduler.StepLR(unet_opt, step_size, gamma=0.4)
    cur_step = 0

    for epoch in range(n_epochs):
        train_l_sum, batch_count = 0.0, 0

        for real, labels in tqdm(dataloader):
            #print('real.shape', real.shape)
            #print('labels.shape', labels.shape)
            cur_batch_size = len(real)
            # Flatten the image
            real = real.to(device)
            labels = labels.to(device)

            ### Update U-Net ###
            unet_opt.zero_grad()
            pred = unet(real)
            #print('pred.shape', pred.shape)
            unet_loss = criterion(pred, labels)
            unet_loss.backward()
            unet_opt.step()

            train_l_sum += unet_loss.cpu().item()
            batch_count += 1

            if cur_step % display_step == 0:
                print(
                    f"Epoch {epoch}: Step {cur_step}: U-Net loss: {unet_loss.item()}"
                )
                """
                show_tensor_images(
                    real,
                    size=(input_dim, target_shape, target_shape)
                )
                print('labesl.shape:', labels.shape)
                print('pred.shape:', pred.shape)
                show_tensor_images(labels, size=(label_dim, target_shape, target_shape))
                show_tensor_images(torch.sigmoid(pred), size=(label_dim, target_shape, target_shape))
                """
            cur_step += 1

        if (epoch + 1) % 2 == 0:
            torch.save(unet.state_dict(),
                       './checkpoints/checkpoint_{}.pth'.format(epoch + 1))

        unet_opt.step()  #更新学习率

        print('epoch %d, train loss %.4f' %
              (epoch + 1, train_l_sum / batch_count))
Пример #4
0
    loss_list = []
    for i in tqdm(range(args.epochs)):
        train(i, exp_lr_scheduler, loss_list)
        test()

    plt.plot(loss_list)
    plt.title("UNet bs={}, ep={}, lr={}".format(args.batch_size, args.epochs,
                                                args.lr))
    plt.xlabel("Number of iterations")
    plt.ylabel("Average DICE loss per batch")
    plt.savefig("plots/{}-UNet_Loss_bs={}_ep={}_lr={}.png".format(
        args.save, args.batch_size, args.epochs, args.lr))

    np.save(
        'npy-files/loss-files/{}-UNet_Loss_bs={}_ep={}_lr={}.npy'.format(
            args.save, args.batch_size, args.epochs, args.lr),
        np.asarray(loss_list))

    torch.save(
        model.state_dict(),
        '{}unetsmall-final-{}-{}-{}'.format(SAVE_MODEL_NAME, args.batch_size,
                                            args.epochs, args.lr))

# elif args.pred:
#     predict()

elif args.load is not None:
    model.load_state_dict(torch.load(args.load))
    #test()
    predict()
Пример #5
0
def start():
    parser = argparse.ArgumentParser(
        description='UNet + BDCLSTM for BraTS Dataset')
    parser.add_argument('--batch-size',
                        type=int,
                        default=4,
                        metavar='N',
                        help='input batch size for training (default: 4)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=4,
                        metavar='N',
                        help='input batch size for testing (default: 4)')
    parser.add_argument('--train',
                        action='store_true',
                        default=False,
                        help='Argument to train model (default: False)')
    parser.add_argument('--epochs',
                        type=int,
                        default=2,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='enables CUDA training (default: False)')
    parser.add_argument('--log-interval',
                        type=int,
                        default=1,
                        metavar='N',
                        help='batches to wait before logging training status')
    parser.add_argument('--size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='imsize')
    parser.add_argument('--load',
                        type=str,
                        default=None,
                        metavar='str',
                        help='weight file to load (default: None)')
    parser.add_argument('--data',
                        type=str,
                        default='./Data/',
                        metavar='str',
                        help='folder that contains data')
    parser.add_argument('--save',
                        type=str,
                        default='OutMasks',
                        metavar='str',
                        help='Identifier to save npy arrays with')
    parser.add_argument('--modality',
                        type=str,
                        default='flair',
                        metavar='str',
                        help='Modality to use for training (default: flair)')
    parser.add_argument('--optimizer',
                        type=str,
                        default='SGD',
                        metavar='str',
                        help='Optimizer (default: SGD)')

    args = parser.parse_args()
    args.cuda = args.cuda and torch.cuda.is_available()

    DATA_FOLDER = args.data

    # %% Loading in the model
    # Binary
    # model = UNet(num_channels=1, num_classes=2)
    # Multiclass
    model = UNet(num_channels=1, num_classes=3)

    if args.cuda:
        model.cuda()

    if args.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.99)
    if args.optimizer == 'ADAM':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))

    # Defining Loss Function
    criterion = DICELossMultiClass()

    if args.train:
        # %% Loading in the Dataset
        full_dataset = BraTSDatasetUnet(DATA_FOLDER,
                                        im_size=[args.size, args.size],
                                        transform=tr.ToTensor())
        #dset_test = BraTSDatasetUnet(DATA_FOLDER, train=False,
        # keywords=[args.modality], im_size=[args.size,args.size], transform=tr.ToTensor())

        train_size = int(0.9 * len(full_dataset))
        test_size = len(full_dataset) - train_size
        train_dataset, validation_dataset = torch.utils.data.random_split(
            full_dataset, [train_size, test_size])

        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=1)
        validation_loader = DataLoader(validation_dataset,
                                       batch_size=args.test_batch_size,
                                       shuffle=False,
                                       num_workers=1)
        #test_loader = DataLoader(full_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=1)

        print("Training Data : ", len(train_loader.dataset))
        print("Validaion Data : ", len(validation_loader.dataset))
        #print("Test Data : ", len(test_loader.dataset))

        loss_list = []
        start = timer()
        for i in tqdm(range(args.epochs)):
            train(model, i, loss_list, train_loader, optimizer, criterion,
                  args)
            test(model, validation_loader, criterion, args, validation=True)
        end = timer()
        print("Training completed in {:0.2f}s".format(end - start))

        plt.plot(loss_list)
        plt.title("UNet bs={}, ep={}, lr={}".format(args.batch_size,
                                                    args.epochs, args.lr))
        plt.xlabel("Number of iterations")
        plt.ylabel("Average DICE loss per batch")
        plt.savefig("./plots/{}-UNet_Loss_bs={}_ep={}_lr={}.png".format(
            args.save, args.batch_size, args.epochs, args.lr))

        np.save(
            './npy-files/loss-files/{}-UNet_Loss_bs={}_ep={}_lr={}.npy'.format(
                args.save, args.batch_size, args.epochs, args.lr),
            np.asarray(loss_list))
        print("Testing Validation")
        test(model, validation_loader, criterion, args, save_output=True)
        torch.save(
            model.state_dict(),
            'unet-multiclass-model-{}-{}-{}'.format(args.batch_size,
                                                    args.epochs, args.lr))

        print("Testing PDF images")
        test_dataset = TestDataset('./pdf_data/',
                                   im_size=[args.size, args.size],
                                   transform=tr.ToTensor())
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.test_batch_size,
                                 shuffle=False,
                                 num_workers=1)
        print("Test Data : ", len(test_loader.dataset))
        test_only(model, test_loader, criterion, args)

    elif args.load is not None:
        test_dataset = TestDataset(DATA_FOLDER,
                                   im_size=[args.size, args.size],
                                   transform=tr.ToTensor())
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.test_batch_size,
                                 shuffle=False,
                                 num_workers=1)
        print("Test Data : ", len(test_loader.dataset))
        model.load_state_dict(torch.load(args.load))
        test_only(model, test_loader, criterion, args)
Пример #6
0
def train():
    t.cuda.set_device(1)

    # n_channels:医学影像为一通道灰度图    n_classes:二分类
    net = UNet(n_channels=1, n_classes=1)
    optimizer = t.optim.SGD(net.parameters(),
                            lr=opt.learning_rate,
                            momentum=0.9,
                            weight_decay=0.0005)
    criterion = t.nn.BCELoss()  # 二进制交叉熵(适合mask占据图像面积较大的场景)

    start_epoch = 0
    if opt.load_model_path:
        checkpoint = t.load(opt.load_model_path)

        # 加载多GPU模型参数到 单模型上
        state_dict = checkpoint['net']
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        net.load_state_dict(new_state_dict)  # 加载模型
        optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器
        start_epoch = checkpoint['epoch']  # 加载训练批次

    # 学习率每当到达milestones值则更新参数
    if start_epoch == 0:
        scheduler = t.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=opt.milestones,
                                                     gamma=0.1,
                                                     last_epoch=-1)  # 默认为-1
        print('从头训练 ,学习率为{}'.format(optimizer.param_groups[0]['lr']))
    else:
        scheduler = t.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=opt.milestones,
                                                     gamma=0.1,
                                                     last_epoch=start_epoch)
        print('加载预训练模型{}并从{}轮开始训练,学习率为{}'.format(
            opt.load_model_path, start_epoch, optimizer.param_groups[0]['lr']))

    # 网络转移到GPU上
    if opt.use_gpu:
        net = t.nn.DataParallel(net, device_ids=opt.device_ids)  # 模型转为GPU并行
        net.cuda()
        cudnn.benchmark = True

    # 定义可视化对象
    vis = Visualizer(opt.env)

    train_data = NodeDataSet(train=True)
    val_data = NodeDataSet(val=True)
    test_data = NodeDataSet(test=True)

    # 数据集加载器
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=True,
                                num_workers=opt.num_workers)
    test_dataloader = DataLoader(test_data,
                                 opt.test_batch_size,
                                 shuffle=False,
                                 num_workers=opt.num_workers)
    for epoch in range(opt.max_epoch - start_epoch):
        print('开始 epoch {}/{}.'.format(start_epoch + epoch + 1, opt.max_epoch))
        epoch_loss = 0

        # 每轮判断是否更新学习率
        scheduler.step()

        # 迭代数据集加载器
        for ii, (img, mask) in enumerate(
                train_dataloader):  # pytorch0.4写法,不再将tensor封装为Variable
            # 将数据转到GPU
            if opt.use_gpu:
                img = img.cuda()
                true_masks = mask.cuda()
            masks_pred = net(img)

            # 经过sigmoid
            masks_probs = t.sigmoid(masks_pred)

            # 损失 = 二进制交叉熵损失 + dice损失
            loss = criterion(masks_probs.view(-1), true_masks.view(-1))

            # 加入dice损失
            if opt.use_dice_loss:
                loss += dice_loss(masks_probs, true_masks)

            epoch_loss += loss.item()

            if ii % 2 == 0:
                vis.plot('训练集loss', loss.item())

            # 优化器梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 更新参数
            optimizer.step()

        # 当前时刻的一些信息
        vis.log("epoch:{epoch},lr:{lr},loss:{loss}".format(
            epoch=epoch, loss=loss.item(), lr=optimizer.param_groups[0]['lr']))

        vis.plot('每轮epoch的loss均值', epoch_loss / ii)
        # 保存模型、优化器、当前轮次等
        state = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch
        }
        t.save(state, opt.checkpoint_root + '{}_unet.pth'.format(epoch))

        # ============验证===================

        net.eval()
        # 评价函数:Dice系数    Dice距离用于度量两个集合的相似性
        tot = 0
        for jj, (img_val, mask_val) in enumerate(val_dataloader):
            img_val = img_val
            true_mask_val = mask_val
            if opt.use_gpu:
                img_val = img_val.cuda()
                true_mask_val = true_mask_val.cuda()

            mask_pred = net(img_val)
            mask_pred = (t.sigmoid(mask_pred) > 0.5).float()  # 阈值为0.5
            # 评价函数:Dice系数   Dice距离用于度量两个集合的相似性
            tot += dice_loss(mask_pred, true_mask_val).item()
        val_dice = tot / jj
        vis.plot('验证集 Dice损失', val_dice)

        # ============验证召回率===================
        # 每10轮验证一次测试集召回率
        if epoch % 10 == 0:
            result_test = []
            for kk, (img_test, mask_test) in enumerate(test_dataloader):
                # 测试 unet分割能力,故 不使用真值mask
                if opt.use_gpu:
                    img_test = img_test.cuda()
                mask_pred_test = net(img_test)  # [1,1,512,512]

                probs = t.sigmoid(mask_pred_test).squeeze().squeeze().cpu(
                ).detach().numpy()  # [512,512]
                mask = probs > opt.out_threshold
                result_test.append(mask)

            # 得到 测试集所有预测掩码,计算二维召回率
            vis.plot('测试集二维召回率', getRecall(result_test).getResult())
        net.train()
Пример #7
0
def train_unet(epoch=100):

    # Get all images in train set
    image_names = os.listdir('dataset/train/images/')
    image_names = [name for name in image_names if name.endswith(('.jpg', '.JPG', '.png'))]

    # Split into train and validation sets
    np.random.shuffle(image_names)
    split = int(len(image_names) * 0.9)
    train_image_names = image_names[:split]
    val_image_names = image_names[split:]

    # Create a dataset
    train_dataset = EggsPansDataset('dataset/train', train_image_names, mode='train')
    val_dataset = EggsPansDataset('dataset/train', val_image_names, mode='val')

    # Create a dataloader
    train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=False, num_workers=0)
    val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0)

    # Initialize model and transfer to device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = UNet()
    model = model.to(device)
    optim = torch.optim.Adam(model.parameters(), lr=0.0001)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='max', verbose=True)
    loss_obj = EggsPansLoss()
    metrics_obj = EggsPansMetricIoU()

    # Keep best IoU and checkpoint
    best_iou = 0.0

    # Train epochs
    for epoch_idx in range(epoch):

        print('Epoch: {:2}/{}'.format(epoch_idx + 1, epoch))
        # Reset metrics and loss
        loss_obj.reset_loss()
        metrics_obj.reset_iou()

        # Train phase
        model.train()

        # Train epoch
        pbar = tqdm(train_dataloader)
        for imgs, egg_masks, pan_masks in pbar:

            # Convert to device
            imgs = imgs.to(device)
            gt_egg_masks = egg_masks.to(device)
            gt_pan_masks = pan_masks.to(device)

            # Zero gradients
            optim.zero_grad()

            # Forward through net, and get the loss
            pred_egg_masks, pred_pan_masks = model(imgs)

            loss = loss_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks])
            iou = metrics_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks])

            # Compute gradients and compute them
            loss.backward()
            optim.step()

            # Update metrics
            pbar.set_description('Loss: {:5.6f}, IoU: {:5.6f}'.format(loss_obj.get_running_loss(),
                                                                      metrics_obj.get_running_iou()))

        print('Validation: ')

        # Reset metrics and loss
        loss_obj.reset_loss()
        metrics_obj.reset_iou()

        # Val phase
        model.eval()

        # Val epoch
        pbar = tqdm(val_dataloader)
        for imgs, egg_masks, pan_masks in pbar:

            # Convert to device
            imgs = imgs.to(device)
            gt_egg_masks = egg_masks.to(device)
            gt_pan_masks = pan_masks.to(device)

            with torch.no_grad():
                # Forward through net, and get the loss
                pred_egg_masks, pred_pan_masks = model(imgs)

                loss = loss_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks])
                iou = metrics_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks])

            pbar.set_description('Val Loss: {:5.6f}, IoU: {:5.6f}'.format(loss_obj.get_running_loss(),
                                                                          metrics_obj.get_running_iou()))

        # Save best model
        if best_iou < metrics_obj.get_running_iou():
            best_iou = metrics_obj.get_running_iou()
            torch.save(model.state_dict(), os.path.join('checkpoints/', 'epoch_{}_{:.4f}.pth'.format(
                epoch_idx + 1, metrics_obj.get_running_iou())))

        # Reduce learning rate on plateau
        lr_scheduler.step(metrics_obj.get_running_iou())

        print('\n')
        print('-'*100)
Пример #8
0
                              shuffle=True,
                              drop_last=True)
    eval_loader = DataLoader(evalset,
                             batch_size=batch,
                             num_workers=1,
                             shuffle=True,
                             drop_last=True)

    model = UNet(n_channels=1, n_classes=1)
    model.to(device=device)

    # criterion = nn.CrossEntropyLoss()
    criterion = nn.MSELoss()
    # criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    training_loss = []
    eval_loss = []

    for epoch in range(num_epoch):
        train()
        evaluate()
        plot_losses()
        if (epoch % 10) == 0:
            torch.save(model.state_dict(),
                       os.path.join(models_path, f"unet_{attempt}_{epoch}.pt"))
        else:
            torch.save(model.state_dict(),
                       os.path.join(models_path, f"unet_{attempt}.pt"))
    print("Done!")
Пример #9
0
class Trainer:
    @classmethod
    def intersection_over_union(cls, y, z):
        iou = (torch.sum(torch.min(y, z))) / (torch.sum(torch.max(y, z)))
        return iou

    @classmethod
    def get_number_of_batches(cls, image_paths, batch_size):
        batches = len(image_paths) / batch_size
        if not batches.is_integer():
            batches = math.floor(batches) + 1
        return int(batches)

    @classmethod
    def evaluate_loss(cls, criterion, output, target):
        loss_1 = criterion(output, target)
        loss_2 = 1 - Trainer.intersection_over_union(output, target)
        loss = loss_1 + 0.1 * loss_2
        return loss

    def __init__(self, side_length, batch_size, epochs, learning_rate,
                 momentum_parameter, seed, image_paths, state_dict,
                 train_val_split):
        self.side_length = side_length
        self.batch_size = batch_size
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.momentum_parameter = momentum_parameter
        self.seed = seed
        self.image_paths = glob.glob(image_paths)
        self.batches = Trainer.get_number_of_batches(self.image_paths,
                                                     self.batch_size)
        self.model = UNet()
        self.loader = Loader(self.side_length)
        self.state_dict = state_dict
        self.train_val_split = train_val_split
        self.train_size = int(np.floor((self.train_val_split * self.batches)))

    def set_cuda(self):
        if torch.cuda.is_available():
            self.model = self.model.cuda()

    def set_seed(self):
        if self.seed is not None:
            np.random.seed(self.seed)

    def process_batch(self, batch):
        # Grab a batch, shuffled according to the provided seed. Note that
        # i-th image: samples[i][0], i-th mask: samples[i][1]
        samples = Loader.get_batch(self.image_paths, self.batch_size, batch,
                                   self.seed)
        samples.astype(float)
        # Cast samples into torch.FloatTensor for interaction with U-Net
        samples = torch.from_numpy(samples)
        samples = samples.float()

        # Cast into a CUDA tensor, if GPUs are available
        if torch.cuda.is_available():
            samples = samples.cuda()

        # Isolate images and their masks
        samples_images = samples[:, 0]
        samples_masks = samples[:, 1]

        # Reshape for interaction with U-Net
        samples_images = samples_images.unsqueeze(1)
        samples_masks = samples_masks.unsqueeze(1)

        # Run inputs through the model
        output = self.model(samples_images)

        # Clamp the target for proper interaction with BCELoss
        target = torch.clamp(samples_masks, min=0, max=1)

        del samples

        return output, target

    def train_model(self):
        self.model.train()
        criterion = nn.BCELoss()
        optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)

        iteration = 0
        best_iteration = 0
        best_loss = 10**10

        losses_train = []
        losses_val = []

        iou_train = []
        average_iou_train = []
        iou_val = []
        average_iou_val = []

        print("BEGIN TRAINING")
        print("TRAINING BATCHES:", self.train_size)
        print("VALIDATION BATCHES:", self.batches - self.train_size)
        print("BATCH SIZE:", self.batch_size)
        print("EPOCHS:", self.epochs)
        print("~~~~~~~~~~~~~~~~~~~~~~~~~~")

        for k in range(0, self.epochs):
            print("EPOCH:", k + 1)
            print("~~~~~~~~~~~~~~~~~~~~~~~~~~")

            # Train
            for batch in range(0, self.train_size):
                iteration = iteration + 1
                output, target = self.process_batch(batch)
                loss = Trainer.evaluate_loss(criterion, output, target)
                print("EPOCH:", self.epochs)
                print("Batch", batch, "of", self.train_size)

                # Aggregate intersection over union scores for each element in the batch
                for i in range(0, output.shape[0]):
                    binary_mask = Editor.make_binary_mask_from_torch(
                        output[i, :, :, :], 1.0)
                    iou = Trainer.intersection_over_union(
                        binary_mask, target[i, :, :, :].cpu())
                    iou_train.append(iou.item())
                    print("IoU:", iou.item())

                # Clear data to prevent memory overload
                del target
                del output

                # Clear gradients, back-propagate, and update weights
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Record the loss value
                loss_value = loss.item()
                if best_loss > loss_value:
                    best_loss = loss_value
                    best_iteration = iteration
                losses_train.append(loss_value)

                if batch == self.train_size - 1:
                    print("LOSS:", loss_value)
                    print("~~~~~~~~~~~~~~~~~~~~~~~~~~")

            average_iou = sum(iou_train) / len(iou_train)
            print("Average IoU:", average_iou)
            average_iou_train.append(average_iou)
            #Visualizer.save_loss_plot(average_iou_train, "average_iou_train.png")

            # Validate
            for batch in range(self.train_size, self.batches):
                output, target = self.process_batch(batch)
                loss = Trainer.evaluate_loss(criterion, output, target)

                for i in range(0, output.shape[0]):
                    binary_mask = Editor.make_binary_mask_from_torch(
                        output[i, :, :, :], 1.0)
                    iou = Trainer.intersection_over_union(
                        binary_mask, target[i, :, :, :].cpu())
                    iou_val.append(iou.item())
                    print("IoU:", iou.item())

                loss_value = loss.item()
                losses_val.append(loss_value)
                print("EPOCH:", self.epochs)
                print("VALIDATION LOSS:", loss_value)
                print("~~~~~~~~~~~~~~~~~~~~~~~~~~")
                del output
                del target

            average_iou = sum(iou_val) / len(iou_val)
            print("Average IoU:", average_iou)
            average_iou_val.append(average_iou)
            #Visualizer.save_loss_plot(average_iou_val, "average_iou_val.png")

        print("Least loss", best_loss, "at iteration", best_iteration)

        torch.save(self.model.state_dict(), "weights/" + self.state_dict)
Пример #10
0
        print("path exists")

    model_path = model_root / 'model_{fold}.pt'.format(fold=fold)

    if model_path.exists():
        state = torch.load(str(model_path))
        epoch = state['epoch']
        step = state['step']
        model.load_state_dict(state['model'])
        print('Restored model, epoch {}, step {:,}'.format(epoch, step))
    else:
        epoch = 1
        step = 0

    save = lambda ep: torch.save({
        'model': model.state_dict(),
        'epoch': ep,
        'step': step,
    }, str(model_path))

    report_each = 10
    valid_each = 4
    log = root.joinpath('train_{fold}.log'.format(fold=fold)).open('at', encoding='utf8')
    valid_losses = []

    if(add_log == False):
        criterion = MultiDiceLoss(num_classes=11)
    else:
        criterion = LossMulti(num_classes=11, jaccard_weight=0.5)
    class_color_table = read_json(json_file_name)
    first_time = True
Пример #11
0
    net = UNet(n_channels=4, n_classes=1, bilinear=True)
    if args.test:
        net.load_state_dict(
            torch.load(args.load_checkpoint, map_location='cpu'))
    net.to(device=device)
    try:
        if not args.test:
            train_net(net=net,
                      args=args,
                      epochs=args.epochs,
                      batch_size=args.batch_size,
                      lr=args.lr,
                      device=device)
        else:
            test = BasicDataset(args)
            test_loader = DataLoader(test,
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True,
                                     drop_last=True)
            eval_net(args, net, test_loader, device)
            skeleton()

    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
Пример #12
0
if args.train:
    loss_list = []
    for i in tqdm(range(args.epochs)):
        train(i, loss_list)
        test()

    plt.plot(loss_list)
    plt.title("UNet bs={}, ep={}, lr={}".format(args.batch_size,
                                                args.epochs, args.lr))
    plt.xlabel("Number of iterations")
    plt.ylabel("Average DICE loss per batch")
    plt.savefig("./plots/{}-UNet_Loss_bs={}_ep={}_lr={}.png".format(args.save,
                                                                    args.batch_size,
                                                                    args.epochs,
                                                                    args.lr))

    np.save('./npy-files/loss-files/{}-UNet_Loss_bs={}_ep={}_lr={}.npy'.format(args.save,
                                                                               args.batch_size,
                                                                               args.epochs,
                                                                               args.lr),
            np.asarray(loss_list))

    torch.save(model.state_dict(), 'unet-final-{}-{}-{}'.format(args.batch_size,
                                                                args.epochs,
                                                                args.lr))
else:
    model.load_state_dict(torch.load(args.load))
    test(save_output=True)
    test(train_accuracy=True)
Пример #13
0
def main():
    # 네트워크
    G = UNet().to(device)
    D = Discriminator().to(device)

    # 네트워크 초기화
    G.apply(weight_init)
    D.apply(weight_init)

    # pretrained 모델 불러오기
    if args.reuse:
        assert os.path.isfile(args.save_path), '[!]Pretrained model not found'
        checkpoint = torch.load(args.save_path)
        G.load_state_dict(checkpoint['G'])
        D.load_state_dict(checkpoint['D'])
        print('[*]Pretrained model loaded')

    # optimizer
    G_optim = optim.Adam(G.parameters(), lr=args.lr, betas=(args.b1, args.b2))
    D_optim = optim.Adam(D.parameters(), lr=args.lr, betas=(args.b1, args.b2))

    for epoch in range(args.num_epoch):
        for i, imgs in enumerate(dataloader['train']):
            A = imgs['A'].to(device)
            B = imgs['B'].to(device)

            # # # # #
            # Discriminator
            # # # # #
            G.eval()
            D.train()

            fake = G(B)
            D_fake = D(fake, B)
            D_real = D(A, B)

            # original loss D
            loss_D = -((D_real.log() + (1 - D_fake).log()).mean())

            #            # LSGAN loss D
            #            loss_D = ((D_real - 1)**2).mean() + (D_fake**2).mean()

            D_optim.zero_grad()
            loss_D.backward()
            D_optim.step()

            # # # # #
            # Generator
            # # # # #
            G.train()
            D.eval()

            fake = G(B)
            D_fake = D(fake, B)

            # original loss G
            loss_G = -(D_fake.mean().log()
                       ) + args.lambda_recon * torch.abs(A - fake).mean()

            #            # LSGAN loss G
            #            loss_G = ((D_fake-1)**2).mean() + args.lambda_recon * torch.abs(A - fake).mean()

            G_optim.zero_grad()
            loss_G.backward()
            G_optim.step()

            # 학습 진행사항 출력
            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                  (epoch, args.num_epoch, i * args.batch_size,
                   len(datasets['train']), loss_D.item(), loss_G.item()))

        # 이미지 저장 (save per epoch)
        val = next(iter(dataloader['test']))
        real_A = val['A'].to(device)
        real_B = val['B'].to(device)

        with torch.no_grad():
            fake_A = G(real_B)
        save_image(torch.cat([real_A, real_B, fake_A], dim=3),
                   'images/{0:03d}.png'.format(epoch + 1),
                   nrow=2,
                   normalize=True)

        # 모델 저장
        torch.save({
            'G': G.state_dict(),
            'D': D.state_dict(),
        }, args.save_path)
Пример #14
0
        t_l = train()
        v_l, viou = validate()
        train_losses.append(t_l)
        valid_losses.append(v_l)
        valid_ious.append(viou)

        # write the losses to a text file
        #with open('../logs/losses_{}_{}_{}.txt'.format(args.model_name,
        #                                               args.exp_name,
        #                                               k), 'a') as logfile:
        #    logfile.write('{},{},{},{}'.format(e, t_l, v_l, v_a) + "\n")

        # save the model everytime we get a new best valid loss
        if v_l < best_val_loss:
            torch.save(net.state_dict(), MODEL_CKPT)
            best_val_loss = v_l

        # if the validation loss gets worse increment 1 to the patience values
        #if v_l > best_val_loss:
        #     valid_patience += 1
        #     lr_patience += 1

        # if the model stops improving by a certain number epochs, stop
        #if valid_patience == args.es_patience:
        #     break
        if e in LR_SCHED:
            for params in optimizer.param_groups:
                params['lr'] = LR_SCHED[e]

        print('Time: {}'.format(time.time() - start))
Пример #15
0
def get_denoised_mat(mat,
                     model=None,
                     out_channels=[64, 64, 64],
                     ndim=3,
                     num_epochs=12,
                     num_iters=600,
                     print_every=300,
                     batch_size=2,
                     mask_prob=0.05,
                     frame_depth=6,
                     frame_weight=None,
                     movie_start_idx=250,
                     movie_end_idx=750,
                     save_folder='.',
                     save_intermediate_results=False,
                     normalize=True,
                     last_out_channels=None,
                     loss_reg_fn=nn.MSELoss(),
                     loss_history=[],
                     loss_threshold=0,
                     optimizer_fn=torch.optim.AdamW,
                     optimizer_fn_args={
                         'lr': 1e-3,
                         'weight_decay': 1e-2
                     },
                     lr_scheduler=None,
                     batch_size_eval=10,
                     kernel_size_unet=3,
                     features=None,
                     fps=60,
                     save_filename='denoised_mat',
                     window_size_row=None,
                     window_size_col=None,
                     weight=None,
                     verbose=False,
                     return_model=False,
                     half_precision=False,
                     device=torch.device('cuda')):
    """The main function for Noise2Self pipeline: train the model, save checkpoints, and return denoised mat and optionally trained model.
    
    Args:
        mat: torch.Tensor of shape (nframe, nrow, ncol)
        model: default None, construct the model internally
        out_channels: used for construct a UNet model when model is None
        ndim: default 3, construct a 3D UNet model; if ndim==2, construct a 2D UNet model
        num_epochs: int
        num_iters: int, number of batches to train per epoch
        features: default None, do not use "global features"; if not None, either provide features as a torch.Tensor or set features=True
            if features is True, then set features = torch.stack([mat.mean(0), mat.std(0)], dim=0) internally
            
    
    Returns:
        denoised_mat: torch.Tensor if return_model is False otherwise return denoised_mat, model
        
    """
    if half_precision:
        mat = mat.half()
    if normalize:
        mean_mat = mat.mean()
        std_mat = mat.std()
        mat = (mat - mean_mat) / std_mat
    else:
        mean_mat = 0
        std_mat = 1
    if features is True:
        features = torch.stack([mat.mean(0), mat.std(0)], dim=0)
    if model is None:
        encoder_depth = len(out_channels)
        kernel_size = kernel_size_unet
        assert kernel_size % 2 == 1
        padding = (kernel_size - 1) // 2
        nrow, ncol = mat.shape[-2:]
        pool_kernel_size_row = get_prime_factors(nrow)[:encoder_depth]
        pool_kernel_size_col = get_prime_factors(ncol)[:encoder_depth]
        if ndim == 2:
            in_channels = frame_depth * 2 + 1
            if features is not None:
                in_channels += features.shape[0]
            model = UNet(in_channels=in_channels,
                         num_classes=1,
                         out_channels=out_channels,
                         num_conv=2,
                         n_dim=2,
                         kernel_size=kernel_size,
                         padding=padding,
                         pool_kernel_size=[(pool_kernel_size_row[i],
                                            pool_kernel_size_col[i])
                                           for i in range(encoder_depth)],
                         transpose_kernel_size=[
                             (pool_kernel_size_row[i], pool_kernel_size_col[i])
                             for i in reversed(range(encoder_depth))
                         ],
                         transpose_stride=[
                             (pool_kernel_size_row[i], pool_kernel_size_col[i])
                             for i in reversed(range(encoder_depth))
                         ],
                         last_out_channels=last_out_channels,
                         use_adaptive_pooling=False,
                         same_shape=True,
                         padding_mode='replicate',
                         normalization='layer_norm',
                         activation=nn.LeakyReLU(negative_slope=0.01,
                                                 inplace=True)).to(device)
        elif ndim == 3:
            assert frame_depth % encoder_depth == 0
            k = frame_depth // encoder_depth + 1
            model = UNet(in_channels=1,
                         num_classes=1,
                         out_channels=out_channels,
                         num_conv=2,
                         n_dim=3,
                         kernel_size=[kernel_size] +
                         [(k, kernel_size, kernel_size)] * encoder_depth +
                         [(1, kernel_size, kernel_size)] * encoder_depth,
                         padding=[padding] +
                         [(0, padding, padding)] * encoder_depth * 2,
                         pool_kernel_size=[(1, pool_kernel_size_row[i],
                                            pool_kernel_size_col[i])
                                           for i in range(encoder_depth)],
                         transpose_kernel_size=[
                             (1, pool_kernel_size_row[i],
                              pool_kernel_size_col[i])
                             for i in reversed(range(encoder_depth))
                         ],
                         transpose_stride=[
                             (1, pool_kernel_size_row[i],
                              pool_kernel_size_col[i])
                             for i in reversed(range(encoder_depth))
                         ],
                         last_out_channels=last_out_channels,
                         use_adaptive_pooling=True,
                         same_shape=False,
                         padding_mode='zeros',
                         normalization='layer_norm',
                         activation=nn.LeakyReLU(negative_slope=0.01,
                                                 inplace=True)).to(device)
    if mat.dtype == torch.float16:
        model = model.half()
    optimizer = optimizer_fn(
        filter(lambda p: p.requires_grad, model.parameters()),
        **optimizer_fn_args)
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    if save_intermediate_results:
        movie_tyx = mat[movie_start_idx:movie_end_idx] * std_mat + mean_mat
        if not os.path.exists(
                f'{save_folder}/movie_frame{movie_start_idx}to{movie_end_idx}_raw.avi'
        ):
            make_video_ffmpeg(
                movie_tyx,
                save_path=
                f'{save_folder}/movie_frame{movie_start_idx}to{movie_end_idx}_raw.avi',
                fps=fps)
    initial_start_time = time.time()
    for epoch in range(num_epochs):
        if lr_scheduler is not None and len(lr_scheduler) >= 2:
            lr = lr_scheduler['lr_fn'](epoch, **lr_scheduler['lr_fn_args'])
            optimizer_fn_args['lr'] = lr
            optimizer = optimizer_fn(
                filter(lambda p: p.requires_grad, model.parameters()),
                **optimizer_fn_args)
            if verbose:
                print(f'Epoch {epoch+1} set learning rate to be {lr:.2e}')
        start_time = time.time()
        for i in range(num_iters):
            batch_data = get_noise2self_train_data(
                mat,
                ndim=ndim,
                batch_size=batch_size,
                frame_depth=frame_depth,
                frame_weight=frame_weight,
                mask_prob=mask_prob,
                features=features,
                window_size_row=window_size_row,
                window_size_col=window_size_col,
                weight=weight,
                return_frame_indices=False)
            x, y_true, mask = batch_data['x'], batch_data[
                'y_true'], batch_data['mask']
            if weight is not None:
                batch_weight = batch_data['weight']
            y_pred = model(x)
            if ndim == 2:
                y_pred = y_pred.squeeze(1)
            elif ndim == 3:
                y_pred = y_pred.squeeze(1).squeeze(1)
            if weight is not None:
                loss = loss_reg_fn((y_pred * batch_weight)[mask],
                                   (y_true * batch_weight)[mask])
            else:
                loss = loss_reg_fn(y_pred[mask], y_true[mask])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_history.append(loss.item())
            if verbose and (i == 0 or i == num_iters - 1 or
                            (i + 1) % print_every == 0):
                print(f'i={i+1}, loss={loss.item()}')
        end_time = time.time()
        if verbose:
            print(f'Epoch {epoch+1} time: {end_time - start_time}')
            plt.plot(loss_history[-num_iters:], 'o-', markersize=3)
            plt.ylabel('loss')
            plt.xlabel(f'iteration (epoch {epoch+1})')
            plt.show()
        if save_intermediate_results:
            torch.save(model.state_dict(),
                       f'{save_folder}/model_step{len(loss_history)}.pt')
            np.save(f'{save_folder}/loss__denoise.npy', loss_history)
            torch.cuda.empty_cache()
            denoised_mat = model_denoise(
                mat[movie_start_idx - frame_depth:movie_end_idx + frame_depth],
                model,
                ndim=ndim,
                frame_depth=frame_depth,
                features=features,
                batch_size=batch_size_eval,
                normalize=False,
                replicate_pad=False)
            movie_tyx = denoised_mat * std_mat + mean_mat
            make_video_ffmpeg(
                movie_tyx,
                save_path=
                f'{save_folder}/denoised_movie_frame{movie_start_idx}to{movie_end_idx}_step{len(loss_history)}.avi',
                fps=fps)
            #             if epoch == num_epochs-1:
            np.save(
                f'{save_folder}/denoised_movie_frame{movie_start_idx}to{movie_end_idx}_step{len(loss_history)}.npy',
                movie_tyx.cpu().numpy())
            del denoised_mat, movie_tyx
        torch.cuda.empty_cache()
        if (loss_threshold > 0 and epoch > 0 and np.abs(
                np.array(loss_history[-num_iters:]).mean() -
                np.array(loss_history[-2 * num_iters:-num_iters]).mean()) <
                loss_threshold):
            break
    torch.save(model.state_dict(),
               f'{save_folder}/model_step{len(loss_history)}.pt')
    np.save(f'{save_folder}/loss__denoise.npy', loss_history)
    if verbose:
        plt.plot(loss_history, 'o-', markersize=3)
        plt.ylabel('loss')
        plt.xlabel(f'iteration')
        plt.title(
            f'loss_history[{len(loss_history)}] = ({loss_history[-1]:.3e})')
        plt.show()
    denoised_mat = model_denoise(mat,
                                 model,
                                 ndim=ndim,
                                 frame_depth=frame_depth,
                                 features=features,
                                 batch_size=batch_size_eval,
                                 normalize=False,
                                 replicate_pad=True)
    del mat
    torch.cuda.empty_cache()
    denoised_mat = denoised_mat * std_mat + mean_mat
    if save_filename is None:
        np.save(
            f'{save_folder}/denoised_movie_frame0to{mat.shape[0]}_step{len(loss_history)}.npy',
            denoised_mat.cpu().numpy())
    else:
        np.save(f'{save_folder}/{save_filename}.npy',
                denoised_mat.cpu().numpy())
    if verbose:
        end_time = time.time()
        print(f'Total time spent: {end_time - initial_start_time}')
    if return_model:
        return denoised_mat, model
    else:
        return denoised_mat
Пример #16
0
def train_val(config):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_loader = get_dataloader(img_dir=config.train_img_dir, mask_dir=config.train_mask_dir, mode="train",
                                  batch_size=config.batch_size, num_workers=config.num_workers, smooth=config.smooth)
    val_loader = get_dataloader(img_dir=config.val_img_dir, mask_dir=config.val_mask_dir, mode="val",
                                batch_size=config.batch_size, num_workers=config.num_workers)

    writer = SummaryWriter(
        comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" % (config.lr, config.batch_size, config.model_type, config.data_type))

    if config.model_type == "UNet":
        model = UNet()
    elif config.model_type == "UNet++":
        model = UNetPP()
    elif config.model_type == "SEDANet":
        model = SEDANet()
    elif config.model_type == "RendDANet":
        src = "./exp/24_DANet_0.7585.pth"
        pretrained_dict = torch.load(src, map_location='cpu').module.state_dict()
        print("load pretrained params from stage 1: " + src)
        model = RendDANet(nclass=15, backbone="resnet101", norm_layer=nn.BatchNorm2d)
        model_dict = model.state_dict()
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        for param in model.pretrained.parameters():
            param.requires_grad = False
        for param in model.head.parameters():
            param.requires_grad = False
        for param in model.seg1.parameters():
            param.requires_grad = False
    elif config.model_type == "RefineNet":
        model = rf101()
    elif config.model_type == "DANet":
        model = DANet(backbone='resnet101', nclass=config.output_ch, pretrained=True, norm_layer=nn.BatchNorm2d)
    elif config.model_type == "Deeplabv3+":
        model = deeplabv3_plus.DeepLabv3_plus(in_channels=3, num_classes=8, backend='resnet101', os=16, pretrained=True, norm_layer=nn.BatchNorm2d)
    elif config.model_type == "HRNet_OCR":
        model = seg_hrnet_ocr.get_seg_model()
    elif config.model_type == "scSEUNet":
        model = scSEUNet(pretrained=True, norm_layer=nn.BatchNorm2d)
    else:
        model = UNet()

    if config.iscontinue:
        model = torch.load("./exp/30_RendDANet_0.7774.pth").module

    for k, m in model.named_modules():
        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatability

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device)

    labels = [1, 2, 3, 4, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    objects = ['水体', '道路', '建筑物', '机场', '停车场', '操场', '普通耕地', '农业大棚', '自然草地', '绿地绿化',
               '自然林', '人工林', '自然裸土', '人为裸土', '其它']
    frequency = np.array([0.0279, 0.0797, 0.1241, 0.00001, 0.0616, 0.0029, 0.2298, 0.0107, 0.1207, 0.0249,
                          0.1470, 0.0777, 0.0617, 0.0118, 0.0187])

    if config.optimizer == "sgd":
        optimizer = SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=config.lr, weight_decay=1e-4, momentum=0.9)
    elif config.optimizer == "adamw":
        optimizer = adamw.AdamW(model.parameters(), lr=config.lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    # weight = torch.tensor([1, 1.5, 1, 2, 1.5, 2, 2, 1.2]).to(device)
    # criterion = nn.CrossEntropyLoss(weight=weight)

    criterion = RendLoss()

    # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[25, 30, 35, 40], gamma=0.5)
    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.1, patience=5, verbose=True)
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=15, eta_min=1e-4)

    global_step = 0
    max_fwiou = 0
    for epoch in range(config.num_epochs):
        epoch_loss = 0.0
        cm = np.zeros([15, 15])
        print(optimizer.param_groups[0]['lr'])
        with tqdm(total=config.num_train, desc="Epoch %d / %d" % (epoch + 1, config.num_epochs),
                  unit='img', ncols=100) as train_pbar:
            model.train()

            for image, mask in train_loader:
                image = image.to(device, dtype=torch.float32)
                mask = mask.to(device, dtype=torch.float32)

                pred = model(image)
                loss = criterion(pred, mask)
                # loss = lovasz_softmax(torch.softmax(pred, dim=1), mask)
                epoch_loss += loss.item()

                writer.add_scalar('Loss/train', loss.item(), global_step)
                train_pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_pbar.update(image.shape[0])
                global_step += 1
                # if global_step > 10:
                #     break

            # scheduler.step()
            print("\ntraining epoch loss: " + str(epoch_loss / (float(config.num_train) / (float(config.batch_size)))))

        val_loss = 0
        with tqdm(total=config.num_val, desc="Epoch %d / %d validation round" % (epoch + 1, config.num_epochs),
                  unit='img', ncols=100) as val_pbar:
            model.eval()
            locker = 0
            for image, mask in val_loader:
                image = image.to(device, dtype=torch.float32)
                target = mask.to(device, dtype=torch.long).argmax(dim=1)
                mask = mask.cpu().numpy()
                pred = model(image)['fine']
                # val_loss += lovasz_softmax(pred, target).item()
                val_loss += F.cross_entropy(pred, target).item()
                pred = pred.cpu().detach().numpy()
                mask = semantic_to_mask(mask, labels)
                pred = semantic_to_mask(pred, labels)
                cm += get_confusion_matrix(mask, pred, labels)
                val_pbar.update(image.shape[0])
                if locker == 25:
                    writer.add_images('mask_a/true', mask[2, :, :], epoch + 1, dataformats='HW')
                    writer.add_images('mask_a/pred', pred[2, :, :], epoch + 1, dataformats='HW')
                    writer.add_images('mask_b/true', mask[3, :, :], epoch + 1, dataformats='HW')
                    writer.add_images('mask_b/pred', pred[3, :, :], epoch + 1, dataformats='HW')
                locker += 1

                # break
            miou = get_miou(cm)
            fw_miou = (miou * frequency).sum()
            scheduler.step()

            if fw_miou > max_fwiou:
                if torch.__version__ == "1.6.0":
                    torch.save(model,
                               config.result_path + "/%d_%s_%.4f.pth" % (epoch + 1, config.model_type, fw_miou),
                               _use_new_zipfile_serialization=False)
                else:
                    torch.save(model,
                               config.result_path + "/%d_%s_%.4f.pth" % (epoch + 1, config.model_type, fw_miou))
                max_fwiou = fw_miou
            print("\n")
            print(miou)
            print("testing epoch loss: " + str(val_loss), "FWmIoU = %.4f" % fw_miou)
            writer.add_scalar('mIoU/val', miou.mean(), epoch + 1)
            writer.add_scalar('FWIoU/val', fw_miou, epoch + 1)
            writer.add_scalar('loss/val', val_loss, epoch + 1)
            for idx, name in enumerate(objects):
                writer.add_scalar('iou/val' + name, miou[idx], epoch + 1)
    writer.close()
    print("Training finished")