def run_loss_map_discriminative_sum_with_prior():
  # n<=3 is not peaky. n>=4 is peaky.
  n = 10
  plot_loss_grad_map(
    value_range=(-5., 5.), steps=11,
    loss_type="sum_with_prior", model_type="model_free",
    xlabel=r"$\theta_a$", ylabel=r"$\theta_b$",
    num_classes=1, target_seq=get_std_fsa_1label(), input_seq=[1] * n + [0] * 2 * n + [1] * n)
def run_loss_map_generative_sum():
  # n<=3 is not peaky. n>=4 is peaky.
  n = 10
  plot_loss_grad_map(
    value_range=(-5., 5.), steps=21,
    loss_type="gen_sum", model_type="gen_model_free",
    xlabel=r"$\theta_a$", ylabel=r"$\theta_b$",
    num_classes=1, target_seq=get_std_fsa_1label(), input_seq=[1] * n + [0] * 2 * n + [1] * n)
def run_ffnn_sum_with_prior():
  """
  Section 7, loss with prior (L_hybrid) does not have peaky behavior.
  Simulation 7.5.
  """
  n = 4
  opts = dict(
    loss_type="sum_with_prior",
    model_type={"bias": False},
    num_frames=None,
    num_classes=1, target_seq=get_std_fsa_1label(),
    input_seq=[1] * n + [0] * 2 * n + [1] * n,
    init_type="zero")
  run(**opts)
def run_generative_sum():
  """
  Section 7, loss with generative model (L_generative) does not have peaky behavior.
  Simulation 7.13.

  Note that the output can be a bit confusing,
  as the Baum-Welch stdout has a wrong normalization
  (that is a different computation than what is being used for the loss),
  and also the softmax stdout does not make sense.
  """
  n = 4
  opts = dict(
    loss_type="gen_sum", model_type="gen_model_free",
    num_frames=None,
    num_classes=1, target_seq=get_std_fsa_1label(),
    input_seq=[1] * n + [0] * 2 * n + [1] * n,
    init_type="zero")
  run(**opts)
def test_count_all_paths_with_label_seq_partly_dominated_inefficient():
  fsa = get_std_fsa_1label()  # same as target_seq=[0]
  n = 4
  opts = dict(
    loss_type="sum", model_type="model_free",
    num_classes=1, target_seq=fsa, input_seq=[1] * n + [0] * 2 * n + [1] * n)
  alpha = 0.5
  prob_dom = numpy.exp(alpha) / (numpy.exp(alpha) + numpy.exp(-alpha))
  print("prob dom:", prob_dom)
  _, grad0, bw0 = get_loss_grad_single(param_value=[0., 0.], **opts)
  _, grad1, bw1 = get_loss_grad_single(param_value=[-alpha, 0.], **opts)
  _, grad2, bw2 = get_loss_grad_single(param_value=[0., alpha], **opts)
  # print(bw1)
  # print(bw2)
  # print(grad1)
  # print(grad2)
  from fst_utils import count_all_paths_with_label_in_frame
  from fst_utils import count_all_paths_with_label_seq_partly_dominated_inefficient
  from fst_utils import Label1StrTemplate, BlankLabel, Label1
  num_frames = n * len(Label1StrTemplate)
  assert bw0.shape == bw1.shape == bw2.shape == (num_frames, 2)

  num_frames_sym, t_sym, c_a_sym = count_all_paths_with_label_in_frame(fsa=fsa, label=Label1)
  num_frames_sym_, t_sym_, c_b_sym = count_all_paths_with_label_in_frame(fsa=fsa, label=BlankLabel)
  for t in range(num_frames):
    c_a = int(c_a_sym.subs(num_frames_sym, num_frames).subs(t_sym, t).doit())
    c_b = int(c_b_sym.subs(num_frames_sym_, num_frames).subs(t_sym_, t).doit())
    z = c_a + c_b
    soft = [float(c_a) / z, float(c_b) / z]
    numpy.testing.assert_allclose(bw0[t], soft, rtol=1e-5)

  res_ = count_all_paths_with_label_seq_partly_dominated_inefficient(
    fsa=fsa, label_seq_template=Label1StrTemplate, dom_label=BlankLabel, n=n,
    prob_dom=0.5, normalized=False, verbosity=1)
  for input_label in [Label1, BlankLabel]:
    c = 0
    res_by_label = numpy.zeros([2])
    c_a = c_b = 0
    for i, input_label_ in enumerate(Label1StrTemplate):
      if input_label_ != input_label:
        c += n
        for j in range(i * n, i * n + n):
          res_by_label += bw0[j]
          c_a += int(c_a_sym.subs(num_frames_sym, num_frames).subs(t_sym, j).doit())
          c_b += int(c_b_sym.subs(num_frames_sym_, num_frames).subs(t_sym_, j).doit())
    res_by_label /= c

    res_by_label_ = res_[(input_label, {Label1: BlankLabel, BlankLabel: Label1}[input_label])]
    assert c_a == res_by_label_[Label1] and c_b == res_by_label_[BlankLabel]
    res_by_label_ = numpy.array([res_by_label_[Label1], res_by_label_[BlankLabel]], dtype="float32")
    res_by_label_ /= sum(res_by_label_)
    numpy.testing.assert_allclose(res_by_label, res_by_label_)

  bws = {Label1: bw1, BlankLabel: bw2}
  res = {}
  for input_label in [Label1, BlankLabel]:
    c = 0
    res_by_label = {Label1: 0.0, BlankLabel: 0.0}
    for i, input_label_ in enumerate(Label1StrTemplate):
      if input_label_ != input_label:
        c += n
        for j in range(i * n, i * n + n):
          res_by_label[Label1] += bws[input_label][j][0]
          res_by_label[BlankLabel] += bws[input_label][j][1]
    res_by_label = {k: v / c for (k, v) in res_by_label.items()}
    res[(input_label, {Label1: BlankLabel, BlankLabel: Label1}[input_label])] = res_by_label
  print(res)
  res_ = count_all_paths_with_label_seq_partly_dominated_inefficient(
    fsa=fsa, label_seq_template=Label1StrTemplate, dom_label=BlankLabel, n=n, prob_dom=prob_dom)
  print(res_)
  assert set(res.keys()) == set(res_.keys())
  for key, res_by_label_ in res_.items():
    res_by_label = res[key]
    assert set(res_by_label.keys()) == set(res_by_label_.keys()) == {Label1, BlankLabel}
    for label in [Label1, BlankLabel]:
      numpy.testing.assert_allclose(res_by_label[label], res_by_label_[label])