def __init__(self, net, train_data, validate_data, losses, optimizer=torch.optim.Adam, schedule=None, accumulate=None, valid_callback=None, **kwargs): super(SupervisedTraining, self).__init__(**kwargs) self.valid_callback = valid_callback or (lambda x, y, z: None) self.accumulate = accumulate self.optimizer = optimizer(net.parameters()) if schedule is None: self.schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=10) else: self.schedule = schedule self.losses = losses self.train_data = DataLoader( train_data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True, prefetch_factor=self.prefetch_factor ) self.valid_iter = None if validate_data is not None: self.validate_data = DataLoader( validate_data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True, prefetch_factor=self.prefetch_factor ) self.valid_iter = iter(self.validate_data) self.net = net.to(self.device) self.checkpoint_names = dict(checkpoint=self.net) self.validation_losses = [0 for _ in range(len(self.losses))] self.training_losses = [0 for _ in range(len(self.losses))] self.best = None
def train(self): """Trains a VAE until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=8, shuffle=True) if self.valid is not None: self.valid_data = DataLoader(self.valid, batch_size=self.batch_size, num_workers=8, shuffle=True) for data in self.train_data: self.step(data) if self.step_id % self.checkpoint_interval == 0: self.checkpoint() if self.valid is not None and self.step_id % self.report_interval == 0: vdata = None try: vdata = next(valid_iter) except StopIteration: valid_iter = iter(self.valid_data) vdata = next(valid_iter) vdata = to_device(vdata, self.device) self.validate(vdata) self.step_id += 1 netlist = [getattr(self, name) for name in self.network_names] return netlist
def train(self): aggressive = True old_mi = 0 new_mi = 0 self.step_id = 0 for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=8, shuffle=True) for data in self.train_data: if aggressive: self.aggressive_update(data) else: self.step(data) if self.step_id % self.checkpoint_interval == 0: self.checkpoint() self.step_id += 1 valid_data = DataLoader(self.valid, batch_size=self.batch_size, shuffle=True) new_mi = self.compute_mi(next(iter(valid_data))) aggressive = new_mi > old_mi netlist = [getattr(self, name) for name in self.network_names] return netlist
def __init__(self, net, train_data, validate_data, losses, **kwargs): super(FewShotTraining, self).__init__(net, train_data, validate_data, losses, **kwargs) support_data = copy(train_data) train_data.data_mode = type(train_data.data_mode)(1) support_data = SupportData(train_data, shots=5) validate_support_data = SupportData(validate_data, shots=5) self.support_loader = iter(DataLoader(support_data)) self.valid_support_loader = iter(DataLoader(validate_support_data))
def __init__(self, net, train_data, validate_data, losses, optimizer=torch.optim.Adam, schedule=None, max_epochs=50, batch_size=128, accumulate=None, device="cpu", network_name="network", path_prefix=".", report_interval=10, checkpoint_interval=1000, num_workers=8, valid_callback=lambda x: None): super(SupervisedTraining, self).__init__() self.valid_callback = valid_callback self.network_name = network_name self.batch_size = batch_size # self.train_writer = SummaryWriter(f'{network_name}-train') # self.valid_writer = SummaryWriter(f'{network_name}-valid') # self.meta_writer = SummaryWriter(f'{network_name}-meta') self.device = device self.accumulate = accumulate self.num_workers = num_workers self.optimizer = optimizer(net.parameters()) if schedule is None: self.schedule = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, patience=10) else: self.schedule = schedule self.losses = losses self.train_data = DataLoader(train_data, batch_size=batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True) self.validate_data = DataLoader(validate_data, batch_size=batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True) self.net = net.to(self.device) self.max_epochs = max_epochs self.checkpoint_path = f"{path_prefix}/{network_name}-checkpoint" self.report_interval = report_interval self.checkpoint_interval = checkpoint_interval self.step_id = 0 self.epoch_id = 0 self.validation_losses = [0 for _ in range(len(self.losses))] self.training_losses = [0 for _ in range(len(self.losses))] self.best = None
def train(self): expectation, embedding = self.expectation() weights, labels, centers = self.cluster(embedding) self.data.labels = torch.zeros_like(expectation) self.data.labels[expectation.argmax(dim=1)] = 1 for epoch_id in range(self.max_epochs): self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=8, sampler=WeightedRandomSampler( weights, len(self.data) * 4, replacement=True)) for data, expected_logits in self.train_data: self.step(data, expected_logits, centers) self.log() self.step_id += 1 expectation, embedding = self.expectation() labels = expectation.argmax(dim=1).to("cpu").squeeze() self.each_cluster(expectation.to("cpu"), labels.numpy()) self.data.labels = expectation.to("cpu").squeeze() self.epoch_id += 1 return self.net
def train(self): for epoch_id in range(self.max_epochs): embedding = self.embed_all() label_hierarchy = [] center_hierarchy = [] for clustering in self.clusterings: self.clustering = clustering weights, labels, centers = self.cluster(embedding) label_hierarchy.append(np.expand_dims(labels, axis=1)) center_hierarchy.append(centers) self.each_cluster(embedding, label_hierarchy) label_hierarchy = np.concatenate(label_hierarchy, axis=1) self.data.labels = label_hierarchy self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=0, sampler=WeightedRandomSampler( weights, min(20000, len(self.data)), replacement=True)) for inner_epoch in range(1): for data, label in self.train_data: self.step(data, label, center_hierarchy) self.log() self.step_id += 1 self.epoch_id += 1 return self.net
def __init__(self, score, *args, buffer_size=10000, buffer_probability=0.95, sample_steps=10, decay=1, reset_threshold=1000, integrator=None, oos_penalty=True, accept_probability=1.0, sampler_likelihood=1.0, maximum_entropy=0.3, **kwargs): self.score = ... super(EnergyTraining, self).__init__({"score": score}, *args, **kwargs) self.sampler_likelihood = sampler_likelihood self.maximum_entropy = maximum_entropy self.target_score = deepcopy(score).eval() self.reset_threshold = reset_threshold self.oos_penalty = oos_penalty self.decay = decay self.integrator = integrator if integrator is not None else Langevin() self.sample_steps = sample_steps self.buffer = SampleBuffer(self, buffer_size=buffer_size, buffer_probability=buffer_probability, accept_probability=accept_probability) self.buffer_loader = lambda x: DataLoader( x, batch_size=self.batch_size, shuffle=True, drop_last=True)
def train(self): """Trains a GAN until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=8, shuffle=True, drop_last=True) batches_per_step = self.n_actor + self.n_critic steps_per_episode = len(self.train_data) // batches_per_step data = iter(self.train_data) for _ in range(steps_per_episode): self.step(data) if self.step_id % self.checkpoint_interval == 0: self.checkpoint() self.step_id += 1 generators = [getattr(self, name) for name in self.generator_names] discriminators = [ getattr(self, name) for name in self.discriminator_names ] return generators, discriminators
def __init__(self, score, sampler, *args, transition_buffer_size=10000, sampler_steps=10, sampler_optimizer=torch.optim.Adam, n_sampler=1, sampler_optimizer_kwargs=None, sampler_wrapper=lambda x: x, **kwargs): super().__init__(score, *args, **kwargs) self.transition_buffer = TransitionBuffer(self, transition_buffer_size) self.sampler = sampler.to(self.device) self.wrapper = sampler_wrapper(self.sampler) self.sampler_steps = sampler_steps self.n_sampler = n_sampler self.transition_buffer_loader = lambda x: DataLoader( x, batch_size=2 * self.batch_size, shuffle=True, drop_last=True) if sampler_optimizer_kwargs is None: sampler_optimizer_kwargs = {"lr": 5e-4} self.sampler_optimizer = sampler_optimizer(sampler.parameters(), **sampler_optimizer_kwargs)
def __init__(self, data_set, batch_size=1, device="cpu", **kwargs): self.data = data_set self.device = device self.loader = DataLoader(data_set, batch_size=batch_size, drop_last=True, sampler=InfiniteSampler(data_set), **kwargs) self.iter = iter(self.loader)
def __init__(self, net, train_data, validate_data, losses, optimizer=torch.optim.Adam, schedule=None, max_epochs=50, batch_size=128, device="cpu", network_name="network", path_prefix=".", report_interval=10, checkpoint_interval=1000, valid_callback=lambda x: None): super(FewShotTraining, self).__init__(net, train_data, validate_data, losses, optimizer=optimizer, schedule=schedule, max_epochs=max_epochs, batch_size=batch_size, device=device, network_name=network_name, path_prefix=path_prefix, report_interval=report_interval, checkpoint_interval=checkpoint_interval, valid_callback=valid_callback) support_data = copy(train_data) train_data.data_mode = type(train_data.data_mode)(1) support_data = SupportData(train_data, shots=5) validate_support_data = SupportData(validate_data, shots=5) self.support_loader = iter(DataLoader(support_data)) self.valid_support_loader = iter(DataLoader(validate_support_data))
def __init__(self, networks, data, valid=None, optimizer=torch.optim.Adam, optimizer_kwargs=None, gradient_clip=200.0, gradient_skip=400.0, **kwargs): """Generic training setup for variational autoencoders. Args: networks (list): networks used in the training step. data (Dataset): provider of training data. optimizer (Optimizer): optimizer class for gradient descent. optimizer_kwargs (dict): keyword arguments for the optimizer used in network training. max_epochs (int): maximum number of training epochs. batch_size (int): number of training samples per batch. device (string): device to use for training. network_name (string): identifier of the network architecture. verbose (bool): log all events and losses? """ super(AbstractVAETraining, self).__init__(**kwargs) self.data = data self.valid = valid self.train_data = None self.valid_data = None self.gradient_clip = gradient_clip self.gradient_skip = gradient_skip self.valid_iter = None if self.valid is not None: self.valid_data = DataLoader( self.valid, batch_size=self.batch_size, num_workers=8, shuffle=True ) self.valid_iter = iter(self.valid_data) self.network_names, netlist = self.collect_netlist(networks) if optimizer_kwargs is None: optimizer_kwargs = {"lr" : 5e-4} self.optimizer = optimizer( netlist, **optimizer_kwargs ) self.checkpoint_names.update( self.get_netlist(self.network_names) )
def get_data(self, step): if self.data[step] is None: return None data = self.loaders[step] if data is None: data = iter( DataLoader(self.data[step], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True)) self.loaders[step] = data try: data_point = to_device(next(data), self.device) except StopIteration: data = iter( DataLoader(self.data[step], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True)) self.loaders[step] = data data_point = to_device(next(data), self.device) return data_point
def embed_all(self): self.net.eval() with torch.no_grad(): embedding = [] batch_loader = DataLoader(self.data, batch_size=self.batch_size, shuffle=False) for point, *_ in islice(batch_loader, 5000 // self.batch_size): latent_point = self.net(point.to(self.device)) latent_point = latent_point.to("cpu") latent_point = latent_point.reshape(latent_point.size(0), -1) embedding.append(latent_point) embedding = torch.cat(embedding, dim=0) self.net.train() return embedding
def train(self): """Runs contrastive training until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader( self.data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True ) for data in self.train_data: self.step(data) self.log() self.step_id += 1 return self.get_netlist(self.names)
def train(self): """Trains a VAE until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader( self.data, batch_size=self.batch_size, num_workers=8, shuffle=True ) for data in self.train_data: self.step(data) self.log() self.step_id += 1 netlist = self.get_netlist(self.network_names) return netlist
def train(self): for epoch_id in range(self.max_epochs): self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=8, shuffle=True) for internal_epoch in range(1): for data, *_ in islice(self.train_data, 100): self.step(data) self.log() self.step_id += 1 self.each_cluster() self.alpha *= float(np.power(2.0, (-(np.log(epoch_id + 1)**2)))) self.epoch_id += 1 return self.net
def train(self): """Trains an EBM until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True) for data in self.train_data: self.step(data) self.log() self.step_id += 1 self.epoch_id += 1 scores = [getattr(self, name) for name in self.names] return scores
def __init__(self, score, critic, *args, decay=0.1, integrator=None, optimizer=torch.optim.Adam, n_critic=5, critic_optimizer_kwargs=None, **kwargs): self.score = ... self.critic = ... super().__init__({"score": score}, *args, optimizer=optimizer, **kwargs) critics = {"critic": critic} netlist = [] self.critic_names = [] for network in critics: self.critic_names.append(network) network_object = critics[network].to(self.device) setattr(self, network, network_object) netlist.extend(list(network_object.parameters())) if critic_optimizer_kwargs is None: critic_optimizer_kwargs = {"lr": 5e-4} self.critic_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=8, shuffle=True, drop_last=True) self.critic_optimizer = optimizer(netlist, **critic_optimizer_kwargs) self.n_critic = n_critic self.decay = decay self.integrator = integrator self.checkpoint_names.update( {name: getattr(self, name) for name in self.critic_names})
def train(self): for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id embedding = self.embed_all() weights, labels, centers = self.cluster(embedding) self.each_cluster(embedding, labels) self.data.labels = labels self.train_data = None self.train_data = DataLoader( self.data, batch_size=self.batch_size, num_workers=8, sampler=WeightedRandomSampler(weights, len(self.data) * 4, replacement=True) ) for data, label in self.train_data: self.step(data, label, centers) self.log() self.step_id += 1 return self.net
def aggressive_update(self, data): inner_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=8, shuffle=True) for parameter in self.decoder: parameter.requires_grad = False last_ten = [None] * 10 for idx, data_p in enumerate(inner_data): loss = self.step(data_p) last_ten[idx % 10] = loss if last_ten[-1] is not None and last_ten[-1] >= last_ten[0]: break for parameter in self.decoder: parameter.requires_grad = True for parameter in self.encoder: parameter.requires_grad = False self.step(data) for parameter in self.encoder: parameter.requires_grad = True
def expectation(self): self.net.eval() with torch.no_grad(): embedding = [] batch_loader = DataLoader(self.data, batch_size=self.batch_size, shuffle=False) for point, *_ in batch_loader: features, mean, logvar = self.net(point.to(self.device)) std = torch.exp(0.5 * logvar) sample = torch.randn_like(std).mul(std).add_(mean) latent_point = func.adaptive_avg_pool2d(sample, 1) latent_point = latent_point latent_point = latent_point.reshape(latent_point.size(0), -1) embedding.append(latent_point) embedding = torch.cat(embedding, dim=0) expectation = self.classifier(embedding) self.net.train() return expectation.to("cpu"), embedding.to("cpu")
def __init__(self, score, *args, buffer_size=100, buffer_probability=0.9, sample_steps=10, decay=1, integrator=None, oos_penalty=True, **kwargs): self.score = ... super(EnergyTraining, self).__init__({"score": score}, *args, **kwargs) self.oos_penalty = oos_penalty self.decay = decay self.integrator = integrator if integrator is not None else Langevin() self.sample_steps = sample_steps self.buffer = SampleBuffer(self, buffer_size=buffer_size, buffer_probability=buffer_probability) self.buffer_loader = lambda x: DataLoader( x, batch_size=self.batch_size, shuffle=True, drop_last=True)