def train(n_epochs, learning_rate_G, learning_rate_D, batch_size, mid_flag, check_num, discriminative): beta_G = cfg.TRAIN.ADAM_BETA_G beta_D = cfg.TRAIN.ADAM_BETA_D n_vox = cfg.CONST.N_VOX dim = cfg.NET.DIM vox_shape = [n_vox[0], n_vox[1], n_vox[2], dim[-1]] com_shape = [n_vox[0], n_vox[1], n_vox[2], 2] dim_z = cfg.NET.DIM_Z start_vox_size = cfg.NET.START_VOX kernel = cfg.NET.KERNEL stride = cfg.NET.STRIDE dilations = cfg.NET.DILATIONS freq = cfg.CHECK_FREQ record_vox_num = cfg.RECORD_VOX_NUM depvox_gan_model = depvox_gan(batch_size=batch_size, vox_shape=vox_shape, com_shape=com_shape, dim_z=dim_z, dim=dim, start_vox_size=start_vox_size, kernel=kernel, stride=stride, dilations=dilations, discriminative=discriminative, is_train=True) Z_tf, z_part_enc_tf, surf_tf, full_tf, full_gen_tf, surf_dec_tf, full_dec_tf,\ gen_loss_tf, discrim_loss_tf, recons_ssc_loss_tf, recons_com_loss_tf, recons_sem_loss_tf, encode_loss_tf, refine_loss_tf, summary_tf,\ part_tf, part_dec_tf, complete_gt_tf, complete_gen_tf, complete_dec_tf, sscnet_tf, scores_tf = depvox_gan_model.build_model() global_step = tf.Variable(0, name='global_step', trainable=False) config_gpu = tf.ConfigProto() config_gpu.gpu_options.allow_growth = True sess = tf.Session(config=config_gpu) saver = tf.train.Saver(max_to_keep=cfg.SAVER_MAX) data_paths = scene_model_id_pair(dataset_portion=cfg.TRAIN.DATASET_PORTION) print('---amount of data:', str(len(data_paths))) data_process = DataProcess(data_paths, batch_size, repeat=True) enc_sscnet_vars = list( filter(lambda x: x.name.startswith('enc_ssc'), tf.trainable_variables())) enc_sdf_vars = list( filter(lambda x: x.name.startswith('enc_x'), tf.trainable_variables())) dis_sdf_vars = list( filter(lambda x: x.name.startswith('dis_x'), tf.trainable_variables())) dis_com_vars = list( filter(lambda x: x.name.startswith('dis_g'), tf.trainable_variables())) dis_sem_vars = list( filter(lambda x: x.name.startswith('dis_y'), tf.trainable_variables())) gen_com_vars = list( filter(lambda x: x.name.startswith('gen_x'), tf.trainable_variables())) gen_sem_vars = list( filter(lambda x: x.name.startswith('gen_y'), tf.trainable_variables())) gen_sdf_vars = list( filter(lambda x: x.name.startswith('gen_z'), tf.trainable_variables())) refine_vars = list( filter(lambda x: x.name.startswith('gen_y_ref'), tf.trainable_variables())) lr_VAE = tf.placeholder(tf.float32, shape=[]) # main optimiser train_op_pred_sscnet = tf.train.AdamOptimizer(learning_rate_G, beta1=beta_G, beta2=0.9).minimize( recons_ssc_loss_tf, var_list=enc_sscnet_vars) train_op_pred_com = tf.train.AdamOptimizer( learning_rate_G, beta1=beta_G, beta2=0.9).minimize( recons_com_loss_tf, var_list=enc_sdf_vars + gen_com_vars + gen_sdf_vars) train_op_pred_sem = tf.train.AdamOptimizer( learning_rate_G, beta1=beta_G, beta2=0.9).minimize( recons_sem_loss_tf, var_list=enc_sdf_vars + gen_sem_vars + gen_sdf_vars) # refine optimiser train_op_refine = tf.train.AdamOptimizer(learning_rate_G, beta1=beta_G, beta2=0.9).minimize( refine_loss_tf, var_list=refine_vars) if discriminative is True: train_op_gen_sdf = tf.train.AdamOptimizer(learning_rate_G, beta1=beta_G, beta2=0.9).minimize( gen_loss_tf, var_list=gen_sdf_vars) train_op_gen_com = tf.train.AdamOptimizer(learning_rate_G, beta1=beta_G, beta2=0.9).minimize( gen_loss_tf, var_list=gen_com_vars) train_op_gen_sem = tf.train.AdamOptimizer( learning_rate_G, beta1=beta_G, beta2=0.9).minimize(gen_loss_tf, var_list=gen_sem_vars + gen_com_vars) train_op_dis_sdf = tf.train.AdamOptimizer(learning_rate_D, beta1=beta_D, beta2=0.9).minimize( discrim_loss_tf, var_list=dis_sdf_vars) train_op_dis_com = tf.train.AdamOptimizer(learning_rate_D, beta1=beta_D, beta2=0.9).minimize( discrim_loss_tf, var_list=dis_com_vars) train_op_dis_sem = tf.train.AdamOptimizer(learning_rate_D, beta1=beta_D, beta2=0.9).minimize( discrim_loss_tf, var_list=dis_sem_vars, global_step=global_step) Z_tf_sample, comp_tf_sample, full_tf_sample, full_ref_tf_sample, part_tf_sample, scores_tf_sample = depvox_gan_model.samples_generator( visual_size=batch_size) model_path = cfg.DIR.CHECK_POINT_PATH + '-d' else: model_path = cfg.DIR.CHECK_POINT_PATH writer = tf.summary.FileWriter(cfg.DIR.LOG_PATH, sess.graph_def) tf.initialize_all_variables().run(session=sess) if mid_flag: chckpt_path = model_path + '/checkpoint' + str(check_num) saver.restore(sess, chckpt_path) Z_var_np_sample = np.load(cfg.DIR.TRAIN_OBJ_PATH + '/sample_z.npy').astype(np.float32) Z_var_np_sample = Z_var_np_sample[:batch_size] print('---weights restored') else: Z_var_np_sample = np.random.normal(size=(batch_size, start_vox_size[0], start_vox_size[1], start_vox_size[2], dim_z)).astype(np.float32) np.save(cfg.DIR.TRAIN_OBJ_PATH + '/sample_z.npy', Z_var_np_sample) ite = check_num * freq + 1 cur_epochs = int(ite / int(len(data_paths) / batch_size)) #training for epoch in np.arange(cur_epochs, n_epochs): epoch_flag = True while epoch_flag: print(colored('---Iteration:%d, epoch:%d', 'blue') % (ite, epoch)) db_inds, epoch_flag = data_process.get_next_minibatch() batch_tsdf = data_process.get_tsdf(db_inds) batch_surf = data_process.get_surf(db_inds) batch_voxel = data_process.get_voxel(db_inds) # Evaluation masks # NOTICE that the target should never have negative values, # otherwise the one-hot coding never works for that region if cfg.TYPE_TASK == 'scene': """ space_effective = np.where(batch_voxel > -1, 1, 0) * np.where(batch_tsdf > -1, 1, 0) batch_voxel *= space_effective batch_tsdf *= space_effective # occluded region """ batch_tsdf[batch_tsdf < -1] = 0 batch_surf[batch_surf < 0] = 0 batch_voxel[batch_voxel < 0] = 0 lr = learning_rate(cfg.LEARNING_RATE_V, ite) batch_z_var = np.random.normal(size=(batch_size, start_vox_size[0], start_vox_size[1], start_vox_size[2], dim_z)).astype(np.float32) # updating for the main network is_supervised = True if is_supervised is True: _ = sess.run( train_op_pred_sscnet, feed_dict={ Z_tf: batch_z_var, part_tf: batch_tsdf, surf_tf: batch_surf, full_tf: batch_voxel, lr_VAE: lr }, ) _, _, _ = sess.run( [train_op_pred_com, train_op_pred_sem, train_op_refine], feed_dict={ Z_tf: batch_z_var, part_tf: batch_tsdf, surf_tf: batch_surf, full_tf: batch_voxel, lr_VAE: lr }, ) gen_com_loss_val, gen_sem_loss_val, z_part_enc_val = sess.run( [recons_com_loss_tf, recons_sem_loss_tf, z_part_enc_tf], feed_dict={ Z_tf: batch_z_var, part_tf: batch_tsdf, surf_tf: batch_surf, full_tf: batch_voxel, lr_VAE: lr }, ) if discriminative is True: discrim_loss_val, gen_loss_val, scores_discrim = sess.run( [discrim_loss_tf, gen_loss_tf, scores_tf], feed_dict={ Z_tf: batch_z_var, part_tf: batch_tsdf, surf_tf: batch_surf, full_tf: batch_voxel, }, ) if scores_discrim[0] - scores_discrim[1] > 0.3: _ = sess.run( train_op_gen_sdf, feed_dict={ Z_tf: batch_z_var, part_tf: batch_tsdf, surf_tf: batch_surf, full_tf: batch_voxel, lr_VAE: lr }, ) if scores_discrim[2] - scores_discrim[3] > 0.3: _ = sess.run( train_op_gen_com, feed_dict={ Z_tf: batch_z_var, part_tf: batch_tsdf, surf_tf: batch_surf, full_tf: batch_voxel, lr_VAE: lr }, ) if scores_discrim[4] - scores_discrim[5] > 0.3: _ = sess.run( train_op_gen_sem, feed_dict={ Z_tf: batch_z_var, part_tf: batch_tsdf, surf_tf: batch_surf, full_tf: batch_voxel, lr_VAE: lr }, ) _ = sess.run( train_op_dis_sdf, feed_dict={ Z_tf: batch_z_var, part_tf: batch_tsdf, surf_tf: batch_surf, full_tf: batch_voxel, }, ) _ = sess.run( train_op_dis_com, feed_dict={ Z_tf: batch_z_var, part_tf: batch_tsdf, surf_tf: batch_surf, full_tf: batch_voxel, }, ) _ = sess.run( train_op_dis_sem, feed_dict={ Z_tf: batch_z_var, part_tf: batch_tsdf, surf_tf: batch_surf, full_tf: batch_voxel, }, ) print('GAN') np.set_printoptions(precision=2) print('reconstruct-com loss:', gen_com_loss_val) print('reconstruct-sem loss:', gen_sem_loss_val) if discriminative is True: print( ' gen loss:', "%.2f" % gen_loss_val if ('gen_loss_val' in locals()) else 'None') print( ' output discrim:', "%.2f" % discrim_loss_val if ('discrim_loss_val' in locals()) else 'None') print( ' scores discrim:', colored("%.2f" % scores_discrim[0], 'green'), colored("%.2f" % scores_discrim[1], 'magenta'), colored("%.2f" % scores_discrim[2], 'green'), colored("%.2f" % scores_discrim[3], 'magenta'), colored("%.2f" % scores_discrim[4], 'green'), colored("%.2f" % scores_discrim[5], 'magenta') if ('scores_discrim' in locals()) else 'None') print( ' avarage of code:', np.mean(np.mean(z_part_enc_val, 4)) if ('z_part_enc_val' in locals()) else 'None') print( ' std of code:', np.mean(np.std(z_part_enc_val, 4)) if ('z_part_enc_val' in locals()) else 'None') if np.mod(ite, freq) == 0: if discriminative is True: full_models = sess.run( full_tf_sample, feed_dict={Z_tf_sample: Z_var_np_sample}, ) full_models_cat = np.argmax(full_models, axis=4) record_vox = full_models_cat[:record_vox_num] np.save( cfg.DIR.TRAIN_OBJ_PATH + '/' + str(ite // freq) + '.npy', record_vox) save_path = saver.save(sess, model_path + '/checkpoint' + str(ite // freq), global_step=None) ite += 1
def train(n_epochs, learning_rate_G, learning_rate_D, batch_size, mid_flag, check_num): beta_G = cfg.TRAIN.ADAM_BETA_G beta_D = cfg.TRAIN.ADAM_BETA_D n_vox = cfg.CONST.N_VOX dim = cfg.NET.DIM vox_shape = [n_vox[0], n_vox[1], n_vox[2], dim[4]] dim_z = cfg.NET.DIM_Z start_vox_size = cfg.NET.START_VOX kernel = cfg.NET.KERNEL stride = cfg.NET.STRIDE freq = cfg.CHECK_FREQ record_vox_num = cfg.RECORD_VOX_NUM refine_ch = cfg.NET.REFINE_CH refine_kernel = cfg.NET.REFINE_KERNEL refine_start = cfg.SWITCHING_ITE fcr_agan_model = FCR_aGAN( batch_size=batch_size, vox_shape=vox_shape, dim_z=dim_z, dim=dim, start_vox_size=start_vox_size, kernel=kernel, stride=stride, refine_ch=refine_ch, refine_kernel=refine_kernel, ) Z_tf, z_enc_tf, vox_tf, vox_gen_tf, vox_gen_decode_tf, vox_refine_dec_tf, vox_refine_gen_tf,\ recons_loss_tf, code_encode_loss_tf, gen_loss_tf, discrim_loss_tf, recons_loss_refine_tf, gen_loss_refine_tf, discrim_loss_refine_tf,\ cost_enc_tf, cost_code_tf, cost_gen_tf, cost_discrim_tf, cost_gen_ref_tf, cost_discrim_ref_tf, summary_tf = fcr_agan_model.build_model() sess = tf.InteractiveSession() global_step = tf.Variable(0, name='global_step', trainable=False) saver = tf.train.Saver(max_to_keep=cfg.SAVER_MAX) data_paths = scene_model_id_pair(dataset_portion=cfg.TRAIN.DATASET_PORTION) print '---amount of data:' + str(len(data_paths)) data_process = DataProcess(data_paths, batch_size, repeat=True) encode_vars = filter(lambda x: x.name.startswith('enc'), tf.trainable_variables()) discrim_vars = filter(lambda x: x.name.startswith('discrim'), tf.trainable_variables()) gen_vars = filter(lambda x: x.name.startswith('gen'), tf.trainable_variables()) code_vars = filter(lambda x: x.name.startswith('cod'), tf.trainable_variables()) refine_vars = filter(lambda x: x.name.startswith('refine'), tf.trainable_variables()) lr_VAE = tf.placeholder(tf.float32, shape=[]) train_op_encode = tf.train.AdamOptimizer( lr_VAE, beta1=beta_D, beta2=0.9).minimize(cost_enc_tf, var_list=encode_vars) train_op_discrim = tf.train.AdamOptimizer(learning_rate_D, beta1=beta_D, beta2=0.9).minimize( cost_discrim_tf, var_list=discrim_vars, global_step=global_step) train_op_gen = tf.train.AdamOptimizer(learning_rate_G, beta1=beta_G, beta2=0.9).minimize( cost_gen_tf, var_list=gen_vars) train_op_code = tf.train.AdamOptimizer( lr_VAE, beta1=beta_G, beta2=0.9).minimize(cost_code_tf, var_list=code_vars) train_op_refine = tf.train.AdamOptimizer( lr_VAE, beta1=beta_G, beta2=0.9).minimize(cost_gen_ref_tf, var_list=refine_vars) train_op_discrim_refine = tf.train.AdamOptimizer( learning_rate_D, beta1=beta_D, beta2=0.9).minimize(cost_discrim_ref_tf, var_list=discrim_vars, global_step=global_step) Z_tf_sample, vox_tf_sample = fcr_agan_model.samples_generator( visual_size=batch_size) sample_vox_tf, sample_refine_vox_tf = fcr_agan_model.refine_generator( visual_size=batch_size) writer = tf.train.SummaryWriter(cfg.DIR.LOG_PATH, sess.graph_def) tf.initialize_all_variables().run() if mid_flag: chckpt_path = cfg.DIR.CHECK_PT_PATH + str(check_num) + '-' + str( check_num * freq) saver.restore(sess, chckpt_path) Z_var_np_sample = np.load(cfg.DIR.TRAIN_OBJ_PATH + '/sample_z.npy').astype(np.float32) Z_var_np_sample = Z_var_np_sample[:batch_size] print '---weights restored' else: Z_var_np_sample = np.random.normal(size=(batch_size, start_vox_size[0], start_vox_size[1], start_vox_size[2], dim_z)).astype(np.float32) np.save(cfg.DIR.TRAIN_OBJ_PATH + '/sample_z.npy', Z_var_np_sample) ite = check_num * freq + 1 cur_epochs = int(ite / int(len(data_paths) / batch_size)) #training for epoch in np.arange(cur_epochs, n_epochs): epoch_flag = True while epoch_flag: print '=iteration:%d, epoch:%d' % (ite, epoch) db_inds, epoch_flag = data_process.get_next_minibatch() batch_voxel = data_process.get_voxel(db_inds) batch_voxel_train = batch_voxel lr = learning_rate(cfg.LEARNING_RATE_V, ite) batch_z_var = np.random.normal(size=(batch_size, start_vox_size[0], start_vox_size[1], start_vox_size[2], dim_z)).astype(np.float32) if ite < refine_start: for s in np.arange(2): _, recons_loss_val, code_encode_loss_val, cost_enc_val = sess.run( [ train_op_encode, recons_loss_tf, code_encode_loss_tf, cost_enc_tf ], feed_dict={ vox_tf: batch_voxel_train, Z_tf: batch_z_var, lr_VAE: lr }, ) _, gen_loss_val, cost_gen_val = sess.run( [train_op_gen, gen_loss_tf, cost_gen_tf], feed_dict={ Z_tf: batch_z_var, vox_tf: batch_voxel_train, lr_VAE: lr }, ) _, discrim_loss_val, cost_discrim_val = sess.run( [train_op_discrim, discrim_loss_tf, cost_discrim_tf], feed_dict={ Z_tf: batch_z_var, vox_tf: batch_voxel_train }, ) _, cost_code_val, z_enc_val, summary = sess.run( [train_op_code, cost_code_tf, z_enc_tf, summary_tf], feed_dict={ Z_tf: batch_z_var, vox_tf: batch_voxel_train, lr_VAE: lr }, ) print 'reconstruction loss:', recons_loss_val print ' code encode loss:', code_encode_loss_val print ' gen loss:', gen_loss_val print ' cost_encoder:', cost_enc_val print ' cost_generator:', cost_gen_val print ' cost_discriminator:', cost_discrim_val print ' cost_code:', cost_code_val print ' avarage of enc_z:', np.mean(np.mean(z_enc_val, 4)) print ' std of enc_z:', np.mean(np.std(z_enc_val, 4)) if np.mod(ite, freq) == 0: vox_models = sess.run( vox_tf_sample, feed_dict={Z_tf_sample: Z_var_np_sample}, ) vox_models_cat = np.argmax(vox_models, axis=4) record_vox = vox_models_cat[:record_vox_num] np.save( cfg.DIR.TRAIN_OBJ_PATH + '/' + str(ite / freq) + '.npy', record_vox) save_path = saver.save(sess, cfg.DIR.CHECK_PT_PATH + str(ite / freq), global_step=global_step) else: _, recons_loss_val, recons_loss_refine_val, gen_loss_refine_val, cost_gen_ref_val = sess.run( [ train_op_refine, recons_loss_tf, recons_loss_refine_tf, gen_loss_refine_tf, cost_gen_ref_tf ], feed_dict={ Z_tf: batch_z_var, vox_tf: batch_voxel_train, lr_VAE: lr }, ) _, discrim_loss_refine_val, cost_discrim_ref_val, summary = sess.run( [ train_op_discrim_refine, discrim_loss_refine_tf, cost_discrim_ref_tf, summary_tf ], feed_dict={ Z_tf: batch_z_var, vox_tf: batch_voxel_train }, ) print 'reconstruction loss:', recons_loss_val print ' recons refine loss:', recons_loss_refine_val print ' gen loss:', gen_loss_refine_val print ' cost_discriminator:', cost_discrim_ref_val if np.mod(ite, freq) == 0: vox_models = sess.run( vox_tf_sample, feed_dict={Z_tf_sample: Z_var_np_sample}, ) refined_models = sess.run( sample_refine_vox_tf, feed_dict={sample_vox_tf: vox_models}) vox_models_cat = np.argmax(vox_models, axis=4) record_vox = vox_models_cat[:record_vox_num] np.save( cfg.DIR.TRAIN_OBJ_PATH + '/' + str(ite / freq) + '.npy', record_vox) vox_models_cat = np.argmax(refined_models, axis=4) record_vox = vox_models_cat[:record_vox_num] np.save( cfg.DIR.TRAIN_OBJ_PATH + '/' + str(ite / freq) + '_refine.npy', record_vox) save_path = saver.save(sess, cfg.DIR.CHECK_PT_PATH + str(ite / freq), global_step=global_step) writer.add_summary(summary, global_step=ite) ite += 1