Пример #1
0
def build_vae_ops(data_dict, args, scope='vae'):
    """
      builds vae operations that are required for training/inference of vae.

      Args:
        data_dict: dict, contains the tensors for the input to the model.
        args: arguments that are set for training.
        scope: string.
      
      Returns:
        train_op, summary_op, data_dict, logger_dict, global_step
        train_op: tf op for running training.
        summary_op: tf summary op that needs to be run for populating the
            summaries.
        data_dict: dictionary of tensors. Keys are tensor names and values
            are tensors. New keys and tensors will be added to the input
            data_dict.
        logger_dict: dictionary of tensors for printing.
        global_step: tf.Step that keeps the step number of the training.
    """
    losses = None
    summaries = None
    train_op = None
    logger_dict = None
    summary_op = None
    global_step = None
    first_dimension = args.num_objects_per_batch * args.num_grasps_per_object
    is_training = args.is_training

    with tf.variable_scope(scope):
        if is_training:
            assert '{}_pred/samples' not in data_dict
            input_pcs = data_dict['{}_pc'.format(scope)]
            losses = {}
            summaries = {}

            gt_control_points = tf_utils.transform_control_points(
                data_dict['{}_grasp_rt'.format(scope)],
                first_dimension,
                mode='rt')
            gt_control_points = tf.slice(gt_control_points, [0, 0, 0],
                                         [-1, -1, 3])
            data_dict['{}_gt_control_point'.format(scope)] = gt_control_points
            pc_input = tf.slice(input_pcs, [0, 0, 0], [-1, -1, 3])

            if not args.gan:  # Create Encoder.
                latent_input = data_dict['{}_grasp_rt'.format(scope)]
                batch_size = get_shape(pc_input)[0]
                npoints = get_shape(pc_input)[1]
                latent_input = tf.tile(
                    tf.reshape(latent_input, [batch_size, 1, -1]),
                    [1, npoints, 1])

                with tf.variable_scope('encoder'):
                    latent_mean_std = models.model.model_with_confidence(
                        pc_input,
                        latent_input,
                        is_training=tf.constant(is_training),
                        bn_decay=None,
                        is_encoder=True,
                        latent_size=args.latent_size,
                        scale=args.model_scale,
                        merge_pcs=args.merge_pcs_in_vae_encoder,
                        pointnet_radius=args.pointnet_radius,
                        pointnet_nclusters=args.pointnet_nclusters)

                    latent_mean = tf.slice(latent_mean_std, [0, 0],
                                           [-1, args.latent_size])
                    latent_std = tf.slice(latent_mean_std,
                                          [0, args.latent_size],
                                          [-1, args.latent_size])

                with tf.variable_scope('sample_from_latent'):
                    samples = latent_mean + tf.exp(
                        latent_std / 2.0) * tf.random_normal(
                            latent_mean.shape, 0, 1, dtype=tf.float32)
                    data_dict['{}_pred/samples'.format(scope)] = samples

                kl_loss = models.model.kl_divergence(latent_mean, latent_std)
                kl_loss = tf.reduce_mean(kl_loss)
                losses['kl_loss'] = kl_loss * args.kl_loss_weight
                summaries['unscaled_kl_loss'] = kl_loss
            else:  # For gan just sample random latents.
                samples = tf.random.uniform(
                    [first_dimension, args.latent_size], name='gan_latents')
        else:
            input_pcs = data_dict['{}_pc'.format(scope)]
            samples = data_dict['{}_pred/samples'.format(scope)]

        with tf.variable_scope('decoder'):
            pc_input = tf.slice(input_pcs, [0, 0, 0], [-1, -1, 3])

            latent_input = samples
            batch_size = get_shape(pc_input)[0]
            npoints = get_shape(pc_input)[1]
            latent_input = tf.tile(
                tf.reshape(latent_input, [batch_size, 1, -1]), [1, npoints, 1])

            q, t, confidence = models.model.model_with_confidence(
                pc_input,
                latent_input,
                tf.constant(is_training),
                bn_decay=None,
                is_encoder=False,
                latent_size=None,
                scale=args.model_scale,
                pointnet_radius=args.pointnet_radius,
                pointnet_nclusters=args.pointnet_nclusters)
            predicted_qt = tf.concat((q, t), -1)
            data_dict['{}_pred/grasp_qt'.format(scope)] = predicted_qt
            data_dict['{}_pred/confidence'.format(scope)] = confidence

            cp = tf_utils.transform_control_points(
                predicted_qt,
                get_shape(data_dict['{}_pc'.format(scope)])[0],
                scope='transform_predicted_qt')
            data_dict['{}_pred/cps'.format(scope)] = cp

        if is_training:
            loss_fn = None
            if args.gan:
                loss_fn = models.model.min_distance_loss
            else:
                loss_fn = models.model.control_point_l1_loss

            loss_term, confidence_term = loss_fn(
                cp,
                gt_control_points,
                confidence=confidence,
                confidence_weight=args.confidence_weight)
            data_dict['{}_loss'.format(scope)] = loss_term
            losses['gan_min_dist' if args.
                   gan else 'L1_grasp_reconstruction'] = loss_term
            losses['confidence'] = confidence_term

            for c in CONFIDENCES:
                qkey = 'quality_at_confidence/{}'.format(c)
                rkey = 'ratio_at_confidence/{}'.format(c)
                summary_fn = models.model.control_point_l1_loss_better_than_threshold
                if args.gan:
                    summary_fn = models.model.min_distance_better_than_threshold
                summaries[qkey], summaries[rkey] = summary_fn(
                    cp, gt_control_points, confidence, c)

            global_step = tf.train.get_or_create_global_step()
            total_loss = tf.reduce_sum(tf.stack(list(losses.values())))
            summaries['total_loss'] = total_loss
            learning_rate = tf.constant(args.lr, dtype=tf.float32)

            if args.ngpus > 1:
                optimizer = tf.train.AdamOptimizer(learning_rate * hvd.size())
                optimizer = hvd.DistributedOptimizer(optimizer)
            else:
                optimizer = tf.train.AdamOptimizer(learning_rate)

            train_op = optimizer.minimize(total_loss, global_step=global_step)
            summaries['global_step'] = global_step
            for k in losses:
                summaries['loss/{}'.format(k)] = losses[k]

            logger_dict = {}
            for k, v in summaries.items():
                logger_dict[k] = summaries[k]
                summaries[k] = tf.summary.scalar(k, v)

            summary_op = tf.summary.merge(list(summaries.values()))

        return train_op, summary_op, data_dict, logger_dict, global_step
Пример #2
0
def get_evaluator_data_dict(first_dimension,
                            args,
                            files,
                            pcreader,
                            scope='evaluator'):
    """
    Returns dictionary for training evaluator.

    Args:
      first_dimension: int, num_objects_per_batch x num_grasps_per_object.
      args: arguments used for training.
      files: list of string, contains path for the training.
      pcreader: PointCloudReader.
    """
    global current_index, epoch_count, lock, all_poses

    OUTPUT_SHAPES = {
        'pc': [first_dimension, args.npoints, 4],
        'grasp_rt': [first_dimension, 4, 4],
        'label': [first_dimension],  # Binary, success or not
        'grasp_quality': [first_dimension],  # For debugging only
        'pc_pose': [first_dimension, 4, 4],
        'cad_path': [first_dimension],
        'cad_scale': [first_dimension],
    }

    OUTPUT_KEYS = sorted(list(OUTPUT_SHAPES.keys()))
    OUTPUT_TYPES = []
    for k in OUTPUT_KEYS:
        if k == 'cad_path':
            OUTPUT_TYPES.append(tf.string)
        elif k == 'label':
            OUTPUT_TYPES.append(tf.int32)
        else:
            OUTPUT_TYPES.append(tf.float32)

    def get_evaluator_data_func():
        global current_index, epoch_count, lock, all_poses
        with lock:
            output_dict = {k: [] for k in OUTPUT_SHAPES}
            for _ in range(args.num_objects_per_batch):
                while True:
                    file_name = files[current_index]
                    no_positive_grasps = False
                    try:
                        output = pcreader.get_evaluator_data(file_name)
                    except grasp_data_reader.NoPositiveGraspsException:
                        no_positive_grasps = True

                    current_index += 1
                    if current_index == len(files):
                        random.shuffle(files)
                        epoch_count += 1
                        current_index = 0

                    if no_positive_grasps:
                        print('skipping {} because no positive grasps'.format(
                            file_name))
                        continue
                    else:
                        break

                output_dict['pc'].append(output[0])
                output_dict['grasp_rt'].append(output[1])
                output_dict['label'].append(output[2])
                output_dict['grasp_quality'].append(output[3])
                output_dict['pc_pose'].append(output[4])
                output_dict['cad_path'].append(output[5])
                output_dict['cad_scale'].append(output[6])

            for k in output_dict:
                output_dict[k] = np.asarray(output_dict[k])
                try:
                    output_dict[k] = np.reshape(output_dict[k],
                                                OUTPUT_SHAPES[k])
                except Exception as e:
                    print('{} =====> {} {}'.format(k, output_dict[k].shape,
                                                   OUTPUT_SHAPES[k]))
                    print(e)
                    #raise ValueError("asd")

            output_list = []
            for k in OUTPUT_KEYS:
                output_list.append(output_dict[k])

            return output_list

    # Takes in the function that generate the data dict and converts it to a tf operation
    data_list = tf.py_func(get_evaluator_data_func, [],
                           OUTPUT_TYPES,
                           stateful=True,
                           name='evaluator_data_reader')
    data_dict = {
        '{}_'.format(scope) + k: v
        for k, v in zip(OUTPUT_KEYS, data_list)
    }
    for k, shape in OUTPUT_SHAPES.items():
        data_dict['{}_'.format(scope) + k].set_shape(shape)

    data_dict['{}_gt_control_points'.format(
        scope)] = tf_utils.transform_control_points(
            data_dict['{}_grasp_rt'.format(scope)], first_dimension, mode='rt')

    return data_dict