def plt_divergence(p_hist, J_hist, x_train, y_train):

    x = np.zeros(len(p_hist))
    y = np.zeros(len(p_hist))
    v = np.zeros(len(p_hist))
    for i in range(len(p_hist)):
        x[i] = p_hist[i][0]
        y[i] = p_hist[i][1]
        v[i] = J_hist[i]

    fig = plt.figure(figsize=(12, 5))
    plt.subplots_adjust(wspace=0)
    gs = fig.add_gridspec(1, 5)
    fig.suptitle(f"Cost escalates when learning rate is too large")
    #===============
    #  First subplot
    #===============
    ax = fig.add_subplot(gs[:2], )

    # Print w vs cost to see minimum
    fix_b = 100
    w_array = np.arange(-70000, 70000, 1000)
    cost = np.zeros_like(w_array)

    for i in range(len(w_array)):
        tmp_w = w_array[i]
        cost[i] = compute_cost(x_train, y_train, tmp_w, fix_b)

    ax.plot(w_array, cost)
    ax.plot(x, v, c=dlmagenta)
    ax.set_title("Cost vs w, b set to 100")
    ax.set_ylabel('Cost')
    ax.set_xlabel('w')
    ax.xaxis.set_major_locator(MaxNLocator(2))

    #===============
    # Second Subplot
    #===============

    tmp_b, tmp_w = np.meshgrid(np.arange(-35000, 35000, 500),
                               np.arange(-70000, 70000, 500))
    z = np.zeros_like(tmp_b)
    for i in range(tmp_w.shape[0]):
        for j in range(tmp_w.shape[1]):
            z[i][j] = compute_cost(x_train, y_train, tmp_w[i][j], tmp_b[i][j])

    ax = fig.add_subplot(gs[2:], projection='3d')
    ax.plot_surface(tmp_w, tmp_b, z, alpha=0.3, color=dlblue)
    ax.xaxis.set_major_locator(MaxNLocator(2))
    ax.yaxis.set_major_locator(MaxNLocator(2))

    ax.set_xlabel('w', fontsize=16)
    ax.set_ylabel('b', fontsize=16)
    ax.set_zlabel('\ncost', fontsize=16)
    plt.title('Cost vs (b, w)')
    # Customize the view angle
    ax.view_init(elev=20., azim=-65)
    ax.plot(x, y, v, c=dlmagenta)

    return
def plt_contour_wgrad(x,
                      y,
                      hist,
                      ax,
                      w_range=[-100, 500, 5],
                      b_range=[-500, 500, 5],
                      contours=[0.1, 50, 1000, 5000, 10000, 25000, 50000],
                      resolution=5,
                      w_final=200,
                      b_final=100,
                      step=10):
    b0, w0 = np.meshgrid(np.arange(*b_range), np.arange(*w_range))
    z = np.zeros_like(b0)
    for i in range(w0.shape[0]):
        for j in range(w0.shape[1]):
            z[i][j] = compute_cost(x, y, w0[i][j], b0[i][j])

    CS = ax.contour(w0,
                    b0,
                    z,
                    contours,
                    linewidths=2,
                    colors=[dlblue, dlorange, dldarkred, dlmagenta, dlpurple])
    ax.clabel(CS, inline=1, fmt='%1.0f', fontsize=10)
    ax.set_xlabel("w")
    ax.set_ylabel("b")
    ax.set_title(
        'Contour plot of cost J(w,b), vs b,w with path of gradient descent')
    w = w_final
    b = b_final
    ax.hlines(b, ax.get_xlim()[0], w, lw=2, color=dlpurple, ls='dotted')
    ax.vlines(w, ax.get_ylim()[0], b, lw=2, color=dlpurple, ls='dotted')

    base = hist[0]
    for point in hist[0::step]:
        edist = np.sqrt((base[0] - point[0])**2 + (base[1] - point[1])**2)
        if (edist > resolution or point == hist[-1]):
            if inbounds(point, base, ax.get_xlim(), ax.get_ylim()):
                plt.annotate('',
                             xy=point,
                             xytext=base,
                             xycoords='data',
                             arrowprops={
                                 'arrowstyle': '->',
                                 'color': 'r',
                                 'lw': 3
                             },
                             va='center',
                             ha='center')
            base = point
    return
def plt_intuition(x_train, y_train):

    w_range = np.array([200 - 200, 200 + 200])
    tmp_b = 100

    w_array = np.arange(*w_range, 5)
    cost = np.zeros_like(w_array)
    for i in range(len(w_array)):
        tmp_w = w_array[i]
        cost[i] = compute_cost(x_train, y_train, tmp_w, tmp_b)

    @interact(w=(*w_range, 10), continuous_update=False)
    def func(w=150):
        f_wb = np.dot(x_train, w) + tmp_b

        fig, ax = plt.subplots(1, 2, constrained_layout=True, figsize=(8, 4))
        fig.canvas.toolbar_position = 'bottom'

        mk_cost_lines(x_train, y_train, w, tmp_b, ax[0])
        plt_house_x(x_train, y_train, f_wb=f_wb, ax=ax[0])

        ax[1].plot(w_array, cost)
        cur_cost = compute_cost(x_train, y_train, w, tmp_b)
        ax[1].scatter(w,
                      cur_cost,
                      s=100,
                      color=dldarkred,
                      zorder=10,
                      label=f"cost at w={w}")
        ax[1].hlines(cur_cost,
                     ax[1].get_xlim()[0],
                     w,
                     lw=4,
                     color=dlpurple,
                     ls='dotted')
        ax[1].vlines(w,
                     ax[1].get_ylim()[0],
                     cur_cost,
                     lw=4,
                     color=dlpurple,
                     ls='dotted')
        ax[1].set_title("Cost vs. w, (b fixed at 100)")
        ax[1].set_ylabel('Cost')
        ax[1].set_xlabel('w')
        ax[1].legend(loc='upper center')
        fig.suptitle(f"Minimize Cost: Current Cost = {cur_cost:0.0f}",
                     fontsize=12)
        plt.show()
    def __call__(self, event):
        if event.inaxes == self.ax[1]:
            ws = event.xdata
            bs = event.ydata
            cst = compute_cost(self.x_train, self.y_train, ws, bs)

            # clear and redraw line plot
            self.ax[0].clear()
            f_wb = np.dot(self.x_train, ws) + bs
            mk_cost_lines(self.x_train, self.y_train, ws, bs, self.ax[0])
            plt_house_x(self.x_train, self.y_train, f_wb=f_wb, ax=self.ax[0])

            # remove lines and re-add on countour plot and 3d plot
            for artist in self.dyn_items:
                artist.remove()

            a = self.ax[1].scatter(ws,
                                   bs,
                                   s=100,
                                   color=dlblue,
                                   zorder=10,
                                   label="cost with \ncurrent w,b")
            b = self.ax[1].hlines(bs,
                                  self.ax[1].get_xlim()[0],
                                  ws,
                                  lw=4,
                                  color=dlpurple,
                                  ls='dotted')
            c = self.ax[1].vlines(ws,
                                  self.ax[1].get_ylim()[0],
                                  bs,
                                  lw=4,
                                  color=dlpurple,
                                  ls='dotted')
            d = self.ax[1].annotate(f"Cost: {cst:.0f}",
                                    xy=(ws, bs),
                                    xytext=(4, 4),
                                    textcoords='offset points',
                                    bbox=dict(facecolor='white'),
                                    size=10)

            #Add point in 3D surface plot
            e = self.ax[2].scatter3D(ws, bs, cst, marker='X', s=100)

            self.dyn_items = [a, b, c, d, e]
            self.fig.canvas.draw()
    def func(w=150):
        f_wb = np.dot(x_train, w) + tmp_b

        fig, ax = plt.subplots(1, 2, constrained_layout=True, figsize=(8, 4))
        fig.canvas.toolbar_position = 'bottom'

        mk_cost_lines(x_train, y_train, w, tmp_b, ax[0])
        plt_house_x(x_train, y_train, f_wb=f_wb, ax=ax[0])

        ax[1].plot(w_array, cost)
        cur_cost = compute_cost(x_train, y_train, w, tmp_b)
        ax[1].scatter(w,
                      cur_cost,
                      s=100,
                      color=dldarkred,
                      zorder=10,
                      label=f"cost at w={w}")
        ax[1].hlines(cur_cost,
                     ax[1].get_xlim()[0],
                     w,
                     lw=4,
                     color=dlpurple,
                     ls='dotted')
        ax[1].vlines(w,
                     ax[1].get_ylim()[0],
                     cur_cost,
                     lw=4,
                     color=dlpurple,
                     ls='dotted')
        ax[1].set_title("Cost vs. w, (b fixed at 100)")
        ax[1].set_ylabel('Cost')
        ax[1].set_xlabel('w')
        ax[1].legend(loc='upper center')
        fig.suptitle(f"Minimize Cost: Current Cost = {cur_cost:0.0f}",
                     fontsize=12)
        plt.show()
def plt_stationary(x_train, y_train):
    # setup figure
    fig = plt.figure(figsize=(9, 8))
    #fig = plt.figure(constrained_layout=True,  figsize=(12,10))
    fig.set_facecolor('#ffffff')  #white
    fig.canvas.toolbar_position = 'top'
    #gs = GridSpec(2, 2, figure=fig, wspace = 0.01)
    gs = GridSpec(2, 2, figure=fig)
    ax0 = fig.add_subplot(gs[0, 0])
    ax1 = fig.add_subplot(gs[0, 1])
    ax2 = fig.add_subplot(gs[1, :], projection='3d')
    ax = np.array([ax0, ax1, ax2])

    #setup useful ranges and common linspaces
    w_range = np.array([200 - 300., 200 + 300])
    b_range = np.array([50 - 300., 50 + 300])
    b_space = np.linspace(*b_range, 100)
    w_space = np.linspace(*w_range, 100)

    # get cost for w,b ranges for contour and 3D
    tmp_b, tmp_w = np.meshgrid(b_space, w_space)
    z = np.zeros_like(tmp_b)
    for i in range(tmp_w.shape[0]):
        for j in range(tmp_w.shape[1]):
            z[i, j] = compute_cost(x_train, y_train, tmp_w[i][j], tmp_b[i][j])
            if z[i, j] == 0: z[i, j] = 1e-6

    w0 = 200
    b = -100  #initial point
    ### plot model w cost ###
    f_wb = np.dot(x_train, w0) + b
    mk_cost_lines(x_train, y_train, w0, b, ax[0])
    plt_house_x(x_train, y_train, f_wb=f_wb, ax=ax[0])

    ### plot contour ###
    CS = ax[1].contour(tmp_w,
                       tmp_b,
                       np.log(z),
                       levels=12,
                       linewidths=2,
                       alpha=0.7,
                       colors=dlcolors)
    ax[1].set_title('Cost(w,b)')
    ax[1].set_xlabel('w', fontsize=10)
    ax[1].set_ylabel('b', fontsize=10)
    ax[1].set_xlim(w_range)
    ax[1].set_ylim(b_range)
    cscat = ax[1].scatter(w0,
                          b,
                          s=100,
                          color=dlblue,
                          zorder=10,
                          label="cost with \ncurrent w,b")
    chline = ax[1].hlines(b,
                          ax[1].get_xlim()[0],
                          w0,
                          lw=4,
                          color=dlpurple,
                          ls='dotted')
    cvline = ax[1].vlines(w0,
                          ax[1].get_ylim()[0],
                          b,
                          lw=4,
                          color=dlpurple,
                          ls='dotted')
    ax[1].text(0.5,
               0.95,
               "Click to choose w,b",
               bbox=dict(facecolor='white', ec='black'),
               fontsize=10,
               transform=ax[1].transAxes,
               verticalalignment='center',
               horizontalalignment='center')

    #Surface plot of the cost function J(w,b)
    ax[2].plot_surface(tmp_w, tmp_b, z, cmap=dlcm, alpha=0.3, antialiased=True)
    ax[2].plot_wireframe(tmp_w, tmp_b, z, color='k', alpha=0.1)
    plt.xlabel("$w$")
    plt.ylabel("$b$")
    ax[2].zaxis.set_rotate_label(False)
    ax[2].xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax[2].yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax[2].zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax[2].set_zlabel("J(w, b)\n\n", rotation=90)
    plt.title("Cost(w,b) \n [You can rotate this figure]", size=12)
    ax[2].view_init(30, -120)

    return fig, ax, [cscat, chline, cvline]