elif isinstance(visible_devices, int): visible_devices = str(visible_devices) logging.info( 'Setting CUDA_VISIBLE_DEVICES={}'.format(visible_devices)) os.environ['CUDA_VISIBLE_DEVICES'] = visible_devices import tensorflow as tf if self.allow_growth is not None: config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True # pylint: disable=no-member else: config = None if self.eager: tf.compat.v1.enable_eager_execution(config=config) elif config is not None: tf.keras.backend.set_session(tf.compat.v1.Session(config=config)) def get_config(self): return dict(allow_growth=self.allow_growth, visible_devices=self.visible_devices, eager=self.eager) session_options = registry.Registry( 'session_options', registry.has_attrs_validator('configure_session')) get = session_options.get deserialize = session_options.deserialize serialize = session_options.serialize register = session_options.register
return tf.nest.map_structure(self._examples_per_epoch, split) def get_base_dataset(self, split=tfds.Split.TRAIN): return tf.nest.map_structure(self._get_base_dataset, split) def get_config(self): objective = self.objective return dict( loss=_keras.losses.serialize(self.loss), metrics=[_keras.metrics.serialize(m) for m in self.metrics], objective=None if objective is None else objective.get_config(), input_spec=get_input_spec_config(self.input_spec), output_spec=get_input_spec_config(self.output_spec)) problems = registry.Registry('problems', registry.subclass_validator(Problem)) def get_input_spec_config(input_spec): if input_spec is None: return None return dict(dtype=repr(input_spec.dtype)[3:], shape=input_spec.shape, ndim=input_spec.ndim, max_ndim=input_spec.max_ndim, min_ndim=input_spec.min_ndim, axes=input_spec.axes) def get_input_spec(identifier): if identifier is None or isinstance(identifier, tf.keras.layers.InputSpec):
steps_per_epoch=train_steps, validation_steps=val_steps, initial_epoch=initial_epoch, ) return history def evaluate(self, checkpoint=LATEST, verbose=True): val_ds, val_steps = self._dataset_and_steps('validation') model, callbacks = self._get_model_and_callbacks() chkpt_callback = BetterModelCheckpoint(self.chkpt_dir) chkpt_callback.set_model(model) chkpt_callback.restore(checkpoint) model.evaluate(val_ds, steps=val_steps, callbacks=callbacks, verbose=verbose) trainers = registry.Registry('trainers', registry.subclass_validator(Trainer)) trainers.register(Trainer) def _config_path(log_dir, epoch=0): return os.path.join(log_dir, 'config-{}.yaml'.format(epoch)) get = trainers.get deserialize = trainers.deserialize serialize = trainers.serialize register = trainers.register
def __call__(self, dataset): return self.preprocess_dataset(dataset) def get_generator(self, dataset_fn): graph = tf.Graph() with graph.as_default(): # pylint: disable=not-context-manager dataset = self.preprocess_dataset(dataset_fn()) return tfds.as_numpy(dataset, graph=graph) def get_config(self): return dict(batch_size=self.batch_size, repeats=self.repeats, shuffle_buffer=self.shuffle_buffer, map_fn=functions.serialize(self.map_fn), prefetch_buffer=self.prefetch_buffer, num_parallel_calls=self.num_parallel_calls, output_spec=get_input_spec_config(self.output_spec), output_spec_fn=functions.serialize(self.output_spec_fn), drop_remainder=self.drop_remainder) pipelines = registry.Registry('pipelines', registry.subclass_validator(Pipeline)) pipelines.register(Pipeline) get = pipelines.get deserialize = pipelines.deserialize serialize = pipelines.serialize register = pipelines.register
# elif isinstance(identifier, functools.partial): # return ConfigurablePartial(identifier) # elif isinstance(identifier, six.string_types): # raise ValueError('Cannot get function with just a string') # else: # return super(FunctionRegistry, self).deserialize(identifier) # def serialize(self, instance): # if isinstance(instance, types.FunctionType): # instance = ConfigurableFunction(instance) # elif isinstance(instance, functools.partial): # instance = ConfigurablePartial(instance) # return super(FunctionRegistry, self).serialize(instance) # functions = FunctionRegistry('functions') functions = registry.Registry('functions') @functions.register class ConfigurableFunction(Configurable): def __init__(self, func): for attr in ('__name__', '__module__'): if not hasattr(func, attr): raise ValueError( 'func must have attr {} to be configurable'.format(attr)) if func.__name__ == '<lambda>': raise ValueError('Cannot wrap lambda functions as configurables') self._func = func def get_config(self): return dict(