Пример #1
0
    def add_student_loss(self,
                         inputs,
                         outputs,
                         min_loss,
                         summary_writer,
                         add_summary,
                         global_step=0):
        cfg = self.cfg()
        num_candidates = cfg.pose_predict_num_candidates

        student = outputs["pose_student"]
        teachers = outputs["poses"]
        teachers = teachers.reshape(-1, num_candidates, 4)

        indices = min_loss
        indices = indices.unsqueeze(-1)
        batch_size = teachers.shape[0]
        batch_indices = torch.arange(0, batch_size, 1).long()
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        batch_indices = batch_indices.unsqueeze(-1).to(device)
        indices = torch.cat([batch_indices, indices], dim=1).long()
        teachers = teachers[indices[:, 0], indices[:, 1]]
        # use teachers only as ground truth
        teachers = teachers.detach()

        if cfg.variable_num_views:
            weights = inputs["valid_samples"]
        else:
            weights = 1.0

        if cfg.pose_student_align_loss:
            ref_pc = self._pc_for_alignloss
            num_ref_points = ref_pc.shape.as_list()[0]
            #import pdb
            #pdb.set_trace()
            ref_pc_all = tf.tile(tf.expand_dims(ref_pc, axis=0),
                                 [teachers.shape[0], 1, 1])
            pc_1 = q_rotate(ref_pc_all, teachers)
            pc_2 = q_rotate(ref_pc_all, student)
            student_loss = tf.nn.l2_loss(pc_1 - pc_2) / num_ref_points
        else:
            #import pdb
            #pdb.set_trace()
            q_diff = q_norm(q_mul(teachers, q_conj(student)))
            angle_diff = q_diff[:, 0]
            student_loss = ((1.0 - angle_diff**2) * weights).sum()

        num_samples = min_loss.shape[0]
        student_loss /= num_samples

        if add_summary and summary_writer is not None:
            summary_writer.add_scalar("losses/pose_predictor_student_loss",
                                      student_loss, global_step)
        student_loss *= cfg.pose_predictor_student_loss_weight

        return student_loss
Пример #2
0
    def add_student_loss(self, inputs, outputs, min_loss, add_summary):
        cfg = self.cfg()
        num_candidates = cfg.pose_predict_num_candidates

        student = outputs["pose_student"]
        teachers = outputs["poses"]
        teachers = tf.reshape(teachers, [-1, num_candidates, 4])

        indices = min_loss
        indices = tf.expand_dims(indices, axis=-1)
        batch_size = teachers.shape[0]
        batch_indices = tf.range(0, batch_size, 1, dtype=tf.int64)
        batch_indices = tf.expand_dims(batch_indices, -1)
        indices = tf.concat([batch_indices, indices], axis=1)
        teachers = tf.gather_nd(teachers, indices)
        # use teachers only as ground truth
        teachers = tf.stop_gradient(teachers)

        if cfg.variable_num_views:
            weights = inputs["valid_samples"]
        else:
            weights = 1.0

        if cfg.pose_student_align_loss:
            ref_pc = self._pc_for_alignloss
            num_ref_points = ref_pc.shape.as_list()[0]
            ref_pc_all = tf.tile(tf.expand_dims(ref_pc, axis=0),
                                 [teachers.shape[0], 1, 1])
            pc_1 = q_rotate(ref_pc_all, teachers)
            pc_2 = q_rotate(ref_pc_all, student)
            student_loss = tf.nn.l2_loss(pc_1 - pc_2) / num_ref_points
        else:
            q_diff = q_norm(q_mul(teachers, q_conj(student)))
            angle_diff = q_diff[:, 0]
            student_loss = tf.reduce_sum(
                (1.0 - tf.square(angle_diff)) * weights)

        num_samples = min_loss.shape[0]
        student_loss /= tf.to_float(num_samples)

        if add_summary:
            tf.contrib.summary.scalar("losses/pose_predictor_student_loss",
                                      student_loss)
        student_loss *= cfg.pose_predictor_student_loss_weight

        return student_loss
Пример #3
0
def align_predictions(outputs, alignment):
    outputs["points_1"] = q_rotate(outputs["points_1"], alignment)
    outputs["poses"] = q_mul(outputs["poses"], q_conj(alignment))
    outputs["pose_student"] = q_mul(outputs["pose_student"], q_conj(alignment))
    return outputs