def __init__(self, networks, network_mapping, data, **kwargs): super().__init__(**kwargs) self.data = {} self.loaders = {} self.optimizers = {} for name, descriptor in self.step_descriptors.items(): descriptor.n_steps_value = kwargs.get( descriptor.n_steps) or descriptor.n_steps_value descriptor.every_value = kwargs.get( descriptor.every) or descriptor.every_value self.optimizers[name] = [] self.data[name] = None self.loaders[name] = None self.nets = {} for name, (network, opt, opt_kwargs) in networks.items(): network = network.to(self.device) optimizer = opt(network.parameters(), **opt_kwargs) optimizer_name = f"{name}_optimizer" setattr(self, name, network) setattr(self, optimizer_name, optimizer) self.checkpoint_parameters += [ NetState(name), NetState(optimizer_name) ] self.checkpoint_names.update({name: network}) self.nets[name] = (network, optimizer) for step, names in network_mapping.items(): self.optimizers[step] = [self.nets[name][1] for name in names] for step, data_set in data.items(): self.data[step] = data_set self.loaders[step] = None self.current_losses = {}
def __init__(self, classifier, optimizer=torch.optim.Adam, classifier_optimizer_kwargs=None): self.checkpoint_parameters += [ NetState("classifier"), NetState("classifier_optimizer") ] if classifier_optimizer_kwargs is None: classifier_optimizer_kwargs = {} self.classifier = classifier.to(self.device) self.classifier_optimizer = optimizer(self.classifier.parameters(), **classifier_optimizer_kwargs) self.discriminator_names.append("classifier")
def __init__(self, classifier, fixed=False, autonomous=False, optimizer=torch.optim.Adam, classifier_optimizer_kwargs=None): self.checkpoint_parameters += [ NetState("classifier"), NetState("classifier_optimizer") ] if classifier_optimizer_kwargs is None: classifier_optimizer_kwargs = {} self.autonomous = autonomous self.fixed = fixed self.classifier = classifier.to(self.device) self.classifier_optimizer = optimizer( self.classifier.parameters(), **classifier_optimizer_kwargs ) self.discriminator_names.append("classifier") self.checkpoint_names.update(dict( classifier=self.classifier ))
def __init__(self, encoder, decoder, discriminator, data, optimizer=torch.optim.Adam, optimizer_kwargs=None, **kwargs): """Training setup for FactorVAE - VAE with disentangled latent space. Args: encoder (nn.Module): encoder neural network. decoder (nn.Module): decoder neural network. discriminator (nn.Module): auxiliary discriminator for approximation of latent space total correlation. data (Dataset): dataset providing training data. c_target (float): target KL-divergence for continuous latent variables in nats. d_target (float): target KL-divergence for discrete latent variables in nats. gamma (float): scaling factor for KL-divergence constraints. temperature (float): temperature parameter of the concrete distribution. kwargs (dict): keyword arguments for generic VAE training. """ super(FactorVAETraining, self).__init__( encoder, decoder, data, optimizer=optimizer, **kwargs ) self.discriminator = discriminator.to(kwargs.get("device", "cpu")) optimizer_kwargs = optimizer_kwargs or {"lr": 1e-4} self.discriminator_optimizer = optimizer( self.discriminator.parameters(), **optimizer_kwargs ) self.checkpoint_parameters.append(NetState("discriminator")) self.checkpoint_names.update(dict( discriminator=self.discriminator ))
class AbstractVAETraining(Training): """Abstract base class for VAE training.""" checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("network_names"), NetState("optimizer") ] def __init__(self, networks, data, valid=None, optimizer=torch.optim.Adam, optimizer_kwargs=None, max_epochs=50, batch_size=128, device="cpu", path_prefix=".", network_name="network", report_interval=10, checkpoint_interval=1000, verbose=False): """Generic training setup for variational autoencoders. Args: networks (list): networks used in the training step. data (Dataset): provider of training data. optimizer (Optimizer): optimizer class for gradient descent. optimizer_kwargs (dict): keyword arguments for the optimizer used in network training. max_epochs (int): maximum number of training epochs. batch_size (int): number of training samples per batch. device (string): device to use for training. network_name (string): identifier of the network architecture. verbose (bool): log all events and losses? """ super(AbstractVAETraining, self).__init__() self.verbose = verbose self.checkpoint_path = f"{path_prefix}/{network_name}" self.report_interval = report_interval self.checkpoint_interval = checkpoint_interval self.data = data self.valid = valid self.train_data = None self.valid_data = None self.max_epochs = max_epochs self.batch_size = batch_size self.device = device netlist = [] self.network_names = [] for network in networks: self.network_names.append(network) network_object = networks[network].to(self.device) setattr(self, network, network_object) netlist.extend(list(network_object.parameters())) self.current_losses = {} self.network_name = network_name self.writer = SummaryWriter(network_name) self.epoch_id = 0 self.step_id = 0 if optimizer_kwargs is None: optimizer_kwargs = {"lr": 5e-4} self.optimizer = optimizer(netlist, **optimizer_kwargs) def save_path(self): return f"{self.checkpoint_path}-save.torch" def divergence_loss(self, *args): """Abstract method. Computes the divergence loss.""" raise NotImplementedError("Abstract") def reconstruction_loss(self, *args): """Abstract method. Computes the reconstruction loss.""" raise NotImplementedError("Abstract") def loss(self, *args): """Abstract method. Computes the training loss.""" raise NotImplementedError("Abstract") def sample(self, *args, **kwargs): """Abstract method. Samples from the latent distribution.""" raise NotImplementedError("Abstract") def run_networks(self, data, *args): """Abstract method. Runs neural networks at each step.""" raise NotImplementedError("Abstract") def preprocess(self, data): """Takes and partitions input data into VAE data and args.""" return data def step(self, data): """Performs a single step of VAE training. Args: data: data points used for training.""" self.optimizer.zero_grad() data = to_device(data, self.device) data, *netargs = self.preprocess(data) args = self.run_networks(data, *netargs) loss_val = self.loss(*args) if self.verbose: for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.writer.add_scalar("total loss", float(loss_val), self.step_id) loss_val.backward() self.optimizer.step() self.each_step() return float(loss_val) def valid_step(self, data): """Performs a single step of VAE validation. Args: data: data points used for validation.""" with torch.no_grad(): if isinstance(data, (list, tuple)): data = [point.to(self.device) for point in data] elif isinstance(data, dict): data = {key: data[key].to(self.device) for key in data} else: data = data.to(self.device) args = self.run_networks(data) loss_val = self.loss(*args) return float(loss_val) def checkpoint(self): """Performs a checkpoint for all encoders and decoders.""" for name in self.network_names: the_net = getattr(self, name) if isinstance(the_net, torch.nn.DataParallel): the_net = the_net.module netwrite( the_net, f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch" ) self.each_checkpoint() def validate(self, data): loss = self.valid_step(data) self.writer.add_scalar("valid loss", loss, self.step_id) self.each_validate() def train(self): """Trains a VAE until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=8, shuffle=True) if self.valid is not None: self.valid_data = DataLoader(self.valid, batch_size=self.batch_size, num_workers=8, shuffle=True) for data in self.train_data: self.step(data) if self.step_id % self.checkpoint_interval == 0: self.checkpoint() if self.valid is not None and self.step_id % self.report_interval == 0: vdata = None try: vdata = next(valid_iter) except StopIteration: valid_iter = iter(self.valid_data) vdata = next(valid_iter) vdata = to_device(vdata, self.device) self.validate(vdata) self.step_id += 1 netlist = [getattr(self, name) for name in self.network_names] return netlist
class SupervisedTraining(Training): """Standard supervised training process. Args: net (Module): a trainable network module. train_data (DataLoader): a :class:`DataLoader` returning the training data set. validate_data (DataLoader): a :class:`DataLoader` return ing the validation data set. optimizer (Optimizer): an optimizer for the network. Defaults to ADAM. schedule (Schedule): a learning rate schedule. Defaults to decay when stagnated. max_epochs (int): the maximum number of epochs to train. device (str): the device to run on. checkpoint_path (str): the path to save network checkpoints. """ checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetState("net"), NetState("optimizer") ] def __init__(self, net, train_data, validate_data, losses, optimizer=torch.optim.Adam, schedule=None, max_epochs=50, batch_size=128, accumulate=None, device="cpu", network_name="network", path_prefix=".", report_interval=10, checkpoint_interval=1000, valid_callback=lambda x: None): super(SupervisedTraining, self).__init__() self.valid_callback = valid_callback self.network_name = network_name self.writer = SummaryWriter(network_name) self.device = device self.accumulate = accumulate self.optimizer = optimizer(net.parameters()) if schedule is None: self.schedule = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, patience=10) else: self.schedule = schedule self.losses = losses self.train_data = DataLoader(train_data, batch_size=batch_size, num_workers=8, shuffle=True, drop_last=True) self.validate_data = DataLoader(validate_data, batch_size=batch_size, num_workers=8, shuffle=True, drop_last=True) self.net = net.to(self.device) self.max_epochs = max_epochs self.checkpoint_path = f"{path_prefix}/{network_name}-checkpoint" self.report_interval = report_interval self.checkpoint_interval = checkpoint_interval self.step_id = 0 self.epoch_id = 0 self.validation_losses = [0 for _ in range(len(self.losses))] self.training_losses = [0 for _ in range(len(self.losses))] self.best = None def save_path(self): return self.checkpoint_path + "-save.torch" def checkpoint(self): the_net = self.net if isinstance(the_net, torch.nn.DataParallel): the_net = the_net.module netwrite( self.net, f"{self.checkpoint_path}-epoch-{self.epoch_id}-step-{self.step_id}.torch" ) self.each_checkpoint() def run_networks(self, data): inputs, *labels = data if not isinstance(inputs, (list, tuple)): inputs = [inputs] predictions = self.net(*inputs) if not isinstance(predictions, (list, tuple)): predictions = [predictions] return [combined for combined in zip(predictions, labels)] def loss(self, inputs): loss_val = torch.tensor(0.0).to(self.device) for idx, the_input in enumerate(inputs): this_loss_val = self.losses[idx](*the_input) self.training_losses[idx] = float(this_loss_val) loss_val += this_loss_val return loss_val def valid_loss(self, inputs): training_cache = list(self.training_losses) loss_val = self.loss(inputs) self.validation_losses = self.training_losses self.training_losses = training_cache return loss_val def step(self, data): if self.accumulate is None: self.optimizer.zero_grad() outputs = self.run_networks(data) loss_val = self.loss(outputs) loss_val.backward() torch.nn.utils.clip_grad_norm_(self.net.parameters(), 5.0) if self.accumulate is None: self.optimizer.step() elif self.step_id % self.accumulate == 0: self.optimizer.step() self.optimizer.zero_grad() self.each_step() def validate(self, data): with torch.no_grad(): self.net.eval() outputs = self.run_networks(data) self.valid_loss(outputs) self.each_validate() self.valid_callback(self, to_device(data, "cpu"), to_device(outputs, "cpu")) self.net.train() def schedule_step(self): self.schedule.step(sum(self.validation_losses)) def each_step(self): Training.each_step(self) for idx, loss in enumerate(self.training_losses): self.writer.add_scalar(f"training loss {idx}", loss, self.step_id) self.writer.add_scalar(f"training loss total", sum(self.training_losses), self.step_id) def each_validate(self): for idx, loss in enumerate(self.validation_losses): self.writer.add_scalar(f"validation loss {idx}", loss, self.step_id) self.writer.add_scalar(f"validation loss total", sum(self.validation_losses), self.step_id) def train(self): for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id valid_iter = iter(self.validate_data) for data in self.train_data: data = to_device(data, self.device) self.step(data) if self.step_id % self.report_interval == 0: vdata = None try: vdata = next(valid_iter) except StopIteration: valid_iter = iter(self.validate_data) vdata = next(valid_iter) vdata = to_device(vdata, self.device) self.validate(vdata) if self.step_id % self.checkpoint_interval == 0: self.checkpoint() self.step_id += 1 self.schedule_step() self.each_epoch() return self.net
class VAETraining(AbstractVAETraining): """Standard VAE training setup.""" checkpoint_parameters = AbstractVAETraining.checkpoint_parameters + [ NetState("prior_target") ] def __init__(self, encoder, decoder, prior, data, prior_mu=0.0, generate=True, reconstruction_weight=1.0, divergence_weight=1.0, **kwargs): """Standard VAE training setup, training a pair of encoder and decoder to maximize the evidence lower bound. Args: encoder (nn.Module): encoder giving the variational posterior. decoder (nn.Module): decoder generating data from latent representations. data (Dataset): dataset providing training data. kwargs (dict): keyword arguments for generic VAE training. """ self.encoder = ... self.decoder = ... self.prior = ... super(VAETraining, self).__init__({ "encoder": encoder, "decoder": decoder, "prior": prior, }, data, **kwargs) self.generate = generate self.prior_mu = prior_mu self.prior_target = deepcopy(self.prior) self.reconstruction_weight = reconstruction_weight self.divergence_weight = divergence_weight self.checkpoint_names.update(dict(prior_target=self.prior_target)) def reconstruction_loss(self, reconstruction, target): return match(reconstruction, target) def divergence_loss(self, posterior, prior): return match(posterior, prior) def loss(self, posterior, prior, prior_target, reconstruction, target, args): ce = self.reconstruction_loss(reconstruction, target) kld = self.divergence_loss(posterior, prior_target) kld_prior = self.divergence_loss(detach(posterior), prior) loss_val = self.reconstruction_weight * ce + self.divergence_weight * (kld - kld.detach() + kld_prior) self.current_losses["reconstruction-log-likelihood"] = float(ce) self.current_losses["kullback-leibler-divergence"] = float(kld_prior) return loss_val def sample(self, distribution): return distribution.rsample() def run_networks(self, data, *args): posterior, *other = self.encoder(data, *args) prior = self.prior(*other, *args) with torch.no_grad(): prior_target = self.prior(*other, *args) sample = self.sample(posterior) reconstruction = self.decoder(sample, *other, *args) return posterior, prior, prior_target, reconstruction, data, args def ema(self): with torch.no_grad(): for target, source in zip( self.prior_target.parameters(), self.prior.parameters() ): target *= self.prior_mu target += (1 - self.prior_mu) * source def each_step(self): self.ema() super().each_step() def shape_adjust(self, data): if data.size(1) == 1: data = torch.repeat_interleave(data, 3, dim=1) return data def generate_samples(self): sample, args = self.prior.sample(self.batch_size) return self.decoder.display(self.decoder(sample, *args)) def each_generate(self, posterior, prior, prior_target, reconstruction, target, args): if self.generate: with torch.no_grad(): generated = self.generate_samples() self.writer.add_images("generated", self.shape_adjust(generated), self.step_id) self.writer.add_images("target", self.shape_adjust(target), self.step_id) self.writer.add_images("reconstruction", self.shape_adjust(self.decoder.display(reconstruction)), self.step_id)
class AbstractVAETraining(Training): """Abstract base class for VAE training.""" checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("network_names"), NetState("optimizer") ] def __init__(self, networks, data, valid=None, optimizer=torch.optim.Adam, optimizer_kwargs=None, gradient_clip=200.0, gradient_skip=400.0, **kwargs): """Generic training setup for variational autoencoders. Args: networks (list): networks used in the training step. data (Dataset): provider of training data. optimizer (Optimizer): optimizer class for gradient descent. optimizer_kwargs (dict): keyword arguments for the optimizer used in network training. max_epochs (int): maximum number of training epochs. batch_size (int): number of training samples per batch. device (string): device to use for training. network_name (string): identifier of the network architecture. verbose (bool): log all events and losses? """ super(AbstractVAETraining, self).__init__(**kwargs) self.data = data self.valid = valid self.train_data = None self.valid_data = None self.gradient_clip = gradient_clip self.gradient_skip = gradient_skip self.valid_iter = None if self.valid is not None: self.valid_data = DataLoader( self.valid, batch_size=self.batch_size, num_workers=8, shuffle=True ) self.valid_iter = iter(self.valid_data) self.network_names, netlist = self.collect_netlist(networks) if optimizer_kwargs is None: optimizer_kwargs = {"lr" : 5e-4} self.optimizer = optimizer( netlist, **optimizer_kwargs ) self.checkpoint_names.update( self.get_netlist(self.network_names) ) def divergence_loss(self, *args): """Abstract method. Computes the divergence loss.""" raise NotImplementedError("Abstract") def reconstruction_loss(self, *args): """Abstract method. Computes the reconstruction loss.""" raise NotImplementedError("Abstract") def loss(self, *args): """Abstract method. Computes the training loss.""" raise NotImplementedError("Abstract") def sample(self, *args, **kwargs): """Abstract method. Samples from the latent distribution.""" raise NotImplementedError("Abstract") def run_networks(self, data, *args): """Abstract method. Runs neural networks at each step.""" raise NotImplementedError("Abstract") def preprocess(self, data): """Takes and partitions input data into VAE data and args.""" return data def each_generate(self, *args): pass def step(self, data): """Performs a single step of VAE training. Args: data: data points used for training.""" self.optimizer.zero_grad() data = to_device(data, self.device) data, *netargs = self.preprocess(data) args = self.run_networks(data, *netargs) loss_val = self.loss(*args) if self.verbose: if self.step_id % self.report_interval == 0: self.each_generate(*args) for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.writer.add_scalar("total loss", float(loss_val), self.step_id) loss_val.backward() parameters = [ param for key, val in self.get_netlist(self.network_names).items() for param in val.parameters() ] gn = nn.utils.clip_grad_norm_(parameters, self.gradient_clip) if (not torch.isnan(gn).any()) and (gn < self.gradient_skip).all(): self.optimizer.step() self.each_step() return float(loss_val) def valid_step(self, data): """Performs a single step of VAE validation. Args: data: data points used for validation.""" with torch.no_grad(): if isinstance(data, (list, tuple)): data = [ point.to(self.device) for point in data ] elif isinstance(data, dict): data = { key : data[key].to(self.device) for key in data } else: data = data.to(self.device) args = self.run_networks(data) loss_val = self.loss(*args) return float(loss_val) def validate(self, data): loss = self.valid_step(data) self.writer.add_scalar("valid loss", loss, self.step_id) self.each_validate() def run_report(self): if self.valid is not None: vdata = None try: vdata = next(self.valid_iter) except StopIteration: self.valid_iter = iter(self.valid_data) vdata = next(self.valid_iter) vdata = to_device(vdata, self.device) self.validate(vdata) def train(self): """Trains a VAE until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader( self.data, batch_size=self.batch_size, num_workers=8, shuffle=True ) for data in self.train_data: self.step(data) self.log() self.step_id += 1 netlist = self.get_netlist(self.network_names) return netlist
class SupervisedTraining(Training): """Standard supervised training process. Args: net (Module): a trainable network module. train_data (DataLoader): a :class:`DataLoader` returning the training data set. validate_data (DataLoader): a :class:`DataLoader` return ing the validation data set. optimizer (Optimizer): an optimizer for the network. Defaults to ADAM. schedule (Schedule): a learning rate schedule. Defaults to decay when stagnated. max_epochs (int): the maximum number of epochs to train. device (str): the device to run on. checkpoint_path (str): the path to save network checkpoints. """ checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetState("net"), NetState("optimizer") ] def __init__(self, net, train_data, validate_data, losses, optimizer=torch.optim.Adam, schedule=None, accumulate=None, valid_callback=None, **kwargs): super(SupervisedTraining, self).__init__(**kwargs) self.valid_callback = valid_callback or (lambda x, y, z: None) self.accumulate = accumulate self.optimizer = optimizer(net.parameters()) if schedule is None: self.schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=10) else: self.schedule = schedule self.losses = losses self.train_data = DataLoader( train_data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True, prefetch_factor=self.prefetch_factor ) self.valid_iter = None if validate_data is not None: self.validate_data = DataLoader( validate_data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True, prefetch_factor=self.prefetch_factor ) self.valid_iter = iter(self.validate_data) self.net = net.to(self.device) self.checkpoint_names = dict(checkpoint=self.net) self.validation_losses = [0 for _ in range(len(self.losses))] self.training_losses = [0 for _ in range(len(self.losses))] self.best = None def run_networks(self, data): inputs, *labels = data if not isinstance(inputs, (list, tuple)): inputs = [inputs] predictions = self.net(*inputs) if not isinstance(predictions, (list, tuple)): predictions = [predictions] return [combined for combined in zip(predictions, labels)] def loss(self, inputs): loss_val = torch.tensor(0.0).to(self.device) for idx, the_input in enumerate(inputs): this_loss_val = self.losses[idx](*the_input) self.training_losses[idx] = float(this_loss_val) loss_val += this_loss_val return loss_val def valid_loss(self, inputs): training_cache = list(self.training_losses) loss_val = self.loss(inputs) self.validation_losses = self.training_losses self.training_losses = training_cache return loss_val def chunk(self, data, split): if torch.is_tensor(data): return data.split(len(data) // split) elif isinstance(data, (list, tuple)): result = [ [] for idx in range(split) ] for item in data: for target, part in zip(result, self.chunk(item, split)): target.append(part) return result elif isinstance(data, dict): result = [ {} for idx in range(split) ] for name in data: for dd, part in zip(result, self.chunk(data[name], split)): dd[name] = part return result else: return data def step(self, data): self.optimizer.zero_grad() if self.accumulate is not None: points = self.chunk(data, self.accumulate) for point in points: outputs = self.run_networks(point) loss_val = self.loss(outputs) / self.accumulate loss_val.backward() else: outputs = self.run_networks(data) loss_val = self.loss(outputs) loss_val.backward() torch.nn.utils.clip_grad_norm_(self.net.parameters(), 5.0) self.optimizer.step() self.each_step() def validate(self, data): with torch.no_grad(): self.net.eval() if self.accumulate is not None: point = self.chunk(data, self.accumulate)[0] outputs = self.run_networks(point) else: outputs = self.run_networks(data) self.valid_loss(outputs) self.each_validate() self.valid_callback( self, to_device(data, "cpu"), to_device(outputs, "cpu") ) self.net.train() def schedule_step(self): self.schedule.step(sum(self.validation_losses)) def each_step(self): Training.each_step(self) for idx, loss in enumerate(self.training_losses): self.writer.add_scalar(f"training loss {idx}", loss, self.step_id) self.writer.add_scalar(f"training loss total", sum(self.training_losses), self.step_id) def each_validate(self): for idx, loss in enumerate(self.validation_losses): self.writer.add_scalar(f"validation loss {idx}", loss, self.step_id) self.writer.add_scalar(f"validation loss total", sum(self.validation_losses), self.step_id) def run_report(self): if self.valid_iter is not None: vdata = None try: vdata = next(self.valid_iter) except StopIteration: self.valid_iter = iter(self.validate_data) vdata = next(self.valid_iter) vdata = to_device(vdata, self.device) self.validate(vdata) def train(self): for epoch_id in range(self.max_epochs): for data in self.train_data: data = to_device(data, self.device) self.step(data) self.log() self.step_id += 1 self.schedule_step() self.each_epoch() self.epoch_id += 1 return self.net
class AbstractGANTraining(Training): """Abstract base class for GAN training.""" checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("generator_names"), NetNameListState("discriminator_names"), NetState("generator_optimizer"), NetState("discriminator_optimizer") ] def __init__(self, generators, discriminators, data, optimizer=torch.optim.Adam, generator_optimizer_kwargs=None, discriminator_optimizer_kwargs=None, n_critic=1, n_actor=1, max_epochs=50, batch_size=128, device="cpu", path_prefix=".", network_name="network", verbose=False, report_interval=10, checkpoint_interval=1000): """Generic training setup for generative adversarial networks. Args: generators (list): networks used in the generation step. discriminators (list): networks used in the discriminator step. data (Dataset): provider of training data. optimizer (Optimizer): optimizer class for gradient descent. generator_optimizer_kwargs (dict): keyword arguments for the optimizer used in generator training. discriminator_optimizer_kwargs (dict): keyword arguments for the optimizer used in discriminator training. n_critic (int): number of critic training iterations per step. n_actor (int): number of actor training iterations per step. max_epochs (int): maximum number of training epochs. batch_size (int): number of training samples per batch. device (string): device to use for training. network_name (string): identifier of the network architecture. verbose (bool): log all events and losses? """ super(AbstractGANTraining, self).__init__() self.verbose = verbose self.report_interval = report_interval self.checkpoint_interval = checkpoint_interval self.checkpoint_path = f"{path_prefix}/{network_name}" self.n_critic = n_critic self.n_actor = n_actor generator_netlist = [] self.generator_names = [] for network in generators: self.generator_names.append(network) network_object = generators[network].to(device) setattr(self, network, network_object) generator_netlist.extend(list(network_object.parameters())) discriminator_netlist = [] self.discriminator_names = [] for network in discriminators: self.discriminator_names.append(network) network_object = discriminators[network].to(device) setattr(self, network, network_object) discriminator_netlist.extend(list(network_object.parameters())) self.data = data self.train_data = None self.max_epochs = max_epochs self.batch_size = batch_size self.device = device self.current_losses = {} self.network_name = network_name self.writer = SummaryWriter(network_name) self.epoch_id = 0 self.step_id = 0 if generator_optimizer_kwargs is None: generator_optimizer_kwargs = {"lr": 5e-4} if discriminator_optimizer_kwargs is None: discriminator_optimizer_kwargs = {"lr": 5e-4} self.generator_optimizer = optimizer(generator_netlist, **generator_optimizer_kwargs) self.discriminator_optimizer = optimizer( discriminator_netlist, **discriminator_optimizer_kwargs) def save_path(self): return f"{self.checkpoint_path}-save.torch" def generator_loss(self, *args): """Abstract method. Computes the generator loss.""" raise NotImplementedError("Abstract") def generator_step_loss(self, *args): """Computes the losses of all generators.""" return self.generator_loss(*args) def discriminator_loss(self, *args): """Abstract method. Computes the discriminator loss.""" raise NotImplementedError("Abstract") def discriminator_step_loss(self, *args): """Computes the losses of all discriminators.""" return self.discriminator_loss(*args) def sample(self, *args, **kwargs): """Abstract method. Samples from the latent distribution.""" raise NotImplementedError("Abstract") def run_generator(self, data): """Abstract method. Runs generation at each step.""" raise NotImplementedError("Abstract") def run_discriminator(self, data): """Abstract method. Runs discriminator training.""" raise NotImplementedError("Abstract") def each_generate(self, *inputs): """Reports on generation.""" pass def discriminator_step(self, data): """Performs a single step of discriminator training. Args: data: data points used for training. """ self.discriminator_optimizer.zero_grad() data = to_device(data, self.device) args = self.run_discriminator(data) loss_val, *grad_out = self.discriminator_step_loss(*args) if self.verbose: for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.writer.add_scalar("discriminator total loss", float(loss_val), self.step_id) loss_val.backward() self.discriminator_optimizer.step() def generator_step(self, data): """Performs a single step of generator training. Args: data: data points used for training. """ self.generator_optimizer.zero_grad() data = to_device(data, self.device) args = self.run_generator(data) loss_val = self.generator_step_loss(*args) if self.verbose: if self.step_id % self.report_interval == 0: self.each_generate(*args) for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.writer.add_scalar("generator total loss", float(loss_val), self.step_id) loss_val.backward() self.generator_optimizer.step() def step(self, data): """Performs a single step of GAN training, comprised of one or more steps of discriminator and generator training. Args: data: data points used for training.""" for _ in range(self.n_critic): self.discriminator_step(next(data)) for _ in range(self.n_actor): self.generator_step(next(data)) self.each_step() def checkpoint(self): """Performs a checkpoint of all generators and discriminators.""" for name in self.generator_names: the_net = getattr(self, name) if isinstance(the_net, torch.nn.DataParallel): the_net = the_net.module netwrite( the_net, f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch" ) for name in self.discriminator_names: the_net = getattr(self, name) if isinstance(the_net, torch.nn.DataParallel): the_net = the_net.module netwrite( the_net, f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch" ) self.each_checkpoint() def train(self): """Trains a GAN until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=8, shuffle=True, drop_last=True) batches_per_step = self.n_actor + self.n_critic steps_per_episode = len(self.train_data) // batches_per_step data = iter(self.train_data) for _ in range(steps_per_episode): self.step(data) if self.step_id % self.checkpoint_interval == 0: self.checkpoint() self.step_id += 1 generators = [getattr(self, name) for name in self.generator_names] discriminators = [ getattr(self, name) for name in self.discriminator_names ] return generators, discriminators
class OffPolicyTraining(Training): checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("auxiliary_names"), NetState("policy"), NetState("optimizer"), NetState("auxiliary_optimizer") ] def __init__(self, policy, agent, environment, auxiliary_networks=None, buffer_size=100_000, piecewise_append=False, policy_steps=1, auxiliary_steps=1, n_workers=8, discount=0.99, double=False, optimizer=torch.optim.Adam, optimizer_kwargs=None, aux_optimizer=torch.optim.Adam, aux_optimizer_kwargs=None, **kwargs): super().__init__(**kwargs) self.policy_steps = policy_steps self.auxiliary_steps = auxiliary_steps self.current_losses = {} self.statistics = ExperienceStatistics() self.discount = discount self.environment = environment self.agent = agent self.policy = policy.to(self.device) self.collector = EnvironmentCollector(environment, agent, discount=discount) self.distributor = DefaultDistributor() self.data_collector = ExperienceCollector(self.distributor, self.collector, n_workers=n_workers, piecewise=piecewise_append) self.buffer = SchemaBuffer(self.data_collector.schema(), buffer_size, double=double) optimizer_kwargs = optimizer_kwargs or {} self.optimizer = optimizer(self.policy.parameters(), **optimizer_kwargs) auxiliary_netlist = [] self.auxiliary_names = [] for network in auxiliary_networks: self.auxiliary_names.append(network) network_object = auxiliary_networks[network].to(self.device) setattr(self, network, network_object) auxiliary_netlist.extend(list(network_object.parameters())) aux_optimizer_kwargs = aux_optimizer_kwargs or {} self.auxiliary_optimizer = aux_optimizer(auxiliary_netlist, **aux_optimizer_kwargs) self.checkpoint_names = dict( policy=self.policy, **{name: getattr(self, name) for name in self.auxiliary_names})