コード例 #1
0
def train_datagen(epoch_iter=2000,
                  epoch_num=5,
                  batch_size=128,
                  data_dir=args.train_data):
    print('train_datagen')
    while (True):
        n_count = 0
        if n_count == 0:
            print(n_count)
            xs = dg.datagenerator(data_dir)
            print(xs.shape)

            #assert len(xs)%args.batch_size ==0, \

            print('log done')
            xs = xs / 255.0
            xs = xs.astype('float32')

            indices = list(range(xs.shape[0]))
            n_count = 1
        print('if done', range(epoch_num))
        for _ in range(epoch_num):
            np.random.shuffle(indices)  # shuffle
            for i in range(0, len(indices), batch_size):
                batch_x = xs[indices[i:i + batch_size]]
                noise = np.random.normal(0, args.sigma / 255.0,
                                         batch_x.shape)  # noise
                #noise =  K.random_normal(ge_batch_y.shape, mean=0, stddev=args.sigma/255.0)
                batch_y = batch_x + noise
                yield batch_y, batch_x
コード例 #2
0
def train_datagen(epoch_iter=1,
                  epoch_num=1,
                  batch_size=128,
                  data_dir=args.train_data,
                  noise_dir=args.noise_data):
    loop = 0
    while (True):
        loop = loop + 1
        xs = dg.datagenerator(data_dir)
        assert len(xs)%batch_size ==0, \
            log('make sure the last iteration has a full batchsize, this is important if you use batch normalization!')
        xs = xs.astype('float32')
        indices = list(range(xs.shape[0]))
        print("Augmented patch data shape:", xs.shape)
        np.random.shuffle(indices)  # shuffle
        for i in range(0, len(indices), batch_size):
            ratio1 = np.random.uniform(0.7, 0.3, batch_size)
            ratio2 = 1.0 - ratio1

            batch_x = xs[indices[i:i + batch_size]]
            noise = extract_noise(batch_x.shape, noise_dir)

            for jj in range(batch_size):
                batch_x[jj] = batch_x[jj] * ratio1[jj]
                noise[jj] = noise[jj] * ratio2[jj]

            batch_y = (batch_x + noise)

            yield batch_y, batch_x
コード例 #3
0
def main():
    global opt, model
    opt = parser.parse_args()
    logger = set_logger(opt.save)
    print(opt)
    print(opt, file=logger)

    # setting gpu and seed
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    opt.seed = random.randint(1, 10000)
    print("Random Seed: ", opt.seed)
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)

    # setting dataset
    print("===> Loading dataset")
    patches = datagenerator(data_dir=opt.data_train)
    train_set = DenoisingDataset(patches, sigma=opt.sigma)
    training_data_loader = DataLoader(dataset=train_set,
                                      num_workers=opt.threads,
                                      drop_last=True,
                                      batch_size=opt.batchSize,
                                      shuffle=True)

    # setting model and loss
    print("===> Building model")
    model = DUAL_CNN_DENOISE()
    criterion = nn.MSELoss(size_average=False)
    model = model.cuda()
    criterion = criterion.cuda()

    # setting optimizer
    print("===> Setting Optimizer")
    kwargs = {'weight_decay': opt.weight_decay}
    optimizer = optim.Adam([{
        "params": model.structure_net.parameters(),
        "lr": opt.srcnn_lr
    }, {
        "params": model.detail_net.parameters(),
        "lr": opt.vdsr_lr
    }], **kwargs)

    print("===> Training")
    for epoch in range(opt.start_epoch, opt.nEpochs + 1):
        train(training_data_loader, optimizer, model, criterion, epoch, logger)
        model_path = save_checkpoint(model, epoch)
        eval.eval(model_path, opt.save, opt.sigma)
コード例 #4
0
ファイル: main_train.py プロジェクト: csprh/WATERCODE
def train_datagen(epoch_iter=2000,epoch_num=5,batch_size=128,data_dir=args.train_data):
    while(True):
        n_count = 0
        if n_count == 0:
            #print(n_count)
            xs, ys = dg.datagenerator(data_dir)
            assert len(xs)%args.batch_size ==0, \
            log('make sure the last iteration has a full batchsize, this is important if you use batch normalization!')
            xs = xs.astype('float32')/255.0
            ys = ys.astype('float32')/255.0
            indices = list(range(xs.shape[0]))
            n_count = 1
        for _ in range(epoch_num):
            np.random.shuffle(indices)    # shuffle
            for i in range(0, len(indices), batch_size):
                batch_x = xs[indices[i:i+batch_size]]
                batch_y = ys[indices[i:i+batch_size]]
                yield batch_y, batch_x
コード例 #5
0
    def train(self):
        train_images, train_labels = datagenerator(train_data_dir=self.train_images_dir, GT_data_dir=self.train_labels_dir)
        train_images = train_images / 255
        train_labels = train_labels / 255

        self.summary_writer = tf.summary.FileWriter(self.log_dir, graph=tf.get_default_graph())
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            self.train_op = tf.train.AdamOptimizer(self.lr_init).minimize(self.loss)
        tf.initialize_all_variables().run()

        counter = 0
        start_time = time.time()

        if self.load(self.checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        print("Training...")

        for ep in range(self.epoch):
            if ep % 5 ==0 and ep != 0:
                self.lr_init = self.lr_init / 10
            batch_idxs = len(train_images) // self.batch_size
            for idx in range(0, batch_idxs):
                batch_images = train_images[idx * self.batch_size: (idx + 1) * self.batch_size]
                batch_labels = train_labels[idx * self.batch_size: (idx + 1) * self.batch_size]

                counter += 1

                _, err, psnr = self.sess.run([self.train_op, self.loss, self.metric],
                                             feed_dict={self.images:batch_images, self.labels:batch_labels})

                summary = self.sess.run(self.merged_summary_op,
                                        feed_dict={self.images:batch_images, self.labels:batch_labels})
                self.summary_writer.add_summary(summary, counter)

                if counter % 10 == 0:
                    print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f], PSNR: [%.4f], lr: [%.6f]" % (
                        (ep + 1), counter, time.time() - start_time, err, psnr, self.lr_init))
                if counter % 500 == 0:
                    self.save(self.checkpoint_dir, counter)
    optimizer_decompose = optim.Adam(decompose_model.parameters(), lr=args.lr)
    scheduler_decompose = MultiStepLR(optimizer_decompose,
                                      milestones=[30, 60, 90],
                                      gamma=0.2)  # learning rates
    # optimizer_compose = optim.Adam(compose_model.parameters(), lr=args.lr)
    # scheduler_compose = MultiStepLR(optimizer_compose, milestones=[30, 60, 90], gamma=0.2)  # learning rates
    for epoch in range(initial_epoch, n_epoch):
        decompose_model.train()
        # compose_model.train()

        scheduler_decompose.step(
            epoch)  # step to the learning rate in this epcoh
        # scheduler_compose.step(epoch)  # step to the learning rate in this epcoh

        xs = dg.datagenerator(data_dir=args.train_data)
        xs = xs.astype('float32') / 255.0
        xs = torch.from_numpy(xs.transpose(
            (0, 3, 1, 2)))  # tensor of the clean patches, NXCXHXW

        DDataset = DenoisingDataset(xs, sigma)
        batch_y, batch_x = DDataset[:238336]

        dataset = torch.cat((batch_x, batch_y), dim=1)
        DLoader = torch.utils.data.DataLoader(dataset=dataset,
                                              num_workers=0,
                                              drop_last=True,
                                              batch_size=batch_size,
                                              shuffle=True)
        epoch_loss = 0
        start_time = time.time()
コード例 #7
0
ファイル: main_train.py プロジェクト: danielasoucst/wiien
    # DADOS PARA VALIDACAO DO TREINAMENTO
    # val_xs, val_illums = dg.datagenerator(data_dir=args.val_data, batch_size=args.batch_size, is_validation=True)
    # val_xs = val_xs.astype(np.float32)/255.0 #float = [0,1]
    # val_xs = torch.from_numpy(np.reshape(val_xs,(val_xs.shape[0], val_xs.shape[3], val_xs.shape[1], val_xs.shape[2])))
    # print('val_xs',val_xs.shape)
    # val_DDataset = DenoisingDataset(val_xs, val_illums)
    # val_DLoader = DataLoader(dataset=val_DDataset, num_workers=4, drop_last=True, batch_size=batch_size, shuffle=True,
    #                          worker_init_fn=worker_init)

    for epoch in range(initial_epoch, n_epoch):
        print('Treinando epoca %d' % (epoch))
        # print('Treinando epoca %d com fator %f'%(epoch, fatores[epoch]))
        # scheduler.step(epoch)

        xs, illums = dg.datagenerator(data_dir=args.train_data, batch_size=args.batch_size)

        xs = xs.astype(np.float32) / 255.0  # float = [0,1]

        xs = torch.from_numpy(np.reshape(xs, (
        xs.shape[0], xs.shape[3], xs.shape[1], xs.shape[2])))  # tensor of the clean patches, NXCXHXW

        # xs = torch.from_numpy(xs.transpose((0, 3, 1, 2)))  # tensor of the clean patches, NXCXHXW
        # xs = xs.transpose((0, 3, 1, 2))

        DDataset = LightingDataset(xs, illums)

        DLoader = DataLoader(dataset=DDataset, num_workers=4, drop_last=True, batch_size=batch_size, shuffle=True,
                             worker_init_fn=worker_init)

        epoch_loss = 0
コード例 #8
0
ファイル: train.py プロジェクト: nuguziii/ResidualDenoising
def train(batch_size=128,
          n_epoch=300,
          sigma=25,
          lr=1e-4,
          depth=7,
          device="cuda:0",
          data_dir='./data/Train400',
          model_dir='models',
          model_name=None):
    device = torch.device(device)

    if not os.path.exists(
            os.path.join(
                model_dir, "model" + str(sigma) + "m" + str(model_name[1]) +
                "d" + str(depth))):
        os.mkdir(
            os.path.join(
                model_dir, "model" + str(sigma) + "m" + str(model_name[1]) +
                "d" + str(depth)))

    save_dir = os.path.join(
        model_dir,
        "model" + str(sigma) + "m" + str(model_name[1]) + "d" + str(depth))

    from datetime import date
    save_name = "model_mode" + str(model_name[1]) + str(depth) + "_" + "".join(
        str(date.today()).split('-')[1:]) + ".pth"

    f = open(os.path.join(save_dir, save_name.replace(".pth", ".txt")), 'w')

    f.write(('--\t This is end to end model saved as ' + save_name + '\n'))
    f.write(('--\t epoch %4d batch_size %4d sigma %4d\n' %
             (n_epoch, batch_size, sigma)))
    f.write(model_name[0])

    DNet = torch.load(os.path.join(model_dir, model_name[0]))

    DNet.eval()

    modelG = Model(model_dir=model_dir,
                   model_name=model_name)  #guidance='noisy, denoised'
    modelD = discriminator()

    print(modelG)
    f.write(str(modelG))
    f.write('\n\n')

    ngpu = 2
    if (device.type == 'cuda') and (ngpu > 1):
        modelG = nn.DataParallel(modelG, list(range(ngpu)))
        modelD = nn.DataParallel(modelD, list(range(ngpu)))
        DNet = nn.DataParallel(DNet, list(range(ngpu)))

    modelG.apply(weights_init)
    modelD.apply(weights_init)

    criterion_perceptual = vgg_loss(device)
    criterion_l1 = nn.L1Loss(size_average=None, reduce=None, reduction='sum')
    criterion_bce = nn.BCELoss()
    criterion_l2 = sum_squared_error()
    criterion_ssim = SSIM()

    if torch.cuda.is_available():
        modelG.to(device)
        modelD.to(device)
        DNet.to(device)

    optimizerG = optim.Adam(modelG.parameters(),
                            lr=lr,
                            betas=(0.5, 0.999),
                            weight_decay=1e-5)
    optimizerD = optim.Adam(modelD.parameters(),
                            lr=lr,
                            betas=(0.5, 0.999),
                            weight_decay=1e-5)
    scheduler = MultiStepLR(optimizerG, milestones=[30, 60, 90],
                            gamma=0.2)  # learning rates

    if sigma == 0:
        sigma_list = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70]
    else:
        sigma_list = [sigma]

    for epoch in range(n_epoch):
        for sig in sigma_list:
            x = dg.datagenerator(data_dir=data_dir).astype('float32') / 255.0
            print(x.shape)
            x = torch.from_numpy(x.transpose((0, 3, 1, 2)))
            dataset = DenoisingDataset(x, sigma=sig)
            loader = DataLoader(dataset=dataset,
                                num_workers=4,
                                drop_last=True,
                                batch_size=batch_size,
                                shuffle=True)
            epoch_loss_g = 0
            start_time = time.time()
            n_count = 0
            for cnt, batch_yx in enumerate(loader):
                if torch.cuda.is_available():
                    batch_original, batch_noise = batch_yx[1].to(
                        device), batch_yx[0].to(device)
                '''
                modelD.zero_grad()
                b_size = batch_original.size(0)
                label = torch.full((b_size,), 1, device=device)
                output = modelD(batch_original).view(-1)
                errD_real = criterion_bce(output, label)
                errD_real.backward(retain_graph=True)
                '''
                residual = DNet(batch_noise)
                fake, structure, denoised = modelG(batch_noise, residual)
                '''
                label.fill_(0)
                output = modelD(fake.detach()).view(-1)
                errD_fake = criterion_bce(output, label)
                errD_fake.backward(retain_graph=True)

                d_loss = errD_real + errD_fake
                optimizerD.step()
                '''
                modelG.zero_grad()
                '''
                label = torch.full((b_size,), 1, device=device)
                output = modelD(fake).view(-1)
                gan_loss = criterion_bce(output, label)
                '''
                s_loss = criterion_l2(structure, batch_original - denoised)
                s_loss.backward(retain_graph=True)

                l1_loss = criterion_l1(fake, batch_original)
                perceptual_loss = criterion_perceptual(fake, batch_original, 0)
                #ssim_out = 1-criterion_ssim(fake, batch_original)
                #l2_loss = criterion_l2(fake, batch_original)

                g_loss = l1_loss + 2e-2 * perceptual_loss  #+1e-2*gan_loss
                g_loss.backward(retain_graph=True)
                epoch_loss_g += g_loss.item()
                optimizerG.step()

                if cnt % 100 == 0:
                    line = '%4d %4d / %4d g_loss = %2.4f\t(snet_l2_loss = %2.4f / l1_loss=%2.4f / perceptual_loss=%2.4f)' % (
                        epoch + 1, cnt, x.size(0) // batch_size,
                        g_loss.item() / batch_size, s_loss.item() / batch_size,
                        l1_loss.item() / batch_size,
                        perceptual_loss.item() / batch_size)
                    print(line)
                    f.write(line)
                    f.write('\n')
                n_count += 1

            elapsed_time = time.time() - start_time
            line = 'epoch = %4d, sigma = %4d, loss = %4.4f , time = %4.2f s' % (
                epoch + 1, sig, epoch_loss_g /
                (n_count * batch_size), elapsed_time)
            print(line)
            f.write(line)
            f.write('\n')
            if (epoch + 1) % 20 == 0:
                torch.save(
                    modelG,
                    os.path.join(
                        save_dir,
                        save_name.replace('.pth', '_epoch%03d.pth') %
                        (epoch + 1)))

        torch.save(modelG, os.path.join(save_dir, save_name))
    f.close()
コード例 #9
0
ファイル: train.py プロジェクト: nuguziii/ResidualDenoising
def pretrain_DNet(batch_size=128,
                  n_epoch=150,
                  sigma=25,
                  lr=1e-3,
                  depth=17,
                  device="cuda:0",
                  data_dir='./data/Train400',
                  model_dir='models'):
    device = torch.device(device)

    model_dir = os.path.join(model_dir,
                             "DNet_s" + str(sigma) + "d" + str(depth))
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    from datetime import date
    save_name = "DNet_s" + str(sigma) + "d" + str(depth) + "_" + "".join(
        str(date.today()).split('-')[1:]) + ".pth"

    print('\n')
    print('--\t This model is pre-trained DNet saved as ', save_name)
    print('--\t epoch %4d batch_size %4d sigma %4d depth %4d' %
          (n_epoch, batch_size, sigma, depth))
    print('\n')

    model = DNet(depth=depth)

    model.train()

    print(model)
    print("\n")

    criterion = sum_squared_error()

    if torch.cuda.is_available():
        model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90],
                            gamma=0.2)  # learning rates

    if sigma == 0:
        sigma_list = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70]
    else:
        sigma_list = [sigma]

    for epoch in range(n_epoch):
        for sig in sigma_list:
            x = dg.datagenerator(data_dir=data_dir).astype('float32') / 255.0
            x = torch.from_numpy(x.transpose((0, 3, 1, 2)))

            dataset = None
            dataset = DenoisingDataset(x, sigma)

            loader = DataLoader(dataset=dataset,
                                num_workers=4,
                                drop_last=True,
                                batch_size=batch_size,
                                shuffle=True)
            epoch_loss = 0
            start_time = time.time()
            n_count = 0
            for cnt, batch_yx in enumerate(loader):
                optimizer.zero_grad()
                if torch.cuda.is_available():
                    batch_original, batch_noise = batch_yx[1].to(
                        device), batch_yx[0].to(device)

                r = model(batch_noise)
                loss = criterion(batch_noise - r, batch_original)
                epoch_loss += loss.item()
                loss.backward()
                optimizer.step()
                if cnt % 100 == 0:
                    print('%4d %4d / %4d loss = %2.4f' %
                          (epoch + 1, cnt, x.size(0) // batch_size,
                           loss.item() / batch_size))
                n_count += 1

            elapsed_time = time.time() - start_time
            print('epoch = %4d , sigma = %4d, loss = %4.4f , time = %4.2f s' %
                  (epoch + 1, sig, epoch_loss / n_count, elapsed_time))
            if (epoch + 1) % 25 == 0:
                torch.save(
                    model,
                    os.path.join(
                        model_dir,
                        save_name.replace('.pth', '_epoch%03d.pth') %
                        (epoch + 1)))
            torch.save(model, os.path.join(model_dir, save_name))

    torch.save(model, os.path.join(model_dir, save_name))
コード例 #10
0
ファイル: train.py プロジェクト: nuguziii/ResidualDenoising
def pretrain_SNet(batch_size=128,
                  n_epoch=100,
                  sigma=25,
                  lr=1e-4,
                  device="cuda:0",
                  data_dir='./data/Train400',
                  model_dir='models/SNet',
                  model_name=None,
                  model=0):
    device = torch.device(device)
    if not os.path.exists(model_dir):
        os.mkdir(os.path.join(model_dir))

    DNet = torch.load(os.path.join(model_dir, model_name[0]))
    if model == 0:
        model = SNet_jfver1()
        save_name = 'SNet_jfver1'
    elif model == 1:
        model = SNet_dfver1()
        save_name = 'SNet_dfver1'
    elif model == 2:
        model = SNet_dfver2()
        save_name = 'SNet_dfver2'
    elif model == 3:
        model = SNet_texture_ver1()
        save_name = 'SNet_texture_ver1'

    from datetime import date
    save_name = save_name + "_" + "".join(str(
        date.today()).split('-')[1:]) + ".pth"

    f = open(os.path.join(model_dir, save_name.replace(".pth", ".txt")), 'w')

    print('\n')
    print('--\t This model is pre-trained SNet saved as ', save_name)
    print('--\t epoch %4d batch_size %4d sigma %4d' %
          (n_epoch, batch_size, sigma))
    print('\n')

    DNet.eval()
    model.train()

    print(model)
    print("\n")

    criterion = sum_squared_error()

    if torch.cuda.is_available():
        DNet.to(device)
        model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90],
                            gamma=0.2)  # learning rates
    for epoch in range(n_epoch):
        x = dg.datagenerator(data_dir=data_dir).astype('float32') / 255.0
        x = torch.from_numpy(x.transpose((0, 3, 1, 2)))
        dataset = DenoisingDataset(x, sigma)
        loader = DataLoader(dataset=dataset,
                            num_workers=4,
                            drop_last=True,
                            batch_size=batch_size,
                            shuffle=True)
        epoch_loss = 0
        start_time = time.time()
        n_count = 0
        for cnt, batch_yx in enumerate(loader):
            optimizer.zero_grad()
            if torch.cuda.is_available():
                batch_original, batch_noise = batch_yx[1].to(
                    device), batch_yx[0].to(device)

            r = DNet(batch_noise)
            d = batch_noise - r
            #r=1.55*(r+0.5)-0.8
            s = model(r, d)
            #target = 1.8*(batch_original-d+0.5)-0.8
            loss = criterion(s, batch_original - d)
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()
            if cnt % 100 == 0:
                line = '%4d %4d / %4d loss = %2.4f' % (
                    epoch + 1, cnt, x.size(0) // batch_size,
                    loss.item() / batch_size)
                print(line)
                f.write(line)
            n_count += 1

        elapsed_time = time.time() - start_time
        line = 'epoch = %4d , loss = %4.4f , time = %4.2f s' % (
            epoch + 1, epoch_loss / n_count, elapsed_time)
        print(line)
        f.write(line)
        if (epoch + 1) % 1 == 0:
            torch.save(
                model,
                os.path.join(
                    model_dir,
                    save_name.replace('.pth', '_epoch%03d.pth') % (epoch + 1)))

    torch.save(model, os.path.join(model_dir, save_name))
    f.close()
コード例 #11
0
ファイル: main_train.py プロジェクト: csprh/WATERCODE
                batch_y = ys[indices[i:i+batch_size]]
                yield batch_y, batch_x

# define loss
def sum_squared_error(y_true, y_pred):
    #return K.mean(K.square(y_pred - y_true), axis=-1)
    #return K.sum(K.square(y_pred - y_true), axis=-1)/2
    return K.sum(K.square(y_pred - y_true))/2

if __name__ == '__main__':

    import pudb; pu.db
    # model selection
    model = DnCNN(depth=17,filters=64,image_channels=1,use_bnorm=True)
    model.summary()
    xs, ys = dg.datagenerator(args.train_data)
    # load the last model in matconvnet style
    initial_epoch = findLastCheckpoint(save_dir=save_dir)
    if initial_epoch > 0:
        print('resuming by loading epoch %03d'%initial_epoch)
        model = load_model(os.path.join(save_dir,'model_%03d.hdf5'%initial_epoch), compile=False)

    # compile the model
    model.compile(optimizer=Adam(0.001), loss=sum_squared_error)

    # use call back functions
    checkpointer = ModelCheckpoint(os.path.join(save_dir,'model_{epoch:03d}.hdf5'),
                verbose=1, save_weights_only=False, period=args.save_every)
    csv_logger = CSVLogger(os.path.join(save_dir,'log.csv'), append=True, separator=',')
    lr_scheduler = LearningRateScheduler(lr_schedule)
コード例 #12
0
    model.train()
    criterion = sum_squared_error()
    if cuda:
        model = model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90],
                            gamma=0.2)  # learning rates

    for epoch in range(initial_epoch, n_epoch):

        for subepoch in range(0, 10):

            logging.info('epoch' + str(epoch))
            scheduler.step(epoch)  # step to the learning rate in this epcoh
            xs, xn = dg.datagenerator(data_dir=args.train_data,
                                      data_dir_noise=args.train_data_noise,
                                      batch_size=batch_size)
            listr = list(range(0, xs.shape[0]))
            random.shuffle(listr)
            xs = xs[listr, :, :, :]
            xn = xn[listr, :, :, :]
            xs = xs.astype('float32') / 255.0
            xn = xn.astype('float32') / 255.0
            xs = torch.from_numpy(xs.transpose(
                (0, 3, 1, 2)))  # tensor of the clean patches, NXCXHXW
            xn = torch.from_numpy(xn.transpose(
                (0, 3, 1, 2)))  # tensor of the clean patches, NXCXHXW
            DDataset = DenoisingDataset(xs, xn)
            DLoader = DataLoader(dataset=DDataset,
                                 num_workers=8,
                                 drop_last=True,
コード例 #13
0
def train_model(config):
    # Define hyper-parameters.
    depth = int(config["DnCNN"]["depth"])
    n_channels = int(config["DnCNN"]["n_channels"])
    img_channel = int(config["DnCNN"]["img_channel"])
    kernel_size = int(config["DnCNN"]["kernel_size"])
    use_bnorm = config.getboolean("DnCNN", "use_bnorm")
    epochs = int(config["DnCNN"]["epoch"])
    batch_size = int(config["DnCNN"]["batch_size"])
    train_data_dir = config["DnCNN"]["train_data_dir"]
    test_data_dir = config["DnCNN"]["test_data_dir"]
    eta_min = float(config["DnCNN"]["eta_min"])
    eta_max = float(config["DnCNN"]["eta_max"])
    dose = float(config["DnCNN"]["dose"])
    model_save_dir = config["DnCNN"]["model_save_dir"]

    # Save logs to txt file.
    log_dir = config["DnCNN"]["log_dir"]
    log_dir = Path(log_dir) / "dose{}".format(str(int(dose * 100)))
    log_file = log_dir / "train_result.txt"

    # Define device.
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Initiate a DnCNN instance.
    # Load the model to device and set the model to training.
    model = DnCNN(depth=depth, n_channels=n_channels,
                  img_channel=img_channel,
                  use_bnorm=use_bnorm,
                  kernel_size=kernel_size)

    model = model.to(device)
    model.train()

    # Define loss criterion and optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2)
    criterion = LossFunc(reduction="mean")

    # Get a validation test set and corrupt with noise for validation performance.
    # For every epoch, use this pre-determined noisy images.
    test_file_list = glob.glob(test_data_dir + "/*.png")
    xs_test = []
    # Can't directly convert the xs_test from list to ndarray because some images are 512*512
    # while the rest are 256*256.
    for i in range(len(test_file_list)):
        img = cv2.imread(test_file_list[i], 0)
        img = np.array(img, dtype="float32") / 255.0
        img = np.expand_dims(img, axis=0)
        img_noisy, _ = nm(img, eta_min, eta_max, dose, t=100)
        xs_test.append((img_noisy, img))

    # Train the model.
    loss_store = []
    epoch_loss_store = []
    psnr_store = []
    ssim_store = []

    psnr_tr_store = []
    ssim_tr_store = []
    
    loss_mse = torch.nn.MSELoss()

    dtype = torch.cuda.FloatTensor
    # load vgg network
    vgg = Vgg16().type(dtype)
    
    
    for epoch in range(epochs):
        # For each epoch, generate clean augmented patches from the training directory.
        # Convert the data from uint8 to float32 then scale them to make it in [0, 1].
        # Then make the patches to be of shape [N, C, H, W],
        # where N is the batch size, C is the number of color channels.
        # H and W are height and width of image patches.
        xs = dg.datagenerator(data_dir=train_data_dir)
        xs = xs.astype("float32") / 255.0
        xs = torch.from_numpy(xs.transpose((0, 3, 1, 2)))

        train_set = dg.DenoisingDatatset(xs, eta_min, eta_max, dose)
        train_loader = DataLoader(dataset=train_set, num_workers=4,
                                  drop_last=True, batch_size=batch_size,
                                  shuffle=True)  # TODO: if drop_last=True, the dropping in the
                                                 # TODO: data_generator is not necessary?

        # train_loader_test = next(iter(train_loader))

        t_start = timer()
        epoch_loss = 0
        for idx, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            img_batch_read = len(inputs)

            optimizer.zero_grad()

            outputs = model(inputs)
            
            # We can use labels for both style and content image
            
                # style image
#             style_transform = transforms.Compose([
#             normalize_tensor_transform()      # normalize with ImageNet values
#             ])
            
#             labels_t = style_transform(labels)
                        
            labels_t = labels.repeat(1, 3, 1, 1)
            outputs_t = outputs.repeat(1, 3, 1, 1)            
            
            y_c_features = vgg(labels_t)
            style_gram = [gram(fmap) for fmap in y_c_features]
            
            y_hat_features = vgg(outputs_t)
            y_hat_gram = [gram(fmap) for fmap in y_hat_features]            
            
            # calculate style loss
            style_loss = 0.0
            for j in range(4):
                style_loss += loss_mse(y_hat_gram[j], style_gram[j][:img_batch_read])
            style_loss = STYLE_WEIGHT*style_loss
            aggregate_style_loss = style_loss

            # calculate content loss (h_relu_2_2)
            recon = y_c_features[1]      
            recon_hat = y_hat_features[1]
            content_loss = CONTENT_WEIGHT*loss_mse(recon_hat, recon)
            aggregate_content_loss = content_loss
            
            loss = aggregate_content_loss + aggregate_style_loss
#             loss = criterion(outputs, labels)
            
            loss_store.append(loss.item())
            epoch_loss += loss.item()

            loss.backward()

            optimizer.step()

            if idx % 100 == 0:
                print("Epoch [{} / {}], step [{} / {}], loss = {:.5f}, lr = {:.6f}, elapsed time = {:.2f}s".format(
                    epoch + 1, epochs, idx, len(train_loader), loss.item(), *scheduler.get_last_lr(), timer()-t_start))

        epoch_loss_store.append(epoch_loss / len(train_loader))

        # At each epoch validate the result.
        model = model.eval()

        # # Firstly validate on training sets. This takes a long time so I commented.
        # tr_psnr = []
        # tr_ssim = []
        # # t_start = timer()
        # with torch.no_grad():
        #     for idx, train_data in enumerate(train_loader):
        #         inputs, labels = train_data
        #         # print(inputs.shape)
        #         # inputs = np.expand_dims(inputs, axis=0)
        #         # inputs = torch.from_numpy(inputs).to(device)
        #         inputs = inputs.to(device)
        #         labels = labels.squeeze().numpy()
        #
        #         outputs = model(inputs)
        #         outputs = outputs.squeeze().cpu().detach().numpy()
        #
        #         tr_psnr.append(peak_signal_noise_ratio(labels, outputs))
        #         tr_ssim.append(structural_similarity(outputs, labels))
        # psnr_tr_store.append(sum(tr_psnr) / len(tr_psnr))
        # ssim_tr_store.append(sum(tr_ssim) / len(tr_ssim))
        # # print("Elapsed time = {}".format(timer() - t_start))
        #
        # print("Validation on train set: epoch [{} / {}], aver PSNR = {:.2f}, aver SSIM = {:.4f}".format(
        #     epoch + 1, epochs, psnr_tr_store[-1], ssim_tr_store[-1]))

        # Validate on test set
        val_psnr = []
        val_ssim = []
        with torch.no_grad():
            for idx, test_data in enumerate(xs_test):
                inputs, labels = test_data
                inputs = np.expand_dims(inputs, axis=0)
                inputs = torch.from_numpy(inputs).to(device)
                labels = labels.squeeze()

                outputs = model(inputs)
                outputs = outputs.squeeze().cpu().detach().numpy()

                val_psnr.append(peak_signal_noise_ratio(labels, outputs))
                val_ssim.append(structural_similarity(outputs, labels))

        psnr_store.append(sum(val_psnr) / len(val_psnr))
        ssim_store.append(sum(val_ssim) / len(val_ssim))

        print("Validation on test set: epoch [{} / {}], aver PSNR = {:.2f}, aver SSIM = {:.4f}".format(
            epoch + 1, epochs, psnr_store[-1], ssim_store[-1]))

        # Set model to train mode again.
        model = model.train()

        scheduler.step()

        # Save model
        save_model(model, model_save_dir, epoch, dose * 100)

        # Save the loss and validation PSNR, SSIM.

        if not log_dir.exists():
            Path.mkdir(log_dir)
        with open(log_file, "a+") as fh:
            # fh.write("{} Epoch [{} / {}], loss = {:.6f}, train PSNR = {:.2f}dB, train SSIM = {:.4f}, "
            #          "validation PSNR = {:.2f}dB, validation SSIM = {:.4f}".format(
            #          datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),
            #          epoch + 1, epochs, epoch_loss_store[-1],
            #          psnr_tr_store[-1], ssim_tr_store[-1],
            #          psnr_store[-1], ssim_store[-1]))
            fh.write("{} Epoch [{} / {}], loss = {:.6f}, "
                     "validation PSNR = {:.2f}dB, validation SSIM = {:.4f}\n".format(
                     datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),
                     epoch + 1, epochs, epoch_loss_store[-1],
                     psnr_store[-1], ssim_store[-1]))

        # np.savetxt(log_file, np.hstack((epoch + 1, epoch_loss_store[-1], psnr_store[-1], ssim_store[-1])), fmt="%.6f", delimiter=",  ")

        fig, ax = plt.subplots()
        ax.plot(loss_store[-len(train_loader):])
        ax.set_title("Last 1862 losses")
        ax.set_xlabel("iteration")
        fig.show()
コード例 #14
0
def train(batch_size=128,
          n_epoch=150,
          sigma=25,
          lr=1e-3,
          lr2=1e-5,
          depth=17,
          device="cuda:0",
          data_dir='./data/Train400',
          model_dir='models'):
    device = torch.device(device)

    from datetime import date
    model_dir = os.path.join(
        model_dir, "result_" + "".join(str(date.today()).split('-')[1:]))
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    save_name = "net.pth"
    save_name2 = "net2.pth"

    print('\n')
    print('--\t This model is pre-trained DNet saved as ', save_name)
    print('--\t epoch %4d batch_size %4d sigma %4d depth %4d' %
          (n_epoch, batch_size, sigma, depth))
    print('\n')

    model = DnCNN(depth=depth)
    model2 = DnCNN(depth=depth)

    model.train()
    model2.train()

    criterion = sum_squared_error()
    criterion_l1 = nn.L1Loss(size_average=None, reduce=None, reduction='sum')

    if torch.cuda.is_available():
        model.to(device)
        model2.to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90],
                            gamma=0.2)  # learning rates
    optimizer2 = optim.Adam(model2.parameters(), lr=lr2, weight_decay=1e-5)
    scheduler2 = MultiStepLR(optimizer2, milestones=[30, 60, 90],
                             gamma=0.2)  # learning rates

    for epoch in range(n_epoch):
        x = dg.datagenerator(data_dir=data_dir).astype('float32') / 255.0
        x = torch.from_numpy(x.transpose((0, 3, 1, 2)))

        dataset = None
        dataset = DenoisingDataset(x, sigma)

        loader = DataLoader(dataset=dataset,
                            num_workers=4,
                            drop_last=True,
                            batch_size=batch_size,
                            shuffle=True)
        epoch_loss = 0
        epoch_loss_first = 0
        start_time = time.time()
        n_count = 0
        for cnt, batch_yx in enumerate(loader):
            optimizer.zero_grad()
            if torch.cuda.is_available():
                batch_original, batch_noise = batch_yx[1].to(
                    device), batch_yx[0].to(device)

            residual = model(batch_noise)
            loss_first = criterion(batch_noise - residual, batch_original)
            loss_first.backward(retain_graph=True)
            optimizer.step()

            structure_residual = model2(residual)
            target = batch_original - (batch_noise - residual)
            structure = residual - structure_residual
            loss = criterion_l1(structure, target)
            loss.backward()
            optimizer2.step()

            epoch_loss_first += loss_first.item()
            epoch_loss += loss.item()

            if cnt % 100 == 0:
                print(
                    '%4d %4d / %4d 1_loss = %2.4f loss = %2.4f' %
                    (epoch + 1, cnt, x.size(0) // batch_size,
                     loss_first.item() / batch_size, loss.item() / batch_size))
            n_count += 1

        elapsed_time = time.time() - start_time
        print(
            'epoch = %4d , sigma = %4d, 1_loss = %4.4f, loss = %4.4f , time = %4.2f s'
            % (epoch + 1, sigma, epoch_loss_first / n_count,
               epoch_loss / n_count, elapsed_time))
        if (epoch + 1) % 25 == 0:
            torch.save(
                model,
                os.path.join(
                    model_dir,
                    save_name.replace('.pth', '_epoch%03d.pth') % (epoch + 1)))
            torch.save(
                model2,
                os.path.join(
                    model_dir,
                    save_name2.replace('.pth', '_epoch%03d.pth') %
                    (epoch + 1)))

    torch.save(model, os.path.join(model_dir, save_name))
    torch.save(model2, os.path.join(model_dir, save_name2))
コード例 #15
0
        model = torch.load(
            os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))
    model.train()
    # criterion = nn.MSELoss(reduction = 'sum')  # PyTorch 0.4.1
    criterion = sum_squard_error()
    if use_cuda:
        model = model.cuda()
        criterion = criterion.cuda()
    optimizer = optim.SGD(model.parameters(), lr=args.lr)
    scheduler = Polyscheduler(optimizer, EPOCH=n_epoch,
                              gamma=0.8)  # learning rates
    # scheduler = LrScheduler.ReduceLROnPlateau(optimizer, 'min')  # learning rates

    writer = SummaryWriter(comment='train')

    DATA = dg.datagenerator(data_dir=filename)
    # data = torch.rand(20, 1, 11, 11)
    DATA = DATA.astype('float32') / 255.0
    DATA = torch.from_numpy(DATA.transpose((0, 3, 1, 2)))
    LABEL = dg.datagenerator(data_dir=filename1)
    # LABEL = torch.rand(20, 1, 11, 11)
    LABEL = LABEL.astype('float32') / 255.0
    LABEL = torch.from_numpy(LABEL.transpose((0, 3, 1, 2)))

    torch_dataset = DATASET(DATA, LABEL)

    for epoch in range(initial_epoch, n_epoch):
        scheduler.step(epoch)  # step to the learning rate in this epcoh
        LR = torch.FloatTensor(scheduler.get_lr())
        DLoader = DataLoader(dataset=torch_dataset,
                             num_workers=4,