Example #1
0
 def train(self, model: Seq2Seq, discriminator: Discriminator,
           src_file_names: List[str], tgt_file_names: List[str],
           unsupervised_big_epochs: int, print_every: int, save_every: int,
           num_words_in_batch: int, max_length: int, teacher_forcing: bool,
           save_file: str="model", n_unsupervised_batches: int=None,
           enable_unsupervised_backtranslation: bool=False):
     if self.main_optimizer is None or self.discriminator_optimizer is None:
         logger.info("Initializing optimizers...")
         self.main_optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 
                                          lr=self.main_lr, betas=self.main_betas)
         self.discriminator_optimizer = optim.RMSprop(discriminator.parameters(), lr=self.discriminator_lr)
     for big_epoch in range(unsupervised_big_epochs):
         src_batch_gen = BatchGenerator(src_file_names, num_words_in_batch, max_len=max_length,
                                        vocabulary=self.vocabulary, language="src",
                                        max_batch_count=n_unsupervised_batches)
         tgt_batch_gen = BatchGenerator(tgt_file_names, num_words_in_batch, max_len=max_length,
                                        vocabulary=self.vocabulary, language="tgt",
                                        max_batch_count=n_unsupervised_batches)
         logger.debug("Src batch:" + str(next(iter(src_batch_gen))))
         logger.debug("Tgt batch:" + str(next(iter(tgt_batch_gen))))
         timer = time.time()
         main_loss_total = 0
         discriminator_loss_total = 0
         epoch = 0
         for src_batch, tgt_batch in zip(src_batch_gen, tgt_batch_gen):
             model.train()
             discriminator_loss, losses = self.train_batch(model, discriminator, src_batch,
                                                           tgt_batch, teacher_forcing)
             main_loss = sum(losses)
             main_loss_total += main_loss
             discriminator_loss_total += discriminator_loss
             if epoch % save_every == 0 and epoch != 0:
                 save_model(model, discriminator, self.main_optimizer,
                            self.discriminator_optimizer, save_file + ".pt")
             if epoch % print_every == 0 and epoch != 0:
                 main_loss_avg = main_loss_total / print_every
                 discriminator_loss_avg = discriminator_loss_total / print_every
                 main_loss_total = 0
                 discriminator_loss_total = 0
                 diff = time.time() - timer
                 timer = time.time()
                 translator = Translator(model, self.vocabulary, self.use_cuda)
                 logger.debug("Auto: " + translator.translate_sentence("you can prepare your meals here .",
                                                                       "src", "src"))
                 logger.debug("Translated: " + translator.translate_sentence("you can prepare your meals here .",
                                                                             "src", "tgt"))
                 logger.info('%s big epoch, %s epoch, %s sec, %.4f main loss, '
                              '%.4f discriminator loss, current losses: %s' %
                              (big_epoch, epoch, diff, main_loss_avg, discriminator_loss_avg, losses))
             epoch += 1
         save_model(model, discriminator, self.main_optimizer,
                    self.discriminator_optimizer, save_file + ".pt")
         if enable_unsupervised_backtranslation:
             self.current_translation_model = Translator(model, self.vocabulary, self.use_cuda)
             model = copy.deepcopy(model)
Example #2
0
def init_optimizers(model: Seq2Seq,
                    discriminator: Discriminator,
                    discriminator_lr=0.0005,
                    main_lr=0.0003,
                    main_betas=(0.5, 0.999)):
    logging.info("Initializing optimizers...")
    main_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                       model.parameters()),
                                lr=main_lr,
                                betas=main_betas)
    discriminator_optimizer = optim.RMSprop(discriminator.parameters(),
                                            lr=discriminator_lr)
    return main_optimizer, discriminator_optimizer
Example #3
0
    def discriminator_step(self, model: Seq2Seq, discriminator: Discriminator,
                           input_batches: Dict[str, Batch], adv_targets: Dict[str, Variable]):
        discriminator.train()
        model.eval()
        self.discriminator_optimizer.zero_grad()
        adv_loss_computer = DiscriminatorLossCompute(discriminator)

        losses = []
        for key in input_batches:
            input_batch = input_batches[key]
            target = adv_targets[key]
            encoder_output, _ = model.encoder(input_batch.variable, input_batch.lengths)
            losses.append(adv_loss_computer.compute(encoder_output, target))

        discriminator_loss = sum(losses)
        discriminator_loss.backward()
        nn.utils.clip_grad_norm(discriminator.parameters(), 5)
        self.discriminator_optimizer.step()
        return discriminator_loss.data[0]
Example #4
0
def test():
    """Test Notebook API"""
    dataset = MelFromDisk(path="data/test")
    dataloader = torch.utils.data.DataLoader(dataset)
    loaders = OrderedDict({"train": dataloader})
    generator = Generator(80)
    discriminator = Discriminator()

    model = torch.nn.ModuleDict({
        "generator": generator,
        "discriminator": discriminator
    })
    optimizer = {
        "opt_g": torch.optim.Adam(generator.parameters()),
        "opt_d": torch.optim.Adam(discriminator.parameters()),
    }
    callbacks = {
        "loss_g":
        GeneratorLossCallback(),
        "loss_d":
        DiscriminatorLossCallback(),
        "o_g":
        dl.OptimizerCallback(metric_key="generator_loss",
                             optimizer_key="opt_g"),
        "o_d":
        dl.OptimizerCallback(metric_key="discriminator_loss",
                             optimizer_key="opt_d"),
    }
    runner = MelGANRunner()

    runner.train(
        model=model,
        loaders=loaders,
        optimizer=optimizer,
        callbacks=callbacks,
        check=True,
        main_metric="discriminator_loss",
    )
Example #5
0
def train(
    max_int: int = 128,
    batch_size: int = 16,
    training_steps: int = 500,
    learning_rate: float = 0.001,
    print_output_every_n_steps: int = 10,
):
    """Trains the even GAN

    Args:
        max_int: The maximum integer our dataset goes to.  It is used to set the size of the binary
            lists
        batch_size: The number of examples in a training batch
        training_steps: The number of steps to train on.
        learning_rate: The learning rate for the generator and discriminator
        print_output_every_n_steps: The number of training steps before we print generated output

    Returns:
        generator: The trained generator model
        discriminator: The trained discriminator model
    """
    input_length = int(math.log(max_int, 2))

    # Models
    generator = Generator(input_length)
    discriminator = Discriminator(input_length)

    # Optimizers
    generator_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
    discriminator_optimizer = torch.optim.Adam(discriminator.parameters(),
                                               lr=0.001)

    # loss
    loss = nn.BCELoss()
    gen_loss = []
    dis_loss = []

    for i in range(training_steps):
        # zero the gradients on each iteration
        generator_optimizer.zero_grad()

        # Create noisy input for generator
        # Need float type instead of int
        noise = torch.randint(0, 2, size=(batch_size, input_length)).float()
        generated_data = generator(noise)

        # Generate examples of even real data
        # true labels: [1,1,1,1,1,1,....] i.e all ones
        # true data: [[0,0,0,0,1,0,0],....] i.e binary code for even numbers
        true_labels, true_data = generate_even_data(max_int,
                                                    batch_size=batch_size)
        true_labels = torch.tensor(true_labels).float()
        true_data = torch.tensor(true_data).float()

        # Train the generator
        # We invert the labels here and don't train the discriminator because we want the generator
        # to make things the discriminator classifies as true.
        # true labels: [1,1,1,1,....]
        discriminator_out_gen_data = discriminator(generated_data)
        generator_loss = loss(discriminator_out_gen_data.squeeze(),
                              true_labels)
        gen_loss.append(generator_loss.item())
        generator_loss.backward()
        generator_optimizer.step()

        # Train the discriminator
        # Teach Discriminator to distinguish true data with true label i.e [1,1,1,1,....]
        discriminator_optimizer.zero_grad()
        discriminator_out_true_data = discriminator(true_data)
        discriminator_loss_true_data = loss(
            discriminator_out_true_data.squeeze(), true_labels)

        # add .detach() here think about this
        discriminator_out_fake_data = discriminator(generated_data.detach())
        fake_labels = torch.zeros(batch_size)  # [0,0,0,.....]
        discriminator_loss_fake_data = loss(
            discriminator_out_fake_data.squeeze(), fake_labels)
        # total discriminator loss
        discriminator_loss = (discriminator_loss_true_data +
                              discriminator_loss_fake_data) / 2

        dis_loss.append(discriminator_loss.item())

        discriminator_loss.backward()
        discriminator_optimizer.step()
        if i % print_output_every_n_steps == 0:
            output = convert_float_matrix_to_int_list(generated_data)
            even_count = len(list(filter(lambda x: (x % 2 == 0), output)))
            print(
                f"steps: {i}, output: {output}, even count: {even_count}/16, Gen Loss: {np.round(generator_loss.item(),4)}, Dis Loss: {np.round(discriminator_loss.item(),4)}"
            )

    history = {}
    history['dis_loss'] = dis_loss
    history['gen_loss'] = gen_loss

    return generator, discriminator, history
netD_B.apply(weights_init_normal)

# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

##
criterion_classification = torch.nn.BCELoss()

# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),
                                               netG_B2A.parameters()),
                               lr=opt.lr,
                               betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G,
                                                   lr_lambda=LambdaLR(
                                                       opt.n_epochs, opt.epoch,
                                                       opt.decay_epoch).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A,
    lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B,
generator = Generator(nz, nc, ngf, opt.imageSize, ngpu).to(device)
generator.apply(weights_init)
if opt.generator != "":
    generator.load_state_dict(torch.load(opt.generator))
print(generator)

discriminator = Discriminator(nc, ndf, opt.imageSize, ngpu).to(device)
discriminator.apply(weights_init)
if opt.discriminator != "":
    discriminator.load_state_dict(torch.load(opt.discriminator))
print(discriminator)

# setup optimizer
optimizerD = optim.Adam(
    discriminator.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)
)
optimizerG = optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999))

fixed_noise = (
    torch.from_numpy(truncated_noise_sample(batch_size=64, dim_z=nz, truncation=0.4))
    .view(64, nz, 1, 1)
    .to(device)
)
real_label = 0.9
fake_label = 0

criterion = nn.BCELoss()

for epoch in range(opt.niter):
    for i, data in enumerate(dataloader, 0):