Пример #1
0
    def __init__(
        self,
        model_builder: Callable,
        optimizer: tf.train.Optimizer,
        task_dists: List[TaskDistribution],
        batch_size: Optional[int] = None,
        num_inner_steps: int = 1,
        inner_optimizer: Optional[dict] = None,
        first_order: bool = False,
        mode: str = common.ModeKeys.TRAIN,
        name: Optional[str] = None,
        **_unused_kwargs,
    ):
        self.batch_size = batch_size
        self.num_inner_steps = num_inner_steps
        self.first_order = first_order

        # Inner loop.
        self.inner_optimizer = optimizers.get(**(inner_optimizer or {}))
        self.inner_adapted_models = []

        super(Maml, self).__init__(
            model_builder=model_builder,
            optimizer=optimizer,
            task_dists=task_dists,
            mode=mode,
            name=(name or self.__class__.__name__),
        )
Пример #2
0
    def __init__(
        self,
        model,
        optimizer,
        tasks,
        batch_size=16,
        inner_optimizer=None,
        first_order=False,
        mode=common.ModeKeys.TRAIN,
        name="Maml",
        **kwargs,
    ):

        # Instantiate Maml.

        super(Maml, self).__init__(
            model=model,
            optimizer=optimizer,
            tasks=tasks,
            mode=mode,
            name=name,
            **kwargs,
        )
        self._batch_size = batch_size
        self._first_order = first_order

        # Inner loop.
        self._inner_optimizer = optimizers.get(**inner_optimizer)
        self._adapt_steps_ph = None
        self._adapted_params = None
Пример #3
0
def build_and_initialize(cfg, mode=common.ModeKeys.TRAIN):
    """Builds and initializes all parts of the graph.

    Parameters
    ----------
    cfg : OmegaConf
        The experiment configuration.

    mode : str, optional (default: common.ModeKeys.TRAIN)
        Defines the mode of the computation graph (TRAIN or EVAL).
        Note: this is likely to be removed from the API down the line.

    Returns
    -------
    exp : Experiment
        An object that represents the experiment.
        Contains `meta_learners`, `samplers`, and `task_dists`.
    """
    sess = tf.get_default_session()

    # Build the data source.
    data_source = datasets.get_data_source(name=cfg.data.name,
                                           **cfg.data.source).build()

    # Build meta-datasets.
    meta_datasets = {}
    for set_name in set(task.set_name for task in cfg[mode].tasks):
        meta_datasets[set_name] = datasets.get_meta_dataset(
            name=cfg.data.name,
            data_sources=data_source[set_name],
            **cfg[mode].meta_dataset,
        ).build()

    # Build task distributions.
    task_dists = []
    for task in cfg[mode].tasks:
        task_dist = tasks.get_distribution(
            meta_dataset=meta_datasets[task.set_name],
            name_suffix=f"{task.set_name}_{task.regime}",
            sampler_config=task.sampler,
            **task.config,
        ).build()
        task_dists.append(task_dist)

    # Build model.
    network_builder = networks.get(**cfg.network)
    model_builder = models.get(
        input_shapes=data_source.data_shapes,
        input_types=data_source.data_types,
        num_classes=cfg[mode].meta_dataset.num_classes,
        network_builder=network_builder,
        **cfg[mode].model,
    )

    # Build optimizer.
    optimizer = optimizers.get(**cfg.train.optimizer)

    # Build meta-learner.
    meta_learner = adaptation.get(
        model_builder=model_builder,
        optimizer=optimizer,
        task_dists=task_dists,
        mode=mode,
        **cfg[mode].adapt,
    )

    # Variable initialization.
    if mode == common.ModeKeys.TRAIN:
        # Initialize all the variables in the graph.
        sess.run(tf.global_variables_initializer())
    else:  # mode == common.ModeKeys.EVAL:
        # Initialize only non-trainable variables.
        # Note: Trainable variables must be loaded from a checkpoint.
        #       Being explicit about which variables are initialized is better
        #       prevents weird side effects when we are unaware of some created
        #       variables that are silently initialized at evaluation time.
        sess.run(
            tf.variables_initializer(meta_learner.non_trainable_parameters))

    # Initialize task distributions.
    for task, task_dist in zip(cfg[mode].tasks, task_dists):
        sampler = None
        if task.sampler is not None:
            sampler = samplers.get(**task.sampler)
            sampler.build(task_dist=task_dist, meta_learner=meta_learner)
        task_dist.initialize(sampler=sampler)

    return meta_learner
Пример #4
0
def build_and_initialize(cfg, sess, categories, mode=common.ModeKeys.TRAIN):
    """Builds and initializes all parts of the graph.

    Parameters
    ----------
    cfg : OmegaConf
        The experiment configuration.

    sess : tf.Session
        The TF session used for executing the computation graph.

    categories : dict of lists of Categories
        Each list of Categories is used to construct meta-datasets.

    mode : str, optional (default: common.ModeKeys.TRAIN)
        Defines the mode of the computation graph (TRAIN or EVAL).
        Note: this is likely to be removed from the API down the line.

    Returns
    -------
    exp : Experiment
        An object that represents the experiment.
        Contains `meta_learners`, `samplers`, and `task_dists`.
    """
    # Build and initialize data pools.
    data_pools = {
        task.set_name: datasets.get_datapool(
            dataset_name=cfg.data.name,
            categories=categories[task.set_name],
            name=f"DP_{task.log_dir.replace('/', '_')}",
        )
        .build(**cfg.data.build_config)
        .initialize(sess)
        for task in cfg[mode].tasks
    }

    # Build meta-dataset.
    meta_datasets = {
        task.set_name: datasets.get_metadataset(
            dataset_name=cfg.data.name,
            data_pool=data_pools[task.set_name],
            batch_size=cfg[mode].meta.batch_size,
            name=f"MD_{task.log_dir.replace('/', '_')}",
            **cfg[mode].dataset,
        ).build()
        for task in cfg[mode].tasks
    }

    # Build model.
    model = models.get(
        dataset_name=cfg.data.name,
        num_classes=cfg[mode].dataset.num_classes,
        **cfg.model,
    )

    # Build optimizer.
    optimizer = optimizers.get(**cfg.train.optimizer)

    # Build task distributions.
    task_dists = [
        tasks.get_distribution(
            meta_dataset=meta_datasets[task.set_name],
            name_suffix=task.log_dir.replace("/", "_"),
            **task.config,
        )
        for task in cfg[mode].tasks
    ]

    # Build meta-learners.
    meta_learners = [
        adaptation.get(
            model=model,
            optimizer=optimizer,
            mode=mode,
            tasks=task_dists[i].task_batch,
            **cfg.adapt,
        )
        for i, task in enumerate(cfg[mode].tasks)
    ]

    # Build samplers.
    samplers_list = [
        samplers.get(
            learner=meta_learners[i], tasks=task_dists[i].task_batch, **task.sampler
        )
        for i, task in enumerate(cfg[mode].tasks)
    ]

    # Run global init.
    sess.run(tf.global_variables_initializer())

    # Initialize task distribution.
    for task_dist, sampler in zip(task_dists, samplers_list):
        task_dist.initialize(sampler=sampler, sess=sess)

    return Experiment(
        meta_learners=meta_learners, samplers=samplers_list, task_dists=task_dists
    )