class AbstractVAETraining(Training): """Abstract base class for VAE training.""" checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("network_names"), NetState("optimizer") ] def __init__(self, networks, data, valid=None, optimizer=torch.optim.Adam, optimizer_kwargs=None, max_epochs=50, batch_size=128, device="cpu", path_prefix=".", network_name="network", report_interval=10, checkpoint_interval=1000, verbose=False): """Generic training setup for variational autoencoders. Args: networks (list): networks used in the training step. data (Dataset): provider of training data. optimizer (Optimizer): optimizer class for gradient descent. optimizer_kwargs (dict): keyword arguments for the optimizer used in network training. max_epochs (int): maximum number of training epochs. batch_size (int): number of training samples per batch. device (string): device to use for training. network_name (string): identifier of the network architecture. verbose (bool): log all events and losses? """ super(AbstractVAETraining, self).__init__() self.verbose = verbose self.checkpoint_path = f"{path_prefix}/{network_name}" self.report_interval = report_interval self.checkpoint_interval = checkpoint_interval self.data = data self.valid = valid self.train_data = None self.valid_data = None self.max_epochs = max_epochs self.batch_size = batch_size self.device = device netlist = [] self.network_names = [] for network in networks: self.network_names.append(network) network_object = networks[network].to(self.device) setattr(self, network, network_object) netlist.extend(list(network_object.parameters())) self.current_losses = {} self.network_name = network_name self.writer = SummaryWriter(network_name) self.epoch_id = 0 self.step_id = 0 if optimizer_kwargs is None: optimizer_kwargs = {"lr": 5e-4} self.optimizer = optimizer(netlist, **optimizer_kwargs) def save_path(self): return f"{self.checkpoint_path}-save.torch" def divergence_loss(self, *args): """Abstract method. Computes the divergence loss.""" raise NotImplementedError("Abstract") def reconstruction_loss(self, *args): """Abstract method. Computes the reconstruction loss.""" raise NotImplementedError("Abstract") def loss(self, *args): """Abstract method. Computes the training loss.""" raise NotImplementedError("Abstract") def sample(self, *args, **kwargs): """Abstract method. Samples from the latent distribution.""" raise NotImplementedError("Abstract") def run_networks(self, data, *args): """Abstract method. Runs neural networks at each step.""" raise NotImplementedError("Abstract") def preprocess(self, data): """Takes and partitions input data into VAE data and args.""" return data def step(self, data): """Performs a single step of VAE training. Args: data: data points used for training.""" self.optimizer.zero_grad() data = to_device(data, self.device) data, *netargs = self.preprocess(data) args = self.run_networks(data, *netargs) loss_val = self.loss(*args) if self.verbose: for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.writer.add_scalar("total loss", float(loss_val), self.step_id) loss_val.backward() self.optimizer.step() self.each_step() return float(loss_val) def valid_step(self, data): """Performs a single step of VAE validation. Args: data: data points used for validation.""" with torch.no_grad(): if isinstance(data, (list, tuple)): data = [point.to(self.device) for point in data] elif isinstance(data, dict): data = {key: data[key].to(self.device) for key in data} else: data = data.to(self.device) args = self.run_networks(data) loss_val = self.loss(*args) return float(loss_val) def checkpoint(self): """Performs a checkpoint for all encoders and decoders.""" for name in self.network_names: the_net = getattr(self, name) if isinstance(the_net, torch.nn.DataParallel): the_net = the_net.module netwrite( the_net, f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch" ) self.each_checkpoint() def validate(self, data): loss = self.valid_step(data) self.writer.add_scalar("valid loss", loss, self.step_id) self.each_validate() def train(self): """Trains a VAE until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=8, shuffle=True) if self.valid is not None: self.valid_data = DataLoader(self.valid, batch_size=self.batch_size, num_workers=8, shuffle=True) for data in self.train_data: self.step(data) if self.step_id % self.checkpoint_interval == 0: self.checkpoint() if self.valid is not None and self.step_id % self.report_interval == 0: vdata = None try: vdata = next(valid_iter) except StopIteration: valid_iter = iter(self.valid_data) vdata = next(valid_iter) vdata = to_device(vdata, self.device) self.validate(vdata) self.step_id += 1 netlist = [getattr(self, name) for name in self.network_names] return netlist
class AbstractEnergyTraining(Training): """Abstract base class for GAN training.""" checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("names") ] def __init__(self, scores, data, optimizer=torch.optim.Adam, optimizer_kwargs=None, num_workers=8, **kwargs): """Generic training setup for energy/score based models. Args: scores (list): networks used for scoring. data (Dataset): provider of training data. optimizer (Optimizer): optimizer class for gradient descent. optimizer_kwargs (dict): keyword arguments for the optimizer used in score function training. max_epochs (int): maximum number of training epochs. batch_size (int): number of training samples per batch. device (string): device to use for training. network_name (string): identifier of the network architecture. verbose (bool): log all events and losses? """ super(AbstractEnergyTraining, self).__init__(**kwargs) netlist = [] self.names = [] for network in scores: self.names.append(network) network_object = scores[network].to(self.device) setattr(self, network, network_object) netlist.extend(list(network_object.parameters())) self.num_workers = num_workers self.data = data self.train_data = None self.current_losses = {} if optimizer_kwargs is None: optimizer_kwargs = {"lr": 5e-4} self.optimizer = optimizer(netlist, **optimizer_kwargs) self.checkpoint_names = { name: getattr(self, name) for name in self.names } def energy_loss(self, *args): """Abstract method. Computes the score function loss.""" raise NotImplementedError("Abstract") def loss(self, *args): return self.energy_loss(*args) def prepare(self, *args, **kwargs): """Abstract method. Prepares an initial state for sampling.""" raise NotImplementedError("Abstract") def sample(self, *args, **kwargs): """Abstract method. Samples from the Boltzmann distribution.""" raise NotImplementedError("Abstract") def run_energy(self, data): """Abstract method. Runs score at each step.""" raise NotImplementedError("Abstract") def each_generate(self, *inputs): """Reports on generation.""" pass def energy_step(self, data): """Performs a single step of discriminator training. Args: data: data points used for training. """ self.optimizer.zero_grad() data = to_device(data, self.device) args = self.run_energy(data) loss_val = self.loss(*args) if self.verbose: for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.writer.add_scalar("discriminator total loss", float(loss_val), self.step_id) loss_val.backward() self.optimizer.step() def step(self, data): """Performs a single step of GAN training, comprised of one or more steps of discriminator and generator training. Args: data: data points used for training.""" self.energy_step(data) self.each_step() def train(self): """Trains an EBM until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True) for data in self.train_data: self.step(data) self.log() self.step_id += 1 self.epoch_id += 1 scores = [getattr(self, name) for name in self.names] return scores
class MultistepTraining(Training, metaclass=StepRegistry): """Abstract base class for multi-step training.""" checkpoint_parameters = Training.checkpoint_parameters + [TrainingState()] step_descriptors = {} step_order = [] def __init__(self, networks, network_mapping, data, **kwargs): super().__init__(**kwargs) self.data = {} self.loaders = {} self.optimizers = {} for name, descriptor in self.step_descriptors.items(): descriptor.n_steps_value = kwargs.get( descriptor.n_steps) or descriptor.n_steps_value descriptor.every_value = kwargs.get( descriptor.every) or descriptor.every_value self.optimizers[name] = [] self.data[name] = None self.loaders[name] = None self.nets = {} for name, (network, opt, opt_kwargs) in networks.items(): network = network.to(self.device) optimizer = opt(network.parameters(), **opt_kwargs) optimizer_name = f"{name}_optimizer" setattr(self, name, network) setattr(self, optimizer_name, optimizer) self.checkpoint_parameters += [ NetState(name), NetState(optimizer_name) ] self.checkpoint_names.update({name: network}) self.nets[name] = (network, optimizer) for step, names in network_mapping.items(): self.optimizers[step] = [self.nets[name][1] for name in names] for step, data_set in data.items(): self.data[step] = data_set self.loaders[step] = None self.current_losses = {} def get_data(self, step): if self.data[step] is None: return None data = self.loaders[step] if data is None: data = iter( DataLoader(self.data[step], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True)) self.loaders[step] = data try: data_point = to_device(next(data), self.device) except StopIteration: data = iter( DataLoader(self.data[step], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True)) self.loaders[step] = data data_point = to_device(next(data), self.device) return data_point def step(self): total_loss = 0.0 for step_name in self.step_order: descriptor = self.step_descriptors[step_name] if self.step_id % descriptor.every_value == 0: for _ in range(descriptor.n_steps_value): data = self.get_data(step_name) for optimizer in self.optimizers[step_name]: optimizer.zero_grad() loss = descriptor.step(self, data) if isinstance(loss, (list, tuple)): loss, *_ = loss if loss is not None: loss.backward() for optimizer in self.optimizers[step_name]: optimizer.step() total_loss += float(loss) self.log_statistics(total_loss) self.each_step() def train(self): for step_id in range(self.max_steps): self.step_id += step_id self.step() self.log() return self.nets
class SupervisedTraining(Training): """Standard supervised training process. Args: net (Module): a trainable network module. train_data (DataLoader): a :class:`DataLoader` returning the training data set. validate_data (DataLoader): a :class:`DataLoader` return ing the validation data set. optimizer (Optimizer): an optimizer for the network. Defaults to ADAM. schedule (Schedule): a learning rate schedule. Defaults to decay when stagnated. max_epochs (int): the maximum number of epochs to train. device (str): the device to run on. checkpoint_path (str): the path to save network checkpoints. """ checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetState("net"), NetState("optimizer") ] def __init__(self, net, train_data, validate_data, losses, optimizer=torch.optim.Adam, schedule=None, max_epochs=50, batch_size=128, accumulate=None, device="cpu", network_name="network", path_prefix=".", report_interval=10, checkpoint_interval=1000, valid_callback=lambda x: None): super(SupervisedTraining, self).__init__() self.valid_callback = valid_callback self.network_name = network_name self.writer = SummaryWriter(network_name) self.device = device self.accumulate = accumulate self.optimizer = optimizer(net.parameters()) if schedule is None: self.schedule = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, patience=10) else: self.schedule = schedule self.losses = losses self.train_data = DataLoader(train_data, batch_size=batch_size, num_workers=8, shuffle=True, drop_last=True) self.validate_data = DataLoader(validate_data, batch_size=batch_size, num_workers=8, shuffle=True, drop_last=True) self.net = net.to(self.device) self.max_epochs = max_epochs self.checkpoint_path = f"{path_prefix}/{network_name}-checkpoint" self.report_interval = report_interval self.checkpoint_interval = checkpoint_interval self.step_id = 0 self.epoch_id = 0 self.validation_losses = [0 for _ in range(len(self.losses))] self.training_losses = [0 for _ in range(len(self.losses))] self.best = None def save_path(self): return self.checkpoint_path + "-save.torch" def checkpoint(self): the_net = self.net if isinstance(the_net, torch.nn.DataParallel): the_net = the_net.module netwrite( self.net, f"{self.checkpoint_path}-epoch-{self.epoch_id}-step-{self.step_id}.torch" ) self.each_checkpoint() def run_networks(self, data): inputs, *labels = data if not isinstance(inputs, (list, tuple)): inputs = [inputs] predictions = self.net(*inputs) if not isinstance(predictions, (list, tuple)): predictions = [predictions] return [combined for combined in zip(predictions, labels)] def loss(self, inputs): loss_val = torch.tensor(0.0).to(self.device) for idx, the_input in enumerate(inputs): this_loss_val = self.losses[idx](*the_input) self.training_losses[idx] = float(this_loss_val) loss_val += this_loss_val return loss_val def valid_loss(self, inputs): training_cache = list(self.training_losses) loss_val = self.loss(inputs) self.validation_losses = self.training_losses self.training_losses = training_cache return loss_val def step(self, data): if self.accumulate is None: self.optimizer.zero_grad() outputs = self.run_networks(data) loss_val = self.loss(outputs) loss_val.backward() torch.nn.utils.clip_grad_norm_(self.net.parameters(), 5.0) if self.accumulate is None: self.optimizer.step() elif self.step_id % self.accumulate == 0: self.optimizer.step() self.optimizer.zero_grad() self.each_step() def validate(self, data): with torch.no_grad(): self.net.eval() outputs = self.run_networks(data) self.valid_loss(outputs) self.each_validate() self.valid_callback(self, to_device(data, "cpu"), to_device(outputs, "cpu")) self.net.train() def schedule_step(self): self.schedule.step(sum(self.validation_losses)) def each_step(self): Training.each_step(self) for idx, loss in enumerate(self.training_losses): self.writer.add_scalar(f"training loss {idx}", loss, self.step_id) self.writer.add_scalar(f"training loss total", sum(self.training_losses), self.step_id) def each_validate(self): for idx, loss in enumerate(self.validation_losses): self.writer.add_scalar(f"validation loss {idx}", loss, self.step_id) self.writer.add_scalar(f"validation loss total", sum(self.validation_losses), self.step_id) def train(self): for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id valid_iter = iter(self.validate_data) for data in self.train_data: data = to_device(data, self.device) self.step(data) if self.step_id % self.report_interval == 0: vdata = None try: vdata = next(valid_iter) except StopIteration: valid_iter = iter(self.validate_data) vdata = next(valid_iter) vdata = to_device(vdata, self.device) self.validate(vdata) if self.step_id % self.checkpoint_interval == 0: self.checkpoint() self.step_id += 1 self.schedule_step() self.each_epoch() return self.net
class AbstractVAETraining(Training): """Abstract base class for VAE training.""" checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("network_names"), NetState("optimizer") ] def __init__(self, networks, data, valid=None, optimizer=torch.optim.Adam, optimizer_kwargs=None, gradient_clip=200.0, gradient_skip=400.0, **kwargs): """Generic training setup for variational autoencoders. Args: networks (list): networks used in the training step. data (Dataset): provider of training data. optimizer (Optimizer): optimizer class for gradient descent. optimizer_kwargs (dict): keyword arguments for the optimizer used in network training. max_epochs (int): maximum number of training epochs. batch_size (int): number of training samples per batch. device (string): device to use for training. network_name (string): identifier of the network architecture. verbose (bool): log all events and losses? """ super(AbstractVAETraining, self).__init__(**kwargs) self.data = data self.valid = valid self.train_data = None self.valid_data = None self.gradient_clip = gradient_clip self.gradient_skip = gradient_skip self.valid_iter = None if self.valid is not None: self.valid_data = DataLoader( self.valid, batch_size=self.batch_size, num_workers=8, shuffle=True ) self.valid_iter = iter(self.valid_data) self.network_names, netlist = self.collect_netlist(networks) if optimizer_kwargs is None: optimizer_kwargs = {"lr" : 5e-4} self.optimizer = optimizer( netlist, **optimizer_kwargs ) self.checkpoint_names.update( self.get_netlist(self.network_names) ) def divergence_loss(self, *args): """Abstract method. Computes the divergence loss.""" raise NotImplementedError("Abstract") def reconstruction_loss(self, *args): """Abstract method. Computes the reconstruction loss.""" raise NotImplementedError("Abstract") def loss(self, *args): """Abstract method. Computes the training loss.""" raise NotImplementedError("Abstract") def sample(self, *args, **kwargs): """Abstract method. Samples from the latent distribution.""" raise NotImplementedError("Abstract") def run_networks(self, data, *args): """Abstract method. Runs neural networks at each step.""" raise NotImplementedError("Abstract") def preprocess(self, data): """Takes and partitions input data into VAE data and args.""" return data def each_generate(self, *args): pass def step(self, data): """Performs a single step of VAE training. Args: data: data points used for training.""" self.optimizer.zero_grad() data = to_device(data, self.device) data, *netargs = self.preprocess(data) args = self.run_networks(data, *netargs) loss_val = self.loss(*args) if self.verbose: if self.step_id % self.report_interval == 0: self.each_generate(*args) for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.writer.add_scalar("total loss", float(loss_val), self.step_id) loss_val.backward() parameters = [ param for key, val in self.get_netlist(self.network_names).items() for param in val.parameters() ] gn = nn.utils.clip_grad_norm_(parameters, self.gradient_clip) if (not torch.isnan(gn).any()) and (gn < self.gradient_skip).all(): self.optimizer.step() self.each_step() return float(loss_val) def valid_step(self, data): """Performs a single step of VAE validation. Args: data: data points used for validation.""" with torch.no_grad(): if isinstance(data, (list, tuple)): data = [ point.to(self.device) for point in data ] elif isinstance(data, dict): data = { key : data[key].to(self.device) for key in data } else: data = data.to(self.device) args = self.run_networks(data) loss_val = self.loss(*args) return float(loss_val) def validate(self, data): loss = self.valid_step(data) self.writer.add_scalar("valid loss", loss, self.step_id) self.each_validate() def run_report(self): if self.valid is not None: vdata = None try: vdata = next(self.valid_iter) except StopIteration: self.valid_iter = iter(self.valid_data) vdata = next(self.valid_iter) vdata = to_device(vdata, self.device) self.validate(vdata) def train(self): """Trains a VAE until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader( self.data, batch_size=self.batch_size, num_workers=8, shuffle=True ) for data in self.train_data: self.step(data) self.log() self.step_id += 1 netlist = self.get_netlist(self.network_names) return netlist
class SupervisedTraining(Training): """Standard supervised training process. Args: net (Module): a trainable network module. train_data (DataLoader): a :class:`DataLoader` returning the training data set. validate_data (DataLoader): a :class:`DataLoader` return ing the validation data set. optimizer (Optimizer): an optimizer for the network. Defaults to ADAM. schedule (Schedule): a learning rate schedule. Defaults to decay when stagnated. max_epochs (int): the maximum number of epochs to train. device (str): the device to run on. checkpoint_path (str): the path to save network checkpoints. """ checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetState("net"), NetState("optimizer") ] def __init__(self, net, train_data, validate_data, losses, optimizer=torch.optim.Adam, schedule=None, accumulate=None, valid_callback=None, **kwargs): super(SupervisedTraining, self).__init__(**kwargs) self.valid_callback = valid_callback or (lambda x, y, z: None) self.accumulate = accumulate self.optimizer = optimizer(net.parameters()) if schedule is None: self.schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=10) else: self.schedule = schedule self.losses = losses self.train_data = DataLoader( train_data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True, prefetch_factor=self.prefetch_factor ) self.valid_iter = None if validate_data is not None: self.validate_data = DataLoader( validate_data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True, prefetch_factor=self.prefetch_factor ) self.valid_iter = iter(self.validate_data) self.net = net.to(self.device) self.checkpoint_names = dict(checkpoint=self.net) self.validation_losses = [0 for _ in range(len(self.losses))] self.training_losses = [0 for _ in range(len(self.losses))] self.best = None def run_networks(self, data): inputs, *labels = data if not isinstance(inputs, (list, tuple)): inputs = [inputs] predictions = self.net(*inputs) if not isinstance(predictions, (list, tuple)): predictions = [predictions] return [combined for combined in zip(predictions, labels)] def loss(self, inputs): loss_val = torch.tensor(0.0).to(self.device) for idx, the_input in enumerate(inputs): this_loss_val = self.losses[idx](*the_input) self.training_losses[idx] = float(this_loss_val) loss_val += this_loss_val return loss_val def valid_loss(self, inputs): training_cache = list(self.training_losses) loss_val = self.loss(inputs) self.validation_losses = self.training_losses self.training_losses = training_cache return loss_val def chunk(self, data, split): if torch.is_tensor(data): return data.split(len(data) // split) elif isinstance(data, (list, tuple)): result = [ [] for idx in range(split) ] for item in data: for target, part in zip(result, self.chunk(item, split)): target.append(part) return result elif isinstance(data, dict): result = [ {} for idx in range(split) ] for name in data: for dd, part in zip(result, self.chunk(data[name], split)): dd[name] = part return result else: return data def step(self, data): self.optimizer.zero_grad() if self.accumulate is not None: points = self.chunk(data, self.accumulate) for point in points: outputs = self.run_networks(point) loss_val = self.loss(outputs) / self.accumulate loss_val.backward() else: outputs = self.run_networks(data) loss_val = self.loss(outputs) loss_val.backward() torch.nn.utils.clip_grad_norm_(self.net.parameters(), 5.0) self.optimizer.step() self.each_step() def validate(self, data): with torch.no_grad(): self.net.eval() if self.accumulate is not None: point = self.chunk(data, self.accumulate)[0] outputs = self.run_networks(point) else: outputs = self.run_networks(data) self.valid_loss(outputs) self.each_validate() self.valid_callback( self, to_device(data, "cpu"), to_device(outputs, "cpu") ) self.net.train() def schedule_step(self): self.schedule.step(sum(self.validation_losses)) def each_step(self): Training.each_step(self) for idx, loss in enumerate(self.training_losses): self.writer.add_scalar(f"training loss {idx}", loss, self.step_id) self.writer.add_scalar(f"training loss total", sum(self.training_losses), self.step_id) def each_validate(self): for idx, loss in enumerate(self.validation_losses): self.writer.add_scalar(f"validation loss {idx}", loss, self.step_id) self.writer.add_scalar(f"validation loss total", sum(self.validation_losses), self.step_id) def run_report(self): if self.valid_iter is not None: vdata = None try: vdata = next(self.valid_iter) except StopIteration: self.valid_iter = iter(self.validate_data) vdata = next(self.valid_iter) vdata = to_device(vdata, self.device) self.validate(vdata) def train(self): for epoch_id in range(self.max_epochs): for data in self.train_data: data = to_device(data, self.device) self.step(data) self.log() self.step_id += 1 self.schedule_step() self.each_epoch() self.epoch_id += 1 return self.net
class AbstractGANTraining(Training): """Abstract base class for GAN training.""" checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("generator_names"), NetNameListState("discriminator_names"), NetState("generator_optimizer"), NetState("discriminator_optimizer") ] def __init__(self, generators, discriminators, data, optimizer=torch.optim.Adam, generator_optimizer_kwargs=None, discriminator_optimizer_kwargs=None, n_critic=1, n_actor=1, max_epochs=50, batch_size=128, device="cpu", path_prefix=".", network_name="network", verbose=False, report_interval=10, checkpoint_interval=1000): """Generic training setup for generative adversarial networks. Args: generators (list): networks used in the generation step. discriminators (list): networks used in the discriminator step. data (Dataset): provider of training data. optimizer (Optimizer): optimizer class for gradient descent. generator_optimizer_kwargs (dict): keyword arguments for the optimizer used in generator training. discriminator_optimizer_kwargs (dict): keyword arguments for the optimizer used in discriminator training. n_critic (int): number of critic training iterations per step. n_actor (int): number of actor training iterations per step. max_epochs (int): maximum number of training epochs. batch_size (int): number of training samples per batch. device (string): device to use for training. network_name (string): identifier of the network architecture. verbose (bool): log all events and losses? """ super(AbstractGANTraining, self).__init__() self.verbose = verbose self.report_interval = report_interval self.checkpoint_interval = checkpoint_interval self.checkpoint_path = f"{path_prefix}/{network_name}" self.n_critic = n_critic self.n_actor = n_actor generator_netlist = [] self.generator_names = [] for network in generators: self.generator_names.append(network) network_object = generators[network].to(device) setattr(self, network, network_object) generator_netlist.extend(list(network_object.parameters())) discriminator_netlist = [] self.discriminator_names = [] for network in discriminators: self.discriminator_names.append(network) network_object = discriminators[network].to(device) setattr(self, network, network_object) discriminator_netlist.extend(list(network_object.parameters())) self.data = data self.train_data = None self.max_epochs = max_epochs self.batch_size = batch_size self.device = device self.current_losses = {} self.network_name = network_name self.writer = SummaryWriter(network_name) self.epoch_id = 0 self.step_id = 0 if generator_optimizer_kwargs is None: generator_optimizer_kwargs = {"lr": 5e-4} if discriminator_optimizer_kwargs is None: discriminator_optimizer_kwargs = {"lr": 5e-4} self.generator_optimizer = optimizer(generator_netlist, **generator_optimizer_kwargs) self.discriminator_optimizer = optimizer( discriminator_netlist, **discriminator_optimizer_kwargs) def save_path(self): return f"{self.checkpoint_path}-save.torch" def generator_loss(self, *args): """Abstract method. Computes the generator loss.""" raise NotImplementedError("Abstract") def generator_step_loss(self, *args): """Computes the losses of all generators.""" return self.generator_loss(*args) def discriminator_loss(self, *args): """Abstract method. Computes the discriminator loss.""" raise NotImplementedError("Abstract") def discriminator_step_loss(self, *args): """Computes the losses of all discriminators.""" return self.discriminator_loss(*args) def sample(self, *args, **kwargs): """Abstract method. Samples from the latent distribution.""" raise NotImplementedError("Abstract") def run_generator(self, data): """Abstract method. Runs generation at each step.""" raise NotImplementedError("Abstract") def run_discriminator(self, data): """Abstract method. Runs discriminator training.""" raise NotImplementedError("Abstract") def each_generate(self, *inputs): """Reports on generation.""" pass def discriminator_step(self, data): """Performs a single step of discriminator training. Args: data: data points used for training. """ self.discriminator_optimizer.zero_grad() data = to_device(data, self.device) args = self.run_discriminator(data) loss_val, *grad_out = self.discriminator_step_loss(*args) if self.verbose: for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.writer.add_scalar("discriminator total loss", float(loss_val), self.step_id) loss_val.backward() self.discriminator_optimizer.step() def generator_step(self, data): """Performs a single step of generator training. Args: data: data points used for training. """ self.generator_optimizer.zero_grad() data = to_device(data, self.device) args = self.run_generator(data) loss_val = self.generator_step_loss(*args) if self.verbose: if self.step_id % self.report_interval == 0: self.each_generate(*args) for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.writer.add_scalar("generator total loss", float(loss_val), self.step_id) loss_val.backward() self.generator_optimizer.step() def step(self, data): """Performs a single step of GAN training, comprised of one or more steps of discriminator and generator training. Args: data: data points used for training.""" for _ in range(self.n_critic): self.discriminator_step(next(data)) for _ in range(self.n_actor): self.generator_step(next(data)) self.each_step() def checkpoint(self): """Performs a checkpoint of all generators and discriminators.""" for name in self.generator_names: the_net = getattr(self, name) if isinstance(the_net, torch.nn.DataParallel): the_net = the_net.module netwrite( the_net, f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch" ) for name in self.discriminator_names: the_net = getattr(self, name) if isinstance(the_net, torch.nn.DataParallel): the_net = the_net.module netwrite( the_net, f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch" ) self.each_checkpoint() def train(self): """Trains a GAN until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=8, shuffle=True, drop_last=True) batches_per_step = self.n_actor + self.n_critic steps_per_episode = len(self.train_data) // batches_per_step data = iter(self.train_data) for _ in range(steps_per_episode): self.step(data) if self.step_id % self.checkpoint_interval == 0: self.checkpoint() self.step_id += 1 generators = [getattr(self, name) for name in self.generator_names] discriminators = [ getattr(self, name) for name in self.discriminator_names ] return generators, discriminators
class AbstractContrastiveTraining(Training): """Abstract base class for contrastive training.""" checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("names") ] def __init__(self, networks, data, optimizer=torch.optim.Adam, optimizer_kwargs=None, num_workers=8, **kwargs): """Generic training setup for energy/score based models. Args: networks (list): networks used for contrastive learning. data (Dataset): provider of training data. optimizer (Optimizer): optimizer class for gradient descent. optimizer_kwargs (dict): keyword arguments for the optimizer used in score function training. """ super().__init__(**kwargs) netlist = [] self.names, netlist = self.collect_netlist(networks) self.data = data self.train_data = None self.num_workers = num_workers self.current_losses = {} if optimizer_kwargs is None: optimizer_kwargs = {"lr" : 5e-4} self.optimizer = optimizer( netlist, **optimizer_kwargs ) self.checkpoint_names = self.get_netlist(self.names) def contrastive_loss(self, *args): """Abstract method. Computes the contrastive loss.""" raise NotImplementedError("Abstract") def regularization(self, *args): """Computes network regularization.""" return 0.0 def loss(self, *args): contrastive = self.contrastive_loss(*args) regularization = self.regularization(*args) return contrastive + regularization def run_networks(self, data): """Abstract method. Runs networks at each step.""" raise NotImplementedError("Abstract") def visualize(self, data): pass def contrastive_step(self, data): """Performs a single step of contrastive training. Args: data: data points used for training. """ if self.step_id % self.report_interval == 0: self.visualize(data) self.optimizer.zero_grad() data = to_device(data, self.device) make_differentiable(data) args = self.run_networks(data) loss_val = self.loss(*args) self.log_statistics(loss_val, name="total loss") loss_val.backward() self.optimizer.step() def step(self, data): """Performs a single step of contrastive training. Args: data: data points used for training. """ self.contrastive_step(data) self.each_step() def train(self): """Runs contrastive training until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader( self.data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True ) for data in self.train_data: self.step(data) self.log() self.step_id += 1 return self.get_netlist(self.names)
class OffPolicyTraining(Training): checkpoint_parameters = Training.checkpoint_parameters + [ TrainingState(), NetNameListState("auxiliary_names"), NetState("policy"), NetState("optimizer"), NetState("auxiliary_optimizer") ] def __init__(self, policy, agent, environment, auxiliary_networks=None, buffer_size=100_000, piecewise_append=False, policy_steps=1, auxiliary_steps=1, n_workers=8, discount=0.99, double=False, optimizer=torch.optim.Adam, optimizer_kwargs=None, aux_optimizer=torch.optim.Adam, aux_optimizer_kwargs=None, **kwargs): super().__init__(**kwargs) self.policy_steps = policy_steps self.auxiliary_steps = auxiliary_steps self.current_losses = {} self.statistics = ExperienceStatistics() self.discount = discount self.environment = environment self.agent = agent self.policy = policy.to(self.device) self.collector = EnvironmentCollector(environment, agent, discount=discount) self.distributor = DefaultDistributor() self.data_collector = ExperienceCollector(self.distributor, self.collector, n_workers=n_workers, piecewise=piecewise_append) self.buffer = SchemaBuffer(self.data_collector.schema(), buffer_size, double=double) optimizer_kwargs = optimizer_kwargs or {} self.optimizer = optimizer(self.policy.parameters(), **optimizer_kwargs) auxiliary_netlist = [] self.auxiliary_names = [] for network in auxiliary_networks: self.auxiliary_names.append(network) network_object = auxiliary_networks[network].to(self.device) setattr(self, network, network_object) auxiliary_netlist.extend(list(network_object.parameters())) aux_optimizer_kwargs = aux_optimizer_kwargs or {} self.auxiliary_optimizer = aux_optimizer(auxiliary_netlist, **aux_optimizer_kwargs) self.checkpoint_names = dict( policy=self.policy, **{name: getattr(self, name) for name in self.auxiliary_names})