Пример #1
0
 def configure_optimizers(self):
     gen_params = chain(
         self.encoder.parameters(),
         self.generator.parameters(),
         self.variance_adapter.parameters(),
     )
     disc_params = chain(self.multiscaledisc.parameters(),
                         self.multiperioddisc.parameters())
     opt1 = torch.optim.AdamW(disc_params, lr=self._cfg.lr)
     opt2 = torch.optim.AdamW(gen_params, lr=self._cfg.lr)
     num_procs = self._trainer.num_gpus * self._trainer.num_nodes
     num_samples = len(self._train_dl.dataset)
     batch_size = self._train_dl.batch_size
     iter_per_epoch = np.ceil(num_samples / (num_procs * batch_size))
     max_steps = iter_per_epoch * self._trainer.max_epochs
     logging.info(f"MAX STEPS: {max_steps}")
     sch1 = NoamAnnealing(opt1,
                          d_model=256,
                          warmup_steps=3000,
                          max_steps=max_steps,
                          min_lr=1e-5)
     sch1_dict = {
         'scheduler': sch1,
         'interval': 'step',
     }
     sch2 = NoamAnnealing(opt2,
                          d_model=256,
                          warmup_steps=3000,
                          max_steps=max_steps,
                          min_lr=1e-5)
     sch2_dict = {
         'scheduler': sch2,
         'interval': 'step',
     }
     return [opt1, opt2], [sch1_dict, sch2_dict]
Пример #2
0
# NODE Info
num_gpus = 1
num_nodes = 1

# DATASET Info
num_files = 10
"""
COMPUTATION
"""
steps = int(num_files * num_epochs / num_gpus / num_nodes / batch_size)
warmup_steps = steps * warmup_ratio

optim = Adam([params], lr=initial_lr)
policy = NoamAnnealing(optim,
                       max_steps=steps,
                       warmup_ratio=warmup_ratio,
                       d_model=d_model)

x = [i for i in range(steps)]
y = []
for step in x:
    y.append(policy.get_lr())
    policy.step()

y = np.asarray(y)
print("Num steps :", steps)
print("Peak step :", y.argmax(axis=0))
print("Peak LR :", y.max(axis=0))
print("Final LR :", y[-1])

plt.plot(x, y, label='Policy={}'.format(str(policy.__class__.__name__)))