def main():
    flag_plot_final_result = True
    flag_write_log = True

    device = torch.device('cpu')
    torch.manual_seed(9999)
    # a and b are used to generate the training dataset
    a = 1.261845
    b = 1.234378
    c = math.sqrt(a * a - b * b)
    xaxisoffset = c  # offset of the center of the ellipse from the origin in this case
    nsamples = 512
    batch_size = 64
    epoch = 100

    # initWa = torch.rand([], device=device, requires_grad=True)
    # initWb = torch.rand([], device=device, requires_grad=True)
    # to initialize Wa and Wb with specific starting values
    initWa = torch.tensor(0.10, device=device, requires_grad=True)
    initWb = torch.tensor(1.8, device=device, requires_grad=True)
    thenet = OrbitRegressionNet(xaxisoffset, initWa, initWb)
    loss_fn = nn.MSELoss(reduction='mean')

    optim_algo = 'Adam'  # 'SGD', 'SGD_Momentum', 'RMSprop' or 'Adam', case sensitive
    if optim_algo == 'SGD':
        # ----- hyper-parameters for the SGD optimizer -------------------
        init_learning_rate = 0.01
        momentum = 0.9
        optimizer = optim.SGD(thenet.parameters(),
                              lr=init_learning_rate,
                              momentum=0.0)
    elif optim_algo == 'SGD_Momentum':
        # ----- hyper-parameters for the SGD_with_Momentum optimizer -----
        init_learning_rate = 0.01
        momentum = 0.9
        dampening = 0.0
        optimizer = optim.SGD(thenet.parameters(),
                              lr=init_learning_rate,
                              momentum=momentum,
                              dampening=dampening)
    elif optim_algo == 'RMSprop':
        # ----- hyper-parameters for the RMSprop optimizer ---------------
        init_learning_rate = 0.02
        alpha = 0.99
        eps = 1e-08
        optimizer = optim.RMSprop(thenet.parameters(),
                                  lr=init_learning_rate,
                                  alpha=alpha,
                                  eps=eps)
    else:
        # ----- hyper-parameters for the Adam optimizer --------
        init_learning_rate = 0.12
        beta1, beta2 = 0.9, 0.999
        eps = 1e-08
        optimizer = optim.Adam(thenet.parameters(),
                               lr=init_learning_rate,
                               betas=(beta1, beta2),
                               eps=eps)

    # instantiate the MultiplicativeLR schedule class, comment out one of the following two blocks
    # lambda_fn_0_8 = lambda epoch: 0.8 if (epoch >= 6 and epoch <= 24 and epoch % 2 == 0) else 1.0
    # scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda_fn_0_8)
    # strLambda = 'lambda_fn_0_8'
    lambda0_5 = lambda epoch: 0.5 if (epoch >= 5 and epoch <= 20 and epoch % 5
                                      == 0) else 1.0
    scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
                                                          lr_lambda=lambda0_5)
    strLambda = 'lambda_fn_0_5'

    if flag_write_log:
        logfilename = r'results/' + optim_algo + '_lr{}'.format(
            init_learning_rate) + '_Epoch{}'.format(epoch) + '_Schl_{}'.format(
                strLambda) + '_results.log'
        foutput = open(logfilename, "w")
        foutput.write(optim_algo + 'result\n')
        logstr = 'nsamples={}, batch_size={}, epoch={}, lr={}, lr_gamma={}\ninitial Wa={}, Wb={}, lr_milestone={}\n' \
            .format(nsamples, batch_size, epoch, init_learning_rate, strLambda, initWa, initWb,
                    r'lmbda0_5 = lambda epoch: 0.5 if (epoch >= 5 and epoch <= 20 and epoch % 5 == 0) else 1.0')
        foutput.write(logstr)

    xy_dataset = EllipseDataset(nsamples, a, b, noise_scale=0.1)
    xy_dataloader = DataLoader(xy_dataset,
                               batch_size=batch_size,
                               shuffle=True,
                               num_workers=0)

    updateCount = 0

    for t in range(epoch):

        for i_batch, sample_batched in enumerate(xy_dataloader):
            x, y = sample_batched['x'], sample_batched['y']

            thenet.train()

            # Step 1: Perform forward pass
            y_pred, negativeloc = thenet(x)

            # Step 2: Compute loss
            loss = loss_fn(y_pred[~negativeloc], y[~negativeloc])

            # Step 3: perform back-propagation and calculate the gradients of loss w.r.t. Wa and Wb
            optimizer.zero_grad()
            loss.backward()

            if flag_write_log:
                logstr = 'updates={}, Epoch={}, minibatch={}, loss={:.5f}, Wa={:.4f}, Wb={:.4f}'.format(
                    updateCount + 1, t, i_batch, loss.item(),
                    thenet.Wa.data.numpy(), thenet.Wb.data.numpy())
                logstr1 = ', unitvecWa={:.5f}, unitvecWb={:.5f}, learningrate={:.5f}, ' \
                         'dWa={:.5f}, dWb={:.5f}, VdWa={:.5f}, VdWb={:.5f}, sqrt_SdWa={:.5f}, sqrt_SdWb={:.5f}\n'.format(
                    0, 0, optimizer.param_groups[0]['lr'], thenet.Wa.grad.data.numpy(), thenet.Wb.grad.data.numpy(), 0, 0, 0, 0)
                foutput.write(logstr + logstr1)
                # if t % 10 == 0 and i_batch == 0:
                #     print(logstr)

            # Step 4: finally Update weights Wa and Wb using Adam algorithm.
            optimizer.step()
            # Step 4.1: and update the learning rate hyper-parameter based on the scheduler.
            scheduler.step()

            updateCount += 1

    # log the final results
    if flag_write_log:
        logstr = 'The ground truth is A={:.4f}, B={:.4f}\n'.format(a, b)
        logstr += 'PyTorch built-in AutoGradient+optimizer result: Final estimated Wa={:.4f}, Wb={:.4f}\n'.format(
            thenet.Wa, thenet.Wb)
        foutput.write(logstr)
        foutput.close()
        print(logstr)

    # plot the results obtained from the training
    if flag_plot_final_result:
        x = xy_dataset[:]['x']
        yfit = thenet.Wb * torch.sqrt(1.0 - (x + c)**2 / thenet.Wa**2)
        yfit[
            yfit !=
            yfit] = 0.0  # take care of the "Nan" at the end-points due to sqrt(negative_value_caused_by_noise)
        plt.plot(x, yfit.detach().numpy(), color="purple", linewidth=2.0)
        strEquation = r'$\frac{{{\left({x+' + '{:.3f}'.format(c) + r'}\right)}^2}}{{' + '{:.3f}'.\
            format(thenet.Wa) + r'^2}}+\frac{y^2}{' + '{:.3f}'.format(thenet.Wb) + r'^2}=1$'
        x0, y0 = x.detach().numpy()[nsamples * 2 //
                                    3], yfit.detach().numpy()[nsamples * 2 //
                                                              3]
        plt.annotate(strEquation,
                     xy=(x0, y0),
                     xycoords='data',
                     xytext=(+0.75, 1.75),
                     textcoords='data',
                     fontsize=16,
                     arrowprops=dict(arrowstyle="->",
                                     connectionstyle="arc3,rad=.2"))
        plt.text(1.0, 1.5, 'Result of Adam', color='black', fontsize=12)
        plt.text(-1.9,
                 1.9,
                 'Epoch={}\n'.format(epoch) +
                 'Init Learning Rate={}\n'.format(init_learning_rate) +
                 'MultiplicativeLR w/ {}'.format(strLambda),
                 color='black',
                 fontsize=12,
                 ha='left',
                 va='top')

        figfilename = r'results/' + optim_algo + '_lr{}'.format(
            init_learning_rate) + '_Epoch{}'.format(epoch) + '_Schl_{}'.format(
                strLambda) + '.png'
        plt.savefig(figfilename)
        plt.show()

    print('Done!')
def main():
    flag_manual_implement = True  # True: using our custom implementation; False: using Torch built-in
    flag_plot_final_result = True
    flag_log = True

    device = torch.device('cpu')
    torch.manual_seed(9999)
    # a and b are used to generate the training dataset
    a = 1.261845
    b = 1.234378
    c = math.sqrt(a * a - b * b)
    nsamples = 512
    batch_size = 64
    epoch = 160
    learning_rate = 0.01  # 0.03
    lr_milestones = [16]
    lr_gamma = 1.0  # 1.0 means no LR scheduler

    beta1, beta2 = 0.9, 0.999
    eps = 1e-08

    xy_dataset = EllipseDataset(nsamples, a, b, noise_scale=0.1)
    xy_dataloader = DataLoader(xy_dataset,
                               batch_size=batch_size,
                               shuffle=True,
                               num_workers=0)

    # Wa = torch.rand([], device=device, requires_grad=True)
    # Wb = torch.rand([], device=device, requires_grad=True)
    Wa = torch.tensor(0.10, device=device, requires_grad=True)
    Wb = torch.tensor(1.8, device=device, requires_grad=True)

    if flag_log:
        if flag_manual_implement:
            logfilename = 'results/Adam_custom_implement_LR{}'.format(
                learning_rate) + '_results.log'
            foutput = open(logfilename, "w")
            foutput.write('Adam optimization using custom implementation' +
                          '\n')
        else:
            logfilename = 'results/Adam_custom_implement_LR{}'.format(
                learning_rate) + '_results.log'
            foutput = open(logfilename, "w")
            foutput.write('Adam optimization using torch built-in' + '\n')

        logstr = 'nsamples={}, batch_size={}, epoch={}, lr={}\ninitial Wa={}, Wb={}\n' \
            .format(nsamples, batch_size, epoch, learning_rate, Wa, Wb)
        foutput.write(logstr)
        print(logstr)

    VdWa = 0.0
    VdWb = 0.0
    SdWa = 0.0
    SdWb = 0.0
    beta1_to_pow_t = 1.0
    beta2_to_pow_t = 1.0

    if flag_manual_implement:
        optimizer = None
    else:
        optimizer = optim.Adam([Wa, Wb],
                               lr=learning_rate,
                               betas=(beta1, beta2),
                               eps=eps)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=lr_milestones, gamma=lr_gamma)

    updates = 0
    for t in range(epoch):

        for i_batch, sample_batched in enumerate(xy_dataloader):
            x, y = sample_batched['x'], sample_batched['y']

            # Step 1: Perform forward pass
            y_pred_sqr = 1.0 - (x + c)**2 / Wa**2
            negativeloc = y_pred_sqr < 0  # record the non-negative y_pred_sqr elements, to be used later
            y_pred_sqr[
                negativeloc] = 0  # handle negative values caused by noise
            y_pred = torch.sqrt(y_pred_sqr) * Wb

            # Step 2: Compute loss
            loss = ((y_pred - y).pow(2))[~negativeloc].mean()

            if flag_log:

                logstr = 'updates={}, Epoch={}, minibatch={}, loss={:.5f}, Wa={:.4f}, Wb={:.4f}'.format(
                    updates + 1, t, i_batch, loss.item(), Wa.data.numpy(),
                    Wb.data.numpy())
                foutput.write(logstr)
                if t % 10 == 0 and i_batch == 0:
                    print(logstr)

            if flag_manual_implement:  # do the job manually
                if updates in lr_milestones:
                    learning_rate *= lr_gamma

                # Step 3: perform back-propagation and calculate the gradients of loss w.r.t. Wa and Wb
                dWa_via_yi = (2.0 * (y_pred - y) * ((x + c)**2) * (Wb**2) /
                              (Wa**3) / y_pred)
                dWa = dWa_via_yi[~negativeloc].mean()  # sum()

                dWb_via_yi = ((2.0 * (y_pred - y) * y_pred / Wb))
                dWb = dWb_via_yi[~negativeloc].mean()  #.sum()

                # Step 4: Update weights using Adam algorithm.
                with torch.no_grad():
                    beta1_to_pow_t *= beta1
                    beta1_correction = 1.0 - beta1_to_pow_t
                    beta2_to_pow_t *= beta2
                    beta2_correction = 1.0 - beta2_to_pow_t

                    VdWa = beta1 * VdWa + (1.0 - beta1) * dWa
                    SdWa = beta2 * SdWa + (1.0 - beta2) * dWa * dWa
                    Wa -= learning_rate * (VdWa / beta1_correction) / (
                        torch.sqrt(SdWa) / math.sqrt(beta2_correction) + eps)

                    VdWb = beta1 * VdWb + (1.0 - beta1) * dWb
                    SdWb = beta2 * SdWb + (1.0 - beta2) * dWb * dWb
                    Wb -= learning_rate * (VdWb / beta1_correction) / (
                        torch.sqrt(SdWb) / math.sqrt(beta2_correction) + eps)

                    if flag_log:
                        tmp_a = (VdWa / beta1_correction) / (
                            torch.sqrt(SdWa) / math.sqrt(beta2_correction) +
                            eps)
                        tmp_b = (VdWb / beta1_correction) / (
                            torch.sqrt(SdWb) / math.sqrt(beta2_correction) +
                            eps)
                        mag_dWadWb = math.sqrt(tmp_a**2 + tmp_b**2)
                        unitvec_a = tmp_a / mag_dWadWb
                        unitvec_b = tmp_b / mag_dWadWb
                        actual_stepsize = learning_rate * mag_dWadWb

                        logstr = ', unitvecWa={:.5f}, unitvecWb={:.5f}, eff_stepsize={:.5f}, ' \
                                 'dWa={:.5f}, dWb={:.5f}, VdWa={:.5f}, VdWb={:.5f}, sqrt_SdWa={:.5f}, sqrt_SdWb={:.5f}\n'.format(
                                    unitvec_a, unitvec_b, actual_stepsize, dWa, dWb, VdWa, VdWb, math.sqrt(SdWa), math.sqrt(SdWb))
                        foutput.write(logstr)

            else:  # do the same job using Torch built-in autograd and optim
                optimizer.zero_grad()
                # Step 3: perform back-propagation and calculate the gradients of loss w.r.t. Wa and Wb
                loss.backward()
                # Step 4: Update weights using Adam algorithm.
                optimizer.step()
                scheduler.step()

            updates += 1

    # log the final results
    if flag_log:
        logstr = 'The ground truth is A={:.4f}, B={:.4f}\n'.format(a, b)
        if flag_manual_implement:
            logstr += 'Manually implemented gradient+optimizer result: Final estimated Wa={:.4f}, Wb={:.4f}\n'.format(
                Wa, Wb)
        else:
            logstr += 'PyTorch built-in AutoGradient+optimizer result: Final estimated Wa={:.4f}, Wb={:.4f}\n'.format(
                Wa, Wb)
        foutput.write(logstr)
        foutput.close()
        print(logstr)

    # plot the results obtained from the training
    if flag_plot_final_result:
        x = xy_dataset[:]['x']
        yfit = Wb * torch.sqrt(1.0 - (x + c)**2 / Wa**2)
        yfit[
            yfit !=
            yfit] = 0.0  # take care of the "Nan" at the end-points due to sqrt(negative_value_caused_by_noise)
        plt.plot(x, yfit.detach().numpy(), color="purple", linewidth=2.0)
        strEquation = r'$\frac{{{\left({x+' + '{:.3f}'.format(
            c) + r'}\right)}^2}}{{' + '{:.3f}'.format(
                Wa) + r'^2}}+\frac{y^2}{' + '{:.3f}'.format(Wb) + r'^2}=1$'
        x0, y0 = x.detach().numpy()[nsamples * 2 //
                                    3], yfit.detach().numpy()[nsamples * 2 //
                                                              3]
        plt.annotate(strEquation,
                     xy=(x0, y0),
                     xycoords='data',
                     xytext=(+0.75, 1.75),
                     textcoords='data',
                     fontsize=16,
                     arrowprops=dict(arrowstyle="->",
                                     connectionstyle="arc3,rad=.2"))
        plt.text(1.0, 1.5, 'Result of Adam', color='black', fontsize=12)
        plt.show()

    print('Done!')
Ejemplo n.º 3
0
def main():
    device = torch.device('cpu')
    torch.manual_seed(9999)
    a = 1.261845
    b = 1.234378
    c = math.sqrt(a * a - b * b)
    nsamples = 512
    batch_size = 512

    # load previously generated results
    WaTraces_SGD_MOMENTUM1, WbTraces_SGD_MOMENTUM1, LossTraces_SGD_MOMENTUM1, EpochTraces1, minibatchTraces1, updateTraces1, lr_SGD_MOMENTUM1, method_SGD_MOMENTUM1, \
        _, _, _, dWa_SGD_MOMENTUM1, dWb_SGD_MOMENTUM1, _, _, _, _ = \
        read_optimizer_results(r'results/Sect2.2_SGD_Momentum_lr0.012_Epoch100_results.log')
    WaTraces_SGD_MOMENTUM2, WbTraces_SGD_MOMENTUM2, LossTraces_SGD_MOMENTUM2, EpochTraces2, minibatchTraces2, updateTraces2,  lr_SGD_MOMENTUM2, method_SGD_MOMENTUM2, \
        _, _, _, dWa_SGD_MOMENTUM2, dWb_SGD_MOMENTUM2, _, _, _, _ = \
        read_optimizer_results(r'results/Sect2.2_SGD_Momentum_lr0.01_Epoch100_results.log')
    WaTraces_SGD_MOMENTUM3, WbTraces_SGD_MOMENTUM3, LossTraces_SGD_MOMENTUM3, EpochTraces3, minibatchTraces3, updateTraces3,  lr_SGD_MOMENTUM3, method_SGD_MOMENTUM3, \
        _, _, _, dWa_SGD_MOMENTUM3, dWb_SGD_MOMENTUM3, _, _, _, _ = \
        read_optimizer_results(r'results/Sect2.2_SGD_Momentum_lr0.005_Epoch100_results.log')
    WaTraces_SGD_MOMENTUM4, WbTraces_SGD_MOMENTUM4, LossTraces_SGD_MOMENTUM4, EpochTraces4, minibatchTraces4, updateTraces4,  lr_SGD_MOMENTUM4, method_SGD_MOMENTUM4, \
        _, _, _, dWa_SGD_MOMENTUM4, dWb_SGD_MOMENTUM4, _, _, _, _ = \
        read_optimizer_results(r'results/Sect2.2_SGD_Momentum_lr0.001_Epoch100_results.log')
    WaTraces_SGD_MOMENTUM5, WbTraces_SGD_MOMENTUM5, LossTraces_SGD_MOMENTUM5, EpochTraces5, minibatchTraces5, updateTraces5,  lr_SGD_MOMENTUM5, method_SGD_MOMENTUM5, \
        _, _, _, dWa_SGD_MOMENTUM5, dWb_SGD_MOMENTUM5, _, _, _, _ = \
        read_optimizer_results(r'results/Sect2.2_SGD_Momentum_lr0.00015_Epoch100_results.log')
    # WaTraces_SGD_MOMENTUM6, WbTraces_SGD_MOMENTUM6, LossTraces_SGD_MOMENTUM6, _, _, _, lr_SGD_MOMENTUM6, method_SGD_MOMENTUM6, \
    #     _, _, _, dWa_SGD_MOMENTUM6, dWb_SGD_MOMENTUM6, _, _, _, _ = \
    #     read_optimizer_results(r'SGD_Momentum_lr0.00015_Epoch500_results.log')

    nframes = max([
        WaTraces_SGD_MOMENTUM1.size, WaTraces_SGD_MOMENTUM2.size,
        WaTraces_SGD_MOMENTUM3.size, WaTraces_SGD_MOMENTUM4.size,
        WaTraces_SGD_MOMENTUM5.size
    ])
    epoch1, epoch2, epoch3, epoch4, epoch5 = EpochTraces1[
        -1] + 1, EpochTraces2[-1] + 1, EpochTraces3[-1] + 1, EpochTraces4[
            -1] + 1, EpochTraces5[-1] + 1

    xy_dataset = EllipseDataset(nsamples, a, b, noise_scale=0.1)
    xy_dataloader = DataLoader(xy_dataset,
                               batch_size=batch_size,
                               shuffle=True,
                               num_workers=0)

    Wa0 = WaTraces_SGD_MOMENTUM1[0]
    Wb0 = WbTraces_SGD_MOMENTUM1[0]

    # nWaGrids, nWbGrids = 200, 200
    # WaGrid = np.linspace(0, 2.0, nWaGrids)
    # WbGrid = np.linspace(0, 2.0, nWbGrids)
    nWaGrids, nWbGrids = 250, 250
    # WaGrid = np.linspace(0, 2.5, nWaGrids)
    # WbGrid = np.linspace(0, 2.5, nWbGrids)
    WaGrid = np.linspace(0, 4.0, nWaGrids)
    WbGrid = np.linspace(-0.5, 3.5, nWbGrids)

    Wa2d, Wb2d = np.meshgrid(WaGrid, WaGrid)
    loss = np.zeros(Wa2d.shape)

    for i_batch, sample_batched in enumerate(xy_dataloader):
        x, y = sample_batched['x'], sample_batched['y']

    for indexb, Wb in enumerate(WbGrid):
        for indexa, Wa in enumerate(WaGrid):
            y_pred_sqr = Wb**2 * (1.0 - (x + c)**2 / Wa**2)
            y_pred_sqr[
                y_pred_sqr <
                0.00000001] = 0.00000001  # handle negative values caused by noise
            y_pred = torch.sqrt(y_pred_sqr)

            loss[indexb, indexa] = (y_pred - y).pow(2).sum()

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)
    ax.contour(WbGrid, WaGrid, loss, levels=15, linewidths=0.5, colors='gray')
    cntr1 = ax.contourf(WaGrid, WbGrid, loss, levels=100, cmap="RdBu_r")
    fig.colorbar(cntr1, ax=ax, shrink=0.75)
    # ax.set(xlim=(0, 2.0), ylim=(0, 2.0))
    ax.set(xlim=(0, 4.0), ylim=(-0.5, 3.5))
    ax.set_title('SGD_w_Momentum Training Progress w/ Diff LR', fontsize=16)
    plt.xlabel("Wa")
    plt.ylabel("Wb")
    ax.set_aspect('equal', adjustable='box')
    ax.plot(a, b, 'yo', ms=4)
    ax.plot(Wa0, Wb0, 'ko', ms=5)
    ax.text(0.08, 1.85, 'Start', color="white", fontsize=14)
    ax.text(0.88, 1.15, 'Target', color="white", fontsize=14)

    walist_SGD_MOMENTUM1 = []
    wblist_SGD_MOMENTUM1 = []
    point_SGD_MOMENTUM1, = ax.plot([], [], 'ro', lw=0.5, markersize=4)
    line_SGD_MOMENTUM1, = ax.plot(
        [], [],
        '-r',
        lw=2,
        label='SGD_Momentum epoch={} '.format(epoch1) + lr_SGD_MOMENTUM1)

    walist_SGD_MOMENTUM2 = []
    wblist_SGD_MOMENTUM2 = []
    point_SGD_MOMENTUM2, = ax.plot([], [], 'yo', lw=0.5, markersize=4)
    line_SGD_MOMENTUM2, = ax.plot(
        [], [],
        '-y',
        lw=2,
        label='SGD_Momentum epoch={} '.format(epoch2) + lr_SGD_MOMENTUM2)

    walist_SGD_MOMENTUM3 = []
    wblist_SGD_MOMENTUM3 = []
    point_SGD_MOMENTUM3, = ax.plot([], [], 'mo', lw=0.5, markersize=4)
    line_SGD_MOMENTUM3, = ax.plot(
        [], [],
        '-m',
        lw=2,
        label='SGD_Momentum epoch={} '.format(epoch3) + lr_SGD_MOMENTUM3)

    walist_SGD_MOMENTUM4 = []
    wblist_SGD_MOMENTUM4 = []
    point_SGD_MOMENTUM4, = ax.plot([], [], 'go', lw=0.5, markersize=4)
    line_SGD_MOMENTUM4, = ax.plot(
        [], [],
        '-g',
        lw=2,
        label='SGD_Momentum epoch={} '.format(epoch4) + lr_SGD_MOMENTUM4)

    walist_SGD_MOMENTUM5 = []
    wblist_SGD_MOMENTUM5 = []
    point_SGD_MOMENTUM5, = ax.plot([], [],
                                   'o',
                                   lw=0.5,
                                   markersize=4,
                                   color='aqua')
    line_SGD_MOMENTUM5, = ax.plot(
        [], [],
        '-',
        lw=2,
        color='aqua',
        label='SGD_Momentum epoch={} '.format(epoch5) + lr_SGD_MOMENTUM5)

    text_update = ax.text(0.03,
                          0.03,
                          '',
                          transform=ax.transAxes,
                          color="white",
                          fontsize=14)

    leg = ax.legend()
    fig.tight_layout()
    plt.show(block=False)

    # initialization function: plot the background of each frame
    def init():
        point_SGD_MOMENTUM1.set_data([], [])
        line_SGD_MOMENTUM1.set_data([], [])

        point_SGD_MOMENTUM2.set_data([], [])
        line_SGD_MOMENTUM2.set_data([], [])

        point_SGD_MOMENTUM3.set_data([], [])
        line_SGD_MOMENTUM3.set_data([], [])

        point_SGD_MOMENTUM4.set_data([], [])
        line_SGD_MOMENTUM4.set_data([], [])

        point_SGD_MOMENTUM5.set_data([], [])
        line_SGD_MOMENTUM5.set_data([], [])

        text_update.set_text('')

        return point_SGD_MOMENTUM1, line_SGD_MOMENTUM1, point_SGD_MOMENTUM2, line_SGD_MOMENTUM2, point_SGD_MOMENTUM3, line_SGD_MOMENTUM3, point_SGD_MOMENTUM4, line_SGD_MOMENTUM4, point_SGD_MOMENTUM5, line_SGD_MOMENTUM5, text_update

    # animation function.  This is called sequentially
    def animate(i):
        if i == 0:

            wblist_SGD_MOMENTUM1[:] = []
            walist_SGD_MOMENTUM1[:] = []

            wblist_SGD_MOMENTUM2[:] = []
            walist_SGD_MOMENTUM2[:] = []

            wblist_SGD_MOMENTUM3[:] = []
            walist_SGD_MOMENTUM3[:] = []

            wblist_SGD_MOMENTUM4[:] = []
            walist_SGD_MOMENTUM4[:] = []

            wblist_SGD_MOMENTUM5[:] = []
            walist_SGD_MOMENTUM5[:] = []

        wa_SGD_MOMENTUM1, wb_SGD_MOMENTUM1 = WaTraces_SGD_MOMENTUM1[
            i], WbTraces_SGD_MOMENTUM1[i]
        wblist_SGD_MOMENTUM1.append(wa_SGD_MOMENTUM1)
        walist_SGD_MOMENTUM1.append(wb_SGD_MOMENTUM1)
        point_SGD_MOMENTUM1.set_data(wa_SGD_MOMENTUM1, wb_SGD_MOMENTUM1)
        line_SGD_MOMENTUM1.set_data(wblist_SGD_MOMENTUM1, walist_SGD_MOMENTUM1)

        wa_SGD_MOMENTUM2, wb_SGD_MOMENTUM2 = WaTraces_SGD_MOMENTUM2[
            i], WbTraces_SGD_MOMENTUM2[i]
        wblist_SGD_MOMENTUM2.append(wa_SGD_MOMENTUM2)
        walist_SGD_MOMENTUM2.append(wb_SGD_MOMENTUM2)
        point_SGD_MOMENTUM2.set_data(wa_SGD_MOMENTUM2, wb_SGD_MOMENTUM2)
        line_SGD_MOMENTUM2.set_data(wblist_SGD_MOMENTUM2, walist_SGD_MOMENTUM2)

        wa_SGD_MOMENTUM3, wb_SGD_MOMENTUM3 = WaTraces_SGD_MOMENTUM3[
            i], WbTraces_SGD_MOMENTUM3[i]
        wblist_SGD_MOMENTUM3.append(wa_SGD_MOMENTUM3)
        walist_SGD_MOMENTUM3.append(wb_SGD_MOMENTUM3)
        point_SGD_MOMENTUM3.set_data(wa_SGD_MOMENTUM3, wb_SGD_MOMENTUM3)
        line_SGD_MOMENTUM3.set_data(wblist_SGD_MOMENTUM3, walist_SGD_MOMENTUM3)

        wa_SGD_MOMENTUM4, wb_SGD_MOMENTUM4 = WaTraces_SGD_MOMENTUM4[
            i], WbTraces_SGD_MOMENTUM4[i]
        wblist_SGD_MOMENTUM4.append(wa_SGD_MOMENTUM4)
        walist_SGD_MOMENTUM4.append(wb_SGD_MOMENTUM4)
        point_SGD_MOMENTUM4.set_data(wa_SGD_MOMENTUM4, wb_SGD_MOMENTUM4)
        line_SGD_MOMENTUM4.set_data(wblist_SGD_MOMENTUM4, walist_SGD_MOMENTUM4)

        wa_SGD_MOMENTUM5, wb_SGD_MOMENTUM5 = WaTraces_SGD_MOMENTUM5[
            i], WbTraces_SGD_MOMENTUM5[i]
        wblist_SGD_MOMENTUM5.append(wa_SGD_MOMENTUM5)
        walist_SGD_MOMENTUM5.append(wb_SGD_MOMENTUM5)
        point_SGD_MOMENTUM5.set_data(wa_SGD_MOMENTUM5, wb_SGD_MOMENTUM5)
        line_SGD_MOMENTUM5.set_data(wblist_SGD_MOMENTUM5, walist_SGD_MOMENTUM5)

        update, epoch, minibatch = updateTraces1[
            i], EpochTraces1[i] + 1, minibatchTraces1[i] + 1
        text_update.set_text('Epoch={:d}, minibatch={:d}, Updates={:d}'.format(
            epoch, minibatch, update))

        return point_SGD_MOMENTUM1, line_SGD_MOMENTUM1, point_SGD_MOMENTUM2, line_SGD_MOMENTUM2, point_SGD_MOMENTUM3, line_SGD_MOMENTUM3, point_SGD_MOMENTUM4, line_SGD_MOMENTUM4, point_SGD_MOMENTUM5, line_SGD_MOMENTUM5, text_update

    # call the animator.  blit=True means only re-draw the parts that have changed.
    intervalms = 10  # this means 10 ms per frame
    anim = animation.FuncAnimation(fig,
                                   animate,
                                   init_func=init,
                                   frames=nframes,
                                   interval=intervalms,
                                   blit=True)

    # save the animation as an mp4.  This requires ffmpeg or mencoder to be installed.
    anim.save(
        'Results/Part5_Fig9_Animation_SGD_Momentum_100Epoch_VariousLR.mp4',
        fps=30,
        bitrate=1800)

    plt.show()

    print('Done!')
def main():
    device = torch.device('cpu')
    torch.manual_seed(9999)
    a = 1.261845
    b = 1.234378
    c = math.sqrt(a * a - b * b)
    nsamples = 512
    batch_size = 512

    # load previously generated results
    WaTraces_Adam1, WbTraces_Adam1, LossTraces_Adam1, EpochTraces1, minibatchTraces1, updateTraces1, lr_Adam1, gamma_Adam1, method_Adam1, \
        lrTraces_Adam1, _, _, dWa_Adam1, dWb_Adam1, _, _, _, _ = \
        read_optimizer_results(r'results/Adam_lr0.01_Epoch100_Schl1.00_results.log')
    WaTraces_Adam2, WbTraces_Adam2, LossTraces_Adam2, EpochTraces2, minibatchTraces2, updateTraces2,  lr_Adam2, gamma_Adam2, method_Adam2, \
        lrTraces_Adam2, _, _, dWa_Adam2, dWb_Adam2, _, _, _, _ = \
        read_optimizer_results(r'results/Adam_lr0.12_Epoch100_Schl1.00_results.log')
    WaTraces_Adam3, WbTraces_Adam3, LossTraces_Adam3, EpochTraces3, minibatchTraces3, updateTraces3,  lr_Adam3, gamma_Adam3, method_Adam3, \
        lrTraces_Adam3, _, _, dWa_Adam3, dWb_Adam3, _, _, _, _ = \
        read_optimizer_results(r'results/Adam_lr0.12_Epoch100_Schl_lambda_fn_0_8_results.log')
    WaTraces_Adam4, WbTraces_Adam4, LossTraces_Adam4, EpochTraces4, minibatchTraces4, updateTraces4,  lr_Adam4, gamma_Adam4, method_Adam4, \
        lrTraces_Adam4, _, _, dWa_Adam4, dWb_Adam4, _, _, _, _ = \
        read_optimizer_results(r'results/Adam_lr0.12_Epoch100_Schl_lambda_fn_0_5_results.log')

    nframes = max([
        WaTraces_Adam1.size, WaTraces_Adam2.size, WaTraces_Adam3.size,
        WaTraces_Adam4.size
    ])
    epoch1, epoch2, epoch3, epoch4 = EpochTraces1[-1] + 1, EpochTraces2[
        -1] + 1, EpochTraces3[-1] + 1, EpochTraces4[-1] + 1
    update1, update2, update3, update4 = updateTraces1[-1], updateTraces2[
        -1], updateTraces3[-1], updateTraces4[-1]

    xy_dataset = EllipseDataset(nsamples, a, b, noise_scale=0.1)
    xy_dataloader = DataLoader(xy_dataset,
                               batch_size=batch_size,
                               shuffle=True,
                               num_workers=0)

    Wa0 = WaTraces_Adam1[0]
    Wb0 = WbTraces_Adam1[0]

    # nWaGrids, nWbGrids = 200, 200
    # WaGrid = np.linspace(0, 2.0, nWaGrids)
    # WbGrid = np.linspace(0, 2.0, nWbGrids)
    nWaGrids, nWbGrids = 200, 200
    # WaGrid = np.linspace(0, 2.5, nWaGrids)
    # WbGrid = np.linspace(0, 2.5, nWbGrids)
    WaGrid = np.linspace(0, 2.0, nWaGrids)
    WbGrid = np.linspace(0.25, 2.25, nWbGrids)

    Wa2d, Wb2d = np.meshgrid(WaGrid, WaGrid)
    loss = np.zeros(Wa2d.shape)

    for i_batch, sample_batched in enumerate(xy_dataloader):
        x, y = sample_batched['x'], sample_batched['y']

    for indexb, Wb in enumerate(WbGrid):
        for indexa, Wa in enumerate(WaGrid):
            y_pred_sqr = Wb**2 * (1.0 - (x + c)**2 / Wa**2)
            y_pred_sqr[
                y_pred_sqr <
                0.00000001] = 0.00000001  # handle negative values caused by noise
            y_pred = torch.sqrt(y_pred_sqr)

            loss[indexb, indexa] = (y_pred - y).pow(2).sum()

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)
    ax.contour(WbGrid, WaGrid, loss, levels=15, linewidths=0.5, colors='gray')
    cntr1 = ax.contourf(WaGrid, WbGrid, loss, levels=100, cmap="RdBu_r")
    fig.colorbar(cntr1, ax=ax, shrink=0.75)
    # ax.set(xlim=(0, 2.0), ylim=(0, 2.0))
    ax.set(xlim=(0, 2), ylim=(0.25, 2.25))
    ax.set_title('Adam Optimization with MultiplicativeLR Schedulers & Lambda',
                 fontsize=16)
    plt.xlabel("Wa")
    plt.ylabel("Wb")
    ax.set_aspect('equal', adjustable='box')
    ax.plot(a, b, 'yo', ms=5)
    ax.plot(Wa0, Wb0, 'ko', ms=5)
    ax.text(0.08, 1.85, 'Start', color="white", fontsize=14)
    ax.text(0.88, 1.15, 'Target', color="white", fontsize=14)

    walist_Adam1 = []
    wblist_Adam1 = []
    point_Adam1, = ax.plot([], [], 'ro', lw=0.5, markersize=4)
    line_Adam1, = ax.plot([], [],
                          '-r',
                          lw=1,
                          label='Adam epoch={}, '.format(epoch1) + lr_Adam1 +
                          ' fixed')

    walist_Adam2 = []
    wblist_Adam2 = []
    point_Adam2, = ax.plot([], [], 'yo', lw=0.5, markersize=4)
    line_Adam2, = ax.plot([], [],
                          '-y',
                          lw=1,
                          label='Adam epoch={}, '.format(epoch2) + lr_Adam2 +
                          ' fixed')

    walist_Adam3 = []
    wblist_Adam3 = []
    point_Adam3, = ax.plot([], [], 'mo', lw=0.5, markersize=4)
    line_Adam3, = ax.plot([], [],
                          '-m',
                          lw=1,
                          label='Adam epoch={}, '.format(epoch3) + lr_Adam3 +
                          ' MultiplicativeLR w/ Lambda fn 0_8')

    walist_Adam4 = []
    wblist_Adam4 = []
    point_Adam4, = ax.plot([], [], 'o', lw=0.5, markersize=4, color='aqua')
    line_Adam4, = ax.plot([], [],
                          '-',
                          lw=1,
                          color='aqua',
                          label='Adam epoch={}, '.format(epoch4) + lr_Adam4 +
                          ' MultiplicativeLR w/ Lambda fn 0_5')

    text_update = ax.text(0.03,
                          0.03,
                          '',
                          transform=ax.transAxes,
                          color="blue",
                          fontsize=14)

    leg = ax.legend()
    fig.tight_layout()
    plt.show(block=False)

    # initialization function: plot the background of each frame
    def init():
        point_Adam1.set_data([], [])
        line_Adam1.set_data([], [])

        point_Adam2.set_data([], [])
        line_Adam2.set_data([], [])

        point_Adam3.set_data([], [])
        line_Adam3.set_data([], [])

        point_Adam4.set_data([], [])
        line_Adam4.set_data([], [])

        text_update.set_text('')

        return point_Adam1, line_Adam1, point_Adam2, line_Adam2, point_Adam3, line_Adam3, point_Adam4, line_Adam4, text_update

    # animation function.  This is called sequentially
    def animate(i):
        if i == 0:

            wblist_Adam1[:] = []
            walist_Adam1[:] = []

            wblist_Adam2[:] = []
            walist_Adam2[:] = []

            wblist_Adam3[:] = []
            walist_Adam3[:] = []

            wblist_Adam4[:] = []
            walist_Adam4[:] = []

        if i < update1:
            wa_Adam1, wb_Adam1 = WaTraces_Adam1[i], WbTraces_Adam1[i]
            wblist_Adam1.append(wa_Adam1)
            walist_Adam1.append(wb_Adam1)
            point_Adam1.set_data(wa_Adam1, wb_Adam1)
            line_Adam1.set_data(wblist_Adam1, walist_Adam1)

        if i < update2:
            wa_Adam2, wb_Adam2 = WaTraces_Adam2[i], WbTraces_Adam2[i]
            wblist_Adam2.append(wa_Adam2)
            walist_Adam2.append(wb_Adam2)
            point_Adam2.set_data(wa_Adam2, wb_Adam2)
            line_Adam2.set_data(wblist_Adam2, walist_Adam2)

        if i < update3:
            wa_Adam3, wb_Adam3 = WaTraces_Adam3[i], WbTraces_Adam3[i]
            wblist_Adam3.append(wa_Adam3)
            walist_Adam3.append(wb_Adam3)
            point_Adam3.set_data(wa_Adam3, wb_Adam3)
            line_Adam3.set_data(wblist_Adam3, walist_Adam3)

        if i < update4:
            wa_Adam4, wb_Adam4 = WaTraces_Adam4[i], WbTraces_Adam4[i]
            wblist_Adam4.append(wa_Adam4)
            walist_Adam4.append(wb_Adam4)
            point_Adam4.set_data(wa_Adam4, wb_Adam4)
            line_Adam4.set_data(wblist_Adam4, walist_Adam4)

        update, epoch, minibatch = updateTraces1[
            i], EpochTraces1[i] + 1, minibatchTraces1[i] + 1
        text_update.set_text('Epoch={:d}, minibatch={:d}, Updates={:d}'.format(
            epoch, minibatch, update))

        return point_Adam1, line_Adam1, point_Adam2, line_Adam2, point_Adam3, line_Adam3, point_Adam4, line_Adam4, text_update

    # call the animator.  blit=True means only re-draw the parts that have changed.
    intervalms = 10  # this means 10 ms per frame
    anim = animation.FuncAnimation(fig,
                                   animate,
                                   init_func=init,
                                   frames=nframes,
                                   interval=intervalms,
                                   blit=True)

    # save the animation as an mp4.  This requires ffmpeg or mencoder to be installed.
    anim.save(
        r'Results/Part5_Fig16_Animation_Adam_w_MultiplicativeLR_Scheduler.mp4',
        fps=30,
        bitrate=1800)

    plt.show()

    # extra: plot the gradients of Wa and Wb for the case of LR=0.012 and epochs=400
    # to demonstrate the effect of Momentum on optimization
    flag_generate_Fig15 = True
    if flag_generate_Fig15:
        fig = plt.figure(figsize=(12, 6))
        plt.plot(lrTraces_Adam3[:30], '-', label='lr_gamma=0.80', lw=1)
        plt.plot(lrTraces_Adam4[:30], '-r', label='lr_gamma=0.50', lw=1)
        plt.xlabel("Number of Updates during Training")
        plt.ylabel("Learning Rate")
        plt.legend()
        plt.savefig(r'Results/Part5_Fig15_Demonstrate_effect_of_scheduler.png')
    # -------------------------------------------

    print('Done!')
def main():
    device = torch.device('cpu')
    torch.manual_seed(9999)
    a = 1.261845
    b = 1.234378
    c = math.sqrt(a * a - b * b)
    nsamples = 512
    batch_size = 512

    # load previously generated results
    WaTraces_Adam, WbTraces_Adam, LossTraces_Adam, EpochTraces, minibatchTraces, updateTraces, lr_Adam, method_Adam, \
        stepsize_Adam, unitvecWa_Adam, unitvecWb_Adam, dWa_Adam, dWb_Adam, VdWa_Adam, VdWb_Adam, SdWa_Adam, SdWb_Adam = \
        read_optimizer_results(r'results/Adam_custom_implement_LR0.01_results.log')
    WaTraces_SGD, WbTraces_SGD, LossTraces_SGD, _, _, _, lr_SGD, method_SGD,\
        stepsize_SGD, unitvecWa_SGD, unitvecWb_SGD, dWa_SGD, dWb_SGD, _, _, _, _ = \
        read_optimizer_results(r'results/SGD_custom_implement_LR0.01_results.log')
    WaTraces_SGD_MOMENTUM, WbTraces_SGD_MOMENTUM, LossTraces_SGD_MOMENTUM, _, _, _, lr_SGD_MOMENTUM, method_SGD_MOMENTUM, \
        stepsize_SGD_MOMENTUM, unitvecWa_SGD_MOMENTUM, unitvecWb_SGD_MOMENTUM, dWa_SGD_MOMENTUM, dWb_SGD_MOMENTUM, VdWa_SGD_MOMENTUM, VdWa_SGD_MOMENTUM, _, _ = \
        read_optimizer_results(r'results/SGD_w_Momentum_custom_implement_LR0.01_results.log')
    WaTraces_RMSprop, WbTraces_RMSprop, LossTraces_RMSprop, _, _, _, lr_RMSprop, method_RMSprop, \
        stepsize_RMSprop, unitvecWa_RMSprop, unitvecWb_RMSprop, dWa_RMSprop, dWb_RMSprop, VdWa_RMSprop, VdWb_RMSprop, SdWa_RMSprop, SdWb_RMSprop = \
        read_optimizer_results(r'results/RMSprop_custom_implement_LR0.01_results.log')

    # First plot: stepsize vs update ------------
    if True:
        fig = plt.figure(figsize=(12, 6))
        ax = fig.add_subplot(111)
        line_SGD, = ax.plot(updateTraces,
                            stepsize_SGD,
                            '-y',
                            lw=1,
                            label=method_SGD + ' ' + lr_SGD)
        line_SGD_MOMENTUM, = ax.plot(updateTraces,
                                     stepsize_SGD_MOMENTUM,
                                     '-r',
                                     lw=1,
                                     label=method_SGD_MOMENTUM + ' ' +
                                     lr_SGD_MOMENTUM)
        line_RMSprop, = ax.plot(updateTraces,
                                stepsize_RMSprop,
                                '-b',
                                lw=0.5,
                                label=method_RMSprop + ' ' + lr_RMSprop)
        line_Adam, = ax.plot(updateTraces,
                             stepsize_Adam,
                             '-g',
                             lw=1,
                             label=method_Adam + ' ' + lr_Adam)
        leg = ax.legend()
        ax.set(xlim=(0, 600),
               ylim=(0, 0.01 +
                     max(stepsize_SGD.max(), stepsize_SGD_MOMENTUM.max(),
                         stepsize_RMSprop.max(), stepsize_Adam.max())))
        ax.set_title('Effective Step size as a function of Updates',
                     fontsize=16)
        plt.xlabel("# of Updates")
        plt.ylabel("Effective Step Size (AU)")
        fig.tight_layout()
        plt.savefig(r'results/Part5_Fig1_EffStepSize_vs_Updates.png')
        plt.show(block=False)
    # -------------------------------------------

    nframes = len(WaTraces_Adam)
    # nframes = 600  # len(WaTraces_Adam)

    xy_dataset = EllipseDataset(nsamples, a, b, noise_scale=0.1)
    xy_dataloader = DataLoader(xy_dataset,
                               batch_size=batch_size,
                               shuffle=True,
                               num_workers=0)

    Wa0 = WaTraces_Adam[0]
    Wb0 = WbTraces_Adam[0]

    # nWaGrids, nWbGrids = 200, 200
    # WaGrid = np.linspace(0, 2.0, nWaGrids)
    # WbGrid = np.linspace(0, 2.0, nWbGrids)
    nWaGrids, nWbGrids = 250, 250
    WaGrid = np.linspace(0, 2.5, nWaGrids)
    WbGrid = np.linspace(0, 2.5, nWbGrids)
    # WaGrid = np.linspace(0, 0.5, nWaGrids)
    # WbGrid = np.linspace(1.5, 2.0, nWbGrids)

    Wa2d, Wb2d = np.meshgrid(WaGrid, WaGrid)
    loss = np.zeros(Wa2d.shape)

    for i_batch, sample_batched in enumerate(xy_dataloader):
        x, y = sample_batched['x'], sample_batched['y']

    for indexb, Wb in enumerate(WbGrid):
        for indexa, Wa in enumerate(WaGrid):
            y_pred_sqr = Wb**2 * (1.0 - (x + c)**2 / Wa**2)
            y_pred_sqr[
                y_pred_sqr <
                0.00000001] = 0.00000001  # handle negative values caused by noise
            y_pred = torch.sqrt(y_pred_sqr)

            loss[indexb, indexa] = (y_pred - y).pow(2).sum()

    # Second plot: plot Steps during the 1st few updates ------------
    if True:
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111)
        ax.contour(WbGrid,
                   WaGrid,
                   loss,
                   levels=15,
                   linewidths=0.5,
                   colors='gray')
        cntr1 = ax.contourf(WaGrid, WbGrid, loss, levels=100, cmap="RdBu_r")
        fig.colorbar(cntr1, ax=ax, shrink=0.75)
        # ax.set(xlim=(0, 2.0), ylim=(0, 2.0))
        # ax.set(xlim=(0, 2.5), ylim=(0, 2.5))
        ax.set(xlim=(0, 0.5), ylim=(1.5, 2.0))
        ax.set_title('First 4 updates of Wa and Wb in training', fontsize=16)
        plt.xlabel("Wa")
        plt.ylabel("Wb")
        ax.set_aspect('equal', adjustable='box')
        ax.plot(a, b, 'yo', ms=3)
        ax.plot(Wa0, Wb0, 'ko', ms=3)
        ax.text(0.08, 1.81, 'Start', color="black", fontsize=14)

        for i in range(4):
            dWa_SGD = (WaTraces_SGD[i + 1] - WaTraces_SGD[i]) * 0.995
            dWb_SGD = (WbTraces_SGD[i + 1] - WbTraces_SGD[i]) * 0.995
            plt.arrow(WaTraces_SGD[i],
                      WbTraces_SGD[i],
                      dWa_SGD,
                      dWb_SGD,
                      color='yellow',
                      linewidth=1,
                      head_width=0.005,
                      length_includes_head=True)

            dWa_SGD_M = (WaTraces_SGD_MOMENTUM[i + 1] -
                         WaTraces_SGD_MOMENTUM[i]) * 0.995
            dWb_SGD_M = (WbTraces_SGD_MOMENTUM[i + 1] -
                         WbTraces_SGD_MOMENTUM[i]) * 0.995
            plt.arrow(WaTraces_SGD_MOMENTUM[i],
                      WbTraces_SGD_MOMENTUM[i],
                      dWa_SGD_M,
                      dWb_SGD_M,
                      color='red',
                      linewidth=1,
                      head_width=0.005,
                      length_includes_head=True)

            dWa_RMSprop = (WaTraces_RMSprop[i + 1] -
                           WaTraces_RMSprop[i]) * 0.995
            dWb_RMSprop = (WbTraces_RMSprop[i + 1] -
                           WbTraces_RMSprop[i]) * 0.995
            plt.arrow(WaTraces_RMSprop[i],
                      WbTraces_RMSprop[i],
                      dWa_RMSprop,
                      dWb_RMSprop,
                      color='blue',
                      linewidth=1,
                      head_width=0.005,
                      length_includes_head=True)

            dWa_Adam = (WaTraces_Adam[i + 1] - WaTraces_Adam[i]) * 0.995
            dWb_Adam = (WbTraces_Adam[i + 1] - WbTraces_Adam[i]) * 0.995
            plt.arrow(WaTraces_Adam[i],
                      WbTraces_Adam[i],
                      dWa_Adam,
                      dWb_Adam,
                      color='green',
                      linewidth=1,
                      head_width=0.005,
                      length_includes_head=True)

        line_SGD, = ax.plot([], [],
                            '-y',
                            lw=1,
                            label=method_SGD + ' ' + lr_SGD)
        line_SGD_MOMENTUM, = ax.plot([], [],
                                     '-r',
                                     lw=1,
                                     label=method_SGD_MOMENTUM + ' ' +
                                     lr_SGD_MOMENTUM)
        line_RMSprop, = ax.plot([], [],
                                '-b',
                                lw=0.5,
                                label=method_RMSprop + ' ' + lr_RMSprop)
        line_Adam, = ax.plot([], [],
                             '-g',
                             lw=1,
                             label=method_Adam + ' ' + lr_Adam)

        leg = ax.legend()
        fig.tight_layout()
        plt.savefig('results/Part5_Fig2_FirstFewSteps.png')
        plt.show(block=False)
    # -------------------------------------------

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)
    ax.contour(WbGrid, WaGrid, loss, levels=15, linewidths=0.5, colors='gray')
    cntr1 = ax.contourf(WaGrid, WbGrid, loss, levels=100, cmap="RdBu_r")
    fig.colorbar(cntr1, ax=ax, shrink=0.75)
    # ax.set(xlim=(0, 2.0), ylim=(0, 2.0))
    ax.set(xlim=(0, 2.5), ylim=(0, 2.5))
    ax.set_title('Different Optimization Paths Using The Same Learning Rate',
                 fontsize=16)
    plt.xlabel("Wa")
    plt.ylabel("Wb")
    ax.set_aspect('equal', adjustable='box')
    ax.plot(a, b, 'yo', ms=3)
    ax.plot(Wa0, Wb0, 'ko', ms=3)

    wblist_SGD = []
    walist_SGD = []
    point_SGD, = ax.plot([], [], 'yo', lw=0.5, markersize=4)
    line_SGD, = ax.plot([], [], '-y', lw=2, label=method_SGD + ' ' + lr_SGD)

    wblist_SGD_MOMENTUM = []
    walist_SGD_MOMENTUM = []
    point_SGD_MOMENTUM, = ax.plot([], [], 'ro', lw=0.5, markersize=4)
    line_SGD_MOMENTUM, = ax.plot([], [],
                                 '-r',
                                 lw=2,
                                 label=method_SGD_MOMENTUM + ' ' +
                                 lr_SGD_MOMENTUM)

    wblist_RMSprop = []
    walist_RMSprop = []
    point_RMSprop, = ax.plot([], [], 'mo', lw=0.5, markersize=4)
    line_RMSprop, = ax.plot([], [],
                            '-m',
                            lw=2,
                            label=method_RMSprop + ' ' + lr_RMSprop)

    wblist_Adam = []
    walist_Adam = []
    point_Adam, = ax.plot([], [], 'go', lw=0.5, markersize=4)
    line_Adam, = ax.plot([], [], '-g', lw=2, label=method_Adam + ' ' + lr_Adam)

    text_update = ax.text(0.03,
                          0.03,
                          '',
                          transform=ax.transAxes,
                          color="black",
                          fontsize=14)

    leg = ax.legend()
    fig.tight_layout()
    plt.show(block=False)

    # initialization function: plot the background of each frame
    def init():
        point_SGD.set_data([], [])
        line_SGD.set_data([], [])

        point_SGD_MOMENTUM.set_data([], [])
        line_SGD_MOMENTUM.set_data([], [])

        point_RMSprop.set_data([], [])
        line_RMSprop.set_data([], [])

        point_Adam.set_data([], [])
        line_Adam.set_data([], [])

        text_update.set_text('')

        return point_SGD, line_SGD, point_SGD_MOMENTUM, line_SGD_MOMENTUM, point_RMSprop, line_RMSprop, point_Adam, line_Adam, text_update

    # animation function.  This is called sequentially
    def animate(i):
        if i == 0:

            wblist_SGD[:] = []
            walist_SGD[:] = []

            wblist_SGD_MOMENTUM[:] = []
            walist_SGD_MOMENTUM[:] = []

            wblist_RMSprop[:] = []
            walist_RMSprop[:] = []

            wblist_Adam[:] = []
            walist_Adam[:] = []

        wa_SGD, wb_SGD = WaTraces_SGD[i], WbTraces_SGD[i]
        wblist_SGD.append(wa_SGD)
        walist_SGD.append(wb_SGD)
        point_SGD.set_data(wa_SGD, wb_SGD)
        line_SGD.set_data(wblist_SGD, walist_SGD)

        wa_SGD_MOMENTUM, wb_SGD_MOMENTUM = WaTraces_SGD_MOMENTUM[
            i], WbTraces_SGD_MOMENTUM[i]
        wblist_SGD_MOMENTUM.append(wa_SGD_MOMENTUM)
        walist_SGD_MOMENTUM.append(wb_SGD_MOMENTUM)
        point_SGD_MOMENTUM.set_data(wa_SGD_MOMENTUM, wb_SGD_MOMENTUM)
        line_SGD_MOMENTUM.set_data(wblist_SGD_MOMENTUM, walist_SGD_MOMENTUM)

        wa_RMSprop, wb_RMSprop = WaTraces_RMSprop[i], WbTraces_RMSprop[i]
        wblist_RMSprop.append(wa_RMSprop)
        walist_RMSprop.append(wb_RMSprop)
        point_RMSprop.set_data(wa_RMSprop, wb_RMSprop)
        line_RMSprop.set_data(wblist_RMSprop, walist_RMSprop)

        wa_Adam, wb_Adam = WaTraces_Adam[i], WbTraces_Adam[i]
        wblist_Adam.append(wa_Adam)
        walist_Adam.append(wb_Adam)
        point_Adam.set_data(wa_Adam, wb_Adam)
        line_Adam.set_data(wblist_Adam, walist_Adam)

        update, epoch, minibatch = updateTraces[
            i], EpochTraces[i] + 1, minibatchTraces[i] + 1
        text_update.set_text('Epoch={:d}, minibatch={:d}, Updates={:d}'.format(
            epoch, minibatch, update))

        return point_SGD, line_SGD, point_SGD_MOMENTUM, line_SGD_MOMENTUM, point_RMSprop, line_RMSprop, point_Adam, line_Adam, text_update

    # call the animator.  blit=True means only re-draw the parts that have changed.
    intervalms = 10  # this means 10 ms per frame
    anim = animation.FuncAnimation(fig,
                                   animate,
                                   init_func=init,
                                   frames=nframes,
                                   interval=intervalms,
                                   blit=True)

    # save the animation as an mp4.  This requires ffmpeg or mencoder to be installed.
    anim.save('results/Part5_Fig3_Animation_of_loss_vs_TrainingUpdates.mp4',
              fps=30,
              bitrate=1800)

    plt.show()

    print('Done!')