class AbstractEnergyTraining(Training): """Abstract base class for GAN training.""" checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("names") ] def __init__(self, scores, data, optimizer=torch.optim.Adam, optimizer_kwargs=None, num_workers=8, **kwargs): """Generic training setup for energy/score based models. Args: scores (list): networks used for scoring. data (Dataset): provider of training data. optimizer (Optimizer): optimizer class for gradient descent. optimizer_kwargs (dict): keyword arguments for the optimizer used in score function training. max_epochs (int): maximum number of training epochs. batch_size (int): number of training samples per batch. device (string): device to use for training. network_name (string): identifier of the network architecture. verbose (bool): log all events and losses? """ super(AbstractEnergyTraining, self).__init__(**kwargs) netlist = [] self.names = [] for network in scores: self.names.append(network) network_object = scores[network].to(self.device) setattr(self, network, network_object) netlist.extend(list(network_object.parameters())) self.num_workers = num_workers self.data = data self.train_data = None self.current_losses = {} if optimizer_kwargs is None: optimizer_kwargs = {"lr": 5e-4} self.optimizer = optimizer(netlist, **optimizer_kwargs) self.checkpoint_names = { name: getattr(self, name) for name in self.names } def energy_loss(self, *args): """Abstract method. Computes the score function loss.""" raise NotImplementedError("Abstract") def loss(self, *args): return self.energy_loss(*args) def prepare(self, *args, **kwargs): """Abstract method. Prepares an initial state for sampling.""" raise NotImplementedError("Abstract") def sample(self, *args, **kwargs): """Abstract method. Samples from the Boltzmann distribution.""" raise NotImplementedError("Abstract") def run_energy(self, data): """Abstract method. Runs score at each step.""" raise NotImplementedError("Abstract") def each_generate(self, *inputs): """Reports on generation.""" pass def energy_step(self, data): """Performs a single step of discriminator training. Args: data: data points used for training. """ self.optimizer.zero_grad() data = to_device(data, self.device) args = self.run_energy(data) loss_val = self.loss(*args) if self.verbose: for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.writer.add_scalar("discriminator total loss", float(loss_val), self.step_id) loss_val.backward() self.optimizer.step() def step(self, data): """Performs a single step of GAN training, comprised of one or more steps of discriminator and generator training. Args: data: data points used for training.""" self.energy_step(data) self.each_step() def train(self): """Trains an EBM until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True) for data in self.train_data: self.step(data) self.log() self.step_id += 1 self.epoch_id += 1 scores = [getattr(self, name) for name in self.names] return scores
class AbstractVAETraining(Training): """Abstract base class for VAE training.""" checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("network_names"), NetState("optimizer") ] def __init__(self, networks, data, valid=None, optimizer=torch.optim.Adam, optimizer_kwargs=None, max_epochs=50, batch_size=128, device="cpu", path_prefix=".", network_name="network", report_interval=10, checkpoint_interval=1000, verbose=False): """Generic training setup for variational autoencoders. Args: networks (list): networks used in the training step. data (Dataset): provider of training data. optimizer (Optimizer): optimizer class for gradient descent. optimizer_kwargs (dict): keyword arguments for the optimizer used in network training. max_epochs (int): maximum number of training epochs. batch_size (int): number of training samples per batch. device (string): device to use for training. network_name (string): identifier of the network architecture. verbose (bool): log all events and losses? """ super(AbstractVAETraining, self).__init__() self.verbose = verbose self.checkpoint_path = f"{path_prefix}/{network_name}" self.report_interval = report_interval self.checkpoint_interval = checkpoint_interval self.data = data self.valid = valid self.train_data = None self.valid_data = None self.max_epochs = max_epochs self.batch_size = batch_size self.device = device netlist = [] self.network_names = [] for network in networks: self.network_names.append(network) network_object = networks[network].to(self.device) setattr(self, network, network_object) netlist.extend(list(network_object.parameters())) self.current_losses = {} self.network_name = network_name self.writer = SummaryWriter(network_name) self.epoch_id = 0 self.step_id = 0 if optimizer_kwargs is None: optimizer_kwargs = {"lr": 5e-4} self.optimizer = optimizer(netlist, **optimizer_kwargs) def save_path(self): return f"{self.checkpoint_path}-save.torch" def divergence_loss(self, *args): """Abstract method. Computes the divergence loss.""" raise NotImplementedError("Abstract") def reconstruction_loss(self, *args): """Abstract method. Computes the reconstruction loss.""" raise NotImplementedError("Abstract") def loss(self, *args): """Abstract method. Computes the training loss.""" raise NotImplementedError("Abstract") def sample(self, *args, **kwargs): """Abstract method. Samples from the latent distribution.""" raise NotImplementedError("Abstract") def run_networks(self, data, *args): """Abstract method. Runs neural networks at each step.""" raise NotImplementedError("Abstract") def preprocess(self, data): """Takes and partitions input data into VAE data and args.""" return data def step(self, data): """Performs a single step of VAE training. Args: data: data points used for training.""" self.optimizer.zero_grad() data = to_device(data, self.device) data, *netargs = self.preprocess(data) args = self.run_networks(data, *netargs) loss_val = self.loss(*args) if self.verbose: for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.writer.add_scalar("total loss", float(loss_val), self.step_id) loss_val.backward() self.optimizer.step() self.each_step() return float(loss_val) def valid_step(self, data): """Performs a single step of VAE validation. Args: data: data points used for validation.""" with torch.no_grad(): if isinstance(data, (list, tuple)): data = [point.to(self.device) for point in data] elif isinstance(data, dict): data = {key: data[key].to(self.device) for key in data} else: data = data.to(self.device) args = self.run_networks(data) loss_val = self.loss(*args) return float(loss_val) def checkpoint(self): """Performs a checkpoint for all encoders and decoders.""" for name in self.network_names: the_net = getattr(self, name) if isinstance(the_net, torch.nn.DataParallel): the_net = the_net.module netwrite( the_net, f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch" ) self.each_checkpoint() def validate(self, data): loss = self.valid_step(data) self.writer.add_scalar("valid loss", loss, self.step_id) self.each_validate() def train(self): """Trains a VAE until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=8, shuffle=True) if self.valid is not None: self.valid_data = DataLoader(self.valid, batch_size=self.batch_size, num_workers=8, shuffle=True) for data in self.train_data: self.step(data) if self.step_id % self.checkpoint_interval == 0: self.checkpoint() if self.valid is not None and self.step_id % self.report_interval == 0: vdata = None try: vdata = next(valid_iter) except StopIteration: valid_iter = iter(self.valid_data) vdata = next(valid_iter) vdata = to_device(vdata, self.device) self.validate(vdata) self.step_id += 1 netlist = [getattr(self, name) for name in self.network_names] return netlist
class AbstractVAETraining(Training): """Abstract base class for VAE training.""" checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("network_names"), NetState("optimizer") ] def __init__(self, networks, data, valid=None, optimizer=torch.optim.Adam, optimizer_kwargs=None, gradient_clip=200.0, gradient_skip=400.0, **kwargs): """Generic training setup for variational autoencoders. Args: networks (list): networks used in the training step. data (Dataset): provider of training data. optimizer (Optimizer): optimizer class for gradient descent. optimizer_kwargs (dict): keyword arguments for the optimizer used in network training. max_epochs (int): maximum number of training epochs. batch_size (int): number of training samples per batch. device (string): device to use for training. network_name (string): identifier of the network architecture. verbose (bool): log all events and losses? """ super(AbstractVAETraining, self).__init__(**kwargs) self.data = data self.valid = valid self.train_data = None self.valid_data = None self.gradient_clip = gradient_clip self.gradient_skip = gradient_skip self.valid_iter = None if self.valid is not None: self.valid_data = DataLoader( self.valid, batch_size=self.batch_size, num_workers=8, shuffle=True ) self.valid_iter = iter(self.valid_data) self.network_names, netlist = self.collect_netlist(networks) if optimizer_kwargs is None: optimizer_kwargs = {"lr" : 5e-4} self.optimizer = optimizer( netlist, **optimizer_kwargs ) self.checkpoint_names.update( self.get_netlist(self.network_names) ) def divergence_loss(self, *args): """Abstract method. Computes the divergence loss.""" raise NotImplementedError("Abstract") def reconstruction_loss(self, *args): """Abstract method. Computes the reconstruction loss.""" raise NotImplementedError("Abstract") def loss(self, *args): """Abstract method. Computes the training loss.""" raise NotImplementedError("Abstract") def sample(self, *args, **kwargs): """Abstract method. Samples from the latent distribution.""" raise NotImplementedError("Abstract") def run_networks(self, data, *args): """Abstract method. Runs neural networks at each step.""" raise NotImplementedError("Abstract") def preprocess(self, data): """Takes and partitions input data into VAE data and args.""" return data def each_generate(self, *args): pass def step(self, data): """Performs a single step of VAE training. Args: data: data points used for training.""" self.optimizer.zero_grad() data = to_device(data, self.device) data, *netargs = self.preprocess(data) args = self.run_networks(data, *netargs) loss_val = self.loss(*args) if self.verbose: if self.step_id % self.report_interval == 0: self.each_generate(*args) for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.writer.add_scalar("total loss", float(loss_val), self.step_id) loss_val.backward() parameters = [ param for key, val in self.get_netlist(self.network_names).items() for param in val.parameters() ] gn = nn.utils.clip_grad_norm_(parameters, self.gradient_clip) if (not torch.isnan(gn).any()) and (gn < self.gradient_skip).all(): self.optimizer.step() self.each_step() return float(loss_val) def valid_step(self, data): """Performs a single step of VAE validation. Args: data: data points used for validation.""" with torch.no_grad(): if isinstance(data, (list, tuple)): data = [ point.to(self.device) for point in data ] elif isinstance(data, dict): data = { key : data[key].to(self.device) for key in data } else: data = data.to(self.device) args = self.run_networks(data) loss_val = self.loss(*args) return float(loss_val) def validate(self, data): loss = self.valid_step(data) self.writer.add_scalar("valid loss", loss, self.step_id) self.each_validate() def run_report(self): if self.valid is not None: vdata = None try: vdata = next(self.valid_iter) except StopIteration: self.valid_iter = iter(self.valid_data) vdata = next(self.valid_iter) vdata = to_device(vdata, self.device) self.validate(vdata) def train(self): """Trains a VAE until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader( self.data, batch_size=self.batch_size, num_workers=8, shuffle=True ) for data in self.train_data: self.step(data) self.log() self.step_id += 1 netlist = self.get_netlist(self.network_names) return netlist
class AbstractGANTraining(Training): """Abstract base class for GAN training.""" checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("generator_names"), NetNameListState("discriminator_names"), NetState("generator_optimizer"), NetState("discriminator_optimizer") ] def __init__(self, generators, discriminators, data, optimizer=torch.optim.Adam, generator_optimizer_kwargs=None, discriminator_optimizer_kwargs=None, n_critic=1, n_actor=1, max_epochs=50, batch_size=128, device="cpu", path_prefix=".", network_name="network", verbose=False, report_interval=10, checkpoint_interval=1000): """Generic training setup for generative adversarial networks. Args: generators (list): networks used in the generation step. discriminators (list): networks used in the discriminator step. data (Dataset): provider of training data. optimizer (Optimizer): optimizer class for gradient descent. generator_optimizer_kwargs (dict): keyword arguments for the optimizer used in generator training. discriminator_optimizer_kwargs (dict): keyword arguments for the optimizer used in discriminator training. n_critic (int): number of critic training iterations per step. n_actor (int): number of actor training iterations per step. max_epochs (int): maximum number of training epochs. batch_size (int): number of training samples per batch. device (string): device to use for training. network_name (string): identifier of the network architecture. verbose (bool): log all events and losses? """ super(AbstractGANTraining, self).__init__() self.verbose = verbose self.report_interval = report_interval self.checkpoint_interval = checkpoint_interval self.checkpoint_path = f"{path_prefix}/{network_name}" self.n_critic = n_critic self.n_actor = n_actor generator_netlist = [] self.generator_names = [] for network in generators: self.generator_names.append(network) network_object = generators[network].to(device) setattr(self, network, network_object) generator_netlist.extend(list(network_object.parameters())) discriminator_netlist = [] self.discriminator_names = [] for network in discriminators: self.discriminator_names.append(network) network_object = discriminators[network].to(device) setattr(self, network, network_object) discriminator_netlist.extend(list(network_object.parameters())) self.data = data self.train_data = None self.max_epochs = max_epochs self.batch_size = batch_size self.device = device self.current_losses = {} self.network_name = network_name self.writer = SummaryWriter(network_name) self.epoch_id = 0 self.step_id = 0 if generator_optimizer_kwargs is None: generator_optimizer_kwargs = {"lr": 5e-4} if discriminator_optimizer_kwargs is None: discriminator_optimizer_kwargs = {"lr": 5e-4} self.generator_optimizer = optimizer(generator_netlist, **generator_optimizer_kwargs) self.discriminator_optimizer = optimizer( discriminator_netlist, **discriminator_optimizer_kwargs) def save_path(self): return f"{self.checkpoint_path}-save.torch" def generator_loss(self, *args): """Abstract method. Computes the generator loss.""" raise NotImplementedError("Abstract") def generator_step_loss(self, *args): """Computes the losses of all generators.""" return self.generator_loss(*args) def discriminator_loss(self, *args): """Abstract method. Computes the discriminator loss.""" raise NotImplementedError("Abstract") def discriminator_step_loss(self, *args): """Computes the losses of all discriminators.""" return self.discriminator_loss(*args) def sample(self, *args, **kwargs): """Abstract method. Samples from the latent distribution.""" raise NotImplementedError("Abstract") def run_generator(self, data): """Abstract method. Runs generation at each step.""" raise NotImplementedError("Abstract") def run_discriminator(self, data): """Abstract method. Runs discriminator training.""" raise NotImplementedError("Abstract") def each_generate(self, *inputs): """Reports on generation.""" pass def discriminator_step(self, data): """Performs a single step of discriminator training. Args: data: data points used for training. """ self.discriminator_optimizer.zero_grad() data = to_device(data, self.device) args = self.run_discriminator(data) loss_val, *grad_out = self.discriminator_step_loss(*args) if self.verbose: for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.writer.add_scalar("discriminator total loss", float(loss_val), self.step_id) loss_val.backward() self.discriminator_optimizer.step() def generator_step(self, data): """Performs a single step of generator training. Args: data: data points used for training. """ self.generator_optimizer.zero_grad() data = to_device(data, self.device) args = self.run_generator(data) loss_val = self.generator_step_loss(*args) if self.verbose: if self.step_id % self.report_interval == 0: self.each_generate(*args) for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.writer.add_scalar("generator total loss", float(loss_val), self.step_id) loss_val.backward() self.generator_optimizer.step() def step(self, data): """Performs a single step of GAN training, comprised of one or more steps of discriminator and generator training. Args: data: data points used for training.""" for _ in range(self.n_critic): self.discriminator_step(next(data)) for _ in range(self.n_actor): self.generator_step(next(data)) self.each_step() def checkpoint(self): """Performs a checkpoint of all generators and discriminators.""" for name in self.generator_names: the_net = getattr(self, name) if isinstance(the_net, torch.nn.DataParallel): the_net = the_net.module netwrite( the_net, f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch" ) for name in self.discriminator_names: the_net = getattr(self, name) if isinstance(the_net, torch.nn.DataParallel): the_net = the_net.module netwrite( the_net, f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch" ) self.each_checkpoint() def train(self): """Trains a GAN until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=8, shuffle=True, drop_last=True) batches_per_step = self.n_actor + self.n_critic steps_per_episode = len(self.train_data) // batches_per_step data = iter(self.train_data) for _ in range(steps_per_episode): self.step(data) if self.step_id % self.checkpoint_interval == 0: self.checkpoint() self.step_id += 1 generators = [getattr(self, name) for name in self.generator_names] discriminators = [ getattr(self, name) for name in self.discriminator_names ] return generators, discriminators
class AbstractContrastiveTraining(Training): """Abstract base class for contrastive training.""" checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("names") ] def __init__(self, networks, data, optimizer=torch.optim.Adam, optimizer_kwargs=None, num_workers=8, **kwargs): """Generic training setup for energy/score based models. Args: networks (list): networks used for contrastive learning. data (Dataset): provider of training data. optimizer (Optimizer): optimizer class for gradient descent. optimizer_kwargs (dict): keyword arguments for the optimizer used in score function training. """ super().__init__(**kwargs) netlist = [] self.names, netlist = self.collect_netlist(networks) self.data = data self.train_data = None self.num_workers = num_workers self.current_losses = {} if optimizer_kwargs is None: optimizer_kwargs = {"lr" : 5e-4} self.optimizer = optimizer( netlist, **optimizer_kwargs ) self.checkpoint_names = self.get_netlist(self.names) def contrastive_loss(self, *args): """Abstract method. Computes the contrastive loss.""" raise NotImplementedError("Abstract") def regularization(self, *args): """Computes network regularization.""" return 0.0 def loss(self, *args): contrastive = self.contrastive_loss(*args) regularization = self.regularization(*args) return contrastive + regularization def run_networks(self, data): """Abstract method. Runs networks at each step.""" raise NotImplementedError("Abstract") def visualize(self, data): pass def contrastive_step(self, data): """Performs a single step of contrastive training. Args: data: data points used for training. """ if self.step_id % self.report_interval == 0: self.visualize(data) self.optimizer.zero_grad() data = to_device(data, self.device) make_differentiable(data) args = self.run_networks(data) loss_val = self.loss(*args) self.log_statistics(loss_val, name="total loss") loss_val.backward() self.optimizer.step() def step(self, data): """Performs a single step of contrastive training. Args: data: data points used for training. """ self.contrastive_step(data) self.each_step() def train(self): """Runs contrastive training until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader( self.data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True ) for data in self.train_data: self.step(data) self.log() self.step_id += 1 return self.get_netlist(self.names)
class OffPolicyTraining(Training): checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("auxiliary_names"), NetState("policy"), NetState("optimizer"), NetState("auxiliary_optimizer") ] def __init__(self, policy, agent, environment, auxiliary_networks=None, buffer_size=100_000, piecewise_append=False, policy_steps=1, auxiliary_steps=1, n_workers=8, discount=0.99, double=False, optimizer=torch.optim.Adam, optimizer_kwargs=None, aux_optimizer=torch.optim.Adam, aux_optimizer_kwargs=None, **kwargs): super().__init__(**kwargs) self.policy_steps = policy_steps self.auxiliary_steps = auxiliary_steps self.current_losses = {} self.statistics = ExperienceStatistics() self.discount = discount self.environment = environment self.agent = agent self.policy = policy.to(self.device) self.collector = EnvironmentCollector(environment, agent, discount=discount) self.distributor = DefaultDistributor() self.data_collector = ExperienceCollector(self.distributor, self.collector, n_workers=n_workers, piecewise=piecewise_append) self.buffer = SchemaBuffer(self.data_collector.schema(), buffer_size, double=double) optimizer_kwargs = optimizer_kwargs or {} self.optimizer = optimizer(self.policy.parameters(), **optimizer_kwargs) auxiliary_netlist = [] self.auxiliary_names = [] for network in auxiliary_networks: self.auxiliary_names.append(network) network_object = auxiliary_networks[network].to(self.device) setattr(self, network, network_object) auxiliary_netlist.extend(list(network_object.parameters())) aux_optimizer_kwargs = aux_optimizer_kwargs or {} self.auxiliary_optimizer = aux_optimizer(auxiliary_netlist, **aux_optimizer_kwargs) self.checkpoint_names = dict( policy=self.policy, **{name: getattr(self, name) for name in self.auxiliary_names})