Esempio n. 1
0
        disc.cuda()

    lr = 0.0001  # The celebrated learning rate

    #   # Optimizer of the generator (Lookahead with Lamb).
    #   gbase = Lamb(genr.parameters(),
    #         lr=lr, weight_decay=0.01, betas=(.9, .999), adam=True)
    #   gopt = Lookahead(base_optimizer=gbase, k=5, alpha=0.8)

    # Optimizer of the generator (SGD)
    gopt = torch.optim.SGD(genr.parameters(), lr=10 * lr)

    # Optimizer of the discriminator (Lookahead with Lamb).
    dbase = Lamb(disc.parameters(),
                 lr=lr,
                 weight_decay=0.01,
                 betas=(.9, .999),
                 adam=True)
    dopt = Lookahead(base_optimizer=dbase, k=5, alpha=0.8)

    # (Binary) cross-entropy loss.
    loss_fun = nn.CrossEntropyLoss(reduction='mean')

    for epoch in range(200):
        dloss = floss = 0.
        for batch in data.batches():
            import pdb
            pdb.set_trace()
            nsmpl = batch.shape[0]

            # Prevent NaNs in the log-likelihood.
      d_model = 256,     # Hidden dimension.
      d_ffn = 512,       # Boom dimension.
      nwrd = len(vocab)  # Output alphabet (protein).
   )

   protdata = SeqData('proteins.txt', vocab)
   mask_symbol = len(vocab)-1 # The last symbol is the mask.

   # Do it with CUDA if possible.
   device = 'cuda' if torch.cuda.is_available() else 'cpu'
   if device == 'cuda': model.cuda()
   
   lr  = 0.0001 # The celebrated learning rate.

   # Optimizer (Lookahead with Lamb).
   baseopt = Lamb(model.parameters(),
         lr=lr, weight_decay=0.01, betas=(.9, .999), adam=True)
   opt = Lookahead(base_optimizer=baseopt, k=5, alpha=0.8)

   loss_fun = nn.CrossEntropyLoss(reduction='mean')

   nbtch = 0
   for epoch in range(20):
      epoch_loss = 0.
      for batch in protdata.batches():
         nbtch += 1

         rnd = lambda n: [random.randint(1,20) for _ in range(n)]

         # Choose symbols to guess (15%).
         guess_pos = (torch.rand(batch.shape) < 0.15) & (batch > 0)
         # Record original symbols (targets).