示例#1
0
def main():
    conf = get_config()
    ctx = get_extension_context(conf.nnabla_context.context,
                                device_id=conf.nnabla_context.device_id)
    nn.set_default_context(ctx)
    refs = sorted(os.listdir(conf.data.ref_path))
    # sort the reference images in order
    # Inference the input frames taking each reference image
    for ref in refs:
        colorize_video(conf, ref)
示例#2
0
def data_iterator_sr(num_examples,
                     batch_size,
                     gt_image,
                     lq_image,
                     train,
                     shuffle,
                     rng=None):
    from args import get_config
    conf = get_config()

    def dataset_load_func(i):
        # get images from the list
        scale = conf.train.scale
        gt_size = conf.train.gt_size
        gt_img = read_image(gt_image[i])
        lq_img = read_image(lq_image[i])
        if not train:
            gt_img = modcrop(gt_img, scale)
        gt_img = channel_convert(gt_img.shape[2], gt_img, color="RGB")
        if train:
            # randomly crop
            H, W, C = lq_img.shape
            lq_size = gt_size // scale
            rnd_h = random.randint(0, max(0, H - lq_size))
            rnd_w = random.randint(0, max(0, W - lq_size))
            lq_img = lq_img[rnd_h:rnd_h + lq_size, rnd_w:rnd_w + lq_size, :]
            rnd_h_gt, rnd_w_gt = int(rnd_h * scale), int(rnd_w * scale)
            gt_img = gt_img[rnd_h_gt:rnd_h_gt + gt_size,
                            rnd_w_gt:rnd_w_gt + gt_size, :]
            # horizontal and vertical flipping and rotation
            hflip, rot = [True, True]
            hflip = hflip and random.random() < 0.5
            vflip = rot and random.random() < 0.5
            rot90 = rot and random.random() < 0.5
            lq_img = augment(lq_img, hflip, rot90, vflip)
            gt_img = augment(gt_img, hflip, rot90, vflip)
            lq_img = channel_convert(C, [lq_img], color="RGB")[0]
        # BGR to RGB and HWC to CHW
        if gt_img.shape[2] == 3:
            gt_img = gt_img[:, :, [2, 1, 0]]
            lq_img = lq_img[:, :, [2, 1, 0]]

        gt_img = np.ascontiguousarray(np.transpose(gt_img, (2, 0, 1)))
        lq_img = np.ascontiguousarray(np.transpose(lq_img, (2, 0, 1)))
        return gt_img, lq_img

    return data_iterator_simple(dataset_load_func,
                                num_examples,
                                batch_size,
                                shuffle=shuffle,
                                rng=rng,
                                with_file_cache=False,
                                with_memory_cache=False)
示例#3
0
 def __init__(self):
     conf = get_config()
     self.h5_file = conf.train.vgg_pre_trained_weights
     with nn.parameter_scope("vgg19"):
         print('loading vgg19 parameters')
         nn.load_parameters(self.h5_file)
         # drop all the affine layers for finetuning.
         drop_layers = [
             'classifier/0/affine', 'classifier/3/affine',
             'classifier/6/affine'
         ]
         for layers in drop_layers:
             nn.parameter.pop_parameter((layers + '/W'))
             nn.parameter.pop_parameter((layers + '/b'))
         self.mean = nn.Variable.from_numpy_array(
             np.asarray([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1))
         self.std = nn.Variable.from_numpy_array(
             np.asarray([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1))
示例#4
0
def val_save(val_gt, val_lq, val_lq_path, idx, epoch, avg_psnr):
    conf = get_config()
    sr_img = rrdb_net(val_lq, 64, 23)
    real_image = array_to_image(val_gt.data)
    sr_image = array_to_image(sr_img.data)
    img_name = os.path.splitext(os.path.basename(val_lq_path[idx]))[0]
    img_dir = os.path.join(conf.val.save_results + "/results", img_name)
    if not os.path.exists(img_dir):
        os.makedirs(img_dir)
    save_img_path = os.path.join(img_dir,
                                 '{:s}_{:d}.png'.format(img_name, epoch))
    cv2.imwrite(save_img_path, sr_image)
    crop_size = conf.train.scale
    cropped_sr_image = sr_image[crop_size:-crop_size, crop_size:-crop_size, :]
    cropped_real_image = real_image[crop_size:-crop_size,
                                    crop_size:-crop_size, :]
    avg_psnr += calculate_psnr(cropped_sr_image, cropped_real_image)
    print("validating", img_name)
    return avg_psnr
示例#5
0
    def __init__(self):
        conf = get_config()

        with nn.parameter_scope("vgg19"):
            if not conf.train.checkpoint:
                print("Loading pre-trained vgg19 weights from ",
                      conf.train.vgg_pre_trained_weights)
                nn.load_parameters(conf.train.vgg_pre_trained_weights)

                # drop all the affine layers from pre-trained model for finetuning.
                drop_layers = [
                    'classifier/0/affine', 'classifier/3/affine',
                    'classifier/6/affine'
                ]
                for layers in drop_layers:
                    nn.parameter.pop_parameter((layers + '/W'))
                    nn.parameter.pop_parameter((layers + '/b'))
            self.mean = nn.Variable.from_numpy_array(
                np.asarray([123.68, 116.78, 103.94]).reshape(1, 1, 1, 3))
示例#6
0
def main():
    conf = get_config()
    extension_module = conf.nnabla_context.context
    ctx = get_extension_context(extension_module,
                                device_id=conf.nnabla_context.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    print("#GPU Count: ", comm.n_procs)

    data_iterator_train = jsi_iterator(conf.batch_size, conf, train=True)
    if conf.scaling_factor == 1:
        d_t = nn.Variable((conf.batch_size, 80, 80, 3), need_grad=True)
        l_t = nn.Variable((conf.batch_size, 80, 80, 3), need_grad=True)

    else:
        d_t = nn.Variable((conf.batch_size, 160 / conf.scaling_factor,
                           160 / conf.scaling_factor, 3),
                          need_grad=True)
        l_t = nn.Variable((conf.batch_size, 160, 160, 3), need_grad=True)

    if comm.n_procs > 1:
        data_iterator_train = data_iterator_train.slice(
            rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank)

    monitor_path = './nnmonitor' + \
        str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))

    monitor = Monitor(monitor_path)
    jsi_monitor = setup_monitor(conf, monitor)

    with nn.parameter_scope("jsinet"):
        nn.load_parameters(conf.pre_trained_model)
        net = model(d_t, conf.scaling_factor)
        net.pred.persistent = True
    rec_loss = F.mean(F.squared_error(net.pred, l_t))
    rec_loss.persistent = True
    g_final_loss = rec_loss

    if conf.jsigan:
        net_gan = gan_model(l_t, net.pred, conf)
        d_final_fm_loss = net_gan.d_adv_loss
        d_final_fm_loss.persistent = True
        d_final_detail_loss = net_gan.d_detail_adv_loss
        d_final_detail_loss.persistent = True
        g_final_loss = conf.rec_lambda * rec_loss + conf.adv_lambda * (
            net_gan.g_adv_loss + net_gan.g_detail_adv_loss
        ) + conf.fm_lambda * (net_gan.fm_loss + net_gan.fm_detail_loss)
        g_final_loss.persistent = True

    max_iter = data_iterator_train._size // (conf.batch_size)
    if comm.rank == 0:
        print("max_iter", data_iterator_train._size, max_iter)

    iteration = 0
    if not conf.jsigan:
        start_epoch = 0
        end_epoch = conf.adv_weight_point
        lr = conf.learning_rate * comm.n_procs
    else:
        start_epoch = conf.adv_weight_point
        end_epoch = conf.epoch
        lr = conf.learning_rate * comm.n_procs
        w_d = conf.weight_decay * comm.n_procs

    # Set generator parameters
    with nn.parameter_scope("jsinet"):
        solver_jsinet = S.Adam(alpha=lr, beta1=0.9, beta2=0.999, eps=1e-08)
        solver_jsinet.set_parameters(nn.get_parameters())

    if conf.jsigan:
        solver_disc_fm = S.Adam(alpha=lr, beta1=0.9, beta2=0.999, eps=1e-08)
        solver_disc_detail = S.Adam(alpha=lr,
                                    beta1=0.9,
                                    beta2=0.999,
                                    eps=1e-08)
        with nn.parameter_scope("Discriminator_FM"):
            solver_disc_fm.set_parameters(nn.get_parameters())
        with nn.parameter_scope("Discriminator_Detail"):
            solver_disc_detail.set_parameters(nn.get_parameters())

    for epoch in range(start_epoch, end_epoch):
        for index in range(max_iter):
            d_t.d, l_t.d = data_iterator_train.next()

            if not conf.jsigan:
                # JSI-net -> Generator
                lr_stair_decay_points = [200, 225]
                lr_net = get_learning_rate(lr, iteration,
                                           lr_stair_decay_points,
                                           conf.lr_decreasing_factor)
                g_final_loss.forward(clear_no_need_grad=True)
                solver_jsinet.zero_grad()
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    g_final_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    g_final_loss.backward(clear_buffer=True)
                solver_jsinet.set_learning_rate(lr_net)
                solver_jsinet.update()
            else:
                # GAN part (discriminator + generator)
                lr_gan = lr if epoch < conf.gan_lr_linear_decay_point \
                    else lr * (end_epoch - epoch) / (end_epoch - conf.gan_lr_linear_decay_point)
                lr_gan = lr_gan * conf.gan_ratio

                net.pred.need_grad = False

                # Discriminator_FM
                solver_disc_fm.zero_grad()
                d_final_fm_loss.forward(clear_no_need_grad=True)
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    d_final_fm_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    d_final_fm_loss.backward(clear_buffer=True)
                solver_disc_fm.set_learning_rate(lr_gan)
                solver_disc_fm.weight_decay(w_d)
                solver_disc_fm.update()

                # Discriminator_Detail
                solver_disc_detail.zero_grad()
                d_final_detail_loss.forward(clear_no_need_grad=True)
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    d_final_detail_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    d_final_detail_loss.backward(clear_buffer=True)
                solver_disc_detail.set_learning_rate(lr_gan)
                solver_disc_detail.weight_decay(w_d)
                solver_disc_detail.update()

                # Generator
                net.pred.need_grad = True
                solver_jsinet.zero_grad()
                g_final_loss.forward(clear_no_need_grad=True)
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    g_final_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    g_final_loss.backward(clear_buffer=True)
                solver_jsinet.set_learning_rate(lr_gan)
                solver_jsinet.update()

            iteration += 1
            if comm.rank == 0:
                train_psnr = compute_psnr(net.pred.d, l_t.d, 1.)
                jsi_monitor['psnr'].add(iteration, train_psnr)
                jsi_monitor['rec_loss'].add(iteration, rec_loss.d.copy())
                jsi_monitor['time'].add(iteration)

            if comm.rank == 0:
                if conf.jsigan:
                    jsi_monitor['g_final_loss'].add(iteration,
                                                    g_final_loss.d.copy())
                    jsi_monitor['g_adv_loss'].add(iteration,
                                                  net_gan.g_adv_loss.d.copy())
                    jsi_monitor['g_detail_adv_loss'].add(
                        iteration, net_gan.g_detail_adv_loss.d.copy())
                    jsi_monitor['d_final_fm_loss'].add(
                        iteration, d_final_fm_loss.d.copy())
                    jsi_monitor['d_final_detail_loss'].add(
                        iteration, d_final_detail_loss.d.copy())
                    jsi_monitor['fm_loss'].add(iteration,
                                               net_gan.fm_loss.d.copy())
                    jsi_monitor['fm_detail_loss'].add(
                        iteration, net_gan.fm_detail_loss.d.copy())
                    jsi_monitor['lr'].add(iteration, lr_gan)

        if comm.rank == 0:
            if not os.path.exists(conf.output_dir):
                os.makedirs(conf.output_dir)
            with nn.parameter_scope("jsinet"):
                nn.save_parameters(
                    os.path.join(conf.output_dir,
                                 "model_param_%04d.h5" % epoch))
示例#7
0
def main():
    conf = get_config()
    train_gt_path = sorted(glob.glob(conf.DIV2K.gt_train + "/*.png"))
    train_lq_path = sorted(glob.glob(conf.DIV2K.lq_train + "/*.png"))
    val_gt_path = sorted(glob.glob(conf.SET14.gt_val + "/*.png"))
    val_lq_path = sorted(glob.glob(conf.SET14.lq_val + "/*.png"))
    train_samples = len(train_gt_path)
    val_samples = len(val_gt_path)
    lr_g = conf.hyperparameters.lr_g
    lr_d = conf.hyperparameters.lr_d
    lr_steps = conf.train.lr_steps

    random.seed(conf.train.seed)
    np.random.seed(conf.train.seed)

    extension_module = conf.nnabla_context.context
    ctx = get_extension_context(
        extension_module, device_id=conf.nnabla_context.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)

    # data iterators for train and val data
    from data_loader import data_iterator_sr
    data_iterator_train = data_iterator_sr(
        train_samples, conf.train.batch_size, train_gt_path, train_lq_path, train=True, shuffle=True)
    data_iterator_val = data_iterator_sr(
        val_samples, conf.val.batch_size, val_gt_path, val_lq_path, train=False, shuffle=False)

    if comm.n_procs > 1:
        data_iterator_train = data_iterator_train.slice(
            rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank)

    train_gt = nn.Variable(
        (conf.train.batch_size, 3, conf.train.gt_size, conf.train.gt_size))
    train_lq = nn.Variable(
        (conf.train.batch_size, 3, conf.train.gt_size // conf.train.scale, conf.train.gt_size // conf.train.scale))

    # setting up monitors for logging
    monitor_path = './nnmonitor' + str(datetime.now().strftime("%Y%m%d%H%M%S"))
    monitor = Monitor(monitor_path)
    monitor_pixel_g = MonitorSeries(
        'l_g_pix per iteration', monitor, interval=100)
    monitor_val = MonitorSeries(
        'Validation loss per epoch', monitor, interval=1)
    monitor_time = MonitorTimeElapsed(
        "Training time per epoch", monitor, interval=1)

    with nn.parameter_scope("gen"):
        nn.load_parameters(conf.train.gen_pretrained)
        fake_h = rrdb_net(train_lq, 64, 23)
        fake_h.persistent = True
    pixel_loss = F.mean(F.absolute_error(fake_h, train_gt))
    pixel_loss.persistent = True
    gen_loss = pixel_loss

    if conf.model.esrgan:
        from esrgan_model import get_esrgan_gen, get_esrgan_dis, get_esrgan_monitors
        gen_model = get_esrgan_gen(conf, train_gt, train_lq, fake_h)
        gen_loss = conf.hyperparameters.eta_pixel_loss * pixel_loss + conf.hyperparameters.feature_loss_weight * gen_model.feature_loss + \
            conf.hyperparameters.lambda_gan_loss * gen_model.loss_gan_gen
        dis_model = get_esrgan_dis(fake_h, gen_model.pred_d_real)
        # Set Discriminator parameters
        solver_dis = S.Adam(lr_d, beta1=0.9, beta2=0.99)
        with nn.parameter_scope("dis"):
            solver_dis.set_parameters(nn.get_parameters())
        esr_mon = get_esrgan_monitors()

    # Set generator Parameters
    solver_gen = S.Adam(alpha=lr_g, beta1=0.9, beta2=0.99)
    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())

    train_size = int(
        train_samples / conf.train.batch_size / comm.n_procs)
    total_epochs = conf.train.n_epochs
    start_epoch = 0
    current_iter = 0
    if comm.rank == 0:
        print("total_epochs", total_epochs)
        print("train_samples", train_samples)
        print("val_samples", val_samples)
        print("train_size", train_size)

    for epoch in range(start_epoch + 1, total_epochs + 1):
        index = 0
        # Training loop for psnr rrdb model
        while index < train_size:
            current_iter += comm.n_procs
            train_gt.d, train_lq.d = data_iterator_train.next()

            if not conf.model.esrgan:
                lr_g = get_repeated_cosine_annealing_learning_rate(
                    current_iter, conf.hyperparameters.eta_max, conf.hyperparameters.eta_min, conf.train.cosine_period,
                    conf.train.cosine_num_period)

            if conf.model.esrgan:
                lr_g = get_multistep_learning_rate(
                    current_iter, lr_steps, lr_g)
                gen_model.var_ref.d = train_gt.d
                gen_model.pred_d_real.grad.zero()
                gen_model.pred_d_real.forward(clear_no_need_grad=True)
                gen_model.pred_d_real.need_grad = False

            # Generator update
            gen_loss.forward(clear_no_need_grad=True)
            solver_gen.zero_grad()
            # All-reduce gradients every 2MiB parameters during backward computation
            if comm.n_procs > 1:
                with nn.parameter_scope('gen'):
                    all_reduce_callback = comm.get_all_reduce_callback()
                    gen_loss.backward(clear_buffer=True,
                                      communicator_callbacks=all_reduce_callback)
            else:
                gen_loss.backward(clear_buffer=True)
            solver_gen.set_learning_rate(lr_g)
            solver_gen.update()

            # Discriminator Upate
            if conf.model.esrgan:
                gen_model.pred_d_real.need_grad = True
                lr_d = get_multistep_learning_rate(
                    current_iter, lr_steps, lr_d)
                solver_dis.zero_grad()
                dis_model.l_d_total.forward(clear_no_need_grad=True)
                if comm.n_procs > 1:
                    with nn.parameter_scope('dis'):
                        all_reduce_callback = comm.get_all_reduce_callback()
                    dis_model.l_d_total.backward(
                        clear_buffer=True, communicator_callbacks=all_reduce_callback)
                else:
                    dis_model.l_d_total.backward(clear_buffer=True)
                solver_dis.set_learning_rate(lr_d)
                solver_dis.update()

            index += 1
            if comm.rank == 0:
                monitor_pixel_g.add(
                    current_iter, pixel_loss.d.copy())
                monitor_time.add(epoch * comm.n_procs)
            if comm.rank == 0 and conf.model.esrgan:
                esr_mon.monitor_feature_g.add(
                    current_iter, gen_model.feature_loss.d.copy())
                esr_mon.monitor_gan_g.add(
                    current_iter, gen_model.loss_gan_gen.d.copy())
                esr_mon.monitor_gan_d.add(
                    current_iter, dis_model.l_d_total.d.copy())
                esr_mon.monitor_d_real.add(current_iter, F.mean(
                    gen_model.pred_d_real.data).data)
                esr_mon.monitor_d_fake.add(current_iter, F.mean(
                    gen_model.pred_g_fake.data).data)

        # Validation Loop
        if comm.rank == 0:
            avg_psnr = 0.0
            for idx in range(val_samples):
                val_gt_im, val_lq_im = data_iterator_val.next()
                val_gt = nn.NdArray.from_numpy_array(val_gt_im)
                val_lq = nn.NdArray.from_numpy_array(val_lq_im)
                with nn.parameter_scope("gen"):
                    avg_psnr = val_save(
                        val_gt, val_lq, val_lq_path, idx, epoch, avg_psnr)
            avg_psnr = avg_psnr / val_samples
            monitor_val.add(epoch, avg_psnr)

        # Save generator weights
        if comm.rank == 0:
            if not os.path.exists(conf.train.savemodel):
                os.makedirs(conf.train.savemodel)
            with nn.parameter_scope("gen"):
                nn.save_parameters(os.path.join(
                    conf.train.savemodel, "generator_param_%06d.h5" % epoch))
       # Save discriminator weights
        if comm.rank == 0 and conf.model.esrgan:
            with nn.parameter_scope("dis"):
                nn.save_parameters(os.path.join(
                    conf.train.savemodel, "discriminator_param_%06d.h5" % epoch))
示例#8
0
def inference():
    """
    Inference function to generate high resolution hdr images
    """
    conf = get_config()
    ctx = get_extension_context(conf.nnabla_context.context,
                                device_id=conf.nnabla_context.device_id)
    nn.set_default_context(ctx)

    data, target = read_mat_file(conf.data.lr_sdr_test,
                                 conf.data.hr_hdr_test,
                                 conf.data.d_name_test,
                                 conf.data.l_name_test,
                                 train=False)

    if not os.path.exists(conf.test_img_dir):
        os.makedirs(conf.test_img_dir)

    data_sz = data.shape
    target_sz = target.shape
    PATCH_BOUNDARY = 10  # set patch boundary to reduce edge effect around patch edges
    test_loss_PSNR_list_for_epoch = []
    inf_time = []
    start_time = time.time()

    test_pred_full = np.zeros((target_sz[1], target_sz[2], target_sz[3]))

    print("Loading pre trained model.........", conf.pre_trained_model)
    nn.load_parameters(conf.pre_trained_model)

    for index in range(data_sz[0]):
        ###======== Divide Into Patches ========###
        for p in range(conf.test_patch**2):
            pH = p // conf.test_patch
            pW = p % conf.test_patch
            sH = data_sz[1] // conf.test_patch
            sW = data_sz[2] // conf.test_patch
            H_low_ind, H_high_ind, W_low_ind, W_high_ind = \
                get_hw_boundary(
                    PATCH_BOUNDARY, data_sz[1], data_sz[2], pH, sH, pW, sW)
            data_test_p = nn.Variable.from_numpy_array(
                data.d[index, H_low_ind:H_high_ind, W_low_ind:W_high_ind, :])
            data_test_sz = data_test_p.shape
            data_test_p = F.reshape(
                data_test_p,
                (1, data_test_sz[0], data_test_sz[1], data_test_sz[2]))
            st = time.time()
            net = model(data_test_p, conf.scaling_factor)
            net.pred.forward()
            test_pred_temp = net.pred.d
            inf_time.append(time.time() - st)
            test_pred_t = trim_patch_boundary(test_pred_temp, PATCH_BOUNDARY,
                                              data_sz[1], data_sz[2], pH, sH,
                                              pW, sW, conf.scaling_factor)
            #pred_sz = test_pred_t.shape
            test_pred_t = np.squeeze(test_pred_t)
            test_pred_full[pH * sH * conf.scaling_factor:(pH + 1) * sH *
                           conf.scaling_factor,
                           pW * sW * conf.scaling_factor:(pW + 1) * sW *
                           conf.scaling_factor, :] = test_pred_t

        ###======== Compute PSNR & Print Results========###
        test_GT = np.squeeze(target.d[index, :, :, :])
        test_PSNR = compute_psnr(test_pred_full, test_GT, 1.)
        test_loss_PSNR_list_for_epoch.append(test_PSNR)
        print(
            " <Test> [%4d/%4d]-th images, time: %4.4f(minutes), test_PSNR: %.8f[dB]  "
            % (int(index), int(data_sz[0]),
               (time.time() - start_time) / 60, test_PSNR))
        if conf.save_images:
            # comment for faster testing
            save_results_yuv(test_pred_full, index, conf.test_img_dir)
    test_PSNR_per_epoch = np.mean(test_loss_PSNR_list_for_epoch)

    print("######### Average Test PSNR: %.8f[dB]  #########" %
          (test_PSNR_per_epoch))
    print(
        "######### Estimated Inference Time (per 4K frame): %.8f[s]  #########"
        % (np.mean(inf_time) * conf.test_patch * conf.test_patch))
示例#9
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys

import nnabla as nn
import nnabla.functions as F
import numpy as np

from preprocess import uncenter_l
from args import get_config

conf = get_config()


def vgg_pre_process(x):
    x_bgr = F.concatenate(x[:, 2:3, :, :],
                          x[:, 1:2, :, :],
                          x[:, 0:1, :, :],
                          axis=1)
    # tensor_bgr = tensor[:, [2, 1, 0], ...]
    x_sub = F.reshape(
        nn.Variable.from_numpy_array(
            np.array([0.40760392, 0.45795686, 0.48501961])), (1, 3, 1, 1))
    x_bgr_ml = x_bgr - x_sub
    x_rst = x_bgr_ml * 255
    return x_rst
示例#10
0
def main():
    """
    main - driver code to run training for Zooming SloMo
    """
    # Check NNabla version
    if get_nnabla_version_integer() < 11700:
        raise ValueError(
            'This does not work with nnabla version less than v1.17.0 since deformable_conv layer is added in v1.17.0 . Please update the nnabla version.'
        )

    conf = get_config()
    extension_module = conf.nnabla_context.context
    ctx = get_extension_context(extension_module,
                                device_id=conf.nnabla_context.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    print("comm rank", comm.rank)

    # change max_iter, learning_rate and cosine_period when batch-size or no. of gpu devices change.
    default_batch_size = 12
    train_scale_factor = comm.n_procs * \
        (conf.train.batch_size / default_batch_size)
    max_iter = int(conf.train.max_iter // train_scale_factor)
    learning_rate = conf.train.learning_rate * \
        (conf.train.batch_size / default_batch_size)
    cosine_period = int(conf.train.cosine_period // train_scale_factor)

    # for single-GPU training
    data_iterator_train = data_iterator(conf, shuffle=True)

    # for multi-GPU training
    if comm.n_procs > 1:
        data_iterator_train = data_iterator_train.slice(
            rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank)

    # LR-LFR data for ZoomingSloMo input
    data_lr_lfr = nn.Variable(
        (conf.train.batch_size, (conf.data.n_frames // 2) + 1, 3,
         conf.data.lr_size, conf.data.lr_size))

    # HR-HFR data for ZoomingSloMo ground truth
    data_gt = nn.Variable((conf.train.batch_size, conf.data.n_frames, 3,
                           conf.data.gt_size, conf.data.gt_size))

    if conf.train.only_slomo:
        '''
        High resolution data as input to only-Slomo network for frame interpolation,
        hence we use lesser number of frames.
        '''
        # LFR data for SloMo input,
        slomo_gt = data_gt
        input_to_slomo = slomo_gt[:, 0:conf.data.n_frames:2, :, :, :]

    # setting up monitors for logging
    monitor_path = './nnmonitor'
    monitor = Monitor(monitor_path)
    monitor_loss = MonitorSeries('loss',
                                 monitor,
                                 interval=conf.train.monitor_log_freq)
    monitor_lr = MonitorSeries('learning rate',
                               monitor,
                               interval=conf.train.monitor_log_freq)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=conf.train.monitor_log_freq)

    scope_name = "ZoomingSloMo" if not conf.train.only_slomo else "SloMo"

    with nn.parameter_scope(scope_name):
        if conf.train.only_slomo:
            generated_frame = zooming_slo_mo_network(input_to_slomo,
                                                     conf.train.only_slomo)
            diff = generated_frame - slomo_gt
        else:
            generated_frame = zooming_slo_mo_network(data_lr_lfr,
                                                     conf.train.only_slomo)
            diff = generated_frame - data_gt

    # Charbonnier loss
    loss = F.sum((diff * diff + conf.train.eps)**0.5)

    # Define optimizer
    solver = S.Adam(alpha=learning_rate,
                    beta1=conf.train.beta1,
                    beta2=conf.train.beta2)

    # Set Parameters
    with nn.parameter_scope(scope_name):
        solver.set_parameters(nn.get_parameters())

    solver_dict = {scope_name: solver}

    if comm.rank == 0:
        print("maximum iterations", max_iter)

    start_point = 0
    if conf.train.checkpoint:
        # Load optimizer/solver information and model weights from checkpoint
        print("Loading weights from checkpoint:", conf.train.checkpoint)
        with nn.parameter_scope(scope_name):
            start_point = load_checkpoint(conf.train.checkpoint, solver_dict)

    if not os.path.isdir(conf.data.output_dir):
        os.makedirs(conf.data.output_dir)

    # Training loop.
    for i in range(start_point, max_iter):
        # Get Training Data
        if conf.train.only_slomo:
            _, data_gt.d = data_iterator_train.next()
        else:
            data_lr_lfr.d, data_gt.d = data_iterator_train.next()
        l_rate = get_repeated_cosine_annealing_learning_rate(
            i, learning_rate, conf.train.eta_min, cosine_period,
            conf.train.cosine_num_period)

        # Update
        solver.zero_grad()
        solver.set_learning_rate(l_rate)
        loss.forward(clear_no_need_grad=True)
        if comm.n_procs > 1:
            all_reduce_callback = comm.get_all_reduce_callback()
            loss.backward(clear_buffer=True,
                          communicator_callbacks=all_reduce_callback)
        else:
            loss.backward(clear_buffer=True)
        solver.update()

        if comm.rank == 0:
            monitor_loss.add(i, loss.d.copy())
            monitor_lr.add(i, l_rate)
            monitor_time.add(i)
            if (i % conf.train.save_checkpoint_freq) == 0:
                # Save intermediate check_points
                with nn.parameter_scope(scope_name):
                    save_checkpoint(conf.data.output_dir, i, solver_dict)

    # Save final model parameters
    if comm.rank == 0:
        with nn.parameter_scope(scope_name):
            nn.save_parameters(
                os.path.join(conf.data.output_dir, "final_model.h5"))
示例#11
0
def main():
    conf = get_config()
    extension_module = conf.nnabla_context.context
    ctx = get_extension_context(
        extension_module, device_id=conf.nnabla_context.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    print("comm rank", comm.rank)

    # data iterators for train and val data
    from data_loader import data_iterator_sr, get_sample_name_grid, nn_data_gauss_down_quad

    sample_names = get_sample_name_grid(conf)
    num_samples = len(sample_names[0])
    print("No of training samples :", num_samples)

    tar_size = conf.train.crop_size
    tar_size = (conf.train.crop_size * 4) + int(1.5 * 3.0) * \
        2  # crop_size * 4, and Gaussian blur margin

    data_iterator_train = data_iterator_sr(
        conf, num_samples, sample_names, tar_size, shuffle=True)

    if comm.n_procs > 1:
        data_iterator_train = data_iterator_train.slice(
            rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank)

    train_hr = nn.Variable(
        (conf.train.batch_size, conf.train.rnn_n, conf.train.crop_size*4, conf.train.crop_size*4, 3))
    data_hr = nn.Variable(
        (conf.train.batch_size, conf.train.rnn_n, tar_size, tar_size, 3))
    train_lr = nn_data_gauss_down_quad(data_hr.reshape(
        (conf.train.batch_size * conf.train.rnn_n, tar_size, tar_size, 3)))
    train_lr = F.reshape(
        train_lr, (conf.train.batch_size, conf.train.rnn_n, conf.train.crop_size, conf.train.crop_size, 3))

    # setting up monitors for logging
    monitor_path = './nnmonitor' + \
        str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
    monitor = Monitor(monitor_path)
    common_monitor = get_common_monitors(monitor)

    # Change max_iter and learning_rate when batch size or no. of gpu devices change.
    div_factor = conf.train.batch_size * comm.n_procs
    max_iter = (conf.train.max_iter * 4) // div_factor
    learning_rate = conf.train.learning_rate * \
        (conf.train.batch_size / 4) * comm.n_procs

    if comm.rank == 0:
        print("maximum iterations", max_iter)

    scope_name = 'frvsr/'
    if conf.train.tecogan:
        scope_name = 'tecogan/'
        if not conf.train.checkpoint:
            print('loading pretrained FRVSR model',
                  conf.train.pre_trained_frvsr_weights)
            with nn.parameter_scope(scope_name):
                nn.load_parameters(conf.train.pre_trained_frvsr_weights)
                params_from_pre_trained_model = []
                for key, val in nn.get_parameters().items():
                    params_from_pre_trained_model.append(scope_name + key)

            network = get_tecogan_model(conf, train_lr, train_hr, scope_name)
            params_from_graph = nn.get_parameters()

            # Set the Generator parameters which are not in FRVSR to zero,
            # as done in orig implementation.
            for key, val in params_from_graph.items():
                if key in params_from_pre_trained_model or key.startswith('vgg') or key.startswith('disc'):
                    continue
                print(key)
                val.data.zero()  # fill with zero

        else:
            network = get_tecogan_model(conf, train_lr, train_hr, scope_name)

        # Define discriminator optimizer/solver
        solver_disc = S.Adam(alpha=learning_rate,
                             beta1=conf.train.beta, eps=conf.train.adameps)
        # Set discriminator Parameters
        with nn.parameter_scope("discriminator"):
            solver_disc.set_parameters(nn.get_parameters())

        # setting up monitors for TecoGAN
        tecogan_monitor = get_tecogan_monitors(monitor)

    else:
        network = get_frvsr_model(conf, train_lr, train_hr, scope_name)

    # Define generator and fnet optimizer/solver
    solver_gen = S.Adam(alpha=learning_rate,
                        beta1=conf.train.beta, eps=conf.train.adameps)
    solver_fnet = S.Adam(alpha=learning_rate,
                         beta1=conf.train.beta, eps=conf.train.adameps)

    # Set generator and fnet Parameters
    with nn.parameter_scope(scope_name + "generator"):
        solver_gen.set_parameters(nn.get_parameters())
    with nn.parameter_scope(scope_name + "fnet"):
        solver_fnet.set_parameters(nn.get_parameters())

    if conf.train.tecogan:
        solver_dict = {"gen": solver_gen,
                       "fnet": solver_fnet, "disc": solver_disc}
    else:
        solver_dict = {"gen": solver_gen, "fnet": solver_fnet}

    start_point = 0
    if conf.train.checkpoint:
        # Load optimizer/solver information and model weights from checkpoint
        start_point = load_checkpoint(conf.train.checkpoint, solver_dict)

    # Exponential Moving Average Calculation for tb
    ema = ExponentialMovingAverage(conf.train.decay)
    tb = 0

    # Create output directory if it doesn't exist
    if not os.path.exists(conf.data.output_dir):
        os.makedirs(conf.data.output_dir)

    # Training loop.
    for i in range(start_point, max_iter):
        # Get Training Data
        data_hr.d, train_hr.d = data_iterator_train.next()

        if conf.train.tecogan:
            network.t_discrim_loss.forward(clear_no_need_grad=True)
            if np.less(tb, 0.4):  # train gen with d
                # Compute grads for discriminator and update
                solver_disc.zero_grad()
                # Stop back-propagation from t_discrim_loss to generator
                network.t_gen_output.need_grad = False
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    network.t_discrim_loss.backward(clear_buffer=True,
                                                    communicator_callbacks=all_reduce_callback)
                else:
                    network.t_discrim_loss.backward(clear_buffer=True)
                solver_disc.update()  # Update grads
                # Enable back propagation from fnet_loss to Generator
                network.t_gen_output.need_grad = True

        # Compute grads for fnet and generator together using fnet_loss
        solver_fnet.zero_grad()
        solver_gen.zero_grad()
        # Apply forward and backward propagation on fnet_loss
        network.fnet_loss.forward(clear_no_need_grad=True)
        if comm.n_procs > 1:
            all_reduce_callback = comm.get_all_reduce_callback()
            network.fnet_loss.backward(clear_buffer=True,
                                       communicator_callbacks=all_reduce_callback)
        else:
            network.fnet_loss.backward(clear_buffer=True)
        # Update grads for fnet and generator
        solver_gen.update()
        solver_fnet.update()

        if conf.train.tecogan:
            if comm.n_procs > 1:
                comm.all_reduce([network.t_discrim_real_loss.data,
                                 network.t_adversarial_loss.data], division=True, inplace=True)
            t_balance = F.mean(network.t_discrim_real_loss.data) + \
                network.t_adversarial_loss.data
            if i == 0:
                ema.register(t_balance)
            else:
                tb = ema(t_balance)
            if comm.rank == 0:
                tecogan_monitor.monitor_pp_loss.add(
                    i, network.pp_loss.d.copy())
                tecogan_monitor.monitor_vgg_loss.add(
                    i, network.vgg_loss.d.copy())
                tecogan_monitor.monitor_sum_layer_loss.add(
                    i, network.sum_layer_loss.d.copy())
                tecogan_monitor.monitor_adv_loss.add(
                    i, network.t_adversarial_loss.d.copy())
                tecogan_monitor.monitor_disc_loss.add(
                    i, network.t_discrim_loss.d.copy())
                tecogan_monitor.monitor_tb.add(i, tb)

        if comm.rank == 0:
            common_monitor.monitor_content_loss.add(
                i, network.content_loss.d.copy())
            common_monitor.monitor_gen_loss.add(i, network.gen_loss.d.copy())
            common_monitor.monitor_warp_loss.add(i, network.warp_loss.d.copy())
            common_monitor.monitor_lr.add(i, learning_rate)
            common_monitor.monitor_time.add(i)
            if (i % conf.train.save_freq) == 0:
                # Save intermediate model parameters
                with nn.parameter_scope(scope_name):
                    nn.save_parameters(os.path.join(
                        conf.data.output_dir, "model_param_%08d.h5" % i))

                # Save intermediate check_points
                save_checkpoint(conf.data.output_dir, i, solver_dict)

    # save final Generator and Fnet network parameters
    if comm.rank == 0:
        with nn.parameter_scope(scope_name):
            nn.save_parameters(os.path.join(
                conf.data.output_dir, "model_param_%08d.h5" % i))
示例#12
0
    score = roc_auc_score(
        np.concatenate(epoch_targets, axis=0),
        np.concatenate(epoch_preds, axis=0),
    )

    mode = "train" if data_loader.is_train else "valid"
    print(
        f"epoch {epoch_idx:02} {mode} score > {score:.4} ({int(timer() - epoch_start)}s)"
    )

    total_loss /= len(data_loader.dataset)
    return score, total_loss


if __name__ == "__main__":
    config = get_config()

    # random seed
    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.random.manual_seed(config.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config.seed)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"device: {device}")

    model = CharCNNScorer(
        vocab_size=len(data_utils.vocabs),
        char_embed_size=config.char_embed_size,
        filter_sizes=config.filter_sizes,