Пример #1
0
def train_model1(sess, train_op, n_epoch, opt, model1, X_dataset, Y_dataset, Z_dataset, qvars, pvars, dvars, qvars_y, pvars_y, dvars_y):
    train_gen_op_x =  train_op.get_train_gen_op_x()
    train_disc_op_x = train_op.get_train_disc_op_x()
    train_gen_op_y =  train_op.get_train_gen_op_y()
    train_disc_op_y = train_op.get_train_disc_op_y()

    x = model1.get_x()
    y = model1.get_y()
    z = model1.get_z()
    d1 = model1.get_d1()
    d2 = model1.get_d2()
    
    FG_x = []
    FG_y = []
    FD_x = []
    FD_y = []
    DG_xz = DG()
    DG_xz.initial()
    DG_yz = DG()
    DG_yz.initial()
    
    dmb1 = [[1]*2]*128
    dmb2 = [[0]*2]*128
    
    for epoch in tqdm( range(n_epoch), total=n_epoch):
        X_dataset = shuffle(X_dataset)
        Y_dataset = shuffle(Y_dataset)
        Z_dataset = shuffle(Z_dataset)
        i = 0
        
        #print(dmb)
        for xmb, ymb, zmb in iter_data(X_dataset, Y_dataset, Z_dataset, size=batch_size):
            #print(xmb)
            i = i + 1
            for _ in range(1):
                f_d_x, _ = sess.run([model1.get_disc_loss_x(), train_disc_op_x], feed_dict={x: xmb, y:ymb, z:zmb, d1:dmb1, d2:dmb2})
            for _ in range(5):
                f_g_x, _ = sess.run([[model1.get_gen_loss_x(), model1.get_gen_loss_xz(), model1.get_cost_x(), model1.get_cost_xz()], train_gen_op_x], feed_dict={x: xmb, y:ymb, z:zmb, d1:dmb1, d2:dmb2})
            FG_x.append(f_g_x)
            FD_x.append(f_d_x)
            for _ in range(1):
                f_d_y, _ = sess.run([model1.get_disc_loss_y(), train_disc_op_y], feed_dict={x: xmb, y:ymb, z:zmb, d1:dmb1, d2:dmb2})
            for _ in range(5):
                f_g_y, _ = sess.run([[model1.get_gen_loss_y(), model1.get_gen_loss_yz(), model1.get_cost_y(), model1.get_cost_yz()], train_gen_op_y], feed_dict={x: xmb, y:ymb, z:zmb, d1:dmb1, d2:dmb2})
            FG_y.append(f_g_y)
            FD_y.append(f_d_y)
        print_xz(epoch, i, f_d_x, f_g_x[0], f_g_x[1], f_g_x[2], f_g_x[3])
        
        print_yz(epoch, i, f_d_y, f_g_y[0], f_g_y[1], f_g_y[2], f_g_y[3])

    DG_xz.set_FD(FD_x)
    DG_xz.set_FG(FG_x)
    DG_yz.set_FD(FD_y)
    DG_yz.set_FG(FG_y)

    return sess, DG_xz, DG_yz, model1
Пример #2
0
def train_model2(train_op, n_epoch_2, opt, model1, sess, X_dataset, Y_dataset, W_dataset, Z_dataset):

    train_gen_op_xyw =  train_op.get_train_gen_op_xyw()
    train_disc_op_xyw = train_op.get_train_disc_op_xyw()

    FD_xyw = []
    FG_xyw = []
    DG_xyw = DG()
    DG_xyw.initial()

    dmb1 = [[0]*2]*128
    dmb2 = [[1]*2]*128

    x = model1.get_x()
    y = model1.get_y()
    w = model1.get_w()
    z = model1.get_z()
    d1 = model1.get_d1()
    d2 = model1.get_d2()

    for epoch in tqdm( range(n_epoch_2), total=n_epoch_2):
        X_dataset = shuffle(X_dataset)
        Y_dataset = shuffle(Y_dataset)
        W_dataset = shuffle(W_dataset)
        Z_dataset = shuffle(Z_dataset)
        i = 0
        for xmb, ymb, wmb, zmb in iter_data(X_dataset, Y_dataset, W_dataset, Z_dataset, size=batch_size):
            i = i + 1
            for _ in range(1):
                f_d_xyw, _ = sess.run([model1.get_disc_loss_xyw(), train_disc_op_xyw], feed_dict={x: xmb, y:ymb, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            for _ in range(5):
                f_g_xyw, _ = sess.run([[model1.get_gen_loss_xyw(), model1.get_gen_loss_x(), model1.get_cost_x(), model1.get_cost_xz()], train_gen_op_xyw], feed_dict={x: xmb, y:ymb, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            FG_xyw.append(f_g_xyw)
            FD_xyw.append(f_d_xyw)
        print_xy(epoch, i, f_d_xyw, f_g_xyw[0], f_g_xyw[1], f_g_xyw[2], f_g_xyw[3])
    DG_xyw.set_FD(FD_xyw)
    DG_xyw.set_FG(FG_xyw)

    return sess, DG_xyw, model1
Пример #3
0
FD_xy = []

X_dataset = data_x.get_dataset()
Y_dataset = data_y.get_dataset()
Z_dataset = data_z.get_dataset()

n_epoch = option.n_epoch
batch_size = option.batch_size
n_epoch_2 = option.n_epoch_2
for epoch in tqdm(range(n_epoch), total=n_epoch):
    X_dataset = shuffle(X_dataset)
    Y_dataset = shuffle(Y_dataset)
    Z_dataset = shuffle(Z_dataset)
    i = 0
    for xmb, ymb, zmb in iter_data(X_dataset,
                                   Y_dataset,
                                   Z_dataset,
                                   size=batch_size):
        i = i + 1
        for _ in range(1):
            f_d_x, _ = sess.run([disc_loss_x, train_disc_op_x],
                                feed_dict={
                                    x: xmb,
                                    y: ymb,
                                    z: zmb
                                })
        for _ in range(5):
            f_g_x, _ = sess.run(
                [[gen_loss_x, gen_loss_xz, cost_x, cost_xz], train_gen_op_x],
                feed_dict={
                    x: xmb,
                    y: ymb,
Пример #4
0
""" training """
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.1
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())


FG = []
FD = []

for epoch in tqdm( range(n_epoch), total=n_epoch):
    X_dataset= shuffle(X_dataset)
    Z_dataset= shuffle(Z_dataset)
    i = 0
    for xmb, zmb in iter_data(X_dataset, Z_dataset, size=batch_size):
        i = i + 1
        for _ in range(1):
            f_d, _ = sess.run([disc_loss, train_disc_op], feed_dict={x: xmb, z:zmb})
        for _ in range(5):
            f_g, _ = sess.run([[gen_loss, gen_loss_xz, cost_x, cost_z], train_gen_op], feed_dict={x: xmb, z:zmb})

        FG.append(f_g)
        FD.append(f_d)

    print("epoch %d iter %d: discloss %f genloss %f adv_x %f recons_x %f recons_z %f" % (epoch, i, f_d, f_g[0], f_g[1], f_g[2], f_g[3]))

# tmpx, tmpz = sess.run([dvars_x, dvars_z])
# pdb.set_trace()

""" plot the results """
Пример #5
0
def test(option, X_np_data_test, Y_np_data_test, W_np_data_test, Z_np_data_test, model, n_viz,sess, dmb1, dmb2):
    q_xz = model.get_q_xz()
    rec_xz = model.get_rec_wzxz()
    p_x = model.get_p_wzx()

    q_yz = model.get_q_yz()
    rec_yz = model.get_rec_xzyz()
    p_y = model.get_p_xzy()

    rec_zx = model.get_rec_xzx()
    rec_zy = model.get_rec_yzy()
    rec_zw = model.get_rec_wzw()

    q_wz = model.get_q_wz()
    rec_wz = model.get_rec_yzwz()
    p_w = model.get_p_yzw()

    x = model.get_x()
    y = model.get_y()
    w = model.get_w()
    z = model.get_z()

    d1 = model.get_d1()
    d2 = model.get_d2()

    dmbb1 = dmb2
    dmbb2 = dmb1
    
    imxz = np.array([]); rmwzxz = np.array([]); imxzy = np.array([]); rmxzyzx = np.array([]); rmyzy = np.array([]); imzx = np.array([]);rmzxz = np.array([]);
    imyz = np.array([]); rmxzyz = np.array([]); imyzw = np.array([]); rmyzwzy = np.array([]); rmxzx = np.array([]); imzy = np.array([]);rmzyz = np.array([]);
    imwz = np.array([]); rmyzwz = np.array([]); imwzx = np.array([]); rmwzxzw = np.array([]); rmwzw = np.array([]); imzw = np.array([]);rmzwz = np.array([]);
    batch_size = option.batch_size
    for _ in range(n_viz):
        for xmb, wmb, ymb, zmb in iter_data(X_np_data_test, W_np_data_test, Y_np_data_test,  Z_np_data_test, size=batch_size):
            
            temp_imwz = sess.run(q_wz, feed_dict={x: xmb, y: ymb, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            imwz = np.vstack([imwz, temp_imwz]) if imwz.size else temp_imwz    

            temp_imxz = sess.run(q_xz, feed_dict={x: xmb, y: ymb, w: wmb, z:zmb, d1:dmb1, d2:dmb2})
            imxz = np.vstack([imxz, temp_imxz]) if imxz.size else temp_imxz
            
            temp_imyz = sess.run(q_yz, feed_dict={x: xmb, y: ymb, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            imyz = np.vstack([imyz, temp_imyz]) if imyz.size else temp_imyz

            temp_imy = sess.run(p_y, feed_dict={x: xmb, y: ymb, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            imxzy = np.vstack([imxzy, temp_imy]) if imxzy.size else temp_imy

            temp_imw = sess.run(p_w, feed_dict={x: xmb, y: ymb, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            imyzw = np.vstack([imyzw, temp_imw]) if imyzw.size else temp_imw

            temp_imx = sess.run(p_x, feed_dict={x: xmb, y: ymb, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            imwzx = np.vstack([imwzx, temp_imx]) if imwzx.size else temp_imx

            ############# clip ****
            temp_rmyz = sess.run(rec_yz, feed_dict={x: xmb, y: temp_imy, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            rmxzyz = np.vstack([rmxzyz, temp_rmyz]) if rmxzyz.size else temp_rmyz

            temp_rmwz = sess.run(rec_wz, feed_dict={x: xmb, y: ymb, w:temp_imw, z:zmb, d1:dmb1, d2:dmb2})
            rmyzwz = np.vstack([rmyzwz, temp_rmwz]) if rmyzwz.size else temp_rmwz

            temp_rmxz = sess.run(rec_xz, feed_dict={x: temp_imx, y: ymb, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            rmwzxz = np.vstack([rmwzxz, temp_rmxz]) if rmwzxz.size else temp_rmxz

            ############# clip
            temp_rmx = sess.run(rec_zx, feed_dict={x: xmb, y: temp_imy, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            rmxzyzx = np.vstack([rmxzyzx, temp_rmx]) if rmxzyzx.size else temp_rmx

            temp_rmx = sess.run(rec_zx, feed_dict={x: xmb, y: ymb, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            rmxzx = np.vstack([rmxzx, temp_rmx]) if rmxzx.size else temp_rmx
            
            temp_rmy = sess.run(rec_zy, feed_dict={x: xmb, y: ymb, w:temp_imw, z:zmb, d1:dmb1, d2:dmb2})
            rmyzwzy = np.vstack([rmyzwzy, temp_rmy]) if rmyzwzy.size else temp_rmy

            temp_rmy = sess.run(rec_zy, feed_dict={x: xmb, y: ymb, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            rmyzy = np.vstack([rmyzy, temp_rmy]) if rmyzy.size else temp_rmy
            
            temp_rmw = sess.run(rec_zw, feed_dict={x: temp_imx, y: ymb, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            rmwzxzw = np.vstack([rmwzxzw, temp_rmw]) if rmwzxzw.size else temp_rmw
            
            temp_rmw = sess.run(rec_zw, feed_dict={x: xmb, y: ymb, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            rmwzw = np.vstack([rmwzw, temp_rmw]) if rmwzw.size else temp_rmw           
    
            temp_imy = sess.run(p_y, feed_dict={x: xmb, y: ymb, w:wmb, z:zmb, d1:dmbb1, d2:dmbb2})
            imzy = np.vstack([imzy, temp_imy]) if imzy.size else temp_imy

            temp_imw = sess.run(p_w, feed_dict={x: xmb, y: ymb, w:wmb, z:zmb, d1:dmbb1, d2:dmbb2})
            imzw = np.vstack([imzw, temp_imw]) if imzw.size else temp_imw

            temp_imx = sess.run(p_x, feed_dict={x: xmb, y: ymb, w:wmb, z:zmb, d1:dmbb1, d2:dmbb2})
            imzx = np.vstack([imzx, temp_imx]) if imzx.size else temp_imx

            temp_rmyz = sess.run(rec_yz, feed_dict={x: xmb, y: temp_imy, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            rmzyz = np.vstack([rmzyz, temp_rmyz]) if rmzyz.size else temp_rmyz

            temp_rmwz = sess.run(rec_wz, feed_dict={x: xmb, y: ymb, w:temp_imw, z:zmb, d1:dmb1, d2:dmb2})
            rmzwz = np.vstack([rmzwz, temp_rmwz]) if rmzwz.size else temp_rmwz

            temp_rmxz = sess.run(rec_xz, feed_dict={x: temp_imx, y: ymb, w:wmb, z:zmb, d1:dmb1, d2:dmb2})
            rmzxz = np.vstack([rmzxz, temp_rmxz]) if rmzxz.size else temp_rmxz


    result = Result()
    result.set_z_domian(imzx, imzy, imzw, rmzxz, rmzyz, rmzwz)
    result.set_all(imxz, imyz, imwz, imxzy, imyzw, imwzx, rmxzyz, rmyzwz, rmwzxz, rmxzyzx, rmyzy, rmyzwzy, rmxzx, rmwzxzw, rmwzw)
    return result
Пример #6
0
def test(option, X_np_data_test, Y_np_data_test, Z_np_data_test, model, n_viz,
         sess):
    q_xz = model.get_q_xz()
    rec_xz = model.get_rec_xz()
    p_x = model.get_p_x()
    q_yz = model.get_q_yz()
    rec_yz = model.get_rec_yz()
    p_y = model.get_p_y()
    rec_zx = model.get_rec_zx()
    rec_zy = model.get_rec_zy()
    x = model.get_x()
    y = model.get_y()
    z = model.get_z()
    d1 = model.get_d1()
    d2 = model.get_d2()

    dmb1 = [[1] * 2] * 128
    dmb2 = [[0] * 2] * 128

    temp = np.array([])

    im = [temp] * 6
    rm = [temp] * 8

    imxz = np.array([])
    imzx = np.array([])
    imyzx = np.array([])
    imyz = np.array([])
    imzy = np.array([])
    imxzy = np.array([])

    rmxzyzx = np.array([])
    rmyzy = np.array([])
    rmxzyz = np.array([])
    rmyzxzy = np.array([])
    rmxzx = np.array([])
    rmyzxz = np.array([])
    rmzxz = np.array([])
    rmzyz = np.array([])

    batch_size = option.batch_size

    for _ in range(n_viz):
        for xmb, ymb, zmb in iter_data(X_np_data_test,
                                       Y_np_data_test,
                                       Z_np_data_test,
                                       size=batch_size):
            temp_imzx = sess.run(p_x,
                                 feed_dict={
                                     x: xmb,
                                     y: ymb,
                                     z: zmb,
                                     d1: dmb1,
                                     d2: dmb2
                                 })
            imzx = np.vstack([imzx, temp_imzx]) if imzx.size else temp_imzx

            temp_imzy = sess.run(p_y,
                                 feed_dict={
                                     x: xmb,
                                     y: ymb,
                                     z: zmb,
                                     d1: dmb1,
                                     d2: dmb2
                                 })
            imzy = np.vstack([imzy, temp_imzy]) if imzy.size else temp_imzy

            temp_imxz = sess.run(q_xz,
                                 feed_dict={
                                     x: xmb,
                                     y: ymb,
                                     z: zmb,
                                     d1: dmb2,
                                     d2: dmb1
                                 })
            imxz = np.vstack([imxz, temp_imxz]) if imxz.size else temp_imxz

            temp_imy = sess.run(p_y,
                                feed_dict={
                                    x: xmb,
                                    y: ymb,
                                    z: temp_imxz,
                                    d1: dmb2,
                                    d2: dmb1
                                })
            imxzy = np.vstack([imxzy, temp_imy]) if imxzy.size else temp_imy

            temp_imyz = sess.run(q_yz,
                                 feed_dict={
                                     x: xmb,
                                     y: ymb,
                                     z: zmb,
                                     d1: dmb2,
                                     d2: dmb1
                                 })
            imyz = np.vstack([imyz, temp_imyz]) if imyz.size else temp_imyz

            temp_imx = sess.run(p_x,
                                feed_dict={
                                    x: xmb,
                                    y: ymb,
                                    z: temp_imyz,
                                    d1: dmb2,
                                    d2: dmb1
                                })
            imyzx = np.vstack([imyzx, temp_imx]) if imyzx.size else temp_imx

            temp_rmz = sess.run(rec_yz,
                                feed_dict={
                                    x: xmb,
                                    y: temp_imy,
                                    z: zmb,
                                    d1: dmb2,
                                    d2: dmb1
                                })
            rmxzyz = np.vstack([rmxzyz, temp_rmz]) if rmxzyz.size else temp_rmz

            temp_rmx = sess.run(rec_zx,
                                feed_dict={
                                    x: xmb,
                                    y: temp_imy,
                                    z: temp_rmz,
                                    d1: dmb2,
                                    d2: dmb1
                                })
            rmxzyzx = np.vstack([rmxzyzx, temp_rmx
                                 ]) if rmxzyzx.size else temp_rmx

            temp_rmx = sess.run(rec_zy,
                                feed_dict={
                                    x: xmb,
                                    y: ymb,
                                    z: temp_rmz,
                                    d1: dmb2,
                                    d2: dmb1
                                })
            rmyzy = np.vstack([rmyzy, temp_rmx]) if rmyzy.size else temp_rmx

            temp_rmz = sess.run(rec_xz,
                                feed_dict={
                                    x: temp_imx,
                                    y: ymb,
                                    z: zmb,
                                    d1: dmb2,
                                    d2: dmb1
                                })
            rmyzxz = np.vstack([rmyzxz, temp_rmz]) if rmyzxz.size else temp_rmz

            temp_rmx = sess.run(rec_zy,
                                feed_dict={
                                    x: temp_imx,
                                    y: ymb,
                                    z: temp_rmz,
                                    d1: dmb2,
                                    d2: dmb1
                                })
            rmyzxzy = np.vstack([rmyzxzy, temp_rmx
                                 ]) if rmyzxzy.size else temp_rmx

            temp_rmx = sess.run(rec_zx,
                                feed_dict={
                                    x: xmb,
                                    y: ymb,
                                    z: temp_rmz,
                                    d1: dmb2,
                                    d2: dmb1
                                })
            rmxzx = np.vstack([rmxzx, temp_rmx]) if rmxzx.size else temp_rmx

            temp_rmzxz = sess.run(rec_xz,
                                  feed_dict={
                                      x: xmb,
                                      y: ymb,
                                      z: zmb,
                                      d1: dmb1,
                                      d2: dmb2
                                  })
            rmzxz = np.vstack([rmzxz, temp_rmzxz
                               ]) if rmzxz.size else temp_rmzxz

            temp_rmzyz = sess.run(rec_yz,
                                  feed_dict={
                                      x: xmb,
                                      y: ymb,
                                      z: zmb,
                                      d1: dmb1,
                                      d2: dmb2
                                  })
            rmzyz = np.vstack([rmzyz, temp_rmzyz
                               ]) if rmzyz.size else temp_rmzyz

    im[0] = imxz
    im[1] = imzx
    im[2] = imyzx
    im[3] = imyz
    im[4] = imzy
    im[5] = imxzy

    rm[0] = rmxzyzx
    rm[1] = rmyzy
    rm[2] = rmxzyz
    rm[3] = rmyzxzy
    rm[4] = rmxzx
    rm[5] = rmyzxz
    rm[6] = rmzxz
    rm[7] = rmzyz

    return im, rm
Пример #7
0
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.1
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())


FG = []
FD = []

for epoch in tqdm(range(n_epoch), total=n_epoch):
    X_dataset= shuffle(X_dataset)
    Y_dataset = shuffle(Y_dataset)
    e1_dataset= shuffle(e1_dataset)
    e2_dataset = shuffle(e2_dataset)
    i = 0
    for xmb, ymb, e1mb, e2mb in iter_data(X_dataset, Y_dataset, e1_dataset, e2_dataset, size=batch_size):
        i = i + 1
        for _ in range(1):
            f_d, _ = sess.run([disc_loss, train_disc_op], feed_dict={x: xmb, y:ymb, e1:e1mb, e2:e2mb})
        for _ in range(1):
            f_g, _ = sess.run([[gen_loss, cost_x, cost_y], train_gen_op], feed_dict={x: xmb, y:ymb, e1:e1mb, e2:e2mb})
        FG.append(f_g)
        FD.append(f_d)

    print("epoch %d iter %d: discloss %f genloss %f recons_x %f recons_y %f"
          %(epoch, i, f_d, f_g[0], f_g[1], f_g[2]))

""" plot the results """

# test dataset