def main(_): runner = beam.runners.DirectRunner() # must create before flags are used equation_kwargs = json.loads(FLAGS.equation_kwargs) accuracy_orders = FLAGS.accuracy_orders if (equations.EQUATION_TYPES[FLAGS.equation_name].BASELINE is equations.Baseline.SPECTRAL and FLAGS.exact_filter_interval): exact_filter_interval = FLAGS.exact_filter_interval else: exact_filter_interval = None def create_equation(seed, name=FLAGS.equation_name, kwargs=equation_kwargs): equation_type = equations.CONSERVATIVE_EQUATION_TYPES[name] return equation_type(random_seed=seed, **kwargs) def integrate_baseline(equation, accuracy_order, times=np.arange(0, FLAGS.time_max, FLAGS.time_delta), warmup=FLAGS.warmup, integrate_method=FLAGS.integrate_method, exact_filter_interval=exact_filter_interval): return integrate.integrate_baseline(equation, times, warmup, accuracy_order, integrate_method, exact_filter_interval) def create_equation_and_integrate(seed_and_accuracy_order): seed, accuracy_order = seed_and_accuracy_order equation = create_equation(seed) assert equation.CONSERVATIVE result = integrate_baseline(equation, accuracy_order) result.coords['sample'] = seed result.coords['accuracy_order'] = accuracy_order return (seed, result) pipeline = ( beam.Create(list(range(FLAGS.num_samples))) | beam.FlatMap(lambda seed: [(seed, accuracy) for accuracy in accuracy_orders]) | beam.Map(create_equation_and_integrate) | beam.CombinePerKey(xarray_beam.ConcatCombineFn('accuracy_order')) | beam.Map(lambda seed_and_ds: seed_and_ds[1].sortby('accuracy_order')) | beam.CombineGlobally(xarray_beam.ConcatCombineFn('sample')) | beam.Map(lambda ds: ds.sortby('sample')) | beam.Map(xarray_beam.write_netcdf, path=FLAGS.output_path)) runner.run(pipeline)
def main(_): runner = beam.runners.DirectRunner() # must create before flags are used equation_kwargs = json.loads(FLAGS.equation_kwargs) def create_equation(seed, name=FLAGS.equation_name, kwargs=equation_kwargs): equation_type = equations.FLUX_EQUATION_TYPES[name] return equation_type(random_seed=seed, **kwargs) def integrate_baseline( equation, times=np.arange(0, FLAGS.time_max, FLAGS.time_delta), warmup=FLAGS.warmup, integrate_method=FLAGS.integrate_method): return integrate.integrate_weno(equation, times, warmup, integrate_method) def create_equation_and_integrate(seed): equation = create_equation(seed) result = integrate_baseline(equation) result.coords['sample'] = seed return result pipeline = ( beam.Create(list(range(FLAGS.num_samples))) | beam.Map(create_equation_and_integrate) | beam.CombineGlobally(xarray_beam.ConcatCombineFn('sample')) | beam.Map(lambda ds: ds.sortby('sample')) | beam.Map(xarray_beam.write_netcdf, path=FLAGS.output_path)) runner.run(pipeline)
def main(_): runner = beam.runners.DirectRunner() # must create before flags are used if (equations.EQUATION_TYPES[FLAGS.equation_name].BASELINE is equations.Baseline.SPECTRAL and FLAGS.exact_filter_interval): exact_filter_interval = FLAGS.exact_filter_interval else: exact_filter_interval = None hparams = training.load_hparams(FLAGS.checkpoint_dir) if FLAGS.equation_kwargs: hparams.set_hparam('equation_kwargs', FLAGS.equation_kwargs) integrate_all = functools.partial( integrate.integrate_exact_baseline_and_model, FLAGS.checkpoint_dir, hparams, times=np.arange(0, FLAGS.time_max + FLAGS.time_delta, FLAGS.time_delta), warmup=FLAGS.warmup, integrate_method=FLAGS.integrate_method, exact_filter_interval=exact_filter_interval) pipeline = ( beam.Create(list(range(FLAGS.num_samples))) | beam.Map( count_start_finish( lambda seed: integrate_all(seed).assign_coords(sample=seed), name='integrate_all')) | beam.CombineGlobally(xarray_beam.ConcatCombineFn('sample')) | beam.Map(lambda ds: ds.sortby('sample')) | beam.Map(xarray_beam.write_netcdf, path=os.path.join(FLAGS.checkpoint_dir, FLAGS.output_name))) runner.run(pipeline)
def main(_, runner=None): if runner is None: # must create before flags are used runner = beam.runners.DirectRunner() equation_kwargs = json.loads(FLAGS.equation_kwargs) use_weno = (FLAGS.discretization_method == 'weno' or (FLAGS.discretization_method == 'exact' and FLAGS.equation_name == 'burgers')) if (not use_weno and FLAGS.exact_filter_interval): exact_filter_interval = float(FLAGS.exact_filter_interval) else: exact_filter_interval = None def create_equation(seed, name=FLAGS.equation_name, kwargs=equation_kwargs): equation_type = (equations.FLUX_EQUATION_TYPES if use_weno else equations.EQUATION_TYPES)[name] return equation_type(random_seed=seed, **kwargs) def do_integrate(equation, times=np.arange(0, FLAGS.time_max + FLAGS.time_delta, FLAGS.time_delta), warmup=FLAGS.warmup, integrate_method=FLAGS.integrate_method): integrate_func = (integrate.integrate_weno if use_weno else integrate.integrate_spectral) return integrate_func(equation, times, warmup, integrate_method, exact_filter_interval=exact_filter_interval) def create_equation_and_integrate(seed): equation = create_equation(seed) result = do_integrate(equation) result.coords['sample'] = seed return result pipeline = (beam.Create(list(range(FLAGS.num_samples))) | beam.Map(create_equation_and_integrate) | beam.CombineGlobally(xarray_beam.ConcatCombineFn('sample')) | beam.Map(lambda ds: ds.sortby('sample')) | beam.Map(xarray_beam.write_netcdf, path=FLAGS.output_path)) runner.run(pipeline)
def main(_, runner=None): if runner is None: # must create before flags are used runner = beam.runners.DirectRunner() hparams = training.load_hparams(FLAGS.checkpoint_dir) if FLAGS.equation_kwargs: hparams.set_hparam('equation_kwargs', FLAGS.equation_kwargs) def load_initial_conditions(path=FLAGS.exact_solution_path, num_samples=FLAGS.num_samples): ds = xarray_beam.read_netcdf(path) initial_conditions = duckarray.resample_mean( ds['y'].isel(time=0).data, hparams.resample_factor) if np.isnan(initial_conditions).any(): raise ValueError('initial conditions cannot have NaNs') if ds.sizes['sample'] != num_samples: raise ValueError('invalid number of samples in exact dataset') for seed in range(num_samples): y0 = initial_conditions[seed, :] assert y0.ndim == 1 yield (seed, y0) def run_integrate( seed_and_initial_condition, checkpoint_dir=FLAGS.checkpoint_dir, times=np.arange(0, FLAGS.time_max + FLAGS.time_delta, FLAGS.time_delta), warmup=FLAGS.warmup, integrate_method=FLAGS.integrate_method, ): random_seed, y0 = seed_and_initial_condition _, equation_coarse = equations.from_hparams( hparams, random_seed=random_seed) checkpoint_path = training.checkpoint_dir_to_path(checkpoint_dir) differentiator = integrate.SavedModelDifferentiator( checkpoint_path, equation_coarse, hparams) solution_model, num_evals_model = integrate.odeint( y0, differentiator, warmup+times, method=integrate_method) results = xarray.Dataset( data_vars={'y': (('time', 'x'), solution_model)}, coords={'time': warmup+times, 'x': equation_coarse.grid.solution_x, 'num_evals': num_evals_model, 'sample': random_seed}) return results samples_path = os.path.join(FLAGS.checkpoint_dir, FLAGS.samples_output_name) mae_path = os.path.join(FLAGS.checkpoint_dir, FLAGS.mae_output_name) survival_path = os.path.join(FLAGS.checkpoint_dir, FLAGS.survival_output_name) def finalize( ds_model, exact_path=FLAGS.exact_solution_path, stop_times=json.loads(FLAGS.stop_times), quantiles=json.loads(FLAGS.quantiles), ): ds_model = ds_model.sortby('sample') xarray_beam.write_netcdf(ds_model, samples_path) # build combined dataset ds_exact = xarray_beam.read_netcdf(exact_path) ds = ds_model.rename({'y': 'y_model', 'x': 'x_low'}) ds['y_exact'] = ds_exact['y'].rename({'x': 'x_high'}) unified = analysis.unify_x_coords(ds) # calculate MAE results = [] for time_max in stop_times: ds_sel = unified.sel(time=slice(None, time_max)) mae = abs(ds_sel.drop('y_exact') - ds_sel.y_exact).mean( ['x', 'time'], skipna=False) results.append(mae) dim = pandas.Index(stop_times, name='time_max') mae_all = xarray.concat(results, dim=dim) xarray_beam.write_netcdf(mae_all, mae_path) # calculate survival survival_all = xarray.concat( [analysis.mostly_good_survival(ds, q) for q in quantiles], dim=pandas.Index(quantiles, name='quantile')) xarray_beam.write_netcdf(survival_all, survival_path) pipeline = ( 'create' >> beam.Create(range(1)) | 'load' >> beam.FlatMap(lambda _: load_initial_conditions()) | 'reshuffle' >> beam.Reshuffle() | 'integrate' >> beam.Map( count_start_finish(run_integrate, name='run_integrate')) | 'combine' >> beam.CombineGlobally(xarray_beam.ConcatCombineFn('sample')) | 'finalize' >> beam.Map(finalize) ) runner.run(pipeline)