Пример #1
0
def train(args):
    """
    Training of the algorithm
    """
    logger = logging.getLogger(__name__)
    logger.info("Training")

    # Parameters
    dataset = args.dataset
    dataset_root = args.dataset_root
    nthreads = args.nthreads
    batch_size = args.batch_size
    dropout = args.dropout
    debug = args.debug
    base_lr = args.base_lr
    num_epochs = args.num_epochs
    discriminator_base_c = args.discriminator_base_c
    generator_base_c = args.generator_base_c
    latent_size = args.latent_size
    sample_nrows = 8
    sample_ncols = 8

    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda') if use_cuda else torch.device('cpu')

    # Dataloaders
    train_loader, valid_loader, img_shape = data.get_dataloaders(
        dataset_root=dataset_root,
        cuda=use_cuda,
        batch_size=batch_size,
        n_threads=nthreads,
        dataset=dataset,
        small_experiment=debug)

    # Model definition
    model = models.GAN(img_shape, dropout, discriminator_base_c, latent_size,
                       generator_base_c)
    model.to(device)

    # Optimizers
    critic = model.discriminator
    generator = model.generator

    # Step 1 - Define the optimizer for the critic
    optim_critic = torch.optim.Adam(critic.parameters(), lr=base_lr)
    # Step 2 - Define the optimizer for the generator
    optim_generator = torch.optim.Adam(generator.parameters(), lr=base_lr // 2)

    # Step 3 - Define the loss (it must embed the sigmoid)
    loss = torch.nn.BCEWithLogitsLoss()

    # Callbacks
    summary_text = "## Summary of the model architecture\n" + \
                   f"{deepcs.display.torch_summarize(model)}\n"
    summary_text += "\n\n## Executed command :\n" + \
                    "{}".format(" ".join(sys.argv))
    summary_text += "\n\n## Args : \n {}".format(args)

    logger.info(summary_text)

    logdir = generate_unique_logpath('./logs', 'gan')
    tensorboard_writer = SummaryWriter(log_dir=logdir, flush_secs=5)
    tensorboard_writer.add_text("Experiment summary",
                                deepcs.display.htmlize(summary_text))

    with open(os.path.join(logdir, "summary.txt"), 'w') as f:
        f.write(summary_text)

    save_path = os.path.join(logdir, 'generator.pt')

    logger.info(f">>>>> Results saved in {logdir}")

    # Define a fixed noise used for sampling
    fixed_noise = torch.randn(sample_nrows * sample_ncols,
                              latent_size).to(device)

    # Generate few samples from the initial generator
    model.eval()
    fake_images = model.generator(X=fixed_noise)
    grid = torchvision.utils.make_grid(fake_images,
                                       nrow=sample_nrows,
                                       normalize=True)
    tensorboard_writer.add_image("Generated", grid, 0)
    torchvision.utils.save_image(grid, 'images/images-0000.png')

    # Training loop
    for e in range(num_epochs):

        tot_closs = tot_gloss = 0
        critic_accuracy = 0
        Nc = Ng = 0
        model.train()
        for ei, (X, _) in enumerate(tqdm.tqdm(train_loader)):
            # X is a batch of real data
            X = X.to(device)
            bi = X.shape[0]

            pos_labels = torch.ones((bi, )).to(device)
            neg_labels = torch.zeros((bi, )).to(device)

            # Step 1 - Forward pass for training the discriminator
            real_logits, _ = model(X, bi)
            fake_logits, _ = model(None, bi)

            # Step 2 - Compute the loss of the critic
            Dloss = (loss(real_logits, pos_labels) +
                     loss(fake_logits, neg_labels)) // 2

            # Step 3 - Reinitialize the gradient accumulator of the critic
            optim_critic.zero_grad()

            # Step 4 - Perform the backward pass on the loss
            Dloss.backward()

            # Step 5 - Update the parameters of the critic
            optim_critic.step()

            real_probs = torch.nn.functional.sigmoid(real_logits)
            fake_probs = torch.nn.functional.sigmoid(fake_logits)
            critic_accuracy += (real_probs > 0.5).sum().item() + (
                fake_probs < 0.5).sum().item()
            dloss_e = Dloss.item()

            # Step 1 - Forward pass for training the generator
            fake_logits, _ = model(None, bi)

            # Step 2 - Compute the loss of the generator
            # The generator wants his generated images to be positive
            Gloss = loss(fake_logits, pos_labels)

            # Step 3 - Reinitialize the gradient accumulator of the critic
            optim_generator.zero_grad()

            # Step 4 - Perform the backward pass on the loss
            Gloss.backward()

            # Step 5 - Update the parameters of the generator
            optim_generator.step()

            gloss_e = Gloss.item()

            Nc += 2 * bi
            tot_closs += 2 * bi * dloss_e
            Ng += bi
            tot_gloss += bi * gloss_e

        critic_accuracy /= Nc
        tot_closs /= Nc
        tot_gloss /= Ng
        logger.info(
            f"[Epoch {e + 1}] C loss : {tot_closs} ; C accuracy : {critic_accuracy}, G loss : {tot_gloss}"
        )

        tensorboard_writer.add_scalar("Critic loss", tot_closs, e + 1)
        tensorboard_writer.add_scalar("Critic accuracy", critic_accuracy,
                                      e + 1)
        tensorboard_writer.add_scalar("Generator loss", tot_gloss, e + 1)

        # Generate few samples from the generator
        model.eval()
        fake_images = model.generator(X=fixed_noise)
        # Unscale the images
        fake_images = fake_images * data._MNIST_STD + data._MNIST_MEAN
        grid = torchvision.utils.make_grid(fake_images,
                                           nrow=sample_nrows,
                                           normalize=True)
        tensorboard_writer.add_image("Generated", grid, e + 1)
        torchvision.utils.save_image(grid, f'images/images-{e + 1:04d}.png')

        real_images = X[:sample_nrows * sample_ncols, ...]
        X = X * data._MNIST_STD + data._MNIST_MEAN
        grid = torchvision.utils.make_grid(real_images,
                                           nrow=sample_nrows,
                                           normalize=True)
        tensorboard_writer.add_image("Real", grid, e + 1)

        # We save the generator
        logger.info(f"Generator saved at {save_path}")
        torch.save(model.generator, save_path)
Пример #2
0
                # rising phase
                return low_lr + 2.0 * dt * (high_lr - low_lr)
            else:
                return high_lr + 2.0 * (dt - 0.5) * (low_lr - high_lr)

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, cyclical_lr)
    else:
        #scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,60,90,120,150], gamma=0.5)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=30,
                                                    gamma=0.5)

    # Callbacks

    ## Where to store the logs
    logdir = generate_unique_logpath('./logs', args.model)
    print(f"Logging to {logdir} ")
    if not os.path.exists(args.logdir):
        os.mkdir(args.logdir)
    if not os.path.exists(logdir):
        os.mkdir(logdir)

    # Display information about the model
    summary_text = f"""## Summary of the model architecture

{deepcs.display.torch_summarize(model, (batch_size, ) + input_dim)}

## Executed command

{' '.join(sys.argv)}
Пример #3
0
            nn.ReLU(), nn.Dropout(0.5), nn.Linear(32, num_classes))
        self.classifier = nn.Sequential(conv_classifier, nn.Flatten(),
                                        fc_classifier)

    def forward(self, x):
        return self.classifier(x)


# Parameters
dataset_dir = os.path.join(os.path.expanduser("~"), 'Datasets', 'MNIST')
batch_size = 64
num_workers = 4
n_epochs = 30
learning_rate = 0.01
device = torch.device('cpu')
logdir = generate_unique_logpath('./logs', 'linear')

# Datasets
train_valid_dataset = torchvision.datasets.MNIST(
    root=dataset_dir,
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_valid_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           pin_memory=True,
                                           num_workers=num_workers)
test_dataset = torchvision.datasets.MNIST(
    root=dataset_dir,
    train=False,
Пример #4
0
def train(args):
    """
    Training of the algorithm
    """
    logger = logging.getLogger(__name__)
    logger.info("Training")

    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda') if use_cuda else torch.device('cpu')

    # Data loading
    loaders = data.get_dataloaders(args.datasetroot,
                                   args.datasetversion,
                                   cuda=use_cuda,
                                   batch_size=args.batch_size,
                                   n_threads=args.nthreads,
                                   min_duration=args.min_duration,
                                   max_duration=args.max_duration,
                                   small_experiment=args.debug,
                                   train_augment=args.train_augment,
                                   nmels=args.nmels,
                                   logger=logger)
    train_loader, valid_loader, test_loader = loaders

    # Parameters
    n_mels = args.nmels
    nhidden_rnn = args.nhidden_rnn
    nlayers_rnn = args.nlayers_rnn
    cell_type = args.cell_type
    dropout = args.dropout
    base_lr = args.base_lr
    num_epochs = args.num_epochs
    grad_clip = args.grad_clip

    # We need the char map to know about the vocabulary size
    charmap = data.CharMap()
    blank_id = charmap.blankid

    # Model definition
    ###########################
    #### START CODING HERE ####
    ###########################
    model = None
    ##########################
    #### STOP CODING HERE ####
    ##########################

    decode = model.decode

    model.to(device)

    # Loss, optimizer
    baseloss = nn.CTCLoss(blank=blank_id, reduction='mean', zero_infinity=True)
    loss = lambda *params: baseloss(*wrap_ctc_args(*params))

    ###########################
    #### START CODING HERE ####
    ###########################
    optimizer = None
    ##########################
    #### STOP CODING HERE ####
    ##########################

    metrics = {'CTC': loss}

    # Callbacks
    summary_text = "## Summary of the model architecture\n" + \
            f"{deepcs.display.torch_summarize(model)}\n"
    summary_text += "\n\n## Executed command :\n" +\
        "{}".format(" ".join(sys.argv))
    summary_text += "\n\n## Args : \n {}".format(args)

    logger.info(summary_text)

    logdir = generate_unique_logpath('./logs', 'ctc')
    tensorboard_writer = SummaryWriter(log_dir=logdir, flush_secs=5)
    tensorboard_writer.add_text("Experiment summary",
                                deepcs.display.htmlize(summary_text))

    with open(os.path.join(logdir, "summary.txt"), 'w') as f:
        f.write(summary_text)

    logger.info(f">>>>> Results saved in {logdir}")

    model_checkpoint = ModelCheckpoint(model,
                                       os.path.join(logdir, 'best_model.pt'))
    scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    # Training loop
    for e in range(num_epochs):
        ftrain(model,
               train_loader,
               loss,
               optimizer,
               device,
               metrics,
               grad_clip=grad_clip,
               num_model_args=1,
               num_epoch=e,
               tensorboard_writer=tensorboard_writer)

        # Compute and record the metrics on the validation set
        valid_metrics = ftest(model,
                              valid_loader,
                              device,
                              metrics,
                              num_model_args=1)
        better_model = model_checkpoint.update(valid_metrics['CTC'])
        scheduler.step()

        logger.info("[%d/%d] Validation:   CTCLoss : %.3f %s" %
                    (e, num_epochs, valid_metrics['CTC'],
                     "[>> BETTER <<]" if better_model else ""))

        for m_name, m_value in valid_metrics.items():
            tensorboard_writer.add_scalar(f'metrics/valid_{m_name}', m_value,
                                          e + 1)
        # Compute and record the metrics on the test set
        test_metrics = ftest(model,
                             test_loader,
                             device,
                             metrics,
                             num_model_args=1)
        logger.info("[%d/%d] Test:   Loss : %.3f " %
                    (e, num_epochs, test_metrics['CTC']))
        for m_name, m_value in test_metrics.items():
            tensorboard_writer.add_scalar(f'metrics/test_{m_name}', m_value,
                                          e + 1)
        # Try to decode some of the validation samples
        model.eval()
        valid_decodings = decode_samples(decode,
                                         valid_loader,
                                         n=2,
                                         device=device,
                                         charmap=charmap)
        train_decodings = decode_samples(decode,
                                         train_loader,
                                         n=2,
                                         device=device,
                                         charmap=charmap)

        decoding_results = "## Decoding results on the training set\n"
        decoding_results += train_decodings
        decoding_results += "## Decoding results on the validation set\n"
        decoding_results += valid_decodings
        tensorboard_writer.add_text("Decodings",
                                    deepcs.display.htmlize(decoding_results),
                                    global_step=e + 1)
        logger.info("\n" + decoding_results)