def test_sampler_geneartor(): size = 10 generator = Generator1D(size) sampler_generator = SamplerGenerator(generator) x = sampler_generator.get_examples() assert isinstance(x, (list, tuple)) x, = x assert x.shape == (size, 1) print('testing generator name: ', sampler_generator)
def test_sampler_geneartor(): size = 10 generator = Generator1D(size) sampler_generator = SamplerGenerator(generator) x = sampler_generator.get_examples() assert isinstance(x, (list, tuple)) x, = x assert x.shape == (size, 1) str(sampler_generator) repr(sampler_generator)
def x(): g = EnsembleGenerator( Generator1D(GRID[0], R_MIN, R_MAX, method='uniform'), Generator1D(GRID[1], 0, np.pi * 2, method='uniform'), Generator1D(GRID[2], Z_MIN, Z_MAX, method='uniform'), ) return SamplerGenerator(g).get_examples()
def six_walls(): x0, y0, z0 = [random.random() for _ in range(3)] x1, y1, z1 = [2 + random.random() for _ in range(3)] x0_val, y0_val, z0_val = [FCNNSplitInput(3, 1) for _ in range(3)] x1_val, y1_val, z1_val = [FCNNSplitInput(3, 1) for _ in range(3)] x0_prime, y0_prime, z0_prime = [FCNNSplitInput(3, 1) for _ in range(3)] x1_prime, y1_prime, z1_prime = [FCNNSplitInput(3, 1) for _ in range(3)] EPS = 0.3 g = Generator3D( xyz_min=(x0 + EPS, y0 + EPS, z0 + EPS), xyz_max=(x1 - EPS, y1 - EPS, z1 - EPS), method='equally-spaced', ) g = StaticGenerator(g) g = SamplerGenerator(g) return ( g, x0, y0, z0, x1, y1, z1, x0_val, y0_val, z0_val, x1_val, y1_val, z1_val, x0_prime, y0_prime, z0_prime, x1_prime, y1_prime, z1_prime, )
def __init__( self, diff_eqs, conditions, nets=None, train_generator=None, valid_generator=None, analytic_solutions=None, optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_input_units=None, n_output_units=None, # deprecated arguments are listed below shuffle=None, batch_size=None): # deprecate argument `shuffle` if shuffle: warnings.warn( "param `shuffle` is deprecated and ignored; shuffling should be performed by generators", FutureWarning, ) # deprecate argument `batch_size` if batch_size is not None: warnings.warn( "param `batch_size` is deprecated and ignored; specify n_batches_train and n_batches_valid instead", FutureWarning, ) self.diff_eqs = diff_eqs self.conditions = conditions self.n_funcs = len(conditions) if nets is None: self.nets = [ FCNN(n_input_units=n_input_units, n_output_units=n_output_units, hidden_units=(32, 32), actv=nn.Tanh) for _ in range(self.n_funcs) ] else: self.nets = nets if train_generator is None: raise ValueError("train_generator must be specified") if valid_generator is None: raise ValueError("valid_generator must be specified") self.metrics_fn = metrics if metrics else {} # For backward compatibility with the legacy `analytic_solutions` argument if analytic_solutions: warnings.warn( 'The `analytic_solutions` argument is deprecated and could lead to unstable behavior. ' 'Pass a `metrics` dict instead.', FutureWarning, ) def analytic_mse(*args): x = args[-n_input_units:] u_hat = analytic_solutions(*x) u = args[:-n_input_units] u, u_hat = torch.stack(u), torch.stack(u_hat) return ((u - u_hat)**2).mean() if 'analytic_mse' in self.metrics_fn: warnings.warn( "Ignoring `analytic_solutions` in presence of key 'analytic_mse' in `metrics`", FutureWarning, ) else: self.metrics_fn['analytic_mse'] = analytic_mse # metric history, keys will be train_loss, valid_loss, train__<metric_name>, valid__<metric_name>. # For compatibility with ode.py and pde.py, # double underscore are used between 'train'/'valid' and custom metric names. self.metrics_history = {} self.metrics_history.update({'train_loss': [], 'valid_loss': []}) self.metrics_history.update( {'train__' + name: [] for name in self.metrics_fn}) self.metrics_history.update( {'valid__' + name: [] for name in self.metrics_fn}) self.optimizer = optimizer if optimizer else Adam( chain.from_iterable(n.parameters() for n in self.nets)) if criterion is None: self.criterion = lambda r: (r**2).mean() elif isinstance(criterion, nn.modules.loss._Loss): self.criterion = lambda r: criterion(r, torch.zeros_like(r)) else: self.criterion = criterion def make_pair_dict(train=None, valid=None): return {'train': train, 'valid': valid} self.generator = make_pair_dict( train=SamplerGenerator(train_generator), valid=SamplerGenerator(valid_generator), ) # number of batches for training / validation; self.n_batches = make_pair_dict(train=n_batches_train, valid=n_batches_valid) # current batch of samples, kept for additional_loss term to use self._batch_examples = make_pair_dict() # current network with lowest loss self.best_nets = None # current lowest loss self.lowest_loss = None # local epoch in a `.fit` call, should only be modified inside self.fit() self.local_epoch = 0 # maximum local epochs to run in a `.fit()` call, should only set by inside self.fit() self._max_local_epoch = 0 # controls early stopping, should be set to False at the beginning of a `.fit()` call # and optionally set to False by `callbacks` in `.fit()` to support early stopping self._stop_training = False # the _phase variable is registered for callback functions to access self._phase = None