コード例 #1
0
ファイル: train.py プロジェクト: JoeyBallentine/BasicSR
def get_resume_state(opt):
    logger = util.get_root_logger()

    # train from scratch OR resume training
    if opt['path']['resume_state']:
        if os.path.isdir(opt['path']['resume_state']):
            resume_state_path = glob.glob(opt['path']['resume_state'] +
                                          '/*.state')
            resume_state_path = util.sorted_nicely(resume_state_path)[-1]
        else:
            resume_state_path = opt['path']['resume_state']

        if opt['gpu_ids']:
            resume_state = torch.load(resume_state_path)
        else:
            resume_state = torch.load(resume_state_path,
                                      map_location=torch.device('cpu'))

        logger.info('Set [resume_state] to {}'.format(resume_state_path))
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        options.check_resume(opt)  # check resume options
    else:  # training from scratch
        resume_state = None
    return resume_state
コード例 #2
0
def iterate_get_graphs(dir, graph_data, nodes, fname_to_gid_func,
                       check_connected=False, natts=(), eatts=()):
    graphs = {}
    not_connected = []
    no_edges = []
    graphs_not_in_edge_list = []
    for file in tqdm(sorted_nicely(glob(join(dir, '*.gexf')))):
        fname = basename(file).split('.')[0]
        gid = fname_to_gid_func(fname)
        if gid not in nodes:
            graphs_not_in_edge_list.append(fname)
            continue
        g = nx.read_gexf(file)
        g.graph['gid'] = gid
        if not nx.is_connected(g):
            msg = '{} not connected'.format(gid)
            if check_connected:
                raise ValueError(msg)
            else:
                not_connected.append(fname)
        # assumes default node mapping to convert_node_labels_to_integers

        nlist = sorted(g.nodes())
        g.graph['node_label_mapping'] = dict(zip(nlist,
                                                 range(0, g.number_of_nodes())))
        add_graph_data_to_nxgraph(g, graph_data[fname])
        g = nx.convert_node_labels_to_integers(g, ordering="sorted")
        if len(g.edges) == 0:
            no_edges.append(fname)
            continue
        # # Must use sorted_nicely because otherwise may result in:
        # # ['0', '1', '10', '2', '3', '4', '5', '6', '7', '8', '9'].
        # # Be very cautious on sorting a list of strings
        # # which are supposed to be integers.
        for i, (n, ndata) in enumerate(sorted(g.nodes(data=True))):
            BiGNNDataset.assert_valid_nid(n, g)
            assert i == n
            remove_entries_from_dict(ndata, natts)
        for i, (n1, n2, edata) in enumerate(sorted(g.edges(data=True))):
            BiGNNDataset.assert_valid_nid(n1, g)
            BiGNNDataset.assert_valid_nid(n2, g)
            remove_entries_from_dict(edata, eatts)
        graphs[gid] = Graph(g)
    print("total graphs with edges: {}\nnon connected graphs: {}"
          .format(len(graphs), len(not_connected)))
    print("not connected ids: ", not_connected)
    print("num no edges: ", len(no_edges), "\nno edges ids: ", no_edges)
    if not graphs:
        raise ValueError('Loaded 0 graphs from {}\n'
                         'Please download the gexf-formated dataset'
                         ' from Google Drive and extract under:\n{}'.
                         format(dir, get_data_path()))
    return graphs
コード例 #3
0
ファイル: train.py プロジェクト: canhnht/BasicSR
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    # Convert to NoneDict, which return None for missing key.
    opt = option.dict_to_nonedict(opt)
    pytorch_ver = get_pytorch_ver()

    # train from scratch OR resume training
    if opt['path']['resume_state']:
        if os.path.isdir(opt['path']['resume_state']):
            import glob
            resume_state_path = util.sorted_nicely(
                glob.glob(
                    os.path.normpath(opt['path']['resume_state']) +
                    '/*.state'))[-1]
        else:
            resume_state_path = opt['path']['resume_state']
        resume_state = torch.load(resume_state_path)
    else:  # training from scratch
        resume_state = None
        # rename old folder if exists
        util.mkdir_and_rename(opt['path']['experiments_root'])
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None,
                      opt['path']['log'],
                      'train',
                      level=logging.INFO,
                      screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Set [resume_state] to ' + resume_state_path)
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        try:
            # for version tensorboardX >= 1.7
            tb_logger = SummaryWriter(logdir='../tb_logger/' + opt['name'])
        except:
            # for version tensorboardX < 1.6
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
        # updated schedulers in case JSON configuration has changed
        model.update_schedulers(opt['train'])
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        for n, train_data in enumerate(train_loader, start=1):
            current_step += 1
            if current_step > total_iters:
                break

            if pytorch_ver == "pre":  # Order for PyTorch ver < 1.1.0
                # update learning rate
                model.update_learning_rate(current_step - 1)
                # training
                model.feed_data(train_data)
                model.optimize_parameters(current_step)
            elif pytorch_ver == "post":  # Order for PyTorch ver > 1.1.0
                # training
                model.feed_data(train_data)
                model.optimize_parameters(current_step)
                # update learning rate
                model.update_learning_rate(current_step - 1)
            else:
                print('Error identifying PyTorch version. ', torch.__version__)
                break

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # save models and training states (changed to save models before validation)
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                model.save(current_step)
                model.save_training_state(epoch + (n >= len(train_loader)),
                                          current_step)
                logger.info('Models and training states saved.')

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr = 0.0
                avg_ssim = 0.0
                avg_lpips = 0.0
                idx = 0
                val_sr_imgs_list = []
                val_gt_imgs_list = []
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()

                    if opt['datasets']['train'][
                            'znorm']:  # If the image range is [-1,1]
                        sr_img = util.tensor2img(visuals['SR'],
                                                 min_max=(-1, 1))  # uint8
                        gt_img = util.tensor2img(visuals['HR'],
                                                 min_max=(-1, 1))  # uint8
                    else:  # Default: Image range is [0,1]
                        sr_img = util.tensor2img(visuals['SR'])  # uint8
                        gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # sr_img = util.tensor2img(visuals['SR'])  # uint8
                    # gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # print("Min. SR value:",sr_img.min()) # Debug
                    # print("Max. SR value:",sr_img.max()) # Debug

                    # print("Min. GT value:",gt_img.min()) # Debug
                    # print("Max. GT value:",gt_img.max()) # Debug

                    # Save SR images for reference
                    save_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}.png'.format(img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR, SSIM and LPIPS distance
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.

                    # For training models with only one channel ndim==2, if RGB ndim==3, etc.
                    if gt_img.ndim == 2:
                        cropped_gt_img = gt_img[crop_size:-crop_size,
                                                crop_size:-crop_size]
                    else:
                        cropped_gt_img = gt_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]
                    if sr_img.ndim == 2:
                        cropped_sr_img = sr_img[crop_size:-crop_size,
                                                crop_size:-crop_size]
                    else:  # Default: RGB images
                        cropped_sr_img = sr_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]

                    # If calculating only once for all images
                    val_gt_imgs_list.append(cropped_gt_img)
                    # If calculating only once for all images
                    val_sr_imgs_list.append(cropped_sr_img)

                    # LPIPS only works for RGB images
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255,
                                                    cropped_gt_img * 255)
                    avg_ssim += util.calculate_ssim(cropped_sr_img * 255,
                                                    cropped_gt_img * 255)
                    # If calculating for each image
                    # avg_lpips += lpips.calculate_lpips([cropped_sr_img], [cropped_gt_img])

                avg_psnr = avg_psnr / idx
                avg_ssim = avg_ssim / idx
                # avg_lpips=avg_lpips / idx  # If calculating for each image
                # If calculating only once for all images
                avg_lpips = lpips.calculate_lpips(val_sr_imgs_list,
                                                  val_gt_imgs_list)

                # log
                # logger.info('# Validation # PSNR: {:.5g}, SSIM: {:.5g}'.format(avg_psnr, avg_ssim))
                logger.info(
                    '# Validation # PSNR: {:.5g}, SSIM: {:.5g}, LPIPS: {:.5g}'.
                    format(avg_psnr, avg_ssim, avg_lpips))
                logger_val = logging.getLogger('val')  # validation logger
                # logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.5g}, ssim: {:.5g}'.format(
                # epoch, current_step, avg_psnr, avg_ssim))
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr: {:.5g}, ssim: {:.5g}, lpips: {:.5g}'
                    .format(epoch, current_step, avg_psnr, avg_ssim,
                            avg_lpips))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    tb_logger.add_scalar('ssim', avg_ssim, current_step)
                    tb_logger.add_scalar('lpips', avg_lpips, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
コード例 #4
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(opt)  # Convert to NoneDict, which return None for missing key.

    # train from scratch OR resume training
    if opt['path']['resume_state']:
        if os.path.isdir(opt['path']['resume_state']):
            import glob
            resume_state_path = util.sorted_nicely(glob.glob(os.path.normpath(opt['path']['resume_state']) + '/*.state'))[-1]
        else:
            resume_state_path = opt['path']['resume_state']
        resume_state = torch.load(resume_state_path)
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None, opt['path']['log'], 'train', level=logging.INFO, screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Set [resume_state] to ' + resume_state_path)
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        try:
            tb_logger = SummaryWriter(logdir='../tb_logger/' + opt['name']) #for version tensorboardX >= 1.7
        except:
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) #for version tensorboardX < 1.6

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    # if the model does not change and input sizes remain the same during training then there may be benefit
    # from setting torch.backends.cudnn.benchmark = True, otherwise it may stall training
    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    val_loader = False
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            batch_size = dataset_opt.get('batch_size', 4)
            virtual_batch_size = dataset_opt.get('virtual_batch_size', batch_size)
            virtual_batch_size = virtual_batch_size if virtual_batch_size > batch_size else batch_size
            train_size = int(math.ceil(len(train_set) / batch_size))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'],
                                                                      len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        virtual_step = current_step * virtual_batch_size / batch_size \
            if virtual_batch_size and virtual_batch_size > batch_size else current_step
        model.resume_training(resume_state)  # handle optimizers and schedulers
        model.update_schedulers(opt['train']) # updated schedulers in case JSON configuration has changed
        del resume_state
        # start the iteration time when resuming
        t0 = time.time()
    else:
        current_step = 0
        virtual_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    try:
        for epoch in range(start_epoch, total_epochs*(virtual_batch_size//batch_size)):
            for n, train_data in enumerate(train_loader,start=1):

                if virtual_step == 0:
                    # first iteration start time
                    t0 = time.time()

                virtual_step += 1
                take_step = False
                if virtual_step > 0 and virtual_step * batch_size % virtual_batch_size == 0:
                    current_step += 1
                    take_step = True
                    if current_step > total_iters:
                        break

                # training
                model.feed_data(train_data)
                model.optimize_parameters(virtual_step)

                # log
                if current_step % opt['logger']['print_freq'] == 0 and take_step:
                    # iteration end time
                    t1 = time.time()

                    logs = model.get_current_log()
                    message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, i_time: {:.4f} sec.> '.format(
                        epoch, current_step, model.get_current_learning_rate(current_step), (t1 - t0))
                    for k, v in logs.items():
                        message += '{:s}: {:.4e} '.format(k, v)
                        # tensorboard logger
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar(k, v, current_step)
                    logger.info(message)

                    # # start time for next iteration
                    # t0 = time.time()

                # update learning rate
                if model.optGstep and model.optDstep and take_step:
                    model.update_learning_rate(current_step, warmup_iter=opt['train'].get('warmup_iter', -1))

                # save models and training states (changed to save models before validation)
                if current_step % opt['logger']['save_checkpoint_freq'] == 0 and take_step:
                    if model.swa:
                        model.save(current_step, opt['logger']['overwrite_chkp'], loader=train_loader)
                    else:
                        model.save(current_step, opt['logger']['overwrite_chkp'])
                    model.save_training_state(epoch + (n >= len(train_loader)), current_step, opt['logger']['overwrite_chkp'])
                    logger.info('Models and training states saved.')

                # validation
                if val_loader and current_step % opt['train']['val_freq'] == 0 and take_step:
                    val_sr_imgs_list = []
                    val_gt_imgs_list = []
                    val_metrics = metrics.MetricsDict(metrics=opt['train'].get('metrics', None))
                    for val_data in val_loader:
                        img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'], img_name)
                        util.mkdir(img_dir)

                        model.feed_data(val_data)
                        model.test(val_data)

                        """
                        Get Visuals
                        """
                        visuals = model.get_current_visuals()
                        sr_img = tensor2np(visuals['SR'], denormalize=opt['datasets']['train']['znorm'])
                        gt_img = tensor2np(visuals['HR'], denormalize=opt['datasets']['train']['znorm'])

                        # Save SR images for reference
                        if opt['train']['overwrite_val_imgs']:
                            save_img_path = os.path.join(img_dir, '{:s}.png'.format(\
                                img_name))
                        else:
                            save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                                img_name, current_step))

                        # save single images or lr / sr comparison
                        if opt['train']['val_comparison']:
                            lr_img = tensor2np(visuals['LR'], denormalize=opt['datasets']['train']['znorm'])
                            util.save_img_comp(lr_img, sr_img, save_img_path)
                        else:
                            util.save_img(sr_img, save_img_path)

                        """
                        Get Metrics
                        # TODO: test using tensor based metrics (batch) instead of numpy.
                        """
                        crop_size = opt['scale']
                        val_metrics.calculate_metrics(sr_img, gt_img, crop_size = crop_size)  #, only_y=True)

                    avg_metrics = val_metrics.get_averages()
                    del val_metrics

                    # log
                    logger_m = ''
                    for r in avg_metrics:
                        #print(r)
                        formatted_res = r['name'].upper()+': {:.5g}, '.format(r['average'])
                        logger_m += formatted_res

                    logger.info('# Validation # '+logger_m[:-2])
                    logger_val = logging.getLogger('val')  # validation logger
                    logger_val.info('<epoch:{:3d}, iter:{:8,d}> '.format(epoch, current_step)+logger_m[:-2])
                    # memory_usage = torch.cuda.memory_allocated()/(1024.0 ** 3) # in GB

                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        for r in avg_metrics:
                            tb_logger.add_scalar(r['name'], r['average'], current_step)

                    # # reset time for next iteration to skip the validation time from calculation
                    # t0 = time.time()

                if current_step % opt['logger']['print_freq'] == 0 and take_step or \
                    (val_loader and current_step % opt['train']['val_freq'] == 0 and take_step):
                    # reset time for next iteration to skip the validation time from calculation
                    t0 = time.time()

        logger.info('Saving the final model.')
        if model.swa:
            model.save('latest', loader=train_loader)
        else:
            model.save('latest')
        logger.info('End of training.')

    except KeyboardInterrupt:
        # catch a KeyboardInterrupt and save the model and state to resume later
        if model.swa:
            model.save(current_step, True, loader=train_loader)
        else:
            model.save(current_step, True)
        model.save_training_state(epoch + (n >= len(train_loader)), current_step, True)
        logger.info('Training interrupted. Latest models and training states saved.')
コード例 #5
0
ファイル: train_ppon.py プロジェクト: liuguoyou/BasicSR
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(opt)  # Convert to NoneDict, which return None for missing key.
    pytorch_ver = get_pytorch_ver()
    
    # train from scratch OR resume training
    if opt['path']['resume_state']:
        if os.path.isdir(opt['path']['resume_state']):
            import glob
            resume_state_path = util.sorted_nicely(glob.glob(os.path.normpath(opt['path']['resume_state']) + '/*.state'))[-1]
        else:
            resume_state_path = opt['path']['resume_state']
        resume_state = torch.load(resume_state_path)
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))
    
    # config loggers. Before it, the log will not work
    util.setup_logger(None, opt['path']['log'], 'train', level=logging.INFO, screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')
    
    if resume_state:
        logger.info('Set [resume_state] to ' + resume_state_path)
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        try:
            tb_logger = SummaryWriter(logdir='../tb_logger/' + opt['name']) #for version tensorboardX >= 1.7
        except:
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) #for version tensorboardX < 1.6

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'],
                                                                      len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
        model.update_schedulers(opt['train']) # updated schedulers in case JSON configuration has changed
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        for n, train_data in enumerate(train_loader,start=1):
            current_step += 1
            if current_step > total_iters:
                break
            
            if pytorch_ver=="pre": #Order for PyTorch ver < 1.1.0
                # update learning rate
                model.update_learning_rate(current_step-1)
                # training
                model.feed_data(train_data)
                model.optimize_parameters(current_step)
            elif pytorch_ver=="post": #Order for PyTorch ver > 1.1.0
                # training
                model.feed_data(train_data)
                model.optimize_parameters(current_step)
                # update learning rate
                model.update_learning_rate(current_step-1)
            else:
                print('Error identifying PyTorch version. ', torch.__version__)
                break

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # save models and training states (changed to save models before validation)
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                model.save(current_step)
                model.save_training_state(epoch + (n >= len(train_loader)), current_step)
                logger.info('Models and training states saved.')
            
            # validation
            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr_c = 0.0
                avg_psnr_s = 0.0
                avg_psnr_p = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    img_c = util.tensor2img(visuals['img_c'])  # uint8
                    img_s = util.tensor2img(visuals['img_s'])  # uint8
                    img_p = util.tensor2img(visuals['img_p'])  # uint8
                    
                    gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # Save SR images for reference
                    save_c_img_path = os.path.join(img_dir, '{:s}_{:d}_c.png'.format(img_name, current_step))
                    save_s_img_path = os.path.join(img_dir, '{:s}_{:d}_s.png'.format(img_name, current_step))
                    save_p_img_path = os.path.join(img_dir, '{:s}_{:d}_d.png'.format(img_name, current_step))
                    
                    util.save_img(img_c, save_c_img_path)
                    util.save_img(img_s, save_s_img_path)
                    util.save_img(img_p, save_p_img_path)

                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    #sr_img = sr_img / 255. #ESRGAN 
                    #PPON
                    #C
                    sr_img_c = img_c / 255.
                    cropped_sr_img_c = sr_img_c[crop_size:-crop_size, crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    avg_psnr_c += util.calculate_psnr(cropped_sr_img_c * 255, cropped_gt_img * 255)
                    #S
                    sr_img_s = img_s / 255.
                    cropped_sr_img_s = sr_img_s[crop_size:-crop_size, crop_size:-crop_size, :]
                    avg_psnr_s += util.calculate_psnr(cropped_sr_img_s * 255, cropped_gt_img * 255)
                    #D
                    sr_img_p = img_p / 255.
                    cropped_sr_img_p = sr_img_p[crop_size:-crop_size, crop_size:-crop_size, :]
                    avg_psnr_p += util.calculate_psnr(cropped_sr_img_p * 255, cropped_gt_img * 255)
                    

                avg_psnr_c = avg_psnr_c / idx
                avg_psnr_s = avg_psnr_s / idx
                avg_psnr_p = avg_psnr_p / idx

                # log
                logger.info('# Validation # PSNR_c: {:.5g}'.format(avg_psnr_c))
                logger.info('# Validation # PSNR_s: {:.5g}'.format(avg_psnr_s))
                logger.info('# Validation # PSNR_p: {:.5g}'.format(avg_psnr_p))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr_c: {:.5g}, psnr_s: {:.5g}, psnr_p: {:.5g}'.format(
                    epoch, current_step, avg_psnr_c, avg_psnr_s, avg_psnr_p))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr_c', avg_psnr_c, current_step)
                    tb_logger.add_scalar('psnr_s', avg_psnr_s, current_step)
                    tb_logger.add_scalar('psnr_p', avg_psnr_p, current_step)


    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')