def main(): flag_plot_final_result = True flag_write_log = True device = torch.device('cpu') torch.manual_seed(9999) # a and b are used to generate the training dataset a = 1.261845 b = 1.234378 c = math.sqrt(a * a - b * b) xaxisoffset = c # offset of the center of the ellipse from the origin in this case nsamples = 512 batch_size = 64 epoch = 100 # initWa = torch.rand([], device=device, requires_grad=True) # initWb = torch.rand([], device=device, requires_grad=True) # to initialize Wa and Wb with specific starting values initWa = torch.tensor(0.10, device=device, requires_grad=True) initWb = torch.tensor(1.8, device=device, requires_grad=True) thenet = OrbitRegressionNet(xaxisoffset, initWa, initWb) loss_fn = nn.MSELoss(reduction='mean') optim_algo = 'Adam' # 'SGD', 'SGD_Momentum', 'RMSprop' or 'Adam', case sensitive if optim_algo == 'SGD': # ----- hyper-parameters for the SGD optimizer ------------------- init_learning_rate = 0.01 momentum = 0.9 optimizer = optim.SGD(thenet.parameters(), lr=init_learning_rate, momentum=0.0) elif optim_algo == 'SGD_Momentum': # ----- hyper-parameters for the SGD_with_Momentum optimizer ----- init_learning_rate = 0.01 momentum = 0.9 dampening = 0.0 optimizer = optim.SGD(thenet.parameters(), lr=init_learning_rate, momentum=momentum, dampening=dampening) elif optim_algo == 'RMSprop': # ----- hyper-parameters for the RMSprop optimizer --------------- init_learning_rate = 0.02 alpha = 0.99 eps = 1e-08 optimizer = optim.RMSprop(thenet.parameters(), lr=init_learning_rate, alpha=alpha, eps=eps) else: # ----- hyper-parameters for the Adam optimizer -------- init_learning_rate = 0.12 beta1, beta2 = 0.9, 0.999 eps = 1e-08 optimizer = optim.Adam(thenet.parameters(), lr=init_learning_rate, betas=(beta1, beta2), eps=eps) # instantiate the MultiplicativeLR schedule class, comment out one of the following two blocks # lambda_fn_0_8 = lambda epoch: 0.8 if (epoch >= 6 and epoch <= 24 and epoch % 2 == 0) else 1.0 # scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda_fn_0_8) # strLambda = 'lambda_fn_0_8' lambda0_5 = lambda epoch: 0.5 if (epoch >= 5 and epoch <= 20 and epoch % 5 == 0) else 1.0 scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda0_5) strLambda = 'lambda_fn_0_5' if flag_write_log: logfilename = r'results/' + optim_algo + '_lr{}'.format( init_learning_rate) + '_Epoch{}'.format(epoch) + '_Schl_{}'.format( strLambda) + '_results.log' foutput = open(logfilename, "w") foutput.write(optim_algo + 'result\n') logstr = 'nsamples={}, batch_size={}, epoch={}, lr={}, lr_gamma={}\ninitial Wa={}, Wb={}, lr_milestone={}\n' \ .format(nsamples, batch_size, epoch, init_learning_rate, strLambda, initWa, initWb, r'lmbda0_5 = lambda epoch: 0.5 if (epoch >= 5 and epoch <= 20 and epoch % 5 == 0) else 1.0') foutput.write(logstr) xy_dataset = EllipseDataset(nsamples, a, b, noise_scale=0.1) xy_dataloader = DataLoader(xy_dataset, batch_size=batch_size, shuffle=True, num_workers=0) updateCount = 0 for t in range(epoch): for i_batch, sample_batched in enumerate(xy_dataloader): x, y = sample_batched['x'], sample_batched['y'] thenet.train() # Step 1: Perform forward pass y_pred, negativeloc = thenet(x) # Step 2: Compute loss loss = loss_fn(y_pred[~negativeloc], y[~negativeloc]) # Step 3: perform back-propagation and calculate the gradients of loss w.r.t. Wa and Wb optimizer.zero_grad() loss.backward() if flag_write_log: logstr = 'updates={}, Epoch={}, minibatch={}, loss={:.5f}, Wa={:.4f}, Wb={:.4f}'.format( updateCount + 1, t, i_batch, loss.item(), thenet.Wa.data.numpy(), thenet.Wb.data.numpy()) logstr1 = ', unitvecWa={:.5f}, unitvecWb={:.5f}, learningrate={:.5f}, ' \ 'dWa={:.5f}, dWb={:.5f}, VdWa={:.5f}, VdWb={:.5f}, sqrt_SdWa={:.5f}, sqrt_SdWb={:.5f}\n'.format( 0, 0, optimizer.param_groups[0]['lr'], thenet.Wa.grad.data.numpy(), thenet.Wb.grad.data.numpy(), 0, 0, 0, 0) foutput.write(logstr + logstr1) # if t % 10 == 0 and i_batch == 0: # print(logstr) # Step 4: finally Update weights Wa and Wb using Adam algorithm. optimizer.step() # Step 4.1: and update the learning rate hyper-parameter based on the scheduler. scheduler.step() updateCount += 1 # log the final results if flag_write_log: logstr = 'The ground truth is A={:.4f}, B={:.4f}\n'.format(a, b) logstr += 'PyTorch built-in AutoGradient+optimizer result: Final estimated Wa={:.4f}, Wb={:.4f}\n'.format( thenet.Wa, thenet.Wb) foutput.write(logstr) foutput.close() print(logstr) # plot the results obtained from the training if flag_plot_final_result: x = xy_dataset[:]['x'] yfit = thenet.Wb * torch.sqrt(1.0 - (x + c)**2 / thenet.Wa**2) yfit[ yfit != yfit] = 0.0 # take care of the "Nan" at the end-points due to sqrt(negative_value_caused_by_noise) plt.plot(x, yfit.detach().numpy(), color="purple", linewidth=2.0) strEquation = r'$\frac{{{\left({x+' + '{:.3f}'.format(c) + r'}\right)}^2}}{{' + '{:.3f}'.\ format(thenet.Wa) + r'^2}}+\frac{y^2}{' + '{:.3f}'.format(thenet.Wb) + r'^2}=1$' x0, y0 = x.detach().numpy()[nsamples * 2 // 3], yfit.detach().numpy()[nsamples * 2 // 3] plt.annotate(strEquation, xy=(x0, y0), xycoords='data', xytext=(+0.75, 1.75), textcoords='data', fontsize=16, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2")) plt.text(1.0, 1.5, 'Result of Adam', color='black', fontsize=12) plt.text(-1.9, 1.9, 'Epoch={}\n'.format(epoch) + 'Init Learning Rate={}\n'.format(init_learning_rate) + 'MultiplicativeLR w/ {}'.format(strLambda), color='black', fontsize=12, ha='left', va='top') figfilename = r'results/' + optim_algo + '_lr{}'.format( init_learning_rate) + '_Epoch{}'.format(epoch) + '_Schl_{}'.format( strLambda) + '.png' plt.savefig(figfilename) plt.show() print('Done!')
def main(): flag_manual_implement = True # True: using our custom implementation; False: using Torch built-in flag_plot_final_result = True flag_log = True device = torch.device('cpu') torch.manual_seed(9999) # a and b are used to generate the training dataset a = 1.261845 b = 1.234378 c = math.sqrt(a * a - b * b) nsamples = 512 batch_size = 64 epoch = 160 learning_rate = 0.01 # 0.03 lr_milestones = [16] lr_gamma = 1.0 # 1.0 means no LR scheduler beta1, beta2 = 0.9, 0.999 eps = 1e-08 xy_dataset = EllipseDataset(nsamples, a, b, noise_scale=0.1) xy_dataloader = DataLoader(xy_dataset, batch_size=batch_size, shuffle=True, num_workers=0) # Wa = torch.rand([], device=device, requires_grad=True) # Wb = torch.rand([], device=device, requires_grad=True) Wa = torch.tensor(0.10, device=device, requires_grad=True) Wb = torch.tensor(1.8, device=device, requires_grad=True) if flag_log: if flag_manual_implement: logfilename = 'results/Adam_custom_implement_LR{}'.format( learning_rate) + '_results.log' foutput = open(logfilename, "w") foutput.write('Adam optimization using custom implementation' + '\n') else: logfilename = 'results/Adam_custom_implement_LR{}'.format( learning_rate) + '_results.log' foutput = open(logfilename, "w") foutput.write('Adam optimization using torch built-in' + '\n') logstr = 'nsamples={}, batch_size={}, epoch={}, lr={}\ninitial Wa={}, Wb={}\n' \ .format(nsamples, batch_size, epoch, learning_rate, Wa, Wb) foutput.write(logstr) print(logstr) VdWa = 0.0 VdWb = 0.0 SdWa = 0.0 SdWb = 0.0 beta1_to_pow_t = 1.0 beta2_to_pow_t = 1.0 if flag_manual_implement: optimizer = None else: optimizer = optim.Adam([Wa, Wb], lr=learning_rate, betas=(beta1, beta2), eps=eps) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=lr_milestones, gamma=lr_gamma) updates = 0 for t in range(epoch): for i_batch, sample_batched in enumerate(xy_dataloader): x, y = sample_batched['x'], sample_batched['y'] # Step 1: Perform forward pass y_pred_sqr = 1.0 - (x + c)**2 / Wa**2 negativeloc = y_pred_sqr < 0 # record the non-negative y_pred_sqr elements, to be used later y_pred_sqr[ negativeloc] = 0 # handle negative values caused by noise y_pred = torch.sqrt(y_pred_sqr) * Wb # Step 2: Compute loss loss = ((y_pred - y).pow(2))[~negativeloc].mean() if flag_log: logstr = 'updates={}, Epoch={}, minibatch={}, loss={:.5f}, Wa={:.4f}, Wb={:.4f}'.format( updates + 1, t, i_batch, loss.item(), Wa.data.numpy(), Wb.data.numpy()) foutput.write(logstr) if t % 10 == 0 and i_batch == 0: print(logstr) if flag_manual_implement: # do the job manually if updates in lr_milestones: learning_rate *= lr_gamma # Step 3: perform back-propagation and calculate the gradients of loss w.r.t. Wa and Wb dWa_via_yi = (2.0 * (y_pred - y) * ((x + c)**2) * (Wb**2) / (Wa**3) / y_pred) dWa = dWa_via_yi[~negativeloc].mean() # sum() dWb_via_yi = ((2.0 * (y_pred - y) * y_pred / Wb)) dWb = dWb_via_yi[~negativeloc].mean() #.sum() # Step 4: Update weights using Adam algorithm. with torch.no_grad(): beta1_to_pow_t *= beta1 beta1_correction = 1.0 - beta1_to_pow_t beta2_to_pow_t *= beta2 beta2_correction = 1.0 - beta2_to_pow_t VdWa = beta1 * VdWa + (1.0 - beta1) * dWa SdWa = beta2 * SdWa + (1.0 - beta2) * dWa * dWa Wa -= learning_rate * (VdWa / beta1_correction) / ( torch.sqrt(SdWa) / math.sqrt(beta2_correction) + eps) VdWb = beta1 * VdWb + (1.0 - beta1) * dWb SdWb = beta2 * SdWb + (1.0 - beta2) * dWb * dWb Wb -= learning_rate * (VdWb / beta1_correction) / ( torch.sqrt(SdWb) / math.sqrt(beta2_correction) + eps) if flag_log: tmp_a = (VdWa / beta1_correction) / ( torch.sqrt(SdWa) / math.sqrt(beta2_correction) + eps) tmp_b = (VdWb / beta1_correction) / ( torch.sqrt(SdWb) / math.sqrt(beta2_correction) + eps) mag_dWadWb = math.sqrt(tmp_a**2 + tmp_b**2) unitvec_a = tmp_a / mag_dWadWb unitvec_b = tmp_b / mag_dWadWb actual_stepsize = learning_rate * mag_dWadWb logstr = ', unitvecWa={:.5f}, unitvecWb={:.5f}, eff_stepsize={:.5f}, ' \ 'dWa={:.5f}, dWb={:.5f}, VdWa={:.5f}, VdWb={:.5f}, sqrt_SdWa={:.5f}, sqrt_SdWb={:.5f}\n'.format( unitvec_a, unitvec_b, actual_stepsize, dWa, dWb, VdWa, VdWb, math.sqrt(SdWa), math.sqrt(SdWb)) foutput.write(logstr) else: # do the same job using Torch built-in autograd and optim optimizer.zero_grad() # Step 3: perform back-propagation and calculate the gradients of loss w.r.t. Wa and Wb loss.backward() # Step 4: Update weights using Adam algorithm. optimizer.step() scheduler.step() updates += 1 # log the final results if flag_log: logstr = 'The ground truth is A={:.4f}, B={:.4f}\n'.format(a, b) if flag_manual_implement: logstr += 'Manually implemented gradient+optimizer result: Final estimated Wa={:.4f}, Wb={:.4f}\n'.format( Wa, Wb) else: logstr += 'PyTorch built-in AutoGradient+optimizer result: Final estimated Wa={:.4f}, Wb={:.4f}\n'.format( Wa, Wb) foutput.write(logstr) foutput.close() print(logstr) # plot the results obtained from the training if flag_plot_final_result: x = xy_dataset[:]['x'] yfit = Wb * torch.sqrt(1.0 - (x + c)**2 / Wa**2) yfit[ yfit != yfit] = 0.0 # take care of the "Nan" at the end-points due to sqrt(negative_value_caused_by_noise) plt.plot(x, yfit.detach().numpy(), color="purple", linewidth=2.0) strEquation = r'$\frac{{{\left({x+' + '{:.3f}'.format( c) + r'}\right)}^2}}{{' + '{:.3f}'.format( Wa) + r'^2}}+\frac{y^2}{' + '{:.3f}'.format(Wb) + r'^2}=1$' x0, y0 = x.detach().numpy()[nsamples * 2 // 3], yfit.detach().numpy()[nsamples * 2 // 3] plt.annotate(strEquation, xy=(x0, y0), xycoords='data', xytext=(+0.75, 1.75), textcoords='data', fontsize=16, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2")) plt.text(1.0, 1.5, 'Result of Adam', color='black', fontsize=12) plt.show() print('Done!')
def main(): device = torch.device('cpu') torch.manual_seed(9999) a = 1.261845 b = 1.234378 c = math.sqrt(a * a - b * b) nsamples = 512 batch_size = 512 # load previously generated results WaTraces_SGD_MOMENTUM1, WbTraces_SGD_MOMENTUM1, LossTraces_SGD_MOMENTUM1, EpochTraces1, minibatchTraces1, updateTraces1, lr_SGD_MOMENTUM1, method_SGD_MOMENTUM1, \ _, _, _, dWa_SGD_MOMENTUM1, dWb_SGD_MOMENTUM1, _, _, _, _ = \ read_optimizer_results(r'results/Sect2.2_SGD_Momentum_lr0.012_Epoch100_results.log') WaTraces_SGD_MOMENTUM2, WbTraces_SGD_MOMENTUM2, LossTraces_SGD_MOMENTUM2, EpochTraces2, minibatchTraces2, updateTraces2, lr_SGD_MOMENTUM2, method_SGD_MOMENTUM2, \ _, _, _, dWa_SGD_MOMENTUM2, dWb_SGD_MOMENTUM2, _, _, _, _ = \ read_optimizer_results(r'results/Sect2.2_SGD_Momentum_lr0.01_Epoch100_results.log') WaTraces_SGD_MOMENTUM3, WbTraces_SGD_MOMENTUM3, LossTraces_SGD_MOMENTUM3, EpochTraces3, minibatchTraces3, updateTraces3, lr_SGD_MOMENTUM3, method_SGD_MOMENTUM3, \ _, _, _, dWa_SGD_MOMENTUM3, dWb_SGD_MOMENTUM3, _, _, _, _ = \ read_optimizer_results(r'results/Sect2.2_SGD_Momentum_lr0.005_Epoch100_results.log') WaTraces_SGD_MOMENTUM4, WbTraces_SGD_MOMENTUM4, LossTraces_SGD_MOMENTUM4, EpochTraces4, minibatchTraces4, updateTraces4, lr_SGD_MOMENTUM4, method_SGD_MOMENTUM4, \ _, _, _, dWa_SGD_MOMENTUM4, dWb_SGD_MOMENTUM4, _, _, _, _ = \ read_optimizer_results(r'results/Sect2.2_SGD_Momentum_lr0.001_Epoch100_results.log') WaTraces_SGD_MOMENTUM5, WbTraces_SGD_MOMENTUM5, LossTraces_SGD_MOMENTUM5, EpochTraces5, minibatchTraces5, updateTraces5, lr_SGD_MOMENTUM5, method_SGD_MOMENTUM5, \ _, _, _, dWa_SGD_MOMENTUM5, dWb_SGD_MOMENTUM5, _, _, _, _ = \ read_optimizer_results(r'results/Sect2.2_SGD_Momentum_lr0.00015_Epoch100_results.log') # WaTraces_SGD_MOMENTUM6, WbTraces_SGD_MOMENTUM6, LossTraces_SGD_MOMENTUM6, _, _, _, lr_SGD_MOMENTUM6, method_SGD_MOMENTUM6, \ # _, _, _, dWa_SGD_MOMENTUM6, dWb_SGD_MOMENTUM6, _, _, _, _ = \ # read_optimizer_results(r'SGD_Momentum_lr0.00015_Epoch500_results.log') nframes = max([ WaTraces_SGD_MOMENTUM1.size, WaTraces_SGD_MOMENTUM2.size, WaTraces_SGD_MOMENTUM3.size, WaTraces_SGD_MOMENTUM4.size, WaTraces_SGD_MOMENTUM5.size ]) epoch1, epoch2, epoch3, epoch4, epoch5 = EpochTraces1[ -1] + 1, EpochTraces2[-1] + 1, EpochTraces3[-1] + 1, EpochTraces4[ -1] + 1, EpochTraces5[-1] + 1 xy_dataset = EllipseDataset(nsamples, a, b, noise_scale=0.1) xy_dataloader = DataLoader(xy_dataset, batch_size=batch_size, shuffle=True, num_workers=0) Wa0 = WaTraces_SGD_MOMENTUM1[0] Wb0 = WbTraces_SGD_MOMENTUM1[0] # nWaGrids, nWbGrids = 200, 200 # WaGrid = np.linspace(0, 2.0, nWaGrids) # WbGrid = np.linspace(0, 2.0, nWbGrids) nWaGrids, nWbGrids = 250, 250 # WaGrid = np.linspace(0, 2.5, nWaGrids) # WbGrid = np.linspace(0, 2.5, nWbGrids) WaGrid = np.linspace(0, 4.0, nWaGrids) WbGrid = np.linspace(-0.5, 3.5, nWbGrids) Wa2d, Wb2d = np.meshgrid(WaGrid, WaGrid) loss = np.zeros(Wa2d.shape) for i_batch, sample_batched in enumerate(xy_dataloader): x, y = sample_batched['x'], sample_batched['y'] for indexb, Wb in enumerate(WbGrid): for indexa, Wa in enumerate(WaGrid): y_pred_sqr = Wb**2 * (1.0 - (x + c)**2 / Wa**2) y_pred_sqr[ y_pred_sqr < 0.00000001] = 0.00000001 # handle negative values caused by noise y_pred = torch.sqrt(y_pred_sqr) loss[indexb, indexa] = (y_pred - y).pow(2).sum() fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111) ax.contour(WbGrid, WaGrid, loss, levels=15, linewidths=0.5, colors='gray') cntr1 = ax.contourf(WaGrid, WbGrid, loss, levels=100, cmap="RdBu_r") fig.colorbar(cntr1, ax=ax, shrink=0.75) # ax.set(xlim=(0, 2.0), ylim=(0, 2.0)) ax.set(xlim=(0, 4.0), ylim=(-0.5, 3.5)) ax.set_title('SGD_w_Momentum Training Progress w/ Diff LR', fontsize=16) plt.xlabel("Wa") plt.ylabel("Wb") ax.set_aspect('equal', adjustable='box') ax.plot(a, b, 'yo', ms=4) ax.plot(Wa0, Wb0, 'ko', ms=5) ax.text(0.08, 1.85, 'Start', color="white", fontsize=14) ax.text(0.88, 1.15, 'Target', color="white", fontsize=14) walist_SGD_MOMENTUM1 = [] wblist_SGD_MOMENTUM1 = [] point_SGD_MOMENTUM1, = ax.plot([], [], 'ro', lw=0.5, markersize=4) line_SGD_MOMENTUM1, = ax.plot( [], [], '-r', lw=2, label='SGD_Momentum epoch={} '.format(epoch1) + lr_SGD_MOMENTUM1) walist_SGD_MOMENTUM2 = [] wblist_SGD_MOMENTUM2 = [] point_SGD_MOMENTUM2, = ax.plot([], [], 'yo', lw=0.5, markersize=4) line_SGD_MOMENTUM2, = ax.plot( [], [], '-y', lw=2, label='SGD_Momentum epoch={} '.format(epoch2) + lr_SGD_MOMENTUM2) walist_SGD_MOMENTUM3 = [] wblist_SGD_MOMENTUM3 = [] point_SGD_MOMENTUM3, = ax.plot([], [], 'mo', lw=0.5, markersize=4) line_SGD_MOMENTUM3, = ax.plot( [], [], '-m', lw=2, label='SGD_Momentum epoch={} '.format(epoch3) + lr_SGD_MOMENTUM3) walist_SGD_MOMENTUM4 = [] wblist_SGD_MOMENTUM4 = [] point_SGD_MOMENTUM4, = ax.plot([], [], 'go', lw=0.5, markersize=4) line_SGD_MOMENTUM4, = ax.plot( [], [], '-g', lw=2, label='SGD_Momentum epoch={} '.format(epoch4) + lr_SGD_MOMENTUM4) walist_SGD_MOMENTUM5 = [] wblist_SGD_MOMENTUM5 = [] point_SGD_MOMENTUM5, = ax.plot([], [], 'o', lw=0.5, markersize=4, color='aqua') line_SGD_MOMENTUM5, = ax.plot( [], [], '-', lw=2, color='aqua', label='SGD_Momentum epoch={} '.format(epoch5) + lr_SGD_MOMENTUM5) text_update = ax.text(0.03, 0.03, '', transform=ax.transAxes, color="white", fontsize=14) leg = ax.legend() fig.tight_layout() plt.show(block=False) # initialization function: plot the background of each frame def init(): point_SGD_MOMENTUM1.set_data([], []) line_SGD_MOMENTUM1.set_data([], []) point_SGD_MOMENTUM2.set_data([], []) line_SGD_MOMENTUM2.set_data([], []) point_SGD_MOMENTUM3.set_data([], []) line_SGD_MOMENTUM3.set_data([], []) point_SGD_MOMENTUM4.set_data([], []) line_SGD_MOMENTUM4.set_data([], []) point_SGD_MOMENTUM5.set_data([], []) line_SGD_MOMENTUM5.set_data([], []) text_update.set_text('') return point_SGD_MOMENTUM1, line_SGD_MOMENTUM1, point_SGD_MOMENTUM2, line_SGD_MOMENTUM2, point_SGD_MOMENTUM3, line_SGD_MOMENTUM3, point_SGD_MOMENTUM4, line_SGD_MOMENTUM4, point_SGD_MOMENTUM5, line_SGD_MOMENTUM5, text_update # animation function. This is called sequentially def animate(i): if i == 0: wblist_SGD_MOMENTUM1[:] = [] walist_SGD_MOMENTUM1[:] = [] wblist_SGD_MOMENTUM2[:] = [] walist_SGD_MOMENTUM2[:] = [] wblist_SGD_MOMENTUM3[:] = [] walist_SGD_MOMENTUM3[:] = [] wblist_SGD_MOMENTUM4[:] = [] walist_SGD_MOMENTUM4[:] = [] wblist_SGD_MOMENTUM5[:] = [] walist_SGD_MOMENTUM5[:] = [] wa_SGD_MOMENTUM1, wb_SGD_MOMENTUM1 = WaTraces_SGD_MOMENTUM1[ i], WbTraces_SGD_MOMENTUM1[i] wblist_SGD_MOMENTUM1.append(wa_SGD_MOMENTUM1) walist_SGD_MOMENTUM1.append(wb_SGD_MOMENTUM1) point_SGD_MOMENTUM1.set_data(wa_SGD_MOMENTUM1, wb_SGD_MOMENTUM1) line_SGD_MOMENTUM1.set_data(wblist_SGD_MOMENTUM1, walist_SGD_MOMENTUM1) wa_SGD_MOMENTUM2, wb_SGD_MOMENTUM2 = WaTraces_SGD_MOMENTUM2[ i], WbTraces_SGD_MOMENTUM2[i] wblist_SGD_MOMENTUM2.append(wa_SGD_MOMENTUM2) walist_SGD_MOMENTUM2.append(wb_SGD_MOMENTUM2) point_SGD_MOMENTUM2.set_data(wa_SGD_MOMENTUM2, wb_SGD_MOMENTUM2) line_SGD_MOMENTUM2.set_data(wblist_SGD_MOMENTUM2, walist_SGD_MOMENTUM2) wa_SGD_MOMENTUM3, wb_SGD_MOMENTUM3 = WaTraces_SGD_MOMENTUM3[ i], WbTraces_SGD_MOMENTUM3[i] wblist_SGD_MOMENTUM3.append(wa_SGD_MOMENTUM3) walist_SGD_MOMENTUM3.append(wb_SGD_MOMENTUM3) point_SGD_MOMENTUM3.set_data(wa_SGD_MOMENTUM3, wb_SGD_MOMENTUM3) line_SGD_MOMENTUM3.set_data(wblist_SGD_MOMENTUM3, walist_SGD_MOMENTUM3) wa_SGD_MOMENTUM4, wb_SGD_MOMENTUM4 = WaTraces_SGD_MOMENTUM4[ i], WbTraces_SGD_MOMENTUM4[i] wblist_SGD_MOMENTUM4.append(wa_SGD_MOMENTUM4) walist_SGD_MOMENTUM4.append(wb_SGD_MOMENTUM4) point_SGD_MOMENTUM4.set_data(wa_SGD_MOMENTUM4, wb_SGD_MOMENTUM4) line_SGD_MOMENTUM4.set_data(wblist_SGD_MOMENTUM4, walist_SGD_MOMENTUM4) wa_SGD_MOMENTUM5, wb_SGD_MOMENTUM5 = WaTraces_SGD_MOMENTUM5[ i], WbTraces_SGD_MOMENTUM5[i] wblist_SGD_MOMENTUM5.append(wa_SGD_MOMENTUM5) walist_SGD_MOMENTUM5.append(wb_SGD_MOMENTUM5) point_SGD_MOMENTUM5.set_data(wa_SGD_MOMENTUM5, wb_SGD_MOMENTUM5) line_SGD_MOMENTUM5.set_data(wblist_SGD_MOMENTUM5, walist_SGD_MOMENTUM5) update, epoch, minibatch = updateTraces1[ i], EpochTraces1[i] + 1, minibatchTraces1[i] + 1 text_update.set_text('Epoch={:d}, minibatch={:d}, Updates={:d}'.format( epoch, minibatch, update)) return point_SGD_MOMENTUM1, line_SGD_MOMENTUM1, point_SGD_MOMENTUM2, line_SGD_MOMENTUM2, point_SGD_MOMENTUM3, line_SGD_MOMENTUM3, point_SGD_MOMENTUM4, line_SGD_MOMENTUM4, point_SGD_MOMENTUM5, line_SGD_MOMENTUM5, text_update # call the animator. blit=True means only re-draw the parts that have changed. intervalms = 10 # this means 10 ms per frame anim = animation.FuncAnimation(fig, animate, init_func=init, frames=nframes, interval=intervalms, blit=True) # save the animation as an mp4. This requires ffmpeg or mencoder to be installed. anim.save( 'Results/Part5_Fig9_Animation_SGD_Momentum_100Epoch_VariousLR.mp4', fps=30, bitrate=1800) plt.show() print('Done!')
def main(): device = torch.device('cpu') torch.manual_seed(9999) a = 1.261845 b = 1.234378 c = math.sqrt(a * a - b * b) nsamples = 512 batch_size = 512 # load previously generated results WaTraces_Adam1, WbTraces_Adam1, LossTraces_Adam1, EpochTraces1, minibatchTraces1, updateTraces1, lr_Adam1, gamma_Adam1, method_Adam1, \ lrTraces_Adam1, _, _, dWa_Adam1, dWb_Adam1, _, _, _, _ = \ read_optimizer_results(r'results/Adam_lr0.01_Epoch100_Schl1.00_results.log') WaTraces_Adam2, WbTraces_Adam2, LossTraces_Adam2, EpochTraces2, minibatchTraces2, updateTraces2, lr_Adam2, gamma_Adam2, method_Adam2, \ lrTraces_Adam2, _, _, dWa_Adam2, dWb_Adam2, _, _, _, _ = \ read_optimizer_results(r'results/Adam_lr0.12_Epoch100_Schl1.00_results.log') WaTraces_Adam3, WbTraces_Adam3, LossTraces_Adam3, EpochTraces3, minibatchTraces3, updateTraces3, lr_Adam3, gamma_Adam3, method_Adam3, \ lrTraces_Adam3, _, _, dWa_Adam3, dWb_Adam3, _, _, _, _ = \ read_optimizer_results(r'results/Adam_lr0.12_Epoch100_Schl_lambda_fn_0_8_results.log') WaTraces_Adam4, WbTraces_Adam4, LossTraces_Adam4, EpochTraces4, minibatchTraces4, updateTraces4, lr_Adam4, gamma_Adam4, method_Adam4, \ lrTraces_Adam4, _, _, dWa_Adam4, dWb_Adam4, _, _, _, _ = \ read_optimizer_results(r'results/Adam_lr0.12_Epoch100_Schl_lambda_fn_0_5_results.log') nframes = max([ WaTraces_Adam1.size, WaTraces_Adam2.size, WaTraces_Adam3.size, WaTraces_Adam4.size ]) epoch1, epoch2, epoch3, epoch4 = EpochTraces1[-1] + 1, EpochTraces2[ -1] + 1, EpochTraces3[-1] + 1, EpochTraces4[-1] + 1 update1, update2, update3, update4 = updateTraces1[-1], updateTraces2[ -1], updateTraces3[-1], updateTraces4[-1] xy_dataset = EllipseDataset(nsamples, a, b, noise_scale=0.1) xy_dataloader = DataLoader(xy_dataset, batch_size=batch_size, shuffle=True, num_workers=0) Wa0 = WaTraces_Adam1[0] Wb0 = WbTraces_Adam1[0] # nWaGrids, nWbGrids = 200, 200 # WaGrid = np.linspace(0, 2.0, nWaGrids) # WbGrid = np.linspace(0, 2.0, nWbGrids) nWaGrids, nWbGrids = 200, 200 # WaGrid = np.linspace(0, 2.5, nWaGrids) # WbGrid = np.linspace(0, 2.5, nWbGrids) WaGrid = np.linspace(0, 2.0, nWaGrids) WbGrid = np.linspace(0.25, 2.25, nWbGrids) Wa2d, Wb2d = np.meshgrid(WaGrid, WaGrid) loss = np.zeros(Wa2d.shape) for i_batch, sample_batched in enumerate(xy_dataloader): x, y = sample_batched['x'], sample_batched['y'] for indexb, Wb in enumerate(WbGrid): for indexa, Wa in enumerate(WaGrid): y_pred_sqr = Wb**2 * (1.0 - (x + c)**2 / Wa**2) y_pred_sqr[ y_pred_sqr < 0.00000001] = 0.00000001 # handle negative values caused by noise y_pred = torch.sqrt(y_pred_sqr) loss[indexb, indexa] = (y_pred - y).pow(2).sum() fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111) ax.contour(WbGrid, WaGrid, loss, levels=15, linewidths=0.5, colors='gray') cntr1 = ax.contourf(WaGrid, WbGrid, loss, levels=100, cmap="RdBu_r") fig.colorbar(cntr1, ax=ax, shrink=0.75) # ax.set(xlim=(0, 2.0), ylim=(0, 2.0)) ax.set(xlim=(0, 2), ylim=(0.25, 2.25)) ax.set_title('Adam Optimization with MultiplicativeLR Schedulers & Lambda', fontsize=16) plt.xlabel("Wa") plt.ylabel("Wb") ax.set_aspect('equal', adjustable='box') ax.plot(a, b, 'yo', ms=5) ax.plot(Wa0, Wb0, 'ko', ms=5) ax.text(0.08, 1.85, 'Start', color="white", fontsize=14) ax.text(0.88, 1.15, 'Target', color="white", fontsize=14) walist_Adam1 = [] wblist_Adam1 = [] point_Adam1, = ax.plot([], [], 'ro', lw=0.5, markersize=4) line_Adam1, = ax.plot([], [], '-r', lw=1, label='Adam epoch={}, '.format(epoch1) + lr_Adam1 + ' fixed') walist_Adam2 = [] wblist_Adam2 = [] point_Adam2, = ax.plot([], [], 'yo', lw=0.5, markersize=4) line_Adam2, = ax.plot([], [], '-y', lw=1, label='Adam epoch={}, '.format(epoch2) + lr_Adam2 + ' fixed') walist_Adam3 = [] wblist_Adam3 = [] point_Adam3, = ax.plot([], [], 'mo', lw=0.5, markersize=4) line_Adam3, = ax.plot([], [], '-m', lw=1, label='Adam epoch={}, '.format(epoch3) + lr_Adam3 + ' MultiplicativeLR w/ Lambda fn 0_8') walist_Adam4 = [] wblist_Adam4 = [] point_Adam4, = ax.plot([], [], 'o', lw=0.5, markersize=4, color='aqua') line_Adam4, = ax.plot([], [], '-', lw=1, color='aqua', label='Adam epoch={}, '.format(epoch4) + lr_Adam4 + ' MultiplicativeLR w/ Lambda fn 0_5') text_update = ax.text(0.03, 0.03, '', transform=ax.transAxes, color="blue", fontsize=14) leg = ax.legend() fig.tight_layout() plt.show(block=False) # initialization function: plot the background of each frame def init(): point_Adam1.set_data([], []) line_Adam1.set_data([], []) point_Adam2.set_data([], []) line_Adam2.set_data([], []) point_Adam3.set_data([], []) line_Adam3.set_data([], []) point_Adam4.set_data([], []) line_Adam4.set_data([], []) text_update.set_text('') return point_Adam1, line_Adam1, point_Adam2, line_Adam2, point_Adam3, line_Adam3, point_Adam4, line_Adam4, text_update # animation function. This is called sequentially def animate(i): if i == 0: wblist_Adam1[:] = [] walist_Adam1[:] = [] wblist_Adam2[:] = [] walist_Adam2[:] = [] wblist_Adam3[:] = [] walist_Adam3[:] = [] wblist_Adam4[:] = [] walist_Adam4[:] = [] if i < update1: wa_Adam1, wb_Adam1 = WaTraces_Adam1[i], WbTraces_Adam1[i] wblist_Adam1.append(wa_Adam1) walist_Adam1.append(wb_Adam1) point_Adam1.set_data(wa_Adam1, wb_Adam1) line_Adam1.set_data(wblist_Adam1, walist_Adam1) if i < update2: wa_Adam2, wb_Adam2 = WaTraces_Adam2[i], WbTraces_Adam2[i] wblist_Adam2.append(wa_Adam2) walist_Adam2.append(wb_Adam2) point_Adam2.set_data(wa_Adam2, wb_Adam2) line_Adam2.set_data(wblist_Adam2, walist_Adam2) if i < update3: wa_Adam3, wb_Adam3 = WaTraces_Adam3[i], WbTraces_Adam3[i] wblist_Adam3.append(wa_Adam3) walist_Adam3.append(wb_Adam3) point_Adam3.set_data(wa_Adam3, wb_Adam3) line_Adam3.set_data(wblist_Adam3, walist_Adam3) if i < update4: wa_Adam4, wb_Adam4 = WaTraces_Adam4[i], WbTraces_Adam4[i] wblist_Adam4.append(wa_Adam4) walist_Adam4.append(wb_Adam4) point_Adam4.set_data(wa_Adam4, wb_Adam4) line_Adam4.set_data(wblist_Adam4, walist_Adam4) update, epoch, minibatch = updateTraces1[ i], EpochTraces1[i] + 1, minibatchTraces1[i] + 1 text_update.set_text('Epoch={:d}, minibatch={:d}, Updates={:d}'.format( epoch, minibatch, update)) return point_Adam1, line_Adam1, point_Adam2, line_Adam2, point_Adam3, line_Adam3, point_Adam4, line_Adam4, text_update # call the animator. blit=True means only re-draw the parts that have changed. intervalms = 10 # this means 10 ms per frame anim = animation.FuncAnimation(fig, animate, init_func=init, frames=nframes, interval=intervalms, blit=True) # save the animation as an mp4. This requires ffmpeg or mencoder to be installed. anim.save( r'Results/Part5_Fig16_Animation_Adam_w_MultiplicativeLR_Scheduler.mp4', fps=30, bitrate=1800) plt.show() # extra: plot the gradients of Wa and Wb for the case of LR=0.012 and epochs=400 # to demonstrate the effect of Momentum on optimization flag_generate_Fig15 = True if flag_generate_Fig15: fig = plt.figure(figsize=(12, 6)) plt.plot(lrTraces_Adam3[:30], '-', label='lr_gamma=0.80', lw=1) plt.plot(lrTraces_Adam4[:30], '-r', label='lr_gamma=0.50', lw=1) plt.xlabel("Number of Updates during Training") plt.ylabel("Learning Rate") plt.legend() plt.savefig(r'Results/Part5_Fig15_Demonstrate_effect_of_scheduler.png') # ------------------------------------------- print('Done!')
def main(): device = torch.device('cpu') torch.manual_seed(9999) a = 1.261845 b = 1.234378 c = math.sqrt(a * a - b * b) nsamples = 512 batch_size = 512 # load previously generated results WaTraces_Adam, WbTraces_Adam, LossTraces_Adam, EpochTraces, minibatchTraces, updateTraces, lr_Adam, method_Adam, \ stepsize_Adam, unitvecWa_Adam, unitvecWb_Adam, dWa_Adam, dWb_Adam, VdWa_Adam, VdWb_Adam, SdWa_Adam, SdWb_Adam = \ read_optimizer_results(r'results/Adam_custom_implement_LR0.01_results.log') WaTraces_SGD, WbTraces_SGD, LossTraces_SGD, _, _, _, lr_SGD, method_SGD,\ stepsize_SGD, unitvecWa_SGD, unitvecWb_SGD, dWa_SGD, dWb_SGD, _, _, _, _ = \ read_optimizer_results(r'results/SGD_custom_implement_LR0.01_results.log') WaTraces_SGD_MOMENTUM, WbTraces_SGD_MOMENTUM, LossTraces_SGD_MOMENTUM, _, _, _, lr_SGD_MOMENTUM, method_SGD_MOMENTUM, \ stepsize_SGD_MOMENTUM, unitvecWa_SGD_MOMENTUM, unitvecWb_SGD_MOMENTUM, dWa_SGD_MOMENTUM, dWb_SGD_MOMENTUM, VdWa_SGD_MOMENTUM, VdWa_SGD_MOMENTUM, _, _ = \ read_optimizer_results(r'results/SGD_w_Momentum_custom_implement_LR0.01_results.log') WaTraces_RMSprop, WbTraces_RMSprop, LossTraces_RMSprop, _, _, _, lr_RMSprop, method_RMSprop, \ stepsize_RMSprop, unitvecWa_RMSprop, unitvecWb_RMSprop, dWa_RMSprop, dWb_RMSprop, VdWa_RMSprop, VdWb_RMSprop, SdWa_RMSprop, SdWb_RMSprop = \ read_optimizer_results(r'results/RMSprop_custom_implement_LR0.01_results.log') # First plot: stepsize vs update ------------ if True: fig = plt.figure(figsize=(12, 6)) ax = fig.add_subplot(111) line_SGD, = ax.plot(updateTraces, stepsize_SGD, '-y', lw=1, label=method_SGD + ' ' + lr_SGD) line_SGD_MOMENTUM, = ax.plot(updateTraces, stepsize_SGD_MOMENTUM, '-r', lw=1, label=method_SGD_MOMENTUM + ' ' + lr_SGD_MOMENTUM) line_RMSprop, = ax.plot(updateTraces, stepsize_RMSprop, '-b', lw=0.5, label=method_RMSprop + ' ' + lr_RMSprop) line_Adam, = ax.plot(updateTraces, stepsize_Adam, '-g', lw=1, label=method_Adam + ' ' + lr_Adam) leg = ax.legend() ax.set(xlim=(0, 600), ylim=(0, 0.01 + max(stepsize_SGD.max(), stepsize_SGD_MOMENTUM.max(), stepsize_RMSprop.max(), stepsize_Adam.max()))) ax.set_title('Effective Step size as a function of Updates', fontsize=16) plt.xlabel("# of Updates") plt.ylabel("Effective Step Size (AU)") fig.tight_layout() plt.savefig(r'results/Part5_Fig1_EffStepSize_vs_Updates.png') plt.show(block=False) # ------------------------------------------- nframes = len(WaTraces_Adam) # nframes = 600 # len(WaTraces_Adam) xy_dataset = EllipseDataset(nsamples, a, b, noise_scale=0.1) xy_dataloader = DataLoader(xy_dataset, batch_size=batch_size, shuffle=True, num_workers=0) Wa0 = WaTraces_Adam[0] Wb0 = WbTraces_Adam[0] # nWaGrids, nWbGrids = 200, 200 # WaGrid = np.linspace(0, 2.0, nWaGrids) # WbGrid = np.linspace(0, 2.0, nWbGrids) nWaGrids, nWbGrids = 250, 250 WaGrid = np.linspace(0, 2.5, nWaGrids) WbGrid = np.linspace(0, 2.5, nWbGrids) # WaGrid = np.linspace(0, 0.5, nWaGrids) # WbGrid = np.linspace(1.5, 2.0, nWbGrids) Wa2d, Wb2d = np.meshgrid(WaGrid, WaGrid) loss = np.zeros(Wa2d.shape) for i_batch, sample_batched in enumerate(xy_dataloader): x, y = sample_batched['x'], sample_batched['y'] for indexb, Wb in enumerate(WbGrid): for indexa, Wa in enumerate(WaGrid): y_pred_sqr = Wb**2 * (1.0 - (x + c)**2 / Wa**2) y_pred_sqr[ y_pred_sqr < 0.00000001] = 0.00000001 # handle negative values caused by noise y_pred = torch.sqrt(y_pred_sqr) loss[indexb, indexa] = (y_pred - y).pow(2).sum() # Second plot: plot Steps during the 1st few updates ------------ if True: fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111) ax.contour(WbGrid, WaGrid, loss, levels=15, linewidths=0.5, colors='gray') cntr1 = ax.contourf(WaGrid, WbGrid, loss, levels=100, cmap="RdBu_r") fig.colorbar(cntr1, ax=ax, shrink=0.75) # ax.set(xlim=(0, 2.0), ylim=(0, 2.0)) # ax.set(xlim=(0, 2.5), ylim=(0, 2.5)) ax.set(xlim=(0, 0.5), ylim=(1.5, 2.0)) ax.set_title('First 4 updates of Wa and Wb in training', fontsize=16) plt.xlabel("Wa") plt.ylabel("Wb") ax.set_aspect('equal', adjustable='box') ax.plot(a, b, 'yo', ms=3) ax.plot(Wa0, Wb0, 'ko', ms=3) ax.text(0.08, 1.81, 'Start', color="black", fontsize=14) for i in range(4): dWa_SGD = (WaTraces_SGD[i + 1] - WaTraces_SGD[i]) * 0.995 dWb_SGD = (WbTraces_SGD[i + 1] - WbTraces_SGD[i]) * 0.995 plt.arrow(WaTraces_SGD[i], WbTraces_SGD[i], dWa_SGD, dWb_SGD, color='yellow', linewidth=1, head_width=0.005, length_includes_head=True) dWa_SGD_M = (WaTraces_SGD_MOMENTUM[i + 1] - WaTraces_SGD_MOMENTUM[i]) * 0.995 dWb_SGD_M = (WbTraces_SGD_MOMENTUM[i + 1] - WbTraces_SGD_MOMENTUM[i]) * 0.995 plt.arrow(WaTraces_SGD_MOMENTUM[i], WbTraces_SGD_MOMENTUM[i], dWa_SGD_M, dWb_SGD_M, color='red', linewidth=1, head_width=0.005, length_includes_head=True) dWa_RMSprop = (WaTraces_RMSprop[i + 1] - WaTraces_RMSprop[i]) * 0.995 dWb_RMSprop = (WbTraces_RMSprop[i + 1] - WbTraces_RMSprop[i]) * 0.995 plt.arrow(WaTraces_RMSprop[i], WbTraces_RMSprop[i], dWa_RMSprop, dWb_RMSprop, color='blue', linewidth=1, head_width=0.005, length_includes_head=True) dWa_Adam = (WaTraces_Adam[i + 1] - WaTraces_Adam[i]) * 0.995 dWb_Adam = (WbTraces_Adam[i + 1] - WbTraces_Adam[i]) * 0.995 plt.arrow(WaTraces_Adam[i], WbTraces_Adam[i], dWa_Adam, dWb_Adam, color='green', linewidth=1, head_width=0.005, length_includes_head=True) line_SGD, = ax.plot([], [], '-y', lw=1, label=method_SGD + ' ' + lr_SGD) line_SGD_MOMENTUM, = ax.plot([], [], '-r', lw=1, label=method_SGD_MOMENTUM + ' ' + lr_SGD_MOMENTUM) line_RMSprop, = ax.plot([], [], '-b', lw=0.5, label=method_RMSprop + ' ' + lr_RMSprop) line_Adam, = ax.plot([], [], '-g', lw=1, label=method_Adam + ' ' + lr_Adam) leg = ax.legend() fig.tight_layout() plt.savefig('results/Part5_Fig2_FirstFewSteps.png') plt.show(block=False) # ------------------------------------------- fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111) ax.contour(WbGrid, WaGrid, loss, levels=15, linewidths=0.5, colors='gray') cntr1 = ax.contourf(WaGrid, WbGrid, loss, levels=100, cmap="RdBu_r") fig.colorbar(cntr1, ax=ax, shrink=0.75) # ax.set(xlim=(0, 2.0), ylim=(0, 2.0)) ax.set(xlim=(0, 2.5), ylim=(0, 2.5)) ax.set_title('Different Optimization Paths Using The Same Learning Rate', fontsize=16) plt.xlabel("Wa") plt.ylabel("Wb") ax.set_aspect('equal', adjustable='box') ax.plot(a, b, 'yo', ms=3) ax.plot(Wa0, Wb0, 'ko', ms=3) wblist_SGD = [] walist_SGD = [] point_SGD, = ax.plot([], [], 'yo', lw=0.5, markersize=4) line_SGD, = ax.plot([], [], '-y', lw=2, label=method_SGD + ' ' + lr_SGD) wblist_SGD_MOMENTUM = [] walist_SGD_MOMENTUM = [] point_SGD_MOMENTUM, = ax.plot([], [], 'ro', lw=0.5, markersize=4) line_SGD_MOMENTUM, = ax.plot([], [], '-r', lw=2, label=method_SGD_MOMENTUM + ' ' + lr_SGD_MOMENTUM) wblist_RMSprop = [] walist_RMSprop = [] point_RMSprop, = ax.plot([], [], 'mo', lw=0.5, markersize=4) line_RMSprop, = ax.plot([], [], '-m', lw=2, label=method_RMSprop + ' ' + lr_RMSprop) wblist_Adam = [] walist_Adam = [] point_Adam, = ax.plot([], [], 'go', lw=0.5, markersize=4) line_Adam, = ax.plot([], [], '-g', lw=2, label=method_Adam + ' ' + lr_Adam) text_update = ax.text(0.03, 0.03, '', transform=ax.transAxes, color="black", fontsize=14) leg = ax.legend() fig.tight_layout() plt.show(block=False) # initialization function: plot the background of each frame def init(): point_SGD.set_data([], []) line_SGD.set_data([], []) point_SGD_MOMENTUM.set_data([], []) line_SGD_MOMENTUM.set_data([], []) point_RMSprop.set_data([], []) line_RMSprop.set_data([], []) point_Adam.set_data([], []) line_Adam.set_data([], []) text_update.set_text('') return point_SGD, line_SGD, point_SGD_MOMENTUM, line_SGD_MOMENTUM, point_RMSprop, line_RMSprop, point_Adam, line_Adam, text_update # animation function. This is called sequentially def animate(i): if i == 0: wblist_SGD[:] = [] walist_SGD[:] = [] wblist_SGD_MOMENTUM[:] = [] walist_SGD_MOMENTUM[:] = [] wblist_RMSprop[:] = [] walist_RMSprop[:] = [] wblist_Adam[:] = [] walist_Adam[:] = [] wa_SGD, wb_SGD = WaTraces_SGD[i], WbTraces_SGD[i] wblist_SGD.append(wa_SGD) walist_SGD.append(wb_SGD) point_SGD.set_data(wa_SGD, wb_SGD) line_SGD.set_data(wblist_SGD, walist_SGD) wa_SGD_MOMENTUM, wb_SGD_MOMENTUM = WaTraces_SGD_MOMENTUM[ i], WbTraces_SGD_MOMENTUM[i] wblist_SGD_MOMENTUM.append(wa_SGD_MOMENTUM) walist_SGD_MOMENTUM.append(wb_SGD_MOMENTUM) point_SGD_MOMENTUM.set_data(wa_SGD_MOMENTUM, wb_SGD_MOMENTUM) line_SGD_MOMENTUM.set_data(wblist_SGD_MOMENTUM, walist_SGD_MOMENTUM) wa_RMSprop, wb_RMSprop = WaTraces_RMSprop[i], WbTraces_RMSprop[i] wblist_RMSprop.append(wa_RMSprop) walist_RMSprop.append(wb_RMSprop) point_RMSprop.set_data(wa_RMSprop, wb_RMSprop) line_RMSprop.set_data(wblist_RMSprop, walist_RMSprop) wa_Adam, wb_Adam = WaTraces_Adam[i], WbTraces_Adam[i] wblist_Adam.append(wa_Adam) walist_Adam.append(wb_Adam) point_Adam.set_data(wa_Adam, wb_Adam) line_Adam.set_data(wblist_Adam, walist_Adam) update, epoch, minibatch = updateTraces[ i], EpochTraces[i] + 1, minibatchTraces[i] + 1 text_update.set_text('Epoch={:d}, minibatch={:d}, Updates={:d}'.format( epoch, minibatch, update)) return point_SGD, line_SGD, point_SGD_MOMENTUM, line_SGD_MOMENTUM, point_RMSprop, line_RMSprop, point_Adam, line_Adam, text_update # call the animator. blit=True means only re-draw the parts that have changed. intervalms = 10 # this means 10 ms per frame anim = animation.FuncAnimation(fig, animate, init_func=init, frames=nframes, interval=intervalms, blit=True) # save the animation as an mp4. This requires ffmpeg or mencoder to be installed. anim.save('results/Part5_Fig3_Animation_of_loss_vs_TrainingUpdates.mp4', fps=30, bitrate=1800) plt.show() print('Done!')