Exemplo n.º 1
0
 def test_hmc(self):
     # Chen et al used n_iters=80000, n_leapfrogs=50 and n_chains=1
     # we verified visually that their configuration works, and use a small-
     # scale config here to save time
     sampler = zs.HMC(step_size=0.01, n_leapfrogs=10)
     with self.session() as sess:
         e = sample_error_with(sampler, sess, n_chains=100, n_iters=1000)
         self.assertLessEqual(e, 0.030)
Exemplo n.º 2
0
def main(hps):
    tf.set_random_seed(hps.seed)
    np.random.seed(hps.seed)

    # Load data
    data_path = os.path.join(hps.data_dir, hps.dataset + '.data')
    data_func = dataset.data_dict()[hps.dataset]
    x_train, y_train, x_valid, y_valid, x_test, y_test = data_func(data_path)
    x_train = np.vstack([x_train, x_valid])
    y_train = np.hstack([y_train, y_valid])
    n_train, x_dim = x_train.shape
    x_train, x_test, mean_x_train, std_x_train = dataset.standardize(
        x_train, x_test)
    y_train, y_test, mean_y_train, std_y_train = dataset.standardize(
        y_train, y_test)

    # Define model parameters
    n_hiddens = hps.layers 

    # Build the computation graph
    x = tf.placeholder(tf.float32, shape=[None, x_dim])
    y = tf.placeholder(tf.float32, shape=[None])
    layer_sizes = [x_dim] + n_hiddens + [1]
    w_names = ["w" + str(i) for i in range(len(layer_sizes) - 1)]

    meta_model = build_model(x, layer_sizes, hps.n_particles, hps.fix_variance)

    def log_joint(bn):
        log_pws = bn.cond_log_prob(w_names)
        log_py_xw = bn.cond_log_prob('y')
        return tf.add_n(log_pws) + tf.reduce_mean(log_py_xw, 1) * n_train

    meta_model.log_joint = log_joint

    latent = {}
    for i, (n_in, n_out) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
        buf = tf.get_variable(
            'buf_'+str(i),
            initializer=init_bnn_weight(hps.n_particles, n_in, n_out))
        latent['w'+str(i)] = buf

    hmc = zs.HMC(step_size=hps.lr, n_leapfrogs=10, adapt_step_size=True)
    sample_op, hmc_info = hmc.sample(meta_model, observed={'y': y}, latent=latent)

    var_bn = meta_model.observe(**latent)
    log_joint = var_bn.log_joint()
    optimizer = tf.train.AdamOptimizer(learning_rate=hps.lr)
    global_step = tf.get_variable(
        'global_step', initializer=0, trainable=False)
    opt_op = optimizer.minimize(
        -log_joint, var_list=[var_bn.y_logstd], global_step=global_step)

    # prediction: rmse & log likelihood
    y_mean = var_bn["y_mean"]
    y_pred = tf.reduce_mean(y_mean, 0)
    rmse = tf.sqrt(tf.reduce_mean((y_pred - y) ** 2)) * std_y_train
    log_py_xw = var_bn.cond_log_prob("y")
    log_likelihood = tf.reduce_mean(zs.log_mean_exp(log_py_xw, 0)) - \
        tf.log(std_y_train)
    ystd_avg = var_bn.y_logstd

    # Define training/evaluation parameters
    epochs = hps.n_epoch
    batch_size = hps.batch_size
    iters = int(np.ceil(x_train.shape[0] / float(batch_size)))
    test_freq = hps.test_freq

    # Run the inference
    dump_buf = []
    with wrapped_supervisor.create_sv(hps, global_step=global_step) as sv:
        sess = sv.sess_
        for epoch in range(1, epochs + 1):
            lbs = []
            perm = np.arange(x_train.shape[0])
            np.random.shuffle(perm)
            x_train = x_train[perm]
            y_train = y_train[perm]
            for t in range(iters):
                x_batch = x_train[t * batch_size:(t + 1) * batch_size]
                y_batch = y_train[t * batch_size:(t + 1) * batch_size]
                _, _, accr = sess.run(
                    [sample_op, opt_op, hmc_info.acceptance_rate],
                    feed_dict={x: x_batch, y: y_batch})
                lbs.append(accr)
            if epoch % 10 == 0:
                print('Epoch {}: Acceptance rate = {}'.format(epoch, np.mean(lbs)))

            if epoch % test_freq == 0:
                test_rmse, test_ll = sess.run(
                    [rmse, log_likelihood],
                    feed_dict={x: x_test, y: y_test})
                print('>> TEST')
                print('>> Test rmse = {}, log_likelihood = {}'
                      .format(test_rmse, test_ll))

            if epoch>epochs //3 and epoch % hps.dump_freq == 0:
                dump_buf.append(sess.run(var_bn['y_mean'], {x:x_test, y: y_test}))

        if len(hps.dump_pred_dir) > 0:
            pred_out = sess.run([var_bn['y_mean'], var_bn.y_logstd], {x: x_test, y: y_test})
            pred_out[0] = np.concatenate(dump_buf, axis=0)
            pred_out[0] = pred_out[0] * std_y_train + mean_y_train
            pred_out[1] = np.exp(pred_out[1])
            f = lambda a, b: [a*std_x_train + mean_x_train, b*std_y_train + mean_y_train]
            todump = pred_out + f(x_test, y_test) + f(x_train, y_train)
            with open(hps.dump_pred_dir, 'wb') as fout:
                import pickle
                pickle.dump(todump, fout)
Exemplo n.º 3
0
    # Load nips dataset
    data_name = 'nips'
    data_path = os.path.join(conf.data_dir, data_name + '.pkl.gz')
    X, vocab = dataset.load_uci_bow(data_name, data_path)
    X_train = X[:1200, :]
    X_test = X[1200:, :]

    # Define model training/evaluation parameters
    D = 100
    K = 100
    V = X_train.shape[1]
    n_chains = 1

    num_e_steps = 5
    hmc = zs.HMC(step_size=1e-3,
                 n_leapfrogs=20,
                 adapt_step_size=True,
                 target_acceptance_rate=0.6)
    epochs = 100
    learning_rate_0 = 1.0
    t0 = 10

    # Padding
    rem = D - X_train.shape[0] % D
    if rem < D:
        X_train = np.vstack((X_train, np.zeros((rem, V))))

    T = np.sum(X_train)
    iters = X_train.shape[0] // D
    Eta = np.zeros((n_chains, X_train.shape[0], K), dtype=np.float32)
    Eta_mean = np.zeros(K, dtype=np.float32)
    Eta_logstd = np.zeros(K, dtype=np.float32)
Exemplo n.º 4
0
    kernel_width = 0.1
    n_chains = 1000
    n_iters = 200
    burnin = n_iters // 2
    n_leapfrogs = 5

    # Build the computation graph
    def log_joint(observed):
        model = gaussian(observed, n_x, stdev, n_chains)
        return model.local_log_prob('x')

    adapt_step_size = tf.placeholder(tf.bool, shape=[], name="adapt_step_size")
    adapt_mass = tf.placeholder(tf.bool, shape=[], name="adapt_mass")
    hmc = zs.HMC(step_size=1e-3,
                 n_leapfrogs=n_leapfrogs,
                 adapt_step_size=adapt_step_size,
                 adapt_mass=adapt_mass,
                 target_acceptance_rate=0.9)
    x = tf.Variable(tf.zeros([n_chains, n_x]), trainable=False, name='x')
    sample_op, hmc_info = hmc.sample(log_joint, {}, {'x': x})

    # Run the inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        samples = []
        print('Sampling...')
        for i in range(n_iters):
            _, x_sample, acc, ss = sess.run([
                sample_op, hmc_info.samples['x'], hmc_info.acceptance_rate,
                hmc_info.updated_step_size
            ],
Exemplo n.º 5
0
    def ais_log_prior(observed):
        z = observed['z']
        ret = vae({'z': z}, n, n_x, n_z, test_n_chains)
        model = ret[0]
        return model.local_log_prob('z')

    ret = vae({}, n, n_x, n_z, test_n_chains)
    model = ret[0]
    pz_samples = model.outputs('z')
    z = tf.Variable(tf.zeros([test_n_chains, test_batch_size, n_z]),
                    name="z",
                    trainable=False)
    hmc = zs.HMC(step_size=1e-6,
                 n_leapfrogs=test_n_leapfrogs,
                 adapt_step_size=True,
                 target_acceptance_rate=0.65,
                 adapt_mass=True)
    ais = AIS(ais_log_prior,
              log_joint, {'z': pz_samples},
              hmc,
              observed={'x': x_obs},
              latent={'z': z},
              n_chains=test_n_chains,
              n_temperatures=test_n_temperatures)

    model_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                       scope="model")
    variational_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                             scope="variational")
    saver = tf.train.Saver(max_to_keep=10,
def main():
    np.random.seed(1234)
    tf.set_random_seed(1237)
    M, N, train_data, valid_data, test_data, user_movie, user_movie_score, \
        movie_user, movie_user_score = dataset.load_movielens1m_mapped(
            os.path.join(conf.data_dir, 'ml-1m.zip'))

    # set configurations and hyper parameters
    N_train = train_data.shape[0]
    N_test = test_data.shape[0]
    N_valid = valid_data.shape[0]
    D = 30
    batch_size = 100000
    test_batch_size = 100000
    valid_batch_size = 100000
    K = 8
    num_steps = 500
    test_freq = 10
    valid_freq = 10
    train_iters = (N_train + batch_size - 1) // batch_size
    valid_iters = (N_valid + valid_batch_size - 1) // valid_batch_size
    test_iters = (N_test + test_batch_size - 1) // test_batch_size

    # paralleled
    chunk_size = 50
    N = (N + chunk_size - 1) // chunk_size
    N *= chunk_size
    M = (M + chunk_size - 1) // chunk_size
    M *= chunk_size

    # Selection
    select_u = tf.placeholder(tf.int32, shape=[None], name='s_u')
    select_v = tf.placeholder(tf.int32, shape=[None], name='s_v')
    subselect_u = tf.placeholder(tf.int32, shape=[None], name='ss_u')
    subselect_v = tf.placeholder(tf.int32, shape=[None], name='ss_v')
    alpha_u = 1.0
    alpha_v = 1.0
    alpha_pred = 0.2 / 4.0

    # Define samples as variables
    Us = []
    Vs = []
    for i in range(N // chunk_size):
        ui = tf.get_variable('u_chunk_%d' % i,
                             shape=[K, chunk_size, D],
                             initializer=tf.random_normal_initializer(0, 0.1),
                             trainable=False)
        Us.append(ui)
    for i in range(M // chunk_size):
        vi = tf.get_variable('v_chunk_%d' % i,
                             shape=[K, chunk_size, D],
                             initializer=tf.random_normal_initializer(0, 0.1),
                             trainable=False)
        Vs.append(vi)
    U = tf.concat(Us, axis=1)
    V = tf.concat(Vs, axis=1)

    # Define models for prediction
    true_rating = tf.placeholder(tf.float32, shape=[None], name='true_rating')
    normalized_rating = (true_rating - 1.0) / 4.0
    _, pred_rating = pmf({
        'u': U,
        'v': V
    }, N, M, D, K, select_u, select_v, alpha_u, alpha_v, alpha_pred)
    pred_rating = tf.reduce_mean(pred_rating, axis=0)
    error = pred_rating - normalized_rating
    rmse = tf.sqrt(tf.reduce_mean(error * error)) * 4

    # Define models for HMC
    n = tf.placeholder(tf.int32, shape=[], name='n')
    m = tf.placeholder(tf.int32, shape=[], name='m')

    def log_joint(observed):
        model, _ = pmf(observed, n, m, D, K, subselect_u, subselect_v, alpha_u,
                       alpha_v, alpha_pred)
        log_pu, log_pv = model.local_log_prob(['u', 'v'])  # [K, N], [K, M]
        log_pr = model.local_log_prob('r')  # [K, batch]
        log_pu = tf.reduce_sum(log_pu, axis=1)
        log_pv = tf.reduce_sum(log_pv, axis=1)
        log_pr = tf.reduce_sum(log_pr, axis=1)
        return log_pu + log_pv + log_pr

    hmc_u = zs.HMC(step_size=1e-3,
                   n_leapfrogs=10,
                   adapt_step_size=None,
                   target_acceptance_rate=0.9)
    hmc_v = zs.HMC(step_size=1e-3,
                   n_leapfrogs=10,
                   adapt_step_size=None,
                   target_acceptance_rate=0.9)
    target_u = tf.gather(U, select_u, axis=1)
    target_v = tf.gather(V, select_v, axis=1)

    candidate_sample_u = tf.get_variable(
        'cand_sample_chunk_u',
        shape=[K, chunk_size, D],
        initializer=tf.random_normal_initializer(0, 0.1),
        trainable=True)
    candidate_sample_v = tf.get_variable(
        'cand_sample_chunk_v',
        shape=[K, chunk_size, D],
        initializer=tf.random_normal_initializer(0, 0.1),
        trainable=True)
    sample_u_op, sample_u_info = hmc_u.sample(log_joint, {
        'r': normalized_rating,
        'v': target_v
    }, {'u': candidate_sample_u})
    sample_v_op, sample_v_info = hmc_v.sample(log_joint, {
        'r': normalized_rating,
        'u': target_u
    }, {'v': candidate_sample_v})

    candidate_idx_u = tf.placeholder(tf.int32,
                                     shape=[chunk_size],
                                     name='cand_u_chunk')
    candidate_idx_v = tf.placeholder(tf.int32,
                                     shape=[chunk_size],
                                     name='cand_v_chunk')
    candidate_u = tf.gather(U, candidate_idx_u, axis=1)  # [K, chunk_size, D]
    candidate_v = tf.gather(V, candidate_idx_v, axis=1)  # [K, chunk_size, D]

    trans_cand_U = tf.assign(candidate_sample_u, candidate_u)
    trans_cand_V = tf.assign(candidate_sample_v, candidate_v)

    trans_us_cand = []
    for i in range(N // chunk_size):
        trans_us_cand.append(tf.assign(Us[i], candidate_sample_u))
    trans_vs_cand = []
    for i in range(M // chunk_size):
        trans_vs_cand.append(tf.assign(Vs[i], candidate_sample_v))

    # Run the inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for step in range(1, num_steps + 1):
            epoch_time = -time.time()
            for i in range(N // chunk_size):
                nv, sv, tr, ssu, ssv = select_from_corpus(
                    i * chunk_size, (i + 1) * chunk_size, user_movie,
                    user_movie_score)
                _ = sess.run(trans_cand_U,
                             feed_dict={
                                 candidate_idx_u:
                                 list(
                                     range(i * chunk_size,
                                           (i + 1) * chunk_size))
                             })
                _ = sess.run(sample_u_op,
                             feed_dict={
                                 select_v: sv,
                                 true_rating: tr,
                                 subselect_u: ssu,
                                 subselect_v: ssv,
                                 n: chunk_size,
                                 m: nv
                             })
                _ = sess.run(trans_us_cand[i])
            for i in range(M // chunk_size):
                nu, su, tr, ssv, ssu = select_from_corpus(
                    i * chunk_size, (i + 1) * chunk_size, movie_user,
                    movie_user_score)
                _ = sess.run(trans_cand_V,
                             feed_dict={
                                 candidate_idx_v:
                                 list(
                                     range(i * chunk_size,
                                           (i + 1) * chunk_size))
                             })
                _ = sess.run(sample_v_op,
                             feed_dict={
                                 select_u: su,
                                 true_rating: tr,
                                 subselect_u: ssu,
                                 subselect_v: ssv,
                                 n: nu,
                                 m: chunk_size
                             })
                _ = sess.run(trans_vs_cand[i])
            epoch_time += time.time()

            train_rmse = []
            train_sizes = []
            time_train = -time.time()
            for t in range(train_iters):
                ed_pos = min((t + 1) * batch_size, N_train + 1)
                su = train_data[t * batch_size:ed_pos, 0]
                sv = train_data[t * batch_size:ed_pos, 1]
                tr = train_data[t * batch_size:ed_pos, 2]
                re = sess.run(rmse,
                              feed_dict={
                                  select_u: su,
                                  select_v: sv,
                                  true_rating: tr
                              })
                train_rmse.append(re)
                train_sizes.append(ed_pos - t * batch_size)
            time_train += time.time()

            print('Step {}({:.1f}s): rmse ({:.1f}s) = {}'.format(
                step, epoch_time, time_train,
                average_rmse_over_batches(train_rmse, train_sizes)))

            if step % valid_freq == 0:
                valid_rmse = []
                valid_sizes = []
                time_valid = -time.time()
                for t in range(valid_iters):
                    ed_pos = min((t + 1) * valid_batch_size, N_test + 1)
                    su = valid_data[t * valid_batch_size:ed_pos, 0]
                    sv = valid_data[t * valid_batch_size:ed_pos, 1]
                    tr = valid_data[t * valid_batch_size:ed_pos, 2]
                    re = sess.run(rmse,
                                  feed_dict={
                                      select_u: su,
                                      select_v: sv,
                                      true_rating: tr
                                  })
                    valid_rmse.append(re)
                    valid_sizes.append(ed_pos - t * batch_size)
                time_valid += time.time()
                print('>>> VALID ({:.1f}s)'.format(time_valid))
                print('>> Valid rmse = {}'.format(
                    average_rmse_over_batches(valid_rmse, valid_sizes)))

            if step % test_freq == 0:
                test_rmse = []
                test_sizes = []
                time_test = -time.time()
                for t in range(test_iters):
                    ed_pos = min((t + 1) * test_batch_size, N_test + 1)
                    su = test_data[t * test_batch_size:ed_pos, 0]
                    sv = test_data[t * test_batch_size:ed_pos, 1]
                    tr = test_data[t * test_batch_size:ed_pos, 2]
                    re = sess.run(rmse,
                                  feed_dict={
                                      select_u: su,
                                      select_v: sv,
                                      true_rating: tr
                                  })
                    test_rmse.append(re)
                    test_sizes.append(ed_pos - t * batch_size)
                time_test += time.time()
                print('>>> TEST ({:.1f}s)'.format(time_test))
                print('>> Test rmse = {}'.format(
                    average_rmse_over_batches(test_rmse, test_sizes)))
Exemplo n.º 7
0
    # Define models for HMC
    n = tf.placeholder(tf.int32, shape=[], name='n')
    m = tf.placeholder(tf.int32, shape=[], name='m')

    def log_joint(observed):
        model, _ = pmf(observed, n, m, D, K, subselect_u, subselect_v, alpha_u,
                       alpha_v, alpha_pred)
        log_pu, log_pv = model.local_log_prob(['u', 'v'])  # [K, N], [K, M]
        log_pr = model.local_log_prob('r')  # [K, batch]
        log_pu = tf.reduce_sum(log_pu, axis=1)
        log_pv = tf.reduce_sum(log_pv, axis=1)
        log_pr = tf.reduce_sum(log_pr, axis=1)
        return log_pu + log_pv + log_pr

    hmc_u = zs.HMC(step_size=1e-3,
                   n_leapfrogs=10,
                   adapt_step_size=None,
                   target_acceptance_rate=0.9)
    hmc_v = zs.HMC(step_size=1e-3,
                   n_leapfrogs=10,
                   adapt_step_size=None,
                   target_acceptance_rate=0.9)
    target_u = select_from_axis1(U, select_u)
    target_v = select_from_axis1(V, select_v)

    candidate_sample_u = \
        tf.get_variable('cand_sample_chunk_u', shape=[K, chunk_size, D],
                        initializer=tf.random_normal_initializer(0, 0.1),
                        trainable=True)
    candidate_sample_v = \
        tf.get_variable('cand_sample_chunk_v', shape=[K, chunk_size, D],
                        initializer=tf.random_normal_initializer(0, 0.1),
Exemplo n.º 8
0
def main():
    np.random.seed(1234)
    tf.set_random_seed(1237)
    M, N, train_data, valid_data, test_data, user_movie, user_movie_score, \
        movie_user, movie_user_score = dataset.load_movielens1m_mapped(
            os.path.join(conf.data_dir, 'ml-1m.zip'))

    # set configurations and hyper parameters
    D = 30
    batch_size = 100000
    K = 8
    n_epochs = 500
    eval_freq = 10

    # paralleled
    chunk_size = 50
    N = (N + chunk_size - 1) // chunk_size
    N *= chunk_size
    M = (M + chunk_size - 1) // chunk_size
    M *= chunk_size

    # Selection
    neighbor_u = tf.placeholder(tf.int32, shape=[None], name="neighbor_u")
    neighbor_v = tf.placeholder(tf.int32, shape=[None], name="neighbor_v")
    select_u = tf.placeholder(tf.int32, shape=[None], name="select_u")
    select_v = tf.placeholder(tf.int32, shape=[None], name="select_v")
    alpha_u = 1.0
    alpha_v = 1.0
    alpha_pred = 0.2 / 4.0

    # Define samples as variables
    Us = []
    Vs = []
    for i in range(N // chunk_size):
        ui = tf.get_variable('u_chunk_%d' % i,
                             shape=[K, chunk_size, D],
                             initializer=tf.random_normal_initializer(0, 0.1),
                             trainable=False)
        Us.append(ui)
    for i in range(M // chunk_size):
        vi = tf.get_variable('v_chunk_%d' % i,
                             shape=[K, chunk_size, D],
                             initializer=tf.random_normal_initializer(0, 0.1),
                             trainable=False)
        Vs.append(vi)
    U = tf.concat(Us, axis=1)
    V = tf.concat(Vs, axis=1)

    n = tf.placeholder(tf.int32, shape=[], name='n')
    m = tf.placeholder(tf.int32, shape=[], name='m')
    model = pmf(n, m, D, K, select_u, select_v, alpha_u, alpha_v, alpha_pred)

    # prediction
    true_rating = tf.placeholder(tf.float32, shape=[None], name='true_rating')
    normalized_rating = (true_rating - 1.0) / 4.0
    pred_rating = model.observe(u=U, v=V)["r"]
    pred_rating = tf.reduce_mean(pred_rating, axis=0)
    rmse = tf.sqrt(tf.reduce_mean(
        tf.square(pred_rating - normalized_rating))) * 4

    hmc_u = zs.HMC(step_size=1e-3,
                   n_leapfrogs=10,
                   adapt_step_size=None,
                   target_acceptance_rate=0.9)
    hmc_v = zs.HMC(step_size=1e-3,
                   n_leapfrogs=10,
                   adapt_step_size=None,
                   target_acceptance_rate=0.9)
    target_u = tf.gather(U, neighbor_u, axis=1)
    target_v = tf.gather(V, neighbor_v, axis=1)

    candidate_sample_u = tf.get_variable(
        'cand_sample_chunk_u',
        shape=[K, chunk_size, D],
        initializer=tf.random_normal_initializer(0, 0.1),
        trainable=True)
    candidate_sample_v = tf.get_variable(
        'cand_sample_chunk_v',
        shape=[K, chunk_size, D],
        initializer=tf.random_normal_initializer(0, 0.1),
        trainable=True)

    def log_joint(bn):
        log_pu, log_pv = bn.cond_log_prob(['u', 'v'])  # [K, N], [K, M]
        log_pr = bn.cond_log_prob('r')  # [K, batch]
        log_pu = tf.reduce_sum(log_pu, axis=-1)
        log_pv = tf.reduce_sum(log_pv, axis=-1)
        log_pr = tf.reduce_sum(log_pr, axis=-1)
        return log_pu + log_pv + log_pr

    model.log_joint = log_joint

    sample_u_op, sample_u_info = hmc_u.sample(model, {
        "r": normalized_rating,
        "v": target_v
    }, {"u": candidate_sample_u})
    sample_v_op, sample_v_info = hmc_v.sample(model, {
        "r": normalized_rating,
        "u": target_u
    }, {"v": candidate_sample_v})

    candidate_idx_u = tf.placeholder(tf.int32,
                                     shape=[chunk_size],
                                     name='cand_u_chunk')
    candidate_idx_v = tf.placeholder(tf.int32,
                                     shape=[chunk_size],
                                     name='cand_v_chunk')
    candidate_u = tf.gather(U, candidate_idx_u, axis=1)  # [K, chunk_size, D]
    candidate_v = tf.gather(V, candidate_idx_v, axis=1)  # [K, chunk_size, D]

    trans_cand_U = tf.assign(candidate_sample_u, candidate_u)
    trans_cand_V = tf.assign(candidate_sample_v, candidate_v)

    trans_us_cand = []
    for i in range(N // chunk_size):
        trans_us_cand.append(tf.assign(Us[i], candidate_sample_u))
    trans_vs_cand = []
    for i in range(M // chunk_size):
        trans_vs_cand.append(tf.assign(Vs[i], candidate_sample_v))

    # Run the inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in range(1, n_epochs + 1):
            epoch_time = -time.time()
            for i in range(N // chunk_size):
                nv, sv, tr, ssu, ssv = select_from_corpus(
                    i * chunk_size, (i + 1) * chunk_size, user_movie,
                    user_movie_score)
                _ = sess.run(trans_cand_U,
                             feed_dict={
                                 candidate_idx_u:
                                 list(
                                     range(i * chunk_size,
                                           (i + 1) * chunk_size))
                             })
                _ = sess.run(sample_u_op,
                             feed_dict={
                                 neighbor_v: sv,
                                 true_rating: tr,
                                 select_u: ssu,
                                 select_v: ssv,
                                 n: chunk_size,
                                 m: nv
                             })
                _ = sess.run(trans_us_cand[i])
            for i in range(M // chunk_size):
                nu, su, tr, ssv, ssu = select_from_corpus(
                    i * chunk_size, (i + 1) * chunk_size, movie_user,
                    movie_user_score)
                _ = sess.run(trans_cand_V,
                             feed_dict={
                                 candidate_idx_v:
                                 list(
                                     range(i * chunk_size,
                                           (i + 1) * chunk_size))
                             })
                _ = sess.run(sample_v_op,
                             feed_dict={
                                 neighbor_u: su,
                                 true_rating: tr,
                                 select_u: ssu,
                                 select_v: ssv,
                                 n: nu,
                                 m: chunk_size
                             })
                _ = sess.run(trans_vs_cand[i])
            epoch_time += time.time()
            print("Epoch {}: {:.1f}s".format(epoch, epoch_time))

            def _eval(phase, data, batch_size):
                rmses = []
                sizes = []
                time_eval = -time.time()
                n_iters = (data.shape[0] + batch_size - 1) // batch_size
                for t in range(n_iters):
                    su = data[t * batch_size:(t + 1) * batch_size, 0]
                    sv = data[t * batch_size:(t + 1) * batch_size, 1]
                    tr = data[t * batch_size:(t + 1) * batch_size, 2]
                    re = sess.run(rmse,
                                  feed_dict={
                                      select_u: su,
                                      select_v: sv,
                                      n: N,
                                      m: M,
                                      true_rating: tr
                                  })
                    rmses.append(re)
                    sizes.append(tr.shape[0])
                time_eval += time.time()
                print('>>> {} ({:.1f}s): rmse = {}'.format(
                    phase, time_eval, average_rmse_over_batches(rmses, sizes)))

            _eval("Train", train_data, batch_size)

            if epoch % eval_freq == 0:
                _eval("Validation", valid_data, batch_size)
                _eval("Test", test_data, batch_size)