Example #1
0
class AbstractEnergyTraining(Training):
    """Abstract base class for GAN training."""
    checkpoint_parameters = Training.checkpoint_parameters + [
        TrainingState(), NetNameListState("names")
    ]

    def __init__(self,
                 scores,
                 data,
                 optimizer=torch.optim.Adam,
                 optimizer_kwargs=None,
                 num_workers=8,
                 **kwargs):
        """Generic training setup for energy/score based models.

    Args:
      scores (list): networks used for scoring.
      data (Dataset): provider of training data.
      optimizer (Optimizer): optimizer class for gradient descent.
      optimizer_kwargs (dict): keyword arguments for the
        optimizer used in score function training.
      max_epochs (int): maximum number of training epochs.
      batch_size (int): number of training samples per batch.
      device (string): device to use for training.
      network_name (string): identifier of the network architecture.
      verbose (bool): log all events and losses?
    """
        super(AbstractEnergyTraining, self).__init__(**kwargs)

        netlist = []
        self.names = []
        for network in scores:
            self.names.append(network)
            network_object = scores[network].to(self.device)
            setattr(self, network, network_object)
            netlist.extend(list(network_object.parameters()))

        self.num_workers = num_workers
        self.data = data
        self.train_data = None

        self.current_losses = {}

        if optimizer_kwargs is None:
            optimizer_kwargs = {"lr": 5e-4}

        self.optimizer = optimizer(netlist, **optimizer_kwargs)

        self.checkpoint_names = {
            name: getattr(self, name)
            for name in self.names
        }

    def energy_loss(self, *args):
        """Abstract method. Computes the score function loss."""
        raise NotImplementedError("Abstract")

    def loss(self, *args):
        return self.energy_loss(*args)

    def prepare(self, *args, **kwargs):
        """Abstract method. Prepares an initial state for sampling."""
        raise NotImplementedError("Abstract")

    def sample(self, *args, **kwargs):
        """Abstract method. Samples from the Boltzmann distribution."""
        raise NotImplementedError("Abstract")

    def run_energy(self, data):
        """Abstract method. Runs score at each step."""
        raise NotImplementedError("Abstract")

    def each_generate(self, *inputs):
        """Reports on generation."""
        pass

    def energy_step(self, data):
        """Performs a single step of discriminator training.

    Args:
      data: data points used for training.
    """
        self.optimizer.zero_grad()
        data = to_device(data, self.device)
        args = self.run_energy(data)
        loss_val = self.loss(*args)

        if self.verbose:
            for loss_name in self.current_losses:
                loss_float = self.current_losses[loss_name]
                self.writer.add_scalar(f"{loss_name} loss", loss_float,
                                       self.step_id)
        self.writer.add_scalar("discriminator total loss", float(loss_val),
                               self.step_id)

        loss_val.backward()
        self.optimizer.step()

    def step(self, data):
        """Performs a single step of GAN training, comprised of
    one or more steps of discriminator and generator training.

    Args:
      data: data points used for training."""
        self.energy_step(data)
        self.each_step()

    def train(self):
        """Trains an EBM until the maximum number of epochs is reached."""
        for epoch_id in range(self.max_epochs):
            self.train_data = None
            self.train_data = DataLoader(self.data,
                                         batch_size=self.batch_size,
                                         num_workers=self.num_workers,
                                         shuffle=True,
                                         drop_last=True)

            for data in self.train_data:
                self.step(data)
                self.log()
                self.step_id += 1
            self.epoch_id += 1

        scores = [getattr(self, name) for name in self.names]

        return scores
Example #2
0
class AbstractVAETraining(Training):
    """Abstract base class for VAE training."""
    checkpoint_parameters = Training.checkpoint_parameters + [
        TrainingState(),
        NetNameListState("network_names"),
        NetState("optimizer")
    ]

    def __init__(self,
                 networks,
                 data,
                 valid=None,
                 optimizer=torch.optim.Adam,
                 optimizer_kwargs=None,
                 max_epochs=50,
                 batch_size=128,
                 device="cpu",
                 path_prefix=".",
                 network_name="network",
                 report_interval=10,
                 checkpoint_interval=1000,
                 verbose=False):
        """Generic training setup for variational autoencoders.

    Args:
      networks (list): networks used in the training step.
      data (Dataset): provider of training data.
      optimizer (Optimizer): optimizer class for gradient descent.
      optimizer_kwargs (dict): keyword arguments for the
        optimizer used in network training.
      max_epochs (int): maximum number of training epochs.
      batch_size (int): number of training samples per batch.
      device (string): device to use for training.
      network_name (string): identifier of the network architecture.
      verbose (bool): log all events and losses?
    """
        super(AbstractVAETraining, self).__init__()

        self.verbose = verbose
        self.checkpoint_path = f"{path_prefix}/{network_name}"
        self.report_interval = report_interval
        self.checkpoint_interval = checkpoint_interval

        self.data = data
        self.valid = valid
        self.train_data = None
        self.valid_data = None
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.device = device

        netlist = []
        self.network_names = []
        for network in networks:
            self.network_names.append(network)
            network_object = networks[network].to(self.device)
            setattr(self, network, network_object)
            netlist.extend(list(network_object.parameters()))

        self.current_losses = {}
        self.network_name = network_name
        self.writer = SummaryWriter(network_name)

        self.epoch_id = 0
        self.step_id = 0

        if optimizer_kwargs is None:
            optimizer_kwargs = {"lr": 5e-4}

        self.optimizer = optimizer(netlist, **optimizer_kwargs)

    def save_path(self):
        return f"{self.checkpoint_path}-save.torch"

    def divergence_loss(self, *args):
        """Abstract method. Computes the divergence loss."""
        raise NotImplementedError("Abstract")

    def reconstruction_loss(self, *args):
        """Abstract method. Computes the reconstruction loss."""
        raise NotImplementedError("Abstract")

    def loss(self, *args):
        """Abstract method. Computes the training loss."""
        raise NotImplementedError("Abstract")

    def sample(self, *args, **kwargs):
        """Abstract method. Samples from the latent distribution."""
        raise NotImplementedError("Abstract")

    def run_networks(self, data, *args):
        """Abstract method. Runs neural networks at each step."""
        raise NotImplementedError("Abstract")

    def preprocess(self, data):
        """Takes and partitions input data into VAE data and args."""
        return data

    def step(self, data):
        """Performs a single step of VAE training.

    Args:
      data: data points used for training."""
        self.optimizer.zero_grad()
        data = to_device(data, self.device)
        data, *netargs = self.preprocess(data)
        args = self.run_networks(data, *netargs)

        loss_val = self.loss(*args)

        if self.verbose:
            for loss_name in self.current_losses:
                loss_float = self.current_losses[loss_name]
                self.writer.add_scalar(f"{loss_name} loss", loss_float,
                                       self.step_id)
        self.writer.add_scalar("total loss", float(loss_val), self.step_id)

        loss_val.backward()
        self.optimizer.step()
        self.each_step()

        return float(loss_val)

    def valid_step(self, data):
        """Performs a single step of VAE validation.

    Args:
      data: data points used for validation."""
        with torch.no_grad():
            if isinstance(data, (list, tuple)):
                data = [point.to(self.device) for point in data]
            elif isinstance(data, dict):
                data = {key: data[key].to(self.device) for key in data}
            else:
                data = data.to(self.device)
            args = self.run_networks(data)

            loss_val = self.loss(*args)

        return float(loss_val)

    def checkpoint(self):
        """Performs a checkpoint for all encoders and decoders."""
        for name in self.network_names:
            the_net = getattr(self, name)
            if isinstance(the_net, torch.nn.DataParallel):
                the_net = the_net.module
            netwrite(
                the_net,
                f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
            )
        self.each_checkpoint()

    def validate(self, data):
        loss = self.valid_step(data)
        self.writer.add_scalar("valid loss", loss, self.step_id)
        self.each_validate()

    def train(self):
        """Trains a VAE until the maximum number of epochs is reached."""
        for epoch_id in range(self.max_epochs):
            self.epoch_id = epoch_id
            self.train_data = None
            self.train_data = DataLoader(self.data,
                                         batch_size=self.batch_size,
                                         num_workers=8,
                                         shuffle=True)
            if self.valid is not None:
                self.valid_data = DataLoader(self.valid,
                                             batch_size=self.batch_size,
                                             num_workers=8,
                                             shuffle=True)
            for data in self.train_data:
                self.step(data)
                if self.step_id % self.checkpoint_interval == 0:
                    self.checkpoint()
                if self.valid is not None and self.step_id % self.report_interval == 0:
                    vdata = None
                    try:
                        vdata = next(valid_iter)
                    except StopIteration:
                        valid_iter = iter(self.valid_data)
                        vdata = next(valid_iter)
                    vdata = to_device(vdata, self.device)
                    self.validate(vdata)
                self.step_id += 1

        netlist = [getattr(self, name) for name in self.network_names]

        return netlist
Example #3
0
class AbstractVAETraining(Training):
  """Abstract base class for VAE training."""
  checkpoint_parameters = Training.checkpoint_parameters + [
    TrainingState(),
    NetNameListState("network_names"),
    NetState("optimizer")
  ]
  def __init__(self, networks, data, valid=None,
               optimizer=torch.optim.Adam,
               optimizer_kwargs=None,
               gradient_clip=200.0,
               gradient_skip=400.0,
               **kwargs):
    """Generic training setup for variational autoencoders.

    Args:
      networks (list): networks used in the training step.
      data (Dataset): provider of training data.
      optimizer (Optimizer): optimizer class for gradient descent.
      optimizer_kwargs (dict): keyword arguments for the
        optimizer used in network training.
      max_epochs (int): maximum number of training epochs.
      batch_size (int): number of training samples per batch.
      device (string): device to use for training.
      network_name (string): identifier of the network architecture.
      verbose (bool): log all events and losses?
    """
    super(AbstractVAETraining, self).__init__(**kwargs)

    self.data = data
    self.valid = valid
    self.train_data = None
    self.valid_data = None
    self.gradient_clip = gradient_clip
    self.gradient_skip = gradient_skip

    self.valid_iter = None
    if self.valid is not None:
      self.valid_data = DataLoader(
        self.valid, batch_size=self.batch_size, num_workers=8,
        shuffle=True
      )
      self.valid_iter = iter(self.valid_data)

    self.network_names, netlist = self.collect_netlist(networks)

    if optimizer_kwargs is None:
      optimizer_kwargs = {"lr" : 5e-4}

    self.optimizer = optimizer(
      netlist,
      **optimizer_kwargs
    )
    self.checkpoint_names.update(
      self.get_netlist(self.network_names)
    )

  def divergence_loss(self, *args):
    """Abstract method. Computes the divergence loss."""
    raise NotImplementedError("Abstract")

  def reconstruction_loss(self, *args):
    """Abstract method. Computes the reconstruction loss."""
    raise NotImplementedError("Abstract")

  def loss(self, *args):
    """Abstract method. Computes the training loss."""
    raise NotImplementedError("Abstract")

  def sample(self, *args, **kwargs):
    """Abstract method. Samples from the latent distribution."""
    raise NotImplementedError("Abstract")

  def run_networks(self, data, *args):
    """Abstract method. Runs neural networks at each step."""
    raise NotImplementedError("Abstract")

  def preprocess(self, data):
    """Takes and partitions input data into VAE data and args."""
    return data

  def each_generate(self, *args):
    pass

  def step(self, data):
    """Performs a single step of VAE training.

    Args:
      data: data points used for training."""
    self.optimizer.zero_grad()
    data = to_device(data, self.device)
    data, *netargs = self.preprocess(data)
    args = self.run_networks(data, *netargs)

    loss_val = self.loss(*args)

    if self.verbose:
      if self.step_id % self.report_interval == 0:
        self.each_generate(*args)
      for loss_name in self.current_losses:
        loss_float = self.current_losses[loss_name]
        self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id)
    self.writer.add_scalar("total loss", float(loss_val), self.step_id)

    loss_val.backward()
    parameters = [
      param
      for key, val in self.get_netlist(self.network_names).items()
      for param in val.parameters()
    ]
    gn = nn.utils.clip_grad_norm_(parameters, self.gradient_clip)
    if (not torch.isnan(gn).any()) and (gn < self.gradient_skip).all():
      self.optimizer.step()
    self.each_step()

    return float(loss_val)

  def valid_step(self, data):
    """Performs a single step of VAE validation.

    Args:
      data: data points used for validation."""
    with torch.no_grad():
      if isinstance(data, (list, tuple)):
        data = [
          point.to(self.device)
          for point in data
        ]
      elif isinstance(data, dict):
        data = {
          key : data[key].to(self.device)
          for key in data
        }
      else:
        data = data.to(self.device)
      args = self.run_networks(data)

      loss_val = self.loss(*args)

    return float(loss_val)

  def validate(self, data):
    loss = self.valid_step(data)
    self.writer.add_scalar("valid loss", loss, self.step_id)
    self.each_validate()

  def run_report(self):
    if self.valid is not None:
      vdata = None
      try:
        vdata = next(self.valid_iter)
      except StopIteration:
        self.valid_iter = iter(self.valid_data)
        vdata = next(self.valid_iter)
      vdata = to_device(vdata, self.device)
      self.validate(vdata)

  def train(self):
    """Trains a VAE until the maximum number of epochs is reached."""
    for epoch_id in range(self.max_epochs):
      self.epoch_id = epoch_id
      self.train_data = None
      self.train_data = DataLoader(
        self.data, batch_size=self.batch_size, num_workers=8,
        shuffle=True
      )
      for data in self.train_data:
        self.step(data)
        self.log()
        self.step_id += 1

    netlist = self.get_netlist(self.network_names)

    return netlist
Example #4
0
class AbstractGANTraining(Training):
    """Abstract base class for GAN training."""
    checkpoint_parameters = Training.checkpoint_parameters + [
        TrainingState(),
        NetNameListState("generator_names"),
        NetNameListState("discriminator_names"),
        NetState("generator_optimizer"),
        NetState("discriminator_optimizer")
    ]

    def __init__(self,
                 generators,
                 discriminators,
                 data,
                 optimizer=torch.optim.Adam,
                 generator_optimizer_kwargs=None,
                 discriminator_optimizer_kwargs=None,
                 n_critic=1,
                 n_actor=1,
                 max_epochs=50,
                 batch_size=128,
                 device="cpu",
                 path_prefix=".",
                 network_name="network",
                 verbose=False,
                 report_interval=10,
                 checkpoint_interval=1000):
        """Generic training setup for generative adversarial networks.

    Args:
      generators (list): networks used in the generation step.
      discriminators (list): networks used in the discriminator step.
      data (Dataset): provider of training data.
      optimizer (Optimizer): optimizer class for gradient descent.
      generator_optimizer_kwargs (dict): keyword arguments for the
        optimizer used in generator training.
      discriminator_optimizer_kwargs (dict): keyword arguments for the
        optimizer used in discriminator training.
      n_critic (int): number of critic training iterations per step.
      n_actor (int): number of actor training iterations per step.
      max_epochs (int): maximum number of training epochs.
      batch_size (int): number of training samples per batch.
      device (string): device to use for training.
      network_name (string): identifier of the network architecture.
      verbose (bool): log all events and losses?
    """
        super(AbstractGANTraining, self).__init__()

        self.verbose = verbose
        self.report_interval = report_interval
        self.checkpoint_interval = checkpoint_interval
        self.checkpoint_path = f"{path_prefix}/{network_name}"

        self.n_critic = n_critic
        self.n_actor = n_actor

        generator_netlist = []
        self.generator_names = []
        for network in generators:
            self.generator_names.append(network)
            network_object = generators[network].to(device)
            setattr(self, network, network_object)
            generator_netlist.extend(list(network_object.parameters()))

        discriminator_netlist = []
        self.discriminator_names = []
        for network in discriminators:
            self.discriminator_names.append(network)
            network_object = discriminators[network].to(device)
            setattr(self, network, network_object)
            discriminator_netlist.extend(list(network_object.parameters()))

        self.data = data
        self.train_data = None
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.device = device

        self.current_losses = {}
        self.network_name = network_name
        self.writer = SummaryWriter(network_name)

        self.epoch_id = 0
        self.step_id = 0

        if generator_optimizer_kwargs is None:
            generator_optimizer_kwargs = {"lr": 5e-4}
        if discriminator_optimizer_kwargs is None:
            discriminator_optimizer_kwargs = {"lr": 5e-4}

        self.generator_optimizer = optimizer(generator_netlist,
                                             **generator_optimizer_kwargs)
        self.discriminator_optimizer = optimizer(
            discriminator_netlist, **discriminator_optimizer_kwargs)

    def save_path(self):
        return f"{self.checkpoint_path}-save.torch"

    def generator_loss(self, *args):
        """Abstract method. Computes the generator loss."""
        raise NotImplementedError("Abstract")

    def generator_step_loss(self, *args):
        """Computes the losses of all generators."""
        return self.generator_loss(*args)

    def discriminator_loss(self, *args):
        """Abstract method. Computes the discriminator loss."""
        raise NotImplementedError("Abstract")

    def discriminator_step_loss(self, *args):
        """Computes the losses of all discriminators."""
        return self.discriminator_loss(*args)

    def sample(self, *args, **kwargs):
        """Abstract method. Samples from the latent distribution."""
        raise NotImplementedError("Abstract")

    def run_generator(self, data):
        """Abstract method. Runs generation at each step."""
        raise NotImplementedError("Abstract")

    def run_discriminator(self, data):
        """Abstract method. Runs discriminator training."""
        raise NotImplementedError("Abstract")

    def each_generate(self, *inputs):
        """Reports on generation."""
        pass

    def discriminator_step(self, data):
        """Performs a single step of discriminator training.

    Args:
      data: data points used for training.
    """
        self.discriminator_optimizer.zero_grad()
        data = to_device(data, self.device)
        args = self.run_discriminator(data)
        loss_val, *grad_out = self.discriminator_step_loss(*args)

        if self.verbose:
            for loss_name in self.current_losses:
                loss_float = self.current_losses[loss_name]
                self.writer.add_scalar(f"{loss_name} loss", loss_float,
                                       self.step_id)
        self.writer.add_scalar("discriminator total loss", float(loss_val),
                               self.step_id)

        loss_val.backward()
        self.discriminator_optimizer.step()

    def generator_step(self, data):
        """Performs a single step of generator training.

    Args:
      data: data points used for training.
    """
        self.generator_optimizer.zero_grad()
        data = to_device(data, self.device)
        args = self.run_generator(data)
        loss_val = self.generator_step_loss(*args)

        if self.verbose:
            if self.step_id % self.report_interval == 0:
                self.each_generate(*args)
            for loss_name in self.current_losses:
                loss_float = self.current_losses[loss_name]
                self.writer.add_scalar(f"{loss_name} loss", loss_float,
                                       self.step_id)
        self.writer.add_scalar("generator total loss", float(loss_val),
                               self.step_id)

        loss_val.backward()
        self.generator_optimizer.step()

    def step(self, data):
        """Performs a single step of GAN training, comprised of
    one or more steps of discriminator and generator training.

    Args:
      data: data points used for training."""
        for _ in range(self.n_critic):
            self.discriminator_step(next(data))
        for _ in range(self.n_actor):
            self.generator_step(next(data))
        self.each_step()

    def checkpoint(self):
        """Performs a checkpoint of all generators and discriminators."""
        for name in self.generator_names:
            the_net = getattr(self, name)
            if isinstance(the_net, torch.nn.DataParallel):
                the_net = the_net.module
            netwrite(
                the_net,
                f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
            )
        for name in self.discriminator_names:
            the_net = getattr(self, name)
            if isinstance(the_net, torch.nn.DataParallel):
                the_net = the_net.module
            netwrite(
                the_net,
                f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
            )
        self.each_checkpoint()

    def train(self):
        """Trains a GAN until the maximum number of epochs is reached."""
        for epoch_id in range(self.max_epochs):
            self.epoch_id = epoch_id
            self.train_data = None
            self.train_data = DataLoader(self.data,
                                         batch_size=self.batch_size,
                                         num_workers=8,
                                         shuffle=True,
                                         drop_last=True)

            batches_per_step = self.n_actor + self.n_critic
            steps_per_episode = len(self.train_data) // batches_per_step

            data = iter(self.train_data)
            for _ in range(steps_per_episode):
                self.step(data)
                if self.step_id % self.checkpoint_interval == 0:
                    self.checkpoint()
                self.step_id += 1

        generators = [getattr(self, name) for name in self.generator_names]

        discriminators = [
            getattr(self, name) for name in self.discriminator_names
        ]

        return generators, discriminators
Example #5
0
class AbstractContrastiveTraining(Training):
  """Abstract base class for contrastive training."""
  checkpoint_parameters = Training.checkpoint_parameters + [
    TrainingState(),
    NetNameListState("names")
  ]
  def __init__(self, networks, data,
               optimizer=torch.optim.Adam,
               optimizer_kwargs=None,
               num_workers=8,
               **kwargs):
    """Generic training setup for energy/score based models.

    Args:
      networks (list): networks used for contrastive learning.
      data (Dataset): provider of training data.
      optimizer (Optimizer): optimizer class for gradient descent.
      optimizer_kwargs (dict): keyword arguments for the
        optimizer used in score function training.
    """
    super().__init__(**kwargs)

    netlist = []
    self.names, netlist = self.collect_netlist(networks)

    self.data = data
    self.train_data = None
    self.num_workers = num_workers

    self.current_losses = {}

    if optimizer_kwargs is None:
      optimizer_kwargs = {"lr" : 5e-4}

    self.optimizer = optimizer(
      netlist,
      **optimizer_kwargs
    )

    self.checkpoint_names = self.get_netlist(self.names)

  def contrastive_loss(self, *args):
    """Abstract method. Computes the contrastive loss."""
    raise NotImplementedError("Abstract")

  def regularization(self, *args):
    """Computes network regularization."""
    return 0.0

  def loss(self, *args):
    contrastive = self.contrastive_loss(*args)
    regularization = self.regularization(*args)
    return contrastive + regularization

  def run_networks(self, data):
    """Abstract method. Runs networks at each step."""
    raise NotImplementedError("Abstract")

  def visualize(self, data):
    pass

  def contrastive_step(self, data):
    """Performs a single step of contrastive training.

    Args:
      data: data points used for training.
    """
    if self.step_id % self.report_interval == 0:
      self.visualize(data)

    self.optimizer.zero_grad()
    data = to_device(data, self.device)
    make_differentiable(data)
    args = self.run_networks(data)
    loss_val = self.loss(*args)

    self.log_statistics(loss_val, name="total loss")

    loss_val.backward()
    self.optimizer.step()

  def step(self, data):
    """Performs a single step of contrastive training.

    Args:
      data: data points used for training.
    """
    self.contrastive_step(data)
    self.each_step()

  def train(self):
    """Runs contrastive training until the maximum number of epochs is reached."""
    for epoch_id in range(self.max_epochs):
      self.epoch_id = epoch_id
      self.train_data = None
      self.train_data = DataLoader(
        self.data, batch_size=self.batch_size, num_workers=self.num_workers,
        shuffle=True, drop_last=True
      )

      for data in self.train_data:
        self.step(data)
        self.log()
        self.step_id += 1

    return self.get_netlist(self.names)
class OffPolicyTraining(Training):
    checkpoint_parameters = Training.checkpoint_parameters + [
        TrainingState(),
        NetNameListState("auxiliary_names"),
        NetState("policy"),
        NetState("optimizer"),
        NetState("auxiliary_optimizer")
    ]

    def __init__(self,
                 policy,
                 agent,
                 environment,
                 auxiliary_networks=None,
                 buffer_size=100_000,
                 piecewise_append=False,
                 policy_steps=1,
                 auxiliary_steps=1,
                 n_workers=8,
                 discount=0.99,
                 double=False,
                 optimizer=torch.optim.Adam,
                 optimizer_kwargs=None,
                 aux_optimizer=torch.optim.Adam,
                 aux_optimizer_kwargs=None,
                 **kwargs):
        super().__init__(**kwargs)
        self.policy_steps = policy_steps
        self.auxiliary_steps = auxiliary_steps

        self.current_losses = {}

        self.statistics = ExperienceStatistics()
        self.discount = discount
        self.environment = environment
        self.agent = agent
        self.policy = policy.to(self.device)
        self.collector = EnvironmentCollector(environment,
                                              agent,
                                              discount=discount)
        self.distributor = DefaultDistributor()
        self.data_collector = ExperienceCollector(self.distributor,
                                                  self.collector,
                                                  n_workers=n_workers,
                                                  piecewise=piecewise_append)
        self.buffer = SchemaBuffer(self.data_collector.schema(),
                                   buffer_size,
                                   double=double)

        optimizer_kwargs = optimizer_kwargs or {}
        self.optimizer = optimizer(self.policy.parameters(),
                                   **optimizer_kwargs)

        auxiliary_netlist = []
        self.auxiliary_names = []
        for network in auxiliary_networks:
            self.auxiliary_names.append(network)
            network_object = auxiliary_networks[network].to(self.device)
            setattr(self, network, network_object)
            auxiliary_netlist.extend(list(network_object.parameters()))

        aux_optimizer_kwargs = aux_optimizer_kwargs or {}
        self.auxiliary_optimizer = aux_optimizer(auxiliary_netlist,
                                                 **aux_optimizer_kwargs)

        self.checkpoint_names = dict(
            policy=self.policy,
            **{name: getattr(self, name)
               for name in self.auxiliary_names})