예제 #1
0
    def _train_triplet_step(self, anchor, pos, neg):
        with tf.GradientTape(persistent=False) as tape:
            anchor_emb = get_embeddings(self.model, anchor)
            pos_emb = get_embeddings(self.model, pos)
            neg_emb = get_embeddings(self.model, neg)

            loss = triplet_loss(anchor_emb, pos_emb, neg_emb, self.alpha)

        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(
            zip(gradients, self.model.trainable_variables))

        return loss
예제 #2
0
    def _get_sim_label(self):
        sims = None
        labels = None
        for image1, image2, label in self.data:
            emb1 = get_embeddings(self.model, image1)
            emb2 = get_embeddings(self.model, image2)
            sim = self._cal_cos_sim(emb1, emb2)
            if sims is None:
                sims = sim
            else:
                sims = tf.concat([sims, sim], axis=0)

            if labels is None:
                labels = label
            else:
                labels = tf.concat([labels, label], axis=0)

        return sims, labels
예제 #3
0
def main(_):
    # We want to see all the logging messages for this tutorial.
    tf.logging.set_verbosity(tf.logging.INFO)

    # Define the input function for training
    wav_files, label_ids, label_to_id = get_file_and_labels(os.path.join(FLAGS.data_dir, 'eval_labels'))
    wav_files = [os.path.join(FLAGS.data_dir, 'eval', wav_file) for wav_file in wav_files]

    groups = _get_groups(os.path.join(FLAGS.data_dir, 'groups_config'))
    enrollments = get_enrollments(os.path.join(FLAGS.data_dir, 'enrollment_config'))
    to_be_verified = _get_to_be_verified(os.path.join(FLAGS.data_dir, 'verification_config'))
    to_be_identified = _get_to_be_identified(os.path.join(FLAGS.data_dir, 'identification_config'))
    file_id_to_index = _get_file_id_to_index(wav_files)
    # TODO validate configurations
    # transform the configurations: wav file id --> index, label_id --> label_index
    groups_transformed = dict()
    for group_id in groups:
        group = [label_to_id[i] for i in groups[group_id]]
        groups_transformed[group_id] = group
    groups = groups_transformed

    enrollments = [file_id_to_index[i] for i in enrollments]
    to_be_verified = [(file_id_to_index[i], label_to_id[j]) for i, j in to_be_verified]
    to_be_identified = [(file_id_to_index[i], group_id) for i, group_id in to_be_identified]
    model = create_model(
        model_dir=FLAGS.model_dir,
        params={
            **FLAGS.__dict__
        })

    embeddings = get_embeddings(model,
                                wav_files=wav_files,
                                **FLAGS.__dict__)

    registerations = get_registerations([embeddings[i] for i in enrollments],
                                        [label_ids[i] for i in enrollments])
    fa_rate, fr_rate, error_rate, threshold = _evaluate_verification(embeddings, label_ids, registerations,
                                                                     to_be_verified,
                                                                     FLAGS.threshold)

    eval_msg_template = 'false accept rate:{}\n' + \
                        'false reject rate:{}\n' + \
                        'error rate:{}\n' + \
                        'threshold:{}'

    tf.logging.info('verification performance')
    tf.logging.info(eval_msg_template.format(fa_rate, fr_rate, error_rate, threshold))

    # use the threshold corresponded to the eer for verification
    fa_rate, fr_rate, error_rate, threshold = _evaluate_identification(embeddings, label_ids, registerations,
                                                                       to_be_identified, groups, threshold)
    tf.logging.info('identification performance')
    tf.logging.info(eval_msg_template.format(fa_rate, fr_rate, error_rate, threshold))
예제 #4
0
def _enroll(model, device_id, user_id, streams):
    _ensure_user_root_path(device_id, user_id)
    enrollment_filenames = _get_enrollment_filenames(device_id, user_id)
    files = []
    user_root_path = _get_user_root_path(device_id, user_id)
    for stream in streams:
        output_file, output_filename = _save_pcm_stream_to_wav(
            user_root_path, stream)
        files.append(output_file)
        enrollment_filenames.append(output_filename)
    # compute and save embeddings
    embeddings = get_embeddings(model, files, FLAGS.desired_ms,
                                FLAGS.window_size_ms, FLAGS.window_stride_ms,
                                FLAGS.sample_rate, FLAGS.magnitude_squared,
                                FLAGS.dct_coefficient_count, FLAGS.batch_size)
    for i, file in enumerate(files):
        embedding_file = file + '.npy'
        np.save(embedding_file, embeddings[i])
    # update config
    _save_enrollment_config(device_id, user_id, enrollment_filenames)
예제 #5
0
def _get_embedding(model, filepath):
    embeddings = get_embeddings(model, [filepath],
                                batch_size=1,
                                **FLAGS.__dict__)
    return embeddings[0]