Esempio n. 1
0
        help='whether to use GAN-based training, default False (use L2-based)')
    args.add_argument('--fname',
                      type=str,
                      default='SHO_hypertune.csv',
                      help='path to save results of replications')
    args.add_argument('--nreps',
                      type=int,
                      default=20,
                      help='number of replications to run')
    args.add_argument('--pkey',
                      type=str,
                      default='sho',
                      help='problem key as a string')
    args = args.parse_args()

    _N_REPS = args.nreps
    _PROBLEM = get_problem(args.pkey)

    handle_overwrite(args.fname)

    hyper_space = get_niters(args.pkey, args.gan)
    hyper_space = dict_product(hyper_space)

    if args.gan:
        results = map(gan_exp_with_hypers, hyper_space)
    else:
        results = map(L2_exp_with_hypers, hyper_space)

    pd.DataFrame().from_records(results).to_csv(args.fname)
    print(f'Saved results to {args.fname}')
Esempio n. 2
0
def train_GAN_2D(G,
                 D,
                 problem,
                 method='unsupervised',
                 niters=100,
                 g_lr=1e-3,
                 g_betas=(0.0, 0.9),
                 d_lr=1e-3,
                 d_betas=(0.0, 0.9),
                 lr_schedule=True,
                 gamma=0.999,
                 obs_every=1,
                 d1=1.,
                 d2=1.,
                 G_iters=1,
                 D_iters=1,
                 wgan=True,
                 gp=0.1,
                 conditional=True,
                 log=True,
                 plot=True,
                 save=False,
                 dirname='train_GAN',
                 config=None,
                 save_for_animation=False,
                 **kwargs):
    """
    Train/test GAN method: supervised/semisupervised/unsupervised
    """
    assert method in ['supervised', 'semisupervised',
                      'unsupervised'], f'Method {method} not understood!'

    dirname = os.path.join(this_dir, '../experiments/runs', dirname)
    if plot and save:
        handle_overwrite(dirname)

    # validation: fixed grid/solution
    x, y = problem.get_grid()
    grid = torch.cat((x, y), 1)
    soln = problem.get_solution(x, y)

    # # observer mask and masked grid/solution (t_obs/y_obs)
    observers = torch.arange(0, len(grid), obs_every)
    # grid_obs = grid[observers, :]
    # soln_obs = soln[observers, :]

    # labels
    real_label = 1
    fake_label = -1 if wgan else 0
    real_labels = torch.full((len(grid), ), real_label).reshape(-1, 1)
    fake_labels = torch.full((len(grid), ), fake_label).reshape(-1, 1)
    # masked label vectors
    real_labels_obs = real_labels[observers, :]
    fake_labels_obs = fake_labels[observers, :]

    # optimization
    optiG = torch.optim.Adam(G.parameters(), lr=g_lr, betas=g_betas)
    optiD = torch.optim.Adam(D.parameters(), lr=d_lr, betas=d_betas)
    if lr_schedule:
        lr_scheduler_G = torch.optim.lr_scheduler.ExponentialLR(
            optimizer=optiG, gamma=gamma)
        lr_scheduler_D = torch.optim.lr_scheduler.ExponentialLR(
            optimizer=optiD, gamma=gamma)

    # losses
    mse = nn.MSELoss()
    bce = nn.BCELoss()
    wass = lambda y_true, y_pred: torch.mean(y_true * y_pred)
    criterion = wass if wgan else bce

    # history
    losses = {'G': [], 'D': []}
    mses = {'train': [], 'val': []}
    preds = {'pred': [], 'soln': []}

    for epoch in range(niters):
        # Train Generator
        for p in D.parameters():
            p.requires_grad = False  # turn off computation for D

        for i in range(G_iters):
            xs, ys = problem.get_grid_sample()
            grid_samp = torch.cat((xs, ys), 1)
            pred = G(grid_samp)
            residuals = problem.get_equation(pred, xs, ys)

            # idea: add noise to relax from dirac delta at 0 to distb'n
            # + torch.normal(0, .1/(i+1), size=residuals.shape)
            real = torch.zeros_like(residuals)
            fake = residuals

            optiG.zero_grad()
            g_loss = criterion(D(fake), real_labels)
            # g_loss = criterion(D(fake), torch.ones_like(fake))
            g_loss.backward(retain_graph=True)
            optiG.step()

        # Train Discriminator
        for p in D.parameters():
            p.requires_grad = True  # turn on computation for D

        for i in range(D_iters):
            if wgan:
                norm_penalty = calc_gradient_penalty(D,
                                                     real,
                                                     fake,
                                                     gp,
                                                     cuda=False)
            else:
                norm_penalty = torch.zeros(1)

            # print(real.shape, fake.shape)
            real_loss = criterion(D(real), real_labels)
            # real_loss = criterion(D(real), torch.ones_like(real))
            fake_loss = criterion(D(fake), fake_labels)
            # fake_loss = criterion(D(fake), torch.zeros_like(fake))

            optiD.zero_grad()
            d_loss = (real_loss + fake_loss) / 2 + norm_penalty
            d_loss.backward(retain_graph=True)
            optiD.step()

        losses['D'].append(d_loss.item())
        losses['G'].append(g_loss.item())

        if lr_schedule:
            lr_scheduler_G.step()
            lr_scheduler_D.step()

        # train MSE: grid sample vs true soln
        # grid_samp, sort_ids = torch.sort(grid_samp, axis=0)
        pred = G(grid_samp)
        pred_adj = problem.adjust(pred, xs, ys)['pred']
        sol_samp = problem.get_solution(xs, ys)
        train_mse = mse(pred_adj, sol_samp).item()
        mses['train'].append(train_mse)

        # val MSE: fixed grid vs true soln
        val_pred = G(grid)
        val_pred_adj = problem.adjust(val_pred, x, y)['pred']
        val_mse = mse(val_pred_adj, soln).item()
        mses['val'].append(val_mse)

        # save preds for animation
        preds['pred'].append(val_pred_adj.detach())
        preds['soln'].append(soln.detach())

        try:
            if (epoch + 1) % 10 == 0:
                # mean of val mses for last 10 steps
                track.log(mean_squared_error=np.mean(mses['val'][-10:]))
                # mean of G - D loss for last 10 steps
                # loss_diff = np.mean(np.abs(losses['G'][-10] - losses['D'][-10]))
                # track.log(mean_squared_error=loss_diff)
        except Exception as e:
            # print(f'Caught exception {e}')
            pass

        if log:
            print(
                f'Step {epoch}: G Loss: {g_loss.item():.4e} | D Loss: {d_loss.item():.4e} | Train MSE {train_mse:.4e} | Val MSE {val_mse:.4e}'
            )

    if plot:
        pred_dict, diff_dict = problem.get_plot_dicts(G(grid), x, y, soln)
        plot_results(mses,
                     losses,
                     grid.detach(),
                     pred_dict,
                     diff_dict=diff_dict,
                     save=save,
                     dirname=dirname,
                     logloss=False,
                     alpha=0.7)

    if save:
        write_config(config, os.path.join(dirname, 'config.yaml'))

    if save_for_animation:
        if not os.path.exists(dirname):
            os.mkdir(dirname)
        anim_dir = os.path.join(dirname, "animation")
        print(f'Saving animation traces to {anim_dir}')
        if not os.path.exists(anim_dir):
            os.mkdir(anim_dir)
        np.save(os.path.join(anim_dir, "grid"), grid.detach())
        for k, v in preds.items():
            v = np.hstack(v)
            # TODO: for systems (i.e. multi-dim preds),
            # hstack flattens preds, need to use dstack
            # v = np.dstack(v)
            np.save(os.path.join(anim_dir, f"{k}_pred"), v)

    return {'mses': mses, 'model': G, 'losses': losses}
Esempio n. 3
0
def train_L2_2D(model,
                problem,
                method='unsupervised',
                niters=100,
                lr=1e-3,
                betas=(0, 0.9),
                lr_schedule=True,
                gamma=0.999,
                obs_every=1,
                d1=1,
                d2=1,
                log=True,
                plot=True,
                save=False,
                dirname='train_L2',
                config=None,
                loss_fn=None,
                save_for_animation=False,
                **kwargs):
    """
    Train/test Lagaris method: supervised/semisupervised/unsupervised
    """
    assert method in ['supervised', 'semisupervised',
                      'unsupervised'], f'Method {method} not understood!'

    dirname = os.path.join(this_dir, '../experiments/runs', dirname)
    if plot and save:
        handle_overwrite(dirname)

    # validation: fixed grid/solution
    x, y = problem.get_grid()
    grid = torch.cat((x, y), 1)
    sol = problem.get_solution(x, y)

    # optimizers & loss functions
    opt = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)
    if loss_fn:
        mse = eval(f"torch.nn.{loss_fn}()")
    else:
        mse = torch.nn.MSELoss()
    # lr scheduler
    if lr_schedule:
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=opt,
                                                              gamma=gamma)

    loss_trace = []
    mses = {'train': [], 'val': []}
    preds = {'pred': [], 'soln': []}

    for i in range(niters):
        xs, ys = problem.get_grid_sample()
        grid_samp = torch.cat((xs, ys), 1)
        pred = model(grid_samp)
        residuals = problem.get_equation(pred, xs, ys)
        loss = mse(residuals, torch.zeros_like(residuals))
        loss_trace.append(loss.item())

        # train MSE: grid sample vs true soln
        # grid_samp, sort_ids = torch.sort(grid_samp, axis=0)
        pred = model(grid_samp)
        try:
            pred_adj = problem.adjust(pred, xs, ys)['pred']
            sol_samp = problem.get_solution(xs, ys)
            train_mse = mse(pred_adj, sol_samp).item()
        except Exception as e:
            print(f'Exception: {e}')
        mses['train'].append(train_mse)

        # val MSE: fixed grid vs true soln
        val_pred = model(grid)
        val_pred_adj = problem.adjust(val_pred, x, y)['pred']
        val_mse = mse(val_pred_adj, sol).item()
        mses['val'].append(val_mse)

        # store preds for animation
        preds['pred'].append(val_pred_adj.detach())
        preds['soln'].append(sol.detach())

        try:
            if (i + 1) % 10 == 0:
                # mean of val mses for last 10 steps
                track.log(mean_squared_error=np.mean(mses['val'][-10:]))
        except Exception as e:
            # print(f'Caught exception {e}')
            pass

        if log:
            print(
                f'Step {i}: Loss {loss.item():.4e} | Train MSE {train_mse:.4e} | Val MSE {val_mse:.4e}'
            )

        opt.zero_grad()
        loss.backward(retain_graph=True)
        opt.step()
        if lr_schedule:
            lr_scheduler.step()

    if plot:
        loss_dict = {}
        if method == 'supervised':
            loss_dict['$L_S$'] = loss_trace
        elif method == 'semisupervised':
            loss_dict['$L_S$'] = [l[0] for l in loss_trace]
            loss_dict['$L_U$'] = [l[1] for l in loss_trace]
        else:
            loss_dict['$L_U$'] = loss_trace

        save_to = os.path.join(this_dir, '../experiments/runs', dirname)

        pred_dict, diff_dict = problem.get_plot_dicts(model(grid), x, y, sol)
        plot_results(mses,
                     loss_dict,
                     grid.detach(),
                     pred_dict,
                     diff_dict=diff_dict,
                     save=save,
                     dirname=dirname,
                     logloss=True,
                     alpha=0.7)

    if save:
        write_config(config, os.path.join(dirname, 'config.yaml'))

    if save_for_animation:
        if not os.path.exists(dirname):
            os.mkdir(dirname)
        anim_dir = os.path.join(dirname, "animation")
        print(f'Saving animation traces to {anim_dir}')
        if not os.path.exists(anim_dir):
            os.mkdir(anim_dir)
        np.save(os.path.join(anim_dir, "grid"), grid.detach())
        for k, v in preds.items():
            v = np.hstack(v)
            # TODO: for systems (i.e. multi-dim preds),
            # hstack flattens preds, need to use dstack
            # v = np.dstack(v)
            np.save(os.path.join(anim_dir, f"{k}_pred"), v)

    return {'mses': mses, 'model': model, 'losses': loss_trace}
Esempio n. 4
0
        type=str,
        default='EXP',
        help=
        'problem to run (exp=Exponential, sho=SimpleOscillator, nlo=NonlinearOscillator)'
    )
    args.add_argument('--nreps',
                      type=int,
                      default=10,
                      help='number of random seeds to try')
    args.add_argument('--fname',
                      type=str,
                      default='rand_reps',
                      help='file to save numpy results of MSEs')
    args = args.parse_args()

    handle_overwrite(args.fname)
    handle_overwrite(args.fname + '.npy')

    params = get_config(args.pkey)

    # turn off plotting / logging
    params['training']['log'] = False
    params['training']['plot'] = False
    # turn off saving
    params['training']['save'] = False
    params['training']['save_for_animation'] = False

    # np.random.seed(42)
    # seeds = np.random.randint(int(1e6), size=args.nreps)
    seeds = list(range(args.nreps))
    print("Using seeds: ", seeds)