示例#1
0
def validate(config_file):
    print('Reading config file...')
    configuration = parse_configuration(config_file)

    print('Initializing dataset...')
    val_dataset = create_dataset(configuration['val_dataset_params'])
    val_dataset_size = len(val_dataset)
    print('The number of validation samples = {0}'.format(val_dataset_size))

    print('Initializing model...')
    model = create_model(configuration['model_params'])
    model.setup()
    model.eval()

    model.pre_epoch_callback(configuration['model_params']['load_checkpoint'])

    for i, data in enumerate(val_dataset):
        model.set_input(data)  # unpack data from data loader
        model.test()  # run inference

    model.post_epoch_callback(configuration['model_params']['load_checkpoint'])
示例#2
0
文件: main.py 项目: praveena2j/WSDAOR
# from losses.center_loss import CenterLoss
from losses.LabelSmoothing import LSR

### Importing model libraries
from models.pytorch_i3d_new import InceptionI3d
from models.I3DWSDDA import I3D_WSDDA
#from models.VGG_inflated import VggFace

args = argparse.ArgumentParser(description='DomainAdaptation')
args.add_argument('-c',
                  '--config',
                  default=None,
                  type=str,
                  help='config file path (default: None)')
args = args.parse_args()
configuration = parse_configuration(args.config)

# global plotter
# plotter = utils.exp_utils.VisdomLinePlotter(env_name='praveen_Plots',port=8051)
# vis = Visdom()

TestError = []
TestAccuracy = []
SEED = configuration['SEED']

ts = time.time()
path = "MILExperiments/"
Logfile_name = path + str(ts) + configuration['Logfilename']

if os.path.isfile(Logfile_name):
    os.remove(Logfile_name)
示例#3
0
def main():

    parser = argparse.ArgumentParser("Script to create the Venn Plots")
    parser.add_argument("-t",
                        "--type",
                        choices=["missing", "full", "fusion"],
                        required=True)
    parser.add_argument("-c",
                        "--configuration",
                        required=True,
                        type=argparse.FileType("r"))
    parser.add_argument("-em",
                        "--exclude-mikado",
                        dest="exclude",
                        action="store_true",
                        default=False,
                        help="Flag. If set, Mikado results will be excluded")
    parser.add_argument("-o",
                        "--out",
                        type=str,
                        help="Output file",
                        required=True)
    parser.add_argument("--format",
                        choices=["svg", "tiff", "png"],
                        default="svg")
    parser.add_argument(
        "-a",
        "--aligner",  # choices=["STAR", "TopHat"],
        required=True,
        nargs="+")
    parser.add_argument(
        "--transcripts",
        action="store_true",
        default=False,
        help="Flag. If set, Venn plotted against transcripts, not genes.")
    parser.add_argument("--dpi", default=300, type=int)
    parser.add_argument("--title", default="Venn Diagram")
    parser.add_argument("--procs", default=1, type=int)
    args = parser.parse_args()

    options = parse_configuration(args.configuration,
                                  exclude_mikado=args.exclude)

    sets = OrderedDict.fromkeys([
        "{}\n({})".format(*_)
        for _ in itertools.product(options["methods"], args.aligner)
    ])

    for k in list(sets.keys()) + ["base"]:
        sets[k] = {"full": set(), "fusion": set(), "missing": set()}

    # Update the sets for each gene and label
    if args.transcripts is True:
        colname = "ref_id"
        ccode = "ccode"
        tag = "transcripts"
    else:
        colname = "ref_gene"
        ccode = "best_ccode"
        tag = "genes"

    first = True

    pool = multiprocessing.Pool(processes=args.procs)
    proxies = dict().fromkeys(sets.keys())

    for method in options["methods"]:
        for aligner in args.aligner:
            if options["methods"][method][aligner] is not None:
                orig_stats, filtered_stats = options["methods"][method][
                    aligner][:2]
            else:
                orig_stats, filtered_stats = None, None
            if first is True and orig_stats is not None:
                proxies["base"] = pool.apply_async(
                    parse_refmaps,
                    (orig_stats, filtered_stats, args.transcripts, True, True))
            key = "{}\n({})".format(method, aligner)
            proxies[key] = pool.apply_async(
                parse_refmaps,
                (orig_stats, filtered_stats, args.transcripts, True, False))

    for proxy in proxies:
        sets[proxy]["full"], sets[proxy]["missing"], sets[proxy][
            "fusion"] = proxies[proxy].get()

    print("Loaded RefMaps.", file=sys.stderr)

    # Now use intervene venn

    labels = dict()
    sums = dict()
    for typ in ["full", "missing", "fusion"]:
        labels[typ] = ivenn.get_labels([list(sets[_][typ]) for _ in sets])
        sums[typ] = dict()
        for num in range(len(sets) + 1):
            sums[typ][num] = 0
        for label in labels[typ]:
            # print(typ, label, sum(int(_) for _ in label), int(labels[typ][label]))
            sums[typ][sum(int(_)
                          for _ in label) - 1] += int(labels[typ][label])
            continue

    print("Sums")
    for num in sorted(range(len(sets))):
        print(num, *[sums[_][num] for _ in ["full", "missing", "fusion"]])

    print("Per method")
    all_reconstructable = set.union(
        *[sets[_]["full"] for _ in sets if _ != "base"])
    missed_all = set.intersection(
        *[sets[_]["missing"] for _ in sets if _ != "base"])
    fused_all = set.intersection(
        *[sets[_]["fusion"] for _ in sets if _ != "base"])

    for ds in sets:
        if ds == "base":
            continue
        key = " ".join(ds.split("\n"))
        row = [key]
        row.append(len(sets[ds]["full"]))
        row.append(len(set.difference(all_reconstructable, sets[ds]["full"])))
        row.append(
            round(
                100 *
                len(set.difference(all_reconstructable, sets[ds]["full"])) /
                len(all_reconstructable), 2))
        print(*row, sep="\t")

    # print("Labels:", labels[args.type])

    if len(sets) > 6:
        print("Too many sets to intersect ({}), exiting.".format(len(sets)),
              file=sys.stderr)
        sys.exit(0)

    funcs = {
        2: ivenn.venn2,
        3: ivenn.venn3,
        4: ivenn.venn4,
        5: ivenn.venn5,
        6: ivenn.venn6,
    }

    # Recalculate labels without the base
    labels = dict()
    for typ in ["full", "missing", "fusion"]:
        labels[typ] = ivenn.get_labels(
            [list(sets[_][typ]) for _ in sets if _ != "base"])

    if options["colourmap"]["use"] is True:
        color_normalizer = matplotlib.colors.Normalize(0,
                                                       len(options["methods"]))
        color_map = cm.get_cmap(options["colourmap"]["name"])
        cols = [
            color_map(color_normalizer(index))
            for index in range(len(options["methods"]))
        ]
        # cols = [matplotlib.colors.rgb2hex(color_map(color_normalizer(index)))
        #         for index in range(len(options["methods"]))]
        # cols = rpy2.robjects.vectors.StrVector(cols)
    else:
        cols = [options["methods"][_]["colour"] for _ in options["methods"]]
        for index, colour in enumerate(cols):
            matched = re.match("\(([0-9]*), ([0-9]*), ([0-9]*)\)$", colour)
            if matched:
                nums = (int(matched.groups()[0]), int(matched.groups()[1]),
                        int(matched.groups()[2]))
                if nums == (255, 255, 255):  # Pure white
                    nums = (125, 125, 125)
                cols[index] = "#{0:02x}{1:02x}{2:02x}{3:02x}".format(
                    clamp(nums[0]), clamp(nums[1]), clamp(nums[2]), 80)

    print(labels[args.type])
    print([_ for _ in sets.keys() if _ != "base"])
    print(cols)
    if (len(sets) - 1) == 2:
        fontsize = 18
    else:
        fontsize = 20
    fig, ax = funcs[len(sets) - 1](
        labels[args.type],
        names=[_ for _ in sets.keys() if _ != "base"],
        colors=cols,
        fontsize=fontsize,
        dpi=args.dpi,
        alpha=0.5,
        figsize=(12, 12))
    fig.savefig("{}.{}".format(args.out, args.format), dpi=args.dpi)
    print("Saved the figure to {}.{}".format(args.out, args.format))
    import time
    time.sleep(3)
示例#4
0
def main():

    parser = argparse.ArgumentParser(__doc__,
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # The species should be a configuration file listing
    # for each species the following:
    # - folder
    # - name
    parser.add_argument("--species", type=argparse.FileType("r"))
    parser.add_argument("--out", required=True)
    parser.add_argument("--opaque", default=True, action="store_false")
    parser.add_argument("--dpi", default=1000, type=int)
    parser.add_argument("--format", default="svg", choices=["png", "pdf", "ps", "eps", "svg"])
    parser.add_argument("--level", default="transcript",
                        choices=["base", "exon", "intron", "intron_chain", "transcript", "gene"])
    parser.add_argument("-s", "--style", choices=plt.style.available, default="ggplot")
    parser.add_argument("--title", default="Mikado stats for transcript level")
    args = parser.parse_args()

    species = yaml.load(args.species)

    names = species.pop("names")

    keys = dict()
    for name in species:
        if isinstance(species[name], dict) and species[name].get("name", None) in names:
            keys[species[name]["name"]] = name

    ncols = len(names)

    figure, axes = plt.subplots(
        nrows=1,
        ncols=ncols,
        dpi=args.dpi,
        figsize=(10 / 2 * ncols, 4 / 2 * ncols))

    figure.suptitle(" ".join(["${}$".format(_) for _ in args.title.split()]),
                    fontsize=20, style="italic", family="serif")
    figure.text(0.5, 0.15, "${}\ \%$".format(args.level.title()), ha="center", fontsize=15)

    # name_ar = []
    # This is tricky .. this should be divided in 2 if there is more than one level

    # categories = species.pop("categories")
    # for category in categories:
    #     for name in names:
    #         key = [_ for _ in species.keys() if species[_]["name"] == name and species[_]["category"] == category]
    #         assert len(key) == 1
    #         key = key.pop()
    #         name_ar.append(key)
    # name_ar = np.array(list(grouper(name_ar, ceil(len(species) / 2), None)))

    # Dictionary to indicate which line should be taken given the level
    line_correspondence = {"base": 5,
                           "exon": 7,
                           "intron": 8,
                           "intron_chain": 9,
                           "transcript": 12,
                           "gene": 15}

    # best_marker = "o"

    for yrow, name in enumerate(names):
        plot = axes[yrow]
        plot.set_title("${}$".format(names[yrow]), fontsize=15)
        stats = OrderedDict()

        with open(species[keys[name]]["configuration"]) as configuration:
            options = parse_configuration(configuration, prefix=species[keys[name]]["folder"])

            # Usually STAR, TopHat
            method_names = sorted(list(itertools.product(options["methods"], options["divisions"])))

            for division in options["divisions"]:
                stats[division.encode()] = []

            max_point = float("-inf")
            for counter, (method, division) in enumerate(reversed(method_names), 1):
                    # print("Method:", method, "Aligner:", division)
                try:
                    orig, filtered = options["methods"][method][division]
                except TypeError:
                    warnings.warn("Something went wrong for {}, {}; continuing".format(
                        method, division))
                    stats[division.encode()].append((-10, -10, -10))
                    continue
                orig_lines = [line.rstrip() for line in open(orig)]
                filtered_lines = [line.rstrip() for line in open(filtered)]
                # for index, line_index in enumerate([5, 7, 8, 9, 11, 12, 14, 15]):

                plot.plot((0, 100), (counter, counter), c="lightgrey")

                for index, line_index in enumerate([line_correspondence[args.level]]):
                    precision = float(orig_lines[line_index].split(":")[1].split()[1])
                    recall = float(filtered_lines[line_index].split(":")[1].split()[0])
                    try:
                        f1 = hmean(np.array([precision, recall]))
                    except TypeError as exc:
                        raise TypeError("\n".join([str(_) for _ in [(precision, type(precision)),
                                                                    (recall, type(recall)),
                                                                    exc]]))
                    # print(level, method, division, (precision, recall, f1))
                    # stats[division.encode()].append((precision, recall, f1))
                    # We can plot directly
                    max_point = max([max_point] + [_ + 7 for _ in (precision, recall, f1)])
                    plot.scatter(precision, counter,
                                 label="Precision",
                                 # label="{0} ({1})".format(label, division),
                                 c="orange", marker="o",
                                 edgecolor="k", s=[100.0], alpha=1)
                    plot.scatter(recall, counter,
                                 label="Recall",
                                 # label="{0} ({1})".format(label, division),
                                 c="lightblue", marker="o",
                                 edgecolor="k", s=[100.0], alpha=1)
                    plot.scatter(f1, counter,
                                 label="F1",
                                 # label="{0} ({1})".format(label, division),
                                 c="black", marker="o",
                                 edgecolor="k", s=[100.0], alpha=1)

            __axes = plot.axes
            max_point = min(100, max_point)
            print(name, max_point)
            __axes.set_xlim(0, max_point)
            __axes.set_ylim(0, len(method_names) + 1)
            __axes.spines["top"].set_visible(False)
            __axes.spines["right"].set_visible(False)
            # __axes.set_aspect("equal")
            if yrow == 0:
                __axes.set_yticks(range(1, len(method_names) + 1))
                __axes.set_yticklabels(["{} ({})".format(*_) for _ in reversed(method_names)],
                                       # fontdict= {'family': 'serif',
                                       #            'color': 'black',
                                       #            'weight': 'normal',
                                       #            'size': 200},
                                       fontsize=16
                                       )
            else:
                __axes.spines["left"].set_visible(False)
                __axes.set_yticks([])

            plot.tick_params(axis='x', which='major', labelsize=12)

    div_labels = []

    # f1_line = mlines.Line2D([], [], color="gray", linestyle="--")
    # div_labels.append((f1_line, "F1 contours"))

    prec_line = mlines.Line2D([], [],  marker="o", markersize=20, markerfacecolor="orange", color="white")
    rec_line = mlines.Line2D([], [],  marker="o", markersize=20, markerfacecolor="lightblue", color="white")
    f1_line = mlines.Line2D([], [],  marker="o", markersize=20, markerfacecolor="black", color="white")

    div_labels.append((prec_line, "Precision"))
    div_labels.append((rec_line, "Recall"))
    div_labels.append((f1_line, "F1"))

    plt.figlegend(handles=[_[0] for _ in div_labels],
                  labels=[_[1] for _ in div_labels],
                  loc="upper center",
                  scatterpoints=1,
                  ncol=3,
                  fontsize=14,
                  framealpha=0.5)

    plt.tight_layout(pad=0.5,
                     h_pad=1,
                     w_pad=1,
                     rect=[0.1,  # Left
                           0.2,  # Bottom
                           0.85,  # Right
                           0.9])  # Top
    if args.out is None:
        plt.ion()
        plt.show(block=True)
    else:
        plt.savefig("{}.{}".format(args.out, args.format),
                    format=args.format,
                    dpi=args.dpi,
                    transparent=args.opaque)
示例#5
0
def main():

    parser = argparse.ArgumentParser(__doc__,
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-c", "--configuration", required=True, type=argparse.FileType("r"))
    parser.add_argument("--out", required=True)
    parser.add_argument("--tight", default=False, action="store_true")
    parser.add_argument("--title", default="Mikado stats")
    parser.add_argument("--format", default=None, choices=["png", "pdf", "ps", "eps", "svg"])
    parser.add_argument("--levels", default=None, choices=["base", "exon", "intron",
                                                           "intron_chain", "transcript", "gene"],
                        nargs="+")
    parser.add_argument("--dpi", type=int, default=None)
    args = parser.parse_args()

    options = parse_configuration(args.configuration)
    options["out"] = os.path.splitext(args.out)[0]

    stats = OrderedDict()

    # figure = plt.figure(dpi=options["dpi"], figsize=(12, 6))

    # gs = gridspec.GridSpec()

    name_ar = np.array([["Base", "Exon", "Intron"],
                        ["Intron chain", "Transcript", "Gene"]])
    name_ar_orig = name_ar.copy()

    name_corr = OrderedDict()
    name_corr["base"] = ("Base", 5)
    name_corr["exon"] = ("Exon", 7)
    name_corr["intron"] = ("Intron", 8)
    name_corr["intron_chain"] = ("Intron chain", 9)
    name_corr["transcript"] = ("Transcript", 12)
    name_corr["gene"] = ("Gene", 15)

    if args.levels is None:
        nrows = 2
        ncols = 3
        indices = [name_corr[_][1] for _ in name_corr]
    else:
        if len(set(args.levels)) <= 2:
            nrows = 1
            ncols = len(set(args.levels))
        elif len(set(args.levels)) == 3:
            ncols = 3
            nrows = 1
        elif len(set(args.levels)) == 4:
            nrows = 2
            ncols = 2
        else:
            nrows = 2
            ncols = 3
        assert nrows * ncols >= len(set(args.levels))
        __arr = []
        indices = []
        for key in name_corr:
            if key in args.levels:
                __arr.append(name_corr[key][0])
                indices.append(name_corr[key][1])
        name_ar = np.array(__arr)

    figure, axes = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        dpi=options["dpi"],
        figsize=(10, 6))

    figure.suptitle(" ".join(["${}$".format(_) for _ in args.title.split()]),
                    fontsize=20, style="italic", family="serif")

    Xaxis = mpatches.FancyArrow(0.1, 0.19, 0.89, 0,
                                width=0.001,
                                length_includes_head=True,
                                transform=figure.transFigure, figure=figure,
                                color="k")
    Yaxis = mpatches.FancyArrow(0.1, 0.19, 0, 0.7,
                                width=0.001,
                                length_includes_head=True,
                                transform=figure.transFigure, figure=figure,
                                color="k")
    figure.lines.extend([Xaxis, Yaxis])

    figure.text(0.92, 0.21, "$Recall$", ha="center", fontsize=15)
    figure.text(0.07, 0.8, "$Precision$", va="center", fontsize=15, rotation="vertical")

    if options["colourmap"]["use"] is True:
        color_normalizer = matplotlib.colors.Normalize(0, len(options["methods"]))
        color_map = cm.get_cmap(options["colourmap"]["name"])
    # mapper = cm.ScalarMappable(colors, "PuOr")

    for xrow in range(nrows):
        for yrow in range(ncols):
            if nrows > 1:
                key = name_ar[xrow, yrow]
                plot = axes[xrow, yrow]
            else:
                key = name_ar[yrow]
                if ncols > 1:
                    plot = axes[yrow]
                else:
                    plot = axes
            if key is None:
                continue
            # plot.grid(True, linestyle='dotted')
            # plot.set(adjustable="box-forced", aspect="equal")
            plot.set_title("{} level".format(key), fontsize=10)

            # plot.set_xlabel("Precision", fontsize=10)
            # plot.set_ylabel("Recall", fontsize=10)
            stats[key] = dict()
            stats[key][b"plot"] = plot
            for division in options["divisions"]:
                stats[key][division.encode()] = []

    for method in options["methods"]:
        for aligner in options["divisions"]:
            if options["methods"][method][aligner] is not None:
                orig, filtered = options["methods"][method][aligner]
                orig_lines = [line.rstrip() for line in open(orig)]
                filtered_lines = [line.rstrip() for line in open(filtered)]
            else:
                orig_lines = None
                filtered_lines = None
                print("Aligner {} not found for method {}".format(aligner, method))
            # for index, line_index in enumerate([5, 7, 8, 9, 11, 12, 14, 15]):
            for index, line_index in enumerate(indices):
                if orig_lines is not None:
                    precision = float(orig_lines[line_index].split(":")[1].split()[1])
                    recall = float(filtered_lines[line_index].split(":")[1].split()[0])
                    try:
                        f1 = hmean(np.array([precision, recall]))
                    except TypeError as exc:
                        raise TypeError("\n".join([str(_) for _ in [(precision, type(precision)),
                                                                    (recall, type(recall)),
                                                                    exc]]))
                else:
                    precision = -10
                    recall = -10
                    f1 = -10

                # In order:
                # Name of the statistic:Base, Exon, etc
                # Name of the method

                stats[list(stats.keys())[index]][aligner.encode()].append((precision, recall, f1))



    divisions = sorted(options["divisions"].keys())

    handles = None
    best_marker = "o"

    for stat in stats.keys():

        plot = stats[stat][b"plot"]

        ys = [np.array([_[0] for _ in stats[stat][division.encode()]]) for division in divisions]
        xs = [np.array([_[1] for _ in stats[stat][division.encode()]]) for division in divisions]

        # print("Xs", xs)
        # print("Ys", ys)

        # plot.axis("scaled")
        # Select a suitable maximum
        index = np.argwhere(name_ar_orig == stat)
        index = index[0][0] * 3 + index[0][1]
        # print("Index", index, stat, name_ar)

        # Structure of X:
        # array[ method1[X assembler1, X assembler2, X assembler3 ...],
        #        method2[X assembler1, X assembler2, X assembler3 ...],
        #        ... ]

        suitable_x = []
        for method_num in range(len(xs)):
            for assembler_val in xs[method_num]:
                if assembler_val > -10:
                    suitable_x.append(assembler_val)

        suitable_y = []
        for method_num in range(len(ys)):
            for assembler_val in ys[method_num]:
                if assembler_val > -10:
                    suitable_y.append(assembler_val)

        suitable_x = np.array(suitable_x)
        suitable_y = np.array(suitable_y)

        # print(suitable_x.min(), index, suitable_x)

        if args.tight is True:
            margin = 1
        else:
            margin = 5
            
        x_minimum = max(0, floor(suitable_x.min()) - margin)
        y_minimum = max(0, floor(suitable_y.min()) - margin)

        x_maximum = min(100,
                        ceil(max(_.max() for _ in xs)) + margin)
        y_maximum = min(100,
                        ceil(max(_.max() for _ in ys)) + margin)

        plotf1curves(plot, fstepsize=ceil(min(x_maximum - x_minimum, y_maximum - y_minimum)/10))
        best_f1 = (-1, [])

        for enumerated, division in enumerate(divisions):
            for index, vals in enumerate(zip(xs[enumerated], ys[enumerated], options["methods"].keys())):
                x, y, label = vals
                f1 = calc_f1(x, y)
                if best_f1[0] < f1:
                    best_f1 = (f1, [(x, y)])
                elif best_f1[0] == f1:
                    best_f1[1].append((x, y))

                if options["colourmap"]["use"] is False:
                    colour = options["methods"][label]["colour"]
                    matched = re.match("\(([0-9]*), ([0-9]*), ([0-9]*)\)$", colour)
                    if matched:
                        colour = "#{0:02x}{1:02x}{2:02x}".format(clamp(int(matched.groups()[0])),
                                                                 clamp(int(matched.groups()[1])),
                                                                 clamp(int(matched.groups()[2])))
                elif options["methods"][label]["colour"] in ("black", "k"):
                    colour = "black"
                else:
                    colour = color_map(color_normalizer(options["methods"][label]["index"]))
                # The size has to change in a inversely related way compared to the number of plots
                # Top must be 100
                plot.scatter(x, y,
                             label=label,
                             # label="{0} ({1})".format(label, division),
                             c=colour, marker=options["divisions"][division]["marker"],
                             edgecolor="k", s=[150/max(nrows,ncols)], alpha=.8)

        circle_rad = 30
        for best in best_f1[1]:
            plot.plot(best[0], best[1], "o",
                      label="Best F1",
                      ms=circle_rad,
                      linestyle="-",
                      mec="k",
                      mfc="none")
        # for best in best_f1[1]:
        #     plot.scatter(best[0], best[1],
        #                  label="Best F1",
        #                  marker=best_marker, s=[20], c="k")

        if handles is None:
            handles, labels = plot.get_legend_handles_labels()

        # labels = list(divisions) + list(options["methods"].keys())

        __axes = plot.axes
        # print(stat, "({}, {})".format(x_minimum, x_maximum), "({}, {})".format(y_minimum, y_maximum))
        __axes.set_xlim(x_minimum, x_maximum)
        __axes.set_ylim(y_minimum, y_maximum)
        # __axes.set_aspect("equal")
        plot.tick_params(axis='both', which='major', labelsize=8)

    # Now create the labels
    # First the aligners

    # labels = []

    div_labels = []

    f1_line = mlines.Line2D([], [], color="gray", linestyle="--")
    div_labels.append((f1_line, "F1 contours"))

    for division in options["divisions"]:
        faux_line = mlines.Line2D([], [], color="white",
                                  marker=options["divisions"][division]["marker"],
                                  markersize=14,
                                  markerfacecolor="black")
        div_labels.append((faux_line, division))

    # best_marker_line = mlines.Line2D([], [], color="white",
    #                             marker=best_marker, markersize=6,
    #                             markerfacecolor="black", markeredgecolor="black")
    best_marker_line = mlines.Line2D([], [], color="white",
                                marker=best_marker, markersize=15,
                                markerfacecolor="none", markeredgecolor="black")
    div_labels.append((best_marker_line, "Best F1"))

    for method in options["methods"]:
        if options["colourmap"]["use"] is False:
            colour = options["methods"][method]["colour"]
            matched = re.match("\(([0-9]*), ([0-9]*), ([0-9]*)\)$", colour)
            if matched:
                colour = "#{0:02x}{1:02x}{2:02x}".format(clamp(int(matched.groups()[0])),
                                                         clamp(int(matched.groups()[1])),
                                                         clamp(int(matched.groups()[2])))
        elif options["methods"][label]["colour"] in ("black", "k"):
            colour = "black"
        else:
            colour = color_map(color_normalizer(options["methods"][method]["index"]))

        patch = mpatches.Patch(facecolor=colour, linewidth=1, edgecolor="k")
        div_labels.append((patch, method))

    npatches = len(options["divisions"]) + len(options["methods"]) + 2
    plt.figlegend(handles=[_[0] for _ in div_labels],
                  labels=[_[1] for _ in div_labels],
                  loc="lower center",
                  scatterpoints=1,
                  ncol=ceil(npatches/4),
                  fontsize=10,
                  framealpha=0.5)
    # Necessary to pad the superior title
    plt.tight_layout(pad=0.5,
                     h_pad=1,
                     w_pad=1,
                     rect=[0.1,  # Left
                           0.2,  # Bottom
                           0.85,  # Right
                           0.9])  # Top
    if options["out"] is None:
        plt.ion()
        plt.show(block=True)
    else:
        if args.format is not None:
            options["format"] = args.format
        if args.dpi is not None:
            options["dpi"] = args.dpi

        plt.savefig("{}.{}".format(options["out"], options["format"]),
                    format=options["format"],
                    dpi=options["dpi"],
                    transparent=options["opaque"])
def main():

    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument("-c", "--conf", required=True)
    parser.add_argument("-o", "--out", default=sys.stdout, type=argparse.FileType("wt"))
    args = parser.parse_args()

    line_correspondence = {"base": 5,
                           "exon": 7,
                           "intron": 8,
                           "intron_chain": 9,
                           "transcript": 12,
                           "gene": 15}

    with open(args.conf) as configuration:
        options = parse_configuration(configuration)

    stats = OrderedDict()

    name_ar = np.array([["Base", "Exon", "Intron"],
                        ["Intron chain", "Transcript", "Gene"]])
    # name_ar_orig = name_ar.copy()

    name_corr = OrderedDict()
    name_corr["base"] = ("Base", 5)
    name_corr["exon"] = ("Exon", 7)
    name_corr["intron"] = ("Intron", 8)
    name_corr["intron_chain"] = ("Intron chain", 9)
    name_corr["transcript"] = ("Transcript", 12)
    name_corr["gene"] = ("Gene", 15)

    for key in name_corr:
        stats[key] = OrderedDict()
        for division in options["divisions"]:
            stats[key][division.encode()] = dict()

    for method in options["methods"]:
        for division in options["divisions"]:
            if options["methods"][method][division] is not None:
                orig, filtered = options["methods"][method][division]
                orig_lines = [line.rstrip() for line in open(orig)]
                filtered_lines = [line.rstrip() for line in open(filtered)]
            else:
                orig_lines = None
                filtered_lines = None
                print("Aligner {} not found for method {}".format(division, method))
            # for index, line_index in enumerate([5, 7, 8, 9, 11, 12, 14, 15]):
            for index, line_index in enumerate([name_corr[_][1] for _ in name_corr]):
                if orig_lines is not None:
                    precision = orig_lines[line_index]
                    precision = precision.split(":")

                    precision = float(orig_lines[line_index].split(":")[1].split()[1])
                    recall = float(filtered_lines[line_index].split(":")[1].split()[0])
                    try:
                        f1 = round(hmean(np.array([precision, recall])), 2)
                    except TypeError as exc:
                        raise TypeError("\n".join([str(_) for _ in [(precision, type(precision)),
                                                                    (recall, type(recall)),
                                                                    exc]]))
                else:
                    precision = -10
                    recall = -10
                    f1 = -10

                # In order:
                # Name of the statistic:Base, Exon, etc
                # Name of the method
                stats[list(stats.keys())[index]][division.encode()][method.encode()] = (precision, recall, f1)

    first_row = ["Level"]
    for name in name_corr:
        first_row.extend(["", name_corr[name][0], ""])
    print(*first_row, sep="\t", file=args.out)

    second_row = [""] + ["Precision", "Recall", "F1"] * (int((len(first_row) -1 )/ 3))
    print(*second_row, sep="\t", file=args.out)

    for method in options["methods"]:
        for division in options["divisions"]:
            row = ["{} ({})".format(method, division)]
            for name in name_corr:
                row.extend(stats[name][division.encode()][method.encode()])
            print(*row, file=args.out, sep="\t")

    return
def main():

    parser = argparse.ArgumentParser(
        __doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # The species should be a configuration file listing
    # for each species the following:
    # - folder
    # - name
    parser.add_argument("--species",
                        type=argparse.FileType("r"),
                        required=True)
    parser.add_argument("--out", required=True)
    parser.add_argument("--opaque", default=True, action="store_false")
    parser.add_argument("--dpi", default=1000, type=int)
    parser.add_argument("--format",
                        default="svg",
                        choices=["png", "pdf", "ps", "eps", "svg"])
    parser.add_argument("--level",
                        default=["transcript"],
                        nargs="+",
                        choices=[
                            "base", "exon", "intron", "intron_chain",
                            "transcript", "gene"
                        ])
    parser.add_argument(
        "-cm",
        "--colour-map",
        dest="colour_map",
        default=None,
        help="Colour map to use. Default: use user-specified colours.")
    parser.add_argument(
        "-cs",
        "--colour-map-size",
        dest="colour_map_size",
        default=0,
        type=int,
        help="Colour map to use. Default: use user-specified colours.")
    parser.add_argument(
        "-e",
        "--equal",
        action="store_true",
        default=False,
        help=
        "Flag. If switched on, all subplots will share the same X and Y limits."
    )
    parser.add_argument("-s",
                        "--style",
                        choices=plt.style.available,
                        default="ggplot")
    parser.add_argument("--title", default="Mikado stats for transcript level")
    args = parser.parse_args()

    if len(args.level) > 2:
        warnings.warn(
            "Currently the script can only accept 1 or 2 levels. Exiting.")
        sys.exit(1)

    species = yaml.load(args.species)

    names = species.pop("names")
    categories = species.pop("categories", None)

    if len(args.level) == 1:
        ncols = ceil(
            len([_ for _ in species if _ not in ("categories", "names")]) / 2)
    else:
        ncols = len(species)

    nrows = len(args.level)

    print(nrows, args.level, ncols, species)

    figure, axes = plt.subplots(nrows=nrows,
                                ncols=ncols,
                                dpi=args.dpi,
                                figsize=(10 / 2 * ncols, 6 / 2 * ncols))
    figure.suptitle(" ".join(["${}$".format(_) for _ in args.title.split()]),
                    fontsize=20,
                    style="italic",
                    family="serif")

    Xaxis = mpatches.FancyArrow(0.07,
                                0.19,
                                0.89,
                                0,
                                width=0.003,
                                length_includes_head=True,
                                transform=figure.transFigure,
                                figure=figure,
                                color="k")
    Yaxis = mpatches.FancyArrow(0.07,
                                0.19,
                                0,
                                0.7,
                                width=0.003,
                                length_includes_head=True,
                                transform=figure.transFigure,
                                figure=figure,
                                color="k")
    figure.lines.extend([Xaxis, Yaxis])

    figure.text(0.92, 0.21, "$Recall$", ha="center", fontsize=15)
    figure.text(0.03,
                0.8,
                "$Precision$",
                va="center",
                fontsize=15,
                rotation="vertical")

    name_ar = []
    # This is tricky .. this should be divided in 2 if there is more than one level

    if categories:
        if len(args.level) == 1:
            for category in categories:
                for name in names:
                    print(category, name)
                    key = [
                        _ for _ in species.keys()
                        if isinstance(species[_], dict) and species[_]["name"]
                        == name and species[_]["category"] == category
                    ]
                    print(key)
                    assert len(key) == 1
                    key = key.pop()
                    name_ar.append(key)
            name_ar = np.array(
                list(grouper(name_ar, ceil(len(species) / 2), None)))
        else:
            for category in categories:
                for name in names:
                    assert isinstance(species, dict)
                    key = [
                        _ for _ in species.keys()
                        if isinstance(species[_], dict) and species[_]["name"]
                        == name and species[_]["category"] == category
                    ]
                    assert len(key) == 1, (key, name, category)
                    key = key.pop()
                    name_ar.append(key)
            name_ar = np.array(list(grouper(name_ar, 2, None)))
    else:
        for name in names:
            key = [
                _ for _ in species.keys()
                if isinstance(species[_], dict) and species[_]["name"] == name
            ]
            print(key)
            assert len(key) == 1
            key = key.pop()
            name_ar.append(key)
        name_ar = np.array([name_ar for _ in range(len(args.level))])

    # Dictionary to indicate which line should be taken given the level
    markers = dict()
    methods = OrderedDict()

    best_marker = "o"

    if categories:
        cat = categories
    else:
        cat = args.level

    print(name_ar)
    for xrow, category in enumerate(cat):
        for yrow, name in enumerate(names):
            # try:
            #     key = name_ar[xrow, yrow]
            # except IndexError:
            #     raise IndexError(name_ar, xrow, yrow)
            key = name_ar[xrow, yrow]
            if key is None:
                continue
            plot = axes[xrow, yrow]
            # plot.grid(True, linestyle='dotted')
            # plot.set(adjustable="box-forced", aspect="equal")
            if xrow == 0:
                plot.set_title("${}$".format(names[yrow]), fontsize=15)
            if yrow == 0:
                if categories:
                    plot.set_ylabel(categories[xrow].title(), fontsize=15)
                else:
                    plot.set_ylabel(args.level[xrow].title(), fontsize=15)

            # plot.set_xlabel("Precision", fontsize=10)
            # plot.set_ylabel("Recall", fontsize=10)

            stats = OrderedDict()

            with open(species[key]["configuration"]) as configuration:
                options = parse_configuration(configuration,
                                              prefix=species[key]["folder"])
                if args.colour_map is not None:
                    options["colourmap"]["use"] = True
                    options["colourmap"]["name"] = args.colour_map
                else:
                    options["colourmap"]["use"] = False

                for division in options["divisions"]:
                    markers[division] = options["divisions"][division][
                        "marker"]

                if options["colourmap"]["use"] is True:
                    cm_length = max(args.colour_map_size,
                                    len(options["methods"]))
                    print(cm_length)
                    color_normalizer = matplotlib.colors.Normalize(
                        vmin=0, vmax=cm_length)
                    color_map = cm.get_cmap(options["colourmap"]["name"])

                for division in options["divisions"]:
                    stats[division.encode()] = []

                for method in options["methods"]:
                    for division in options["divisions"]:
                        # print("Method:", method, "Aligner:", division)
                        try:
                            orig, filtered = options["methods"][method][
                                division]
                        except TypeError:
                            warnings.warn(
                                "Something went wrong for {}, {}; continuing".
                                format(method, division))
                            stats[division.encode()].extend(
                                [(-10, -10, -10)] * len(line_correspondence))
                            continue
                        orig_lines = [line.rstrip() for line in open(orig)]
                        filtered_lines = [
                            line.rstrip() for line in open(filtered)
                        ]
                        # for index, line_index in enumerate([5, 7, 8, 9, 11, 12, 14, 15]):
                        if category in args.level:
                            level = category
                        else:
                            level = args.level[0]

                        for index, line_index in enumerate(
                            [line_correspondence[level]]):
                            precision = float(orig_lines[line_index].split(":")
                                              [1].split()[1])
                            recall = float(filtered_lines[line_index].split(
                                ":")[1].split()[0])
                            try:
                                f1 = hmean(np.array([precision, recall]))
                            except TypeError as exc:
                                raise TypeError("\n".join([
                                    str(_) for _ in [(
                                        precision,
                                        type(precision)), (recall,
                                                           type(recall)), exc]
                                ]))
                            # print(level, method, division, (precision, recall, f1))
                            stats[division.encode()].append(
                                (precision, recall, f1))
                divisions = sorted(options["divisions"].keys())
                handles = None

                # ys = []
                # xs = []
                # for division in divisions:
                #     for level in args.level:
                #         ys.append(stats[division.encode()][level][0])
                #         xs.append(stats[division.encode()][level][1])
                #
                # ys = np.array(ys)
                # xs = np.array(xs)

                ys = [
                    np.array([_[0] for _ in stats[division.encode()]])
                    for division in divisions
                ]
                xs = [
                    np.array([_[1] for _ in stats[division.encode()]])
                    for division in divisions
                ]

                x_minimum = max(
                    0,
                    floor(
                        floor(
                            min(
                                np.array([x for x in _ if x >= 0]).min()
                                for _ in xs)) * 0.95))
                # x_minimum = max(0,
                #                 floor(min(_.min() for _ in xs)) - 5)
                y_minimum = max(
                    0,
                    floor(
                        floor(
                            min(
                                np.array([y for y in _ if y >= 0]).min()
                                for _ in ys)) * 0.95))
                x_maximum = min(100,
                                ceil(ceil(max(_.max() for _ in xs)) * 1.05))
                y_maximum = min(100,
                                ceil(ceil(max(_.max() for _ in ys)) * 1.05))

                plotf1curves(
                    plot,
                    fstepsize=ceil(
                        min(x_maximum - x_minimum, y_maximum - y_minimum) /
                        10))
                best_f1 = (-1, [])

                for enumerated, division in enumerate(divisions):
                    for index, vals in enumerate(
                            zip(xs[enumerated], ys[enumerated],
                                options["methods"].keys())):
                        x, y, label = vals

                        f1 = calc_f1(x, y)
                        if best_f1[0] < f1:
                            best_f1 = (f1, [(x, y)])
                        elif best_f1[0] == f1:
                            best_f1[1].append((x, y))

                        if label in ("Mikado permissive", "Mikado stringent"):
                            method_name = "Mikado"
                        else:
                            method_name = label
                        if options["colourmap"]["use"] is False:
                            colour = options["methods"][label]["colour"]
                            matched = re.match(
                                "\(([0-9]*), ([0-9]*), ([0-9]*)\)$", colour)
                            if matched:
                                colour = "#{0:02x}{1:02x}{2:02x}".format(
                                    clamp(int(matched.groups()[0])),
                                    clamp(int(matched.groups()[1])),
                                    clamp(int(matched.groups()[2])))
                        elif options["methods"][label]["colour"] in ("black",
                                                                     "k"):
                            colour = "black"
                        else:
                            colour = color_map(
                                color_normalizer(
                                    options["methods"][label]["index"]))
                        methods[method_name] = colour
                        plot.scatter(
                            x,
                            y,
                            label=label,
                            # label="{0} ({1})".format(label, division),
                            c=colour,
                            marker=options["divisions"][division]["marker"],
                            edgecolor="k",
                            s=[100.0],
                            alpha=1)
                __axes = plot.axes
                __axes.set_xlim(x_minimum, x_maximum)
                __axes.set_ylim(y_minimum, y_maximum)
                # __axes.set_aspect("equal")
                plot.tick_params(axis='both', which='major', labelsize=8)
                # Annotate the best F1
                circle_rad = 20
                for best in best_f1[1]:
                    plot.plot(best[0],
                              best[1],
                              "o",
                              label="Best F1",
                              ms=circle_rad,
                              linestyle="-",
                              mec="k",
                              mfc="none")
                    # marker=best_marker, c="k")

    # Not very efficient way to normalize boundaries for the plots
    for col in range(len(names)):
        p1 = axes[0, col]
        p2 = axes[1, col]

        if len(args.level) == 1 and args.equal is True:
            x_min, x_max = min(p1.get_xlim()[0],
                               p2.get_xlim()[0]), max(p1.get_xlim()[1],
                                                      p2.get_xlim()[1])
            p1.set_xlim(x_min, x_max)
            p2.set_xlim(x_min, x_max)
            y_min, y_max = min(p1.get_ylim()[0],
                               p2.get_ylim()[0]), max(p1.get_ylim()[1],
                                                      p2.get_ylim()[1])
            p1.set_ylim(y_min, y_max)
            p2.set_ylim(y_min, y_max)
        # else:
        #     # Only uniform Y
        #     y1_min, y1_max = min(p1.get_xlim()[0])

    div_labels = []

    f1_line = mlines.Line2D([], [], color="gray", linestyle="--")
    div_labels.append((f1_line, "F1 contours"))

    for division in markers:
        faux_line = mlines.Line2D([], [],
                                  color="white",
                                  marker=markers[division],
                                  markersize=14,
                                  markerfacecolor="black")
        div_labels.append((faux_line, division))

    best_marker_line = mlines.Line2D([], [],
                                     color="white",
                                     marker=best_marker,
                                     markersize=17,
                                     markerfacecolor="none",
                                     markeredgecolor="black")
    div_labels.append((best_marker_line, "Best F1"))

    for method in sorted(options["methods"]):
        print(method)
        colour = methods[method]
        patch = mpatches.Patch(facecolor=colour, linewidth=1, edgecolor="k")
        div_labels.append((patch, method))

    plt.figlegend(handles=[_[0] for _ in div_labels],
                  labels=[_[1] for _ in div_labels],
                  loc="lower center",
                  scatterpoints=1,
                  ncol=ceil((len(methods) + len(markers)) / 4) + 2,
                  fontsize=13,
                  framealpha=0.5)

    plt.tight_layout(
        pad=0.5,
        h_pad=1,
        w_pad=1,
        rect=[
            0.1,  # Left
            0.2,  # Bottom
            0.85,  # Right
            0.9
        ])  # Top
    if args.out is None:
        plt.ion()
        plt.show(block=True)
    else:
        plt.savefig("{}.{}".format(args.out, args.format),
                    format=args.format,
                    dpi=args.dpi,
                    transparent=args.opaque)
def main():

    parser = argparse.ArgumentParser(__doc__,
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-c", "--configuration", required=True, type=argparse.FileType("r"))
    parser.add_argument("-t", "--type", default="F1", choices=["F1", "Precision", "Recall"])
    parser.add_argument("-o", "--out", default=None)
    args = parser.parse_args()

    options = parse_configuration(args.configuration)
    if args.out is not None:
        options["out"] = os.path.splitext(args.out)[0]
    else:
        options["out"] = None

    stats = OrderedDict()

    name_ar = np.array([["Base", "Exon", "Intron"],
                        ["Intron chain", "Transcript", "Gene"]])
    name_ar_orig = name_ar.copy()

    name_corr = OrderedDict()
    name_corr["base"] = ("Base", 5)
    name_corr["exon"] = ("Exon", 7)
    name_corr["intron"] = ("Intron", 8)
    name_corr["intron_chain"] = ("Intron chain", 9)
    name_corr["transcript"] = ("Transcript", 12)
    name_corr["gene"] = ("Gene", 15)
    indices = [name_corr[_][1] for _ in name_corr]

    stats = dict()

    for key in name_corr:
        stats[key] = OrderedDict()

    key_map = []

    for method in sorted(options["methods"].keys()):
        for aligner in sorted(options["divisions"].keys()):
            key_map.append((method, aligner))
            if options["methods"][method][aligner] is not None:
                orig, filtered = options["methods"][method][aligner]
                orig_lines = [line.rstrip() for line in open(orig)]
                filtered_lines = [line.rstrip() for line in open(filtered)]
            else:
                orig_lines = None
                filtered_lines = None
                print("Aligner {} not found for method {}".format(aligner, method))
            # for index, line_index in enumerate([5, 7, 8, 9, 11, 12, 14, 15]):
            for index, line_index in enumerate(indices):
                if orig_lines is not None:
                    precision = float(orig_lines[line_index].split(":")[1].split()[1])
                    recall = float(filtered_lines[line_index].split(":")[1].split()[0])
                    try:
                        f1 = hmean(np.array([precision, recall]))
                    except TypeError as exc:
                        raise TypeError("\n".join([str(_) for _ in [(precision, type(precision)),
                                                                    (recall, type(recall)),
                                                                    exc]]))
                else:
                    precision = -10
                    recall = -10
                    f1 = -10

                stats[list(stats.keys())[index]][(method, aligner)] = (precision, recall, f1)

    zscores = {"precision": defaultdict(OrderedDict),
               "recall": defaultdict(OrderedDict),
               "f1": defaultdict(OrderedDict)}

    for stat in stats:
        for index, feat in enumerate(["precision", "recall", "f1"]):
            vals = np.array([_[index] for _ in stats[stat].values()])
            for key, score in zip(stats[stat].keys(), mstats.zscore(vals)):
                zscores[feat][key][stat] = score

    # Now we have calculated the Zscore for each of the methods for each of the stats
    # Time to sum them up

    zscore_sum = defaultdict(OrderedDict)
    ranks = defaultdict(OrderedDict)
    for feat in ["precision", "recall", "f1"]:
        for key in zscores[feat]:
            zscore_sum[feat][key] = sum(zscores[feat][key].values())
        feat_ranks = sorted(zscore_sum[feat].items(), key=operator.itemgetter(1), reverse=True)
        feat_ranks = list(enumerate(feat_ranks, 1))
        d = dict()
        for rank, (key, zsum) in feat_ranks:
            d[key] = (zsum, rank)
        for key in stats["base"].keys():
            ranks[feat][key] = d[key][:]

    if options["out"] is not None:
        out = open(options["out"], "wt")
    else:
        out = sys.stdout

    print("Method", "Precision", "", "Recall", "", "F1", sep="\t", file=out)
    print("", *["zscore", "rank"] * 3, sep="\t", file=out)
    for key in ranks["precision"].keys():
        row = ["{} ({})".format(*key)]
        for feat in ["precision", "recall", "f1"]:
            row.extend(ranks[feat][key])
        print(*row, sep="\t", file=out)

    if options["out"] is not None:
        out.close()
示例#9
0
文件: train.py 项目: yuxgis/RSCD
def train(config_file='./config_segmentation.json', export=True):
    print('Reading config file...')
    configuration = parse_configuration(config_file)

    print('Initializing dataset...')
    train_dataset = create_dataset(configuration['train_dataset_params'])
    train_dataset_size = len(train_dataset)
    print('The number of training samples = {0}'.format(train_dataset_size))

    val_dataset = create_dataset(configuration['val_dataset_params'])
    val_dataset_size = len(val_dataset)
    print('The number of validation samples = {0}'.format(val_dataset_size))

    print('Initializing model...')
    model = create_model(configuration['model_params'])
    model.setup()

    print('Initializing visualization...')
    #visualizer = Visualizer(configuration['visualization_params'])  # create a visualizer that displays images and plots

    starting_epoch = configuration['model_params']['load_checkpoint'] + 1
    num_epochs = configuration['model_params']['max_epochs']

    for epoch in range(starting_epoch, num_epochs):
        epoch_start_time = time.time()  # timer for entire epoch
        train_dataset.dataset.pre_epoch_callback(epoch)
        model.pre_epoch_callback(epoch)

        train_iterations = len(train_dataset)
        train_batch_size = configuration['train_dataset_params'][
            'loader_params']['batch_size']

        model.train()
        for i, data in enumerate(train_dataset):  # inner loop within one epoch
            #  visualizer.reset()

            model.set_input(
                data)  # unpack data from dataset and apply preprocessing
            model.forward()
            model.backward()

            if i % configuration['model_update_freq'] == 0:
                model.optimize_parameters(
                )  # calculate loss functions, get gradients, update network weights

            if i % configuration['printout_freq'] == 0:
                losses = model.get_current_losses()
                # visualizer.print_current_losses(epoch, num_epochs, i, math.floor(train_iterations / train_batch_size),
                #                                 losses)
                # visualizer.plot_current_losses(epoch, float(i) / math.floor(train_iterations / train_batch_size),
                #                                losses)

        model.eval()
        # for i, data in enumerate(val_dataset):
        #     model.set_input(data)
        #     model.test()
        #
        # model.post_epoch_callback(epoch, visualizer)
        # train_dataset.dataset.post_epoch_callback(epoch)

        print('Saving model at the end of epoch {0}'.format(epoch))
        model.save_networks(epoch)
        model.save_optimizers(epoch)

        print('End of epoch {0} / {1} \t Time Taken: {2} sec'.format(
            epoch, num_epochs,
            time.time() - epoch_start_time))

        model.update_learning_rate()  # update learning rates every epoch

    if export:
        print('Exporting model')
        model.eval()
        custom_configuration = configuration['train_dataset_params']
        custom_configuration['loader_params'][
            'batch_size'] = 1  # set batch size to 1 for tracing
        dl = train_dataset.get_custom_dataloader(custom_configuration)
        sample_input = next(iter(dl))  # sample input from the training dataset
        model.set_input(sample_input)
        model.export()

    return model.get_hyperparam_result()
示例#10
0
def main():

    parser = argparse.ArgumentParser("Script to create the Venn Plots")
    parser.add_argument("-t",
                        "--type",
                        choices=["missing", "full", "fusion"],
                        required=True)
    parser.add_argument("-c",
                        "--configuration",
                        required=True,
                        type=argparse.FileType("r"))
    parser.add_argument("-em",
                        "--exclude-mikado",
                        dest="exclude",
                        action="store_true",
                        default=False,
                        help="Flag. If set, Mikado results will be excluded")
    parser.add_argument("-o",
                        "--out",
                        type=str,
                        help="Output file",
                        required=True)
    parser.add_argument("--format",
                        choices=["svg", "tiff", "png"],
                        default=None)
    # parser.add_argument("-a", "--aligner", choices=["STAR", "TopHat"],
    #                     required=True)
    parser.add_argument(
        "--transcripts",
        action="store_true",
        default=False,
        help="Flag. If set, Venn plotted against transcripts, not genes.")
    parser.add_argument("--title", default="Venn Diagram")
    args = parser.parse_args()

    options = parse_configuration(args.configuration,
                                  exclude_mikado=args.exclude)

    sets = OrderedDict()

    total = Counter()
    first = True

    # Update the sets for each gene and label
    if args.transcripts is True:
        colname = "ref_id"
        ccode = "ccode"
        tag = "transcripts"
    else:
        colname = "ref_gene"
        ccode = "best_ccode"
        tag = "genes"

    for aligner in ["STAR", "TopHat"]:
        for method in options["methods"]:
            refmap = "{}.refmap".format(
                re.sub(".stats$", "", options["methods"][method][aligner][0]))
            with open(refmap) as ref:
                tsv = csv.DictReader(ref, delimiter="\t")
                meth = "{} ({})".format(method, aligner)
                sets[meth] = set()
                for row in tsv:
                    if first:
                        total.update([row[colname]])
                    if row[ccode].lower() in ("na", "x", "p", "i",
                                              "ri") and args.type == "missing":
                        sets[meth].add(row[colname])
                    elif row[ccode] in ("=", "_") and args.type == "full":
                        sets[meth].add(row[colname])
                    elif row[ccode][0] == "f" and args.type == "fusion":
                        sets[meth].add(row[colname])
                    else:
                        continue
                if first:
                    for gid in total:
                        total[gid] = 0
                    first = False

    for aligner in ["STAR", "TopHat"]:
        for method in sorted(options["methods"].keys()):
            set_name = "{} ({})".format(method, aligner)
            # print(set_name)
            sets[set_name] = pd.DataFrame(list(sets[set_name]),
                                          columns=["TID"])

    pyu.plot(
        sets,
        # sort_by="degree",
        inters_size_bounds=(100, 20000),
    )
    if args.format is None:
        args.format = "svg"
    plt.savefig(args.out, format=args.format)
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("-c",
                        "--configuration",
                        required=True,
                        type=argparse.FileType("r"))
    parser.add_argument("--title", required=False, default="")
    parser.add_argument("--log", action="store_true", default=False)
    parser.add_argument(
        "--out",
        required=False,
        default=None,
        help=
        "Output file. If unspecified, the script will exit after printing the numbers."
    )
    parser.add_argument(
        "--genes",
        default=False,
        action="store_true",
        help=
        "Flag. If switched on, the gene level will be used instead of the transcript level."
    )
    parser.add_argument("-p",
                        "--procs",
                        default=multiprocessing.cpu_count(),
                        type=int)
    parser.add_argument("--transcripts", action="store_true", default=False)
    parser.add_argument("--division", action="store_true", default=False)
    # parser.add_argument("refmap", nargs=10, type=argparse.FileType("rt"))
    args = parser.parse_args()

    data = OrderedDict()

    options = parse_configuration(args.configuration)

    if options["colourmap"]["use"] is True:
        color_normalizer = matplotlib.colors.Normalize(0,
                                                       len(options["methods"]))
        color_map = cm.get_cmap(options["colourmap"]["name"])

    header = ["Method", "Division", "Fully", "Missed", "Fused"]
    print(*header, sep="\t")

    pool = multiprocessing.Pool(processes=args.procs)

    for label in options["methods"]:
        for aligner in options["divisions"]:
            if options["methods"][label][aligner] is not None:
                data[(label, aligner)] = [set(), set(), set()]
                orig_refmap = "{}.refmap".format(
                    re.sub(".stats$", "",
                           options["methods"][label][aligner][0]))
                with open(orig_refmap) as refmap:
                    for row in csv.DictReader(refmap, delimiter="\t"):
                        if args.genes is True:
                            if row["best_ccode"] in ("=", "_"):
                                data[(label, aligner)][0].add(row["ref_gene"])
                            elif row["best_ccode"][0] == "f":
                                data[(label, aligner)][2].add(row["ref_gene"])
                            # elif row["best_ccode"] in ("NA", "p", "P", "i", "I", "ri", "rI", "X", "x"):
                            #     data[(label, aligner)][1].add(row["ref_gene"])
                        else:
                            if row["ccode"] in ("=", "_"):
                                data[(label, aligner)][0].add(row["ref_id"])
                            elif row["ccode"][0] == "f":
                                data[(label, aligner)][2].add(row["ref_id"])
                filtered_refmap = "{}.refmap".format(
                    re.sub(".stats$", "",
                           options["methods"][label][aligner][1]))
                with open(filtered_refmap) as refmap:
                    for row in csv.DictReader(refmap, delimiter="\t"):
                        if args.genes is True:
                            if row["best_ccode"] in ("NA", "p", "P", "i", "I",
                                                     "ri", "rI", "X", "x"):
                                data[(label, aligner)][1].add(row["ref_gene"])
                        else:
                            if row["ccode"] in ("NA", "p", "P", "i", "I", "ri",
                                                "rI", "X", "x"):
                                data[(label, aligner)][1].add(row["ref_id"])
                for num in range(3):
                    data[(label, aligner)][num] = len(data[(label,
                                                            aligner)][num])
                orig_stats, filtered_stats = options["methods"][label][
                    aligner][:2]
            else:
                orig_stats, filtered_stats = None, None
            data[(label, aligner)] = pool.apply_async(
                parse_refmaps, (orig_stats, filtered_stats, args.transcripts))

    for label in options["methods"]:
        for aligner in options["divisions"]:
            data[(label, aligner)] = data[(label, aligner)].get()
            print(label, aligner, *data[(label, aligner)], sep="\t")

    # print(*data.items(), sep="\n")

    # Now print out the table

    if args.out is None:
        sys.exit(0)

    # divisions = sorted(options["divisions"].keys())

    if args.division is True:
        figsize = (14, 12)
    else:
        figsize = (6, 8)

    figure, axes = plt.subplots(nrows=1, ncols=3, dpi=300, figsize=figsize)
    figure.suptitle(args.title)

    handles = []
    labels = []

    factor = 10

    for pos, ax in enumerate(axes):
        ax.set_ylim(0, 2)
        max_x = max(data[_][pos] for _ in data) + 500
        min_x = max(0,
                    min(data[_][pos] for _ in data if data[_][pos] > 0) - 500)
        ax.set_ylim(min_x, max_x)
        if args.division is False:
            ax.set_xlim(0, 2.5)
        else:
            ax.set_xlim(-2, len(options["divisions"]) * factor)
        # ax.plot((1, max_x), (1, 1), 'k-')
        ax.tick_params(axis='both', which='major', labelsize=10)
        # ax.set_xticklabels([newticks[pos]], fontsize=15)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["left"].set_visible(True)
        ax.spines["bottom"].set_visible(True)
        ax.tick_params(axis="y", left="on", right="off")
        if args.division is False:
            ax.tick_params(axis="x", top="off", bottom="off")
        else:
            ax.tick_params(axis="x", top="off", bottom="on")
        if pos == 0:
            if args.transcripts is True:
                ax.set_ylabel("Number of transcripts", fontsize=16)
            else:
                ax.set_ylabel("Number of genes", fontsize=16)
            ax.set_xlabel("Reconstructed\ngenes", fontsize=16)
        elif pos == 1:
            ax.set_xlabel("Missed\ngenes", fontsize=16)
        else:
            ax.set_xlabel("Fused\ngenes", fontsize=16)

        if args.division is True:
            ax.set_xticks([
                factor * (_ + 1 / 4) for _ in range(len(options["divisions"]))
            ])
            ax.set_xticklabels(options["divisions"], rotation=45, fontsize=14)

        if args.division is False:
            points = []
        else:
            points = defaultdict(list)

        for index, tup in enumerate(data.keys()):
            method, division = tup
            if options["colourmap"]["use"] is False:
                colour = options["methods"][method]["colour"]
                matched = re.match("\(([0-9]*), ([0-9]*), ([0-9]*)\)$", colour)
                if matched:
                    colour = "#{0:02x}{1:02x}{2:02x}".format(
                        clamp(int(matched.groups()[0])),
                        clamp(int(matched.groups()[1])),
                        clamp(int(matched.groups()[2])))
            elif options["methods"][method]["colour"] in ("black", "k"):
                colour = "black"
            else:
                colour = color_map(
                    color_normalizer(options["methods"][division]["index"]))

            marker = options["divisions"][division]["marker"]
            cat = "{} ({})".format(method, division)
            labels.append(cat)

            point = data[tup][pos]
            if args.division is False:
                points.append([point, cat, colour, marker])
            else:
                points[division].append([point, cat, colour, marker])

        if args.division is False:
            points = sorted(points, key=itemgetter(0))
            for index, point in enumerate(points):
                point, cat, colour, marker = point
                x_coord = 0.5 + (index % 4 / 2)
                handle = axes[pos].scatter(x_coord,
                                           point,
                                           alpha=1,
                                           label=cat,
                                           color=colour,
                                           marker=marker,
                                           edgecolor="k",
                                           s=150)
                handle.get_sketch_params()
                if pos == 0:
                    handles.append(handle)
                handle = mlines.Line2D([], [],
                                       markersize=5,
                                       color=colour,
                                       marker=marker,
                                       label=cat,
                                       alpha=0.6)
        else:

            max_last_xcoord = None

            for dindex, division in enumerate(options["divisions"].keys()):
                max_xcoord = -100
                min_xcoord = 100000
                dpoints = sorted(points[division], key=itemgetter(0))
                for index, point in enumerate(dpoints):
                    point, cat, colour, marker = point
                    if index % 3 == 0:
                        x_coord = factor * dindex
                    elif index % 3 == 2:
                        x_coord = factor * (dindex + 1 / 2)
                    else:
                        x_coord = factor * (dindex + 1 / 4)

                    max_xcoord = max(x_coord, max_xcoord)
                    min_xcoord = min(x_coord, min_xcoord)
                    # print(x_coord, point, cat)
                    handle = axes[pos].scatter(x_coord,
                                               point,
                                               alpha=1,
                                               label=cat,
                                               color=colour,
                                               marker=marker,
                                               edgecolor="k",
                                               s=150)
                    handle.get_sketch_params()
                    if pos == 0:
                        handles.append(handle)
                    handle = mlines.Line2D([], [],
                                           markersize=5,
                                           color=colour,
                                           marker=marker,
                                           label=cat,
                                           alpha=0.6)
                if max_last_xcoord is not None and min_xcoord < max_last_xcoord:
                    raise ValueError("Overlapping X values")
                max_last_xcoord = max_xcoord

    # Set the axis to log if necessary
    if args.log is True:
        plt.xscale("log")

    #
    # plot.plot((1, max(max(data[_]) + 1000 for _ in data)), (2, 2), 'k-')
    # plot.plot((1, max(max(data[_]) + 1000 for _ in data)), (3, 3), 'k-')

    handles = []
    labels = []
    for index, tup in enumerate(data.keys()):
        method, division = tup
        if options["colourmap"]["use"] is False:
            colour = options["methods"][method]["colour"]
            matched = re.match("\(([0-9]*), ([0-9]*), ([0-9]*)\)$", colour)
            if matched:
                colour = "#{0:02x}{1:02x}{2:02x}".format(
                    clamp(int(matched.groups()[0])),
                    clamp(int(matched.groups()[1])),
                    clamp(int(matched.groups()[2])))
        elif options["methods"][method]["colour"] in ("black", "k"):
            colour = "black"
        else:
            colour = color_map(
                color_normalizer(options["methods"][division]["index"]))

        # color = colors[int(index / 2)]
        marker = options["divisions"][division]["marker"]
        cat = "{} ({})".format(method, division)
        labels.append(cat)
        # handles.append(handle)
        for pos, point in enumerate(data[tup]):
            handle = axes[pos].scatter(point,
                                       1,
                                       alpha=1,
                                       label=cat,
                                       color=colour,
                                       marker=marker,
                                       edgecolor="k",
                                       s=150)
            handle.get_sketch_params()
            if pos == 0:
                handles.append(handle)
        handle = mlines.Line2D([], [],
                               markersize=5,
                               color=colour,
                               marker=marker,
                               label=cat,
                               alpha=0.6)

    div_labels = []
    for division in options["divisions"]:
        faux_line = mlines.Line2D(
            [], [],
            color="white",
            marker=options["divisions"][division]["marker"],
            markersize=14,
            markerfacecolor="black")
        div_labels.append((faux_line, division))

    for method in options["methods"]:
        if options["colourmap"]["use"] is False:
            colour = options["methods"][method]["colour"]
            matched = re.match("\(([0-9]*), ([0-9]*), ([0-9]*)\)$", colour)
            if matched:
                colour = "#{0:02x}{1:02x}{2:02x}".format(
                    clamp(int(matched.groups()[0])),
                    clamp(int(matched.groups()[1])),
                    clamp(int(matched.groups()[2])))
        elif options["methods"][method]["colour"] in ("black", "k"):
            colour = "black"
        else:
            colour = color_map(
                color_normalizer(options["methods"][method]["index"]))

        patch = mpatches.Patch(facecolor=colour, linewidth=1, edgecolor="k")
        div_labels.append((patch, method))

    if args.division is True:
        fs = 16
    else:
        fs = 10

    plt.figlegend(handles=[_[0] for _ in div_labels],
                  labels=[_[1] for _ in div_labels],
                  loc="lower center",
                  scatterpoints=1,
                  ncol=min(ceil(len(options["methods"]) * 2 / 4), 3),
                  fontsize=fs,
                  framealpha=0.5)
    plt.tight_layout(
        pad=0.5,
        h_pad=1,
        w_pad=1,
        rect=[
            0.05,  # Left
            0.15,  # Bottom
            0.95,  # Right
            0.95
        ])  # Top
    out = "{}.{}".format(os.path.splitext(args.out)[0], options["format"])
    plt.savefig(out, format=options["format"], transparent=True)
示例#12
0
def main():

    parser = argparse.ArgumentParser("Script to create a violin plot.")
    parser.add_argument("--quant_file",
                        "-q",
                        required=True,
                        type=argparse.FileType("r"))
    parser.add_argument("-c",
                        "--configuration",
                        required=True,
                        type=argparse.FileType("r"))
    parser.add_argument(
        "--out",
        nargs="?",
        type=str,
        default=None,
        help="Optional output file name. Default: None (show to stdout)")
    parser.add_argument("--format",
                        choices=["svg", "png", "pdf"],
                        default="svg")
    parser.add_argument("--title", type=str, default="")
    args = parser.parse_args()

    options = parse_configuration(args.configuration)
    # Retrieve FPKM
    star = dict()
    tophat = dict()

    for row in csv.DictReader(args.quant_file, delimiter="\t"):
        # Set up STAR
        star[row["target_id"]] = dict()
        star[row["target_id"]]["TPM"] = numpy.log2(float(row["tpm"]) + 1)
        star[row["target_id"]]["tid"] = row["target_id"]
        star[row["target_id"]]["Aligner"] = "STAR"
        for num in range(1, 7):
            star[row["target_id"]][num] = 0

        # Set up TopHat
        tophat[row["target_id"]] = dict()
        tophat[row["target_id"]]["TPM"] = numpy.log2(float(row["tpm"]) + 1)
        tophat[row["target_id"]]["tid"] = row["target_id"]
        tophat[row["target_id"]]["Aligner"] = "TopHat"
        for num in range(1, 7):
            tophat[row["target_id"]][num] = 0

    for aligner, dictionary in (("STAR", star), ("TopHat", tophat)):
        for method in options["methods"]:
            input_file = "{}.refmap".format(
                re.sub(".stats", "", options["methods"][method][aligner][0]))
            print(input_file)
            dictionary = analyse_refmap(input_file, dictionary)

    star_data = []
    tophat_data = []
    merged_data = []
    for tid in star:
        star_data.append([tid] + [star[tid]["TPM"]] + ["STAR"] +
                         [star[tid][_] for _ in range(1, 7)])
        tophat_data.append([tid] + [tophat[tid]["TPM"]] + ["Tophat"] +
                           [tophat[tid][_] for _ in range(1, 7)])
        merged_data.append(
            [tid] + [tophat[tid]["TPM"]] + ["Either"] +
            [star[tid][_] + tophat[tid][_] for _ in range(1, 7)])

    print("Generating the final data frames")
    star_data = pandas.DataFrame(star_data,
                                 columns=["tid", "TPM", "Aligner"] +
                                 list(range(1, 7)))
    tophat_data = pandas.DataFrame(tophat_data,
                                   columns=["tid", "TPM", "Aligner"] +
                                   list(range(1, 7)))
    merged_data = pandas.DataFrame(merged_data,
                                   columns=["tid", "TPM", "Aligner"] +
                                   list(range(1, 7)))

    generate_plot(star_data, tophat_data, merged_data, "full", args, options)
    generate_plot(star_data, tophat_data, merged_data, "missed", args, options)
    return
def main():

    parser = argparse.ArgumentParser(
        "Script to create the merged ccode/TPM input for Gemy's modified script."
    )
    parser.add_argument("--quant_file",
                        "-q",
                        required=True,
                        type=argparse.FileType("r"))
    parser.add_argument("-c",
                        "--configuration",
                        required=True,
                        type=argparse.FileType("r"))
    parser.add_argument(
        "--out",
        nargs="?",
        type=str,
        default=None,
        help="Optional output file name. Default: None (show to stdout)")
    parser.add_argument("--title", type=str, default="")
    args = parser.parse_args()

    options = parse_configuration(args.configuration)
    values = dict()

    # Retrieve FPKM
    for row in csv.DictReader(args.quant_file, delimiter="\t"):
        values[row["target_id"]] = dict()
        values[row["target_id"]]["TPM"] = float(row["tpm"])
        values[row["target_id"]]["tid"] = row["target_id"]

    labels = []
    for aligner in sorted(options["divisions"]):
        for method in options["methods"]:
            label = "{} ({})".format(method, aligner)
            labels.append(label)
            input_file = "{}.refmap".format(
                re.sub(".stats", "", options["methods"][method][aligner][0]))
            values = analyse_refmap(input_file, label, values)

    tids = set(values.keys())
    right_total = len(labels) + 2
    count_zero = set()
    for tid in tids:
        if len(values[tid].keys()) == 2:
            del values[tid]
        elif values[tid]["TPM"] == 0:
            del values[tid]
            count_zero.add(tid)
        elif len(values[tid].keys()) != right_total:
            raise KeyError("ID {} has been found only in {}".format(
                tid, values[tid].keys()))
    tids = set.difference(tids, count_zero)
    if tids != set(values.keys()):
        print("Removed {} TIDs due to filtering".format(
            len(tids) - len(set(values.keys()))))

    data = defaultdict(list)

    sorter = functools.partial(sort_values, **{"dictionary": values})
    keys = None
    for tid in values:
        if keys is None:
            keys = values[tid].keys()
        for key in values[tid]:
            data[key].append(values[tid][key])

            #        out.writerow(values[tid])
    data = pandas.DataFrame(data, columns=["tid", "TPM"] + labels)
    generate_plot(data, args, options, nrows=len(options["divisions"]))

    return
示例#14
0
def main():

    parser = argparse.ArgumentParser(__doc__)
    # parser.add_argument("-c", "--configuration", required=True, type=argparse.FileType("r"))
    parser.add_argument("--species", type=argparse.FileType("r"))
    parser.add_argument("--title", required=True)
    parser.add_argument("--log", action="store_true", default=False)
    parser.add_argument("--out", required=True)
    parser.add_argument(
        "--type",
        required=True,
        choices=["missed", "fused", "detected", "reconstructed"])
    parser.add_argument("--format",
                        default="svg",
                        choices=["png", "pdf", "ps", "eps", "svg"])
    parser.add_argument("--dpi", default=1000, type=int)
    parser.add_argument("--opaque", default=True, action="store_false")
    # parser.add_argument("refmap", nargs=10, type=argparse.FileType("rt"))
    args = parser.parse_args()

    species = yaml.load(args.species)

    names = species.pop("names")

    name_keys = dict()
    for name in species:
        if isinstance(species[name], dict) and species[name].get(
                "name", None) in names:
            name_keys[species[name]["name"]] = name

    ncols = len(names)

    figure, axes = plt.subplots(nrows=1,
                                ncols=ncols,
                                dpi=args.dpi,
                                figsize=(10 / 2 * ncols, 13))

    figure.suptitle(" ".join(["${}$".format(_) for _ in args.title.split()]),
                    fontsize=20,
                    style="italic",
                    family="serif")
    figure.text(0.5,
                0.15,
                "${}\ genes\ by\ method,\ in\ other\ methods$".format(
                    args.type.title()),
                ha="center",
                fontsize=15)

    for yrow, name in enumerate(names):
        plot = axes[yrow]
        plot.set_title("${}$".format(names[yrow]), fontsize=15)
        data = defaultdict(dict)
        with open(species[name_keys[name]]["configuration"]) as configuration:
            options = parse_configuration(
                configuration, prefix=species[name_keys[name]]["folder"])
        # options = parse_configuration(args.configuration)

        for label in options["methods"]:
            for aligner in options["divisions"]:
                print(label, aligner)
                if options["methods"][label][aligner] is not None:
                    # data[(label, aligner)] = [set(), set(), set()]
                    orig_refmap = "{}.refmap".format(
                        re.sub(".stats$", "",
                               options["methods"][label][aligner][0]))
                    orig_refmap = pandas.read_csv(orig_refmap, delimiter="\t")
                    orig_refmap = orig_refmap[["ref_gene", "best_ccode"
                                               ]].drop_duplicates("ref_gene")
                    orig_refmap.columns = ["ref_gene", "orig_ccode"]
                    orig_refmap.orig_ccode.fillna("NA", inplace=True)
                    filtered_refmap = "{}.refmap".format(
                        re.sub(".stats$", "",
                               options["methods"][label][aligner][1]))
                    filtered_refmap = pandas.read_csv(filtered_refmap,
                                                      delimiter="\t")
                    filtered_refmap = filtered_refmap[[
                        "ref_gene", "best_ccode"
                    ]].drop_duplicates("ref_gene")
                    filtered_refmap.columns = ["ref_gene", "filtered_ccode"]
                    filtered_refmap.filtered_ccode.fillna("NA", inplace=True)
                    conc = pandas.merge(orig_refmap,
                                        filtered_refmap,
                                        how="outer")
                    for row in conc.itertuples():
                        if row.orig_ccode in ("=", "_"):
                            category = 0
                        elif row.orig_ccode[0] == "f":
                            category = 2
                        elif not row.filtered_ccode is np.nan and row.filtered_ccode in (
                                "NA", "p", "P", "i", "I", "ri", "rI", "X",
                                "x"):
                            category = 3
                        else:
                            category = 1

                        data[row.ref_gene][(label, aligner)] = category

                    # for num in range(3):
                    #     data[(label, aligner)][num] = len(data[(label, aligner)][num])
                # else:
                #     data[(label, aligner)] = [-1000] * 3

        # Now we have to create the numbers
        counts = OrderedDict()

        type_dict = {
            "missed": 3,
            "fused": 2,
            "detected": 1,
            "reconstructed": 0
        }

        to_search = type_dict[args.type]

        # This is probably VERY inefficient
        keys = list(
            reversed(
                sorted(itertools.product(options["methods"],
                                         options["divisions"]),
                       key=operator.itemgetter(1))))

        max_val = 0
        colours = ["darkorange", "darkcyan", "purple", "lightcoral"]
        for rowno, key in enumerate(keys, 1):
            # Full, detected, fused, missed
            counts[key] = [0, 0, 0, 0]
            for gene in data:
                assert isinstance(data[gene], dict), (gene, type(data[gene]))
                assert key in data[gene], (gene, key, data[gene].keys())
                if data[gene][key] == to_search:
                    for pos in range(4):
                        # We have to segregate by division
                        if any(data[gene][_] == pos for _ in data[gene]
                               if _ != key and _[1] == key[1]):
                            counts[key][pos] += 1
                            break

            left = 0
            for num, old, colour in zip(counts[key], [0] + counts[key][:-1],
                                        colours):
                print(rowno, num, colour, left)
                left += old
                plot.barh(rowno,
                          num,
                          color=colour,
                          left=left,
                          height=0.5,
                          edgecolor="k")
            max_val = max(max_val, sum(counts[key]))
            # max_val = max(max_val, sum(counts[key]))

        plot.axes.set_xlim(0, max_val * 1.1)
        plot.axes.set_ylim(0, len(keys) + 0.5)

        # Now draw the lines ...

        plot.spines["top"].set_visible(False)
        plot.spines["right"].set_visible(False)

        if yrow == 0:
            plot.axes.set_yticks(range(1, len(keys) + 1))
            plot.axes.set_yticklabels(["{}".format(_[0]) for _ in keys],
                                      fontsize=14)
            x, y = np.array([[-2000, -0.5], [0, 0]])
            line = lines.Line2D(x, y, lw=1., color='k', alpha=1)
            line.set_clip_on(False)
            plot.add_line(line)
            for div_index, div in enumerate(
                    reversed(sorted(options["divisions"])), 1):
                line_pos = len(options["methods"]) * div_index + 0.5
                text_pos = len(options["methods"]) * (div_index - 0.5) + 0.5
                x, y = np.array([[-2000, -0.5], [line_pos, line_pos]])
                line = lines.Line2D(x, y, lw=1., color='k', alpha=1)
                line.set_clip_on(False)
                plot.add_line(line)
                # ax.plot((-1000, max_val), (line_pos, line_pos), color="lightgrey")
                plot.text(-max_val * 0.75,
                          text_pos,
                          div,
                          ha="center",
                          fontsize=15,
                          rotation="vertical")
        else:
            plot.axes.set_yticks([])

    # Now the legend ...

    div_labels = []
    full = mpatches.Patch(facecolor=colours[0], linewidth=1, edgecolor="k")
    detected = mpatches.Patch(facecolor=colours[1], linewidth=1, edgecolor="k")
    fused = mpatches.Patch(facecolor=colours[2], linewidth=1, edgecolor="k")
    missed = mpatches.Patch(facecolor=colours[3], linewidth=1, edgecolor="k")

    div_labels.append((full, "Fully reconstructed by at least another method"))
    div_labels.append((detected, "Detected by at least another method"))
    div_labels.append((fused, "Fused or missed by any other method"))
    div_labels.append((missed, "Missed by all other methods"))

    plt.figlegend(handles=[_[0] for _ in div_labels],
                  labels=[_[1] for _ in div_labels],
                  loc="lower center",
                  scatterpoints=1,
                  ncol=2,
                  fontsize=20,
                  framealpha=0.5)

    plt.tight_layout(
        pad=0.5,
        h_pad=1,
        w_pad=1,
        rect=[
            0.1,  # Left
            0.2,  # Bottom
            0.85,  # Right
            0.9
        ])  # Top
    if args.out is None:
        plt.ion()
        plt.show(block=True)
    else:
        plt.savefig("{}.{}".format(args.out, args.format),
                    format=args.format,
                    dpi=args.dpi,
                    transparent=args.opaque)
示例#15
0
def main():
    """"""

    parser = argparse.ArgumentParser(
        __doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--species",
                        type=argparse.FileType("r"),
                        required=True)
    parser.add_argument("--out",
                        type=argparse.FileType("w"),
                        default=sys.stdout)
    parser.add_argument("--format",
                        choices=tabulate._table_formats.keys(),
                        default="latex")
    parser.add_argument("--level",
                        default="transcript",
                        choices=line_correspondence.keys())
    args = parser.parse_args()

    species = yaml.load(args.species)

    # Names are the SPECIES
    names = species.pop("names")
    # Categories are Real vs Simulated data
    categories = species.pop("categories")

    # We want to create a table of the form

    # Species | Method   || Category              || Category              || ..
    # species | Name     | Prec | Rec | F1        ||Prec | Rec | F1        || ..

    name_ar = []
    for category in categories:
        for name in names:
            # print(category, name)
            key = [
                _ for _ in species.keys()
                if isinstance(species[_], dict) and species[_]["name"] == name
                and species[_]["category"] == category
            ]
            # print(key)
            assert len(key) == 1
            key = key.pop()
            name_ar.append(key)
    name_ar = np.array(list(grouper(name_ar, ceil(len(species) / 2), None)))

    header = [""] + list(categories)

    header.append(["Species", "Aligner", "Method"] +
                  ["Precision", "Recall", "F1"] * 2)

    rows = []
    # print(rows)

    key = None
    methods = None
    divisions = None

    for yrow, name in enumerate(names):

        new_rows = OrderedDict()

        for xrow, category in enumerate(categories):
            try:
                key = name_ar[xrow, yrow]
            except IndexError:
                raise IndexError(name_ar, xrow, yrow)
            if key is None:
                raise IndexError(name_ar, xrow, category, yrow, name)
                # continue

            with open(species[key]["configuration"]) as configuration:
                options = parse_configuration(configuration,
                                              prefix=species[key]["folder"])
                # Assembler
                if methods is None:
                    methods = list(options["methods"])
                    divisions = list(options["divisions"])

                for method in options["methods"]:
                    # Aligner
                    for division in options["divisions"]:
                        meth_key = (method, division)
                        if meth_key not in new_rows:
                            new_rows[meth_key] = OrderedDict()
                        try:
                            orig, filtered = options["methods"][method][
                                division]
                        except TypeError:
                            warnings.warn(
                                "Something went wrong for {}, {}; continuing".
                                format(method, division))

                            new_rows[meth_key][category] = (-10, -10, -10)
                            continue
                        orig_lines = [line.rstrip() for line in open(orig)]
                        filtered_lines = [
                            line.rstrip() for line in open(filtered)
                        ]
                        for index, line_index in enumerate(
                            [line_correspondence[args.level]]):
                            precision = float(orig_lines[line_index].split(":")
                                              [1].split()[1])
                            recall = float(filtered_lines[line_index].split(
                                ":")[1].split()[0])
                            try:
                                f1 = hmean(np.array([precision, recall]))
                            except TypeError as exc:
                                raise TypeError("\n".join([
                                    str(_) for _ in [(
                                        precision,
                                        type(precision)), (recall,
                                                           type(recall)), exc]
                                ]))
                            # print(level, method, division, (precision, recall, f1))
                            new_rows[meth_key][category] = (precision, recall,
                                                            f1)

        begun = False
        for division in divisions:
            division_done = False
            for method in methods:
                meth_key = (method, division)
                if not begun:
                    row = [name]
                    begun = True
                else:
                    row = [""]

                if not division_done:
                    row.append(division)
                    division_done = True
                else:
                    row.append("")

                row.append(method)
                # row.append(meth_key)
                # print(new_rows[meth_key].keys())
                for category in new_rows[meth_key]:
                    row.extend(new_rows[meth_key][category])

                rows.append(row)

    print(tabulate.tabulate(rows, headers=header, tablefmt=args.format))
    # print(categories)
    return