Esempio n. 1
0
 def set_model(self, model_selected=None):
     self.clean_model()
     if model_selected is not None:
         self.model_num = set_model(self, model_selected)
Esempio n. 2
0
def run_train(opt, training_data_loader):
    # check gpu setting with opt arguments
    opt = set_gpu(opt)

    print('Initialize networks for training')
    net = set_model(opt)
    print(net)

    if opt.use_cuda:
        net = net.to(opt.device)

    print("Setting Optimizer")
    if opt.optimizer == 'adam':
        optimizer = optim.Adam(net.parameters(),
                               lr=opt.lr,
                               betas=(opt.b1, opt.b2),
                               eps=1e-8,
                               weight_decay=0)
        print("===> Use Adam optimizer")

    if opt.resume:
        opt.start_epoch, net, optimizer = load_model(opt,
                                                     net,
                                                     optimizer=optimizer)
    else:
        set_checkpoint_dir(opt)

    if opt.multi_gpu:
        net = nn.DataParallel(net)

    if not os.path.exists(opt.checkpoint_dir):
        os.makedirs(opt.checkpoint_dir)
    log_file = os.path.join(opt.checkpoint_dir, opt.model + "_log.csv")
    opt_file = os.path.join(opt.checkpoint_dir, opt.model + "_opt.txt")

    scheduler = ReduceLROnPlateau(optimizer,
                                  factor=0.5,
                                  patience=5,
                                  mode='min')
    # scheduler = StepLR(optimizer, step_size=50, gamma=0.5)

    # Create log file when training start
    if opt.start_epoch == 1:
        with open(log_file, mode='w') as f:
            f.write("epoch,train_loss,valid_loss\n")
        save_config(opt)

    data_loader = {
        'train': training_data_loader,
    }
    modes = ['train', 'valid']

    l2_criterion = nn.MSELoss()
    l1_criterion = nn.L1Loss()
    if opt.use_cuda:
        l2_criterion = l2_criterion.to(opt.device)
        l1_criterion = l1_criterion.to(opt.device)

    if opt.content_loss == 'l2':
        content_loss_criterion = l2_criterion
    elif opt.content_loss == 'l1':
        content_loss_criterion = l1_criterion
    else:
        raise ValueError("Specify content loss correctly (l1, l2)")

    if opt.style_loss == 'l2':
        style_loss_criterion = l2_criterion
    elif opt.style_loss == 'l1':
        style_loss_criterion = l1_criterion
    else:
        raise ValueError("Specify style loss correctly (l1, l2)")

    if opt.ll_loss == 'l2':
        ll_loss_criterion = l2_criterion
    elif opt.ll_loss == 'l1':
        ll_loss_criterion = l1_criterion
    else:
        raise ValueError("Specify style loss correctly (l1, l2)")

    nc = opt.n_channels
    np.random.seed(1024)
    sq = np.arange(1024)
    np.random.shuffle(sq)

    for epoch in range(opt.start_epoch, opt.n_epochs):
        opt.epoch_num = epoch
        for phase in modes:
            if phase == 'train':
                total_loss = 0.0
                total_psnr = 0.0
                total_iteration = 0

                net.train()

                mode = "Training"
                print("*** %s ***" % mode)
                start_time = time.time()

                for iteration, batch in enumerate(data_loader[phase], 1):
                    # (_, x), (_, target) = batch[0], batch[1]
                    x, target = batch[0], batch[1]
                    x_img, target_img = batch[3], batch[4]
                    lr_approx = batch[5]

                    if opt.use_cuda:
                        x = x.to(opt.device)
                        target = target.to(opt.device)

                    optimizer.zero_grad()

                    # epoch_loss = 0.
                    with torch.set_grad_enabled(phase == 'train'):
                        out = net(x)

                        # norm_target = normalize_coeffs(target, ch_min=opt.ch_min, ch_max=opt.ch_max)
                        std_target = standarize_coeffs(target,
                                                       ch_mean=opt.ch_mean,
                                                       ch_std=opt.ch_std)
                        # norm_out = normalize_coeffs(out, ch_min=opt.ch_min, ch_max=opt.ch_max)
                        std_out = standarize_coeffs(out,
                                                    ch_mean=opt.ch_mean,
                                                    ch_std=opt.ch_std)

                        ll_target = std_target[:, 0:nc, :, :]
                        ll_out = std_out[:, 0:nc, :, :]
                        high_target = std_target[:, nc:, :, :]
                        high_out = std_out[:, nc:, :, :]

                        # log_channel_loss(std_out, std_target, content_loss_criterion)
                        ll_content_loss = content_loss_criterion(
                            ll_target, ll_out)
                        ll_style_loss = 0
                        # content_loss = content_loss_criterion(norm_target, norm_out)
                        high_content_loss = content_loss_criterion(
                            high_target, high_out)
                        high_style_loss = 0

                        ll_loss = ll_content_loss + ll_style_loss
                        high_loss = high_content_loss + high_style_loss
                        epoch_loss = opt.ll_weight * ll_loss + (
                            1 - opt.ll_weight) * high_loss

                        # L1 loss for wavelet coeffiecients
                        l1_loss = 0

                        total_loss += epoch_loss.item()

                        epoch_loss.backward()
                        optimizer.step()

                    mse_loss = l2_criterion(out, target)
                    psnr = 10 * math.log10(1 / mse_loss.item())
                    total_psnr += psnr

                    print(
                        "High Content Loss: {:5f}, High Style Loss: {:5f}, LL Content Loss: {:5f}, LL Style Loss:{:5f}"
                        .format(high_content_loss, high_style_loss,
                                ll_content_loss, ll_style_loss))
                    print(
                        "{} {:4f}s => Epoch[{}/{}]({}/{}): Epoch Loss: {:5f} High Loss: {:5f} LL Loss: {:5f} L1 Loss: {:5f} PSNR: {:5f}"
                        .format(mode,
                                time.time() - start_time,
                                opt.epoch_num, opt.n_epochs, iteration,
                                len(data_loader[phase]), epoch_loss.item(),
                                high_loss.item(), ll_loss.item(), l1_loss,
                                psnr))

                    total_iteration = iteration

                total_loss = total_loss / total_iteration
                total_psnr = total_psnr / total_iteration

                train_loss = total_loss
                train_psnr = total_psnr

            else:
                net.eval()
                mode = "Validation"
                print("*** %s ***" % mode)
                valid_loss, valid_psnr = run_valid(opt, net,
                                                   content_loss_criterion, sq)
                scheduler.step(valid_loss)

        with open(log_file, mode='a') as f:
            f.write("%d,%08f,%08f,%08f,%08f\n" %
                    (epoch, train_loss, train_psnr, valid_loss, valid_psnr))

        save_checkpoint(opt, net, optimizer, epoch, valid_loss)
Esempio n. 3
0
def prep_Result(opt):

    net = set_model(opt)
    _, net, _ = load_model(opt, net)

    if opt.n_channels == 1:
        from skimage.external.tifffile import imsave, imread
    else:
        from skimage.io import imsave, imread

    opt = set_gpu(opt)

    if opt.use_cuda:
        net = net.to(opt.device)

    if opt.multi_gpu:
        net = nn.DataParallel(net)

    set_test_dir(opt)
    if not os.path.exists(opt.test_result_dir):
        os.makedirs(opt.test_result_dir)

    res_img_dir = os.path.join(opt.test_result_dir, 'result_img_dir')
    if not os.path.exists(res_img_dir):
        os.makedirs(res_img_dir)

    # create results directory
    res_dir = os.path.join(opt.test_result_dir, 'res_dir')
    os.makedirs(os.path.join(res_dir), exist_ok=True)

    print('\ntest_result_dir : ', opt.test_result_dir)
    print('\nresult_img_dir : ', res_img_dir)
    print('\nresult_dir : ', res_dir)

    # total_psnr = 0
    # total_ssim = 0

    loss_criterion = nn.MSELoss()
    total_psnr = 0.0

    if opt.n_channels == 1:
        # load noisy images
        noisy_fn = 'siddplus_test_noisy_raw.mat'
        noisy_key = 'siddplus_test_noisy_raw'
        noisy_mat = loadmat(os.path.join(opt.test_dir, opt.dataset,
                                         noisy_fn))[noisy_key]

        # denoise
        n_im, h, w = noisy_mat.shape
        results = noisy_mat.copy()

        start_time = time.time()
        for i in range(n_im):
            print('\n[*]PROCESSING..{}/{}'.format(i, n_im))

            noisy = np.reshape(noisy_mat[i, :, :], (h, w))
            denoised = denoiser(opt, net, noisy)
            results[i, :, :] = denoised

            result_name = str(i) + '.tiff'
            concat_img = np.concatenate((denoised, noisy), axis=1)
            imsave(os.path.join(res_img_dir, result_name), concat_img)

            denoised = torch.Tensor(denoised)
            noisy = torch.Tensor(noisy)

            mse_loss = loss_criterion(denoised, noisy)
            psnr = 10 * math.log10(1 / mse_loss.item())
            total_psnr += psnr
            print('%.5fs .. [%d/%d] psnr : %.5f, avg_psnr : %.5f' %
                  (time.time() - start_time, i, n_im, psnr, total_psnr /
                   (i + 1)))

    else:
        # load noisy images
        noisy_fn = 'siddplus_test_noisy_srgb.mat'
        noisy_key = 'siddplus_test_noisy_srgb'
        noisy_mat = loadmat(os.path.join(opt.test_dir, opt.dataset,
                                         noisy_fn))[noisy_key]

        # denoise
        n_im, h, w, c = noisy_mat.shape
        results = noisy_mat.copy()

        start_time = time.time()
        for i in range(n_im):
            print('\n[*]PROCESSING..{}/{}'.format(i, n_im))

            noisy = np.reshape(noisy_mat[i, :, :, :], (h, w, c))
            denoised = denoiser(opt, net, noisy)
            results[i, :, :, :] = denoised

            result_name = str(i) + '.png'
            concat_img = np.concatenate((denoised, noisy), axis=1)
            imsave(os.path.join(res_img_dir, result_name), concat_img)

            denoised = torch.Tensor(denoised).float() / 255.0
            noisy = torch.Tensor(noisy).float() / 255.0

            mse_loss = loss_criterion(noisy, denoised)
            psnr = 10 * math.log10(1 / mse_loss.item())
            total_psnr += psnr
            print('%.5fs .. [%d/%d] psnr : %.5f, avg_psnr : %.5f' %
                  (time.time() - start_time, i, n_im, psnr, total_psnr /
                   (i + 1)))

    print("****total avg psnr : %.10f", total_psnr / (n_im))
    # save denoised images in a .mat file with dictionary key "results"
    res_fn = os.path.join(res_dir, 'results.mat')
    res_key = 'results'  # Note: do not change this key, the evaluation code will look for this key
    savemat(res_fn, {res_key: results})

    runtime = 0.0  # seconds / megapixel
    cpu_or_gpu = 0  # 0: GPU, 1: CPU
    use_metadata = 0  # 0: no use of metadata, 1: metadata used
    other = '(optional) any additional description or information'

    # prepare and save readme file
    readme_fn = os.path.join(res_dir,
                             'readme.txt')  # Note: do not change 'readme.txt'
    with open(readme_fn, 'w') as readme_file:
        readme_file.write('Runtime (seconds / megapixel): %s\n' % str(runtime))
        readme_file.write('CPU[1] / GPU[0]: %s\n' % str(cpu_or_gpu))
        readme_file.write('Metadata[1] / No Metadata[0]: %s\n' %
                          str(use_metadata))
        readme_file.write('Other description: %s\n' % str(other))

    # compress results directory
    res_zip_fn = 'results_dir'
    shutil.make_archive(os.path.join(opt.test_result_dir, res_zip_fn), 'zip',
                        res_dir)
Esempio n. 4
0
def main():
    temp_args = get_arguments()
    assert temp_args.snapshot is not None, 'snapshot must be selected!'
    set_seed()

    args = argparse.ArgumentParser().parse_args(args=[])
    tmp = yaml.full_load(
        open(
            f'{temp_args.result_path}/'
            f'{temp_args.dataset}/'
            f'{temp_args.stamp}/'
            'model_desc.yml', 'r'))

    for k, v in tmp.items():
        setattr(args, k, v)

    args.snapshot = temp_args.snapshot
    args.src_path = temp_args.src_path
    args.data_path = temp_args.data_path
    args.result_path = temp_args.result_path
    args.gpus = temp_args.gpus
    args.batch_size = 1

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info(f"{k} : {v}")

    ##########################
    # Dataset
    ##########################
    _, valset = set_dataset(args.dataset, args.classes, args.data_path)
    validation_steps = len(valset)

    logger.info("TOTAL STEPS OF DATASET FOR EVALUATION")
    logger.info("=========== VALSET ===========")
    logger.info(f"    --> {validation_steps}")

    ##########################
    # Model & Generator
    ##########################
    model = set_model(args.backbone, args.dataset, args.classes)
    model.load_weights(args.snapshot)
    logger.info(f"Load weights at {args.snapshot}")

    model.compile(loss=args.loss,
                  batch_size=args.batch_size,
                  optimizer=tf.keras.optimizers.SGD(args.lr, momentum=.9),
                  metrics=[
                      tf.keras.metrics.TopKCategoricalAccuracy(k=1,
                                                               name='acc1'),
                      tf.keras.metrics.TopKCategoricalAccuracy(k=5,
                                                               name='acc5')
                  ],
                  xe_loss=tf.keras.losses.categorical_crossentropy,
                  cls_loss=tf.keras.losses.KLD,
                  cls_lambda=args.loss_weight,
                  temperature=args.temperature)

    val_generator = DataLoader(loss='crossentropy',
                               mode='val',
                               datalist=valset,
                               dataset=args.dataset,
                               classes=args.classes,
                               batch_size=args.batch_size,
                               shuffle=False).dataloader()

    ##########################
    # Evaluation
    ##########################
    print(
        model.evaluate(val_generator, steps=validation_steps,
                       return_dict=True))
Esempio n. 5
0
def main():
    # parse arguments
    cfg = parse_arguments(funcs=[add_arguments])

    # get the name of a model
    arch_name = models.utils.set_arch_name(cfg)

    # set a logger
    logger = utils.Logger(cfg, arch_name)

    # construct a model
    logger.print('Building a model ...')
    model, image_size = models.set_model(cfg)

    # profile the model
    input = torch.randn(1, 3, image_size, image_size)
    macs, params = profile(model, inputs=(input, ), verbose=False)
    logger.print(
        f'Name: {arch_name}    (Params: {int(params)}, FLOPs: {int(macs)})')

    # set other options
    criterion = nn.CrossEntropyLoss()
    optimizer = set_optimizer(model, cfg)
    lr_scheduler = set_lr_scheduler(optimizer, cfg)

    # load dataset
    loaders = datasets.set_dataset(cfg, image_size)

    # set a trainer
    trainer = Trainer(cfg=cfg,
                      model=model,
                      criterion=criterion,
                      optimizer=optimizer,
                      lr_scheduler=lr_scheduler,
                      loaders=loaders,
                      logger=logger)

    # set device
    trainer.set_device()

    # run
    if cfg.run_type == 'train':
        # set hooks
        if cfg.load is not None:
            if not cfg.resume:
                trainer.register_hooks(loc='before_train', func=[load_init])
            else:
                trainer.register_hooks(loc='before_train', func=[load_resume])
        if cfg.step_location == 'epoch':
            trainer.register_hooks(loc='after_epoch', func=[step_lr_epoch])
        else:
            trainer.register_hooks(loc='after_batch', func=[step_lr_batch])
        trainer.register_hooks(loc='after_epoch',
                               func=[save_train, summarize_reports])

        trainer.train()

    elif cfg.run_type == 'validate':
        # set hooks
        trainer.register_hooks(loc='before_epoch', func=[load_valid])
        trainer.register_hooks(loc='after_epoch', func=[summarize_reports])

        trainer.validate()

    elif cfg.run_type == 'test':
        # set hooks
        trainer.register_hooks(loc='before_epoch', func=[load_valid])
        trainer.register_hooks(loc='after_epoch', func=[save_pred])

        trainer.test()

    elif cfg.run_type == 'analyze':
        # set hooks
        trainer.register_hooks(loc='before_epoch', func=[load_valid])
        # extract features
        from utils import FeatureExtractor
        extractor = FeatureExtractor()
        trainer.register_hooks(loc='before_epoch', func=[extractor.initialize])
        trainer.register_hooks(loc='after_batch',
                               func=[extractor.check_feature])
        trainer.register_hooks(loc='after_epoch',
                               func=[extractor.save_feature])

        trainer.analyze()
Esempio n. 6
0
def main():
    args = get_arguments()
    set_seed(args.seed)
    args.classes = CLASS_DICT[args.dataset]
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info(f"{k} : {v}")


    ##########################
    # Strategy
    ##########################
    if len(args.gpus.split(',')) > 1:
        strategy = tf.distribute.experimental.CentralStorageStrategy()
    else:
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
    
    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % num_workers == 0

    logger.info(f"{strategy.__class__.__name__} : {num_workers}")
    logger.info(f"GLOBAL BATCH SIZE : {args.batch_size}")


    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args.dataset, args.classes, args.data_path)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    validation_steps = len(valset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== TRAINSET ==========")
    logger.info(f"    --> {len(trainset)}")
    logger.info(f"    --> {steps_per_epoch}")

    logger.info("=========== VALSET ===========")
    logger.info(f"    --> {len(valset)}")
    logger.info(f"    --> {validation_steps}")


    ##########################
    # Model
    ##########################
    with strategy.scope():
        model = set_model(args.backbone, args.dataset, args.classes)
        if args.snapshot:
            model.load_weights(args.snapshot)
            logger.info(f"Load weights at {args.snapshot}")

        model.compile(
            loss=args.loss,
            optimizer=tf.keras.optimizers.SGD(args.lr, momentum=.9),
            metrics=[
                tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='acc1'),
                tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='acc5')],
            xe_loss=tf.keras.losses.categorical_crossentropy,
            cls_loss=tf.keras.losses.KLD,
            cls_lambda=args.loss_weight,
            temperature=args.temperature,
            num_workers=num_workers,
            run_eagerly=True)


    ##########################
    # Generator
    ##########################
    train_generator = DataLoader(
        loss=args.loss,
        mode='train', 
        datalist=trainset, 
        dataset=args.dataset, 
        classes=args.classes,
        batch_size=args.batch_size, 
        shuffle=True).dataloader()

    val_generator = DataLoader(
        loss='crossentropy',
        mode='val', 
        datalist=valset, 
        dataset=args.dataset, 
        classes=args.classes,
        batch_size=args.batch_size, 
        shuffle=False).dataloader()


    ##########################
    # Train
    ##########################
    callbacks, initial_epoch = create_callbacks(args, logger, initial_epoch)
    if callbacks == -1:
        logger.info('Check your model.')
        return
    elif callbacks == -2:
        return

    model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,)