예제 #1
0
def main(_):
    runner = beam.runners.DirectRunner()  # must create before flags are used

    equation_kwargs = json.loads(FLAGS.equation_kwargs)
    accuracy_orders = FLAGS.accuracy_orders

    if (equations.EQUATION_TYPES[FLAGS.equation_name].BASELINE is
            equations.Baseline.SPECTRAL and FLAGS.exact_filter_interval):
        exact_filter_interval = FLAGS.exact_filter_interval
    else:
        exact_filter_interval = None

    def create_equation(seed,
                        name=FLAGS.equation_name,
                        kwargs=equation_kwargs):
        equation_type = equations.CONSERVATIVE_EQUATION_TYPES[name]
        return equation_type(random_seed=seed, **kwargs)

    def integrate_baseline(equation,
                           accuracy_order,
                           times=np.arange(0, FLAGS.time_max,
                                           FLAGS.time_delta),
                           warmup=FLAGS.warmup,
                           integrate_method=FLAGS.integrate_method,
                           exact_filter_interval=exact_filter_interval):
        return integrate.integrate_baseline(equation, times, warmup,
                                            accuracy_order, integrate_method,
                                            exact_filter_interval)

    def create_equation_and_integrate(seed_and_accuracy_order):
        seed, accuracy_order = seed_and_accuracy_order
        equation = create_equation(seed)
        assert equation.CONSERVATIVE
        result = integrate_baseline(equation, accuracy_order)
        result.coords['sample'] = seed
        result.coords['accuracy_order'] = accuracy_order
        return (seed, result)

    pipeline = (
        beam.Create(list(range(FLAGS.num_samples)))
        | beam.FlatMap(lambda seed: [(seed, accuracy)
                                     for accuracy in accuracy_orders])
        | beam.Map(create_equation_and_integrate)
        | beam.CombinePerKey(xarray_beam.ConcatCombineFn('accuracy_order'))
        | beam.Map(lambda seed_and_ds: seed_and_ds[1].sortby('accuracy_order'))
        | beam.CombineGlobally(xarray_beam.ConcatCombineFn('sample'))
        | beam.Map(lambda ds: ds.sortby('sample'))
        | beam.Map(xarray_beam.write_netcdf, path=FLAGS.output_path))

    runner.run(pipeline)
예제 #2
0
def main(_):
  runner = beam.runners.DirectRunner()  # must create before flags are used

  equation_kwargs = json.loads(FLAGS.equation_kwargs)

  def create_equation(seed, name=FLAGS.equation_name, kwargs=equation_kwargs):
    equation_type = equations.FLUX_EQUATION_TYPES[name]
    return equation_type(random_seed=seed, **kwargs)

  def integrate_baseline(
      equation,
      times=np.arange(0, FLAGS.time_max, FLAGS.time_delta),
      warmup=FLAGS.warmup,
      integrate_method=FLAGS.integrate_method):
    return integrate.integrate_weno(equation, times, warmup, integrate_method)

  def create_equation_and_integrate(seed):
    equation = create_equation(seed)
    result = integrate_baseline(equation)
    result.coords['sample'] = seed
    return result

  pipeline = (
      beam.Create(list(range(FLAGS.num_samples)))
      | beam.Map(create_equation_and_integrate)
      | beam.CombineGlobally(xarray_beam.ConcatCombineFn('sample'))
      | beam.Map(lambda ds: ds.sortby('sample'))
      | beam.Map(xarray_beam.write_netcdf, path=FLAGS.output_path))

  runner.run(pipeline)
예제 #3
0
def main(_):
    runner = beam.runners.DirectRunner()  # must create before flags are used

    if (equations.EQUATION_TYPES[FLAGS.equation_name].BASELINE is
            equations.Baseline.SPECTRAL and FLAGS.exact_filter_interval):
        exact_filter_interval = FLAGS.exact_filter_interval
    else:
        exact_filter_interval = None

    hparams = training.load_hparams(FLAGS.checkpoint_dir)

    if FLAGS.equation_kwargs:
        hparams.set_hparam('equation_kwargs', FLAGS.equation_kwargs)

    integrate_all = functools.partial(
        integrate.integrate_exact_baseline_and_model,
        FLAGS.checkpoint_dir,
        hparams,
        times=np.arange(0, FLAGS.time_max + FLAGS.time_delta,
                        FLAGS.time_delta),
        warmup=FLAGS.warmup,
        integrate_method=FLAGS.integrate_method,
        exact_filter_interval=exact_filter_interval)

    pipeline = (
        beam.Create(list(range(FLAGS.num_samples)))
        | beam.Map(
            count_start_finish(
                lambda seed: integrate_all(seed).assign_coords(sample=seed),
                name='integrate_all'))
        | beam.CombineGlobally(xarray_beam.ConcatCombineFn('sample'))
        | beam.Map(lambda ds: ds.sortby('sample'))
        | beam.Map(xarray_beam.write_netcdf,
                   path=os.path.join(FLAGS.checkpoint_dir, FLAGS.output_name)))
    runner.run(pipeline)
def main(_, runner=None):
    if runner is None:
        # must create before flags are used
        runner = beam.runners.DirectRunner()

    equation_kwargs = json.loads(FLAGS.equation_kwargs)

    use_weno = (FLAGS.discretization_method == 'weno'
                or (FLAGS.discretization_method == 'exact'
                    and FLAGS.equation_name == 'burgers'))

    if (not use_weno and FLAGS.exact_filter_interval):
        exact_filter_interval = float(FLAGS.exact_filter_interval)
    else:
        exact_filter_interval = None

    def create_equation(seed,
                        name=FLAGS.equation_name,
                        kwargs=equation_kwargs):
        equation_type = (equations.FLUX_EQUATION_TYPES
                         if use_weno else equations.EQUATION_TYPES)[name]
        return equation_type(random_seed=seed, **kwargs)

    def do_integrate(equation,
                     times=np.arange(0, FLAGS.time_max + FLAGS.time_delta,
                                     FLAGS.time_delta),
                     warmup=FLAGS.warmup,
                     integrate_method=FLAGS.integrate_method):
        integrate_func = (integrate.integrate_weno
                          if use_weno else integrate.integrate_spectral)
        return integrate_func(equation,
                              times,
                              warmup,
                              integrate_method,
                              exact_filter_interval=exact_filter_interval)

    def create_equation_and_integrate(seed):
        equation = create_equation(seed)
        result = do_integrate(equation)
        result.coords['sample'] = seed
        return result

    pipeline = (beam.Create(list(range(FLAGS.num_samples)))
                | beam.Map(create_equation_and_integrate)
                | beam.CombineGlobally(xarray_beam.ConcatCombineFn('sample'))
                | beam.Map(lambda ds: ds.sortby('sample'))
                | beam.Map(xarray_beam.write_netcdf, path=FLAGS.output_path))

    runner.run(pipeline)
def main(_, runner=None):
  if runner is None:
    # must create before flags are used
    runner = beam.runners.DirectRunner()

  hparams = training.load_hparams(FLAGS.checkpoint_dir)

  if FLAGS.equation_kwargs:
    hparams.set_hparam('equation_kwargs', FLAGS.equation_kwargs)

  def load_initial_conditions(path=FLAGS.exact_solution_path,
                              num_samples=FLAGS.num_samples):
    ds = xarray_beam.read_netcdf(path)
    initial_conditions = duckarray.resample_mean(
        ds['y'].isel(time=0).data, hparams.resample_factor)

    if np.isnan(initial_conditions).any():
      raise ValueError('initial conditions cannot have NaNs')
    if ds.sizes['sample'] != num_samples:
      raise ValueError('invalid number of samples in exact dataset')

    for seed in range(num_samples):
      y0 = initial_conditions[seed, :]
      assert y0.ndim == 1
      yield (seed, y0)

  def run_integrate(
      seed_and_initial_condition,
      checkpoint_dir=FLAGS.checkpoint_dir,
      times=np.arange(0, FLAGS.time_max + FLAGS.time_delta, FLAGS.time_delta),
      warmup=FLAGS.warmup,
      integrate_method=FLAGS.integrate_method,
  ):
    random_seed, y0 = seed_and_initial_condition
    _, equation_coarse = equations.from_hparams(
        hparams, random_seed=random_seed)
    checkpoint_path = training.checkpoint_dir_to_path(checkpoint_dir)
    differentiator = integrate.SavedModelDifferentiator(
        checkpoint_path, equation_coarse, hparams)
    solution_model, num_evals_model = integrate.odeint(
        y0, differentiator, warmup+times, method=integrate_method)

    results = xarray.Dataset(
        data_vars={'y': (('time', 'x'), solution_model)},
        coords={'time': warmup+times,
                'x': equation_coarse.grid.solution_x,
                'num_evals': num_evals_model,
                'sample': random_seed})
    return results

  samples_path = os.path.join(FLAGS.checkpoint_dir, FLAGS.samples_output_name)
  mae_path = os.path.join(FLAGS.checkpoint_dir, FLAGS.mae_output_name)
  survival_path = os.path.join(FLAGS.checkpoint_dir, FLAGS.survival_output_name)

  def finalize(
      ds_model,
      exact_path=FLAGS.exact_solution_path,
      stop_times=json.loads(FLAGS.stop_times),
      quantiles=json.loads(FLAGS.quantiles),
  ):
    ds_model = ds_model.sortby('sample')
    xarray_beam.write_netcdf(ds_model, samples_path)

    # build combined dataset
    ds_exact = xarray_beam.read_netcdf(exact_path)
    ds = ds_model.rename({'y': 'y_model', 'x': 'x_low'})
    ds['y_exact'] = ds_exact['y'].rename({'x': 'x_high'})
    unified = analysis.unify_x_coords(ds)

    # calculate MAE
    results = []
    for time_max in stop_times:
      ds_sel = unified.sel(time=slice(None, time_max))
      mae = abs(ds_sel.drop('y_exact') - ds_sel.y_exact).mean(
          ['x', 'time'], skipna=False)
      results.append(mae)
    dim = pandas.Index(stop_times, name='time_max')
    mae_all = xarray.concat(results, dim=dim)
    xarray_beam.write_netcdf(mae_all, mae_path)

    # calculate survival
    survival_all = xarray.concat(
        [analysis.mostly_good_survival(ds, q) for q in quantiles],
        dim=pandas.Index(quantiles, name='quantile'))
    xarray_beam.write_netcdf(survival_all, survival_path)

  pipeline = (
      'create' >> beam.Create(range(1))
      | 'load' >> beam.FlatMap(lambda _: load_initial_conditions())
      | 'reshuffle' >> beam.Reshuffle()
      | 'integrate' >> beam.Map(
          count_start_finish(run_integrate, name='run_integrate'))
      | 'combine' >> beam.CombineGlobally(xarray_beam.ConcatCombineFn('sample'))
      | 'finalize' >> beam.Map(finalize)
  )
  runner.run(pipeline)