示例#1
0
  def __init__(self, net, train_data, validate_data, losses,
               optimizer=torch.optim.Adam,
               schedule=None,
               accumulate=None,
               valid_callback=None,
               **kwargs):
    super(SupervisedTraining, self).__init__(**kwargs)
    self.valid_callback = valid_callback or (lambda x, y, z: None)
    self.accumulate = accumulate
    self.optimizer = optimizer(net.parameters())
    if schedule is None:
      self.schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=10)
    else:
      self.schedule = schedule
    self.losses = losses
    self.train_data = DataLoader(
      train_data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True,
      prefetch_factor=self.prefetch_factor
    )
    self.valid_iter = None
    if validate_data is not None:
      self.validate_data = DataLoader(
        validate_data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True,
        prefetch_factor=self.prefetch_factor
      )
      self.valid_iter = iter(self.validate_data)
    
    self.net = net.to(self.device)

    self.checkpoint_names = dict(checkpoint=self.net)
    self.validation_losses = [0 for _ in range(len(self.losses))]
    self.training_losses = [0 for _ in range(len(self.losses))]
    self.best = None
示例#2
0
    def train(self):
        """Trains a VAE until the maximum number of epochs is reached."""
        for epoch_id in range(self.max_epochs):
            self.epoch_id = epoch_id
            self.train_data = None
            self.train_data = DataLoader(self.data,
                                         batch_size=self.batch_size,
                                         num_workers=8,
                                         shuffle=True)
            if self.valid is not None:
                self.valid_data = DataLoader(self.valid,
                                             batch_size=self.batch_size,
                                             num_workers=8,
                                             shuffle=True)
            for data in self.train_data:
                self.step(data)
                if self.step_id % self.checkpoint_interval == 0:
                    self.checkpoint()
                if self.valid is not None and self.step_id % self.report_interval == 0:
                    vdata = None
                    try:
                        vdata = next(valid_iter)
                    except StopIteration:
                        valid_iter = iter(self.valid_data)
                        vdata = next(valid_iter)
                    vdata = to_device(vdata, self.device)
                    self.validate(vdata)
                self.step_id += 1

        netlist = [getattr(self, name) for name in self.network_names]

        return netlist
示例#3
0
    def train(self):
        aggressive = True
        old_mi = 0
        new_mi = 0
        self.step_id = 0
        for epoch_id in range(self.max_epochs):
            self.epoch_id = epoch_id
            self.train_data = None
            self.train_data = DataLoader(self.data,
                                         batch_size=self.batch_size,
                                         num_workers=8,
                                         shuffle=True)
            for data in self.train_data:
                if aggressive:
                    self.aggressive_update(data)
                else:
                    self.step(data)
                if self.step_id % self.checkpoint_interval == 0:
                    self.checkpoint()
                self.step_id += 1
            valid_data = DataLoader(self.valid,
                                    batch_size=self.batch_size,
                                    shuffle=True)
            new_mi = self.compute_mi(next(iter(valid_data)))
            aggressive = new_mi > old_mi

        netlist = [getattr(self, name) for name in self.network_names]

        return netlist
示例#4
0
    def __init__(self, net, train_data, validate_data, losses, **kwargs):
        super(FewShotTraining, self).__init__(net, train_data, validate_data,
                                              losses, **kwargs)

        support_data = copy(train_data)
        train_data.data_mode = type(train_data.data_mode)(1)
        support_data = SupportData(train_data, shots=5)
        validate_support_data = SupportData(validate_data, shots=5)
        self.support_loader = iter(DataLoader(support_data))
        self.valid_support_loader = iter(DataLoader(validate_support_data))
示例#5
0
 def __init__(self,
              net,
              train_data,
              validate_data,
              losses,
              optimizer=torch.optim.Adam,
              schedule=None,
              max_epochs=50,
              batch_size=128,
              accumulate=None,
              device="cpu",
              network_name="network",
              path_prefix=".",
              report_interval=10,
              checkpoint_interval=1000,
              num_workers=8,
              valid_callback=lambda x: None):
     super(SupervisedTraining, self).__init__()
     self.valid_callback = valid_callback
     self.network_name = network_name
     self.batch_size = batch_size
     # self.train_writer = SummaryWriter(f'{network_name}-train')
     # self.valid_writer = SummaryWriter(f'{network_name}-valid')
     # self.meta_writer = SummaryWriter(f'{network_name}-meta')
     self.device = device
     self.accumulate = accumulate
     self.num_workers = num_workers
     self.optimizer = optimizer(net.parameters())
     if schedule is None:
         self.schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(
             self.optimizer, patience=10)
     else:
         self.schedule = schedule
     self.losses = losses
     self.train_data = DataLoader(train_data,
                                  batch_size=batch_size,
                                  num_workers=self.num_workers,
                                  shuffle=True,
                                  drop_last=True)
     self.validate_data = DataLoader(validate_data,
                                     batch_size=batch_size,
                                     num_workers=self.num_workers,
                                     shuffle=True,
                                     drop_last=True)
     self.net = net.to(self.device)
     self.max_epochs = max_epochs
     self.checkpoint_path = f"{path_prefix}/{network_name}-checkpoint"
     self.report_interval = report_interval
     self.checkpoint_interval = checkpoint_interval
     self.step_id = 0
     self.epoch_id = 0
     self.validation_losses = [0 for _ in range(len(self.losses))]
     self.training_losses = [0 for _ in range(len(self.losses))]
     self.best = None
示例#6
0
    def train(self):
        expectation, embedding = self.expectation()
        weights, labels, centers = self.cluster(embedding)
        self.data.labels = torch.zeros_like(expectation)
        self.data.labels[expectation.argmax(dim=1)] = 1

        for epoch_id in range(self.max_epochs):
            self.train_data = None
            self.train_data = DataLoader(self.data,
                                         batch_size=self.batch_size,
                                         num_workers=8,
                                         sampler=WeightedRandomSampler(
                                             weights,
                                             len(self.data) * 4,
                                             replacement=True))
            for data, expected_logits in self.train_data:
                self.step(data, expected_logits, centers)
                self.log()
                self.step_id += 1

            expectation, embedding = self.expectation()
            labels = expectation.argmax(dim=1).to("cpu").squeeze()
            self.each_cluster(expectation.to("cpu"), labels.numpy())
            self.data.labels = expectation.to("cpu").squeeze()
            self.epoch_id += 1

        return self.net
示例#7
0
    def train(self):
        for epoch_id in range(self.max_epochs):
            embedding = self.embed_all()

            label_hierarchy = []
            center_hierarchy = []
            for clustering in self.clusterings:
                self.clustering = clustering
                weights, labels, centers = self.cluster(embedding)
                label_hierarchy.append(np.expand_dims(labels, axis=1))
                center_hierarchy.append(centers)
            self.each_cluster(embedding, label_hierarchy)
            label_hierarchy = np.concatenate(label_hierarchy, axis=1)

            self.data.labels = label_hierarchy
            self.train_data = None
            self.train_data = DataLoader(self.data,
                                         batch_size=self.batch_size,
                                         num_workers=0,
                                         sampler=WeightedRandomSampler(
                                             weights,
                                             min(20000, len(self.data)),
                                             replacement=True))
            for inner_epoch in range(1):
                for data, label in self.train_data:
                    self.step(data, label, center_hierarchy)
                    self.log()
                    self.step_id += 1
            self.epoch_id += 1

        return self.net
示例#8
0
 def __init__(self,
              score,
              *args,
              buffer_size=10000,
              buffer_probability=0.95,
              sample_steps=10,
              decay=1,
              reset_threshold=1000,
              integrator=None,
              oos_penalty=True,
              accept_probability=1.0,
              sampler_likelihood=1.0,
              maximum_entropy=0.3,
              **kwargs):
     self.score = ...
     super(EnergyTraining, self).__init__({"score": score}, *args, **kwargs)
     self.sampler_likelihood = sampler_likelihood
     self.maximum_entropy = maximum_entropy
     self.target_score = deepcopy(score).eval()
     self.reset_threshold = reset_threshold
     self.oos_penalty = oos_penalty
     self.decay = decay
     self.integrator = integrator if integrator is not None else Langevin()
     self.sample_steps = sample_steps
     self.buffer = SampleBuffer(self,
                                buffer_size=buffer_size,
                                buffer_probability=buffer_probability,
                                accept_probability=accept_probability)
     self.buffer_loader = lambda x: DataLoader(
         x, batch_size=self.batch_size, shuffle=True, drop_last=True)
示例#9
0
    def train(self):
        """Trains a GAN until the maximum number of epochs is reached."""
        for epoch_id in range(self.max_epochs):
            self.epoch_id = epoch_id
            self.train_data = None
            self.train_data = DataLoader(self.data,
                                         batch_size=self.batch_size,
                                         num_workers=8,
                                         shuffle=True,
                                         drop_last=True)

            batches_per_step = self.n_actor + self.n_critic
            steps_per_episode = len(self.train_data) // batches_per_step

            data = iter(self.train_data)
            for _ in range(steps_per_episode):
                self.step(data)
                if self.step_id % self.checkpoint_interval == 0:
                    self.checkpoint()
                self.step_id += 1

        generators = [getattr(self, name) for name in self.generator_names]

        discriminators = [
            getattr(self, name) for name in self.discriminator_names
        ]

        return generators, discriminators
示例#10
0
    def __init__(self,
                 score,
                 sampler,
                 *args,
                 transition_buffer_size=10000,
                 sampler_steps=10,
                 sampler_optimizer=torch.optim.Adam,
                 n_sampler=1,
                 sampler_optimizer_kwargs=None,
                 sampler_wrapper=lambda x: x,
                 **kwargs):
        super().__init__(score, *args, **kwargs)
        self.transition_buffer = TransitionBuffer(self, transition_buffer_size)
        self.sampler = sampler.to(self.device)
        self.wrapper = sampler_wrapper(self.sampler)
        self.sampler_steps = sampler_steps
        self.n_sampler = n_sampler
        self.transition_buffer_loader = lambda x: DataLoader(
            x, batch_size=2 * self.batch_size, shuffle=True, drop_last=True)

        if sampler_optimizer_kwargs is None:
            sampler_optimizer_kwargs = {"lr": 5e-4}

        self.sampler_optimizer = sampler_optimizer(sampler.parameters(),
                                                   **sampler_optimizer_kwargs)
示例#11
0
 def __init__(self, data_set, batch_size=1, device="cpu", **kwargs):
     self.data = data_set
     self.device = device
     self.loader = DataLoader(data_set,
                              batch_size=batch_size,
                              drop_last=True,
                              sampler=InfiniteSampler(data_set),
                              **kwargs)
     self.iter = iter(self.loader)
示例#12
0
    def __init__(self,
                 net,
                 train_data,
                 validate_data,
                 losses,
                 optimizer=torch.optim.Adam,
                 schedule=None,
                 max_epochs=50,
                 batch_size=128,
                 device="cpu",
                 network_name="network",
                 path_prefix=".",
                 report_interval=10,
                 checkpoint_interval=1000,
                 valid_callback=lambda x: None):
        super(FewShotTraining,
              self).__init__(net,
                             train_data,
                             validate_data,
                             losses,
                             optimizer=optimizer,
                             schedule=schedule,
                             max_epochs=max_epochs,
                             batch_size=batch_size,
                             device=device,
                             network_name=network_name,
                             path_prefix=path_prefix,
                             report_interval=report_interval,
                             checkpoint_interval=checkpoint_interval,
                             valid_callback=valid_callback)

        support_data = copy(train_data)
        train_data.data_mode = type(train_data.data_mode)(1)
        support_data = SupportData(train_data, shots=5)
        validate_support_data = SupportData(validate_data, shots=5)
        self.support_loader = iter(DataLoader(support_data))
        self.valid_support_loader = iter(DataLoader(validate_support_data))
示例#13
0
  def __init__(self, networks, data, valid=None,
               optimizer=torch.optim.Adam,
               optimizer_kwargs=None,
               gradient_clip=200.0,
               gradient_skip=400.0,
               **kwargs):
    """Generic training setup for variational autoencoders.

    Args:
      networks (list): networks used in the training step.
      data (Dataset): provider of training data.
      optimizer (Optimizer): optimizer class for gradient descent.
      optimizer_kwargs (dict): keyword arguments for the
        optimizer used in network training.
      max_epochs (int): maximum number of training epochs.
      batch_size (int): number of training samples per batch.
      device (string): device to use for training.
      network_name (string): identifier of the network architecture.
      verbose (bool): log all events and losses?
    """
    super(AbstractVAETraining, self).__init__(**kwargs)

    self.data = data
    self.valid = valid
    self.train_data = None
    self.valid_data = None
    self.gradient_clip = gradient_clip
    self.gradient_skip = gradient_skip

    self.valid_iter = None
    if self.valid is not None:
      self.valid_data = DataLoader(
        self.valid, batch_size=self.batch_size, num_workers=8,
        shuffle=True
      )
      self.valid_iter = iter(self.valid_data)

    self.network_names, netlist = self.collect_netlist(networks)

    if optimizer_kwargs is None:
      optimizer_kwargs = {"lr" : 5e-4}

    self.optimizer = optimizer(
      netlist,
      **optimizer_kwargs
    )
    self.checkpoint_names.update(
      self.get_netlist(self.network_names)
    )
 def get_data(self, step):
     if self.data[step] is None:
         return None
     data = self.loaders[step]
     if data is None:
         data = iter(
             DataLoader(self.data[step],
                        batch_size=self.batch_size,
                        num_workers=self.num_workers,
                        shuffle=True,
                        drop_last=True))
         self.loaders[step] = data
     try:
         data_point = to_device(next(data), self.device)
     except StopIteration:
         data = iter(
             DataLoader(self.data[step],
                        batch_size=self.batch_size,
                        num_workers=self.num_workers,
                        shuffle=True,
                        drop_last=True))
         self.loaders[step] = data
         data_point = to_device(next(data), self.device)
     return data_point
示例#15
0
 def embed_all(self):
     self.net.eval()
     with torch.no_grad():
         embedding = []
         batch_loader = DataLoader(self.data,
                                   batch_size=self.batch_size,
                                   shuffle=False)
         for point, *_ in islice(batch_loader, 5000 // self.batch_size):
             latent_point = self.net(point.to(self.device))
             latent_point = latent_point.to("cpu")
             latent_point = latent_point.reshape(latent_point.size(0), -1)
             embedding.append(latent_point)
         embedding = torch.cat(embedding, dim=0)
     self.net.train()
     return embedding
示例#16
0
  def train(self):
    """Runs contrastive training until the maximum number of epochs is reached."""
    for epoch_id in range(self.max_epochs):
      self.epoch_id = epoch_id
      self.train_data = None
      self.train_data = DataLoader(
        self.data, batch_size=self.batch_size, num_workers=self.num_workers,
        shuffle=True, drop_last=True
      )

      for data in self.train_data:
        self.step(data)
        self.log()
        self.step_id += 1

    return self.get_netlist(self.names)
示例#17
0
  def train(self):
    """Trains a VAE until the maximum number of epochs is reached."""
    for epoch_id in range(self.max_epochs):
      self.epoch_id = epoch_id
      self.train_data = None
      self.train_data = DataLoader(
        self.data, batch_size=self.batch_size, num_workers=8,
        shuffle=True
      )
      for data in self.train_data:
        self.step(data)
        self.log()
        self.step_id += 1

    netlist = self.get_netlist(self.network_names)

    return netlist
示例#18
0
    def train(self):
        for epoch_id in range(self.max_epochs):
            self.train_data = None
            self.train_data = DataLoader(self.data,
                                         batch_size=self.batch_size,
                                         num_workers=8,
                                         shuffle=True)
            for internal_epoch in range(1):
                for data, *_ in islice(self.train_data, 100):
                    self.step(data)
                    self.log()
                    self.step_id += 1

            self.each_cluster()
            self.alpha *= float(np.power(2.0, (-(np.log(epoch_id + 1)**2))))
            self.epoch_id += 1

        return self.net
示例#19
0
    def train(self):
        """Trains an EBM until the maximum number of epochs is reached."""
        for epoch_id in range(self.max_epochs):
            self.train_data = None
            self.train_data = DataLoader(self.data,
                                         batch_size=self.batch_size,
                                         num_workers=self.num_workers,
                                         shuffle=True,
                                         drop_last=True)

            for data in self.train_data:
                self.step(data)
                self.log()
                self.step_id += 1
            self.epoch_id += 1

        scores = [getattr(self, name) for name in self.names]

        return scores
示例#20
0
    def __init__(self,
                 score,
                 critic,
                 *args,
                 decay=0.1,
                 integrator=None,
                 optimizer=torch.optim.Adam,
                 n_critic=5,
                 critic_optimizer_kwargs=None,
                 **kwargs):
        self.score = ...
        self.critic = ...
        super().__init__({"score": score},
                         *args,
                         optimizer=optimizer,
                         **kwargs)

        critics = {"critic": critic}
        netlist = []
        self.critic_names = []
        for network in critics:
            self.critic_names.append(network)
            network_object = critics[network].to(self.device)
            setattr(self, network, network_object)
            netlist.extend(list(network_object.parameters()))

        if critic_optimizer_kwargs is None:
            critic_optimizer_kwargs = {"lr": 5e-4}

        self.critic_data = DataLoader(self.data,
                                      batch_size=self.batch_size,
                                      num_workers=8,
                                      shuffle=True,
                                      drop_last=True)
        self.critic_optimizer = optimizer(netlist, **critic_optimizer_kwargs)

        self.n_critic = n_critic
        self.decay = decay
        self.integrator = integrator
        self.checkpoint_names.update(
            {name: getattr(self, name)
             for name in self.critic_names})
示例#21
0
  def train(self):
    for epoch_id in range(self.max_epochs):
      self.epoch_id = epoch_id
      embedding = self.embed_all()
      weights, labels, centers = self.cluster(embedding)

      self.each_cluster(embedding, labels)

      self.data.labels = labels
      self.train_data = None
      self.train_data = DataLoader(
        self.data, batch_size=self.batch_size, num_workers=8,
        sampler=WeightedRandomSampler(weights, len(self.data) * 4, replacement=True)
      )
      for data, label in self.train_data:
        self.step(data, label, centers)
        self.log()
        self.step_id += 1

    return self.net
示例#22
0
 def aggressive_update(self, data):
     inner_data = DataLoader(self.data,
                             batch_size=self.batch_size,
                             num_workers=8,
                             shuffle=True)
     for parameter in self.decoder:
         parameter.requires_grad = False
     last_ten = [None] * 10
     for idx, data_p in enumerate(inner_data):
         loss = self.step(data_p)
         last_ten[idx % 10] = loss
         if last_ten[-1] is not None and last_ten[-1] >= last_ten[0]:
             break
     for parameter in self.decoder:
         parameter.requires_grad = True
     for parameter in self.encoder:
         parameter.requires_grad = False
     self.step(data)
     for parameter in self.encoder:
         parameter.requires_grad = True
示例#23
0
    def expectation(self):
        self.net.eval()
        with torch.no_grad():
            embedding = []
            batch_loader = DataLoader(self.data,
                                      batch_size=self.batch_size,
                                      shuffle=False)
            for point, *_ in batch_loader:
                features, mean, logvar = self.net(point.to(self.device))
                std = torch.exp(0.5 * logvar)
                sample = torch.randn_like(std).mul(std).add_(mean)
                latent_point = func.adaptive_avg_pool2d(sample, 1)

                latent_point = latent_point
                latent_point = latent_point.reshape(latent_point.size(0), -1)
                embedding.append(latent_point)
            embedding = torch.cat(embedding, dim=0)
            expectation = self.classifier(embedding)
        self.net.train()
        return expectation.to("cpu"), embedding.to("cpu")
示例#24
0
 def __init__(self,
              score,
              *args,
              buffer_size=100,
              buffer_probability=0.9,
              sample_steps=10,
              decay=1,
              integrator=None,
              oos_penalty=True,
              **kwargs):
     self.score = ...
     super(EnergyTraining, self).__init__({"score": score}, *args, **kwargs)
     self.oos_penalty = oos_penalty
     self.decay = decay
     self.integrator = integrator if integrator is not None else Langevin()
     self.sample_steps = sample_steps
     self.buffer = SampleBuffer(self,
                                buffer_size=buffer_size,
                                buffer_probability=buffer_probability)
     self.buffer_loader = lambda x: DataLoader(
         x, batch_size=self.batch_size, shuffle=True, drop_last=True)