Ejemplo n.º 1
0
def process_track(file_struct, boundaries_id, labels_id, config):
    # Only analize files with annotated beats
    if config["annot_beats"]:
        jam = jams2.load(file_struct.ref_file)
        if len(jam.beats) > 0 and len(jam.beats[0].data) > 0:
            pass
        else:
            logging.warning("No beat information in file %s" %
                            file_struct.ref_file)
            return

    logging.info("Segmenting %s" % file_struct.audio_file)

    # Compute features if needed
    if not os.path.isfile(file_struct.features_file):
        featextract.compute_all_features(file_struct)

    # Get estimations
    est_times, est_labels = run_algorithms(file_struct.audio_file,
                                           boundaries_id, labels_id, config)

    # Save
    logging.info("Writing results in: %s" % file_struct.est_file)
    est_inters = utils.times_to_intervals(est_times)
    io.save_estimations(file_struct.est_file, est_inters, est_labels,
                        boundaries_id, labels_id, **config)

    return est_times, est_labels
Ejemplo n.º 2
0
def plot_one_track(plot_name, file_struct, est_times, est_labels, boundaries_id, labels_id,
				   ds_prefix, title=None):
	"""Plots the results of one track, with ground truth if it exists."""
	# Get context
	if ds_prefix in msaf.prefix_dict.keys():
		context = msaf.prefix_dict[ds_prefix]
	else:
		context = "function"

	# Set up the boundaries id
	bid_lid = boundaries_id
	if labels_id is not None:
		bid_lid += " + " + labels_id
	try:
		# Read file
		ref_inter, ref_labels = jams2.converters.load_jams_range(file_struct.ref_file, "sections", annotator=0, context=context)

		# To times
		ref_times = utils.intervals_to_times(ref_inter)
		all_boundaries = [ref_times, est_times]
		all_labels = [ref_labels, est_labels]
		algo_ids = ["GT", bid_lid]
	except:
		logging.warning("No references found in %s. Not plotting groundtruth"
						% file_struct.ref_file)
		all_boundaries = [est_times]
		all_labels = [est_labels]
		algo_ids = [bid_lid]

	N = len(all_boundaries)

	# Index the labels to normalize them
	for i, labels in enumerate(all_labels):
		all_labels[i] = mir_eval.util.index_labels(labels)[0]

	# Get color map
	cm = plt.get_cmap('gist_rainbow')
	max_label = max(max(labels) for labels in all_labels)

	figsize = (8, 4)
	plt.figure(1, figsize=figsize, dpi=120, facecolor='w', edgecolor='k')
	for i, boundaries in enumerate(all_boundaries):
		color = "b"
		if i == 0:
			color = "g"
		for b in boundaries:
			plt.axvline(b, i / float(N), (i + 1) / float(N), color=color)
		if labels_id is not None:
			labels = all_labels[i]
			inters = utils.times_to_intervals(boundaries)
			for label, inter in zip(labels, inters):
				plt.axvspan(inter[0], inter[1], ymin=i / float(N),
							ymax=(i + 1) / float(N), alpha=0.6,
							color=cm(label / float(max_label)))
		plt.axhline(i / float(N), color="k", linewidth=1)

	# Format plot
	_plot_formatting(title, os.path.basename(file_struct.audio_file), algo_ids, all_boundaries[0][-1], N, plot_name)
Ejemplo n.º 3
0
def plot_labels(all_labels, gt_times, est_file, algo_ids=None, title=None,
                output_file=None):
    """Plots all the labels.

    Parameters
    ----------
    all_labels: list
        A list of np.arrays containing the labels of the boundaries, one array
        for each algorithm.
    gt_times: np.array
        Array with the ground truth boundaries.
    est_file: str
        Path to the estimated file (JSON file)
    algo_ids : list
        List of algorithm ids to to read boundaries from.
        If None, all algorithm ids are read.
    title : str
        Title of the plot. If None, the name of the file is printed instead.
    """
    import matplotlib.pyplot as plt
    N = len(all_labels)  # Number of lists of labels
    if algo_ids is None:
        algo_ids = io.get_algo_ids(est_file)

    # Translate ids
    for i, algo_id in enumerate(algo_ids):
        algo_ids[i] = translate_ids[algo_id]
    algo_ids = ["GT"] + algo_ids

    # Index the labels to normalize them
    for i, labels in enumerate(all_labels):
        all_labels[i] = mir_eval.util.index_labels(labels)[0]

    # Get color map
    cm = plt.get_cmap('gist_rainbow')
    max_label = max(max(labels) for labels in all_labels)

    # To intervals
    gt_inters = utils.times_to_intervals(gt_times)

    # Plot labels
    figsize = (6, 4)
    plt.figure(1, figsize=figsize, dpi=120, facecolor='w', edgecolor='k')
    for i, labels in enumerate(all_labels):
        for label, inter in zip(labels, gt_inters):
            plt.axvspan(inter[0], inter[1], ymin=i / float(N),
                        ymax=(i + 1) / float(N), alpha=0.6,
                        color=cm(label / float(max_label)))
        plt.axhline(i / float(N), color="k", linewidth=1)

    # Draw the boundary lines
    for bound in gt_times:
        plt.axvline(bound, color="g")

    # Format plot
    _plot_formatting(title, est_file, algo_ids, gt_times[-1], N,
                     output_file)
Ejemplo n.º 4
0
def plot_labels(all_labels, gt_times, est_file, algo_ids=None, title=None,
                output_file=None):
    """Plots all the labels.

    Parameters
    ----------
    all_labels: list
        A list of np.arrays containing the labels of the boundaries, one array
        for each algorithm.
    gt_times: np.array
        Array with the ground truth boundaries.
    est_file: str
        Path to the estimated file (JSON file)
    algo_ids : list
        List of algorithm ids to to read boundaries from.
        If None, all algorithm ids are read.
    title : str
        Title of the plot. If None, the name of the file is printed instead.
    """
    import matplotlib.pyplot as plt
    N = len(all_labels)  # Number of lists of labels
    if algo_ids is None:
        algo_ids = io.get_algo_ids(est_file)

    # Translate ids
    for i, algo_id in enumerate(algo_ids):
        algo_ids[i] = translate_ids[algo_id]
    algo_ids = ["GT"] + algo_ids

    # Index the labels to normalize them
    for i, labels in enumerate(all_labels):
        all_labels[i] = mir_eval.util.index_labels(labels)[0]

    # Get color map
    cm = plt.get_cmap('gist_rainbow')
    max_label = max(max(labels) for labels in all_labels)

    # To intervals
    gt_inters = utils.times_to_intervals(gt_times)

    # Plot labels
    figsize = (6, 4)
    plt.figure(1, figsize=figsize, dpi=120, facecolor='w', edgecolor='k')
    for i, labels in enumerate(all_labels):
        for label, inter in zip(labels, gt_inters):
            plt.axvspan(inter[0], inter[1], ymin=i / float(N),
                        ymax=(i + 1) / float(N), alpha=0.6,
                        color=cm(label / float(max_label)))
        plt.axhline(i / float(N), color="k", linewidth=1)

    # Draw the boundary lines
    for bound in gt_times:
        plt.axvline(bound, color="g")

    # Format plot
    _plot_formatting(title, est_file, algo_ids, gt_times[-1], N,
                     output_file)
Ejemplo n.º 5
0
def plot_one_track(file_struct, est_times, est_labels, boundaries_id, labels_id,
                   title=None, output_file=None):
    """Plots the results of one track, with ground truth if it exists."""
    import matplotlib.pyplot as plt
    # Set up the boundaries id
    bid_lid = boundaries_id
    if labels_id is not None:
        bid_lid += " + " + labels_id
    try:
        # Read file
        jam = jams.load(file_struct.ref_file)
        ann = jam.search(namespace='segment_.*')[0]
        ref_inters, ref_labels = ann.to_interval_values()

        # To times
        ref_times = utils.intervals_to_times(ref_inters)
        all_boundaries = [ref_times, est_times]
        all_labels = [ref_labels, est_labels]
        algo_ids = ["GT", bid_lid]
    except:
        logging.warning("No references found in %s. Not plotting groundtruth"
                        % file_struct.ref_file)
        all_boundaries = [est_times]
        all_labels = [est_labels]
        algo_ids = [bid_lid]

    N = len(all_boundaries)

    # Index the labels to normalize them
    for i, labels in enumerate(all_labels):
        all_labels[i] = mir_eval.util.index_labels(labels)[0]

    # Get color map
    cm = plt.get_cmap('gist_rainbow')
    max_label = max(max(labels) for labels in all_labels)

    figsize = (8, 4)
    plt.figure(1, figsize=figsize, dpi=120, facecolor='w', edgecolor='k')
    for i, boundaries in enumerate(all_boundaries):
        color = "b"
        if i == 0:
            color = "g"
        for b in boundaries:
            plt.axvline(b, i / float(N), (i + 1) / float(N), color=color)
        if labels_id is not None:
            labels = all_labels[i]
            inters = utils.times_to_intervals(boundaries)
            for label, inter in zip(labels, inters):
                plt.axvspan(inter[0], inter[1], ymin=i / float(N),
                            ymax=(i + 1) / float(N), alpha=0.6,
                            color=cm(label / float(max_label)))
        plt.axhline(i / float(N), color="k", linewidth=1)

    # Format plot
    _plot_formatting(title, os.path.basename(file_struct.audio_file), algo_ids,
                     all_boundaries[0][-1], N, output_file)
Ejemplo n.º 6
0
def plot_one_track(file_struct, est_times, est_labels, boundaries_id, labels_id,
                   title=None):
    """Plots the results of one track, with ground truth if it exists."""
    import matplotlib.pyplot as plt
    # Set up the boundaries id
    bid_lid = boundaries_id
    if labels_id is not None:
        bid_lid += " + " + labels_id
    try:
        # Read file
        jam = jams.load(file_struct.ref_file)
        ann = jam.search(namespace='segment_.*')[0]
        ref_inters, ref_labels = ann.to_interval_values()

        # To times
        ref_times = utils.intervals_to_times(ref_inters)
        all_boundaries = [ref_times, est_times]
        all_labels = [ref_labels, est_labels]
        algo_ids = ["GT", bid_lid]
    except:
        logging.warning("No references found in %s. Not plotting groundtruth"
                        % file_struct.ref_file)
        all_boundaries = [est_times]
        all_labels = [est_labels]
        algo_ids = [bid_lid]

    N = len(all_boundaries)

    # Index the labels to normalize them
    for i, labels in enumerate(all_labels):
        all_labels[i] = mir_eval.util.index_labels(labels)[0]

    # Get color map
    cm = plt.get_cmap('gist_rainbow')
    max_label = max(max(labels) for labels in all_labels)

    figsize = (8, 4)
    plt.figure(1, figsize=figsize, dpi=120, facecolor='w', edgecolor='k')
    for i, boundaries in enumerate(all_boundaries):
        color = "b"
        if i == 0:
            color = "g"
        for b in boundaries:
            plt.axvline(b, i / float(N), (i + 1) / float(N), color=color)
        if labels_id is not None:
            labels = all_labels[i]
            inters = utils.times_to_intervals(boundaries)
            for label, inter in zip(labels, inters):
                plt.axvspan(inter[0], inter[1], ymin=i / float(N),
                            ymax=(i + 1) / float(N), alpha=0.6,
                            color=cm(label / float(max_label)))
        plt.axhline(i / float(N), color="k", linewidth=1)

    # Format plot
    _plot_formatting(title, os.path.basename(file_struct.audio_file), algo_ids,
                     all_boundaries[0][-1], N, None)
Ejemplo n.º 7
0
def compute_gt_results(est_file,
                       ref_file,
                       boundaries_id,
                       labels_id,
                       config,
                       bins=251,
                       annotator_id=0):
    """Computes the results by using the ground truth dataset identified by
    the annotator parameter.

    Return
    ------
    results : dict
        Dictionary of the results (see function compute_results).
    """
    try:
        if config["hier"]:
            ref_times, ref_labels, ref_levels = \
                msaf.io.read_hier_references(
                    ref_file, annotation_id=0,
                    exclude_levels=["segment_salami_function"])
        else:
            jam = jams.load(ref_file, validate=False)
            ann = jam.search(namespace='segment_.*')[annotator_id]
            ref_inter, ref_labels = ann.data.to_interval_values()
    except:
        logging.warning("No references for file: %s" % ref_file)
        return {}

    # Read estimations with correct configuration
    est_inter, est_labels = io.read_estimations(est_file, boundaries_id,
                                                labels_id, **config)
    if len(est_inter) == 0:
        logging.warning("No estimations for file: %s" % est_file)
        return {}

    # Compute the results and return
    logging.info("Evaluating %s" % os.path.basename(est_file))
    if config["hier"]:
        # Hierarchical
        assert len(est_inter) == len(est_labels), "Same number of levels " \
            "are required in the boundaries and labels for the hierarchical " \
            "evaluation."
        est_times = []
        est_labels = []

        # Sort based on how many segments per level
        est_inter = sorted(est_inter, key=lambda level: len(level))

        for inter in est_inter:
            est_times.append(msaf.utils.intervals_to_times(inter))
            # Add fake labels (hierarchical eval does not use labels --yet--)
            est_labels.append(np.ones(len(est_times[-1]) - 1) * -1)

        # Align the times
        utils.align_end_hierarchies(est_times, ref_times, thres=1)

        # To intervals
        est_hier = [utils.times_to_intervals(times) for times in est_times]
        ref_hier = [utils.times_to_intervals(times) for times in ref_times]

        # Compute evaluations
        res = {}
        res["t_recall10"], res["t_precision10"], res["t_measure10"] = \
            mir_eval.hierarchy.tmeasure(ref_hier, est_hier, window=10)
        res["t_recall15"], res["t_precision15"], res["t_measure15"] = \
            mir_eval.hierarchy.tmeasure(ref_hier, est_hier, window=15)

        res["track_id"] = os.path.basename(est_file)[:-5]
        return res
    else:
        # Flat
        return compute_results(ref_inter, est_inter, ref_labels, est_labels,
                               bins, est_file)
Ejemplo n.º 8
0
Archivo: eval.py Proyecto: beckgom/msaf
def compute_gt_results(est_file, ref_file, boundaries_id, labels_id, config, bins=251, annotator_id=0):
    """Computes the results by using the ground truth dataset identified by
    the annotator parameter.

    Return
    ------
    results : dict
        Dictionary of the results (see function compute_results).
    """
    try:
        if config["hier"]:
            ref_times, ref_labels, ref_levels = msaf.io.read_hier_references(
                ref_file, annotation_id=0, exclude_levels=["segment_salami_function"]
            )
        else:
            jam = jams.load(ref_file, validate=False)
            ann = jam.search(namespace="segment_.*")[annotator_id]
            ref_inter, ref_labels = ann.data.to_interval_values()
    except:
        logging.warning("No references for file: %s" % ref_file)
        return {}

    # Read estimations with correct configuration
    est_inter, est_labels = io.read_estimations(est_file, boundaries_id, labels_id, **config)
    if len(est_inter) == 0:
        logging.warning("No estimations for file: %s" % est_file)
        return {}

    # Compute the results and return
    logging.info("Evaluating %s" % os.path.basename(est_file))
    if config["hier"]:
        # Hierarchical
        assert len(est_inter) == len(est_labels), (
            "Same number of levels " "are required in the boundaries and labels for the hierarchical " "evaluation."
        )
        est_times = []
        est_labels = []

        # Sort based on how many segments per level
        est_inter = sorted(est_inter, key=lambda level: len(level))

        for inter in est_inter:
            est_times.append(msaf.utils.intervals_to_times(inter))
            # Add fake labels (hierarchical eval does not use labels --yet--)
            est_labels.append(np.ones(len(est_times[-1]) - 1) * -1)

        # Align the times
        utils.align_end_hierarchies(est_times, ref_times, thres=1)

        # To intervals
        est_hier = [utils.times_to_intervals(times) for times in est_times]
        ref_hier = [utils.times_to_intervals(times) for times in ref_times]

        # Compute evaluations
        res = {}
        res["t_recall10"], res["t_precision10"], res["t_measure10"] = mir_eval.hierarchy.tmeasure(
            ref_hier, est_hier, window=10
        )
        res["t_recall15"], res["t_precision15"], res["t_measure15"] = mir_eval.hierarchy.tmeasure(
            ref_hier, est_hier, window=15
        )

        res["track_id"] = os.path.basename(est_file)[:-5]
        return res
    else:
        # Flat
        return compute_results(ref_inter, est_inter, ref_labels, est_labels, bins, est_file)
Ejemplo n.º 9
0
def save_estimations(file_struct, times, labels, boundaries_id, labels_id,
                     **params):
    """Saves the segment estimations in a JAMS file.

    Parameters
    ----------
    file_struct : FileStruct
        Object with the different file paths of the current file.
    times : np.array or list
        Estimated boundary times.
        If `list`, estimated hierarchical boundaries.
    labels : np.array(N, 2)
        Estimated labels (None in case we are only storing boundary
        evaluations).
    boundaries_id : str
        Boundary algorithm identifier.
    labels_id : str
        Labels algorithm identifier.
    params : dict
        Dictionary with additional parameters for both algorithms.
    """
    # Remove features if they exist
    params.pop("features", None)

    # Get duration
    dur = get_duration(file_struct.features_file)

    # Convert to intervals and sanity check
    if 'numpy' in str(type(times)):
        inters = utils.times_to_intervals(times)
        assert len(inters) == len(labels), "Number of boundary intervals " \
            "(%d) and labels (%d) do not match" % (len(inters), len(labels))
        # Put into lists to simplify the writing process later
        inters = [inters]
        labels = [labels]
    else:
        inters = []
        for level in range(len(times)):
            est_inters = utils.times_to_intervals(times[level])
            inters.append(est_inters)
            assert len(inters[level]) == len(labels[level]), \
                "Number of boundary intervals (%d) and labels (%d) do not " \
                "match" % (len(inters[level]), len(labels[level]))

    # Create new estimation
    namespace = "multi_segment" if params["hier"] else "segment_open"
    ann = jams.Annotation(namespace=namespace)

    # Find estimation in file
    if os.path.isfile(file_struct.est_file):
        jam = jams.load(file_struct.est_file, validate=False)
        curr_ann = find_estimation(jam, boundaries_id, labels_id, params)
        if curr_ann is not None:
            curr_ann.data = ann.data  # cleanup all data
            ann = curr_ann  # This will overwrite the existing estimation
        else:
            jam.annotations.append(ann)
    else:
        # Create new JAMS if it doesn't exist
        jam = jams.JAMS()
        jam.file_metadata.duration = dur
        jam.annotations.append(ann)

    # Save metadata and parameters
    ann.annotation_metadata.version = msaf.__version__
    ann.annotation_metadata.data_source = "MSAF"
    sandbox = {}
    sandbox["boundaries_id"] = boundaries_id
    sandbox["labels_id"] = labels_id
    sandbox["timestamp"] = \
        datetime.datetime.today().strftime("%Y/%m/%d %H:%M:%S")
    for key in params:
        sandbox[key] = params[key]
    ann.sandbox = sandbox

    # Save actual data
    for i, (level_inters, level_labels) in enumerate(zip(inters, labels)):
        if level_labels is None:
            label = np.ones(len(inters)) * -1
        for bound_inter, label in zip(level_inters, level_labels):
            dur = float(bound_inter[1]) - float(bound_inter[0])
            label = chr(int(label) + 65)
            if params["hier"]:
                value = {"label": label, "level": i}
            else:
                value = label
            ann.append(time=bound_inter[0], duration=dur, value=value)

    # Write results
    jam.save(file_struct.est_file)
Ejemplo n.º 10
0
def save_estimations(file_struct, times, labels, boundaries_id, labels_id, **params):
    """Saves the segment estimations in a JAMS file.

    Parameters
    ----------
    file_struct : FileStruct
        Object with the different file paths of the current file.
    times : np.array or list
        Estimated boundary times.
        If `list`, estimated hierarchical boundaries.
    labels : np.array(N, 2)
        Estimated labels (None in case we are only storing boundary
        evaluations).
    boundaries_id : str
        Boundary algorithm identifier.
    labels_id : str
        Labels algorithm identifier.
    params : dict
        Dictionary with additional parameters for both algorithms.
    """
    # Remove features if they exist
    params.pop("features", None)

    # Get duration
    dur = get_duration(file_struct.features_file)

    # Convert to intervals and sanity check
    if "numpy" in str(type(times)):
        inters = utils.times_to_intervals(times)
        assert len(inters) == len(labels), "Number of boundary intervals " "(%d) and labels (%d) do not match" % (
            len(inters),
            len(labels),
        )
        # Put into lists to simplify the writing process later
        inters = [inters]
        labels = [labels]
    else:
        inters = []
        for level in range(len(times)):
            est_inters = utils.times_to_intervals(times[level])
            inters.append(est_inters)
            assert len(inters[level]) == len(labels[level]), (
                "Number of boundary intervals (%d) and labels (%d) do not "
                "match" % (len(inters[level]), len(labels[level]))
            )

    # Create new estimation
    namespace = "multi_segment" if params["hier"] else "segment_open"
    ann = jams.Annotation(namespace=namespace)

    # Find estimation in file
    if os.path.isfile(file_struct.est_file):
        jam = jams.load(file_struct.est_file, validate=False)
        curr_ann = find_estimation(jam, boundaries_id, labels_id, params)
        if curr_ann is not None:
            curr_ann.data = ann.data  # cleanup all data
            ann = curr_ann  # This will overwrite the existing estimation
        else:
            jam.annotations.append(ann)
    else:
        # Create new JAMS if it doesn't exist
        jam = jams.JAMS()
        jam.file_metadata.duration = dur
        jam.annotations.append(ann)

    # Save metadata and parameters
    ann.annotation_metadata.version = msaf.__version__
    ann.annotation_metadata.data_source = "MSAF"
    sandbox = {}
    sandbox["boundaries_id"] = boundaries_id
    sandbox["labels_id"] = labels_id
    sandbox["timestamp"] = datetime.datetime.today().strftime("%Y/%m/%d %H:%M:%S")
    for key in params:
        sandbox[key] = params[key]
    ann.sandbox = sandbox

    # Save actual data
    for i, (level_inters, level_labels) in enumerate(zip(inters, labels)):
        if level_labels is None:
            label = np.ones(len(inters)) * -1
        for bound_inter, label in zip(level_inters, level_labels):
            dur = float(bound_inter[1]) - float(bound_inter[0])
            if params["hier"]:
                value = {"label": str(int(label)), "level": i}
            else:
                value = str(int(label))
            ann.append(time=bound_inter[0], duration=dur, value=six.text_type(value))

    # Write results
    jam.save(file_struct.est_file)
Ejemplo n.º 11
0
def save_estimations(out_file, times, labels, boundaries_id, labels_id,
                     **params):
    """Saves the segment estimations in a JAMS file.close

	Parameters
	----------
	out_file : str
		Path to the output JAMS file in which to save the estimations.
	times : np.array or list
		Estimated boundary times.
		If `list`, estimated hierarchical boundaries.
	labels : np.array(N, 2)
		Estimated labels (None in case we are only storing boundary
		evaluations).
	boundaries_id : str
		Boundary algorithm identifier.
	labels_id : str
		Labels algorithm identifier.
	params : dict
		Dictionary with additional parameters for both algorithms.
	"""
    # Convert to intervals and sanity check
    if 'numpy' in str(type(times)):
        inters = utils.times_to_intervals(times)
        assert len(inters) == len(labels), "Number of boundary intervals " \
         "(%d) and labels (%d) do not match" % (len(inters), len(labels))
        # Put into lists to simplify the writing process later
        inters = [inters]
        labels = [labels]
    else:
        inters = []
        for level in range(len(times)):
            est_inters = utils.times_to_intervals(times[level])
            inters.append(est_inters)
            assert len(inters[level]) == len(labels[level]), \
            "Number of boundary intervals (%d) and labels (%d) do not match" % \
             (len(inters[level]), len(labels[level]))

    curr_estimation = None
    curr_i = -1

    # Find estimation in file
    if os.path.isfile(out_file):
        jam = jams2.load(out_file)
        all_estimations = jam.sections
        curr_estimation, curr_i = find_estimation(all_estimations,
                                                  boundaries_id, labels_id,
                                                  params, out_file)
    else:
        # Create new JAMS if it doesn't exist
        jam = jams2.Jams()
        jam.metadata.title = os.path.basename(out_file).replace(
            msaf.Dataset.estimations_ext, "")

    # Create new annotation if needed
    if curr_estimation is None:
        curr_estimation = jam.sections.create_annotation()

    # Save metadata and parameters
    curr_estimation.annotation_metadata.attribute = "sections"
    curr_estimation.annotation_metadata.version = msaf.__version__
    curr_estimation.annotation_metadata.origin = "MSAF"
    sandbox = {}
    sandbox["boundaries_id"] = boundaries_id
    sandbox["labels_id"] = labels_id
    sandbox["timestamp"] = \
     datetime.datetime.today().strftime("%Y/%m/%d %H:%M:%S")
    for key in params:
        sandbox[key] = params[key]
    curr_estimation.sandbox = sandbox

    # Save actual data
    curr_estimation.data = []
    for i, (level_inters, level_labels) in enumerate(zip(inters, labels)):
        if level_labels is None:
            label = np.ones(len(inters)) * -1
        for bound_inter, label in zip(level_inters, level_labels):
            segment = curr_estimation.create_datapoint()
            segment.start.value = float(bound_inter[0])
            segment.start.confidence = 0.0
            segment.end.value = float(bound_inter[1])
            segment.end.confidence = 0.0
            segment.label.value = label
            segment.label.confidence = 0.0
            segment.label.context = "level_%d" % i

    # Place estimation in its place
    if curr_i != -1:
        jam.sections[curr_i] = curr_estimation

    # Write file and do not let users interrupt it
    my_thread = Thread(target=safe_write, args=(
        jam,
        out_file,
    ))
    my_thread.start()
    my_thread.join()
Ejemplo n.º 12
0
def save_estimations(out_file, times, labels, boundaries_id, labels_id,
					 **params):
	"""Saves the segment estimations in a JAMS file.close

	Parameters
	----------
	out_file : str
		Path to the output JAMS file in which to save the estimations.
	times : np.array or list
		Estimated boundary times.
		If `list`, estimated hierarchical boundaries.
	labels : np.array(N, 2)
		Estimated labels (None in case we are only storing boundary
		evaluations).
	boundaries_id : str
		Boundary algorithm identifier.
	labels_id : str
		Labels algorithm identifier.
	params : dict
		Dictionary with additional parameters for both algorithms.
	"""
	# Convert to intervals and sanity check
	if 'numpy' in str(type(times)):
		inters = utils.times_to_intervals(times)
		assert len(inters) == len(labels), "Number of boundary intervals " \
			"(%d) and labels (%d) do not match" % (len(inters), len(labels))
		# Put into lists to simplify the writing process later
		inters = [inters]
		labels = [labels]
	else:
		inters = []
		for level in range(len(times)):
			est_inters = utils.times_to_intervals(times[level])
			inters.append(est_inters)
			assert len(inters[level]) == len(labels[level]), \
			"Number of boundary intervals (%d) and labels (%d) do not match" % \
				(len(inters[level]), len(labels[level]))

	curr_estimation = None
	curr_i = -1

	# Find estimation in file
	if os.path.isfile(out_file):
		jam = jams2.load(out_file)
		all_estimations = jam.sections
		curr_estimation, curr_i = find_estimation(
			all_estimations, boundaries_id, labels_id, params, out_file)
	else:
		# Create new JAMS if it doesn't exist
		jam = jams2.Jams()
		jam.metadata.title = os.path.basename(out_file).replace(
			msaf.Dataset.estimations_ext, "")

	# Create new annotation if needed
	if curr_estimation is None:
		curr_estimation = jam.sections.create_annotation()

	# Save metadata and parameters
	curr_estimation.annotation_metadata.attribute = "sections"
	curr_estimation.annotation_metadata.version = msaf.__version__
	curr_estimation.annotation_metadata.origin = "MSAF"
	sandbox = {}
	sandbox["boundaries_id"] = boundaries_id
	sandbox["labels_id"] = labels_id
	sandbox["timestamp"] = \
		datetime.datetime.today().strftime("%Y/%m/%d %H:%M:%S")
	for key in params:
		sandbox[key] = params[key]
	curr_estimation.sandbox = sandbox

	# Save actual data
	curr_estimation.data = []
	for i, (level_inters, level_labels) in enumerate(zip(inters, labels)):
		if level_labels is None:
			label = np.ones(len(inters)) * -1
		for bound_inter, label in zip(level_inters, level_labels):
			segment = curr_estimation.create_datapoint()
			segment.start.value = float(bound_inter[0])
			segment.start.confidence = 0.0
			segment.end.value = float(bound_inter[1])
			segment.end.confidence = 0.0
			segment.label.value = label
			segment.label.confidence = 0.0
			segment.label.context = "level_%d" % i

	# Place estimation in its place
	if curr_i != -1:
		jam.sections[curr_i] = curr_estimation

	# Write file and do not let users interrupt it
	my_thread = Thread(target=safe_write, args=(jam, out_file,))
	my_thread.start()
	my_thread.join()