Exemple #1
0
def visualize_image(decoder, dataloader, visualization_path, nb_images=1):
    """
    Writes the nifti files of images and their reconstructions by an autoencoder.

    Args:
        decoder: (Autoencoder) Autoencoder constructed from a CNN with the Autoencoder class.
        dataloader: (DataLoader) wrapper of the dataset.
        visualization_path: (str) directory in which the inputs and reconstructions will be stored.
        nb_images: (int) number of images to reconstruct.
    """
    import nibabel as nib
    import numpy as np
    from .iotools import check_and_clean

    check_and_clean(visualization_path)

    dataset = dataloader.dataset
    decoder.eval()

    for image_index in range(nb_images):
        data = dataset[image_index]
        image = data["image"].unsqueeze(0)
        output = decoder(image)

        output_np = output.squeeze(0).squeeze(0).cpu().detach().numpy()
        input_np = image.squeeze(0).squeeze(0).cpu().detach().numpy()
        output_nii = nib.Nifti1Image(output_np, np.eye(4))
        input_nii = nib.Nifti1Image(input_np, np.eye(4))
        nib.save(
            output_nii,
            os.path.join(visualization_path, 'output-%i.nii.gz' % image_index))
        nib.save(
            input_nii,
            os.path.join(visualization_path, 'input-%i.nii.gz' % image_index))
Exemple #2
0
def generate_shepplogan_dataset(output_dir,
                                img_size,
                                labels_distribution,
                                samples=100,
                                smoothing=True):

    check_and_clean(join(output_dir, "subjects"))
    commandline_to_json({
        "output_dir": output_dir,
        "img_size": img_size,
        "labels_distribution": labels_distribution,
        "samples": samples,
        "smoothing": smoothing
    })
    columns = ["participant_id", "session_id", "diagnosis", "subtype"]
    data_df = pd.DataFrame(columns=columns)

    for i, label in enumerate(labels_distribution.keys()):
        for j in range(samples):
            participant_id = "sub-CLNC%i%04d" % (i, j)
            session_id = "ses-M00"
            subtype = np.random.choice(np.arange(
                len(labels_distribution[label])),
                                       p=labels_distribution[label])
            row_df = pd.DataFrame(
                [[participant_id, session_id, label, subtype]],
                columns=columns)
            data_df = data_df.append(row_df)

            # Image generation
            path_out = join(
                output_dir, "subjects", "%s_%s%s.pt" %
                (participant_id, session_id, FILENAME_TYPE["shepplogan"]))
            img = generate_shepplogan_phantom(img_size,
                                              label=subtype,
                                              smoothing=smoothing)
            torch_img = torch.from_numpy(img).float().unsqueeze(0)
            torch.save(torch_img, path_out)

    data_df.to_csv(join(output_dir, 'data.tsv'), sep="\t", index=False)

    missing_path = join(output_dir, "missing_mods")
    if not exists(missing_path):
        makedirs(missing_path)

    sessions = data_df.session_id.unique()
    for session in sessions:
        session_df = data_df[data_df.session_id == session]
        out_df = copy(session_df[["participant_id"]])
        out_df["t1w"] = [1] * len(out_df)
        out_df.to_csv(join(missing_path, "missing_mods_%s.tsv" % session),
                      sep="\t",
                      index=False)
Exemple #3
0
def train(model,
          train_loader,
          valid_loader,
          criterion,
          optimizer,
          resume,
          log_dir,
          model_dir,
          options,
          logger=None):
    """
    Function used to train a CNN.
    The best model and checkpoint will be found in the 'best_model_dir' of options.output_dir.

    Args:
        model: (Module) CNN to be trained
        train_loader: (DataLoader) wrapper of the training dataset
        valid_loader: (DataLoader) wrapper of the validation dataset
        criterion: (loss) function to calculate the loss
        optimizer: (torch.optim) optimizer linked to model parameters
        resume: (bool) if True, a begun job is resumed
        log_dir: (str) path to the folder containing the logs
        model_dir: (str) path to the folder containing the models weights and biases
        options: (Namespace) ensemble of other options given to the main script.
        logger: (logging object) writer to stdout and stderr
    """
    from tensorboardX import SummaryWriter
    from time import time

    if logger is None:
        logger = logging

    columns = [
        'epoch', 'iteration', 'time', 'balanced_accuracy_train', 'loss_train',
        'balanced_accuracy_valid', 'loss_valid'
    ]
    if hasattr(model, "variational") and model.variational:
        columns += ["kl_loss_train", "kl_loss_valid"]
    filename = os.path.join(os.path.dirname(log_dir), 'training.tsv')

    if not resume:
        check_and_clean(model_dir)
        check_and_clean(log_dir)

        results_df = pd.DataFrame(columns=columns)
        with open(filename, 'w') as f:
            results_df.to_csv(f, index=False, sep='\t')
        options.beginning_epoch = 0

    else:
        if not os.path.exists(filename):
            raise ValueError(
                'The training.tsv file of the resumed experiment does not exist.'
            )
        truncated_tsv = pd.read_csv(filename, sep='\t')
        truncated_tsv.set_index(['epoch', 'iteration'], inplace=True)
        truncated_tsv.drop(options.beginning_epoch, level=0, inplace=True)
        truncated_tsv.to_csv(filename, index=True, sep='\t')

    # Create writers
    writer_train = SummaryWriter(os.path.join(log_dir, 'train'))
    writer_valid = SummaryWriter(os.path.join(log_dir, 'validation'))

    # Initialize variables
    best_valid_accuracy = -1.0
    best_valid_loss = np.inf
    epoch = options.beginning_epoch

    model.train()  # set the model to training mode
    train_loader.dataset.train()

    early_stopping = EarlyStopping('min',
                                   min_delta=options.tolerance,
                                   patience=options.patience)
    mean_loss_valid = None
    t_beginning = time()

    while epoch < options.epochs and not early_stopping.step(mean_loss_valid):
        logger.info("Beginning epoch %i." % epoch)

        model.zero_grad()
        evaluation_flag = True
        step_flag = True
        tend = time()
        total_time = 0

        for i, data in enumerate(train_loader, 0):
            t0 = time()
            total_time = total_time + t0 - tend
            if options.gpu:
                imgs, labels = data['image'].cuda(), data['label'].cuda()
            else:
                imgs, labels = data['image'], data['label']

            if hasattr(model, "variational") and model.variational:
                z, mu, std, train_output = model(imgs)
                kl_loss = kl_divergence(z, mu, std)
                loss = criterion(train_output, labels) + kl_loss
            else:
                train_output = model(imgs)
                loss = criterion(train_output, labels)

            # Back propagation
            loss.backward()

            del imgs, labels

            if (i + 1) % options.accumulation_steps == 0:
                step_flag = False
                optimizer.step()
                optimizer.zero_grad()

                del loss

                # Evaluate the model only when no gradients are accumulated
                if options.evaluation_steps != 0 and (
                        i + 1) % options.evaluation_steps == 0:
                    evaluation_flag = False

                    _, results_train = test(model, train_loader, options.gpu,
                                            criterion)
                    mean_loss_train = results_train["total_loss"] / (
                        len(train_loader) * train_loader.batch_size)

                    _, results_valid = test(model, valid_loader, options.gpu,
                                            criterion)
                    mean_loss_valid = results_valid["total_loss"] / (
                        len(valid_loader) * valid_loader.batch_size)
                    model.train()
                    train_loader.dataset.train()

                    global_step = i + epoch * len(train_loader)
                    writer_train.add_scalar('balanced_accuracy',
                                            results_train["balanced_accuracy"],
                                            global_step)
                    writer_train.add_scalar('loss', mean_loss_train,
                                            global_step)
                    writer_valid.add_scalar('balanced_accuracy',
                                            results_valid["balanced_accuracy"],
                                            global_step)
                    writer_valid.add_scalar('loss', mean_loss_valid,
                                            global_step)
                    logger.info(
                        "%s level training accuracy is %f at the end of iteration %d"
                        %
                        (options.mode, results_train["balanced_accuracy"], i))
                    logger.info(
                        "%s level validation accuracy is %f at the end of iteration %d"
                        %
                        (options.mode, results_valid["balanced_accuracy"], i))

                    t_current = time() - t_beginning
                    row = [
                        epoch, i, t_current,
                        results_train["balanced_accuracy"], mean_loss_train,
                        results_valid["balanced_accuracy"], mean_loss_valid
                    ]
                    if hasattr(model, "variational") and model.variational:
                        row += [
                            results_train["total_kl_loss"] /
                            (len(train_loader) * train_loader.batch_size),
                            results_valid["total_kl_loss"] /
                            (len(valid_loader) * valid_loader.batch_size)
                        ]
                    row_df = pd.DataFrame([row], columns=columns)
                    with open(filename, 'a') as f:
                        row_df.to_csv(f, header=False, index=False, sep='\t')

            tend = time()
        logger.debug(
            'Mean time per batch loading: %.10f s' %
            (total_time / len(train_loader) * train_loader.batch_size))

        # If no step has been performed, raise Exception
        if step_flag:
            raise Exception(
                'The model has not been updated once in the epoch. The accumulation step may be too large.'
            )

        # If no evaluation has been performed, warn the user
        elif evaluation_flag and options.evaluation_steps != 0:
            warnings.warn(
                'Your evaluation steps are too big compared to the size of the dataset.'
                'The model is evaluated only once at the end of the epoch')

        # Always test the results and save them once at the end of the epoch
        model.zero_grad()
        logger.debug('Last checkpoint at the end of the epoch %d' % epoch)

        _, results_train = test(model, train_loader, options.gpu, criterion)
        mean_loss_train = results_train["total_loss"] / (
            len(train_loader) * train_loader.batch_size)

        _, results_valid = test(model, valid_loader, options.gpu, criterion)
        mean_loss_valid = results_valid["total_loss"] / (
            len(valid_loader) * valid_loader.batch_size)
        model.train()
        train_loader.dataset.train()

        global_step = (epoch + 1) * len(train_loader)
        writer_train.add_scalar('balanced_accuracy',
                                results_train["balanced_accuracy"],
                                global_step)
        writer_train.add_scalar('loss', mean_loss_train, global_step)
        writer_valid.add_scalar('balanced_accuracy',
                                results_valid["balanced_accuracy"],
                                global_step)
        writer_valid.add_scalar('loss', mean_loss_valid, global_step)
        logger.info(
            "%s level training accuracy is %f at the end of iteration %d" %
            (options.mode, results_train["balanced_accuracy"],
             len(train_loader)))
        logger.info(
            "%s level validation accuracy is %f at the end of iteration %d" %
            (options.mode, results_valid["balanced_accuracy"],
             len(train_loader)))

        t_current = time() - t_beginning
        row = [
            epoch, i, t_current, results_train["balanced_accuracy"],
            mean_loss_train, results_valid["balanced_accuracy"],
            mean_loss_valid
        ]
        if hasattr(model, "variational") and model.variational:
            row += [
                results_train["total_kl_loss"] /
                (len(train_loader) * train_loader.batch_size),
                results_valid["total_kl_loss"] /
                (len(valid_loader) * valid_loader.batch_size)
            ]
        row_df = pd.DataFrame([row], columns=columns)
        with open(filename, 'a') as f:
            row_df.to_csv(f, header=False, index=False, sep='\t')

        accuracy_is_best = results_valid[
            "balanced_accuracy"] > best_valid_accuracy
        loss_is_best = mean_loss_valid < best_valid_loss
        best_valid_accuracy = max(results_valid["balanced_accuracy"],
                                  best_valid_accuracy)
        best_valid_loss = min(mean_loss_valid, best_valid_loss)

        save_checkpoint(
            {
                'model': model.state_dict(),
                'epoch': epoch,
                'valid_loss': mean_loss_valid,
                'valid_acc': results_valid["balanced_accuracy"]
            }, accuracy_is_best, loss_is_best, model_dir)
        # Save optimizer state_dict to be able to reload
        save_checkpoint(
            {
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'name': options.optimizer,
            },
            False,
            False,
            model_dir,
            filename='optimizer.pth.tar')

        epoch += 1

    os.remove(os.path.join(model_dir, "optimizer.pth.tar"))
    os.remove(os.path.join(model_dir, "checkpoint.pth.tar"))
Exemple #4
0
def train(decoder, train_loader, valid_loader, criterion, optimizer, resume,
          log_dir, model_dir, options):
    """
    Function used to train an autoencoder.
    The best autoencoder will be found in the 'best_model_dir' of options.output_dir.

    Args:
        decoder: (Autoencoder) Autoencoder constructed from a CNN with the Autoencoder class.
        train_loader: (DataLoader) wrapper of the training dataset.
        valid_loader: (DataLoader) wrapper of the validation dataset.
        criterion: (loss) function to calculate the loss.
        optimizer: (torch.optim) optimizer linked to model parameters.
        resume: (bool) if True, a begun job is resumed.
        log_dir: (str) path to the folder containing the logs.
        model_dir: (str) path to the folder containing the models weights and biases.
        options: (Namespace) ensemble of other options given to the main script.
    """
    from tensorboardX import SummaryWriter

    if not resume:
        check_and_clean(model_dir)
        check_and_clean(log_dir)
        options.beginning_epoch = 0

    # Create writers
    writer_train = SummaryWriter(os.path.join(log_dir, 'train'))
    writer_valid = SummaryWriter(os.path.join(log_dir, 'validation'))

    decoder.train()
    print(decoder)

    if options.gpu:
        decoder.cuda()

    # Initialize variables
    best_loss_valid = np.inf

    early_stopping = EarlyStopping('min',
                                   min_delta=options.tolerance,
                                   patience=options.patience)
    loss_valid = None
    epoch = options.beginning_epoch

    print("Beginning training")
    while epoch < options.epochs and not early_stopping.step(loss_valid):
        print("At %d-th epoch." % epoch)

        decoder.zero_grad()
        evaluation_flag = True
        step_flag = True
        for i, data in enumerate(train_loader):
            if options.gpu:
                imgs = data['image'].cuda()
            else:
                imgs = data['image']

            train_output = decoder(imgs)
            loss = criterion(train_output, imgs)
            loss.backward()

            del imgs, train_output

            if (i + 1) % options.accumulation_steps == 0:
                step_flag = False
                optimizer.step()
                optimizer.zero_grad()

                # Evaluate the decoder only when no gradients are accumulated
                if options.evaluation_steps != 0 and (
                        i + 1) % options.evaluation_steps == 0:
                    evaluation_flag = False
                    print('Iteration %d' % i)
                    loss_train = test_ae(decoder, train_loader, options.gpu,
                                         criterion)
                    mean_loss_train = loss_train / (len(train_loader) *
                                                    train_loader.batch_size)

                    loss_valid = test_ae(decoder, valid_loader, options.gpu,
                                         criterion)
                    mean_loss_valid = loss_valid / (len(valid_loader) *
                                                    valid_loader.batch_size)
                    decoder.train()

                    writer_train.add_scalar('loss', mean_loss_train,
                                            i + epoch * len(train_loader))
                    writer_valid.add_scalar('loss', mean_loss_valid,
                                            i + epoch * len(train_loader))
                    print(
                        "Scan level validation loss is %f at the end of iteration %d"
                        % (loss_valid, i))

        # If no step has been performed, raise Exception
        if step_flag:
            raise Exception(
                'The model has not been updated once in the epoch. The accumulation step may be too large.'
            )

        # If no evaluation has been performed, warn the user
        if evaluation_flag and options.evaluation_steps != 0:
            warnings.warn(
                'Your evaluation steps are too big compared to the size of the dataset.'
                'The model is evaluated only once at the end of the epoch')

        # Always test the results and save them once at the end of the epoch
        print('Last checkpoint at the end of the epoch %d' % epoch)

        loss_train = test_ae(decoder, train_loader, options.gpu, criterion)
        mean_loss_train = loss_train / (len(train_loader) *
                                        train_loader.batch_size)

        loss_valid = test_ae(decoder, valid_loader, options.gpu, criterion)
        mean_loss_valid = loss_valid / (len(valid_loader) *
                                        valid_loader.batch_size)
        decoder.train()

        writer_train.add_scalar('loss', mean_loss_train,
                                i + epoch * len(train_loader))
        writer_valid.add_scalar('loss', mean_loss_valid,
                                i + epoch * len(train_loader))
        print("Scan level validation loss is %f at the end of iteration %d" %
              (loss_valid, i))

        is_best = loss_valid < best_loss_valid
        best_loss_valid = min(best_loss_valid, loss_valid)
        # Always save the model at the end of the epoch and update best model
        save_checkpoint(
            {
                'model': decoder.state_dict(),
                'epoch': epoch,
                'valid_loss': loss_valid
            }, False, is_best, model_dir)
        # Save optimizer state_dict to be able to reload
        save_checkpoint(
            {
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'name': options.optimizer,
            },
            False,
            False,
            model_dir,
            filename='optimizer.pth.tar')

        epoch += 1

    os.remove(os.path.join(model_dir, "optimizer.pth.tar"))
    os.remove(os.path.join(model_dir, "checkpoint.pth.tar"))
Exemple #5
0
def ae_finetuning(decoder, train_loader, valid_loader, criterion, optimizer,
                  resume, options):
    """
    Function used to train an autoencoder.
    The best autoencoder and checkpoint will be found in the 'best_model_dir' of options.output_dir.

    :param decoder: (Autoencoder) Autoencoder constructed from a CNN with the Autoencoder class.
    :param train_loader: (DataLoader) wrapper of the training dataset.
    :param valid_loader: (DataLoader) wrapper of the validation dataset.
    :param criterion: (loss) function to calculate the loss.
    :param optimizer: (torch.optim) optimizer linked to model parameters.
    :param resume: (bool) if True, a begun job is resumed.
    :param options: (Namespace) ensemble of other options given to the main script.
    """
    from tensorboardX import SummaryWriter

    log_dir = os.path.join(options.output_dir, 'log_dir',
                           'fold_' + str(options.split), 'ConvAutoencoder')
    visualization_path = os.path.join(options.output_dir, 'visualize',
                                      'fold_' + str(options.split))
    best_model_dir = os.path.join(options.output_dir, 'best_model_dir',
                                  'fold_' + str(options.split),
                                  'ConvAutoencoder')
    filename = os.path.join(log_dir, 'training.tsv')

    if not resume:
        check_and_clean(best_model_dir)
        check_and_clean(visualization_path)
        check_and_clean(log_dir)
        columns = [
            'epoch', 'iteration', 'loss_train', 'mean_loss_train',
            'loss_valid', 'mean_loss_valid'
        ]
        results_df = pd.DataFrame(columns=columns)
        with open(filename, 'w') as f:
            results_df.to_csv(f, index=False, sep='\t')
        options.beginning_epoch = 0

    else:
        if not os.path.exists(filename):
            raise ValueError(
                'The training.tsv file of the resumed experiment does not exist.'
            )
        truncated_tsv = pd.read_csv(filename, sep='\t')
        truncated_tsv.set_index(['epoch', 'iteration'], inplace=True)
        truncated_tsv.drop(options.beginning_epoch, level=0, inplace=True)
        truncated_tsv.to_csv(filename, index=True, sep='\t')

    # Create writers
    writer_train = SummaryWriter(os.path.join(log_dir, 'train'))
    writer_valid = SummaryWriter(os.path.join(log_dir, 'valid'))

    decoder.train()
    first_visu = True
    print(decoder)

    if options.gpu:
        decoder.cuda()

    # Initialize variables
    best_loss_valid = np.inf

    early_stopping = EarlyStopping('min',
                                   min_delta=options.tolerance,
                                   patience=options.patience)
    loss_valid = None
    epoch = options.beginning_epoch

    print("Beginning training")
    while epoch < options.epochs and not early_stopping.step(loss_valid):
        print("At %d-th epoch." % epoch)

        decoder.zero_grad()
        evaluation_flag = True
        step_flag = True
        for i, data in enumerate(train_loader):
            if options.gpu:
                imgs = data['image'].cuda()
            else:
                imgs = data['image']

            train_output = decoder(imgs)
            loss = criterion(train_output, imgs)
            loss.backward()

            del imgs, train_output

            if (i + 1) % options.accumulation_steps == 0:
                step_flag = False
                optimizer.step()
                optimizer.zero_grad()

                # Evaluate the decoder only when no gradients are accumulated
                if (i + 1) % options.evaluation_steps == 0:
                    evaluation_flag = False
                    print('Iteration %d' % i)
                    loss_train = test_ae(decoder, train_loader, options.gpu,
                                         criterion)
                    mean_loss_train = loss_train / (len(train_loader) *
                                                    train_loader.batch_size)

                    loss_valid = test_ae(decoder, valid_loader, options.gpu,
                                         criterion)
                    mean_loss_valid = loss_valid / (len(valid_loader) *
                                                    valid_loader.batch_size)
                    decoder.train()

                    writer_train.add_scalar('loss', mean_loss_train,
                                            i + epoch * len(train_loader))
                    writer_valid.add_scalar('loss', mean_loss_valid,
                                            i + epoch * len(train_loader))
                    print(
                        "Scan level validation loss is %f at the end of iteration %d"
                        % (loss_valid, i))
                    row = np.array([
                        epoch, i, loss_train, mean_loss_train, loss_valid,
                        mean_loss_valid
                    ]).reshape(1, -1)
                    row_df = pd.DataFrame(row, columns=columns)
                    with open(filename, 'a') as f:
                        row_df.to_csv(f, header=False, index=False, sep='\t')

        # If no step has been performed, raise Exception
        if step_flag:
            raise Exception(
                'The model has not been updated once in the epoch. The accumulation step may be too large.'
            )

        # If no evaluation has been performed, warn the user
        if evaluation_flag:
            warnings.warn(
                'Your evaluation steps are too big compared to the size of the dataset.'
                'The model is evaluated only once at the end of the epoch')

        # Always test the results and save them once at the end of the epoch
        print('Last checkpoint at the end of the epoch %d' % epoch)

        loss_train = test_ae(decoder, train_loader, options.gpu, criterion)
        mean_loss_train = loss_train / (len(train_loader) *
                                        train_loader.batch_size)

        loss_valid = test_ae(decoder, valid_loader, options.gpu, criterion)
        mean_loss_valid = loss_valid / (len(valid_loader) *
                                        valid_loader.batch_size)
        decoder.train()

        writer_train.add_scalar('loss', mean_loss_train,
                                i + epoch * len(train_loader))
        writer_valid.add_scalar('loss', mean_loss_valid,
                                i + epoch * len(train_loader))
        print("Scan level validation loss is %f at the end of iteration %d" %
              (loss_valid, i))

        row = np.array([
            epoch, i, loss_train, mean_loss_train, loss_valid, mean_loss_valid
        ]).reshape(1, -1)
        row_df = pd.DataFrame(row, columns=columns)
        with open(filename, 'a') as f:
            row_df.to_csv(f, header=False, index=False, sep='\t')

        is_best = loss_valid < best_loss_valid
        best_loss_valid = min(best_loss_valid, loss_valid)
        # Always save the model at the end of the epoch and update best model
        save_checkpoint(
            {
                'model': decoder.state_dict(),
                'iteration': i,
                'epoch': epoch,
                'loss_valid': loss_valid
            }, False, is_best, best_model_dir)
        # Save optimizer state_dict to be able to reload
        save_checkpoint(
            {
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'name': options.optimizer,
            },
            False,
            False,
            best_model_dir,
            filename='optimizer.pth.tar')

        if epoch % 10 == 0:
            visualize_subject(decoder,
                              train_loader,
                              visualization_path,
                              options,
                              epoch=epoch,
                              save_input=first_visu)
            first_visu = False

        epoch += 1

    visualize_subject(decoder,
                      train_loader,
                      visualization_path,
                      options,
                      epoch=epoch,
                      save_input=first_visu)
Exemple #6
0
def train(model, train_loader, valid_loader, criterion, optimizer, resume,
          options):
    """
    Function used to train a CNN.
    The best model and checkpoint will be found in the 'best_model_dir' of options.output_dir.

    :param model: (Module) CNN to be trained
    :param train_loader: (DataLoader) wrapper of the training dataset
    :param valid_loader: (DataLoader) wrapper of the validation dataset
    :param criterion: (loss) function to calculate the loss
    :param optimizer: (torch.optim) optimizer linked to model parameters
    :param resume: (bool) if True, a begun job is resumed
    :param options: (Namespace) ensemble of other options given to the main script.
    """
    from tensorboardX import SummaryWriter
    from time import time

    columns = [
        'epoch', 'iteration', 'acc_train', 'mean_loss_train', 'acc_valid',
        'mean_loss_valid', 'time'
    ]
    log_dir = os.path.join(options.output_dir, 'log_dir',
                           'fold_' + str(options.split), 'CNN')
    best_model_dir = os.path.join(options.output_dir, 'best_model_dir',
                                  'fold_' + str(options.split), 'CNN')
    filename = os.path.join(log_dir, 'training.tsv')

    if not resume:
        check_and_clean(best_model_dir)
        check_and_clean(log_dir)

        results_df = pd.DataFrame(columns=columns)
        with open(filename, 'w') as f:
            results_df.to_csv(f, index=False, sep='\t')
        options.beginning_epoch = 0

    else:
        if not os.path.exists(filename):
            raise ValueError(
                'The training.tsv file of the resumed experiment does not exist.'
            )
        truncated_tsv = pd.read_csv(filename, sep='\t')
        truncated_tsv.set_index(['epoch', 'iteration'], inplace=True)
        truncated_tsv.drop(options.beginning_epoch, level=0, inplace=True)
        truncated_tsv.to_csv(filename, index=True, sep='\t')

    # Create writers
    writer_train = SummaryWriter(os.path.join(log_dir, 'train'))
    writer_valid = SummaryWriter(os.path.join(log_dir, 'valid'))

    # Initialize variables
    best_valid_accuracy = 0.0
    best_valid_loss = np.inf
    epoch = options.beginning_epoch

    model.train()  # set the module to training mode

    early_stopping = EarlyStopping('min',
                                   min_delta=options.tolerance,
                                   patience=options.patience)
    mean_loss_valid = None
    t_beggining = time()

    while epoch < options.epochs and not early_stopping.step(mean_loss_valid):
        print("At %d-th epoch." % epoch)

        model.zero_grad()
        evaluation_flag = True
        step_flag = True
        tend = time()
        total_time = 0

        for i, data in enumerate(train_loader, 0):
            t0 = time()
            total_time = total_time + t0 - tend
            if options.gpu:
                imgs, labels = data['image'].cuda(), data['label'].cuda()
            else:
                imgs, labels = data['image'], data['label']
            train_output = model(imgs)
            _, predict_batch = train_output.topk(1)
            loss = criterion(train_output, labels)

            # Back propagation
            loss.backward()

            del imgs, labels

            if (i + 1) % options.accumulation_steps == 0:
                step_flag = False
                optimizer.step()
                optimizer.zero_grad()

                del loss

                # Evaluate the model only when no gradients are accumulated
                if (i + 1) % options.evaluation_steps == 0:
                    evaluation_flag = False
                    print('Iteration %d' % i)

                    acc_mean_train, total_loss_train = test(
                        model, train_loader, options.gpu, criterion)
                    mean_loss_train = total_loss_train / (
                        len(train_loader) * train_loader.batch_size)

                    acc_mean_valid, total_loss_valid = test(
                        model, valid_loader, options.gpu, criterion)
                    mean_loss_valid = total_loss_valid / (
                        len(valid_loader) * valid_loader.batch_size)
                    model.train()

                    writer_train.add_scalar('balanced_accuracy',
                                            acc_mean_train,
                                            i + epoch * len(train_loader))
                    writer_train.add_scalar('loss', mean_loss_train,
                                            i + epoch * len(train_loader))
                    writer_valid.add_scalar('balanced_accuracy',
                                            acc_mean_valid,
                                            i + epoch * len(train_loader))
                    writer_valid.add_scalar('loss', mean_loss_valid,
                                            i + epoch * len(train_loader))
                    print(
                        "Subject level training accuracy is %f at the end of iteration %d"
                        % (acc_mean_train, i))
                    print(
                        "Subject level validation accuracy is %f at the end of iteration %d"
                        % (acc_mean_valid, i))

                    t_current = time() - t_beggining
                    row = np.array([
                        epoch, i, acc_mean_train, mean_loss_train,
                        acc_mean_valid, mean_loss_valid, t_current
                    ]).reshape(1, -1)
                    row_df = pd.DataFrame(row, columns=columns)
                    with open(filename, 'a') as f:
                        row_df.to_csv(f, header=False, index=False, sep='\t')

            tend = time()
        print('Mean time per batch (train):',
              total_time / len(train_loader) * train_loader.batch_size)

        # If no step has been performed, raise Exception
        if step_flag:
            raise Exception(
                'The model has not been updated once in the epoch. The accumulation step may be too large.'
            )

        # If no evaluation has been performed, warn the user
        elif evaluation_flag:
            warnings.warn(
                'Your evaluation steps are too big compared to the size of the dataset.'
                'The model is evaluated only once at the end of the epoch')

        # Always test the results and save them once at the end of the epoch
        model.zero_grad()
        print('Last checkpoint at the end of the epoch %d' % epoch)

        acc_mean_train, total_loss_train = test(model, train_loader,
                                                options.gpu, criterion)
        mean_loss_train = total_loss_train / (len(train_loader) *
                                              train_loader.batch_size)

        acc_mean_valid, total_loss_valid = test(model, valid_loader,
                                                options.gpu, criterion)
        mean_loss_valid = total_loss_valid / (len(valid_loader) *
                                              valid_loader.batch_size)
        model.train()

        writer_train.add_scalar('balanced_accuracy', acc_mean_train,
                                i + epoch * len(train_loader))
        writer_train.add_scalar('loss', mean_loss_train,
                                i + epoch * len(train_loader))
        writer_valid.add_scalar('balanced_accuracy', acc_mean_valid,
                                i + epoch * len(train_loader))
        writer_valid.add_scalar('loss', mean_loss_valid,
                                i + epoch * len(train_loader))
        print(
            "Subject level training accuracy is %f at the end of iteration %d"
            % (acc_mean_train, i))
        print(
            "Subject level validation accuracy is %f at the end of iteration %d"
            % (acc_mean_valid, i))

        t_current = time() - t_beggining
        row = np.array([
            epoch, i, acc_mean_train, mean_loss_train, acc_mean_valid,
            mean_loss_valid, t_current
        ]).reshape(1, -1)
        row_df = pd.DataFrame(row, columns=columns)
        with open(filename, 'a') as f:
            row_df.to_csv(f, header=False, index=False, sep='\t')
        accuracy_is_best = acc_mean_valid > best_valid_accuracy
        loss_is_best = mean_loss_valid < best_valid_loss
        best_valid_accuracy = max(acc_mean_valid, best_valid_accuracy)
        best_valid_loss = min(mean_loss_valid, best_valid_loss)

        save_checkpoint(
            {
                'model': model.state_dict(),
                'epoch': epoch,
                'valid_acc': acc_mean_valid
            }, accuracy_is_best, loss_is_best, best_model_dir)
        # Save optimizer state_dict to be able to reload
        save_checkpoint(
            {
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'name': options.optimizer,
            },
            False,
            False,
            best_model_dir,
            filename='optimizer.pth.tar')

        epoch += 1
def train(decoder,
          train_loader,
          valid_loader,
          criterion,
          optimizer,
          resume,
          log_dir,
          model_dir,
          options,
          logger=None):
    """
    Function used to train an autoencoder.
    The best autoencoder will be found in the 'best_model_dir' of options.output_dir.

    Args:
        decoder: (Autoencoder) Autoencoder constructed from a CNN with the Autoencoder class.
        train_loader: (DataLoader) wrapper of the training dataset.
        valid_loader: (DataLoader) wrapper of the validation dataset.
        criterion: (loss) function to calculate the loss.
        optimizer: (torch.optim) optimizer linked to model parameters.
        resume: (bool) if True, a begun job is resumed.
        log_dir: (str) path to the folder containing the logs.
        model_dir: (str) path to the folder containing the models weights and biases.
        options: (Namespace) ensemble of other options given to the main script.
        logger: (logging object) writer to stdout and stderr
    """
    from tensorboardX import SummaryWriter

    columns = ['epoch', 'iteration', 'time', 'loss_train', 'loss_valid']
    filename = os.path.join(os.path.dirname(log_dir), 'training.tsv')

    if logger is None:
        logger = logging

    columns = ['epoch', 'iteration', 'time', 'loss_train', 'loss_valid']
    filename = os.path.join(os.path.dirname(log_dir), 'training.tsv')

    if not resume:
        check_and_clean(model_dir)
        check_and_clean(log_dir)

        results_df = pd.DataFrame(columns=columns)
        with open(filename, 'w') as f:
            results_df.to_csv(f, index=False, sep='\t')
        options.beginning_epoch = 0

    else:
        if not os.path.exists(filename):
            raise ValueError(
                'The training.tsv file of the resumed experiment does not exist.'
            )
        truncated_tsv = pd.read_csv(filename, sep='\t')
        truncated_tsv.set_index(['epoch', 'iteration'], inplace=True)
        truncated_tsv.drop(options.beginning_epoch, level=0, inplace=True)
        truncated_tsv.to_csv(filename, index=True, sep='\t')

    # Create writers
    writer_train = SummaryWriter(os.path.join(log_dir, 'train'))
    writer_valid = SummaryWriter(os.path.join(log_dir, 'validation'))

    decoder.train()
    train_loader.dataset.train()
    logger.debug(decoder)

    if options.gpu:
        decoder.cuda()

    # Initialize variables
    best_loss_valid = np.inf
    epoch = options.beginning_epoch

    early_stopping = EarlyStopping('min',
                                   min_delta=options.tolerance,
                                   patience=options.patience)
    loss_valid = None
    t_beginning = time()

    logger.debug("Beginning training")
    while epoch < options.epochs and not early_stopping.step(loss_valid):
        logger.info("Beginning epoch %i." % epoch)

        decoder.zero_grad()
        evaluation_flag = True
        step_flag = True
        for i, data in enumerate(train_loader):
            if options.gpu:
                imgs = data['image'].cuda()
            else:
                imgs = data['image']

            train_output = decoder(imgs)
            loss = criterion(train_output, imgs)
            loss.backward()

            del imgs, train_output

            if (i + 1) % options.accumulation_steps == 0:
                step_flag = False
                optimizer.step()
                optimizer.zero_grad()

                # Evaluate the decoder only when no gradients are accumulated
                if options.evaluation_steps != 0 and (
                        i + 1) % options.evaluation_steps == 0:
                    evaluation_flag = False
                    loss_train = test_ae(decoder, train_loader, options.gpu,
                                         criterion)
                    mean_loss_train = loss_train / \
                        (len(train_loader) * train_loader.batch_size)

                    loss_valid = test_ae(decoder, valid_loader, options.gpu,
                                         criterion)
                    mean_loss_valid = loss_valid / \
                        (len(valid_loader) * valid_loader.batch_size)
                    decoder.train()
                    train_loader.dataset.train()

                    writer_train.add_scalar('loss', mean_loss_train,
                                            i + epoch * len(train_loader))
                    writer_valid.add_scalar('loss', mean_loss_valid,
                                            i + epoch * len(train_loader))
                    logger.info(
                        "%s level training loss is %f at the end of iteration %d"
                        % (options.mode, mean_loss_train, i))
                    logger.info(
                        "%s level validation loss is %f at the end of iteration %d"
                        % (options.mode, mean_loss_valid, i))

                    t_current = time() - t_beginning
                    row = [
                        epoch, i, t_current, mean_loss_train, mean_loss_valid
                    ]
                    row_df = pd.DataFrame([row], columns=columns)
                    with open(filename, 'a') as f:
                        row_df.to_csv(f, header=False, index=False, sep='\t')

        # If no step has been performed, raise Exception
        if step_flag:
            raise Exception(
                'The model has not been updated once in the epoch. The accumulation step may be too large.'
            )

        # If no evaluation has been performed, warn the user
        if evaluation_flag and options.evaluation_steps != 0:
            logger.warning(
                'Your evaluation steps are too big compared to the size of the dataset.'
                'The model is evaluated only once at the end of the epoch')

        # Always test the results and save them once at the end of the epoch
        logger.debug('Last checkpoint at the end of the epoch %d' % epoch)

        loss_train = test_ae(decoder, train_loader, options.gpu, criterion)
        mean_loss_train = loss_train / \
            (len(train_loader) * train_loader.batch_size)

        loss_valid = test_ae(decoder, valid_loader, options.gpu, criterion)
        mean_loss_valid = loss_valid / \
            (len(valid_loader) * valid_loader.batch_size)
        decoder.train()
        train_loader.dataset.train()

        writer_train.add_scalar('loss', mean_loss_train,
                                i + epoch * len(train_loader))
        writer_valid.add_scalar('loss', mean_loss_valid,
                                i + epoch * len(train_loader))
        logger.info("%s level training loss is %f at the end of iteration %d" %
                    (options.mode, mean_loss_train, i))
        logger.info(
            "%s level validation loss is %f at the end of iteration %d" %
            (options.mode, mean_loss_valid, i))

        t_current = time() - t_beginning
        row = [epoch, i, t_current, mean_loss_train, mean_loss_valid]
        row_df = pd.DataFrame([row], columns=columns)
        with open(filename, 'a') as f:
            row_df.to_csv(f, header=False, index=False, sep='\t')

        is_best = loss_valid < best_loss_valid
        best_loss_valid = min(best_loss_valid, loss_valid)
        # Always save the model at the end of the epoch and update best model
        save_checkpoint(
            {
                'model': decoder.state_dict(),
                'epoch': epoch,
                'valid_loss': loss_valid
            }, False, is_best, model_dir)
        # Save optimizer state_dict to be able to reload
        save_checkpoint(
            {
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'name': options.optimizer,
            },
            False,
            False,
            model_dir,
            filename='optimizer.pth.tar')

        epoch += 1

    os.remove(os.path.join(model_dir, "optimizer.pth.tar"))
    os.remove(os.path.join(model_dir, "checkpoint.pth.tar"))
Exemple #8
0
def train(model,
          train_loader,
          valid_loader,
          criterion,
          optimizer,
          resume,
          log_dir,
          model_dir,
          options,
          fi=None,
          cnn_index=None,
          num_cnn=None,
          train_begin_time=None):
    """
    Function used to train a CNN.
    The best model and checkpoint will be found in the 'best_model_dir' of options.output_dir.

    Args:
        model: (Module) CNN to be trained
        train_loader: (DataLoader) wrapper of the training dataset
        valid_loader: (DataLoader) wrapper of the validation dataset
        criterion: (loss) function to calculate the loss
        optimizer: (torch.optim) optimizer linked to model parameters
        resume: (bool) if True, a begun job is resumed
        log_dir: (str) path to the folder containing the logs
        model_dir: (str) path to the folder containing the models weights and biases
        options: (Namespace) ensemble of other options given to the main script.
    """
    from tensorboardX import SummaryWriter
    from time import time
    import wandb

    if not resume:
        check_and_clean(model_dir)
        check_and_clean(log_dir)

    # Create writers
    writer_train = SummaryWriter(os.path.join(log_dir, 'train'))
    writer_valid = SummaryWriter(os.path.join(log_dir, 'validation'))

    # Initialize variables
    best_valid_accuracy = 0.0
    best_valid_loss = np.inf
    epoch = options.beginning_epoch

    model.train()  # set the module to training mode

    early_stopping = EarlyStopping('min',
                                   min_delta=options.tolerance,
                                   patience=options.patience)
    mean_loss_valid = None

    while epoch < options.epochs and not early_stopping.step(mean_loss_valid):
        if fi is not None and options.n_splits is not None:
            print("[%s]: At (%d/%d) fold (%d/%d) epoch." %
                  (timeSince(train_begin_time), fi, options.n_splits, epoch,
                   options.epochs))
        else:
            print("[%s]: At (%d/%d) epoch." %
                  (timeSince(train_begin_time), epoch, options.epochs))

        model.zero_grad()
        evaluation_flag = True
        step_flag = True
        tend = time()
        total_time = 0

        for i, data in enumerate(train_loader, 0):
            t0 = time()
            total_time = total_time + t0 - tend
            if options.gpu:
                device = torch.device("cuda:{}".format(options.device))
                imgs, labels, participant_id, session_id = data['image'].to(
                    device), data['label'].to(
                        device), data['participant_id'], data['session_id']
            else:
                imgs, labels, participant_id, session_id = data['image'], data[
                    'label'], data['participant_id'], data['session_id']
            if options.model == 'ROI_GCN' or 'gcn' in options.model:
                # roi_image = data['roi_image'].to(device)
                train_output = model(imgs,
                                     label_list=labels,
                                     id=participant_id,
                                     session=session_id,
                                     fi=fi,
                                     epoch=epoch)
            else:
                train_output = model(imgs)
            # train_output = model(imgs)
            _, predict_batch = train_output.topk(1)
            loss = criterion(train_output, labels)

            # Back propagation
            loss.backward()
            # for name, param in model.named_parameters():
            #     if param.requires_grad:
            #         if param.grad is not None:
            #             pass
            #             # print("{}, gradient: {}".format(name, param.grad.mean()))
            #         else:
            #             print("{} has not gradient".format(name))
            # del imgs, labels

            if (i + 1) % options.accumulation_steps == 0:
                step_flag = False
                optimizer.step()
                optimizer.zero_grad()

                del loss

                # Evaluate the model only when no gradients are accumulated
                if options.evaluation_steps != 0 and (
                        i + 1) % options.evaluation_steps == 0:
                    evaluation_flag = False
                    print('Iteration %d' % i)

                    _, results_train = test(model,
                                            train_loader,
                                            options.gpu,
                                            criterion,
                                            device_index=options.device,
                                            train_begin_time=train_begin_time)
                    mean_loss_train = results_train["total_loss"] / (
                        len(train_loader) * train_loader.batch_size)

                    _, results_valid = test(model,
                                            valid_loader,
                                            options.gpu,
                                            criterion,
                                            device_index=options.device,
                                            train_begin_time=train_begin_time)
                    mean_loss_valid = results_valid["total_loss"] / (
                        len(valid_loader) * valid_loader.batch_size)
                    model.train()

                    global_step = i + epoch * len(train_loader)
                    if cnn_index is None:
                        writer_train.add_scalar(
                            'balanced_accuracy',
                            results_train["balanced_accuracy"], global_step)
                        writer_train.add_scalar('loss', mean_loss_train,
                                                global_step)
                        writer_valid.add_scalar(
                            'balanced_accuracy',
                            results_valid["balanced_accuracy"], global_step)
                        writer_valid.add_scalar('loss', mean_loss_valid,
                                                global_step)
                        wandb.log({
                            'train_balanced_accuracy':
                            results_train["balanced_accuracy"],
                            'train_loss':
                            mean_loss_train,
                            'valid_balanced_accuracy':
                            results_valid["balanced_accuracy"],
                            'valid_loss':
                            mean_loss_valid,
                            'global_step':
                            global_step
                        })
                        print(
                            "[%s]: %s level training accuracy is %f at the end of iteration %d - fake mri count: %d"
                            % (timeSince(train_begin_time), options.mode,
                               results_train["balanced_accuracy"], i,
                               data['num_fake_mri']))
                        print(
                            "[%s]: %s level validation accuracy is %f at the end of iteration %d - fake mri count: %d"
                            % (timeSince(train_begin_time), options.mode,
                               results_valid["balanced_accuracy"], i,
                               data['num_fake_mri']))
                    else:
                        writer_train.add_scalar(
                            '{}_model_balanced_accuracy'.format(cnn_index),
                            results_train["balanced_accuracy"], global_step)
                        writer_train.add_scalar(
                            '{}_model_loss'.format(cnn_index), mean_loss_train,
                            global_step)
                        writer_valid.add_scalar(
                            '{}_model_balanced_accuracy'.format(cnn_index),
                            results_valid["balanced_accuracy"], global_step)
                        writer_valid.add_scalar(
                            '{}_model_loss'.format(cnn_index), mean_loss_valid,
                            global_step)
                        wandb.log({
                            '{}_model_train_balanced_accuracy'.format(cnn_index):
                            results_train["balanced_accuracy"],
                            '{}_model_train_loss'.format(cnn_index):
                            mean_loss_train,
                            '{}_model_valid_balanced_accuracy'.format(cnn_index):
                            results_valid["balanced_accuracy"],
                            '{}_model_valid_loss'.format(cnn_index):
                            mean_loss_valid,
                            'global_step':
                            global_step
                        })
                        print(
                            "[{}]: ({}/{}) model {} level training accuracy is {} at the end of iteration {}-fake mri count:{}"
                            .format(timeSince(train_begin_time), cnn_index,
                                    num_cnn, options.mode,
                                    results_train["balanced_accuracy"], i,
                                    data['num_fake_mri']))
                        print(
                            "[{}]: ({}/{}) model {} level validation accuracy is {} at the end of iteration {}-fake mri count:{}"
                            .format(timeSince(train_begin_time), cnn_index,
                                    num_cnn, options.mode,
                                    results_valid["balanced_accuracy"], i,
                                    data['num_fake_mri']))

            tend = time()
        print(
            '[{}]: Mean time per batch loading (train):'.format(
                timeSince(train_begin_time)),
            total_time / len(train_loader) * train_loader.batch_size)

        # If no step has been performed, raise Exception
        if step_flag:
            raise Exception(
                'The model has not been updated once in the epoch. The accumulation step may be too large.'
            )

        # If no evaluation has been performed, warn the user
        elif evaluation_flag and options.evaluation_steps != 0:
            warnings.warn(
                'Your evaluation steps are too big compared to the size of the dataset.'
                'The model is evaluated only once at the end of the epoch')

        # Always test the results and save them once at the end of the epoch
        model.zero_grad()
        print('[%s]: Last checkpoint at the end of the epoch %d' %
              (timeSince(train_begin_time), epoch))

        _, results_train = test(model,
                                train_loader,
                                options.gpu,
                                criterion,
                                device_index=options.device,
                                train_begin_time=train_begin_time,
                                model_options=options)
        mean_loss_train = results_train["total_loss"] / (
            len(train_loader) * train_loader.batch_size)

        _, results_valid = test(model,
                                valid_loader,
                                options.gpu,
                                criterion,
                                device_index=options.device,
                                train_begin_time=train_begin_time,
                                model_options=options)
        mean_loss_valid = results_valid["total_loss"] / (
            len(valid_loader) * valid_loader.batch_size)
        model.train()

        global_step = (epoch + 1) * len(train_loader)
        if cnn_index is None:
            writer_train.add_scalar('balanced_accuracy',
                                    results_train["balanced_accuracy"],
                                    global_step)
            writer_train.add_scalar('loss', mean_loss_train, global_step)
            writer_valid.add_scalar('balanced_accuracy',
                                    results_valid["balanced_accuracy"],
                                    global_step)
            writer_valid.add_scalar('loss', mean_loss_valid, global_step)
            wandb.log({
                'train_balanced_accuracy':
                results_train["balanced_accuracy"],
                'train_loss':
                mean_loss_train,
                'valid_balanced_accuracy':
                results_valid["balanced_accuracy"],
                'valid_loss':
                mean_loss_valid,
                'global_step':
                global_step
            })
            print(
                "[%s]: %s level training accuracy is %f at the end of iteration %d"
                % (timeSince(train_begin_time), options.mode,
                   results_train["balanced_accuracy"], i))
            print(
                "[%s]: %s level validation accuracy is %f at the end of iteration %d"
                % (timeSince(train_begin_time), options.mode,
                   results_valid["balanced_accuracy"], i))
        else:
            writer_train.add_scalar(
                '{}_model_balanced_accuracy'.format(cnn_index),
                results_train["balanced_accuracy"], global_step)
            writer_train.add_scalar('{}_model_loss'.format(cnn_index),
                                    mean_loss_train, global_step)
            writer_valid.add_scalar(
                '{}_model_balanced_accuracy'.format(cnn_index),
                results_valid["balanced_accuracy"], global_step)
            writer_valid.add_scalar('{}_model_loss'.format(cnn_index),
                                    mean_loss_valid, global_step)
            wandb.log({
                '{}_model_train_balanced_accuracy'.format(cnn_index):
                results_train["balanced_accuracy"],
                '{}_model_train_loss'.format(cnn_index):
                mean_loss_train,
                '{}_model_valid_balanced_accuracy'.format(cnn_index):
                results_valid["balanced_accuracy"],
                '{}_model_valid_loss'.format(cnn_index):
                mean_loss_valid,
                'global_step':
                global_step
            })
            print(
                "[%s]: %s model %s level training accuracy is %f at the end of iteration %d"
                % (timeSince(train_begin_time), cnn_index, options.mode,
                   results_train["balanced_accuracy"], i))
            print(
                "[%s]: %s model %s level validation accuracy is %f at the end of iteration %d"
                % (timeSince(train_begin_time), cnn_index, options.mode,
                   results_valid["balanced_accuracy"], i))

        accuracy_is_best = results_valid[
            "balanced_accuracy"] > best_valid_accuracy
        loss_is_best = mean_loss_valid < best_valid_loss
        best_valid_accuracy = max(results_valid["balanced_accuracy"],
                                  best_valid_accuracy)
        best_valid_loss = min(mean_loss_valid, best_valid_loss)

        save_checkpoint(
            {
                'model': model.state_dict(),
                'epoch': epoch,
                'valid_loss': mean_loss_valid,
                'valid_acc': results_valid["balanced_accuracy"]
            }, accuracy_is_best, loss_is_best, model_dir)
        # Save optimizer state_dict to be able to reload
        save_checkpoint(
            {
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'name': options.optimizer,
            },
            False,
            False,
            model_dir,
            filename='optimizer.pth.tar')

        epoch += 1

    os.remove(os.path.join(model_dir, "optimizer.pth.tar"))
    os.remove(os.path.join(model_dir, "checkpoint.pth.tar"))
Exemple #9
0
def train(model, train_loader, valid_loader, criterion, optimizer, resume, log_dir, model_dir, options):
    """
    Function used to train a CNN.
    The best model and checkpoint will be found in the 'best_model_dir' of options.output_dir.

    Args:
        model: (Module) CNN to be trained
        train_loader: (DataLoader) wrapper of the training dataset
        valid_loader: (DataLoader) wrapper of the validation dataset
        criterion: (loss) function to calculate the loss
        optimizer: (torch.optim) optimizer linked to model parameters
        resume: (bool) if True, a begun job is resumed
        log_dir: (str) path to the folder containing the logs
        model_dir: (str) path to the folder containing the models weights and biases
        options: (Namespace) ensemble of other options given to the main script.
    """
    from tensorboardX import SummaryWriter
    from time import time

    if not resume:
        check_and_clean(model_dir)
        check_and_clean(log_dir)

    # Create writers
    writer_train = SummaryWriter(os.path.join(log_dir, 'train'))
    writer_valid = SummaryWriter(os.path.join(log_dir, 'validation'))

    # Initialize variables
    best_valid_accuracy = 0.0
    best_valid_loss = np.inf
    epoch = options.beginning_epoch

    model.train()  # set the module to training mode

    early_stopping = EarlyStopping('min', min_delta=options.tolerance, patience=options.patience)
    mean_loss_valid = None

    while epoch < options.epochs and not early_stopping.step(mean_loss_valid):
        print("At %d-th epoch." % epoch)

        model.zero_grad()
        evaluation_flag = True
        step_flag = True
        tend = time()
        total_time = 0

        for i, data in enumerate(train_loader, 0):
            t0 = time()
            total_time = total_time + t0 - tend
            if options.gpu:
                imgs, labels = data['image'].cuda(), data['label'].cuda()
            else:
                imgs, labels = data['image'], data['label']
            train_output = model(imgs)
            _, predict_batch = train_output.topk(1)
            loss = criterion(train_output, labels)

            # Back propagation
            loss.backward()

            del imgs, labels

            if (i + 1) % options.accumulation_steps == 0:
                step_flag = False
                optimizer.step()
                optimizer.zero_grad()

                del loss

                # Evaluate the model only when no gradients are accumulated
                if options.evaluation_steps != 0 and (i + 1) % options.evaluation_steps == 0:
                    evaluation_flag = False
                    print('Iteration %d' % i)

                    _, results_train = test(model, train_loader, options.gpu, criterion)
                    mean_loss_train = results_train["total_loss"] / (len(train_loader) * train_loader.batch_size)

                    _, results_valid = test(model, valid_loader, options.gpu, criterion)
                    mean_loss_valid = results_valid["total_loss"] / (len(valid_loader) * valid_loader.batch_size)
                    model.train()

                    global_step = i + epoch * len(train_loader)
                    writer_train.add_scalar('balanced_accuracy', results_train["balanced_accuracy"], global_step)
                    writer_train.add_scalar('loss', mean_loss_train, global_step)
                    writer_valid.add_scalar('balanced_accuracy', results_valid["balanced_accuracy"], global_step)
                    writer_valid.add_scalar('loss', mean_loss_valid, global_step)
                    print("%s level training accuracy is %f at the end of iteration %d"
                          % (options.mode, results_train["balanced_accuracy"], i))
                    print("%s level validation accuracy is %f at the end of iteration %d"
                          % (options.mode, results_valid["balanced_accuracy"], i))

            tend = time()
        print('Mean time per batch loading (train):', total_time / len(train_loader) * train_loader.batch_size)

        # If no step has been performed, raise Exception
        if step_flag:
            raise Exception('The model has not been updated once in the epoch. The accumulation step may be too large.')

        # If no evaluation has been performed, warn the user
        elif evaluation_flag and options.evaluation_steps != 0:
            warnings.warn('Your evaluation steps are too big compared to the size of the dataset.'
                          'The model is evaluated only once at the end of the epoch')

        # Always test the results and save them once at the end of the epoch
        model.zero_grad()
        print('Last checkpoint at the end of the epoch %d' % epoch)

        _, results_train = test(model, train_loader, options.gpu, criterion)
        mean_loss_train = results_train["total_loss"] / (len(train_loader) * train_loader.batch_size)

        _, results_valid = test(model, valid_loader, options.gpu, criterion)
        mean_loss_valid = results_valid["total_loss"] / (len(valid_loader) * valid_loader.batch_size)
        model.train()

        global_step = (epoch + 1) * len(train_loader)
        writer_train.add_scalar('balanced_accuracy', results_train["balanced_accuracy"], global_step)
        writer_train.add_scalar('loss', mean_loss_train, global_step)
        writer_valid.add_scalar('balanced_accuracy', results_valid["balanced_accuracy"], global_step)
        writer_valid.add_scalar('loss', mean_loss_valid, global_step)
        print("%s level training accuracy is %f at the end of iteration %d"
              % (options.mode, results_train["balanced_accuracy"], len(train_loader)))
        print("%s level validation accuracy is %f at the end of iteration %d"
              % (options.mode, results_valid["balanced_accuracy"], len(train_loader)))

        accuracy_is_best = results_valid["balanced_accuracy"] > best_valid_accuracy
        loss_is_best = mean_loss_valid < best_valid_loss
        best_valid_accuracy = max(results_valid["balanced_accuracy"], best_valid_accuracy)
        best_valid_loss = min(mean_loss_valid, best_valid_loss)

        save_checkpoint({'model': model.state_dict(),
                         'epoch': epoch,
                         'valid_loss': mean_loss_valid,
                         'valid_acc': results_valid["balanced_accuracy"]},
                        accuracy_is_best, loss_is_best,
                        model_dir)
        # Save optimizer state_dict to be able to reload
        save_checkpoint({'optimizer': optimizer.state_dict(),
                         'epoch': epoch,
                         'name': options.optimizer,
                         },
                        False, False,
                        model_dir,
                        filename='optimizer.pth.tar')

        epoch += 1

    os.remove(os.path.join(model_dir, "optimizer.pth.tar"))
    os.remove(os.path.join(model_dir, "checkpoint.pth.tar"))