Ejemplo n.º 1
0
def train(args, data, params):
    train = data['train']
    valid = data['valid']
    learning_rate = args.learning_rate

    with tf.Graph().as_default():
        input_ph = tf.placeholder(tf.int32,
                                  shape=[args.batch_size, params['gram_size']])
        targ_ph = tf.placeholder(tf.int32, shape=[args.batch_size])
        learning_rate_ph = tf.placeholder(tf.float32, shape=[])

        if args.w2v:
            with h5py.File(args.w2v, 'r') as datafile:
                embeds = datafile['w2v'][:]
            scores, normalize_op, vars = ops.model(input_ph, params, embeds)
        else:
            scores, normalize_op, vars = ops.model(input_ph, params)

        loss = ops.loss(scores, targ_ph)
        train_op, print_op = ops.train(loss, learning_rate_ph, args)

        #sess = tf.Session(config=tf.ConfigProto(inter_op_parallelism_threads=NUM_THREADS,\
        #		intra_op_parallelism_threads=NUM_THREADS))
        sess = tf.Session()
        init = tf.initialize_all_variables(
        )  # initialize variables before they can be used
        saver = tf.train.Saver()
        sess.run(init)
        if args.modelfile:
            saver.restore(sess, args.modelfile)
            print "Model restored from %s" % args.modelfile

        valid_loss = 0.
        for i in xrange(valid.nbatches):
            valid_feed_dict = get_feed_dict(valid, i, input_ph, targ_ph,
                                            learning_rate_ph)
            batch_loss = sess.run([loss], feed_dict=valid_feed_dict)[0]
            valid_loss += batch_loss
        last_valid = valid_loss
        print 'Initial valid loss: %.3f' % math.exp(
            valid_loss / valid.nbatches)

        for epoch in xrange(args.nepochs):
            print "Training epoch %d with learning rate %.3f" % (epoch + 1,
                                                                 learning_rate)
            vals = sess.run(vars)
            start_time = time.time()
            train_loss = 0.
            valid_loss = 0.

            for i in xrange(train.nbatches):
                train_feed_dict = get_feed_dict(train, i, input_ph, targ_ph, \
                    learning_rate_ph, learning_rate)
                #grads = sess.run(print_op, feed_dict=train_feed_dict)
                _, batch_loss = sess.run([train_op, loss],
                                         feed_dict=train_feed_dict)
                train_loss += batch_loss

            for i in xrange(valid.nbatches):
                valid_feed_dict = get_feed_dict(valid, i, input_ph, targ_ph,
                                                learning_rate_ph)
                batch_loss = sess.run([loss], feed_dict=valid_feed_dict)[0]
                valid_loss += batch_loss

            if args.normalize:
                _ = sess.run(normalize_op)

            duration = time.time() - start_time
            print "\tloss = %.3f, valid ppl = %.3f, %.3f s" % \
                (math.exp(train_loss/train.nbatches), \
                    math.exp(valid_loss/valid.nbatches), duration)
            if last_valid < valid_loss:
                learning_rate /= 2.
            elif args.outfile:
                saver.save(sess, args.outfile)
            if epoch >= args.decay_after:
                learning_rate /= 1.2
            last_valid = valid_loss

        return sess.run([normalize_op
                         ])[0]  # return final normalized embeddings
Ejemplo n.º 2
0
def main():
    conf = get_config()
    extension_module = conf.nnabla_context.context
    ctx = get_extension_context(extension_module,
                                device_id=conf.nnabla_context.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    print("#GPU Count: ", comm.n_procs)

    data_iterator_train = jsi_iterator(conf.batch_size, conf, train=True)
    if conf.scaling_factor == 1:
        d_t = nn.Variable((conf.batch_size, 80, 80, 3), need_grad=True)
        l_t = nn.Variable((conf.batch_size, 80, 80, 3), need_grad=True)

    else:
        d_t = nn.Variable((conf.batch_size, 160 / conf.scaling_factor,
                           160 / conf.scaling_factor, 3),
                          need_grad=True)
        l_t = nn.Variable((conf.batch_size, 160, 160, 3), need_grad=True)

    if comm.n_procs > 1:
        data_iterator_train = data_iterator_train.slice(
            rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank)

    monitor_path = './nnmonitor' + \
        str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))

    monitor = Monitor(monitor_path)
    jsi_monitor = setup_monitor(conf, monitor)

    with nn.parameter_scope("jsinet"):
        nn.load_parameters(conf.pre_trained_model)
        net = model(d_t, conf.scaling_factor)
        net.pred.persistent = True
    rec_loss = F.mean(F.squared_error(net.pred, l_t))
    rec_loss.persistent = True
    g_final_loss = rec_loss

    if conf.jsigan:
        net_gan = gan_model(l_t, net.pred, conf)
        d_final_fm_loss = net_gan.d_adv_loss
        d_final_fm_loss.persistent = True
        d_final_detail_loss = net_gan.d_detail_adv_loss
        d_final_detail_loss.persistent = True
        g_final_loss = conf.rec_lambda * rec_loss + conf.adv_lambda * (
            net_gan.g_adv_loss + net_gan.g_detail_adv_loss
        ) + conf.fm_lambda * (net_gan.fm_loss + net_gan.fm_detail_loss)
        g_final_loss.persistent = True

    max_iter = data_iterator_train._size // (conf.batch_size)
    if comm.rank == 0:
        print("max_iter", data_iterator_train._size, max_iter)

    iteration = 0
    if not conf.jsigan:
        start_epoch = 0
        end_epoch = conf.adv_weight_point
        lr = conf.learning_rate * comm.n_procs
    else:
        start_epoch = conf.adv_weight_point
        end_epoch = conf.epoch
        lr = conf.learning_rate * comm.n_procs
        w_d = conf.weight_decay * comm.n_procs

    # Set generator parameters
    with nn.parameter_scope("jsinet"):
        solver_jsinet = S.Adam(alpha=lr, beta1=0.9, beta2=0.999, eps=1e-08)
        solver_jsinet.set_parameters(nn.get_parameters())

    if conf.jsigan:
        solver_disc_fm = S.Adam(alpha=lr, beta1=0.9, beta2=0.999, eps=1e-08)
        solver_disc_detail = S.Adam(alpha=lr,
                                    beta1=0.9,
                                    beta2=0.999,
                                    eps=1e-08)
        with nn.parameter_scope("Discriminator_FM"):
            solver_disc_fm.set_parameters(nn.get_parameters())
        with nn.parameter_scope("Discriminator_Detail"):
            solver_disc_detail.set_parameters(nn.get_parameters())

    for epoch in range(start_epoch, end_epoch):
        for index in range(max_iter):
            d_t.d, l_t.d = data_iterator_train.next()

            if not conf.jsigan:
                # JSI-net -> Generator
                lr_stair_decay_points = [200, 225]
                lr_net = get_learning_rate(lr, iteration,
                                           lr_stair_decay_points,
                                           conf.lr_decreasing_factor)
                g_final_loss.forward(clear_no_need_grad=True)
                solver_jsinet.zero_grad()
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    g_final_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    g_final_loss.backward(clear_buffer=True)
                solver_jsinet.set_learning_rate(lr_net)
                solver_jsinet.update()
            else:
                # GAN part (discriminator + generator)
                lr_gan = lr if epoch < conf.gan_lr_linear_decay_point \
                    else lr * (end_epoch - epoch) / (end_epoch - conf.gan_lr_linear_decay_point)
                lr_gan = lr_gan * conf.gan_ratio

                net.pred.need_grad = False

                # Discriminator_FM
                solver_disc_fm.zero_grad()
                d_final_fm_loss.forward(clear_no_need_grad=True)
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    d_final_fm_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    d_final_fm_loss.backward(clear_buffer=True)
                solver_disc_fm.set_learning_rate(lr_gan)
                solver_disc_fm.weight_decay(w_d)
                solver_disc_fm.update()

                # Discriminator_Detail
                solver_disc_detail.zero_grad()
                d_final_detail_loss.forward(clear_no_need_grad=True)
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    d_final_detail_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    d_final_detail_loss.backward(clear_buffer=True)
                solver_disc_detail.set_learning_rate(lr_gan)
                solver_disc_detail.weight_decay(w_d)
                solver_disc_detail.update()

                # Generator
                net.pred.need_grad = True
                solver_jsinet.zero_grad()
                g_final_loss.forward(clear_no_need_grad=True)
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    g_final_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    g_final_loss.backward(clear_buffer=True)
                solver_jsinet.set_learning_rate(lr_gan)
                solver_jsinet.update()

            iteration += 1
            if comm.rank == 0:
                train_psnr = compute_psnr(net.pred.d, l_t.d, 1.)
                jsi_monitor['psnr'].add(iteration, train_psnr)
                jsi_monitor['rec_loss'].add(iteration, rec_loss.d.copy())
                jsi_monitor['time'].add(iteration)

            if comm.rank == 0:
                if conf.jsigan:
                    jsi_monitor['g_final_loss'].add(iteration,
                                                    g_final_loss.d.copy())
                    jsi_monitor['g_adv_loss'].add(iteration,
                                                  net_gan.g_adv_loss.d.copy())
                    jsi_monitor['g_detail_adv_loss'].add(
                        iteration, net_gan.g_detail_adv_loss.d.copy())
                    jsi_monitor['d_final_fm_loss'].add(
                        iteration, d_final_fm_loss.d.copy())
                    jsi_monitor['d_final_detail_loss'].add(
                        iteration, d_final_detail_loss.d.copy())
                    jsi_monitor['fm_loss'].add(iteration,
                                               net_gan.fm_loss.d.copy())
                    jsi_monitor['fm_detail_loss'].add(
                        iteration, net_gan.fm_detail_loss.d.copy())
                    jsi_monitor['lr'].add(iteration, lr_gan)

        if comm.rank == 0:
            if not os.path.exists(conf.output_dir):
                os.makedirs(conf.output_dir)
            with nn.parameter_scope("jsinet"):
                nn.save_parameters(
                    os.path.join(conf.output_dir,
                                 "model_param_%04d.h5" % epoch))
Ejemplo n.º 3
0
def inference():
    """
    Inference function to generate high resolution hdr images
    """
    conf = get_config()
    ctx = get_extension_context(conf.nnabla_context.context,
                                device_id=conf.nnabla_context.device_id)
    nn.set_default_context(ctx)

    data, target = read_mat_file(conf.data.lr_sdr_test,
                                 conf.data.hr_hdr_test,
                                 conf.data.d_name_test,
                                 conf.data.l_name_test,
                                 train=False)

    if not os.path.exists(conf.test_img_dir):
        os.makedirs(conf.test_img_dir)

    data_sz = data.shape
    target_sz = target.shape
    PATCH_BOUNDARY = 10  # set patch boundary to reduce edge effect around patch edges
    test_loss_PSNR_list_for_epoch = []
    inf_time = []
    start_time = time.time()

    test_pred_full = np.zeros((target_sz[1], target_sz[2], target_sz[3]))

    print("Loading pre trained model.........", conf.pre_trained_model)
    nn.load_parameters(conf.pre_trained_model)

    for index in range(data_sz[0]):
        ###======== Divide Into Patches ========###
        for p in range(conf.test_patch**2):
            pH = p // conf.test_patch
            pW = p % conf.test_patch
            sH = data_sz[1] // conf.test_patch
            sW = data_sz[2] // conf.test_patch
            H_low_ind, H_high_ind, W_low_ind, W_high_ind = \
                get_hw_boundary(
                    PATCH_BOUNDARY, data_sz[1], data_sz[2], pH, sH, pW, sW)
            data_test_p = nn.Variable.from_numpy_array(
                data.d[index, H_low_ind:H_high_ind, W_low_ind:W_high_ind, :])
            data_test_sz = data_test_p.shape
            data_test_p = F.reshape(
                data_test_p,
                (1, data_test_sz[0], data_test_sz[1], data_test_sz[2]))
            st = time.time()
            net = model(data_test_p, conf.scaling_factor)
            net.pred.forward()
            test_pred_temp = net.pred.d
            inf_time.append(time.time() - st)
            test_pred_t = trim_patch_boundary(test_pred_temp, PATCH_BOUNDARY,
                                              data_sz[1], data_sz[2], pH, sH,
                                              pW, sW, conf.scaling_factor)
            #pred_sz = test_pred_t.shape
            test_pred_t = np.squeeze(test_pred_t)
            test_pred_full[pH * sH * conf.scaling_factor:(pH + 1) * sH *
                           conf.scaling_factor,
                           pW * sW * conf.scaling_factor:(pW + 1) * sW *
                           conf.scaling_factor, :] = test_pred_t

        ###======== Compute PSNR & Print Results========###
        test_GT = np.squeeze(target.d[index, :, :, :])
        test_PSNR = compute_psnr(test_pred_full, test_GT, 1.)
        test_loss_PSNR_list_for_epoch.append(test_PSNR)
        print(
            " <Test> [%4d/%4d]-th images, time: %4.4f(minutes), test_PSNR: %.8f[dB]  "
            % (int(index), int(data_sz[0]),
               (time.time() - start_time) / 60, test_PSNR))
        if conf.save_images:
            # comment for faster testing
            save_results_yuv(test_pred_full, index, conf.test_img_dir)
    test_PSNR_per_epoch = np.mean(test_loss_PSNR_list_for_epoch)

    print("######### Average Test PSNR: %.8f[dB]  #########" %
          (test_PSNR_per_epoch))
    print(
        "######### Estimated Inference Time (per 4K frame): %.8f[s]  #########"
        % (np.mean(inf_time) * conf.test_patch * conf.test_patch))