예제 #1
0
def stylize(args):
    #content_image = utils.tensor_load_rgbimage(args.content_image, scale = args.content_scale)
    #content_image = content_image.unsqueeze(0)
    content_image = None
    if args.srcnn:
        content_image = utils.tensor_load_rgbimage(args.content_image,
                                                   scale=args.upsample)
    else:
        content_image = utils.tensor_load_rgbimage(args.content_image)
    content_image.unsqueeze_(0)
    if args.cuda:
        content_image = content_image.cuda()
    content_image = Variable(utils.preprocess_batch(content_image),
                             volatile=True)

    style_model = None
    if args.srcnn:
        style_model = SRCNN()
    else:
        style_model = TransformerNet(args.arch)
    ##style_model = TransformerNet()
    style_model.load_state_dict(torch.load(args.model))

    if args.cuda:
        style_model.cuda()

    output = style_model(content_image)
    utils.tensor_save_bgrimage(output.data[0], args.output_image, args.cuda)
예제 #2
0
 def __init__(self,input_path,nImgs,nEpochs,batch_size,print_every):
     self.input_path = input_path #Path where data and labels are stored in folders Data/ and Labels/
     self.nImgs = nImgs #Number of images to use for training
     self.nEpochs = nEpochs #Number of epochs to train for
     self.batch_size = batch_size
     self.print_every = print_every
     self.nIters = self.nImgs*self.nEpochs/self.batch_size
     self.data_files = glob.glob(self.input_path + '/Data/*.bmp')[0:nImgs]
     self.label_files = glob.glob(self.input_path + '/Labels/*.bmp')[0:nImgs]
     print "{0} files loaded".format(len(self.data_files))
     self.mean = 113.087154872
     self.stddev = 69.7176496121
     self.model = SRCNN()
     self.imgs = []
     self.target = []
예제 #3
0
def test_image(config_file, image):
    image = np.array(image, dtype='float32') / 255.
    x = T.tensor4('input', theano.config.floatX)
    net = SRCNN(config_file)
    net.load_params()
    hi_res_patches = net.inference(x)
    test_network = net.compile([x],
                               hi_res_patches,
                               name='test_srcnn',
                               allow_input_downcast=True)
    prediction = test_network(image)
    return prediction
def main(argv):
    # Set up flags
    FLAGS = setupFlags()
    printer = pp.PrettyPrinter()
    printer.pprint(FLAGS.__flags)

    with tf.Session() as sess:
        # Initialize SRCNN Model - Creates CNN layers, loss function and other model related variables
        resBooster_SRCNN = SRCNN(sess, FLAGS)

        # Load Model
        print(" [*] Loading model...")
        model_loaded = resBooster_SRCNN.load_model()
        if (model_loaded):
            print(" [*] Model successfully loaded.")
        else:
            print(" [*] Error loading model.")

        # Train model / Test
        if (FLAGS.is_train):
            resBooster_SRCNN.train_model()
        else:
            resBooster_SRCNN.test_model()
예제 #5
0
    parser.add_argument('--load', help='load model')
    parser.add_argument('--batch_size', default='4')
    parser.add_argument('--refine', default=None)

    args = parser.parse_args()
    setup_global_var(args)

    checkpoint_path = None
    if args.load:
        checkpoint_path = args.load

    level = int(args.level)

    # RefineNet setup
    if args.refine is not None:
        refineNet = SRCNN(IN_CH, OUT_CH).to(device)
        refineNet.requires_grad = False
        refineCheckpoint = torch.load(REFINE_PATHS[level - 1])
        refineNet.load_state_dict(refineCheckpoint['state_dict'])
        checkpoint_path = REFINE_PATHS[level]

    # Setup dataflow
    X_data, y_data = get_data()
    X_train, y_train = X_data[5:], y_data[5:]
    X_test, y_test = X_data[:5], y_data[:5]

    print("Train-set size: ", len(X_train))
    print("Test-set size: ", len(X_test))

    if args.refine is not None:
        print("Use refine-dataset")
예제 #6
0
# sample_images(Raw_folder, LR_folder,0.5)
# sample_images(Raw_folder, HR_folder,1)
# align_images(LR_folder, HR_folder, LR_folder)
# sample_images(LR_folder, Inputs_folder_train,1)
# sample_images(HR_folder, Labels_folder_train,1)

# # 然后将图像分割
# size=64
# cut_images(Inputs_folder_train, size, size//1)
# cut_images(Labels_folder_train, size, size//1)

# # 随机分配文件到测试集中
# random_move(Inputs_folder_train,Labels_folder_train,Inputs_folder_test,Labels_folder_test,0.1)

# 设置训练参数
net=SRCNN()
lr, num_epochs = 0.03, 400
batch_size = 32
my_optim=torch.optim.SGD(net.parameters(),lr=lr,momentum=0.9,weight_decay=0.0001)
# 自适应学习率
scheduler = torch.optim.lr_scheduler.StepLR(my_optim,step_size=40,gamma = 0.1)
loss = torch.nn.MSELoss()

# 读取数据集
train_dataset=ImagePairDataset_y(Inputs_folder_train,Labels_folder_train)
test_dataset=ImagePairDataset_y(Inputs_folder_test,Labels_folder_test)
train_iter = DataLoader(train_dataset, batch_size, shuffle=True)
test_iter = DataLoader(test_dataset, 1, shuffle=True)
print('Datasets loaded!')

# 训练
예제 #7
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs1 = {'num_workers': 0, 'pin_memory': False}
        kwargs2 = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs1 = {}
        kwargs2 = {}
    """ 
    ### for SR the transdataset should be be HR and LR. 
    transform = transforms.Compose([transforms.Scale(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size = args.batch_size, **kwargs)
    """

    transform_LR = None
    transform_HR = None
    if args.srcnn:
        transform_HR = transforms.Compose([
            transforms.Resize((args.image_size, args.image_size),
                              interpolation=3),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])
        transform_LR = transforms.Compose([
            iu.AddMyGauss(),
            transforms.Resize((int(args.image_size), int(args.image_size)),
                              interpolation=3),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])
    else:
        transform_HR = transforms.Compose([
            transforms.CenterCrop((args.image_size, args.image_size)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])
        transform_LR = transforms.Compose([
            iu.AddMyGauss(),
            transforms.CenterCrop((args.image_size, args.image_size)),
            transforms.Resize((int(args.image_size / args.upsample),
                               int(args.image_size / args.upsample)),
                              interpolation=3),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])
    #transform_LR = transforms.Compose([iu.AddMyGauss(),transforms.Resize((int(args.image_size/args.upsample),int(args.image_size/args.upsample)),interpolation=3), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))])

    #transform_HR = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))])
    train_dataset_HR = datasets.ImageFolder(args.dataset, transform_HR)
    train_loader_HR = DataLoader(train_dataset_HR,
                                 batch_size=args.batch_size,
                                 **kwargs1)
    #print(args.image_size/args.upsample)
    train_dataset_LR = datasets.ImageFolder(args.dataset, transform_LR)
    train_loader_LR = DataLoader(train_dataset_LR,
                                 batch_size=args.batch_size,
                                 **kwargs2)

    transformer = None
    if args.srcnn:
        transformer = SRCNN()
    else:
        transformer = TransformerNet(args.arch)
    #transformer = TransformerNet()
    optimizer = Adam(transformer.parameters(), lr=args.lr)
    mse_loss = torch.nn.MSELoss()
    """
    vggmodel = torchvision.models.vgg.vgg16(pretrained=True)
    if args.cuda:
        vggmodel=vggmodel.cuda()
    vgg = LossNetwork(vggmodel)
    vgg.eval()
    del vggmodel
    """
    if args.cuda:
        transformer.cuda()
        #vgg.cuda()

    #style = utils.tensor_load_rgbimage(args.style_image, size = args.style_size)
    #style = style.repeat(args.batch_size, 1, 1, 1)
    #style = utils.preprocess_batch(style)

    #if args.cuda:
    #    style=style.cuda()
    #style_v = utils.subtract_imagenet_mean_batch(Variable(style, volatile = True))
    #if args.cuda:
    #    style_v=style_v.cuda()
    #features_style = vgg(style_v)
    #gram_style = [utils.gram_matrix(y) for y in features_style]
    #log_msg = "pix_weight = "+args.pix_weight+"   content_weight = "+args.content_weight

    for e in range(args.epochs):
        log_msg = "pix_weight = " + str(
            args.pix_weight) + "   content_weight = " + str(
                args.content_weight)
        print(log_msg)
        transformer.train()
        agg_content_loss = 0
        agg_style_loss = 0
        count = 0
        for batch_id, ((x, x_), (style, y_)) in enumerate(
                zip(train_loader_LR, train_loader_HR)):
            #(y,y_) = train_loader_HR[batch_id]

            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            #style = utils.tensor_load_rgbimage(args.style_image, size=args.style_size)
            #style = style.repeat(args.batch_size, 1, 1, 1)
            #gram_style = [utils.gram_matrix(y) for y in features_style]

            x = utils.preprocess_batch(x)
            style = utils.preprocess_batch(style)

            x = Variable(x)

            if args.cuda:
                x = x.cuda()

            y = transformer(x)

            yy = Variable(style)

            if args.cuda:
                yy = yy.cuda()

            pix_loss = args.pix_weight * mse_loss(y, yy)
            """
            xc = Variable(style.clone())
            if(args.cuda):
                xc = xc.cuda()
            y = utils.subtract_imagenet_mean_batch(y)
            xc = utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(y)
            features_xc = vgg(xc)

            f_xc_c = Variable(features_xc[1].data)

            content_loss = args.content_weight * mse_loss(features_y[1], f_xc_c)
            """
            #style_loss = 0;
            """
            for m in range(len(features_y)):
                gram_s = Variable(gram_style[m].data, requires_grad = False)
                gram_y = utils.gram_matrix(features_y[m])
                style_loss += args.style_weight * mse_loss(gram_y, gram_s[:n_batch, :, :])
            """

            total_loss = pix_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += pix_loss.data[0]
            #agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(\
                time.ctime(), e + 1, count, len(train_dataset_LR),\
                agg_content_loss / (batch_id + 1),\
                agg_style_loss / (batch_id + 1),\
                (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

    transformer.eval()
    transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_SRCNN_" + str(args.srcnn) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
예제 #8
0
def train_srcnn(config_file, **kwargs):
    net = SRCNN(config_file)

    X = T.tensor4('input', theano.config.floatX)
    Y = T.tensor4('gt', theano.config.floatX)

    mean_X = T.mean(X)
    X_centered = X
    placeholder_x = theano.shared(np.zeros((net.batch_size,) + net.input_shape[1:], 'float32'), 'patch_placeholder')
    placeholder_y = theano.shared(np.zeros((net.batch_size, net.output_shape[2], net.output_shape[0], net.output_shape[1]),
                                           'float32'), 'img_placeholder')

    output = net.inference(X_centered)
    up_img = output
    # cropped = 9 // 2 + 5 // 2
    cost = net.build_cost(up_img, Y, **{'params': net.regularizable})
    # cost = net.build_cost(up_img, Y[:, :, cropped:-cropped, cropped:-cropped], **{'params': net.regularizable})
    updates = net.build_updates(cost, net.trainable)
    train_network = net.compile([], cost, updates=updates, givens={X: placeholder_x, Y: placeholder_y}, name='train_srcnn')

    psnr_loss = psnr(up_img, Y)
    # psnr_loss = psnr(up_img, Y[:, :, cropped:-cropped, cropped:-cropped])
    msssim_loss = MS_SSIM(rgb2gray(up_img), rgb2gray(Y))
    # msssim_loss = MS_SSIM(up_img, Y[:, :, cropped:-cropped, cropped:-cropped])
    test_network = net.compile([], [cost, psnr_loss, msssim_loss], givens={X: placeholder_x, Y: placeholder_y}, name='test_srcnn')

    epoch = 0
    vote_to_terminate = 0
    best_psnr = 0.
    best_epoch = 0
    if net.display_cost:
        training_cost_to_plot = []
        validation_cost_to_plot = []

    data_manager = DataManager4(net.batch_size, (placeholder_x, placeholder_y), True, False)
    num_training_batches = data_manager.train_data_shape[0] // net.batch_size
    num_validation_batches = data_manager.test_data_shape[0] // net.validation_batch_size

    print('Training... %d training batches, %d developing batches' % (num_training_batches, num_validation_batches))
    start_training_time = time.time()
    while epoch < net.n_epochs:
        epoch += 1
        training_cost = 0.
        start_epoch_time = time.time()
        batches = data_manager.get_batches(epoch=epoch, num_epochs=net.n_epochs)
        idx = 0
        for b in batches:
            iteration = (epoch - 1.) * num_training_batches + idx + 1
            data_manager.update_input(b)
            training_cost += train_network()
            if np.isnan(training_cost):
                raise ValueError('Training failed due to NaN cost')

            if iteration % net.validation_frequency == 0:
                batch_valid = data_manager.get_batches(stage='test')
                validation_cost = 0.
                validation_psnr = 0.
                validation_msssim = 0.
                for b_valid in batch_valid:
                    data_manager.update_input(b_valid)
                    c, p, s = test_network()
                    validation_cost += c
                    validation_psnr += p
                    validation_msssim += s
                validation_cost /= num_validation_batches
                validation_psnr /= num_validation_batches
                validation_msssim /= num_validation_batches
                print('\tvalidation cost: %.4f' % validation_cost)
                print('\tvalidation PSNR: %.4f' % validation_psnr)
                print('\tvalidation MSSSIM: %.4f' % validation_msssim)
                if validation_psnr > best_psnr:
                    best_epoch = epoch
                    best_psnr = validation_psnr
                    vote_to_terminate = 0
                    print('\tbest validation PSNR: %.4f' % best_psnr)
                    if net.extract_params:
                        net.save_params()
                else:
                    vote_to_terminate += 1

                if net.display_cost:
                    training_cost_to_plot.append(training_cost / (idx + 1))
                    validation_cost_to_plot.append(validation_cost)
                    plt.clf()
                    plt.plot(training_cost_to_plot)
                    plt.plot(validation_cost_to_plot)
                    plt.show(block=False)
                    plt.pause(1e-5)
            idx += 1
        training_cost /= num_training_batches
        print('\tepoch %d took %.2f mins' % (epoch, (time.time() - start_epoch_time) / 60.))
        print('\ttraining cost: %.4f' % training_cost)
    if net.display_cost:
        plt.savefig('%s/training_curve.png' % net.save_path)
    print('Best validation PSNR: %.4f' % best_psnr)
    print('Training took %.2f hours' % ((time.time() - start_training_time) / 3600))