예제 #1
0
def main(_):

    # Sets up config and enables GPU memory allocation growth
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    # Checks if train mode is 3, 5 or 6 and training mode is on
    if FLAGS.train_mode == 3 and FLAGS.is_train:
        print('Error: Bicubic Mode does not require training')
        return
    elif FLAGS.train_mode == 5 and FLAGS.is_train:
        print(
            'Error: Multi-Dir testing mode for Mode 2 does not require training'
        )
        return
    elif FLAGS.train_mode == 6 and FLAGS.is_train:
        print(
            'Error: Multi-Dir testing mode for Mode 1 does not require training'
        )
        return

    # Starts session; initializes ESPCN object; and runs training or testing operations
    with tf.Session(config=config) as sess:
        espcn = ESPCN(sess,
                      image_size=FLAGS.image_size,
                      is_train=FLAGS.is_train,
                      train_mode=FLAGS.train_mode,
                      scale=FLAGS.scale,
                      c_dim=FLAGS.c_dim,
                      batch_size=FLAGS.batch_size,
                      load_existing_data=FLAGS.load_existing_data,
                      config=config)
        espcn.train(FLAGS)
예제 #2
0
def main(_):  #?
    with tf.Session() as sess:
        espcn = ESPCN(sess,
                      image_size=FLAGS.image_size,
                      is_train=FLAGS.is_train,
                      scale=FLAGS.scale,
                      c_dim=FLAGS.c_dim,
                      batch_size=FLAGS.batch_size,
                      test_img=FLAGS.test_img,
                      test_path=FLAGS.test_path)

        espcn.train(FLAGS)
예제 #3
0
def main(_):  # ?
    with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
            log_device_placement=True)) as sess:
        espcn = ESPCN(
            sess,
            image_size=FLAGS.image_size,
            is_train=FLAGS.is_train,
            scale=FLAGS.scale,
            c_dim=FLAGS.c_dim,
            batch_size=FLAGS.batch_size,
            test_img=FLAGS.test_img,
        )

        espcn.train(FLAGS)
예제 #4
0
파일: main.py 프로젝트: zz-Jade/ESPCN
import tensorflow as tf
from model import ESPCN
import config

if __name__ == '__main__':
    with tf.Session() as sess:
        espcn = ESPCN(
            sess,
            image_size=config.image_size,
            scale=config.scale,
            c_dim=config.c_dim,
            batch_size=config.batch_size,
        )
        espcn.train()
예제 #5
0
    train_set = Train(args.training_set, scale=args.scale, patch_size=args.patch_size)
    trainloader = DataLoader(train_set, batch_size=args.batch_size,
                              shuffle=True, num_workers=args.num_workers, pin_memory=True)

    val_set = Validation(args.val_set)
    valloader = DataLoader(val_set, batch_size=1,
                            shuffle=True, num_workers=args.num_workers, pin_memory=True)
    best_epoch = 0
    best_PSNR = 0.0
    loss_plot = []
    psnr_plot = []
    best_weights = copy.deepcopy(net.state_dict())


    for epoch in range(args.epoch):
        net.train()
        epoch_loss = AverageMeter()

        for data in trainloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)
            preds = net(inputs)

            preds_3d = preds.repeat(1, 3, 1, 1)
            labels_3d = labels.repeat(1, 3, 1, 1)

            if k != 0:
                feature_preds = feat_ext(preds_3d)
                feature_labels = feat_ext(labels_3d)
예제 #6
0
        start_iter = ckpt['iter']
        best_val_psnr = ckpt['best_val_psnr']

    print('===> Start training')
    sys.stdout.flush()

    still_training = True
    i = start_iter
    val_loss_meter = AverageMeter()
    val_psnr_meter = AverageMeter()

    while i <= config['training']['iterations'] and still_training:
        for batch in train_dataloader:
            i += 1
            scheduler.step()
            model.train()

            input, target = batch[0].to(device), batch[1].to(device)

            output = model(input)
            optimizer.zero_grad()

            loss = criterion(output, target)

            loss.backward()
            optimizer.step()

            if i % config['training']['print_interval'] == 0:
                format_str = 'Iter [{:d}/{:d}] Loss: {:.6f}'
                print_str = format_str.format(i,
                                              config['training']['iterations'],