def test_save_load_parameters():
    v = nn.Variable([64, 1, 28, 28], need_grad=False)
    with nn.parameter_scope("param1"):
        with nn.parameter_scope("conv1"):
            h = PF.convolution(v, 32, (3, 3))
            b = PF.batch_normalization(h, batch_stat=True)
        with nn.parameter_scope("conv2"):
            h1 = PF.convolution(v, 32, (3, 3))
            b2 = PF.batch_normalization(h1, batch_stat=True)

    for k, v in nn.get_parameters(grad_only=False).iteritems():
        v.data.cast(np.float32)[...] = np.random.randn(*v.shape)

    with nn.parameter_scope("param1"):
        param1 = nn.get_parameters(grad_only=False)
        nn.save_parameters("tmp.h5")
        nn.save_parameters("tmp.protobuf")

    with nn.parameter_scope("param2"):
        nn.load_parameters('tmp.h5')
        param2 = nn.get_parameters(grad_only=False)

    with nn.parameter_scope("param3"):
        nn.load_parameters('tmp.protobuf')
        param3 = nn.get_parameters(grad_only=False)

    for par2 in [param2, param3]:
        assert param1.keys() == par2.keys()  # Check order
        for (n1, p1), (n2, p2) in zip(sorted(param1.items()),
                                      sorted(par2.items())):
            assert n1 == n2
            assert np.all(p1.d == p2.d)
            assert p1.data.dtype == p2.data.dtype
            assert p1.need_grad == p2.need_grad
def test():
    print("Evaluate the trained model with full MNIST test set")

    args = get_args()

    # Create CNN network for both training and testing.
    mnist_cnn_prediction = mnist_lenet_prediction
    if args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction

    args.batch_size = 100
    tdata = data_iterator_mnist(args.batch_size, False)
    timage = nn.Variable([args.batch_size, 1, 28, 28])
    tlabel = nn.Variable([args.batch_size, 1])

    parameter_file = os.path.join(
        args.model_save_path,
        '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    nn.load_parameters(parameter_file)

    # Create inference graph
    tpred = mnist_cnn_prediction(timage, test=True)
    num_test_iter = int((tdata.size + args.batch_size - 1) / args.batch_size)

    te = 0.0

    for j in range(num_test_iter):
        timage.d, tlabel.d = tdata.next()
        tpred.forward(clear_buffer=True)
        te += categorical_error(tpred.d, tlabel.d)

    te_avg = te / num_test_iter
    print("MNIST test accuracy", 1 - te_avg)
예제 #3
0
    def load_checkpoint(self, args):
        """Load pretrained parameters and solver states

        Args:
                args (ArgumentParser): To check if tensorflow trained weights are to be used for testing and to get the path of the folder 
                                                                from where the parameter and solver states are to be loaded
        """

        if args.use_tf_weights:
            if not os.path.isfile(
                    os.path.join(args.weights_path, 'gen_params.h5')):
                os.makedirs(args.weights_path, exist_ok=True)
                print(
                    "Downloading the pretrained tf-converted weights. Please wait..."
                )
                url = "https://nnabla.org/pretrained-models/nnabla-examples/GANs/stylegan2/styleGAN2_G_params.h5"
                from nnabla.utils.data_source_loader import download
                download(url, os.path.join(args.weights_path, 'gen_params.h5'),
                         False)
            nn.load_parameters(os.path.join(args.weights_path,
                                            'gen_params.h5'))
            print('Loaded pretrained weights from tensorflow!')

        else:
            try:
                nn.load_parameters(os.path.join(args.weights_path,
                                                'params.h5'))
            except:
                if args.test:
                    warnings.warn(
                        "Testing Model without pretrained weights!!!")
                else:
                    print('No Pretrained weights loaded.')
예제 #4
0
def main(args):
    if args.gpu:
        ctx = get_extension_context('cudnn', device_id=str(args.device))
        nn.set_default_context(ctx)

    # atari environment
    env = AtariWrapper(gym.make(args.env), args.seed, episodic=True)
    eval_env = AtariWrapper(gym.make(args.env), args.seed, episodic=False)
    num_actions = env.action_space.n

    # action-value function built with neural network
    model = NoisyNetDQN(q_function, num_actions, args.batch_size, args.gamma,
                        args.lr)
    if args.load is not None:
        nn.load_parameters(args.load)
    model.update_target()

    buffer = ReplayBuffer(args.buffer_size, args.batch_size)

    exploration = ConstantEpsilonGreedy(num_actions, 0.0)

    monitor = prepare_monitor(args.logdir)

    update_fn = update(model, buffer, args.target_update_interval)

    eval_fn = evaluate(eval_env, model, render=args.render)

    train(env, model, buffer, exploration, monitor, update_fn, eval_fn,
          args.final_step, args.update_start, args.update_interval,
          args.save_interval, args.evaluate_interval, ['loss'])
    def load_models(epoch_num, gen=True, dis=True):

        # load generator parameter
        with nn.parameter_scope('Wave-U-Net'):
            nn.load_parameters(
                os.path.join(args.model_save_path,
                             'param_{:04}.h5'.format(args.epoch_from)))
예제 #6
0
def test(args):

    ##  Load data & Create batch
    clean_data, noisy_data = dt.data_loader(test=True, need_length=True)
    # Batch
    #  - Proccessing speech interval can be adjusted by "start_frame" and "start_frame".
    #  - "None" -> All speech in test dataset.
    baches_test = dt.create_batch_test(clean_data, noisy_data, start_frame=None, stop_frame=None)
    del clean_data, noisy_data

    ##  Create network
    # Variables
    noisy_t     = nn.Variable(baches_test.noisy.shape)          # Input
    z           = nn.Variable([baches_test.noisy.shape[0], 1024, 8])  # Random Latent Variable
    # Network (Only Generator)
    output_t = Generator(noisy_t, z)

    ##  Load parameter
    # load generator
    with nn.parameter_scope("gen"):
        print(args.epoch)
        nn.load_parameters(os.path.join(args.model_save_path, "generator_param_{:04}.h5".format(args.epoch)))

    ##  Validation
    noisy_t.d = baches_test.noisy
    #z.d = np.random.randn(*z.shape)
    z.d = np.zeros(z.shape)             # zero latent valiables

    output_t.forward()

    ##  Create wav files
    dt.wav_write('clean.wav', baches_test.clean.flatten(), fs=16000)
    dt.wav_write('input_segan.wav', baches_test.noisy.flatten(), fs=16000)
    dt.wav_write('output_segan.wav', output_t.d.flatten(), fs=16000)
    print('finish!')
def show():
    args = get_args()

    # Load model
    nn.load_parameters(args.model_load_path)
    params = nn.get_parameters()

    # Show heatmap
    for name, param in params.items():
        # SSL only on convolution weights
        if "conv/W" not in name:
            continue
        print(name)
        n, m, k0, k1 = param.d.shape
        w_matrix = param.d.reshape((n, m * k0 * k1))
        # Filter x Channel heatmap

        fig, ax = plt.subplots()
        ax.set_title(
            "{} with shape {} \n Filter x (Channel x Heigh x Width)".format(
                name, (n, m, k0, k1)))
        heatmap = ax.pcolor(w_matrix)
        fig.colorbar(heatmap)

        plt.pause(0.5)
        raw_input("Press Key")
        plt.close()
예제 #8
0
def load_parameters(params_path):
    if not os.path.isfile(params_path):
        from nnabla.utils.download import download
        url = os.path.join("https://nnabla.org/pretrained-models/nnabla-examples/eval_metrics/lpips",
                           params_path.split("/")[-1])
        download(url, params_path, False)
    nn.load_parameters(params_path)
예제 #9
0
def main(args):
    # Setting
    device_id = args.device_id
    conf = args.conf
    path = conf.data_path
    B = conf.batch_size
    R = conf.n_rays
    L = conf.layers
    D = conf.depth
    feature_size = conf.feature_size

    ctx = get_extension_context('cudnn', device_id=device_id)
    nn.set_default_context(ctx)

    # Dataset
    ds = DTUMVSDataSource(path, R, shuffle=True)

    # Monitor
    monitor_path = "/".join(args.model_load_path.split("/")[0:-1])
    monitor = Monitor(monitor_path)
    monitor_psnrs = MonitorSeries(f"PSNRs", monitor, interval=1)
    monitor_psnr = MonitorSeries(f"PSNR", monitor, interval=1)

    # Load model
    nn.load_parameters(args.model_load_path)

    # Evaluate
    image_list = []
    for pose, intrinsic, mask_obj in zip(ds.poses, ds.intrinsics, ds.masks):
        image = render(pose[np.newaxis, ...], intrinsic[np.newaxis, ...],
                       mask_obj[np.newaxis, ...], conf)
        image_list.append(image)

    metric = psnr(image_list, ds.images, ds.masks, monitor_psnrs)
    monitor_psnr.add(0, metric)
예제 #10
0
def test_transformer(config, netG, train_iterators, monitor, param_file):

    netG_A2B = netG['netG_A2B']

    train_iterator_src, train_iterator_trg = train_iterators

    # Load boundary image to get Variable shapes
    bod_map_A = train_iterator_src.next()[0]
    bod_map_B = train_iterator_trg.next()[0]
    real_bod_map_A = nn.Variable(bod_map_A.shape)
    real_bod_map_B = nn.Variable(bod_map_B.shape)
    real_bod_map_A.persistent, real_bod_map_B.persistent = True, True

    ################### Graph Construction ####################
    # Generator
    with nn.parameter_scope('netG_transformer'):
        with nn.parameter_scope('netG_A2B'):
            fake_bod_map_B = netG_A2B(
                real_bod_map_A, test=True,
                norm_type=config["norm_type"])  # (1, 15, 64, 64)
    fake_bod_map_B.persistent = True

    # load parameters of networks
    with nn.parameter_scope('netG_transformer'):
        with nn.parameter_scope('netG_A2B'):
            nn.load_parameters(param_file)

    monitor_vis = nm.MonitorImage('result',
                                  monitor,
                                  interval=config["test"]["vis_interval"],
                                  num_images=1,
                                  normalize_method=lambda x: x)

    # Test
    i = 0
    iter_per_epoch = train_iterator_src.size // config["test"]["batch_size"] + 1

    if config["num_test"]:
        num_test = config["num_test"]
    else:
        num_test = train_iterator_src.size

    for _ in range(iter_per_epoch):
        bod_map_A = train_iterator_src.next()[0]
        bod_map_B = train_iterator_trg.next()[0]
        real_bod_map_A.d, real_bod_map_B.d = bod_map_A, bod_map_B

        # Generate fake images
        fake_bod_map_B.forward(clear_buffer=True)

        i += 1

        images_to_visualize = [
            real_bod_map_A.d, fake_bod_map_B.d, real_bod_map_B.d
        ]
        visuals = combine_images(images_to_visualize)
        monitor_vis.add(i, visuals)

        if i > num_test:
            break
예제 #11
0
def evaluate(path):
    nn.load_parameters(os.path.join(path, "params.h5"))

    # Create embedded network
    batch_size = 500
    image = nn.Variable([batch_size, 1, 28, 28])
    feature = I.mnist_lenet_feature(image, test=True)

    # Process all images
    features = []
    labels = []

    # Prepare MNIST data iterator
    rng = np.random.RandomState(313)
    data = I.data_iterator_mnist(batch_size, train=False, shuffle=True, rng=rng)

    for i in range(10000 // batch_size):
        image_data, label_data = data.next()
        image.d = image_data / 255.0
        feature.forward(clear_buffer=True)
        features.append(feature.d.copy())
        labels.append(label_data.copy())
    features = np.vstack(features)
    labels = np.vstack(labels)
    df = pd.DataFrame(features, columns=["x", "y"])
    df["label"] = labels

    return df
예제 #12
0
def test_save_load_parameters():
    v = nn.Variable([64, 1, 28, 28], need_grad=False)
    with nn.parameter_scope("param1"):
        with nn.parameter_scope("conv1"):
            h = PF.convolution(v, 32, (3, 3))
            b = PF.batch_normalization(h, batch_stat=True)
        with nn.parameter_scope("conv2"):
            h1 = PF.convolution(v, 32, (3, 3))
            b2 = PF.batch_normalization(h1, batch_stat=True)

    for k, v in iteritems(nn.get_parameters(grad_only=False)):
        v.data.cast(np.float32)[...] = np.random.randn(*v.shape)

    with nn.parameter_scope("param1"):
        param1 = nn.get_parameters(grad_only=False)
        nn.save_parameters("tmp.h5")
        nn.save_parameters("tmp.protobuf")

    with nn.parameter_scope("param2"):
        nn.load_parameters('tmp.h5')
        param2 = nn.get_parameters(grad_only=False)

    with nn.parameter_scope("param3"):
        nn.load_parameters('tmp.protobuf')
        param3 = nn.get_parameters(grad_only=False)

    for par2 in [param2, param3]:
        assert param1.keys() == par2.keys()  # Check order
        for (n1, p1), (n2, p2) in zip(sorted(param1.items()), sorted(par2.items())):
            assert n1 == n2
            assert np.all(p1.d == p2.d)
            assert p1.data.dtype == p2.data.dtype
            assert p1.need_grad == p2.need_grad
예제 #13
0
def main(args):
    if args.gpu:
        ctx = get_extension_context('cudnn', device_id=str(args.device))
        nn.set_default_context(ctx)

    # atari environment
    num_envs = args.num_envs
    envs = [gym.make(args.env) for _ in range(num_envs)]
    batch_env = BatchEnv([AtariWrapper(env, args.seed) for env in envs])
    eval_env = AtariWrapper(gym.make(args.env), 50, episodic=False)
    num_actions = envs[0].action_space.n

    # action-value function built with neural network
    lr_scheduler = learning_rate_scheduler(args.lr, 10**7)
    model = A2C(num_actions, num_envs, num_envs * args.time_horizon,
                args.v_coeff, args.ent_coeff, lr_scheduler)
    if args.load is not None:
        nn.load_parameters(args.load)

    logdir = prepare_directory(args.logdir)

    eval_fn = evaluate(eval_env, model, args.render)

    # start training loop
    return_fn = compute_returns(args.gamma)
    train_loop(batch_env, model, num_actions, return_fn, logdir, eval_fn, args)
예제 #14
0
def main(args):
    # Setting
    device_id = args.device_id
    conf = args.conf
    path = conf.data_path
    B = conf.batch_size
    R = conf.n_rays
    L = conf.layers
    D = conf.depth
    feature_size = conf.feature_size

    ctx = get_extension_context('cudnn', device_id=device_id)
    nn.set_default_context(ctx)

    # Dataset
    ds = DTUMVSDataSource(path, R, shuffle=True)

    # Monitor
    monitor_path = "/".join(args.model_load_path.split("/")[0:-1])
    monitor = Monitor(monitor_path)
    monitor_image = MonitorImage(
        f"Rendered image synthesis", monitor, interval=1)

    # Load model
    nn.load_parameters(args.model_load_path)

    # Render
    pose = ds.poses[conf.valid_index:conf.valid_index+1, ...]
    intrinsic = ds.intrinsics[conf.valid_index:conf.valid_index+1, ...]
    mask_obj = ds.masks[conf.valid_index:conf.valid_index+1, ...]
    image = render(pose, intrinsic, mask_obj, conf)
    monitor_image.add(conf.valid_index, image)
예제 #15
0
def meta_test(args, shape_x, test_data):

    # Build episode generators
    test_episode_generator = EpisodeGenerator(
        test_data[0], test_data[1], args.n_class, args.n_shot, args.n_query)

    # Build prototypical network
    xs_v = nn.Variable((args.n_class * args.n_shot, ) + shape_x)
    xq_v = nn.Variable((args.n_class * args.n_query, ) + shape_x)
    hq_v = net(args.n_class, xs_v, xq_v, args.embedding,
               args.net_type, args.metric, True)
    yq_v = nn.Variable((args.n_class * args.n_query, 1))
    err_v = F.mean(F.top_n_error(hq_v, yq_v, n=1))

    # Load parameters
    nn.load_parameters(args.work_dir + "/params.h5")

    # Evaluate error rate
    v_errs = []
    for k in range(args.n_episode_for_test):
        xs_v.d, xq_v.d, yq_v.d = test_episode_generator.next()
        err_v.forward(clear_no_need_grad=True, clear_buffer=True)
        v_errs.append(np.float(err_v.d.copy()))
    v_err_mean = np.mean(v_errs)
    v_err_std = np.std(v_errs)
    v_err_conf = 1.96 * v_err_std / np.sqrt(args.n_episode_for_test)

    # Monitor error rate
    monitor = Monitor(args.work_dir)
    monitor_test_err = MonitorSeries("Test error", monitor)
    monitor_test_conf = MonitorSeries("Test error confidence", monitor)
    monitor_test_err.add(0, v_err_mean * 100)
    monitor_test_conf.add(0, v_err_conf * 100)

    return v_err_mean, v_err_conf
예제 #16
0
def generate(args):
    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size

    # Generator
    nn.load_parameters(args.model_load_path)
    z_test = nn.Variable([batch_size, latent])
    x_test = generator(z_test, maps=maps, test=True, up=args.up)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_image_tile_test = MonitorImageTile("Image Tile Generated",
                                               monitor,
                                               num_images=batch_size,
                                               interval=1,
                                               normalize_method=denormalize)

    # Generation iteration
    for i in range(args.num_generation):
        z_test.d = np.random.randn(batch_size, latent)
        x_test.forward(clear_buffer=True)
        monitor_image_tile_test.add(i, x_test)
예제 #17
0
def _load_nnp_to_proto(nnp_path):
    import google.protobuf.text_format as text_format
    import tempfile
    import zipfile
    import shutil
    proto = nnabla_pb2.NNablaProtoBuf()

    tmpdir = tempfile.mkdtemp()
    try:
        with zipfile.ZipFile(nnp_path, "r") as nnp:
            for name in nnp.namelist():
                _, ext = os.path.splitext(name)
                if name == "nnp_version.txt":
                    pass  # Currently nnp_version.txt is ignored
                elif ext in [".nntxt", ".prototxt"]:
                    nnp.extract(name, tmpdir)
                    with open(os.path.join(tmpdir, name), "rt") as f:
                        text_format.Merge(f.read(), proto)
                elif ext in [".protobuf", ".h5"]:
                    nnp.extract(name, tmpdir)
                    nn.load_parameters(os.path.join(tmpdir, name))
    finally:
        shutil.rmtree(tmpdir)

    return proto
예제 #18
0
 def __init__(self, config):
     self._cols_size = config.columns_size
     self._x_length = config.x_length
     self._model_params_path = config.mlp_model_params_path
     self._x = nn.Variable([1, self._x_length, self._cols_size])
     self._pred = self.network(self._x)
     nn.load_parameters(self._model_params_path)
예제 #19
0
def main(args):
    env = gym.make(args.env)
    env.seed(args.seed)
    eval_env = gym.make(args.env)
    eval_env.seed(50)
    action_shape = env.action_space.shape

    # GPU
    if args.gpu:
        ctx = get_extension_context('cudnn', device_id=str(args.device))
        nn.set_default_context(ctx)

    if args.load:
        nn.load_parameters(args.load)

    model = SAC(env.observation_space.shape, action_shape[0], args.batch_size,
                args.critic_lr, args.actor_lr, args.temp_lr, args.tau,
                args.gamma)
    model.sync_target()

    buffer = ReplayBuffer(args.buffer_size, args.batch_size)

    monitor = prepare_monitor(args.logdir)

    update_fn = update(model, buffer)

    eval_fn = evaluate(eval_env, model, render=args.render)

    train(env, model, buffer, EmptyNoise(), monitor, update_fn, eval_fn,
          args.final_step, args.batch_size, 1, args.save_interval,
          args.evaluate_interval, ['critic_loss', 'actor_loss', 'temp_loss'])
예제 #20
0
    def __init__(self, param_path=None):
        assert os.path.isfile(
            param_path), "pretrained VGG19 weights not found."
        self.h5_file = param_path
        if not os.path.exists(self.h5_file):
            print(
                "Pretrained VGG19 parameters not found. Downloading. Please wait..."
            )
            url = "https://nnabla.org/pretrained-models/nnabla-examples/GANs/first-order-model/vgg19.h5"
            from nnabla.utils.data_source_loader import download
            download(url, url.split('/')[-1], False)

        with nn.parameter_scope("VGG19"):
            logger.info('loading vgg19 parameters...')
            nn.load_parameters(self.h5_file)
            # drop all the affine layers.
            drop_layers = [
                'classifier/0/affine', 'classifier/3/affine',
                'classifier/6/affine'
            ]
            for layers in drop_layers:
                nn.parameter.pop_parameter((layers + '/W'))
                nn.parameter.pop_parameter((layers + '/b'))
            self.mean = nn.Variable.from_numpy_array(
                np.asarray([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1))
            self.std = nn.Variable.from_numpy_array(
                np.asarray([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1))
예제 #21
0
def test_parameter_file_load_save():
    module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32),
                                                    (4, 3, 32, 32)])
    proto_variable_inputs = module_creator.get_proto_variable_inputs()
    outputs = module_creator.module(*proto_variable_inputs)
    g = nn.graph_def.get_default_graph_by_variable(outputs)

    with create_temp_with_dir(nnp_file) as tmp_file:
        g.save(tmp_file)
        another = TSTNetNormal()
        variable_inputs = module_creator.get_variable_inputs()
        outputs = g(*variable_inputs)
        ref_outputs = another(*variable_inputs)

        # Should not equal
        with pytest.raises(AssertionError) as excinfo:
            forward_variable_and_check_equal(outputs, ref_outputs)

        # load to local scope
        with nn.parameter_scope('', another.parameter_scope):
            nn.load_parameters(tmp_file)

        another.update_parameter()

        ref_outputs = another(*variable_inputs)
        forward_variable_and_check_equal(outputs, ref_outputs)
예제 #22
0
def main():
    ctx = get_extension_context('cudnn', device_id=args.gpus)
    nn.set_default_context(ctx)
    image_left = imread(args.left_image)
    image_right = imread(args.right_image)

    if args.dataset == 'Kitti':
        var_left = nn.Variable((1, 3, args.im_height_kt, args.im_width_kt))
        var_right = nn.Variable((1, 3, args.im_height_kt, args.im_width_kt))
        img_left, img_right = preprocess_kitti(image_left, image_right)
    elif args.dataset == 'SceneFlow':
        var_left = nn.Variable((1, 3, args.im_height_sf, args.im_width_sf))
        var_right = nn.Variable((1, 3, args.im_height_sf, args.im_width_sf))
        img_left, img_right = preprocess_sceneflow(image_left, image_right)

    var_left.d, var_right.d = img_left, img_right
    if args.loadmodel is not None:
        # Loading CNN pretrained parameters.
        nn.load_parameters(args.loadmodel)
    pred_test = psm_net(var_left, var_right, args.maxdisp, False)
    pred_test.forward(clear_buffer=True)
    pred = pred_test.d
    pred = np.squeeze(pred, axis=1)
    pred = pred[0]
    pred = 2*(pred - np.min(pred))/np.ptp(pred)-1
    scipy.misc.imsave('stereo_depth.png', pred)

    print("Done")
예제 #23
0
def test_parameter_file_load_save_for_file_object(memory_buffer_format):
    module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32),
                                                    (4, 3, 32, 32)])
    variable_inputs = module_creator.get_variable_inputs()
    a_module = module_creator.module
    outputs = a_module(*variable_inputs)
    another = TSTNetNormal()
    ref_outputs = another(*variable_inputs)
    extension = memory_buffer_format

    # Should not equal
    with pytest.raises(AssertionError) as excinfo:
        forward_variable_and_check_equal(outputs, ref_outputs)

    with io.BytesIO() as param_file:
        nn.save_parameters(param_file,
                           a_module.get_parameters(),
                           extension=extension)
        # load from file
        with nn.parameter_scope('', another.parameter_scope):
            nn.load_parameters(param_file, extension=extension)
        another.update_parameter()

    ref_outputs = another(*variable_inputs)

    # should equal
    forward_variable_and_check_equal(outputs, ref_outputs)
예제 #24
0
def test_parameter_file_load_save_for_files(parameter_file):
    module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32),
                                                    (4, 3, 32, 32)])
    variable_inputs = module_creator.get_variable_inputs()
    a_module = module_creator.module
    outputs = a_module(*variable_inputs)
    another = TSTNetNormal()
    ref_outputs = another(*variable_inputs)

    # Should not equal
    with pytest.raises(AssertionError) as excinfo:
        forward_variable_and_check_equal(outputs, ref_outputs)

    with create_temp_with_dir(parameter_file) as tmp_file:
        # save to file
        nn.save_parameters(tmp_file, a_module.get_parameters())

        # load from file
        with nn.parameter_scope('', another.parameter_scope):
            nn.load_parameters(tmp_file)
    another.update_parameter()

    ref_outputs = another(*variable_inputs)

    # should equal
    forward_variable_and_check_equal(outputs, ref_outputs)
예제 #25
0
def main(args):
    env = gym.make(args.env)
    env.seed(args.seed)
    eval_env = gym.make(args.env)
    eval_env.seed(50)
    action_shape = env.action_space.shape

    # GPU
    if args.gpu:
        ctx = get_extension_context('cudnn', device_id=str(args.device))
        nn.set_default_context(ctx)

    if args.load:
        nn.load_parameters(args.load)

    model = TD3(env.observation_space.shape, action_shape[0], args.batch_size,
                args.critic_lr, args.actor_lr, args.tau, args.gamma,
                args.target_reg_sigma, args.target_reg_clip)
    model.sync_target()

    noise = NormalNoise(np.zeros(action_shape),
                        args.exploration_sigma + np.zeros(action_shape))

    buffer = ReplayBuffer(args.buffer_size, args.batch_size)

    monitor = prepare_monitor(args.logdir)

    update_fn = update(model, buffer, args.update_actor_freq)

    eval_fn = evaluate(eval_env, model, render=args.render)

    train(env, model, buffer, noise, monitor, update_fn, eval_fn,
          args.final_step, args.batch_size, 1, args.save_interval,
          args.evaluate_interval, ['critic_loss', 'actor_loss'])
예제 #26
0
def load_parameters_and_config(path):
    '''
    Load paramters and deduce the configuration
    of memory layout and input channels

    Returns: (channel_last, input_channels)

    '''
    nn.load_parameters(path)
    try:
        conv1 = nn.parameter.get_parameter('conv1/conv/W')
    except:
        raise ValueError(
            'conv1/conv/W is not found. This parameter configuration deduction works for resnet only.'
        )
    shape = conv1.shape
    assert shape[1] == 7 or shape[
        3] == 7, 'This deduction process assumes that the first convolution has 7x7 filter.'
    channel_last = False
    channels = shape[1]
    if shape[1] == 7:
        channel_last = True
        channels = shape[3]
    assert channels in (3, 4), f'channels must be either 3 or 4: {channels}.'
    return channel_last, channels
def main():
    args = get_args()

    nn.load_parameters(args.input)
    params = nn.get_parameters(grad_only=False)

    processed = False

    # Convert memory layout
    layout = get_memory_layout(params)
    if args.memory_layout is None:
        pass
    elif args.memory_layout != layout:
        logger.info(f'Converting memory layout to {args.memory_layout}.')
        convert_memory_layout(params, args.memory_layout)
        processed |= True
    else:
        logger.info('No need to convert memory layout.')

    if args.force_3_channels:
        ret = force_3_channels(params, args.memory_layout)
        if ret:
            logger.info('Converted first conv to 3-channel input.')
        processed |= ret

    if not processed:
        logger.info(
            'No change has been made for the input. Not saving a new parameter file.'
        )
        return
    logger.info(f'Save a new parameter file at {args.output}')
    for key, param in params.items():
        nn.parameter.set_parameter(key, param)
    nn.save_parameters(args.output)
예제 #28
0
    def get_estimates(self,
                       input_path: str,
                       parts,
                       fft_size=4096,
                       hop_size=1024,
                       n_channels=2,
                       apply_mwf_flag=True,
                       ch_flip_average=True):
        # Set NNabla extention
        ctx = get_extension_context(self.context)
        nn.set_default_context(ctx)

        # Load the model weights
        nn.load_parameters(str(self.model_file_path))

        # Read file locally
        if settings.DEFAULT_FILE_STORAGE == 'api.storage.FileSystemStorage':
            _, inp_stft = generate_data(input_path, fft_size,
                                                  hop_size, n_channels, self.sample_rate)
        else:
            # If remote, download to temp file and load audio
            fd, tmp_path = tempfile.mkstemp()
            try:
                r_get = requests.get(input_path)
                with os.fdopen(fd, 'wb') as tmp:
                    tmp.write(r_get.content)

                _, inp_stft = generate_data(tmp_path, fft_size, hop_size,
                                            n_channels, self.sample_rate)
            finally:
                # Remove temp file
                os.remove(tmp_path)

        out_stfts = {}
        estimates = {}
        inp_stft_contiguous = np.abs(np.ascontiguousarray(inp_stft))

        # Need to compute all parts even for static mix, for mwf?
        for part in parts:
            print(f'Processing {part}...')

            with open('./config/d3net/{}.yaml'.format(part)) as file:
                # Load part specific Hyper parameters
                hparams = yaml.load(file, Loader=yaml.FullLoader)

            with nn.parameter_scope(part):
                out_sep = model_separate(
                    inp_stft_contiguous, hparams, ch_flip_average=ch_flip_average)
                out_stfts[part] = out_sep * np.exp(1j * np.angle(inp_stft))

        if apply_mwf_flag:
            out_stfts = apply_mwf(out_stfts, inp_stft)

        for part, output in out_stfts.items():
            if not parts[part]:
                continue
            estimates[part] = stft2time_domain(output, hop_size, True)

        return estimates
예제 #29
0
def main():
    """
        Inference function to generate SR images.
    """
    nn.load_parameters(args.model)
    # Inference data loader
    inference_data = inference_data_loader(args.input_dir_lr)
    input_shape = [
        1,
    ] + list(inference_data.inputs[0].shape)
    output_shape = [1, input_shape[1] * 4, input_shape[2] * 4, 3]
    oh = input_shape[1] - input_shape[1] // 8 * 8
    ow = input_shape[2] - input_shape[2] // 8 * 8

    # Build the computation graph
    inputs_raw = nn.Variable(input_shape)
    pre_inputs = nn.Variable(input_shape)
    pre_gen = nn.Variable(output_shape)
    pre_warp = nn.Variable(output_shape)

    transposed_pre_warp = space_to_depth(pre_warp)
    inputs_all = F.concatenate(inputs_raw, transposed_pre_warp)
    with nn.parameter_scope("generator"):
        gen_output = generator(inputs_all, 3, args.num_resblock)
    outputs = (gen_output + 1) / 2
    inputs_frames = F.concatenate(pre_inputs, inputs_raw)
    with nn.parameter_scope("fnet"):
        flow_lr = flow_estimator(inputs_frames)
    flow_lr = F.pad(flow_lr, (0, 0, 0, oh, 0, ow, 0, 0), "reflect")
    flow_hr = upscale_four(flow_lr * 4.0)
    pre_gen_warp = warp_by_flow(pre_gen, flow_hr)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    max_iter = len(inference_data.inputs)
    print('Frame evaluation starts!!')
    pre_inputs.d, pre_gen.d, pre_warp.d = 0, 0, 0
    for i in range(max_iter):
        inputs_raw.d = np.array([inference_data.inputs[i]]).astype(np.float32)
        if i != 0:
            pre_gen_warp.forward()
            pre_warp.data.copy_from(pre_gen_warp.data)
        outputs.forward()
        output_frame = outputs.d

        if i >= 5:
            name, _ = os.path.splitext(
                os.path.basename(str(inference_data.paths_lr[i])))
            filename = args.output_name + '_' + name
            print('saving image %s' % filename)
            out_path = os.path.join(args.output_dir,
                                    "%s.%s" % (filename, args.output_ext))
            save_img(out_path, output_frame[0])
        else:  # First 5 is a hard-coded symmetric frame padding, ignored but time added!
            print("Warming up %d" % (5 - i))

        pre_inputs.data.copy_from(inputs_raw.data)
        pre_gen.data.copy_from(outputs.data)
 def importNetwork(self, fname):
     print("Import Q-network from {}".format(fname))
     nn.load_parameters(fname + '.h5')
     print "Updating target Q-network"
     self.update_Q_target()
     print '--------------------------------------------------'
     print nn.get_parameters()
     print '--------------------------------------------------'
예제 #31
0
def evaluate(args):
    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    n_classes = args.n_classes
    not_sn = args.not_sn

    # Model (Inception model) from nnp file
    nnp = NnpLoader(args.nnp_inception_model_load_path)
    x, y = get_input_and_output(nnp, args.batch_size, name=args.variable_name)

    if args.evaluation_metric == "IS":
        is_model = None
        compute_metric = compute_inception_score
    if args.evaluation_metric == "FID":
        di = data_iterator_imagenet(args.valid_dir,
                                    args.dirname_to_label_path,
                                    batch_size=args.batch_size,
                                    ih=args.image_size,
                                    iw=args.image_size,
                                    shuffle=True,
                                    train=False,
                                    noise=False)
        compute_metric = functools.partial(compute_frechet_inception_distance,
                                           di=di)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_metric = MonitorSeries("{}".format(args.evaluation_metric),
                                   monitor,
                                   interval=1)

    # Compute the evaluation metric for all models
    def cmp_func(path):
        return int(path.split("/")[-1].strip("params_").rstrip(".h5"))
    model_load_path = sorted(glob.glob("{}/*.h5".format(args.model_load_path)), key=cmp_func) \
        if os.path.isdir(args.model_load_path) else \
        [args.model_load_path]

    for path in model_load_path:
        # Model (SAGAN)
        nn.load_parameters(path)
        z = nn.Variable([batch_size, latent])
        y_fake = nn.Variable([batch_size])
        x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\
            .apply(persistent=True)
        # Compute the evaluation metric
        score = compute_metric(z, y_fake, x_fake, x, y, args)
        itr = cmp_func(path)
        monitor_metric.add(itr, score)
예제 #32
0
파일: siamese.py 프로젝트: zwsong/nnabla
def visualize(args):
    """
    Visualizing embedded digits onto 2D space.
    """
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    batch_size = 500

    # Create default context.
    ctx = nn.Context(backend="cpu|cuda",
                     compute_backend="default|cudnn",
                     array_class="CudaArray",
                     device_id="{}".format(args.device_id))

    # Load parameters
    nn.load_parameters(os.path.join(args.model_save_path,
                                    'params_%06d.h5' % args.max_iter))

    # Create embedder network
    image = nn.Variable([batch_size, 1, 28, 28])
    feature = mnist_lenet_feature(image, test=False)

    # Process all images
    features = []
    labels = []
    # Prepare MNIST data iterator

    rng = np.random.RandomState(313)
    data = data_iterator_mnist(batch_size, train=False, shuffle=True, rng=rng)
    for i in range(10000 // batch_size):
        image_data, label_data = data.next()
        image.d = image_data / 255.
        feature.forward(clear_buffer=True)
        features.append(feature.d.copy())
        labels.append(label_data.copy())
    features = np.vstack(features)
    labels = np.vstack(labels)

    # Visualize
    f = plt.figure(figsize=(16, 9))
    for i in range(10):
        c = plt.cm.Set1(i / 10.)
        plt.plot(features[labels.flat == i, 0].flatten(), features[
                 labels.flat == i, 1].flatten(), '.', c=c)
    plt.legend(map(str, range(10)))
    plt.grid()
    plt.savefig(os.path.join(args.monitor_path, "embed.png"))
예제 #33
0
def main():
    HERE = os.path.dirname(__file__)
    # Import MNIST data
    sys.path.append(
        os.path.realpath(os.path.join(HERE, '..', '..', 'vision', 'mnist')))
    from mnist_data import data_iterator_mnist
    from args import get_args
    from classification import mnist_lenet_prediction, mnist_resnet_prediction

    args = get_args(description=__doc__)

    mnist_cnn_prediction = mnist_lenet_prediction
    if args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction

    # Infer parameter file name and read it.
    model_save_path = os.path.join('../../vision/mnist',
                                   args.model_save_path)
    parameter_file = os.path.join(
        model_save_path,
        '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    try:
        nn.load_parameters(parameter_file)
    except IOError:
        logger.error("Run classification.py before runnning this script.")
        exit(1)

    # Create a computation graph to be saved.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    pred = mnist_cnn_prediction(image, test=True)

    # Save NNP file (used in C++ inference later.).
    nnp_file = '{}_{:06}.nnp'.format(args.net, args.max_iter)
    runtime_contents = {
        'networks': [
            {'name': 'runtime',
             'batch_size': args.batch_size,
             'outputs': {'y': pred},
             'names': {'x': image}}],
        'executors': [
            {'name': 'runtime',
             'network': 'runtime',
             'data': ['x'],
             'output': ['y']}]}
    nn.utils.save.save(nnp_file, runtime_contents)
예제 #34
0
파일: test_graph.py 프로젝트: zwsong/nnabla
def test_graph_clear_buffer(seed):
    np.random.seed(313)
    rng = np.random.RandomState(seed)
    x = nn.Variable([2, 3, 4, 4])
    t = nn.Variable([2, 1])
    x.d = rng.randn(*x.shape)
    t.d = rng.randint(0, 5, size=t.shape)

    # Network definition
    nn.set_default_context(nn.Context())
    nn.clear_parameters()
    x1 = x + 1
    x2 = x1 - 1
    with nn.parameter_scope('conv1'):
        z = PF.convolution(x2, 3, (2, 2))
        z2 = F.relu(z, inplace=True)
    with nn.parameter_scope('fc2'):
        z3 = PF.affine(z2, 5)
    l = F.softmax_cross_entropy(z3, t, 1)
    L = F.mean(l)

    # Forwardprop
    import tempfile
    import os
    tmpd = tempfile.mkdtemp()
    nn.save_parameters(os.path.join(tmpd, 'parameter.h5'))
    first = False
    for cnng in [False, True]:
        for cb in [False, True]:
            _ = nn.load_parameters(os.path.join(tmpd, 'parameter.h5'))
            for v in nn.get_parameters().values():
                v.grad.zero()
            L.forward(clear_no_need_grad=cnng)
            L.backward(clear_buffer=cb)
            if not first:
                first = True
                g = list(nn.get_parameters().values())[0].g.copy()
            else:
                g2 = list(nn.get_parameters().values())[0].g.copy()
                assert np.all(g == g2)
예제 #35
0
파일: load.py 프로젝트: zwsong/nnabla
def load(filenames, prepare_data_iterator=True):
    '''load
    Load network information from files.

    Args:
        filenames (list): List of filenames.
    Returns:
        dict: Network infomation.
    '''
    class Info:
        pass
    info = Info()

    proto = nnabla_pb2.NNablaProtoBuf()
    for filename in filenames:
        _, ext = os.path.splitext(filename)
        if 'txt' in ext:
            with open(filename, 'rt') as f:
                text_format.Merge(f.read(), proto)
        elif ext in ['.protobuf', '.h5']:
            nn.load_parameters(filename, proto)
        elif ext == '.nnp':
            tmpdir = tempfile.mkdtemp()
            with zipfile.ZipFile(filename, 'r') as nnp:
                for name in nnp.namelist():
                    nnp.extract(name, tmpdir)
                    _, ext = os.path.splitext(name)
                    if 'txt' in ext:
                        with open(os.path.join(tmpdir, name), 'rt') as f:
                            text_format.Merge(f.read(), proto)
                    elif ext in ['.protobuf', '.h5']:
                        nn.load_parameters(os.path.join(tmpdir, name), proto)
            shutil.rmtree(tmpdir)

    default_context = None
    if proto.HasField('global_config'):
        info.global_config = _global_config(proto)
        default_context = info.global_config.default_context
    else:
        default_context = nn.context()

    if proto.HasField('training_config'):
        info.training_config = _training_config(proto)

    if len(proto.dataset) > 0:
        info.datasets = _datasets(proto, prepare_data_iterator)

    if len(proto.network) > 0:
        info.networks = _networks(proto, default_context)

    if len(proto.optimizer) > 0:
        info.optimizers = _optimizers(
            proto, default_context, info.networks, info.datasets)

    if len(proto.monitor) > 0:
        info.monitors = _monitors(
            proto, default_context, info.networks, info.datasets)

    if len(proto.executor) > 0:
        info.executors = _executors(proto, info.networks)

    return info
예제 #36
0
파일: exp016.py 프로젝트: kzky/works
def main(args):
    # Settings
    device_id = args.device_id
    batch_size = 100
    batch_size_eval = 100
    n_l_train_data = 4000
    n_train_data = 50000
    n_cls = 10
    learning_rate = 1. * 1e-3
    n_epoch = 300
    act = F.relu
    iter_epoch = n_train_data / batch_size
    n_iter = n_epoch * iter_epoch
    extension_module = args.context

    # Model
    ## supervised 
    batch_size, m, h, w = batch_size, 3, 32, 32
    ctx = extension_context(extension_module, device_id=device_id)
    x_l = nn.Variable((batch_size, m, h, w))
    y_l = nn.Variable((batch_size, 1))
    pred = cnn_model_003(ctx, x_l)
    loss_ce = ce_loss(ctx, pred, y_l)
    loss_er = er_loss(ctx, pred)
    loss_supervised = loss_ce + loss_er

    ## stochastic regularization
    x_u0 = nn.Variable((batch_size, m, h, w), need_grad=False)
    x_u1 = nn.Variable((batch_size, m, h, w), need_grad=False)
    pred_x_u0 = cnn_model_003(ctx, x_u0)
    pred_x_u1 = cnn_model_003(ctx, x_u1)
    loss_sr = sr_loss(ctx, pred_x_u0, pred_x_u1)
    loss_er0 = er_loss(ctx, pred_x_u0)
    loss_er1 = er_loss(ctx, pred_x_u1)
    loss_unsupervised = loss_sr + loss_er0 + loss_er1

    ## autoencoder
    path = args.model_path
    nn.load_parameters(path)
    x_u0_rc = cnn_ae_model_000(ctx, x_u0, act=F.relu, test=True)
    x_u1_rc = cnn_ae_model_000(ctx, x_u1, act=F.relu, test=True)
    x_u0_rc.need_grad = False
    x_u1_rc.need_grad = False
    pred_x_u0_rc = cnn_model_003(ctx, x_u0_rc, test=False)
    pred_x_u1_rc = cnn_model_003(ctx, x_u1_rc, test=False)
    loss_sr_rc = sr_loss(ctx, pred_x_u0_rc, pred_x_u1_rc)
    loss_er0_rc = er_loss(ctx, pred_x_u0_rc)
    loss_er1_rc = er_loss(ctx, pred_x_u1_rc)
    loss_unsupervised_rc = loss_sr_rc + loss_er0_rc + loss_er1_rc
    loss_unsupervised += loss_unsupervised_rc

    ## evaluate
    batch_size_eval, m, h, w = batch_size, 3, 32, 32
    x_eval = nn.Variable((batch_size_eval, m, h, w))
    pred_eval = cnn_model_003(ctx, x_eval, test=True)
    
    # Solver
    with nn.context_scope(ctx):
        solver = S.Adam(alpha=learning_rate)
        solver.set_parameters(nn.get_parameters())

    # Dataset
    ## separate dataset
    home = os.environ.get("HOME")
    fpath = os.path.join(home, "datasets/cifar10/cifar-10.npz")
    separator = Separator(n_l_train_data)
    separator.separate_then_save(fpath)

    l_train_path = os.path.join(home, "datasets/cifar10/l_cifar-10.npz")
    u_train_path = os.path.join(home, "datasets/cifar10/cifar-10.npz")
    test_path = os.path.join(home, "datasets/cifar10/cifar-10.npz")

    # data reader
    data_reader = Cifar10DataReader(l_train_path, u_train_path, test_path,
                                  batch_size=batch_size,
                                  n_cls=n_cls,
                                  da=True,
                                  shape=True)

    # Training loop
    print("# Training loop")
    epoch = 1
    st = time.time()
    acc_prev = 0.
    for i in range(n_iter):
        # Get data and set it to the varaibles
        x_l0_data, x_l1_data, y_l_data = data_reader.get_l_train_batch()
        x_u0_data, x_u1_data, y_u_data = data_reader.get_u_train_batch()
        
        x_l.d, _ , y_l.d= x_l0_data, x_l1_data, y_l_data
        x_u0.d, x_u1.d= x_u0_data, x_u1_data

        # Train
        loss_supervised.forward(clear_no_need_grad=True)
        solver.zero_grad()
        loss_supervised.backward(clear_buffer=True)
        solver.update()
        loss_unsupervised.forward(clear_no_need_grad=True)
        solver.zero_grad()
        loss_unsupervised.backward(clear_buffer=True)
        solver.update()
        
        # Evaluate
        if (i+1) % iter_epoch == 0:
            # Get data and set it to the varaibles
            x_data, y_data = data_reader.get_test_batch()

            # Evaluation loop
            ve = 0.
            iter_val = 0
            for k in range(0, len(x_data), batch_size_eval):
                x_eval.d = x_data[k:k+batch_size_eval, :]
                label = y_data[k:k+batch_size_eval, :]
                pred_eval.forward(clear_buffer=True)
                ve += categorical_error(pred_eval.d, label)
                iter_val += 1
            msg = "Epoch:{},ElapsedTime:{},Acc:{:02f}".format(
                epoch,
                time.time() - st, 
                (1. - ve / iter_val) * 100)
            print(msg)
            st = time.time()
            epoch +=1
예제 #37
0
파일: exp082.py 프로젝트: kzky/works
def main(args):
    # Settings
    device_id = args.device_id
    batch_size = args.batch_size
    batch_size_eval = args.batch_size_eval
    n_l_train_data = 4000
    n_train_data = 50000
    n_cls = 10
    learning_rate = 1. * 1e-3
    n_epoch = 300
    act = F.relu
    iter_epoch = int(n_train_data / batch_size)
    n_iter = n_epoch * iter_epoch
    extension_module = args.context

    # Model
    ## supervised resnet
    batch_size, m, h, w = batch_size, 3, 32, 32
    ctx = extension_context(extension_module, device_id=device_id)
    x_l = nn.Variable((batch_size, m, h, w))
    y_l = nn.Variable((batch_size, 1))
    pred_res = cifar10_resnet23_prediction(ctx, "resnet", x_l)
    loss_res_ce = ce_loss(ctx, pred_res, y_l)
    loss_res_supervised = loss_res_ce

    ## stochastic regularization
    nn.load_parameters(args.model_load_path)
    x_u0 = nn.Variable((batch_size, m, h, w))
    x_u0.persistent = True
    pred_x_u0, log_var0 = cnn_model_003(ctx, x_u0)
    pred_x_u0.need_grad, log_var0.need_grad = False, False

    ## knowledge transfer for resnet
    pred_res_x_u0 = cifar10_resnet23_prediction(ctx, "resnet", x_u0)
    loss_res_unsupervised = kl_divergence(ctx, pred_res_x_u0, pred_x_u0, log_var0)

    ## evaluate
    batch_size_eval, m, h, w = batch_size, 3, 32, 32
    x_eval = nn.Variable((batch_size_eval, m, h, w))
    pred_res_eval = cifar10_resnet23_prediction(ctx, "resnet", x_eval, test=True)

    # Solver
    with nn.context_scope(ctx):
        with nn.parameter_scope("resnet"):
            solver_res = S.Adam(alpha=learning_rate)
            solver_res.set_parameters(nn.get_parameters())

    # Dataset
    ## separate dataset
    home = os.environ.get("HOME")
    fpath = os.path.join(home, "datasets/cifar10/cifar-10.npz")
    separator = Separator(n_l_train_data)
    separator.separate_then_save(fpath)

    l_train_path = os.path.join(home, "datasets/cifar10/l_cifar-10.npz")
    u_train_path = os.path.join(home, "datasets/cifar10/cifar-10.npz")
    test_path = os.path.join(home, "datasets/cifar10/cifar-10.npz")

    # data reader
    data_reader = Cifar10DataReader(l_train_path, u_train_path, test_path,
                                  batch_size=batch_size,
                                  n_cls=n_cls,
                                  da=True,
                                  shape=True)

    # Training loop
    print("# Training loop")
    epoch = 1
    st = time.time()
    acc_prev = 0.
    for i in range(n_iter):
        # Get data and set it to the varaibles
        x_l0_data, x_l1_data, y_l_data = data_reader.get_l_train_batch()
        x_u0_data, x_u1_data, y_u_data = data_reader.get_u_train_batch()
        
        x_l.d, _ , y_l.d= x_l0_data, x_l1_data, y_l_data
        x_u0.d, _ = x_u0_data, x_u1_data

        # Train resnet
        loss_res_supervised.forward(clear_no_need_grad=True)
        loss_res_unsupervised.forward(clear_no_need_grad=True)
        solver_res.zero_grad()
        loss_res_supervised.backward(clear_buffer=True)
        loss_res_unsupervised.backward(clear_buffer=True)
        solver_res.update()

        # Evaluate
        if int((i+1) % iter_epoch) == 0:
            # Get data and set it to the varaibles
            x_data, y_data = data_reader.get_test_batch()

            # Evaluation loop for resnet
            ve = 0.
            iter_val = 0
            for k in range(0, len(x_data), batch_size_eval):
                x_eval.d = get_test_data(x_data, k, batch_size_eval)
                label = get_test_data(y_data, k, batch_size_eval)
                pred_res_eval.forward(clear_buffer=True)
                ve += categorical_error(pred_res_eval.d, label)
                iter_val += 1
            msg = "Model:resnet,Epoch:{},ElapsedTime:{},Acc:{:02f}".format(
                epoch,
                time.time() - st, 
                (1. - ve / iter_val) * 100)
            print(msg)

            st = time.time()
            epoch +=1