Exemple #1
0
def dataPrepare(date_list,T,N):
    if(len(date_list)!=T+N+1):
        gl.logger.error('Date list length is incorrect,plz check!')
    else:
        gl.logger.info('Prepare close data and factor data.')
        close_data=om.getCloseData(date_list[-1],date_list[0])
        train_factor_data=om.getFactorData(date_list[-1],date_list[-T])
        gl.logger.info('Close data length: %d, factor data length: %d' %(len(close_data),len(train_factor_data)))
        
        while(0==len(close_data) or 0==len(train_factor_data)):
            gl.logger.warning('Failed to get the daily data,sleep to try again!')
            time.sleep(100)
            close_data=om.getCloseData(date_list[-1],date_list[0])
            train_factor_data=om.getFactorData(date_list[-1],date_list[-T])
            
        #测试样本构建
        gl.logger.info('Prepare train data...')

        close_df = dp.closeToDf(close_data)
        gl.logger.info('Close data OK!')

        return_df = dp.calculateReturn(T,close_df)
        gl.logger.info('Return data OK!')

        train_factor_df = dp.factorDataToDf(train_factor_data)
        #训练样本 ,当天的因子数据
        gl.logger.info('Prepare test data...')
        test_factor_data=om.getFactorData(date_list[0],date_list[0])
        test_factor_df = dp.factorDataToDf(test_factor_data)       
    return return_df,train_factor_df,test_factor_df
Exemple #2
0
def factorDataNormalized(test_factor_df):
     #风格因子信息
    factor_col = []
    factor_info=om.getStyleFactorInfo()
    style_factor=factor_info[factor_info[:,1]!='S000000',:]
    factor_col=np.append(['date','stockcode'],style_factor[:,0]) 
    #风格因子值,并全局标准化
    gl.logger.info('Start to normalize the factor data.')
    factor_list = []
    factor_col_list = list(factor_col)
    for element in factor_col_list:
        factor_list.append(str(element).strip())
    test_factor_df = test_factor_df[factor_list]
    column_mean = test_factor_df.mean(axis=0)
    for column in test_factor_df.columns:
        if column != 'stockcode':
            test_factor_df[column] = test_factor_df[column].fillna(column_mean[column])
    factor_data_scaled=preprocessing.scale(test_factor_df.ix[:,2:])
    factor_data_scaled=np.where(factor_data_scaled>3,3,factor_data_scaled)
    factor_data_scaled=np.where(factor_data_scaled<-3,-3,factor_data_scaled)

    factor_data_scaled_df = pd.DataFrame(factor_data_scaled,columns=style_factor[:,0])#全市场标准化
    factor_normal_data=pd.concat([test_factor_df.ix[:,:2],factor_data_scaled_df],axis=1)
    data_df=factor_normal_data.copy()
    #转换为库中数据格式
    factor_normal_info=dp.transferToFactordata(data_df)
    factor_normal_info.iloc[:,0]=factor_normal_info.iloc[:,0].astype('int32')
    gl.logger.info('Factor data normalization is done, stock num is :%d ,factor num is %d, record num is: %d'%(len(factor_data_scaled),len(factor_data_scaled_df.columns),len(factor_normal_info)))
    #插入因子标准数据表
    flag=om.normalDataInsert(factor_normal_info)
    gl.logger.info('Factor normalized data insert is done,flag is :%d!'%(flag))
    return flag
Exemple #3
0
def modelPredict(train_factor_df,return_df,test_factor_df):
     #模型部分
    gl.logger.info('Start to produce model sample!')
    train_sample= dp.produceSample(train_factor_df,return_df,'train')
    test_sample= dp.produceSample(test_factor_df,pd.DataFrame(),'test')
    gl.logger.info('Train sample number: %d, test sample number: %d' %(len(train_sample),len(test_sample)))
    ypred_df=mxgb.classificationModel(train_sample,test_sample)
    
    gl.logger.info('Model task is finished ,next to add the market_value column!')
    ypred_df.index=ypred_df.iloc[:,:2]
    test_factor_df.index=test_factor_df.iloc[:,:2]
    stock_info=pd.concat([ypred_df,test_factor_df.ix[:,'ln_capital']],axis=1) #预测值+市值,便于后面计算权重
    stock_info.ix[stock_info['ln_capital']<0,'ln_capital']=1
    stock_info.index=range(len(stock_info))
    test_factor_df.index=range(len(test_factor_df))
    
    gl.logger.info('Model task is finished ,predict data length is :%d!'%(len(ypred_df)))
    flag=om.stockRatingInsert(stock_info)
    gl.logger.info('Stock rating info insert is done,flag is :%d!'%(flag))

    return flag
Exemple #4
0
def train(n_epochs, learning_rate_G, learning_rate_D, batch_size, mid_flag,
          check_num, discriminative):
    beta_G = cfg.TRAIN.ADAM_BETA_G
    beta_D = cfg.TRAIN.ADAM_BETA_D
    n_vox = cfg.CONST.N_VOX
    dim = cfg.NET.DIM
    vox_shape = [n_vox[0], n_vox[1], n_vox[2], dim[-1]]
    com_shape = [n_vox[0], n_vox[1], n_vox[2], 2]
    dim_z = cfg.NET.DIM_Z
    start_vox_size = cfg.NET.START_VOX
    kernel = cfg.NET.KERNEL
    stride = cfg.NET.STRIDE
    dilations = cfg.NET.DILATIONS
    freq = cfg.CHECK_FREQ
    record_vox_num = cfg.RECORD_VOX_NUM

    depvox_gan_model = depvox_gan(batch_size=batch_size,
                                  vox_shape=vox_shape,
                                  com_shape=com_shape,
                                  dim_z=dim_z,
                                  dim=dim,
                                  start_vox_size=start_vox_size,
                                  kernel=kernel,
                                  stride=stride,
                                  dilations=dilations,
                                  discriminative=discriminative,
                                  is_train=True)

    Z_tf, z_part_enc_tf, surf_tf, full_tf, full_gen_tf, surf_dec_tf, full_dec_tf,\
    gen_loss_tf, discrim_loss_tf, recons_ssc_loss_tf, recons_com_loss_tf, recons_sem_loss_tf, encode_loss_tf, refine_loss_tf, summary_tf,\
    part_tf, part_dec_tf, complete_gt_tf, complete_gen_tf, complete_dec_tf, sscnet_tf, scores_tf = depvox_gan_model.build_model()
    global_step = tf.Variable(0, name='global_step', trainable=False)
    config_gpu = tf.ConfigProto()
    config_gpu.gpu_options.allow_growth = True
    sess = tf.Session(config=config_gpu)
    saver = tf.train.Saver(max_to_keep=cfg.SAVER_MAX)

    data_paths = scene_model_id_pair(dataset_portion=cfg.TRAIN.DATASET_PORTION)
    print('---amount of data:', str(len(data_paths)))
    data_process = DataProcess(data_paths, batch_size, repeat=True)

    enc_sscnet_vars = list(
        filter(lambda x: x.name.startswith('enc_ssc'),
               tf.trainable_variables()))
    enc_sdf_vars = list(
        filter(lambda x: x.name.startswith('enc_x'), tf.trainable_variables()))
    dis_sdf_vars = list(
        filter(lambda x: x.name.startswith('dis_x'), tf.trainable_variables()))
    dis_com_vars = list(
        filter(lambda x: x.name.startswith('dis_g'), tf.trainable_variables()))
    dis_sem_vars = list(
        filter(lambda x: x.name.startswith('dis_y'), tf.trainable_variables()))
    gen_com_vars = list(
        filter(lambda x: x.name.startswith('gen_x'), tf.trainable_variables()))
    gen_sem_vars = list(
        filter(lambda x: x.name.startswith('gen_y'), tf.trainable_variables()))
    gen_sdf_vars = list(
        filter(lambda x: x.name.startswith('gen_z'), tf.trainable_variables()))
    refine_vars = list(
        filter(lambda x: x.name.startswith('gen_y_ref'),
               tf.trainable_variables()))

    lr_VAE = tf.placeholder(tf.float32, shape=[])

    # main optimiser
    train_op_pred_sscnet = tf.train.AdamOptimizer(learning_rate_G,
                                                  beta1=beta_G,
                                                  beta2=0.9).minimize(
                                                      recons_ssc_loss_tf,
                                                      var_list=enc_sscnet_vars)
    train_op_pred_com = tf.train.AdamOptimizer(
        learning_rate_G, beta1=beta_G, beta2=0.9).minimize(
            recons_com_loss_tf,
            var_list=enc_sdf_vars + gen_com_vars + gen_sdf_vars)
    train_op_pred_sem = tf.train.AdamOptimizer(
        learning_rate_G, beta1=beta_G, beta2=0.9).minimize(
            recons_sem_loss_tf,
            var_list=enc_sdf_vars + gen_sem_vars + gen_sdf_vars)

    # refine optimiser
    train_op_refine = tf.train.AdamOptimizer(learning_rate_G,
                                             beta1=beta_G,
                                             beta2=0.9).minimize(
                                                 refine_loss_tf,
                                                 var_list=refine_vars)

    if discriminative is True:
        train_op_gen_sdf = tf.train.AdamOptimizer(learning_rate_G,
                                                  beta1=beta_G,
                                                  beta2=0.9).minimize(
                                                      gen_loss_tf,
                                                      var_list=gen_sdf_vars)
        train_op_gen_com = tf.train.AdamOptimizer(learning_rate_G,
                                                  beta1=beta_G,
                                                  beta2=0.9).minimize(
                                                      gen_loss_tf,
                                                      var_list=gen_com_vars)
        train_op_gen_sem = tf.train.AdamOptimizer(
            learning_rate_G, beta1=beta_G,
            beta2=0.9).minimize(gen_loss_tf,
                                var_list=gen_sem_vars + gen_com_vars)
        train_op_dis_sdf = tf.train.AdamOptimizer(learning_rate_D,
                                                  beta1=beta_D,
                                                  beta2=0.9).minimize(
                                                      discrim_loss_tf,
                                                      var_list=dis_sdf_vars)
        train_op_dis_com = tf.train.AdamOptimizer(learning_rate_D,
                                                  beta1=beta_D,
                                                  beta2=0.9).minimize(
                                                      discrim_loss_tf,
                                                      var_list=dis_com_vars)
        train_op_dis_sem = tf.train.AdamOptimizer(learning_rate_D,
                                                  beta1=beta_D,
                                                  beta2=0.9).minimize(
                                                      discrim_loss_tf,
                                                      var_list=dis_sem_vars,
                                                      global_step=global_step)

        Z_tf_sample, comp_tf_sample, full_tf_sample, full_ref_tf_sample, part_tf_sample, scores_tf_sample = depvox_gan_model.samples_generator(
            visual_size=batch_size)

        model_path = cfg.DIR.CHECK_POINT_PATH + '-d'
    else:
        model_path = cfg.DIR.CHECK_POINT_PATH

    writer = tf.summary.FileWriter(cfg.DIR.LOG_PATH, sess.graph_def)
    tf.initialize_all_variables().run(session=sess)

    if mid_flag:
        chckpt_path = model_path + '/checkpoint' + str(check_num)
        saver.restore(sess, chckpt_path)
        Z_var_np_sample = np.load(cfg.DIR.TRAIN_OBJ_PATH +
                                  '/sample_z.npy').astype(np.float32)
        Z_var_np_sample = Z_var_np_sample[:batch_size]
        print('---weights restored')
    else:
        Z_var_np_sample = np.random.normal(size=(batch_size, start_vox_size[0],
                                                 start_vox_size[1],
                                                 start_vox_size[2],
                                                 dim_z)).astype(np.float32)
        np.save(cfg.DIR.TRAIN_OBJ_PATH + '/sample_z.npy', Z_var_np_sample)

    ite = check_num * freq + 1
    cur_epochs = int(ite / int(len(data_paths) / batch_size))

    #training
    for epoch in np.arange(cur_epochs, n_epochs):
        epoch_flag = True
        while epoch_flag:
            print(colored('---Iteration:%d, epoch:%d', 'blue') % (ite, epoch))
            db_inds, epoch_flag = data_process.get_next_minibatch()
            batch_tsdf = data_process.get_tsdf(db_inds)
            batch_surf = data_process.get_surf(db_inds)
            batch_voxel = data_process.get_voxel(db_inds)

            # Evaluation masks
            # NOTICE that the target should never have negative values,
            # otherwise the one-hot coding never works for that region
            if cfg.TYPE_TASK == 'scene':
                """
                space_effective = np.where(batch_voxel > -1, 1, 0) * np.where(batch_tsdf > -1, 1, 0)
                batch_voxel *= space_effective
                batch_tsdf *= space_effective
                # occluded region
                """
                batch_tsdf[batch_tsdf < -1] = 0
                batch_surf[batch_surf < 0] = 0
                batch_voxel[batch_voxel < 0] = 0

            lr = learning_rate(cfg.LEARNING_RATE_V, ite)

            batch_z_var = np.random.normal(size=(batch_size, start_vox_size[0],
                                                 start_vox_size[1],
                                                 start_vox_size[2],
                                                 dim_z)).astype(np.float32)

            # updating for the main network
            is_supervised = True
            if is_supervised is True:
                _ = sess.run(
                    train_op_pred_sscnet,
                    feed_dict={
                        Z_tf: batch_z_var,
                        part_tf: batch_tsdf,
                        surf_tf: batch_surf,
                        full_tf: batch_voxel,
                        lr_VAE: lr
                    },
                )
                _, _, _ = sess.run(
                    [train_op_pred_com, train_op_pred_sem, train_op_refine],
                    feed_dict={
                        Z_tf: batch_z_var,
                        part_tf: batch_tsdf,
                        surf_tf: batch_surf,
                        full_tf: batch_voxel,
                        lr_VAE: lr
                    },
                )
            gen_com_loss_val, gen_sem_loss_val, z_part_enc_val = sess.run(
                [recons_com_loss_tf, recons_sem_loss_tf, z_part_enc_tf],
                feed_dict={
                    Z_tf: batch_z_var,
                    part_tf: batch_tsdf,
                    surf_tf: batch_surf,
                    full_tf: batch_voxel,
                    lr_VAE: lr
                },
            )

            if discriminative is True:
                discrim_loss_val, gen_loss_val, scores_discrim = sess.run(
                    [discrim_loss_tf, gen_loss_tf, scores_tf],
                    feed_dict={
                        Z_tf: batch_z_var,
                        part_tf: batch_tsdf,
                        surf_tf: batch_surf,
                        full_tf: batch_voxel,
                    },
                )
                if scores_discrim[0] - scores_discrim[1] > 0.3:
                    _ = sess.run(
                        train_op_gen_sdf,
                        feed_dict={
                            Z_tf: batch_z_var,
                            part_tf: batch_tsdf,
                            surf_tf: batch_surf,
                            full_tf: batch_voxel,
                            lr_VAE: lr
                        },
                    )
                if scores_discrim[2] - scores_discrim[3] > 0.3:
                    _ = sess.run(
                        train_op_gen_com,
                        feed_dict={
                            Z_tf: batch_z_var,
                            part_tf: batch_tsdf,
                            surf_tf: batch_surf,
                            full_tf: batch_voxel,
                            lr_VAE: lr
                        },
                    )
                if scores_discrim[4] - scores_discrim[5] > 0.3:
                    _ = sess.run(
                        train_op_gen_sem,
                        feed_dict={
                            Z_tf: batch_z_var,
                            part_tf: batch_tsdf,
                            surf_tf: batch_surf,
                            full_tf: batch_voxel,
                            lr_VAE: lr
                        },
                    )
                _ = sess.run(
                    train_op_dis_sdf,
                    feed_dict={
                        Z_tf: batch_z_var,
                        part_tf: batch_tsdf,
                        surf_tf: batch_surf,
                        full_tf: batch_voxel,
                    },
                )
                _ = sess.run(
                    train_op_dis_com,
                    feed_dict={
                        Z_tf: batch_z_var,
                        part_tf: batch_tsdf,
                        surf_tf: batch_surf,
                        full_tf: batch_voxel,
                    },
                )
                _ = sess.run(
                    train_op_dis_sem,
                    feed_dict={
                        Z_tf: batch_z_var,
                        part_tf: batch_tsdf,
                        surf_tf: batch_surf,
                        full_tf: batch_voxel,
                    },
                )

            print('GAN')
            np.set_printoptions(precision=2)
            print('reconstruct-com loss:', gen_com_loss_val)

            print('reconstruct-sem loss:', gen_sem_loss_val)

            if discriminative is True:
                print(
                    '            gen loss:', "%.2f" % gen_loss_val if
                    ('gen_loss_val' in locals()) else 'None')

                print(
                    '      output discrim:', "%.2f" % discrim_loss_val if
                    ('discrim_loss_val' in locals()) else 'None')

                print(
                    '      scores discrim:',
                    colored("%.2f" % scores_discrim[0], 'green'),
                    colored("%.2f" % scores_discrim[1], 'magenta'),
                    colored("%.2f" % scores_discrim[2], 'green'),
                    colored("%.2f" % scores_discrim[3], 'magenta'),
                    colored("%.2f" % scores_discrim[4], 'green'),
                    colored("%.2f" % scores_discrim[5], 'magenta') if
                    ('scores_discrim' in locals()) else 'None')

            print(
                '     avarage of code:',
                np.mean(np.mean(z_part_enc_val, 4)) if
                ('z_part_enc_val' in locals()) else 'None')

            print(
                '         std of code:',
                np.mean(np.std(z_part_enc_val, 4)) if
                ('z_part_enc_val' in locals()) else 'None')

            if np.mod(ite, freq) == 0:
                if discriminative is True:
                    full_models = sess.run(
                        full_tf_sample,
                        feed_dict={Z_tf_sample: Z_var_np_sample},
                    )
                    full_models_cat = np.argmax(full_models, axis=4)
                    record_vox = full_models_cat[:record_vox_num]
                    np.save(
                        cfg.DIR.TRAIN_OBJ_PATH + '/' + str(ite // freq) +
                        '.npy', record_vox)
                save_path = saver.save(sess,
                                       model_path + '/checkpoint' +
                                       str(ite // freq),
                                       global_step=None)

            ite += 1
Exemple #5
0
def train(n_epochs, learning_rate_G, learning_rate_D, batch_size, mid_flag,
          check_num):
    beta_G = cfg.TRAIN.ADAM_BETA_G
    beta_D = cfg.TRAIN.ADAM_BETA_D
    n_vox = cfg.CONST.N_VOX
    dim = cfg.NET.DIM
    vox_shape = [n_vox[0], n_vox[1], n_vox[2], dim[4]]
    dim_z = cfg.NET.DIM_Z
    start_vox_size = cfg.NET.START_VOX
    kernel = cfg.NET.KERNEL
    stride = cfg.NET.STRIDE
    freq = cfg.CHECK_FREQ
    record_vox_num = cfg.RECORD_VOX_NUM
    refine_ch = cfg.NET.REFINE_CH
    refine_kernel = cfg.NET.REFINE_KERNEL

    refine_start = cfg.SWITCHING_ITE

    fcr_agan_model = FCR_aGAN(
        batch_size=batch_size,
        vox_shape=vox_shape,
        dim_z=dim_z,
        dim=dim,
        start_vox_size=start_vox_size,
        kernel=kernel,
        stride=stride,
        refine_ch=refine_ch,
        refine_kernel=refine_kernel,
    )

    Z_tf, z_enc_tf, vox_tf, vox_gen_tf, vox_gen_decode_tf, vox_refine_dec_tf, vox_refine_gen_tf,\
     recons_loss_tf, code_encode_loss_tf, gen_loss_tf, discrim_loss_tf, recons_loss_refine_tf, gen_loss_refine_tf, discrim_loss_refine_tf,\
      cost_enc_tf, cost_code_tf, cost_gen_tf, cost_discrim_tf, cost_gen_ref_tf, cost_discrim_ref_tf, summary_tf = fcr_agan_model.build_model()

    sess = tf.InteractiveSession()
    global_step = tf.Variable(0, name='global_step', trainable=False)
    saver = tf.train.Saver(max_to_keep=cfg.SAVER_MAX)

    data_paths = scene_model_id_pair(dataset_portion=cfg.TRAIN.DATASET_PORTION)
    print '---amount of data:' + str(len(data_paths))
    data_process = DataProcess(data_paths, batch_size, repeat=True)

    encode_vars = filter(lambda x: x.name.startswith('enc'),
                         tf.trainable_variables())
    discrim_vars = filter(lambda x: x.name.startswith('discrim'),
                          tf.trainable_variables())
    gen_vars = filter(lambda x: x.name.startswith('gen'),
                      tf.trainable_variables())
    code_vars = filter(lambda x: x.name.startswith('cod'),
                       tf.trainable_variables())
    refine_vars = filter(lambda x: x.name.startswith('refine'),
                         tf.trainable_variables())

    lr_VAE = tf.placeholder(tf.float32, shape=[])
    train_op_encode = tf.train.AdamOptimizer(
        lr_VAE, beta1=beta_D, beta2=0.9).minimize(cost_enc_tf,
                                                  var_list=encode_vars)
    train_op_discrim = tf.train.AdamOptimizer(learning_rate_D,
                                              beta1=beta_D,
                                              beta2=0.9).minimize(
                                                  cost_discrim_tf,
                                                  var_list=discrim_vars,
                                                  global_step=global_step)
    train_op_gen = tf.train.AdamOptimizer(learning_rate_G,
                                          beta1=beta_G,
                                          beta2=0.9).minimize(
                                              cost_gen_tf, var_list=gen_vars)
    train_op_code = tf.train.AdamOptimizer(
        lr_VAE, beta1=beta_G, beta2=0.9).minimize(cost_code_tf,
                                                  var_list=code_vars)
    train_op_refine = tf.train.AdamOptimizer(
        lr_VAE, beta1=beta_G, beta2=0.9).minimize(cost_gen_ref_tf,
                                                  var_list=refine_vars)
    train_op_discrim_refine = tf.train.AdamOptimizer(
        learning_rate_D, beta1=beta_D,
        beta2=0.9).minimize(cost_discrim_ref_tf,
                            var_list=discrim_vars,
                            global_step=global_step)

    Z_tf_sample, vox_tf_sample = fcr_agan_model.samples_generator(
        visual_size=batch_size)
    sample_vox_tf, sample_refine_vox_tf = fcr_agan_model.refine_generator(
        visual_size=batch_size)
    writer = tf.train.SummaryWriter(cfg.DIR.LOG_PATH, sess.graph_def)
    tf.initialize_all_variables().run()

    if mid_flag:
        chckpt_path = cfg.DIR.CHECK_PT_PATH + str(check_num) + '-' + str(
            check_num * freq)
        saver.restore(sess, chckpt_path)
        Z_var_np_sample = np.load(cfg.DIR.TRAIN_OBJ_PATH +
                                  '/sample_z.npy').astype(np.float32)
        Z_var_np_sample = Z_var_np_sample[:batch_size]
        print '---weights restored'
    else:
        Z_var_np_sample = np.random.normal(size=(batch_size, start_vox_size[0],
                                                 start_vox_size[1],
                                                 start_vox_size[2],
                                                 dim_z)).astype(np.float32)
        np.save(cfg.DIR.TRAIN_OBJ_PATH + '/sample_z.npy', Z_var_np_sample)

    ite = check_num * freq + 1
    cur_epochs = int(ite / int(len(data_paths) / batch_size))

    #training
    for epoch in np.arange(cur_epochs, n_epochs):
        epoch_flag = True
        while epoch_flag:
            print '=iteration:%d, epoch:%d' % (ite, epoch)
            db_inds, epoch_flag = data_process.get_next_minibatch()
            batch_voxel = data_process.get_voxel(db_inds)
            batch_voxel_train = batch_voxel
            lr = learning_rate(cfg.LEARNING_RATE_V, ite)

            batch_z_var = np.random.normal(size=(batch_size, start_vox_size[0],
                                                 start_vox_size[1],
                                                 start_vox_size[2],
                                                 dim_z)).astype(np.float32)

            if ite < refine_start:
                for s in np.arange(2):
                    _, recons_loss_val, code_encode_loss_val, cost_enc_val = sess.run(
                        [
                            train_op_encode, recons_loss_tf,
                            code_encode_loss_tf, cost_enc_tf
                        ],
                        feed_dict={
                            vox_tf: batch_voxel_train,
                            Z_tf: batch_z_var,
                            lr_VAE: lr
                        },
                    )

                    _, gen_loss_val, cost_gen_val = sess.run(
                        [train_op_gen, gen_loss_tf, cost_gen_tf],
                        feed_dict={
                            Z_tf: batch_z_var,
                            vox_tf: batch_voxel_train,
                            lr_VAE: lr
                        },
                    )

                _, discrim_loss_val, cost_discrim_val = sess.run(
                    [train_op_discrim, discrim_loss_tf, cost_discrim_tf],
                    feed_dict={
                        Z_tf: batch_z_var,
                        vox_tf: batch_voxel_train
                    },
                )

                _, cost_code_val, z_enc_val, summary = sess.run(
                    [train_op_code, cost_code_tf, z_enc_tf, summary_tf],
                    feed_dict={
                        Z_tf: batch_z_var,
                        vox_tf: batch_voxel_train,
                        lr_VAE: lr
                    },
                )

                print 'reconstruction loss:', recons_loss_val
                print '   code encode loss:', code_encode_loss_val
                print '           gen loss:', gen_loss_val
                print '       cost_encoder:', cost_enc_val
                print '     cost_generator:', cost_gen_val
                print ' cost_discriminator:', cost_discrim_val
                print '          cost_code:', cost_code_val
                print '   avarage of enc_z:', np.mean(np.mean(z_enc_val, 4))
                print '       std of enc_z:', np.mean(np.std(z_enc_val, 4))

                if np.mod(ite, freq) == 0:
                    vox_models = sess.run(
                        vox_tf_sample,
                        feed_dict={Z_tf_sample: Z_var_np_sample},
                    )
                    vox_models_cat = np.argmax(vox_models, axis=4)
                    record_vox = vox_models_cat[:record_vox_num]
                    np.save(
                        cfg.DIR.TRAIN_OBJ_PATH + '/' + str(ite / freq) +
                        '.npy', record_vox)
                    save_path = saver.save(sess,
                                           cfg.DIR.CHECK_PT_PATH +
                                           str(ite / freq),
                                           global_step=global_step)

            else:
                _, recons_loss_val, recons_loss_refine_val, gen_loss_refine_val, cost_gen_ref_val = sess.run(
                    [
                        train_op_refine, recons_loss_tf, recons_loss_refine_tf,
                        gen_loss_refine_tf, cost_gen_ref_tf
                    ],
                    feed_dict={
                        Z_tf: batch_z_var,
                        vox_tf: batch_voxel_train,
                        lr_VAE: lr
                    },
                )

                _, discrim_loss_refine_val, cost_discrim_ref_val, summary = sess.run(
                    [
                        train_op_discrim_refine, discrim_loss_refine_tf,
                        cost_discrim_ref_tf, summary_tf
                    ],
                    feed_dict={
                        Z_tf: batch_z_var,
                        vox_tf: batch_voxel_train
                    },
                )

                print 'reconstruction loss:', recons_loss_val
                print ' recons refine loss:', recons_loss_refine_val
                print '           gen loss:', gen_loss_refine_val
                print ' cost_discriminator:', cost_discrim_ref_val

                if np.mod(ite, freq) == 0:
                    vox_models = sess.run(
                        vox_tf_sample,
                        feed_dict={Z_tf_sample: Z_var_np_sample},
                    )
                    refined_models = sess.run(
                        sample_refine_vox_tf,
                        feed_dict={sample_vox_tf: vox_models})
                    vox_models_cat = np.argmax(vox_models, axis=4)
                    record_vox = vox_models_cat[:record_vox_num]
                    np.save(
                        cfg.DIR.TRAIN_OBJ_PATH + '/' + str(ite / freq) +
                        '.npy', record_vox)

                    vox_models_cat = np.argmax(refined_models, axis=4)
                    record_vox = vox_models_cat[:record_vox_num]
                    np.save(
                        cfg.DIR.TRAIN_OBJ_PATH + '/' + str(ite / freq) +
                        '_refine.npy', record_vox)
                    save_path = saver.save(sess,
                                           cfg.DIR.CHECK_PT_PATH +
                                           str(ite / freq),
                                           global_step=global_step)

            writer.add_summary(summary, global_step=ite)

            ite += 1