Пример #1
0
def main():
    global opt
    opt = parser.parse_args()
    opt.gpuids = list(map(int, opt.gpuids))

    print(opt)

    if opt.cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")
    cudnn.benchmark = True

    train_set = get_training_set(opt.upscale_factor, opt.add_noise,
                                 opt.noise_std)
    validation_set = get_validation_set(opt.upscale_factor)
    test_set = get_test_set(opt.upscale_factor)
    training_data_loader = DataLoader(dataset=train_set,
                                      num_workers=opt.threads,
                                      batch_size=opt.batch_size,
                                      shuffle=True)
    validating_data_loader = DataLoader(dataset=validation_set,
                                        num_workers=opt.threads,
                                        batch_size=opt.test_batch_size,
                                        shuffle=False)
    testing_data_loader = DataLoader(dataset=test_set,
                                     num_workers=opt.threads,
                                     batch_size=opt.test_batch_size,
                                     shuffle=False)

    model = SRCNN()
    criterion = nn.MSELoss()

    if opt.cuda:
        torch.cuda.set_device(opt.gpuids[0])
        with torch.cuda.device(opt.gpuids[0]):
            model = model.cuda()
            criterion = criterion.cuda()
        model = nn.DataParallel(model,
                                device_ids=opt.gpuids,
                                output_device=opt.gpuids[0])

    optimizer = optim.Adam(model.parameters(), lr=opt.lr)

    if opt.test:
        model_name = join("model", opt.model)
        model = torch.load(model_name)
        model = nn.DataParallel(model,
                                device_ids=opt.gpuids,
                                output_device=opt.gpuids[0])
        start_time = time.time()
        test(model, criterion, testing_data_loader)
        elapsed_time = time.time() - start_time
        print("===> average {:.2f} image/sec for processing".format(
            100.0 / elapsed_time))
        return

    for epoch in range(1, opt.epochs + 1):
        train(model, criterion, epoch, optimizer, training_data_loader)
        validate(model, criterion, validating_data_loader)
        if epoch % 10 == 0:
            checkpoint(model, epoch)
Пример #2
0
def main(args):
    srcnn = SRCNN(
        image_size=args.image_size,
        c_dim=args.c_dim,
        is_training=False)
    X_pre_test, X_test, Y_test, color = load_test(scale=args.scale)
    predicted_list = []
    for img in X_test:
        predicted = srcnn.process(img.reshape(1,img.shape[0],img.shape[1],1))
        predicted_list.append(predicted.reshape(predicted.shape[1],predicted.shape[2],1))
    n_img = len(predicted_list)
    dirname = './result'
    for i in range(n_img):
        imgname = 'image{:02}'.format(i)
        # print(color[i].shape)
        # print(predicted_list[i].clip(min=0, max=255))
        # print(np.concatenate((predicted_list[i] / 255.0, color[i] / 255.0), axis=2))
        # cv2.imshow('result', X_test[i])
        # cv2.imshow('result', cv2.cvtColor(np.concatenate((Y_test[i], color[i]), axis=2), cv2.COLOR_YCrCb2BGR))
        # cv2.imshow('result', cv2.cvtColor(np.concatenate((predicted_list[i] / 255.0, color[i] / 255.0), axis=2).astype(np.float32), cv2.COLOR_YCrCb2BGR))
        # cv2.waitKey(0)
        cv2.imwrite(os.path.join(dirname,imgname+'_original.bmp'), X_pre_test[i])
        cv2.imwrite(os.path.join(dirname,imgname+'_input.bmp'), cv2.cvtColor(np.concatenate((X_test[i], color[i]), axis=2), cv2.COLOR_YCrCb2BGR))
        cv2.imwrite(os.path.join(dirname,imgname+'_answer.bmp'), cv2.cvtColor(np.concatenate((Y_test[i], color[i]), axis=2), cv2.COLOR_YCrCb2BGR))
        float_predicted = (predicted_list[i] / 255.0).clip(min=0., max=1.).astype(np.float64)
        normalized_predicted = np.expand_dims(cv2.normalize(float_predicted, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1), axis=2)
        print(normalized_predicted.shape)
        cv2.imwrite(os.path.join(dirname,imgname+'_predicted.bmp'), cv2.cvtColor(np.concatenate((normalized_predicted, color[i]), axis=2), cv2.COLOR_YCrCb2BGR))
Пример #3
0
def main(_):
    with tf.Session() as sess:
        srcnn = SRCNN(sess,
                      image_dim=FLAGS.image_dim,
                      label_dim=FLAGS.label_dim,
                      channel=FLAGS.channel)

        srcnn.train(FLAGS)
Пример #4
0
def main(_): #?
    with tf.Session() as sess:
        srcnn = SRCNN(sess,
                      image_size = FLAGS.image_size,
                      label_size = FLAGS.label_size,
                      c_dim = FLAGS.c_dim)

        srcnn.train(FLAGS)
Пример #5
0
def main(args):
    srcnn = SRCNN(
        image_size=args.image_size,
        c_dim=args.c_dim,
        is_training=True,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        epochs=args.epochs)
    X_train, Y_train = load_train(image_size=args.image_size, stride=args.stride, scale=args.scale)
    srcnn.train(X_train, Y_train)
Пример #6
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)
    if not os.path.exists(FLAGS.log_dir):
        os.makedirs(FLAGS.log_dir)
    with tf.Session() as sess:
        srcnn = SRCNN(sess, FLAGS)
        srcnn.train()
Пример #7
0
def main(_): #?
    with tf.Session() as sess:
        
        #print("Calling init")

        srcnn = SRCNN(sess,
                      image_size = FLAGS.image_size,
                      label_size = FLAGS.label_size,
                      c_dim = FLAGS.c_dim)

        #print("Calling train")
        srcnn.train(FLAGS)
Пример #8
0
    def train_srcnn(self, iteration):
        # images = low resolution, labels = high resolution
        sess = self.sess
        #load data
        train_label_list = sorted(glob.glob('./dataset/training/gray/*.*'))

        num_image = len(train_label_list)

        sr_model = SRCNN(channel_length=self.c_length, image=self.x)
        v1, v2, prediction = sr_model.build_model()

        with tf.name_scope("mse_loss"):
            loss = tf.reduce_mean(tf.square(self.y - prediction))
        '''
        train_op1 = tf.train.GradientDescentOptimizer(learning_rate=1e-4).minimize(loss, var_list=v1)
        train_op2 = tf.train.GradientDescentOptimizer(learning_rate=1e-5).minimize(loss, var_list=v2)
        train_op = tf.group(train_op1, train_op2)
        '''
        train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)

        batch_size = 3
        num_batch = int(num_image / batch_size)

        init = tf.global_variables_initializer()
        sess.run(init)

        saver = tf.train.Saver(max_to_keep=1)
        if self.pre_trained:
            saver.restore(sess, self.save_path)

        for i in range(iteration):
            total_mse_loss = 0
            for j in range(num_batch):
                for k in range(2, 5):
                    train_image_list = sorted(
                        glob.glob('./dataset/training/X{}/*.*'.format(k)))
                    batch_image, batch_label = preprocess.load_data(
                        train_image_list, train_label_list, j * batch_size,
                        min((j + 1) * batch_size, num_image), self.patch_size,
                        self.num_patch_per_image)
                    mse_loss, _ = sess.run([loss, train_op],
                                           feed_dict={
                                               self.x: batch_image,
                                               self.y: batch_label
                                           })
                    total_mse_loss += mse_loss / (num_batch * 3)
                    # print(mse_loss)

            print('In', i + 1, 'epoch, current loss is',
                  '{:.5f}'.format(total_mse_loss))
            saver.save(sess, save_path=self.save_path)

        print('Train completed')
Пример #9
0
def main(_):
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)
    with tf.Session() as sess:
        srcnn = SRCNN(sess,
                      image_size=FLAGS.image_size,
                      label_size=FLAGS.label_size,
                      batch_size=FLAGS.batch_size,
                      c_dim=FLAGS.c_dim,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir)
        srcnn.train(FLAGS)
Пример #10
0
def main(_):
    """3.print configurations"""
    print('tf version:', tf.__version__)
    print('tf setup:')
    for k, v in FLAGS.flag_values_dict().items():
        print(k, v)
    FLAGS.TB_dir += '_' + str(FLAGS.c_dim)
    """4.check/create folders"""
    print("check dirs...")
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.TB_dir):
        os.makedirs(FLAGS.TB_dir)
    """5.begin tf session"""
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        print("building model...")
        """6.init srcnn model"""
        srcnn = SRCNN(sess, FLAGS)
        """7.start to train/test"""
        if (FLAGS.is_train):
            srcnn.train()
        elif FLAGS.patch_test:
            srcnn.test()
        else:
            srcnn.test_whole_img()
Пример #11
0
def main():

    if not os.path.exists(Config.checkpoint_dir):
        os.makedirs(Config.checkpoint_dir)

    with tf.Session() as sess:
        trysr = SRCNN(sess,
                      image_size=Config.image_size,
                      label_size=Config.label_size,
                      batch_size=Config.batch_size,
                      c_dim=Config.c_dim,
                      checkpoint_dir=Config.checkpoint_dir,
                      scale=Config.scale)

        trysr.train(Config)
 def main(self):
     global model
     print("SRCNN ==> Data loading .. ")
     loader = data.Data(self.args)
     print("SRCNN ==> Check run type .. ")
     if self.args.run_type == 'train':
         train_data_loader = loader.loader_train
         test_data_loader = loader.loader_test
         print("SRCNN ==> Load model .. ")
         model = SRCNN.SRCNN()
         print("SRCNN ==> Setting optimizer .. [ ", self.args.optimizer, " ] , lr [ ", self.args.lr, " ] , Loss [ MSE ]")
         optimizer = optim.Adam(model.parameters(), self.args.lr)
         if self.args.cuda:
             model.cuda()
         self.train(model, optimizer, self.args.epochs, train_data_loader, test_data_loader)
     elif self.args.run_type == 'test':
         print("SRCNN ==> Testing .. ")
         if os.path.exists(self.args.pre_model_dir):
             if not os.path.exists(self.args.dir_data_test_lr): print("SRCNN ==> Fail [ Test model is not exists ]")
             else:
                 test_data_loader = loader.loader_test
                 Loaded = torch.load(self.args.pre_model_dir)
                 model.load_state_dict(Loaded)
                 if self.args.cuda:
                     model.cuda()
                 self.test(self.args, test_data_loader, model)
         else : print("SRCNN ==> Fail [ Pretrain model directory is not exists ]")
Пример #13
0
def main(_):
    with tf.Session() as sess:

        if config.train:
            srcnn = SRCNN(sess,
                          image_size=config.image_size,
                          label_size=config.label_size,
                          c=config.c)
            train(srcnn, config)
        else:
            diff = config.image_size - config.label_size
            srcnn = SRCNN(sess,
                          image_size=config.image_size + diff,
                          label_size=config.image_size,
                          c=config.c)
            test(srcnn, config)
Пример #14
0
def train(trainX, trainY):
    m = SRCNN()
    m.compile(Adam(lr=0.0003), 'mse')
    count = 1
    while True:
        m.fit(trainX, trainY, batch_size=128, nb_epoch=5)
        print("Saving model " + str(count * 5))
        m.save(join('./model_' + str(count * 5) + '.h5'))
        count += 1
Пример #15
0
def SRCNN2(
    args, image_file
):  # CHANGE TO INPUT THE after-resize IMAGE FILE, SO IN THE OUTPUT3, NEED TO STORE THE denoise+resize image
    # load the SRCNN weights model
    #cudnn.benchmark = True
    device = torch.device('cuda: 0' if torch.cuda.is_available() else 'cpu')
    model = SRCNN().to(device)
    state_dict = model.state_dict()
    weights_dir = os.getcwd() + '\\SRCNN_outputs\\x{}\\'.format(
        args.SR_scale)  #
    weights_file = os.path.join(weights_dir, 'best.pth')  ###
    if not weights_file:
        print(weights_file + ' not exist')
    for n, p in torch.load(weights_file,
                           map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model.eval()  # model set in evaluation mode

    img_format = image_file[-4:]
    image = pil_image.open(image_file).convert('RGB')  # 512

    image = np.array(image).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(image)

    y = ycbcr[..., 0]
    y /= 255.
    y = torch.from_numpy(y).to(device)
    y = y.unsqueeze(0).unsqueeze(0)

    with torch.no_grad():
        preds = model(y).clamp(0.0, 1.0)  # output2.size 510

    # psnr = calc_psnr(y, preds)
    # print('PSNR: {:.2f}'.format(psnr))

    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(
        0)  # tensor -> np

    output = np.array([preds, ycbcr[..., 1],
                       ycbcr[..., 2]]).transpose([1, 2, 0])  # why transpose
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    return output  ## type in pil_image
Пример #16
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        srcnn = SRCNN(sess,
                      image_size=FLAGS.image_size,
                      label_size=FLAGS.label_size,
                      batch_size=FLAGS.batch_size,
                      is_grayscale=FLAGS.is_grayscale,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir)

        srcnn.train(FLAGS)
Пример #17
0
def test(self,sess):
    nx,ny = input_up(sess)
    print(nx,ny)
    data_dir = os.path.join(os.getcwd(), "checkpoint\\test.h5")
    test_data, test_label = preprocessing.read_data(data_dir)
    if SRCNN.load(self,config.checkpoint_dir):
        print(" [*] Load SUCCESS")
    else:
        print(" [!] Load failed...")
    print("Testing...")
    #312*21
    result = SRCNN.model(self).eval({self.images:test_data,self.labels:test_label})
    result = merge(result,[nx,ny])
    result = result.squeeze() # 除去size为1的维度
    # result= exposure.adjust_gamma(result, 1.07)#调暗一些
    image_path = os.path.join(os.getcwd(), "sample")
    image_path = os.path.join(image_path, "MySRCNN.bmp")
    preprocessing.imsave( image_path,result)
Пример #18
0
def main(_):
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)
    if not os.path.exists(FLAGS.summaries_dir):
        os.makedirs(FLAGS.summaries_dir)
    sess = tf.InteractiveSession()

    srcnn = SRCNN(sess,
                  imageSize=FLAGS.image_size,
                  labelSize=FLAGS.label_size,
                  batchSize=FLAGS.batch_size,
                  channel=FLAGS.input_channel,
                  checkpoint_dir=FLAGS.checkpoint_dir,
                  sample_dir=FLAGS.sample_dir)

    srcnn.mode()
Пример #19
0
def main(args):
    srcnn = SRCNN(
        image_size=args.image_size,
        c_dim=args.c_dim,
        is_training=False)
    X_pre_test, X_test, Y_test = load_test(scale=args.scale)
    predicted_list = []
    for img in X_test:
        predicted = srcnn.process(img.reshape(1,img.shape[0],img.shape[1],1))
        predicted_list.append(predicted.reshape(predicted.shape[1],predicted.shape[2],1))
    n_img = len(predicted_list)
    dirname = './static'
    for i in range(n_img):
        imgname='image00'
        #imgname = 'image{:02}'.format(i)
        #cv2.imwrite(os.path.join(dirname,imgname+'_original.bmp'), X_pre_test[i])
        cv2.imwrite(os.path.join(dirname,imgname+'_input.bmp'), X_test[i])
        cv2.imwrite(os.path.join(dirname,imgname+'_answer.bmp'), Y_test[i])
Пример #20
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        srcnn = SRCNN(sess,
                      image_size=FLAGS.image_size,
                      label_size=FLAGS.label_size,
                      batch_size=FLAGS.batch_size,
                      c_dim=FLAGS.c_dim,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir)

        #srcnn.test('Test/Set5/baby_GT.bmp', FLAGS)
        srcnn.train(FLAGS)
Пример #21
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        srcnn = SRCNN(sess,
                      image_size=FLAGS.image_size,
                      label_size=FLAGS.label_size,
                      batch_size=FLAGS.batch_size,
                      c_dim=FLAGS.c_dim,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir)

        srcnn.train(FLAGS)
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.device('/gpu:0'):
        with tf.Session(config=config) as sess:
            srcnn = SRCNN(sess,
                          image_size=FLAGS.image_size,
                          label_size=FLAGS.label_size,
                          batch_size=FLAGS.batch_size,
                          c_dim=FLAGS.c_dim,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir,
                          config=FLAGS)
            srcnn.train(FLAGS)
Пример #23
0
def main():

    dataloaders = myDataloader()
    train_loader = dataloaders.getTrainLoader(batch_size)

    model = SRCNN().cuda()
    model.train()

    optimizer = optim.Adam(model.parameters(), lr=lr)
    mse_loss = nn.MSELoss()

    for ep in range(epoch):
        running_loss = 0.0
        for i, (pic, blurPic, _) in enumerate(train_loader):
            pic = pic.cuda()
            blurPic = blurPic.cuda()
            optimizer.zero_grad()
            out = model(blurPic)
            loss = mse_loss(out, pic)
            loss.backward()
            optimizer.step()

            running_loss += loss
            if i % 10 == 9:
                print('[%d %d] loss: %.4f' %
                      (ep + 1, i + 1, running_loss / 20))
                running_loss = 0.0
        if ep % 10 == 9:
            torch.save(model.state_dict(),
                       f="./result/train/" + str(ep + 1) + "srcnnParms.pth")
    print("finish training")
Пример #24
0
def main(_):
    t0 = time.time()

    pp.pprint(FLAGS.build_model)

    if not FLAGS.build_model:
        FLAGS.test_img = validate(FLAGS.test_img)
        print("Image path = ", FLAGS.test_img)
        if not os.path.isfile(FLAGS.test_img):
            print("File does not exist ", FLAGS.test_img)
            sys.exit()

    create_required_directories(FLAGS)

    with tf.compat.v1.Session() as sess:
        srcnn = SRCNN(sess,
                      image_size=FLAGS.image_size,
                      label_size=FLAGS.label_size,
                      batch_size=FLAGS.batch_size,
                      c_dim=FLAGS.c_dim,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir)

        if FLAGS.build_model:
            srcnn.train(FLAGS)
        else:
            srcnn.test(FLAGS)
    print("\n\nTime taken %4.2f\n\n" % (time.time() - t0))
Пример #25
0
def train(training_data, dev_data, args):
    training_gen = data.DataLoader(training_data, batch_size=2)
    dev_gen = data.DataLoader(dev_data, batch_size=2)
    device = torch.device('cuda' if cuda.is_available() else 'cpu')
    print('Initializing model')
    model = SRCNN()
    loss = RMSE()
    if cuda.device_count() > 1:
        print('Using %d CUDA devices' % cuda.device_count())
        model = nn.DataParallel(
            model, device_ids=[i for i in range(cuda.device_count())])
    model.to(device)
    loss.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    def _train(data, opt=True):
        total = 0
        for y, x in data:
            y, x = y.to(device), x.to(device)
            pred_y = model(x)
            l = loss(pred_y, y)
            total += l.item()
            if opt:
                optimizer.zero_grad()
                l.backward()
                optimizer.step()
        cuda.synchronize()
        return total

    print('Training')
    for ep in range(args.ep):
        train_loss = _train(training_gen)
        dev_loss = _train(dev_gen, opt=False)
        print_flush('Epoch %d: Train %.4f Dev %.4f' %
                    (ep, train_loss, dev_loss))
        if ep % 50 == 0:
            save_model(model, args.o)
    return model
Пример #26
0
def upscaling():
    data = json.loads(request.get_data(as_text=True))
    times = data['times']
    picname = data['picname']
    # 判断是否重复
    timespic = picname[0:picname.find('.')] + times + "x_"
    picture = Picture.query.filter(Picture.name.like(timespic + "%")).first()
    if picture != None:
        return jsonify(code=400, message="The picture has been processed")
    # 路径
    path = os.path.join(os.getcwd(), FLAGS.sample_dir, picname)
    print(path)
    # 名字
    picture = Picture.query.filter_by(name=picname).first()
    picname = picname[0:picname.find('.')] + times + 'x_.' + picture.suffix
    # url
    url = picture.url
    url = url[0:url.rfind('/')] + '/' + picname[0:picname.find('_') + 1]
    # 放大图片
    with tf.Session() as sess:
        srcnn = SRCNN(sess,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir)
        srcnn.upscaling(picname, path, FLAGS, int(times))
    # 保存数据库
    action = 'Upscale_' + times + 'X'
    newpic = Picture(picname, url, action, picture.id)
    db.session.add(newpic)
    db.session.flush()
    db.session.commit()

    return jsonify(code=200,
                   message="success upscaling",
                   name=picname,
                   id=newpic.id,
                   url=url,
                   action=action)
Пример #27
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)
    config = tf.ConfigProto(allow_soft_placement=True)

    with tf.device('/device:GPU:0'):
        with tf.Session(config=config) as sess:
            srcnn = SRCNN(sess,
                          image_size=FLAGS.image_size,
                          label_size=FLAGS.label_size,
                          batch_size=FLAGS.batch_size,
                          ci_dim=FLAGS.ci_dim,
                          co_dim=FLAGS.co_dim,
                          scale_factor=FLAGS.scale_factor,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir,
                          model_carac=FLAGS.model_carac,
                          train_dir=FLAGS.train_dir,
                          test_dir=FLAGS.test_dir)
            srcnn.train(FLAGS)
Пример #28
0
def main(mode='srgan'):
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        if mode == 'srcnn':
            srcnn = SRCNN(sess,
                          image_size=FLAGS.image_size,
                          label_size=FLAGS.label_size,
                          batch_size=FLAGS.batch_size,
                          c_dim=FLAGS.c_dim,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir)
        srgan = SRGAN(sess)
        srgan.train()
        srgan.test()
Пример #29
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    srcnn = SRCNN(image_size=FLAGS.image_size,
                  label_size=FLAGS.label_size,
                  batch_size=FLAGS.batch_size,
                  c_dim=FLAGS.c_dim,
                  checkpoint_dir=FLAGS.checkpoint_dir,
                  sample_dir=FLAGS.sample_dir,
                  FLAGS=FLAGS)
    srcnn.call()
    if FLAGS.is_train == True:
        srcnn.train()
    else:
        srcnn.inference()
Пример #30
0
def main(_):
    """3.print configurations"""
    print('tf version:',tf.__version__)
    print('tf setup:')
    #os.makedirs(FLAGS.checkpoint_dir)
    """5.begin tf session"""
    with tf.Session() as sess:
        """6.init srcnn model"""
        srcnn = SRCNN(sess, FLAGS)
        """7.start to train/test"""
        if(FLAGS.is_train):
            srcnn.train()
        else:
            srcnn.test()