def plot_sudden_markov_multinutr_environments(input_fname, plot_fname, label):
    """
    Plot sudden multinutrient environments.
    """
    print "plotting: %s" %(os.path.basename(plot_fname))
    # set seed
    np.random.seed(2)
    plt.figure(figsize=(4.5, 4.05))
    # transition parameters
    trans_params = [np.array([[0.9, 0.05, 0.05],
                              [0.05, 0.9, 0.05],
                              [0.05, 0.05, 0.9]]),
                    np.array([[0.05, 0.9, 0.05],
                              [0.9, 0.05, 0.05],
                              [0.05, 0.05, 0.9]]),
                    np.array([[0.05, 0.9, 0.05],
                              [0.1, 0.6, 0.3],
                              [0.05, 0.05, 0.9]])]
    # initial state probabilities: start with Glu
    init_probs = np.array([1., 0., 0.])
    time_obj = time_unit.Time(0, 100, step_size=1)
    num_points = len(time_obj.t)
    #plt.axvspan(a, b, color='y', alpha=0.5, lw=0)
    num_plots = len(trans_params)
    n = 0
    data_to_labels = {0: "Glu",
                      1: "Gal",
                      2: "Mal"}
    labels_to_colors = \
      {"Glu": sns.color_palette("Set1")[2],
       "Gal": sns.color_palette("Set1")[1],
       "Mal": plot_utils.orange}
    sns.set_style("ticks")
    num_xticks = 11
    palette = "seismic"
    prev_n = 0
    prev_k = 0
    gs = gridspec.GridSpec(num_plots * 3, 3)
    gs.update(hspace=0.6)
    # order of nutrients along transition matrix
    nutr_order = ["Glu", "Gal", "Mal"]
    for n in xrange(num_plots):
        trans_mat = trans_params[n]
        trans_df = pandas.DataFrame(trans_mat)
        trans_df.columns = nutr_order
        trans_df.index = nutr_order
        ax1 = plt.subplot(gs[prev_n:prev_n+3, 0])
        heatmap_ax = sns.heatmap(trans_df, annot=True, cbar=False,
                                 cmap=plt.cm.gray_r,
                                 linewidth=0.1,
                                 linecolor="k",
                                 annot_kws={"fontsize": 8})
        heatmap_ax.set_aspect("equal")
        ax1.tick_params(axis="both", which="major", pad=0.01, length=2.5,
                        labelsize=8)
        if n != (num_plots - 1):
            plt.axis("off")
        ax2 = plt.subplot(gs[prev_k+1:prev_k + 2, 1:3])
        prev_n += 3
        prev_k += 3
        # sample values from current transition matrix
        samples = prob_utils.sample_markov_chain(num_points,
                                                 init_probs,
                                                 trans_mat)
        handles_info = \
          plot_utils.plot_sudden_switches(time_obj, samples,
                                          data_to_labels=data_to_labels,
                                          labels_to_colors=labels_to_colors,
                                          ax=ax2,
                                          pad=0.01,
                                          despine=True)
        if n == (num_plots - 1):
            plt.xlabel("Time step", fontsize=8)
    plt.savefig(plot_fname)
def plot_sudden_ssm_filtering(input_fname, plot_fname, label):
    """
    Plot SSM filtering predictions.
    """
    print "plotting: %s" %(os.path.basename(plot_fname))
    fig = plt.figure(figsize=(7, 5))
    sns.set_style("ticks")
    all_results = simulation.load_data(input_fname)
    all_results = all_results["model"]
    num_plots = len(all_results)
    total_plots = num_plots * 2
    gs = gridspec.GridSpec(total_plots, 1,
                           height_ratios=[1, 0.2]*num_plots)
    curr_plot_num = 0
    axes = {}
    for n, data_label in enumerate(all_results.keys()):
        data_set = all_results[data_label]
        params = data_set["params"]
        time_obj = time_unit.Time(params["t_start"],
                                  params["t_end"],
                                  step_size=params["step_size"])
        c = 0.8
        x_axis = time_obj.t[0::4]
        xlims = [time_obj.t[0] - c, time_obj.t[-1] + c]
        ax1 = plt.subplot(gs[curr_plot_num, 0])
        pred_probs = [p[0] for p in data_set["preds_with_lag"]]
        plt.plot(time_obj.t, pred_probs, "-o", color=plot_utils.red,
                 label="Prediction",
                 clip_on=False,
                 zorder=100)
        plt.xlabel(r"Time step")
        plt.ylabel(r"$P(C_{t+1} =\ \mathsf{Glu} \mid  C_{0:t})$",
                   fontsize=11)
        plt.title("lag = %d" %(params["decision_lag_time"]), fontsize=8)
        plt.legend(loc="lower right")
        plt.xticks(x_axis, fontsize=8)
        plt.xlim(xlims)
        plt.ylim([0, 1])
        plt.yticks(np.arange(0, 1 + 0.2, 0.2))
        ax2 = plt.subplot(gs[curr_plot_num + 1, 0])
        data_to_labels = {0: "Glu", 1: "Gal"}
        labels_to_colors = {"Glu": plot_utils.green, "Gal": plot_utils.blue}
        data = params["data"]
        ax2.get_yaxis().set_visible(False)
        ax2.set_yticks([])
        ax2.spines["left"].set_visible(False)
        plot_utils.plot_sudden_switches(time_obj, data,
                                        data_to_labels=data_to_labels,
                                        labels_to_colors=labels_to_colors,
                                        box_height=0.025,
                                        y_val=0.02,
                                        ax=ax2,
                                        despine=False,
                                        with_legend=True,
                                        legend_outside=(0,1))
        plt.xticks(x_axis, fontsize=8)
        plt.xlim(xlims)
        # despine axes
        sns.despine(trim=True, left=True, ax=ax2)
        sns.despine(trim=True, ax=ax1)
        # advance number of plots by two
        curr_plot_num += 2
    ax1.spines["left"].set_visible(True)
    #plt.tight_layout(h_pad=0.1)
    fig.set_tight_layout(True)
    plt.savefig(plot_fname)