Example #1
0
class AbstractVAETraining(Training):
    """Abstract base class for VAE training."""
    checkpoint_parameters = Training.checkpoint_parameters + [
        TrainingState(),
        NetNameListState("network_names"),
        NetState("optimizer")
    ]

    def __init__(self,
                 networks,
                 data,
                 valid=None,
                 optimizer=torch.optim.Adam,
                 optimizer_kwargs=None,
                 max_epochs=50,
                 batch_size=128,
                 device="cpu",
                 path_prefix=".",
                 network_name="network",
                 report_interval=10,
                 checkpoint_interval=1000,
                 verbose=False):
        """Generic training setup for variational autoencoders.

    Args:
      networks (list): networks used in the training step.
      data (Dataset): provider of training data.
      optimizer (Optimizer): optimizer class for gradient descent.
      optimizer_kwargs (dict): keyword arguments for the
        optimizer used in network training.
      max_epochs (int): maximum number of training epochs.
      batch_size (int): number of training samples per batch.
      device (string): device to use for training.
      network_name (string): identifier of the network architecture.
      verbose (bool): log all events and losses?
    """
        super(AbstractVAETraining, self).__init__()

        self.verbose = verbose
        self.checkpoint_path = f"{path_prefix}/{network_name}"
        self.report_interval = report_interval
        self.checkpoint_interval = checkpoint_interval

        self.data = data
        self.valid = valid
        self.train_data = None
        self.valid_data = None
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.device = device

        netlist = []
        self.network_names = []
        for network in networks:
            self.network_names.append(network)
            network_object = networks[network].to(self.device)
            setattr(self, network, network_object)
            netlist.extend(list(network_object.parameters()))

        self.current_losses = {}
        self.network_name = network_name
        self.writer = SummaryWriter(network_name)

        self.epoch_id = 0
        self.step_id = 0

        if optimizer_kwargs is None:
            optimizer_kwargs = {"lr": 5e-4}

        self.optimizer = optimizer(netlist, **optimizer_kwargs)

    def save_path(self):
        return f"{self.checkpoint_path}-save.torch"

    def divergence_loss(self, *args):
        """Abstract method. Computes the divergence loss."""
        raise NotImplementedError("Abstract")

    def reconstruction_loss(self, *args):
        """Abstract method. Computes the reconstruction loss."""
        raise NotImplementedError("Abstract")

    def loss(self, *args):
        """Abstract method. Computes the training loss."""
        raise NotImplementedError("Abstract")

    def sample(self, *args, **kwargs):
        """Abstract method. Samples from the latent distribution."""
        raise NotImplementedError("Abstract")

    def run_networks(self, data, *args):
        """Abstract method. Runs neural networks at each step."""
        raise NotImplementedError("Abstract")

    def preprocess(self, data):
        """Takes and partitions input data into VAE data and args."""
        return data

    def step(self, data):
        """Performs a single step of VAE training.

    Args:
      data: data points used for training."""
        self.optimizer.zero_grad()
        data = to_device(data, self.device)
        data, *netargs = self.preprocess(data)
        args = self.run_networks(data, *netargs)

        loss_val = self.loss(*args)

        if self.verbose:
            for loss_name in self.current_losses:
                loss_float = self.current_losses[loss_name]
                self.writer.add_scalar(f"{loss_name} loss", loss_float,
                                       self.step_id)
        self.writer.add_scalar("total loss", float(loss_val), self.step_id)

        loss_val.backward()
        self.optimizer.step()
        self.each_step()

        return float(loss_val)

    def valid_step(self, data):
        """Performs a single step of VAE validation.

    Args:
      data: data points used for validation."""
        with torch.no_grad():
            if isinstance(data, (list, tuple)):
                data = [point.to(self.device) for point in data]
            elif isinstance(data, dict):
                data = {key: data[key].to(self.device) for key in data}
            else:
                data = data.to(self.device)
            args = self.run_networks(data)

            loss_val = self.loss(*args)

        return float(loss_val)

    def checkpoint(self):
        """Performs a checkpoint for all encoders and decoders."""
        for name in self.network_names:
            the_net = getattr(self, name)
            if isinstance(the_net, torch.nn.DataParallel):
                the_net = the_net.module
            netwrite(
                the_net,
                f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
            )
        self.each_checkpoint()

    def validate(self, data):
        loss = self.valid_step(data)
        self.writer.add_scalar("valid loss", loss, self.step_id)
        self.each_validate()

    def train(self):
        """Trains a VAE until the maximum number of epochs is reached."""
        for epoch_id in range(self.max_epochs):
            self.epoch_id = epoch_id
            self.train_data = None
            self.train_data = DataLoader(self.data,
                                         batch_size=self.batch_size,
                                         num_workers=8,
                                         shuffle=True)
            if self.valid is not None:
                self.valid_data = DataLoader(self.valid,
                                             batch_size=self.batch_size,
                                             num_workers=8,
                                             shuffle=True)
            for data in self.train_data:
                self.step(data)
                if self.step_id % self.checkpoint_interval == 0:
                    self.checkpoint()
                if self.valid is not None and self.step_id % self.report_interval == 0:
                    vdata = None
                    try:
                        vdata = next(valid_iter)
                    except StopIteration:
                        valid_iter = iter(self.valid_data)
                        vdata = next(valid_iter)
                    vdata = to_device(vdata, self.device)
                    self.validate(vdata)
                self.step_id += 1

        netlist = [getattr(self, name) for name in self.network_names]

        return netlist
Example #2
0
class AbstractEnergyTraining(Training):
    """Abstract base class for GAN training."""
    checkpoint_parameters = Training.checkpoint_parameters + [
        TrainingState(), NetNameListState("names")
    ]

    def __init__(self,
                 scores,
                 data,
                 optimizer=torch.optim.Adam,
                 optimizer_kwargs=None,
                 num_workers=8,
                 **kwargs):
        """Generic training setup for energy/score based models.

    Args:
      scores (list): networks used for scoring.
      data (Dataset): provider of training data.
      optimizer (Optimizer): optimizer class for gradient descent.
      optimizer_kwargs (dict): keyword arguments for the
        optimizer used in score function training.
      max_epochs (int): maximum number of training epochs.
      batch_size (int): number of training samples per batch.
      device (string): device to use for training.
      network_name (string): identifier of the network architecture.
      verbose (bool): log all events and losses?
    """
        super(AbstractEnergyTraining, self).__init__(**kwargs)

        netlist = []
        self.names = []
        for network in scores:
            self.names.append(network)
            network_object = scores[network].to(self.device)
            setattr(self, network, network_object)
            netlist.extend(list(network_object.parameters()))

        self.num_workers = num_workers
        self.data = data
        self.train_data = None

        self.current_losses = {}

        if optimizer_kwargs is None:
            optimizer_kwargs = {"lr": 5e-4}

        self.optimizer = optimizer(netlist, **optimizer_kwargs)

        self.checkpoint_names = {
            name: getattr(self, name)
            for name in self.names
        }

    def energy_loss(self, *args):
        """Abstract method. Computes the score function loss."""
        raise NotImplementedError("Abstract")

    def loss(self, *args):
        return self.energy_loss(*args)

    def prepare(self, *args, **kwargs):
        """Abstract method. Prepares an initial state for sampling."""
        raise NotImplementedError("Abstract")

    def sample(self, *args, **kwargs):
        """Abstract method. Samples from the Boltzmann distribution."""
        raise NotImplementedError("Abstract")

    def run_energy(self, data):
        """Abstract method. Runs score at each step."""
        raise NotImplementedError("Abstract")

    def each_generate(self, *inputs):
        """Reports on generation."""
        pass

    def energy_step(self, data):
        """Performs a single step of discriminator training.

    Args:
      data: data points used for training.
    """
        self.optimizer.zero_grad()
        data = to_device(data, self.device)
        args = self.run_energy(data)
        loss_val = self.loss(*args)

        if self.verbose:
            for loss_name in self.current_losses:
                loss_float = self.current_losses[loss_name]
                self.writer.add_scalar(f"{loss_name} loss", loss_float,
                                       self.step_id)
        self.writer.add_scalar("discriminator total loss", float(loss_val),
                               self.step_id)

        loss_val.backward()
        self.optimizer.step()

    def step(self, data):
        """Performs a single step of GAN training, comprised of
    one or more steps of discriminator and generator training.

    Args:
      data: data points used for training."""
        self.energy_step(data)
        self.each_step()

    def train(self):
        """Trains an EBM until the maximum number of epochs is reached."""
        for epoch_id in range(self.max_epochs):
            self.train_data = None
            self.train_data = DataLoader(self.data,
                                         batch_size=self.batch_size,
                                         num_workers=self.num_workers,
                                         shuffle=True,
                                         drop_last=True)

            for data in self.train_data:
                self.step(data)
                self.log()
                self.step_id += 1
            self.epoch_id += 1

        scores = [getattr(self, name) for name in self.names]

        return scores
class MultistepTraining(Training, metaclass=StepRegistry):
    """Abstract base class for multi-step training."""
    checkpoint_parameters = Training.checkpoint_parameters + [TrainingState()]
    step_descriptors = {}
    step_order = []

    def __init__(self, networks, network_mapping, data, **kwargs):
        super().__init__(**kwargs)

        self.data = {}
        self.loaders = {}
        self.optimizers = {}
        for name, descriptor in self.step_descriptors.items():
            descriptor.n_steps_value = kwargs.get(
                descriptor.n_steps) or descriptor.n_steps_value
            descriptor.every_value = kwargs.get(
                descriptor.every) or descriptor.every_value
            self.optimizers[name] = []
            self.data[name] = None
            self.loaders[name] = None

        self.nets = {}
        for name, (network, opt, opt_kwargs) in networks.items():
            network = network.to(self.device)
            optimizer = opt(network.parameters(), **opt_kwargs)
            optimizer_name = f"{name}_optimizer"
            setattr(self, name, network)
            setattr(self, optimizer_name, optimizer)
            self.checkpoint_parameters += [
                NetState(name), NetState(optimizer_name)
            ]
            self.checkpoint_names.update({name: network})
            self.nets[name] = (network, optimizer)
        for step, names in network_mapping.items():
            self.optimizers[step] = [self.nets[name][1] for name in names]

        for step, data_set in data.items():
            self.data[step] = data_set
            self.loaders[step] = None

        self.current_losses = {}

    def get_data(self, step):
        if self.data[step] is None:
            return None
        data = self.loaders[step]
        if data is None:
            data = iter(
                DataLoader(self.data[step],
                           batch_size=self.batch_size,
                           num_workers=self.num_workers,
                           shuffle=True,
                           drop_last=True))
            self.loaders[step] = data
        try:
            data_point = to_device(next(data), self.device)
        except StopIteration:
            data = iter(
                DataLoader(self.data[step],
                           batch_size=self.batch_size,
                           num_workers=self.num_workers,
                           shuffle=True,
                           drop_last=True))
            self.loaders[step] = data
            data_point = to_device(next(data), self.device)
        return data_point

    def step(self):
        total_loss = 0.0
        for step_name in self.step_order:
            descriptor = self.step_descriptors[step_name]
            if self.step_id % descriptor.every_value == 0:
                for _ in range(descriptor.n_steps_value):
                    data = self.get_data(step_name)
                    for optimizer in self.optimizers[step_name]:
                        optimizer.zero_grad()
                    loss = descriptor.step(self, data)
                    if isinstance(loss, (list, tuple)):
                        loss, *_ = loss
                    if loss is not None:
                        loss.backward()
                    for optimizer in self.optimizers[step_name]:
                        optimizer.step()
                total_loss += float(loss)
        self.log_statistics(total_loss)
        self.each_step()

    def train(self):
        for step_id in range(self.max_steps):
            self.step_id += step_id
            self.step()
            self.log()

        return self.nets
Example #4
0
class SupervisedTraining(Training):
    """Standard supervised training process.

  Args:
    net (Module): a trainable network module.
    train_data (DataLoader): a :class:`DataLoader` returning the training
                              data set.
    validate_data (DataLoader): a :class:`DataLoader` return ing the
                                 validation data set.
    optimizer (Optimizer): an optimizer for the network. Defaults to ADAM.
    schedule (Schedule): a learning rate schedule. Defaults to decay when
                          stagnated.
    max_epochs (int): the maximum number of epochs to train.
    device (str): the device to run on.
    checkpoint_path (str): the path to save network checkpoints.
  """
    checkpoint_parameters = Training.checkpoint_parameters + [
        TrainingState(),
        NetState("net"),
        NetState("optimizer")
    ]

    def __init__(self,
                 net,
                 train_data,
                 validate_data,
                 losses,
                 optimizer=torch.optim.Adam,
                 schedule=None,
                 max_epochs=50,
                 batch_size=128,
                 accumulate=None,
                 device="cpu",
                 network_name="network",
                 path_prefix=".",
                 report_interval=10,
                 checkpoint_interval=1000,
                 valid_callback=lambda x: None):
        super(SupervisedTraining, self).__init__()
        self.valid_callback = valid_callback
        self.network_name = network_name
        self.writer = SummaryWriter(network_name)
        self.device = device
        self.accumulate = accumulate
        self.optimizer = optimizer(net.parameters())
        if schedule is None:
            self.schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, patience=10)
        else:
            self.schedule = schedule
        self.losses = losses
        self.train_data = DataLoader(train_data,
                                     batch_size=batch_size,
                                     num_workers=8,
                                     shuffle=True,
                                     drop_last=True)
        self.validate_data = DataLoader(validate_data,
                                        batch_size=batch_size,
                                        num_workers=8,
                                        shuffle=True,
                                        drop_last=True)
        self.net = net.to(self.device)
        self.max_epochs = max_epochs
        self.checkpoint_path = f"{path_prefix}/{network_name}-checkpoint"
        self.report_interval = report_interval
        self.checkpoint_interval = checkpoint_interval
        self.step_id = 0
        self.epoch_id = 0
        self.validation_losses = [0 for _ in range(len(self.losses))]
        self.training_losses = [0 for _ in range(len(self.losses))]
        self.best = None

    def save_path(self):
        return self.checkpoint_path + "-save.torch"

    def checkpoint(self):
        the_net = self.net
        if isinstance(the_net, torch.nn.DataParallel):
            the_net = the_net.module
        netwrite(
            self.net,
            f"{self.checkpoint_path}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
        )
        self.each_checkpoint()

    def run_networks(self, data):
        inputs, *labels = data
        if not isinstance(inputs, (list, tuple)):
            inputs = [inputs]
        predictions = self.net(*inputs)
        if not isinstance(predictions, (list, tuple)):
            predictions = [predictions]
        return [combined for combined in zip(predictions, labels)]

    def loss(self, inputs):
        loss_val = torch.tensor(0.0).to(self.device)
        for idx, the_input in enumerate(inputs):
            this_loss_val = self.losses[idx](*the_input)
            self.training_losses[idx] = float(this_loss_val)
            loss_val += this_loss_val
        return loss_val

    def valid_loss(self, inputs):
        training_cache = list(self.training_losses)
        loss_val = self.loss(inputs)
        self.validation_losses = self.training_losses
        self.training_losses = training_cache
        return loss_val

    def step(self, data):
        if self.accumulate is None:
            self.optimizer.zero_grad()
        outputs = self.run_networks(data)
        loss_val = self.loss(outputs)
        loss_val.backward()
        torch.nn.utils.clip_grad_norm_(self.net.parameters(), 5.0)
        if self.accumulate is None:
            self.optimizer.step()
        elif self.step_id % self.accumulate == 0:
            self.optimizer.step()
            self.optimizer.zero_grad()
        self.each_step()

    def validate(self, data):
        with torch.no_grad():
            self.net.eval()
            outputs = self.run_networks(data)
            self.valid_loss(outputs)
            self.each_validate()
            self.valid_callback(self, to_device(data, "cpu"),
                                to_device(outputs, "cpu"))
            self.net.train()

    def schedule_step(self):
        self.schedule.step(sum(self.validation_losses))

    def each_step(self):
        Training.each_step(self)
        for idx, loss in enumerate(self.training_losses):
            self.writer.add_scalar(f"training loss {idx}", loss, self.step_id)
        self.writer.add_scalar(f"training loss total",
                               sum(self.training_losses), self.step_id)

    def each_validate(self):
        for idx, loss in enumerate(self.validation_losses):
            self.writer.add_scalar(f"validation loss {idx}", loss,
                                   self.step_id)
        self.writer.add_scalar(f"validation loss total",
                               sum(self.validation_losses), self.step_id)

    def train(self):
        for epoch_id in range(self.max_epochs):
            self.epoch_id = epoch_id
            valid_iter = iter(self.validate_data)
            for data in self.train_data:
                data = to_device(data, self.device)
                self.step(data)
                if self.step_id % self.report_interval == 0:
                    vdata = None
                    try:
                        vdata = next(valid_iter)
                    except StopIteration:
                        valid_iter = iter(self.validate_data)
                        vdata = next(valid_iter)
                    vdata = to_device(vdata, self.device)
                    self.validate(vdata)
                if self.step_id % self.checkpoint_interval == 0:
                    self.checkpoint()
                self.step_id += 1
            self.schedule_step()
            self.each_epoch()
        return self.net
Example #5
0
class AbstractVAETraining(Training):
  """Abstract base class for VAE training."""
  checkpoint_parameters = Training.checkpoint_parameters + [
    TrainingState(),
    NetNameListState("network_names"),
    NetState("optimizer")
  ]
  def __init__(self, networks, data, valid=None,
               optimizer=torch.optim.Adam,
               optimizer_kwargs=None,
               gradient_clip=200.0,
               gradient_skip=400.0,
               **kwargs):
    """Generic training setup for variational autoencoders.

    Args:
      networks (list): networks used in the training step.
      data (Dataset): provider of training data.
      optimizer (Optimizer): optimizer class for gradient descent.
      optimizer_kwargs (dict): keyword arguments for the
        optimizer used in network training.
      max_epochs (int): maximum number of training epochs.
      batch_size (int): number of training samples per batch.
      device (string): device to use for training.
      network_name (string): identifier of the network architecture.
      verbose (bool): log all events and losses?
    """
    super(AbstractVAETraining, self).__init__(**kwargs)

    self.data = data
    self.valid = valid
    self.train_data = None
    self.valid_data = None
    self.gradient_clip = gradient_clip
    self.gradient_skip = gradient_skip

    self.valid_iter = None
    if self.valid is not None:
      self.valid_data = DataLoader(
        self.valid, batch_size=self.batch_size, num_workers=8,
        shuffle=True
      )
      self.valid_iter = iter(self.valid_data)

    self.network_names, netlist = self.collect_netlist(networks)

    if optimizer_kwargs is None:
      optimizer_kwargs = {"lr" : 5e-4}

    self.optimizer = optimizer(
      netlist,
      **optimizer_kwargs
    )
    self.checkpoint_names.update(
      self.get_netlist(self.network_names)
    )

  def divergence_loss(self, *args):
    """Abstract method. Computes the divergence loss."""
    raise NotImplementedError("Abstract")

  def reconstruction_loss(self, *args):
    """Abstract method. Computes the reconstruction loss."""
    raise NotImplementedError("Abstract")

  def loss(self, *args):
    """Abstract method. Computes the training loss."""
    raise NotImplementedError("Abstract")

  def sample(self, *args, **kwargs):
    """Abstract method. Samples from the latent distribution."""
    raise NotImplementedError("Abstract")

  def run_networks(self, data, *args):
    """Abstract method. Runs neural networks at each step."""
    raise NotImplementedError("Abstract")

  def preprocess(self, data):
    """Takes and partitions input data into VAE data and args."""
    return data

  def each_generate(self, *args):
    pass

  def step(self, data):
    """Performs a single step of VAE training.

    Args:
      data: data points used for training."""
    self.optimizer.zero_grad()
    data = to_device(data, self.device)
    data, *netargs = self.preprocess(data)
    args = self.run_networks(data, *netargs)

    loss_val = self.loss(*args)

    if self.verbose:
      if self.step_id % self.report_interval == 0:
        self.each_generate(*args)
      for loss_name in self.current_losses:
        loss_float = self.current_losses[loss_name]
        self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id)
    self.writer.add_scalar("total loss", float(loss_val), self.step_id)

    loss_val.backward()
    parameters = [
      param
      for key, val in self.get_netlist(self.network_names).items()
      for param in val.parameters()
    ]
    gn = nn.utils.clip_grad_norm_(parameters, self.gradient_clip)
    if (not torch.isnan(gn).any()) and (gn < self.gradient_skip).all():
      self.optimizer.step()
    self.each_step()

    return float(loss_val)

  def valid_step(self, data):
    """Performs a single step of VAE validation.

    Args:
      data: data points used for validation."""
    with torch.no_grad():
      if isinstance(data, (list, tuple)):
        data = [
          point.to(self.device)
          for point in data
        ]
      elif isinstance(data, dict):
        data = {
          key : data[key].to(self.device)
          for key in data
        }
      else:
        data = data.to(self.device)
      args = self.run_networks(data)

      loss_val = self.loss(*args)

    return float(loss_val)

  def validate(self, data):
    loss = self.valid_step(data)
    self.writer.add_scalar("valid loss", loss, self.step_id)
    self.each_validate()

  def run_report(self):
    if self.valid is not None:
      vdata = None
      try:
        vdata = next(self.valid_iter)
      except StopIteration:
        self.valid_iter = iter(self.valid_data)
        vdata = next(self.valid_iter)
      vdata = to_device(vdata, self.device)
      self.validate(vdata)

  def train(self):
    """Trains a VAE until the maximum number of epochs is reached."""
    for epoch_id in range(self.max_epochs):
      self.epoch_id = epoch_id
      self.train_data = None
      self.train_data = DataLoader(
        self.data, batch_size=self.batch_size, num_workers=8,
        shuffle=True
      )
      for data in self.train_data:
        self.step(data)
        self.log()
        self.step_id += 1

    netlist = self.get_netlist(self.network_names)

    return netlist
Example #6
0
class SupervisedTraining(Training):
  """Standard supervised training process.

  Args:
    net (Module): a trainable network module.
    train_data (DataLoader): a :class:`DataLoader` returning the training
                              data set.
    validate_data (DataLoader): a :class:`DataLoader` return ing the
                                 validation data set.
    optimizer (Optimizer): an optimizer for the network. Defaults to ADAM.
    schedule (Schedule): a learning rate schedule. Defaults to decay when
                          stagnated.
    max_epochs (int): the maximum number of epochs to train.
    device (str): the device to run on.
    checkpoint_path (str): the path to save network checkpoints.
  """
  checkpoint_parameters = Training.checkpoint_parameters + [
    TrainingState(),
    NetState("net"),
    NetState("optimizer")
  ]
  def __init__(self, net, train_data, validate_data, losses,
               optimizer=torch.optim.Adam,
               schedule=None,
               accumulate=None,
               valid_callback=None,
               **kwargs):
    super(SupervisedTraining, self).__init__(**kwargs)
    self.valid_callback = valid_callback or (lambda x, y, z: None)
    self.accumulate = accumulate
    self.optimizer = optimizer(net.parameters())
    if schedule is None:
      self.schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=10)
    else:
      self.schedule = schedule
    self.losses = losses
    self.train_data = DataLoader(
      train_data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True,
      prefetch_factor=self.prefetch_factor
    )
    self.valid_iter = None
    if validate_data is not None:
      self.validate_data = DataLoader(
        validate_data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True,
        prefetch_factor=self.prefetch_factor
      )
      self.valid_iter = iter(self.validate_data)
    
    self.net = net.to(self.device)

    self.checkpoint_names = dict(checkpoint=self.net)
    self.validation_losses = [0 for _ in range(len(self.losses))]
    self.training_losses = [0 for _ in range(len(self.losses))]
    self.best = None

  def run_networks(self, data):
    inputs, *labels = data
    if not isinstance(inputs, (list, tuple)):
      inputs = [inputs]
    predictions = self.net(*inputs)
    if not isinstance(predictions, (list, tuple)):
      predictions = [predictions]
    return [combined for combined in zip(predictions, labels)]

  def loss(self, inputs):
    loss_val = torch.tensor(0.0).to(self.device)
    for idx, the_input in enumerate(inputs):
      this_loss_val = self.losses[idx](*the_input)
      self.training_losses[idx] = float(this_loss_val)
      loss_val += this_loss_val
    return loss_val

  def valid_loss(self, inputs):
    training_cache = list(self.training_losses)
    loss_val = self.loss(inputs)
    self.validation_losses = self.training_losses
    self.training_losses = training_cache
    return loss_val

  def chunk(self, data, split):
    if torch.is_tensor(data):
      return data.split(len(data) // split)
    elif isinstance(data, (list, tuple)):
      result = [
        [] for idx in range(split)
      ]
      for item in data:
        for target, part in zip(result, self.chunk(item, split)):
          target.append(part)
      return result
    elif isinstance(data, dict):
      result = [
        {}
        for idx in range(split)
      ]
      for name in data:
        for dd, part in zip(result, self.chunk(data[name], split)):
          dd[name] = part
      return result
    else:
      return data

  def step(self, data):
    self.optimizer.zero_grad()
    if self.accumulate is not None:
      points = self.chunk(data, self.accumulate)
      for point in points:
        outputs = self.run_networks(point)
        loss_val = self.loss(outputs) / self.accumulate
        loss_val.backward()
    else:
      outputs = self.run_networks(data)
      loss_val = self.loss(outputs)
      loss_val.backward()
    torch.nn.utils.clip_grad_norm_(self.net.parameters(), 5.0)
    self.optimizer.step()
    self.each_step()

  def validate(self, data):
    with torch.no_grad():
      self.net.eval()
      if self.accumulate is not None:
        point = self.chunk(data, self.accumulate)[0]
        outputs = self.run_networks(point)
      else:
        outputs = self.run_networks(data)
      self.valid_loss(outputs)
      self.each_validate()
      self.valid_callback(
        self, to_device(data, "cpu"), to_device(outputs, "cpu")
      )
      self.net.train()

  def schedule_step(self):
    self.schedule.step(sum(self.validation_losses))

  def each_step(self):
    Training.each_step(self)
    for idx, loss in enumerate(self.training_losses):
      self.writer.add_scalar(f"training loss {idx}", loss, self.step_id)
    self.writer.add_scalar(f"training loss total", sum(self.training_losses), self.step_id)

  def each_validate(self):
    for idx, loss in enumerate(self.validation_losses):
      self.writer.add_scalar(f"validation loss {idx}", loss, self.step_id)
    self.writer.add_scalar(f"validation loss total", sum(self.validation_losses), self.step_id)

  def run_report(self):
    if self.valid_iter is not None:
      vdata = None
      try:
        vdata = next(self.valid_iter)
      except StopIteration:
        self.valid_iter = iter(self.validate_data)
        vdata = next(self.valid_iter)
      vdata = to_device(vdata, self.device)
      self.validate(vdata)

  def train(self):
    for epoch_id in range(self.max_epochs):
      for data in self.train_data:
        data = to_device(data, self.device)
        self.step(data)
        self.log()
        self.step_id += 1
      self.schedule_step()
      self.each_epoch()
      self.epoch_id += 1
    return self.net
Example #7
0
class AbstractGANTraining(Training):
    """Abstract base class for GAN training."""
    checkpoint_parameters = Training.checkpoint_parameters + [
        TrainingState(),
        NetNameListState("generator_names"),
        NetNameListState("discriminator_names"),
        NetState("generator_optimizer"),
        NetState("discriminator_optimizer")
    ]

    def __init__(self,
                 generators,
                 discriminators,
                 data,
                 optimizer=torch.optim.Adam,
                 generator_optimizer_kwargs=None,
                 discriminator_optimizer_kwargs=None,
                 n_critic=1,
                 n_actor=1,
                 max_epochs=50,
                 batch_size=128,
                 device="cpu",
                 path_prefix=".",
                 network_name="network",
                 verbose=False,
                 report_interval=10,
                 checkpoint_interval=1000):
        """Generic training setup for generative adversarial networks.

    Args:
      generators (list): networks used in the generation step.
      discriminators (list): networks used in the discriminator step.
      data (Dataset): provider of training data.
      optimizer (Optimizer): optimizer class for gradient descent.
      generator_optimizer_kwargs (dict): keyword arguments for the
        optimizer used in generator training.
      discriminator_optimizer_kwargs (dict): keyword arguments for the
        optimizer used in discriminator training.
      n_critic (int): number of critic training iterations per step.
      n_actor (int): number of actor training iterations per step.
      max_epochs (int): maximum number of training epochs.
      batch_size (int): number of training samples per batch.
      device (string): device to use for training.
      network_name (string): identifier of the network architecture.
      verbose (bool): log all events and losses?
    """
        super(AbstractGANTraining, self).__init__()

        self.verbose = verbose
        self.report_interval = report_interval
        self.checkpoint_interval = checkpoint_interval
        self.checkpoint_path = f"{path_prefix}/{network_name}"

        self.n_critic = n_critic
        self.n_actor = n_actor

        generator_netlist = []
        self.generator_names = []
        for network in generators:
            self.generator_names.append(network)
            network_object = generators[network].to(device)
            setattr(self, network, network_object)
            generator_netlist.extend(list(network_object.parameters()))

        discriminator_netlist = []
        self.discriminator_names = []
        for network in discriminators:
            self.discriminator_names.append(network)
            network_object = discriminators[network].to(device)
            setattr(self, network, network_object)
            discriminator_netlist.extend(list(network_object.parameters()))

        self.data = data
        self.train_data = None
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.device = device

        self.current_losses = {}
        self.network_name = network_name
        self.writer = SummaryWriter(network_name)

        self.epoch_id = 0
        self.step_id = 0

        if generator_optimizer_kwargs is None:
            generator_optimizer_kwargs = {"lr": 5e-4}
        if discriminator_optimizer_kwargs is None:
            discriminator_optimizer_kwargs = {"lr": 5e-4}

        self.generator_optimizer = optimizer(generator_netlist,
                                             **generator_optimizer_kwargs)
        self.discriminator_optimizer = optimizer(
            discriminator_netlist, **discriminator_optimizer_kwargs)

    def save_path(self):
        return f"{self.checkpoint_path}-save.torch"

    def generator_loss(self, *args):
        """Abstract method. Computes the generator loss."""
        raise NotImplementedError("Abstract")

    def generator_step_loss(self, *args):
        """Computes the losses of all generators."""
        return self.generator_loss(*args)

    def discriminator_loss(self, *args):
        """Abstract method. Computes the discriminator loss."""
        raise NotImplementedError("Abstract")

    def discriminator_step_loss(self, *args):
        """Computes the losses of all discriminators."""
        return self.discriminator_loss(*args)

    def sample(self, *args, **kwargs):
        """Abstract method. Samples from the latent distribution."""
        raise NotImplementedError("Abstract")

    def run_generator(self, data):
        """Abstract method. Runs generation at each step."""
        raise NotImplementedError("Abstract")

    def run_discriminator(self, data):
        """Abstract method. Runs discriminator training."""
        raise NotImplementedError("Abstract")

    def each_generate(self, *inputs):
        """Reports on generation."""
        pass

    def discriminator_step(self, data):
        """Performs a single step of discriminator training.

    Args:
      data: data points used for training.
    """
        self.discriminator_optimizer.zero_grad()
        data = to_device(data, self.device)
        args = self.run_discriminator(data)
        loss_val, *grad_out = self.discriminator_step_loss(*args)

        if self.verbose:
            for loss_name in self.current_losses:
                loss_float = self.current_losses[loss_name]
                self.writer.add_scalar(f"{loss_name} loss", loss_float,
                                       self.step_id)
        self.writer.add_scalar("discriminator total loss", float(loss_val),
                               self.step_id)

        loss_val.backward()
        self.discriminator_optimizer.step()

    def generator_step(self, data):
        """Performs a single step of generator training.

    Args:
      data: data points used for training.
    """
        self.generator_optimizer.zero_grad()
        data = to_device(data, self.device)
        args = self.run_generator(data)
        loss_val = self.generator_step_loss(*args)

        if self.verbose:
            if self.step_id % self.report_interval == 0:
                self.each_generate(*args)
            for loss_name in self.current_losses:
                loss_float = self.current_losses[loss_name]
                self.writer.add_scalar(f"{loss_name} loss", loss_float,
                                       self.step_id)
        self.writer.add_scalar("generator total loss", float(loss_val),
                               self.step_id)

        loss_val.backward()
        self.generator_optimizer.step()

    def step(self, data):
        """Performs a single step of GAN training, comprised of
    one or more steps of discriminator and generator training.

    Args:
      data: data points used for training."""
        for _ in range(self.n_critic):
            self.discriminator_step(next(data))
        for _ in range(self.n_actor):
            self.generator_step(next(data))
        self.each_step()

    def checkpoint(self):
        """Performs a checkpoint of all generators and discriminators."""
        for name in self.generator_names:
            the_net = getattr(self, name)
            if isinstance(the_net, torch.nn.DataParallel):
                the_net = the_net.module
            netwrite(
                the_net,
                f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
            )
        for name in self.discriminator_names:
            the_net = getattr(self, name)
            if isinstance(the_net, torch.nn.DataParallel):
                the_net = the_net.module
            netwrite(
                the_net,
                f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
            )
        self.each_checkpoint()

    def train(self):
        """Trains a GAN until the maximum number of epochs is reached."""
        for epoch_id in range(self.max_epochs):
            self.epoch_id = epoch_id
            self.train_data = None
            self.train_data = DataLoader(self.data,
                                         batch_size=self.batch_size,
                                         num_workers=8,
                                         shuffle=True,
                                         drop_last=True)

            batches_per_step = self.n_actor + self.n_critic
            steps_per_episode = len(self.train_data) // batches_per_step

            data = iter(self.train_data)
            for _ in range(steps_per_episode):
                self.step(data)
                if self.step_id % self.checkpoint_interval == 0:
                    self.checkpoint()
                self.step_id += 1

        generators = [getattr(self, name) for name in self.generator_names]

        discriminators = [
            getattr(self, name) for name in self.discriminator_names
        ]

        return generators, discriminators
Example #8
0
class AbstractContrastiveTraining(Training):
  """Abstract base class for contrastive training."""
  checkpoint_parameters = Training.checkpoint_parameters + [
    TrainingState(),
    NetNameListState("names")
  ]
  def __init__(self, networks, data,
               optimizer=torch.optim.Adam,
               optimizer_kwargs=None,
               num_workers=8,
               **kwargs):
    """Generic training setup for energy/score based models.

    Args:
      networks (list): networks used for contrastive learning.
      data (Dataset): provider of training data.
      optimizer (Optimizer): optimizer class for gradient descent.
      optimizer_kwargs (dict): keyword arguments for the
        optimizer used in score function training.
    """
    super().__init__(**kwargs)

    netlist = []
    self.names, netlist = self.collect_netlist(networks)

    self.data = data
    self.train_data = None
    self.num_workers = num_workers

    self.current_losses = {}

    if optimizer_kwargs is None:
      optimizer_kwargs = {"lr" : 5e-4}

    self.optimizer = optimizer(
      netlist,
      **optimizer_kwargs
    )

    self.checkpoint_names = self.get_netlist(self.names)

  def contrastive_loss(self, *args):
    """Abstract method. Computes the contrastive loss."""
    raise NotImplementedError("Abstract")

  def regularization(self, *args):
    """Computes network regularization."""
    return 0.0

  def loss(self, *args):
    contrastive = self.contrastive_loss(*args)
    regularization = self.regularization(*args)
    return contrastive + regularization

  def run_networks(self, data):
    """Abstract method. Runs networks at each step."""
    raise NotImplementedError("Abstract")

  def visualize(self, data):
    pass

  def contrastive_step(self, data):
    """Performs a single step of contrastive training.

    Args:
      data: data points used for training.
    """
    if self.step_id % self.report_interval == 0:
      self.visualize(data)

    self.optimizer.zero_grad()
    data = to_device(data, self.device)
    make_differentiable(data)
    args = self.run_networks(data)
    loss_val = self.loss(*args)

    self.log_statistics(loss_val, name="total loss")

    loss_val.backward()
    self.optimizer.step()

  def step(self, data):
    """Performs a single step of contrastive training.

    Args:
      data: data points used for training.
    """
    self.contrastive_step(data)
    self.each_step()

  def train(self):
    """Runs contrastive training until the maximum number of epochs is reached."""
    for epoch_id in range(self.max_epochs):
      self.epoch_id = epoch_id
      self.train_data = None
      self.train_data = DataLoader(
        self.data, batch_size=self.batch_size, num_workers=self.num_workers,
        shuffle=True, drop_last=True
      )

      for data in self.train_data:
        self.step(data)
        self.log()
        self.step_id += 1

    return self.get_netlist(self.names)
class OffPolicyTraining(Training):
    checkpoint_parameters = Training.checkpoint_parameters + [
        TrainingState(),
        NetNameListState("auxiliary_names"),
        NetState("policy"),
        NetState("optimizer"),
        NetState("auxiliary_optimizer")
    ]

    def __init__(self,
                 policy,
                 agent,
                 environment,
                 auxiliary_networks=None,
                 buffer_size=100_000,
                 piecewise_append=False,
                 policy_steps=1,
                 auxiliary_steps=1,
                 n_workers=8,
                 discount=0.99,
                 double=False,
                 optimizer=torch.optim.Adam,
                 optimizer_kwargs=None,
                 aux_optimizer=torch.optim.Adam,
                 aux_optimizer_kwargs=None,
                 **kwargs):
        super().__init__(**kwargs)
        self.policy_steps = policy_steps
        self.auxiliary_steps = auxiliary_steps

        self.current_losses = {}

        self.statistics = ExperienceStatistics()
        self.discount = discount
        self.environment = environment
        self.agent = agent
        self.policy = policy.to(self.device)
        self.collector = EnvironmentCollector(environment,
                                              agent,
                                              discount=discount)
        self.distributor = DefaultDistributor()
        self.data_collector = ExperienceCollector(self.distributor,
                                                  self.collector,
                                                  n_workers=n_workers,
                                                  piecewise=piecewise_append)
        self.buffer = SchemaBuffer(self.data_collector.schema(),
                                   buffer_size,
                                   double=double)

        optimizer_kwargs = optimizer_kwargs or {}
        self.optimizer = optimizer(self.policy.parameters(),
                                   **optimizer_kwargs)

        auxiliary_netlist = []
        self.auxiliary_names = []
        for network in auxiliary_networks:
            self.auxiliary_names.append(network)
            network_object = auxiliary_networks[network].to(self.device)
            setattr(self, network, network_object)
            auxiliary_netlist.extend(list(network_object.parameters()))

        aux_optimizer_kwargs = aux_optimizer_kwargs or {}
        self.auxiliary_optimizer = aux_optimizer(auxiliary_netlist,
                                                 **aux_optimizer_kwargs)

        self.checkpoint_names = dict(
            policy=self.policy,
            **{name: getattr(self, name)
               for name in self.auxiliary_names})