コード例 #1
0
ファイル: session_options.py プロジェクト: jackd/keras-config
            elif isinstance(visible_devices, int):
                visible_devices = str(visible_devices)
            logging.info(
                'Setting CUDA_VISIBLE_DEVICES={}'.format(visible_devices))
            os.environ['CUDA_VISIBLE_DEVICES'] = visible_devices

        import tensorflow as tf
        if self.allow_growth is not None:
            config = tf.compat.v1.ConfigProto()
            config.gpu_options.allow_growth = True  # pylint: disable=no-member
        else:
            config = None
        if self.eager:
            tf.compat.v1.enable_eager_execution(config=config)
        elif config is not None:
            tf.keras.backend.set_session(tf.compat.v1.Session(config=config))

    def get_config(self):
        return dict(allow_growth=self.allow_growth,
                    visible_devices=self.visible_devices,
                    eager=self.eager)


session_options = registry.Registry(
    'session_options', registry.has_attrs_validator('configure_session'))

get = session_options.get
deserialize = session_options.deserialize
serialize = session_options.serialize
register = session_options.register
コード例 #2
0
        return tf.nest.map_structure(self._examples_per_epoch, split)

    def get_base_dataset(self, split=tfds.Split.TRAIN):
        return tf.nest.map_structure(self._get_base_dataset, split)

    def get_config(self):
        objective = self.objective
        return dict(
            loss=_keras.losses.serialize(self.loss),
            metrics=[_keras.metrics.serialize(m) for m in self.metrics],
            objective=None if objective is None else objective.get_config(),
            input_spec=get_input_spec_config(self.input_spec),
            output_spec=get_input_spec_config(self.output_spec))


problems = registry.Registry('problems', registry.subclass_validator(Problem))


def get_input_spec_config(input_spec):
    if input_spec is None:
        return None
    return dict(dtype=repr(input_spec.dtype)[3:],
                shape=input_spec.shape,
                ndim=input_spec.ndim,
                max_ndim=input_spec.max_ndim,
                min_ndim=input_spec.min_ndim,
                axes=input_spec.axes)


def get_input_spec(identifier):
    if identifier is None or isinstance(identifier, tf.keras.layers.InputSpec):
コード例 #3
0
            steps_per_epoch=train_steps,
            validation_steps=val_steps,
            initial_epoch=initial_epoch,
        )
        return history

    def evaluate(self, checkpoint=LATEST, verbose=True):
        val_ds, val_steps = self._dataset_and_steps('validation')
        model, callbacks = self._get_model_and_callbacks()
        chkpt_callback = BetterModelCheckpoint(self.chkpt_dir)
        chkpt_callback.set_model(model)
        chkpt_callback.restore(checkpoint)
        model.evaluate(val_ds,
                       steps=val_steps,
                       callbacks=callbacks,
                       verbose=verbose)


trainers = registry.Registry('trainers', registry.subclass_validator(Trainer))
trainers.register(Trainer)


def _config_path(log_dir, epoch=0):
    return os.path.join(log_dir, 'config-{}.yaml'.format(epoch))


get = trainers.get
deserialize = trainers.deserialize
serialize = trainers.serialize
register = trainers.register
コード例 #4
0
ファイル: pipelines.py プロジェクト: jackd/keras-config
    def __call__(self, dataset):
        return self.preprocess_dataset(dataset)

    def get_generator(self, dataset_fn):
        graph = tf.Graph()
        with graph.as_default():  # pylint: disable=not-context-manager
            dataset = self.preprocess_dataset(dataset_fn())
        return tfds.as_numpy(dataset, graph=graph)

    def get_config(self):
        return dict(batch_size=self.batch_size,
                    repeats=self.repeats,
                    shuffle_buffer=self.shuffle_buffer,
                    map_fn=functions.serialize(self.map_fn),
                    prefetch_buffer=self.prefetch_buffer,
                    num_parallel_calls=self.num_parallel_calls,
                    output_spec=get_input_spec_config(self.output_spec),
                    output_spec_fn=functions.serialize(self.output_spec_fn),
                    drop_remainder=self.drop_remainder)


pipelines = registry.Registry('pipelines',
                              registry.subclass_validator(Pipeline))

pipelines.register(Pipeline)

get = pipelines.get
deserialize = pipelines.deserialize
serialize = pipelines.serialize
register = pipelines.register
コード例 #5
0
ファイル: functions.py プロジェクト: jackd/keras-config
#         elif isinstance(identifier, functools.partial):
#             return ConfigurablePartial(identifier)
#         elif isinstance(identifier, six.string_types):
#             raise ValueError('Cannot get function with just a string')
#         else:
#             return super(FunctionRegistry, self).deserialize(identifier)

#     def serialize(self, instance):
#         if isinstance(instance, types.FunctionType):
#             instance = ConfigurableFunction(instance)
#         elif isinstance(instance, functools.partial):
#             instance = ConfigurablePartial(instance)
#         return super(FunctionRegistry, self).serialize(instance)

# functions = FunctionRegistry('functions')
functions = registry.Registry('functions')


@functions.register
class ConfigurableFunction(Configurable):
    def __init__(self, func):
        for attr in ('__name__', '__module__'):
            if not hasattr(func, attr):
                raise ValueError(
                    'func must have attr {} to be configurable'.format(attr))
        if func.__name__ == '<lambda>':
            raise ValueError('Cannot wrap lambda functions as configurables')
        self._func = func

    def get_config(self):
        return dict(