コード例 #1
0
def use(args):
    '''\
  Transform images with the net.

  Args:
    args: namespace of arguments. Run 'artRecycle use --help' for info.
  '''

    # Model name and paths
    model_name = '{}|{}'.format(*args.datasets)
    model_path, log_path, logs_path = _prepare_directories(model_name,
                                                           resume=True)

    model_json = os.path.join(model_path, 'keras.json')
    model_checkpoint = os.path.join(model_path, 'model')

    # Define dataset
    image_shape = (300, 300, 3)
    test_dataset, test_size = data.load_pair(*args.datasets,
                                             'test',
                                             shape=image_shape,
                                             batch=args.batch)

    # Define keras model
    keras_model, model_layer = models.define_model(image_shape)

    # Load
    keras_model.load_weights(model_checkpoint)
    print('> Weights loaded')

    raise NotImplementedError()
コード例 #2
0
ファイル: main.py プロジェクト: BrainGardenAI/TecoGAN-PyTorch
def test(opt):
    # logging
    logger = base_utils.get_logger('base')
    if opt['verbose']:
        logger.info('{} Configurations {}'.format('=' * 20, '=' * 20))
        base_utils.print_options(opt, logger)
    # infer and evaluate performance for each model
    for load_path in opt['model']['generator']['load_path_lst']:
        # setup model index
        model_idx = osp.splitext(osp.split(load_path)[-1])[0]
        
        # log
        logger.info('=' * 40)
        logger.info('Testing model: {}'.format(model_idx))
        logger.info('=' * 40)

        # create model
        opt['model']['generator']['load_path'] = load_path
        model = define_model(opt)
        model_idx = osp.basename(opt['model']['generator']['load_path']).split('.')[0]
        # for each test dataset
        for dataset_idx in sorted(opt['dataset'].keys()):
            # use dataset with prefix `test`
            if not dataset_idx.startswith('test'):
                continue
            validate(opt, model, logger, dataset_idx, model_idx, compute_metrics=False)

            logger.info('-' * 40)

    # logging
    logger.info('Finish testing')
    logger.info('=' * 40)
コード例 #3
0
ファイル: torch2onnx.py プロジェクト: Thmen/EGVSR
def test(opt):
    # logging
    logger = base_utils.get_logger('base')
    if opt['verbose']:
        logger.info('{} Configurations {}'.format('=' * 20, '=' * 20))
        base_utils.print_options(opt, logger)

    # infer and evaluate performance for each model
    for load_path in opt['model']['generator']['load_path_lst']:
        # setup model index
        model_idx = osp.splitext(osp.split(load_path)[-1])[0]

        # log
        logger.info('=' * 40)
        logger.info('Testing model: {}'.format(model_idx))
        logger.info('=' * 40)

        # create model
        opt['model']['generator']['load_path'] = load_path
        model = define_model(opt)

        # for each test dataset
        for dataset_idx in sorted(opt['dataset'].keys()):
            # use dataset with prefix `test`
            if not dataset_idx.startswith('test'):
                continue

            ds_name = opt['dataset'][dataset_idx]['name']
            logger.info('Testing on {}: {}'.format(dataset_idx, ds_name))

            # create data loader
            test_loader = create_dataloader(opt, dataset_idx=dataset_idx)

            # infer and store results for each sequence
            for i, data in enumerate(test_loader):

                # fetch data
                lr_data = data['lr'][0]
                seq_idx = data['seq_idx'][0]
                frm_idx = [frm_idx[0] for frm_idx in data['frm_idx']]

                # infer
                hr_seq = model.infer(lr_data)  # thwc|rgb|uint8

                # save results (optional)
                if opt['test']['save_res']:
                    res_dir = osp.join(opt['test']['res_dir'], ds_name,
                                       model_idx)
                    res_seq_dir = osp.join(res_dir, seq_idx)
                    data_utils.save_sequence(res_seq_dir,
                                             hr_seq,
                                             frm_idx,
                                             to_bgr=True)

            logger.info('-' * 40)

    # logging
    logger.info('Finish testing')
    logger.info('=' * 40)
コード例 #4
0
def debug(args):
    '''\
  Debugging function.

  Args:
    args: namespace of arguments. Run --help for info.
  '''
    import matplotlib.pyplot as plt

    print('> Debug')

    # Saving the Tensorboard graph without training

    # Model
    image_shape = (300, 300, 3)
    keras_model, model_layer = models.define_model(image_shape)

    keras_model.summary()

    # TensorBoard callback writer
    tbCallback = tf.keras.callbacks.TensorBoard('debug', write_graph=True)
    tbCallback.set_model(keras_model)
コード例 #5
0
def train(opt):
    # logging
    logger = base_utils.get_logger('base')
    logger.info('{} Options {}'.format('='*20, '='*20))
    base_utils.print_options(opt, logger)

    # create data loader
    train_loader = create_dataloader(opt, dataset_idx='train')

    # create downsampling kernels for BD degradation
    kernel = data_utils.create_kernel(opt)

    # create model
    model = define_model(opt)

    # training configs
    total_sample = len(train_loader.dataset)
    iter_per_epoch = len(train_loader)
    total_iter = opt['train']['total_iter']
    total_epoch = int(math.ceil(total_iter / iter_per_epoch))
    start_iter, iter = opt['train']['start_iter'], 0

    test_freq = opt['test']['test_freq']
    log_freq = opt['logger']['log_freq']
    ckpt_freq = opt['logger']['ckpt_freq']

    logger.info('Number of training samples: {}'.format(total_sample))
    logger.info('Total epochs needed: {} for {} iterations'.format(
        total_epoch, total_iter))

    # train
    for epoch in range(total_epoch):
        for data in train_loader:
            # update iter
            iter += 1
            curr_iter = start_iter + iter
            if iter > total_iter:
                logger.info('Finish training')
                break

            # update learning rate
            model.update_learning_rate()

            # prepare data
            data = prepare_data(opt, data, kernel)

            # train for a mini-batch
            model.train(data)

            # update running log
            model.update_running_log()

            # log
            if log_freq > 0 and iter % log_freq == 0:
                # basic info
                msg = '[epoch: {} | iter: {}'.format(epoch, curr_iter)
                for lr_type, lr in model.get_current_learning_rate().items():
                    msg += ' | {}: {:.2e}'.format(lr_type, lr)
                msg += '] '

                # loss info
                log_dict = model.get_running_log()
                msg += ', '.join([
                    '{}: {:.3e}'.format(k, v) for k, v in log_dict.items()])

                logger.info(msg)

            # save model
            if ckpt_freq > 0 and iter % ckpt_freq == 0:
                model.save(curr_iter)

            # evaluate performance
            if test_freq > 0 and iter % test_freq == 0:
                # setup model index
                model_idx = 'G_iter{}'.format(curr_iter)

                # for each testset
                for dataset_idx in sorted(opt['dataset'].keys()):
                    # use dataset with prefix `test`
                    if not dataset_idx.startswith('test'):
                        continue

                    ds_name = opt['dataset'][dataset_idx]['name']
                    logger.info(
                        'Testing on {}: {}'.format(dataset_idx, ds_name))

                    # create data loader
                    test_loader = create_dataloader(opt, dataset_idx=dataset_idx)

                    # define metric calculator
                    metric_calculator = MetricCalculator(opt)

                    # infer and compute metrics for each sequence
                    for data in test_loader:
                        # fetch data
                        lr_data = data['lr'][0]
                        seq_idx = data['seq_idx'][0]
                        frm_idx = [frm_idx[0] for frm_idx in data['frm_idx']]

                        # infer
                        hr_seq = model.infer(lr_data)  # thwc|rgb|uint8

                        # save results (optional)
                        if opt['test']['save_res']:
                            res_dir = osp.join(
                                opt['test']['res_dir'], ds_name, model_idx)
                            res_seq_dir = osp.join(res_dir, seq_idx)
                            data_utils.save_sequence(
                                res_seq_dir, hr_seq, frm_idx, to_bgr=True)

                        # compute metrics for the current sequence
                        true_seq_dir = osp.join(
                            opt['dataset'][dataset_idx]['gt_seq_dir'], seq_idx)
                        metric_calculator.compute_sequence_metrics(
                            seq_idx, true_seq_dir, '', pred_seq=hr_seq)

                    # save/print metrics
                    if opt['test'].get('save_json'):
                        # save results to json file
                        json_path = osp.join(
                            opt['test']['json_dir'], '{}_avg.json'.format(ds_name))
                        metric_calculator.save_results(
                            model_idx, json_path, override=True)
                    else:
                        # print directly
                        metric_calculator.display_results()
コード例 #6
0
def test(opt):
    # logging
    logger = base_utils.get_logger('base')
    if opt['verbose']:
        logger.info('{} Configurations {}'.format('=' * 20, '=' * 20))
        base_utils.print_options(opt, logger)

    # infer and evaluate performance for each model
    for load_path in opt['model']['generator']['load_path_lst']:
        # setup model index
        model_idx = osp.splitext(osp.split(load_path)[-1])[0]

        # log
        logger.info('=' * 40)
        logger.info('Testing model: {}'.format(model_idx))
        logger.info('=' * 40)

        # create model
        opt['model']['generator']['load_path'] = load_path
        model = define_model(opt)

        # for each test dataset
        for dataset_idx in sorted(opt['dataset'].keys()):
            # use dataset with prefix `test`
            if not dataset_idx.startswith('test'):
                continue

            ds_name = opt['dataset'][dataset_idx]['name']
            logger.info('Testing on {}: {}'.format(dataset_idx, ds_name))

            # define metric calculator
            try:
                metric_calculator = MetricCalculator(opt)
            except:
                print('No metirc need to compute!')

            # create data loader
            test_loader = create_dataloader(opt, dataset_idx=dataset_idx)

            # infer and store results for each sequence
            for i, data in enumerate(test_loader):

                # fetch data
                lr_data = data['lr'][0]
                seq_idx = data['seq_idx'][0]
                frm_idx = [frm_idx[0] for frm_idx in data['frm_idx']]

                # infer
                hr_seq = model.infer(lr_data)  # thwc|rgb|uint8

                # save results (optional)
                if opt['test']['save_res']:
                    res_dir = osp.join(opt['test']['res_dir'], ds_name,
                                       model_idx)
                    res_seq_dir = osp.join(res_dir, seq_idx)
                    data_utils.save_sequence(res_seq_dir,
                                             hr_seq,
                                             frm_idx,
                                             to_bgr=True)

                # compute metrics for the current sequence
                true_seq_dir = osp.join(
                    opt['dataset'][dataset_idx]['gt_seq_dir'], seq_idx)
                try:
                    metric_calculator.compute_sequence_metrics(seq_idx,
                                                               true_seq_dir,
                                                               '',
                                                               pred_seq=hr_seq)
                except:
                    print('No metirc need to compute!')

            # save/print metrics
            try:
                if opt['test'].get('save_json'):
                    # save results to json file
                    json_path = osp.join(opt['test']['json_dir'],
                                         '{}_avg.json'.format(ds_name))
                    metric_calculator.save_results(model_idx,
                                                   json_path,
                                                   override=True)
                else:
                    # print directly
                    metric_calculator.display_results()

            except:
                print('No metirc need to save!')

            logger.info('-' * 40)

    # logging
    logger.info('Finish testing')
    logger.info('=' * 40)
コード例 #7
0
ファイル: main.py プロジェクト: skycrapers/TecoGAN-PyTorch
def test(opt):
    # logging
    base_utils.print_options(opt)

    # infer and evaluate performance for each model
    for load_path in opt['model']['generator']['load_path_lst']:
        # set model index
        model_idx = osp.splitext(osp.split(load_path)[-1])[0]

        # log
        base_utils.log_info(f'{"=" * 40}')
        base_utils.log_info(f'Testing model: {model_idx}')
        base_utils.log_info(f'{"=" * 40}')

        # create model
        opt['model']['generator']['load_path'] = load_path
        model = define_model(opt)

        # for each test dataset
        for dataset_idx in sorted(opt['dataset'].keys()):
            # select testing dataset
            if 'test' not in dataset_idx:
                continue

            ds_name = opt['dataset'][dataset_idx]['name']
            base_utils.log_info(f'Testing on {ds_name} dataset')

            # create data loader
            test_loader = create_dataloader(opt, phase='test', idx=dataset_idx)
            test_dataset = test_loader.dataset
            num_seq = len(test_dataset)

            # create metric calculator
            metric_calculator = create_metric_calculator(opt)

            # infer a sequence
            rank, world_size = dist_utils.get_dist_info()
            for idx in range(rank, num_seq, world_size):
                # fetch data
                data = test_dataset[idx]

                # prepare data
                model.prepare_inference_data(data)

                # infer
                hr_seq = model.infer()

                # save hr results
                if opt['test']['save_res']:
                    res_dir = osp.join(opt['test']['res_dir'], ds_name,
                                       model_idx)
                    res_seq_dir = osp.join(res_dir, data['seq_idx'])
                    data_utils.save_sequence(res_seq_dir,
                                             hr_seq,
                                             data['frm_idx'],
                                             to_bgr=True)

                # compute metrics for the current sequence
                if metric_calculator is not None:
                    gt_seq = data['gt'].numpy()
                    metric_calculator.compute_sequence_metrics(
                        data['seq_idx'], gt_seq, hr_seq)

            # save/print results
            if metric_calculator is not None:
                seq_idx_lst = [data['seq_idx'] for data in test_dataset]
                metric_calculator.gather(seq_idx_lst)

                if opt['test'].get('save_json'):
                    # write results to a json file
                    json_path = osp.join(opt['test']['json_dir'],
                                         f'{ds_name}_avg.json')
                    metric_calculator.save(model_idx, json_path, override=True)
                else:
                    # print directly
                    metric_calculator.display()

            base_utils.log_info('-' * 40)
コード例 #8
0
ファイル: main.py プロジェクト: skycrapers/TecoGAN-PyTorch
def train(opt):
    # print configurations
    base_utils.log_info(f'{20*"-"} Configurations {20*"-"}')
    base_utils.print_options(opt)

    # create data loader
    train_loader = create_dataloader(opt, phase='train', idx='train')

    # build model
    model = define_model(opt)

    # set training params
    total_sample, iter_per_epoch = len(train_loader.dataset), len(train_loader)
    total_iter = opt['train']['total_iter']
    total_epoch = int(math.ceil(total_iter / iter_per_epoch))
    start_iter, iter = opt['train']['start_iter'], 0
    test_freq = opt['test']['test_freq']
    log_freq = opt['logger']['log_freq']
    ckpt_freq = opt['logger']['ckpt_freq']

    base_utils.log_info(f'Number of the training samples: {total_sample}')
    base_utils.log_info(
        f'{total_epoch} epochs needed for {total_iter} iterations')

    # train
    for epoch in range(total_epoch):
        if opt['dist']:
            train_loader.sampler.set_epoch(epoch)

        for data in train_loader:
            # update iter
            iter += 1
            curr_iter = start_iter + iter
            if iter > total_iter: break

            # prepare data
            model.prepare_training_data(data)

            # train a mini-batch
            model.train()

            # update running log
            model.update_running_log()

            # update learning rate
            model.update_learning_rate()

            # print messages
            if log_freq > 0 and curr_iter % log_freq == 0:
                msg = model.get_format_msg(epoch, curr_iter)
                base_utils.log_info(msg)

            # save model
            if ckpt_freq > 0 and curr_iter % ckpt_freq == 0:
                model.save(curr_iter)

            # evaluate model
            if test_freq > 0 and curr_iter % test_freq == 0:
                # set model index
                model_idx = f'G_iter{curr_iter}'

                # for each testset
                for dataset_idx in sorted(opt['dataset'].keys()):
                    # select test dataset
                    if 'test' not in dataset_idx: continue

                    ds_name = opt['dataset'][dataset_idx]['name']
                    base_utils.log_info(f'Testing on {ds_name} dataset')

                    # create data loader
                    test_loader = create_dataloader(opt,
                                                    phase='test',
                                                    idx=dataset_idx)
                    test_dataset = test_loader.dataset
                    num_seq = len(test_dataset)

                    # create metric calculator
                    metric_calculator = create_metric_calculator(opt)

                    # infer a sequence
                    rank, world_size = dist_utils.get_dist_info()
                    for idx in range(rank, num_seq, world_size):
                        # fetch data
                        data = test_dataset[idx]

                        # prepare data
                        model.prepare_inference_data(data)

                        # infer
                        hr_seq = model.infer()

                        # save hr results
                        if opt['test']['save_res']:
                            res_dir = osp.join(opt['test']['res_dir'], ds_name,
                                               model_idx)
                            res_seq_dir = osp.join(res_dir, data['seq_idx'])
                            data_utils.save_sequence(res_seq_dir,
                                                     hr_seq,
                                                     data['frm_idx'],
                                                     to_bgr=True)

                        # compute metrics for the current sequence
                        if metric_calculator is not None:
                            gt_seq = data['gt'].numpy()
                            metric_calculator.compute_sequence_metrics(
                                data['seq_idx'], gt_seq, hr_seq)

                    # save/print results
                    if metric_calculator is not None:
                        seq_idx_lst = [
                            data['seq_idx'] for data in test_dataset
                        ]
                        metric_calculator.gather(seq_idx_lst)

                        if opt['test'].get('save_json'):
                            # write results to a json file
                            json_path = osp.join(opt['test']['json_dir'],
                                                 f'{ds_name}_avg.json')
                            metric_calculator.save(model_idx,
                                                   json_path,
                                                   override=True)
                        else:
                            # print directly
                            metric_calculator.display()
コード例 #9
0
        model.load_state_dict(checkpoint['state_dict'])

        # predict the characters
        output = model.forward(tensor)
        _, preds = torch.max(output, 1)
        print(preds.item())
        for p in np.array(preds.cpu()):
            result.append(cat_to_class[model.class_to_idx[p]])
        print(preds)
    return result


if __name__ == '__main__':

    if is_train:
        model_ft, criterion, optimizer_ft, exp_lr_scheduler = models.define_model(
        )
        models = train.train_model(model_ft,
                                   criterion,
                                   optimizer_ft,
                                   exp_lr_scheduler,
                                   num_epochs=15)
        if is_save:
            torch.save(
                {
                    'state_dict': model_ft.state_dict(),
                    'class_to_idx': model_ft.class_to_idx
                }, os.path.join(model_path, 'cnnnet_win.pkl'))
    else:
        test.test()
コード例 #10
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# import numpy as np
# import os
# import pickle
# import tensorflow as tf
# import time

from data import bern_emb_data
from models import define_model
from args import parse_args
from utils import make_dir


args = parse_args()

d = bern_emb_data(args.cs, args.ns, args.fpath, args.dynamic, args.n_epochs)

dir_name = make_dir(d.name)


m = define_model(args, d, dir_name)

m.initialize_training()

m.train_embeddings()
コード例 #11
0
def train(args):
    '''\
  Training function.

  Args:
    args: namespace of arguments. Run 'artRecycle train --help' for info.
  '''

    # Model name and paths
    model_name = '{}|{}'.format(*args.datasets)
    model_path, log_path, logs_path = _prepare_directories(model_name,
                                                           resume=args.cont)

    model_json = os.path.join(model_path, 'keras.json')
    model_checkpoint = os.path.join(model_path, 'model')

    # Summary writers
    train_summary_writer = tf.summary.create_file_writer(
        os.path.join(log_path, 'train'))

    # Define datasets
    image_shape = (300, 300, 3)
    train_dataset, train_size = data.load_pair(*args.datasets,
                                               'all',
                                               shape=image_shape,
                                               batch=args.batch)

    train_dataset_it = iter(train_dataset)
    few_samples = [data.load_few(name, 'all', image_shape, 1) \
        for name in args.datasets]

    # Define keras model
    keras_model, model_layer = models.define_model(image_shape)

    # Save keras model
    keras_json = keras_model.to_json()
    keras_json = json.dumps(json.loads(keras_json), indent=2)
    with open(model_json, 'w') as f:
        f.write(keras_json)

    # Save TensorBoard graph
    @tf.function
    def tracing_model_ops(inputs):
        return model_layer(inputs)

    tf.summary.trace_on()
    tracing_model_ops(next(train_dataset_it))
    with train_summary_writer.as_default():
        tf.summary.trace_export('Model', step=0)

    # Resuming
    if args.cont:
        keras_model.load_weights(model_checkpoint)
        print('> Weights loaded')

    # Training steps
    step_saver = CountersSaver(log_dir=logs_path, log_every=args.logs)

    steps_per_epoch = int(train_size/args.batch) \
        if not args.epoch_steps else args.epoch_steps
    epochs = range(step_saver.epoch, args.epochs)

    # Training tools
    make_optmizer = lambda: tf.optimizers.Adam(args.rate)
    trainer = models.Trainer(keras_model, make_optmizer, train_dataset_it)
    tester = models.Tester(keras_model, train_dataset_it)
    saver = CheckpointSaver(keras_model, model_checkpoint)

    # Print job
    print('> Training.  Epochs:', epochs)

    # Training loop
    for epoch in epochs:
        print('> Epoch', step_saver.epoch)

        for epoch_step in range(steps_per_epoch):
            print('> Step', step_saver.step, end='\r')

            # Train step
            output = trainer.step()

            # Validation and log
            if step_saver.step % args.logs == 0 or epoch_step == steps_per_epoch - 1:
                print('\n> Validation')

                # Evaluation
                for i in range(args.val_steps):
                    tester.step()
                train_metrics = tester.result()

                # Log in console
                print('  Train metrics:', train_metrics)

                # Log in TensorBoard
                with train_summary_writer.as_default():
                    for metric in train_metrics:
                        tf.summary.scalar(metric,
                                          train_metrics[metric],
                                          step=step_saver.step)

                # Save weigths
                loss = 0
                for m in train_metrics.values():
                    loss += m
                saved = saver.save(score=-loss)
                if saved:
                    print('Weigths saved')

                # Transform images for visualization
                if args.images:
                    fake_A, fake_B, *_ = keras_model(few_samples)
                    fake_A_viz = image_unnormalize(fake_A)
                    fake_B_viz = image_unnormalize(fake_B)

                    # Log images
                    with train_summary_writer.as_default():
                        tf.summary.image('fake_A',
                                         fake_A_viz,
                                         step=step_saver.step)
                        tf.summary.image('fake_B',
                                         fake_B_viz,
                                         step=step_saver.step)

            # End step
            step_saver.new_step()

        # End epoch
        step_saver.new_epoch()
コード例 #12
0
                        num_workers=threads)
test_loader = DataLoader(dataset,
                         batch_size=batch_size,
                         sampler=test_sampler,
                         shuffle=False,
                         num_workers=threads)

print('--training samples count:', len(train_indices))
print('--validation samples count:', len(val_indices))
print('--test samples count:', len(test_indices))

print('===> Loading model')

net = define_model(config['input_nc'],
                   config['output_nc'],
                   config['nfg'],
                   n_blocks=config['layers'],
                   gpu_id=device,
                   args=args).float()

# if torch.cuda.device_count() > 1:
#   print("--using", torch.cuda.device_count(), "GPUs")
#   net = nn.DataParallel(net)
# net.to(device)
# net_g.to(device)

optimizer = optim.Adam(net.parameters(),
                       lr=config['adam_lr'],
                       betas=(config['adam_b1'], config['adam_b2']))
net_scheduler = get_scheduler(optimizer, args)

criterionL1 = nn.L1Loss().to(device)
コード例 #13
0
ファイル: main.py プロジェクト: BrainGardenAI/TecoGAN-PyTorch
def train(opt):
    # logging
    logger = base_utils.get_logger('base')
    logger.info('{} Options {}'.format('='*20, '='*20))
    base_utils.print_options(opt, logger)

    # create data loader
    train_loader = create_dataloader(opt, dataset_idx='train')

    # create downsampling kernels for BD degradation
    kernel = data_utils.create_kernel(opt)

    # create model
    model = define_model(opt)

    # training configs
    total_sample = len(train_loader.dataset)
    iter_per_epoch = len(train_loader)
    total_iter = opt['train']['total_iter']
    total_epoch = int(math.ceil(total_iter / iter_per_epoch))
    curr_iter = opt['train']['start_iter']

    test_freq = opt['test']['test_freq']
    log_freq = opt['logger']['log_freq']
    ckpt_freq = opt['logger']['ckpt_freq']
    sigma_freq = opt['dataset']['degradation'].get('sigma_freq', 0)
    sigma_inc = opt['dataset']['degradation'].get('sigma_inc', 0)
    sigma_max = opt['dataset']['degradation'].get('sigma_max', 10)

    logger.info('Number of training samples: {}'.format(total_sample))
    logger.info('Total epochs needed: {} for {} iterations'.format(
        total_epoch, total_iter))
    print('device count:', torch.cuda.device_count())
    # train
    for epoch in range(total_epoch):
        for data in tqdm(train_loader):
            # update iter
            curr_iter += 1
            if curr_iter > total_iter:
                logger.info('Finish training')
                break

            # update learning rate
            model.update_learning_rate()

            # prepare data
            data = prepare_data(opt, data, kernel)

            # train for a mini-batch
            model.train(data)

            # update running log
            model.update_running_log()

            # log
            if log_freq > 0 and curr_iter % log_freq == 0:
                # basic info
                msg = '[epoch: {} | iter: {}'.format(epoch, curr_iter)
                for lr_type, lr in model.get_current_learning_rate().items():
                    msg += ' | {}: {:.2e}'.format(lr_type, lr)
                msg += '] '

                # loss info
                log_dict = model.get_running_log()
                msg += ', '.join([
                    '{}: {:.3e}'.format(k, v) for k, v in log_dict.items()])
                if opt['dataset']['degradation']['type'] == 'BD':
                    msg += ' | Sigma: {}'.format(opt['dataset']['degradation']['sigma'])
                logger.info(msg)

            # save model
            if ckpt_freq > 0 and curr_iter % ckpt_freq == 0:
                model.save(curr_iter)

            # evaluate performance
            if test_freq > 0 and curr_iter % test_freq == 0:
                # setup model index
                model_idx = 'G_iter{}'.format(curr_iter)
                if opt['dataset']['degradation']['type'] == 'BD':
                    model_idx = model_idx + str(opt['dataset']['degradation']['sigma'])

                # for each testset
                for dataset_idx in sorted(opt['dataset'].keys()):
                    # use dataset with prefix `test`
                    if not dataset_idx.startswith('validate'):
                        continue
                    validate(opt, model, logger, dataset_idx, model_idx)

        # schedule sigma
        if opt['dataset']['degradation']['type'] == 'BD':
            if sigma_freq > 0 and (epoch + 1) % sigma_freq == 0:
                current_sigma = opt['dataset']['degradation']['sigma']
                opt['dataset']['degradation']['sigma'] = min(current_sigma + sigma_inc, sigma_max)
                kernel = data_utils.create_kernel(opt)
                
                # __getitem__ in custom dataset class uses some crop that depends sigma
                # it is crucial to change this cropsize accordingly if sigma is being changed
                train_loader.dataset.change_cropsize(opt['dataset']['degradation']['sigma'])
                print('kernel changed')