def handle_plot_ckpt(do_plot=False):
            train_loss = np.array(to_scalar([loss_1, loss_2, loss_3])).mean(0)
            info['train_losses'].append(train_loss)
            info['train_cnts'].append(train_cnt)
            vx, vy, _ = data_loader.validation_data()
            vloss_1, vloss_2, vloss_3, vx_d, vz_e_x, vz_q_x, vlatents = forward_pass(
                vx, vy)
            test_loss = np.array(to_scalar([vloss_1, vloss_2,
                                            vloss_3])).mean(0)
            info['test_losses'].append(test_loss)
            info['test_cnts'].append(train_cnt)

            print(
                'examples %010d train loss %03.03f test loss %03.03f' %
                (train_cnt, info['train_losses'][-1], info['test_losses'][-1]))
            if do_plot:
                info['last_plot'] = train_cnt
                plot_name = os.path.join(
                    default_base_savedir,
                    basename + "_%010dloss.png" % train_cnt)
                print('plotting: %s' % plot_name)
                n = 3
                plot_losses(info['train_cnts'],
                            info['train_losses'],
                            info['test_cnts'],
                            info['test_losses'],
                            name=plot_name,
                            rolling_length=n)
def handle_plot_ckpt(do_plot, train_cnt, avg_train_loss):
    info['train_losses'].append(avg_train_loss)
    info['train_cnts'].append(train_cnt)
    test_loss = test_acn(train_cnt,do_plot)
    info['test_losses'].append(test_loss)
    info['test_cnts'].append(train_cnt)
    print('examples %010d train loss %03.03f test loss %03.03f' %(train_cnt,
                              info['train_losses'][-1], info['test_losses'][-1]))
    # plot
    if do_plot:
        info['last_plot'] = train_cnt
        rolling = 3
        if len(info['train_losses'])<rolling*3:
            rolling = 1
        print('adding last loss plot', train_cnt)
        plot_name = vae_base_filepath + "_%010dloss.png"%train_cnt
        print('plotting loss: %s with %s points'%(plot_name, len(info['train_cnts'])))
        plot_losses(info['train_cnts'],
                    info['train_losses'],
                    info['test_cnts'],
                    info['test_losses'], name=plot_name, rolling_length=rolling)