def create_decoder(self, features): with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE): coarse = mlp(features, [2048, 2048, self.num_coarse * (3 + 12)]) coarse = tf.reshape(coarse, [-1, self.num_coarse, 3 + 12]) with tf.variable_scope('folding', reuse=tf.AUTO_REUSE): grid = tf.meshgrid(tf.linspace(-0.05, 0.05, self.grid_size), tf.linspace(-0.05, 0.05, self.grid_size)) grid = tf.expand_dims(tf.reshape(tf.stack(grid, axis=2), [-1, 2]), 0) grid_feat = tf.tile(grid, [features.shape[0], self.num_coarse, 1]) point_feat = tf.tile(tf.expand_dims(coarse, 2), [1, 1, self.grid_size**2, 1]) point_feat = tf.reshape(point_feat, [-1, self.num_fine, 3 + 12]) global_feat = tf.tile(tf.expand_dims(features, 1), [1, self.num_fine, 1]) feat = tf.concat([grid_feat, point_feat, global_feat], axis=2) # feat_1 = mlp(feat, [512, 512, 3]) # feat_2 = tf.concat([feat_1, point_feat, global_feat], axis=2) center = tf.tile(tf.expand_dims(coarse, 2), [1, 1, self.grid_size**2, 1]) center = tf.reshape(center, [-1, self.num_fine, 3 + 12]) # with tf.variable_scope('folding_1', reuse=tf.AUTO_REUSE): fine = mlp(feat, [512, 512, 3 + 12]) + center return coarse, fine
def while_body(i, est_pose, est_inputs): with tf.variable_scope('encoder_0', reuse=tf.AUTO_REUSE): features = mlp_conv(est_inputs, [128, 256], bn=self.bn) features_global = tf.reduce_max(features, axis=1, keep_dims=True, name='maxpool_0') features = tf.concat([features, tf.tile(features_global, [1, tf.shape(inputs)[1], 1])], axis=2) with tf.variable_scope('encoder_1', reuse=tf.AUTO_REUSE): features = mlp_conv(features, [512, 1024], bn=self.bn) features = tf.reduce_max(features, axis=1, name='maxpool_1') with tf.variable_scope('fc', reuse=tf.AUTO_REUSE): if self.rot_representation == 'quat': est_pose_rep_i = mlp(features, [1024, 1024, 512, 512, 256, 7], bn=self.bn) elif self.rot_representation == '6dof': est_pose_rep_i = mlp(features, [1024, 1024, 512, 512, 256, 9], bn=self.bn) with tf.variable_scope('est', reuse=tf.AUTO_REUSE): if self.rot_representation == 'quat': t = tf.expand_dims(est_pose_rep_i[:, 4:], axis=2) q = est_pose_rep_i[:, 0:4] R = quat2rotm_tf(q) elif self.rot_representation == '6dof': t = tf.expand_dims(est_pose_rep_i[:, 6:], axis=2) mat6d = est_pose_rep_i[:, 0:6] R = mat6d2rotm_tf(mat6d) est_pose_T_i = tf.concat([ tf.concat([R, t], axis=2), tf.concat([tf.zeros([B, 1, 3]), tf.ones([B, 1, 1])], axis=2)], axis=1) est_inputs = transform_tf(est_inputs, est_pose_T_i) est_pose = tf.linalg.matmul(est_pose_T_i, est_pose) return [tf.add(i, 1), est_pose, est_inputs]
def decoder(inputs, features, step_ratio=16, num_fine=16 * 1024): num_coarse=1024 assert num_fine == num_coarse * step_ratio with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE): coarse = mlp(features, [1024, 1024, num_coarse * 3], bn=None, bn_params=None) coarse = tf.reshape(coarse, [-1, num_coarse, 3]) p1_idx = farthest_point_sample(512, coarse) coarse_1 = gather_point(coarse, p1_idx) input_fps = symmetric_sample(inputs, int(512 / 2)) coarse = tf.concat([input_fps, coarse_1], 1) with tf.variable_scope('folding', reuse=tf.AUTO_REUSE): if not step_ratio ** .5 % 1 == 0: grid = gen_1d_grid(step_ratio) else: grid = gen_grid(np.round(np.sqrt(step_ratio)).astype(np.int32)) grid = tf.expand_dims(grid, 0) grid_feat = tf.tile(grid, [features.shape[0], num_coarse, 1]) point_feat = tf.tile(tf.expand_dims(coarse, 2), [1, 1, step_ratio, 1]) point_feat = tf.reshape(point_feat, [-1, num_fine, 3]) global_feat = tf.tile(tf.expand_dims(features, 1), [1, num_fine, 1]) feat = tf.concat([grid_feat, point_feat, global_feat], axis=2) fine = mlp_conv(feat, [512, 512, 3], bn=None, bn_params=None) + point_feat return coarse, fine
def create_decoder(self, features): with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE): coarse = mlp(features, [1024, 1024, self.num_coarse * (3 + 12)]) coarse = tf.reshape(coarse, [-1, self.num_coarse, 3 + 12]) with tf.variable_scope('folding', reuse=tf.AUTO_REUSE): x = tf.linspace(-self.grid_scale, self.grid_scale, self.grid_size) y = tf.linspace(-self.grid_scale, self.grid_scale, self.grid_size) grid = tf.meshgrid(x, y) grid = tf.expand_dims(tf.reshape(tf.stack(grid, axis=2), [-1, 2]), 0) grid_feat = tf.tile(grid, [features.shape[0], self.num_coarse, 1]) point_feat = tf.tile(tf.expand_dims(coarse, 2), [1, 1, self.grid_size**2, 1]) point_feat = tf.reshape(point_feat, [-1, self.num_fine, 3 + 12]) global_feat = tf.tile(tf.expand_dims(features, 1), [1, self.num_fine, 1]) feat = tf.concat([grid_feat, point_feat, global_feat], axis=2) center = tf.tile(tf.expand_dims(coarse, 2), [1, 1, self.grid_size**2, 1]) center = tf.reshape(center, [-1, self.num_fine, 3 + 12]) fine = tf.squeeze( mlp_conv(tf.expand_dims(feat, -2), [512, 512, 3 + 12])) + center return coarse, fine
def create_decoder(self, features): with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE): coarse = mlp(features, [1024, 1024, self.num_coarse * 3]) coarse = tf.reshape(coarse, [-1, self.num_coarse, 3]) print('coarse', coarse) with tf.variable_scope('folding', reuse=tf.AUTO_REUSE): grid = tf.meshgrid(tf.linspace(-0.05, 0.05, self.grid_size), tf.linspace(-0.05, 0.05, self.grid_size)) print('grid:', grid) grid = tf.expand_dims(tf.reshape(tf.stack(grid, axis=2), [-1, 2]), 0) print('grid:', grid) grid_feat = tf.tile(grid, [features.shape[0], self.num_coarse, 1]) print('grid_feat', grid_feat) point_feat = tf.tile(tf.expand_dims(coarse, 2), [1, 1, self.grid_size ** 2, 1]) point_feat = tf.reshape(point_feat, [-1, self.num_fine, 3]) print('point_feat', point_feat) global_feat = tf.tile(tf.expand_dims(features, 1), [1, self.num_fine, 1]) print('global_feat', global_feat) feat = tf.concat([grid_feat, point_feat, global_feat], axis=2) print('feat:', feat) center = tf.tile(tf.expand_dims(coarse, 2), [1, 1, self.grid_size ** 2, 1]) center = tf.reshape(center, [-1, self.num_fine, 3]) print('center', center) fine = mlp_conv(feat, [512, 512, 3]) + center print('fine:', fine) return coarse, fine
def create_decoder(code, inputs, step_ratio, num_extract=512, mean_feature=None): with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE): level0 = tf_util.mlp(code, [1024, 1024, 512 * 3], bn=None, bn_params=None, name='coarse') # ,name='coarse' level0 = tf.tanh(level0) level0 = tf.reshape(level0, [-1, 512, 3]) coarse = level0 input_fps = symmetric_sample(inputs, int(num_extract / 2)) level0 = tf.concat([input_fps, level0], 1) if num_extract > 512: level0 = gather_point(level0, farthest_point_sample(1024, level0)) for i in range(int(math.log2(step_ratio))): num_fine = 2**(i + 1) * 1024 grid = tf_util.gen_grid_up(2**(i + 1)) grid = tf.expand_dims(grid, 0) grid_feat = tf.tile(grid, [level0.shape[0], 1024, 1]) point_feat = tf.tile(tf.expand_dims(level0, 2), [1, 1, 2, 1]) point_feat = tf.reshape(point_feat, [-1, num_fine, 3]) global_feat = tf.tile(tf.expand_dims(code, 1), [1, num_fine, 1]) mean_feature_use = tf.contrib.layers.fully_connected( mean_feature, 128, activation_fn=tf.nn.relu, scope='mean_fc') mean_feature_use = tf.expand_dims(mean_feature_use, 1) mean_feature_use = tf.tile(mean_feature_use, [1, num_fine, 1]) feat = tf.concat( [grid_feat, point_feat, global_feat, mean_feature_use], axis=2) feat1 = tf_util.mlp_conv(feat, [128, 64], bn=None, bn_params=None, name='up_branch') feat1 = tf.nn.relu(feat1) feat2 = contract_expand_operation(feat1, 2) feat = feat1 + feat2 fine = tf_util.mlp_conv( feat, [512, 512, 3], bn=None, bn_params=None, name='fine') + point_feat level0 = fine return coarse, fine
def create_decoder(self, features): with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE): outputs = mlp(features, [1024, 1024, self.num_output_points * 3]) outputs = tf.reshape(outputs, [-1, self.num_output_points, 3]) return outputs
def train(args): is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training') global_step = tf.Variable(0, trainable=False, name='global_step') alpha = tf.train.piecewise_constant( global_step // (args.gen_iter + args.dis_iter), args.fine_step, [0.01, 0.1, 0.5, 1.0], 'alpha_op') inputs_pl = tf.placeholder(tf.float32, (args.batch_size, args.num_input_points, 3), 'inputs') gt_pl = tf.placeholder(tf.float32, (args.batch_size, args.num_gt_points, 3), 'gt') complete_feature = tf.placeholder(tf.float32, (args.batch_size, 1024), 'complete_feature') complete_feature0 = tf.placeholder(tf.float32, (args.batch_size, 256), 'complete_feature0') label_pl = tf.placeholder(tf.int32, shape=(args.batch_size)) model_module = importlib.import_module('.%s' % args.model_type, 'models') file_train = h5py.File(args.h5_train, 'r') incomplete_pcds_train = file_train['incomplete_pcds'][()] complete_pcds_train = file_train['complete_pcds'][()] labels_train = file_train['labels'][()] if args.num_gt_points == 2048: if args.step_ratio == 2: complete_features_train = file_train['complete_feature'][()] complete_features_train0 = file_train['complete_feature0'][()] elif args.step_ratio == 4: complete_features_train = file_train['complete_feature1_4'][()] complete_features_train0 = file_train['complete_feature0_4'][()] elif args.step_ratio == 8: complete_features_train = file_train['complete_feature1_8'][()] complete_features_train0 = file_train['complete_feature0_8'][()] elif args.step_ratio == 16: complete_features_train = file_train['complete_feature1_16'][()] complete_features_train0 = file_train['complete_feature0_16'][()] elif args.num_gt_points == 16384: file_train_feature = h5py.File( args.pretrain_complete_decoder + '/train_complete_feature.h5', 'r') complete_features_train = file_train_feature['complete_feature1'][()] complete_features_train0 = file_train_feature['complete_feature0'][()] file_train_feature.close() file_train.close() assert complete_features_train.shape[0] == complete_features_train0.shape[ 0] == incomplete_pcds_train.shape[0] assert complete_features_train.shape[1] == 1024 assert complete_features_train0.shape[1] == 256 train_num = incomplete_pcds_train.shape[0] BN_DECAY_DECAY_STEP = int(train_num / args.batch_size * args.lr_decay_epochs) learning_rate_d = tf.train.exponential_decay( args.base_lr_d, global_step // (args.gen_iter + args.dis_iter), BN_DECAY_DECAY_STEP, args.lr_decay_rate, staircase=True, name='lr_d') learning_rate_d = tf.maximum(learning_rate_d, args.lr_clip) learning_rate_g = tf.train.exponential_decay( args.base_lr_g, global_step // (args.gen_iter + args.dis_iter), BN_DECAY_DECAY_STEP, args.lr_decay_rate, staircase=True, name='lr_g') learning_rate_g = tf.maximum(learning_rate_g, args.lr_clip) with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): G_optimizers = tf.train.AdamOptimizer(learning_rate_g, beta1=0.9) D_optimizers = tf.train.AdamOptimizer(learning_rate_d, beta1=0.5) coarse_gpu = [] fine_gpu = [] tower_grads_g = [] tower_grads_d = [] total_dis_loss_gpu = [] total_gen_loss_gpu = [] total_lossReconstruction_gpu = [] total_lossFeature_gpu = [] for i in range(NUM_GPUS): with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): with tf.device('/gpu:%d' % (i)), tf.name_scope('gpu_%d' % (i)) as scope: inputs_pl_batch = tf.slice(inputs_pl, [int(i * DEVICE_BATCH_SIZE), 0, 0], [int(DEVICE_BATCH_SIZE), -1, -1]) gt_pl_batch = tf.slice(gt_pl, [int(i * DEVICE_BATCH_SIZE), 0, 0], [int(DEVICE_BATCH_SIZE), -1, -1]) complete_feature_batch = tf.slice( complete_feature, [int(i * DEVICE_BATCH_SIZE), 0], [int(DEVICE_BATCH_SIZE), -1]) complete_feature_batch0 = tf.slice( complete_feature0, [int(i * DEVICE_BATCH_SIZE), 0], [int(DEVICE_BATCH_SIZE), -1]) with tf.variable_scope('generator'): features_partial_0, partial_reconstruct = model_module.encoder( inputs_pl_batch, embed_size=1024) coarse_batch, fine_batch = model_module.decoder( inputs_pl_batch, partial_reconstruct, step_ratio=args.step_ratio, num_fine=args.step_ratio * 1024) with tf.variable_scope('discriminator') as dis_scope: errD_fake = mlp(tf.expand_dims(partial_reconstruct, axis=1), [16], bn=None, bn_params=None) dis_scope.reuse_variables() errD_real = mlp(tf.expand_dims(complete_feature_batch, axis=1), [16], bn=None, bn_params=None) kernel = getattr(mmd, '_%s_kernel' % args.mmd_kernel) kerGI = kernel(errD_fake[:, 0, :], errD_real[:, 0, :]) errG = mmd.mmd2(kerGI) errD = -errG epsilon = tf.random_uniform([], 0.0, 1.0) x_hat = complete_feature_batch * ( 1 - epsilon) + epsilon * partial_reconstruct d_hat = mlp(tf.expand_dims(x_hat, axis=1), [16], bn=None, bn_params=None) Ekx = lambda yy: tf.reduce_mean( kernel(d_hat[:, 0, :], yy, K_XY_only=True), axis=1) Ekxr, Ekxf = Ekx(errD_real[:, 0, :]), Ekx(errD_fake[:, 0, :]) witness = Ekxr - Ekxf gradients = tf.gradients(witness, [x_hat])[0] penalty = tf.reduce_mean( tf.square(mmd.safer_norm(gradients, axis=1) - 1.0)) errD_loss_batch = penalty * args.gp_weight + errD errG_loss_batch = errG feature_loss = tf.reduce_mean(tf.squared_difference(partial_reconstruct,complete_feature_batch))+ \ tf.reduce_mean(tf.squared_difference(features_partial_0[:,0,:], complete_feature_batch0)) dist1_fine, dist2_fine = chamfer(fine_batch, gt_pl_batch) dist1_coarse, dist2_coarse = chamfer(coarse_batch, gt_pl_batch) total_loss_fine = (tf.reduce_mean(tf.sqrt(dist1_fine)) + tf.reduce_mean(tf.sqrt(dist2_fine))) / 2 total_loss_coarse = (tf.reduce_mean(tf.sqrt(dist1_coarse)) + tf.reduce_mean(tf.sqrt(dist2_coarse))) / 2 total_loss_rec_batch = alpha * total_loss_fine + total_loss_coarse t_vars = tf.global_variables() gen_tvars = [ var for var in t_vars if var.name.startswith("generator") ] dis_tvars = [ var for var in t_vars if var.name.startswith("discriminator") ] total_gen_loss_batch = args.feat_weight * feature_loss + errG_loss_batch + args.rec_weight * total_loss_rec_batch total_dis_loss_batch = errD_loss_batch # Calculate the gradients for the batch of data on this tower. grads_g = G_optimizers.compute_gradients(total_gen_loss_batch, var_list=gen_tvars) grads_d = D_optimizers.compute_gradients(total_dis_loss_batch, var_list=dis_tvars) # Keep track of the gradients across all towers. tower_grads_g.append(grads_g) tower_grads_d.append(grads_d) coarse_gpu.append(coarse_batch) fine_gpu.append(fine_batch) total_dis_loss_gpu.append(total_dis_loss_batch) total_gen_loss_gpu.append(errG_loss_batch) total_lossReconstruction_gpu.append(args.rec_weight * total_loss_rec_batch) total_lossFeature_gpu.append(args.feat_weight * feature_loss) fine = tf.concat(fine_gpu, 0) grads_g = average_gradients(tower_grads_g) grads_d = average_gradients(tower_grads_d) # apply the gradients with our optimizers train_G = G_optimizers.apply_gradients(grads_g, global_step=global_step) train_D = D_optimizers.apply_gradients(grads_d, global_step=global_step) total_dis_loss = tf.reduce_mean(tf.stack(total_dis_loss_gpu, 0)) total_gen_loss = tf.reduce_mean(tf.stack(total_gen_loss_gpu, 0)) total_loss_rec = tf.reduce_mean(tf.stack(total_lossReconstruction_gpu, 0)) total_loss_fea = tf.reduce_mean(tf.stack(total_lossFeature_gpu, 0)) dist1_eval, dist2_eval = chamfer(fine, gt_pl) file_validate = h5py.File(args.h5_validate, 'r') incomplete_pcds_validate = file_validate['incomplete_pcds'][()] complete_pcds_validate = file_validate['complete_pcds'][()] labels_validate = file_validate['labels'][()] file_validate.close() config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True config.log_device_placement = False sess = tf.Session(config=config) saver = tf.train.Saver(max_to_keep=3) sess.run(tf.global_variables_initializer()) saver_decoder = tf.train.Saver(var_list=[var for var in tf.global_variables() if (var.name.startswith("generator/decoder") \ or var.name.startswith("generator/folding"))]) saver_decoder.restore(sess, args.pretrain_complete_decoder) if args.restore: saver.restore(sess, tf.train.latest_checkpoint(args.log_dir)) init_step = sess.run(global_step) // (args.gen_iter + args.dis_iter) epoch = init_step * args.batch_size // train_num + 1 print('init_step:%d,' % init_step, 'epoch:%d' % epoch, 'training data number:%d' % train_num) train_idxs = np.arange(0, train_num) for ep_cnt in range(epoch, args.max_epoch + 1): num_batches = train_num // args.batch_size np.random.shuffle(train_idxs) for batch_idx in range(num_batches): init_step += 1 start_idx = batch_idx * args.batch_size end_idx = min(start_idx + args.batch_size, train_num) ids_train = list(np.sort(train_idxs[start_idx:end_idx])) batch_data = incomplete_pcds_train[ids_train] batch_gt = complete_pcds_train[ids_train] labels = labels_train[ids_train] # partial_feature_input=incomplete_features_train[ids_train] complete_feature_input = complete_features_train[ids_train] complete_feature_input0 = complete_features_train0[ids_train] feed_dict = { inputs_pl: batch_data, gt_pl: batch_gt, is_training_pl: True, label_pl: labels, complete_feature: complete_feature_input, complete_feature0: complete_feature_input0 } for i in range(args.dis_iter): _, loss_dis = sess.run([train_D, total_dis_loss], feed_dict=feed_dict) for i in range(args.gen_iter): _, loss_gen, rec_loss, fea_loss = sess.run( [train_G, total_gen_loss, total_loss_rec, total_loss_fea], feed_dict=feed_dict) if init_step % args.steps_per_print == 0: print('epoch %d step %d gen_loss %.8f rec_loss %.8f fea_loss %.8f dis_loss %.8f' % \ (ep_cnt, init_step, loss_gen,rec_loss,fea_loss,loss_dis)) if init_step % args.steps_per_eval == 0: total_loss = 0 sess.run(tf.local_variables_initializer()) batch_data = np.zeros( (args.batch_size, incomplete_pcds_validate[0].shape[0], 3), 'f') batch_gt = np.zeros((args.batch_size, args.num_gt_points, 3), 'f') # partial_feature_input = np.zeros((args.batch_size, 1024), 'f') labels = np.zeros((args.batch_size, ), dtype=np.int32) feature_complete_input = np.zeros( (args.batch_size, 1024)).astype(np.float32) feature_complete_input0 = np.zeros( (args.batch_size, 256)).astype(np.float32) for batch_idx_eval in range(0, incomplete_pcds_validate.shape[0], args.batch_size): # start = time.time() start_idx = batch_idx_eval end_idx = min(start_idx + args.batch_size, incomplete_pcds_validate.shape[0]) batch_data[0:end_idx - start_idx] = incomplete_pcds_validate[ start_idx:end_idx] batch_gt[0:end_idx - start_idx] = complete_pcds_validate[ start_idx:end_idx] labels[0:end_idx - start_idx] = labels_validate[start_idx:end_idx] feature_complete_input[ 0:end_idx - start_idx] = complete_features_train[start_idx:end_idx] feature_complete_input0[ 0:end_idx - start_idx] = complete_features_train0[ start_idx:end_idx] feed_dict = { inputs_pl: batch_data, gt_pl: batch_gt, is_training_pl: False, label_pl: labels, complete_feature: feature_complete_input, complete_feature0: feature_complete_input0 } dist1_out, dist2_out = sess.run([dist1_eval, dist2_eval], feed_dict=feed_dict) if args.loss_type == 'cd_1': total_loss += np.mean(dist1_out[0:end_idx - start_idx]) * (end_idx - start_idx) \ + np.mean(dist2_out[0:end_idx - start_idx]) * (end_idx - start_idx) elif args.loss_type == 'cd_2': total_loss += (np.mean(np.sqrt(dist1_out[0:end_idx - start_idx])) * (end_idx - start_idx) \ + np.mean(np.sqrt(dist2_out[0:end_idx - start_idx])) * (end_idx - start_idx)) / 2 if total_loss / incomplete_pcds_validate.shape[ 0] < args.best_loss: args.best_loss = total_loss / incomplete_pcds_validate.shape[ 0] saver.save(sess, os.path.join(args.log_dir, 'model'), init_step) print('epoch %d step %d loss %.8f best_loss %.8f' % (ep_cnt, init_step, total_loss / incomplete_pcds_validate.shape[0], args.best_loss)) sess.close()