Esempio n. 1
0
def train(model,
          train_loader,
          valid_loader,
          criterion,
          optimizer,
          resume,
          log_dir,
          model_dir,
          options,
          logger=None):
    """
    Function used to train a CNN.
    The best model and checkpoint will be found in the 'best_model_dir' of options.output_dir.

    Args:
        model: (Module) CNN to be trained
        train_loader: (DataLoader) wrapper of the training dataset
        valid_loader: (DataLoader) wrapper of the validation dataset
        criterion: (loss) function to calculate the loss
        optimizer: (torch.optim) optimizer linked to model parameters
        resume: (bool) if True, a begun job is resumed
        log_dir: (str) path to the folder containing the logs
        model_dir: (str) path to the folder containing the models weights and biases
        options: (Namespace) ensemble of other options given to the main script.
        logger: (logging object) writer to stdout and stderr
    """
    from tensorboardX import SummaryWriter
    from time import time

    if logger is None:
        logger = logging

    columns = [
        'epoch', 'iteration', 'time', 'balanced_accuracy_train', 'loss_train',
        'balanced_accuracy_valid', 'loss_valid'
    ]
    if hasattr(model, "variational") and model.variational:
        columns += ["kl_loss_train", "kl_loss_valid"]
    filename = os.path.join(os.path.dirname(log_dir), 'training.tsv')

    if not resume:
        check_and_clean(model_dir)
        check_and_clean(log_dir)

        results_df = pd.DataFrame(columns=columns)
        with open(filename, 'w') as f:
            results_df.to_csv(f, index=False, sep='\t')
        options.beginning_epoch = 0

    else:
        if not os.path.exists(filename):
            raise ValueError(
                'The training.tsv file of the resumed experiment does not exist.'
            )
        truncated_tsv = pd.read_csv(filename, sep='\t')
        truncated_tsv.set_index(['epoch', 'iteration'], inplace=True)
        truncated_tsv.drop(options.beginning_epoch, level=0, inplace=True)
        truncated_tsv.to_csv(filename, index=True, sep='\t')

    # Create writers
    writer_train = SummaryWriter(os.path.join(log_dir, 'train'))
    writer_valid = SummaryWriter(os.path.join(log_dir, 'validation'))

    # Initialize variables
    best_valid_accuracy = -1.0
    best_valid_loss = np.inf
    epoch = options.beginning_epoch

    model.train()  # set the model to training mode
    train_loader.dataset.train()

    early_stopping = EarlyStopping('min',
                                   min_delta=options.tolerance,
                                   patience=options.patience)
    mean_loss_valid = None
    t_beginning = time()

    while epoch < options.epochs and not early_stopping.step(mean_loss_valid):
        logger.info("Beginning epoch %i." % epoch)

        model.zero_grad()
        evaluation_flag = True
        step_flag = True
        tend = time()
        total_time = 0

        for i, data in enumerate(train_loader, 0):
            t0 = time()
            total_time = total_time + t0 - tend
            if options.gpu:
                imgs, labels = data['image'].cuda(), data['label'].cuda()
            else:
                imgs, labels = data['image'], data['label']

            if hasattr(model, "variational") and model.variational:
                z, mu, std, train_output = model(imgs)
                kl_loss = kl_divergence(z, mu, std)
                loss = criterion(train_output, labels) + kl_loss
            else:
                train_output = model(imgs)
                loss = criterion(train_output, labels)

            # Back propagation
            loss.backward()

            del imgs, labels

            if (i + 1) % options.accumulation_steps == 0:
                step_flag = False
                optimizer.step()
                optimizer.zero_grad()

                del loss

                # Evaluate the model only when no gradients are accumulated
                if options.evaluation_steps != 0 and (
                        i + 1) % options.evaluation_steps == 0:
                    evaluation_flag = False

                    _, results_train = test(model, train_loader, options.gpu,
                                            criterion)
                    mean_loss_train = results_train["total_loss"] / (
                        len(train_loader) * train_loader.batch_size)

                    _, results_valid = test(model, valid_loader, options.gpu,
                                            criterion)
                    mean_loss_valid = results_valid["total_loss"] / (
                        len(valid_loader) * valid_loader.batch_size)
                    model.train()
                    train_loader.dataset.train()

                    global_step = i + epoch * len(train_loader)
                    writer_train.add_scalar('balanced_accuracy',
                                            results_train["balanced_accuracy"],
                                            global_step)
                    writer_train.add_scalar('loss', mean_loss_train,
                                            global_step)
                    writer_valid.add_scalar('balanced_accuracy',
                                            results_valid["balanced_accuracy"],
                                            global_step)
                    writer_valid.add_scalar('loss', mean_loss_valid,
                                            global_step)
                    logger.info(
                        "%s level training accuracy is %f at the end of iteration %d"
                        %
                        (options.mode, results_train["balanced_accuracy"], i))
                    logger.info(
                        "%s level validation accuracy is %f at the end of iteration %d"
                        %
                        (options.mode, results_valid["balanced_accuracy"], i))

                    t_current = time() - t_beginning
                    row = [
                        epoch, i, t_current,
                        results_train["balanced_accuracy"], mean_loss_train,
                        results_valid["balanced_accuracy"], mean_loss_valid
                    ]
                    if hasattr(model, "variational") and model.variational:
                        row += [
                            results_train["total_kl_loss"] /
                            (len(train_loader) * train_loader.batch_size),
                            results_valid["total_kl_loss"] /
                            (len(valid_loader) * valid_loader.batch_size)
                        ]
                    row_df = pd.DataFrame([row], columns=columns)
                    with open(filename, 'a') as f:
                        row_df.to_csv(f, header=False, index=False, sep='\t')

            tend = time()
        logger.debug(
            'Mean time per batch loading: %.10f s' %
            (total_time / len(train_loader) * train_loader.batch_size))

        # If no step has been performed, raise Exception
        if step_flag:
            raise Exception(
                'The model has not been updated once in the epoch. The accumulation step may be too large.'
            )

        # If no evaluation has been performed, warn the user
        elif evaluation_flag and options.evaluation_steps != 0:
            warnings.warn(
                'Your evaluation steps are too big compared to the size of the dataset.'
                'The model is evaluated only once at the end of the epoch')

        # Always test the results and save them once at the end of the epoch
        model.zero_grad()
        logger.debug('Last checkpoint at the end of the epoch %d' % epoch)

        _, results_train = test(model, train_loader, options.gpu, criterion)
        mean_loss_train = results_train["total_loss"] / (
            len(train_loader) * train_loader.batch_size)

        _, results_valid = test(model, valid_loader, options.gpu, criterion)
        mean_loss_valid = results_valid["total_loss"] / (
            len(valid_loader) * valid_loader.batch_size)
        model.train()
        train_loader.dataset.train()

        global_step = (epoch + 1) * len(train_loader)
        writer_train.add_scalar('balanced_accuracy',
                                results_train["balanced_accuracy"],
                                global_step)
        writer_train.add_scalar('loss', mean_loss_train, global_step)
        writer_valid.add_scalar('balanced_accuracy',
                                results_valid["balanced_accuracy"],
                                global_step)
        writer_valid.add_scalar('loss', mean_loss_valid, global_step)
        logger.info(
            "%s level training accuracy is %f at the end of iteration %d" %
            (options.mode, results_train["balanced_accuracy"],
             len(train_loader)))
        logger.info(
            "%s level validation accuracy is %f at the end of iteration %d" %
            (options.mode, results_valid["balanced_accuracy"],
             len(train_loader)))

        t_current = time() - t_beginning
        row = [
            epoch, i, t_current, results_train["balanced_accuracy"],
            mean_loss_train, results_valid["balanced_accuracy"],
            mean_loss_valid
        ]
        if hasattr(model, "variational") and model.variational:
            row += [
                results_train["total_kl_loss"] /
                (len(train_loader) * train_loader.batch_size),
                results_valid["total_kl_loss"] /
                (len(valid_loader) * valid_loader.batch_size)
            ]
        row_df = pd.DataFrame([row], columns=columns)
        with open(filename, 'a') as f:
            row_df.to_csv(f, header=False, index=False, sep='\t')

        accuracy_is_best = results_valid[
            "balanced_accuracy"] > best_valid_accuracy
        loss_is_best = mean_loss_valid < best_valid_loss
        best_valid_accuracy = max(results_valid["balanced_accuracy"],
                                  best_valid_accuracy)
        best_valid_loss = min(mean_loss_valid, best_valid_loss)

        save_checkpoint(
            {
                'model': model.state_dict(),
                'epoch': epoch,
                'valid_loss': mean_loss_valid,
                'valid_acc': results_valid["balanced_accuracy"]
            }, accuracy_is_best, loss_is_best, model_dir)
        # Save optimizer state_dict to be able to reload
        save_checkpoint(
            {
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'name': options.optimizer,
            },
            False,
            False,
            model_dir,
            filename='optimizer.pth.tar')

        epoch += 1

    os.remove(os.path.join(model_dir, "optimizer.pth.tar"))
    os.remove(os.path.join(model_dir, "checkpoint.pth.tar"))
def train_CNN_bad_data_split(params):

    # Initialize the model
    print('Do transfer learning with existed model trained on ImageNet!\n')
    print('The chosen network is %s !' % params.network)

    model = create_model(params.network, params.gpu)
    trg_size = (224, 224
                )  # most of the imagenet pretrained model has this input size

    # All pre-trained models expect input images normalized in the same way,
    # i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H
    # and W are expected to be at least 224. The images have to be loaded in to
    # a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406]
    # and std = [0.229, 0.224, 0.225].
    transformations = transforms.Compose([
        MinMaxNormalization(),
        transforms.ToPILImage(),
        transforms.Resize(trg_size),
        transforms.ToTensor()
    ])

    total_time = time()
    init_state = copy.deepcopy(model.state_dict())

    if params.split is None:
        fold_iterator = range(params.n_splits)
    else:
        fold_iterator = [params.split]

    for fi in fold_iterator:
        print("Running for the %d-th fold" % fi)

        training_sub_df, valid_sub_df = load_data(params.tsv_path,
                                                  params.diagnoses,
                                                  fi,
                                                  n_splits=params.n_splits,
                                                  baseline=params.baseline)

        # split the training + validation by slice
        training_df, valid_df = mix_slices(training_sub_df,
                                           valid_sub_df,
                                           mri_plane=params.mri_plane)

        data_train = MRIDataset_slice_mixed(params.caps_directory,
                                            training_df,
                                            transformations=transformations,
                                            mri_plane=params.mri_plane,
                                            prepare_dl=params.prepare_dl)

        data_valid = MRIDataset_slice_mixed(params.caps_directory,
                                            valid_df,
                                            transformations=transformations,
                                            mri_plane=params.mri_plane,
                                            prepare_dl=params.prepare_dl)

        # Use argument load to distinguish training and testing
        train_loader = DataLoader(data_train,
                                  batch_size=params.batch_size,
                                  shuffle=True,
                                  num_workers=params.num_workers,
                                  pin_memory=True)

        valid_loader = DataLoader(data_valid,
                                  batch_size=params.batch_size,
                                  shuffle=False,
                                  num_workers=params.num_workers,
                                  pin_memory=True)

        # chosen optimizer for back-propagation
        optimizer = eval("torch.optim." + params.optimizer)(
            filter(lambda x: x.requires_grad, model.parameters()),
            params.learning_rate,
            weight_decay=params.weight_decay)

        model.load_state_dict(init_state)

        # Binary cross-entropy loss
        loss = torch.nn.CrossEntropyLoss()

        # parameters used in training
        best_accuracy = 0.0
        best_loss_valid = np.inf

        writer_train_batch = SummaryWriter(
            log_dir=(os.path.join(params.output_dir, "log_dir", "fold_" +
                                  str(fi), "train_batch")))
        writer_train_all_data = SummaryWriter(
            log_dir=(os.path.join(params.output_dir, "log_dir", "fold_" +
                                  str(fi), "train_all_data")))

        writer_valid = SummaryWriter(
            log_dir=(os.path.join(params.output_dir, "log_dir", "fold_" +
                                  str(fi), "valid")))

        # initialize the early stopping instance
        early_stopping = EarlyStopping('min',
                                       min_delta=params.tolerance,
                                       patience=params.patience)

        for epoch in range(params.epochs):
            print("At %i-th epoch." % epoch)

            # train the model
            train_df, acc_mean_train, loss_batch_mean_train, global_step \
                = train(
                        model,
                        train_loader,
                        params,
                        loss,
                        optimizer,
                        writer_train_batch,
                        epoch,
                        model_mode='train',
                        selection_threshold=params.selection_threshold
                        )

            # calculate the accuracy with the whole training data to monitor overfitting
            train_all_df, acc_mean_train_all, loss_batch_mean_train_all, _\
                = train(
                        model,
                        train_loader,
                        params.gpu,
                        loss,
                        optimizer,
                        writer_train_all_data,
                        epoch,
                        model_mode='valid',
                        selection_threshold=params.selection_threshold
                        )

            print(
                "For training, subject level balanced accuracy is %f at the end of epoch %d"
                % (acc_mean_train_all, epoch))

            # at then end of each epoch, we validate one time for the model with the validation data
            valid_df, acc_mean_valid, loss_batch_mean_valid, _ \
                = train(
                        model,
                        valid_loader,
                        params.gpu,
                        loss,
                        optimizer,
                        writer_valid,
                        epoch,
                        model_mode='valid',
                        selection_threshold=params.selection_threshold
                        )

            print(
                "For validation, subject level balanced accuracy is %f at the end of epoch %d"
                % (acc_mean_valid, epoch))

            # save the best model based on the best loss and accuracy
            acc_is_best = acc_mean_valid > best_accuracy
            best_accuracy = max(best_accuracy, acc_mean_valid)
            loss_is_best = loss_batch_mean_valid < best_loss_valid
            best_loss_valid = min(loss_batch_mean_valid, best_loss_valid)

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model': model.state_dict(),
                    'loss': loss_batch_mean_valid,
                    'accuracy': acc_mean_valid,
                    'optimizer': optimizer.state_dict(),
                    'global_step': global_step
                }, acc_is_best, loss_is_best,
                os.path.join(params.output_dir, "best_model_dir",
                             "fold_" + str(fi), "CNN"))

            # try early stopping criterion
            if early_stopping.step(
                    loss_batch_mean_valid) or epoch == params.epochs - 1:
                print(
                    "By applying early stopping or at the last epoch defined by user, "
                    "the training is stopped at %d-th epoch" % epoch)

                break

        # Final evaluation for all criteria
        for selection in ['best_loss', 'best_acc']:
            model, best_epoch = load_model(
                model,
                os.path.join(params.output_dir, 'best_model_dir',
                             'fold_%i' % fi, 'CNN', str(selection)),
                gpu=params.gpu,
                filename='model_best.pth.tar')

            train_df, metrics_train = test(model, train_loader, params.gpu,
                                           loss)
            valid_df, metrics_valid = test(model, valid_loader, params.gpu,
                                           loss)

            # write the information of subjects and performances into tsv files.
            slice_level_to_tsvs(params.output_dir,
                                train_df,
                                metrics_train,
                                fi,
                                dataset='train',
                                selection=selection)
            slice_level_to_tsvs(params.output_dir,
                                valid_df,
                                metrics_valid,
                                fi,
                                dataset='validation',
                                selection=selection)

            soft_voting_to_tsvs(params.output_dir,
                                fi,
                                dataset='train',
                                selection=selection,
                                selection_threshold=params.selection_threshold)
            soft_voting_to_tsvs(params.output_dir,
                                fi,
                                dataset='validation',
                                selection=selection,
                                selection_threshold=params.selection_threshold)
            torch.cuda.empty_cache()

    total_time = time() - total_time
    print("Total time of computation: %d s" % total_time)
Esempio n. 3
0
def train(decoder, train_loader, valid_loader, criterion, optimizer, resume,
          log_dir, model_dir, options):
    """
    Function used to train an autoencoder.
    The best autoencoder will be found in the 'best_model_dir' of options.output_dir.

    Args:
        decoder: (Autoencoder) Autoencoder constructed from a CNN with the Autoencoder class.
        train_loader: (DataLoader) wrapper of the training dataset.
        valid_loader: (DataLoader) wrapper of the validation dataset.
        criterion: (loss) function to calculate the loss.
        optimizer: (torch.optim) optimizer linked to model parameters.
        resume: (bool) if True, a begun job is resumed.
        log_dir: (str) path to the folder containing the logs.
        model_dir: (str) path to the folder containing the models weights and biases.
        options: (Namespace) ensemble of other options given to the main script.
    """
    from tensorboardX import SummaryWriter

    if not resume:
        check_and_clean(model_dir)
        check_and_clean(log_dir)
        options.beginning_epoch = 0

    # Create writers
    writer_train = SummaryWriter(os.path.join(log_dir, 'train'))
    writer_valid = SummaryWriter(os.path.join(log_dir, 'validation'))

    decoder.train()
    print(decoder)

    if options.gpu:
        decoder.cuda()

    # Initialize variables
    best_loss_valid = np.inf

    early_stopping = EarlyStopping('min',
                                   min_delta=options.tolerance,
                                   patience=options.patience)
    loss_valid = None
    epoch = options.beginning_epoch

    print("Beginning training")
    while epoch < options.epochs and not early_stopping.step(loss_valid):
        print("At %d-th epoch." % epoch)

        decoder.zero_grad()
        evaluation_flag = True
        step_flag = True
        for i, data in enumerate(train_loader):
            if options.gpu:
                imgs = data['image'].cuda()
            else:
                imgs = data['image']

            train_output = decoder(imgs)
            loss = criterion(train_output, imgs)
            loss.backward()

            del imgs, train_output

            if (i + 1) % options.accumulation_steps == 0:
                step_flag = False
                optimizer.step()
                optimizer.zero_grad()

                # Evaluate the decoder only when no gradients are accumulated
                if options.evaluation_steps != 0 and (
                        i + 1) % options.evaluation_steps == 0:
                    evaluation_flag = False
                    print('Iteration %d' % i)
                    loss_train = test_ae(decoder, train_loader, options.gpu,
                                         criterion)
                    mean_loss_train = loss_train / (len(train_loader) *
                                                    train_loader.batch_size)

                    loss_valid = test_ae(decoder, valid_loader, options.gpu,
                                         criterion)
                    mean_loss_valid = loss_valid / (len(valid_loader) *
                                                    valid_loader.batch_size)
                    decoder.train()

                    writer_train.add_scalar('loss', mean_loss_train,
                                            i + epoch * len(train_loader))
                    writer_valid.add_scalar('loss', mean_loss_valid,
                                            i + epoch * len(train_loader))
                    print(
                        "Scan level validation loss is %f at the end of iteration %d"
                        % (loss_valid, i))

        # If no step has been performed, raise Exception
        if step_flag:
            raise Exception(
                'The model has not been updated once in the epoch. The accumulation step may be too large.'
            )

        # If no evaluation has been performed, warn the user
        if evaluation_flag and options.evaluation_steps != 0:
            warnings.warn(
                'Your evaluation steps are too big compared to the size of the dataset.'
                'The model is evaluated only once at the end of the epoch')

        # Always test the results and save them once at the end of the epoch
        print('Last checkpoint at the end of the epoch %d' % epoch)

        loss_train = test_ae(decoder, train_loader, options.gpu, criterion)
        mean_loss_train = loss_train / (len(train_loader) *
                                        train_loader.batch_size)

        loss_valid = test_ae(decoder, valid_loader, options.gpu, criterion)
        mean_loss_valid = loss_valid / (len(valid_loader) *
                                        valid_loader.batch_size)
        decoder.train()

        writer_train.add_scalar('loss', mean_loss_train,
                                i + epoch * len(train_loader))
        writer_valid.add_scalar('loss', mean_loss_valid,
                                i + epoch * len(train_loader))
        print("Scan level validation loss is %f at the end of iteration %d" %
              (loss_valid, i))

        is_best = loss_valid < best_loss_valid
        best_loss_valid = min(best_loss_valid, loss_valid)
        # Always save the model at the end of the epoch and update best model
        save_checkpoint(
            {
                'model': decoder.state_dict(),
                'epoch': epoch,
                'valid_loss': loss_valid
            }, False, is_best, model_dir)
        # Save optimizer state_dict to be able to reload
        save_checkpoint(
            {
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'name': options.optimizer,
            },
            False,
            False,
            model_dir,
            filename='optimizer.pth.tar')

        epoch += 1

    os.remove(os.path.join(model_dir, "optimizer.pth.tar"))
    os.remove(os.path.join(model_dir, "checkpoint.pth.tar"))
Esempio n. 4
0
def ae_finetuning(decoder, train_loader, valid_loader, criterion, optimizer,
                  resume, options):
    """
    Function used to train an autoencoder.
    The best autoencoder and checkpoint will be found in the 'best_model_dir' of options.output_dir.

    :param decoder: (Autoencoder) Autoencoder constructed from a CNN with the Autoencoder class.
    :param train_loader: (DataLoader) wrapper of the training dataset.
    :param valid_loader: (DataLoader) wrapper of the validation dataset.
    :param criterion: (loss) function to calculate the loss.
    :param optimizer: (torch.optim) optimizer linked to model parameters.
    :param resume: (bool) if True, a begun job is resumed.
    :param options: (Namespace) ensemble of other options given to the main script.
    """
    from tensorboardX import SummaryWriter

    log_dir = os.path.join(options.output_dir, 'log_dir',
                           'fold_' + str(options.split), 'ConvAutoencoder')
    visualization_path = os.path.join(options.output_dir, 'visualize',
                                      'fold_' + str(options.split))
    best_model_dir = os.path.join(options.output_dir, 'best_model_dir',
                                  'fold_' + str(options.split),
                                  'ConvAutoencoder')
    filename = os.path.join(log_dir, 'training.tsv')

    if not resume:
        check_and_clean(best_model_dir)
        check_and_clean(visualization_path)
        check_and_clean(log_dir)
        columns = [
            'epoch', 'iteration', 'loss_train', 'mean_loss_train',
            'loss_valid', 'mean_loss_valid'
        ]
        results_df = pd.DataFrame(columns=columns)
        with open(filename, 'w') as f:
            results_df.to_csv(f, index=False, sep='\t')
        options.beginning_epoch = 0

    else:
        if not os.path.exists(filename):
            raise ValueError(
                'The training.tsv file of the resumed experiment does not exist.'
            )
        truncated_tsv = pd.read_csv(filename, sep='\t')
        truncated_tsv.set_index(['epoch', 'iteration'], inplace=True)
        truncated_tsv.drop(options.beginning_epoch, level=0, inplace=True)
        truncated_tsv.to_csv(filename, index=True, sep='\t')

    # Create writers
    writer_train = SummaryWriter(os.path.join(log_dir, 'train'))
    writer_valid = SummaryWriter(os.path.join(log_dir, 'valid'))

    decoder.train()
    first_visu = True
    print(decoder)

    if options.gpu:
        decoder.cuda()

    # Initialize variables
    best_loss_valid = np.inf

    early_stopping = EarlyStopping('min',
                                   min_delta=options.tolerance,
                                   patience=options.patience)
    loss_valid = None
    epoch = options.beginning_epoch

    print("Beginning training")
    while epoch < options.epochs and not early_stopping.step(loss_valid):
        print("At %d-th epoch." % epoch)

        decoder.zero_grad()
        evaluation_flag = True
        step_flag = True
        for i, data in enumerate(train_loader):
            if options.gpu:
                imgs = data['image'].cuda()
            else:
                imgs = data['image']

            train_output = decoder(imgs)
            loss = criterion(train_output, imgs)
            loss.backward()

            del imgs, train_output

            if (i + 1) % options.accumulation_steps == 0:
                step_flag = False
                optimizer.step()
                optimizer.zero_grad()

                # Evaluate the decoder only when no gradients are accumulated
                if (i + 1) % options.evaluation_steps == 0:
                    evaluation_flag = False
                    print('Iteration %d' % i)
                    loss_train = test_ae(decoder, train_loader, options.gpu,
                                         criterion)
                    mean_loss_train = loss_train / (len(train_loader) *
                                                    train_loader.batch_size)

                    loss_valid = test_ae(decoder, valid_loader, options.gpu,
                                         criterion)
                    mean_loss_valid = loss_valid / (len(valid_loader) *
                                                    valid_loader.batch_size)
                    decoder.train()

                    writer_train.add_scalar('loss', mean_loss_train,
                                            i + epoch * len(train_loader))
                    writer_valid.add_scalar('loss', mean_loss_valid,
                                            i + epoch * len(train_loader))
                    print(
                        "Scan level validation loss is %f at the end of iteration %d"
                        % (loss_valid, i))
                    row = np.array([
                        epoch, i, loss_train, mean_loss_train, loss_valid,
                        mean_loss_valid
                    ]).reshape(1, -1)
                    row_df = pd.DataFrame(row, columns=columns)
                    with open(filename, 'a') as f:
                        row_df.to_csv(f, header=False, index=False, sep='\t')

        # If no step has been performed, raise Exception
        if step_flag:
            raise Exception(
                'The model has not been updated once in the epoch. The accumulation step may be too large.'
            )

        # If no evaluation has been performed, warn the user
        if evaluation_flag:
            warnings.warn(
                'Your evaluation steps are too big compared to the size of the dataset.'
                'The model is evaluated only once at the end of the epoch')

        # Always test the results and save them once at the end of the epoch
        print('Last checkpoint at the end of the epoch %d' % epoch)

        loss_train = test_ae(decoder, train_loader, options.gpu, criterion)
        mean_loss_train = loss_train / (len(train_loader) *
                                        train_loader.batch_size)

        loss_valid = test_ae(decoder, valid_loader, options.gpu, criterion)
        mean_loss_valid = loss_valid / (len(valid_loader) *
                                        valid_loader.batch_size)
        decoder.train()

        writer_train.add_scalar('loss', mean_loss_train,
                                i + epoch * len(train_loader))
        writer_valid.add_scalar('loss', mean_loss_valid,
                                i + epoch * len(train_loader))
        print("Scan level validation loss is %f at the end of iteration %d" %
              (loss_valid, i))

        row = np.array([
            epoch, i, loss_train, mean_loss_train, loss_valid, mean_loss_valid
        ]).reshape(1, -1)
        row_df = pd.DataFrame(row, columns=columns)
        with open(filename, 'a') as f:
            row_df.to_csv(f, header=False, index=False, sep='\t')

        is_best = loss_valid < best_loss_valid
        best_loss_valid = min(best_loss_valid, loss_valid)
        # Always save the model at the end of the epoch and update best model
        save_checkpoint(
            {
                'model': decoder.state_dict(),
                'iteration': i,
                'epoch': epoch,
                'loss_valid': loss_valid
            }, False, is_best, best_model_dir)
        # Save optimizer state_dict to be able to reload
        save_checkpoint(
            {
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'name': options.optimizer,
            },
            False,
            False,
            best_model_dir,
            filename='optimizer.pth.tar')

        if epoch % 10 == 0:
            visualize_subject(decoder,
                              train_loader,
                              visualization_path,
                              options,
                              epoch=epoch,
                              save_input=first_visu)
            first_visu = False

        epoch += 1

    visualize_subject(decoder,
                      train_loader,
                      visualization_path,
                      options,
                      epoch=epoch,
                      save_input=first_visu)
Esempio n. 5
0
def train(model, train_loader, valid_loader, criterion, optimizer, resume,
          options):
    """
    Function used to train a CNN.
    The best model and checkpoint will be found in the 'best_model_dir' of options.output_dir.

    :param model: (Module) CNN to be trained
    :param train_loader: (DataLoader) wrapper of the training dataset
    :param valid_loader: (DataLoader) wrapper of the validation dataset
    :param criterion: (loss) function to calculate the loss
    :param optimizer: (torch.optim) optimizer linked to model parameters
    :param resume: (bool) if True, a begun job is resumed
    :param options: (Namespace) ensemble of other options given to the main script.
    """
    from tensorboardX import SummaryWriter
    from time import time

    columns = [
        'epoch', 'iteration', 'acc_train', 'mean_loss_train', 'acc_valid',
        'mean_loss_valid', 'time'
    ]
    log_dir = os.path.join(options.output_dir, 'log_dir',
                           'fold_' + str(options.split), 'CNN')
    best_model_dir = os.path.join(options.output_dir, 'best_model_dir',
                                  'fold_' + str(options.split), 'CNN')
    filename = os.path.join(log_dir, 'training.tsv')

    if not resume:
        check_and_clean(best_model_dir)
        check_and_clean(log_dir)

        results_df = pd.DataFrame(columns=columns)
        with open(filename, 'w') as f:
            results_df.to_csv(f, index=False, sep='\t')
        options.beginning_epoch = 0

    else:
        if not os.path.exists(filename):
            raise ValueError(
                'The training.tsv file of the resumed experiment does not exist.'
            )
        truncated_tsv = pd.read_csv(filename, sep='\t')
        truncated_tsv.set_index(['epoch', 'iteration'], inplace=True)
        truncated_tsv.drop(options.beginning_epoch, level=0, inplace=True)
        truncated_tsv.to_csv(filename, index=True, sep='\t')

    # Create writers
    writer_train = SummaryWriter(os.path.join(log_dir, 'train'))
    writer_valid = SummaryWriter(os.path.join(log_dir, 'valid'))

    # Initialize variables
    best_valid_accuracy = 0.0
    best_valid_loss = np.inf
    epoch = options.beginning_epoch

    model.train()  # set the module to training mode

    early_stopping = EarlyStopping('min',
                                   min_delta=options.tolerance,
                                   patience=options.patience)
    mean_loss_valid = None
    t_beggining = time()

    while epoch < options.epochs and not early_stopping.step(mean_loss_valid):
        print("At %d-th epoch." % epoch)

        model.zero_grad()
        evaluation_flag = True
        step_flag = True
        tend = time()
        total_time = 0

        for i, data in enumerate(train_loader, 0):
            t0 = time()
            total_time = total_time + t0 - tend
            if options.gpu:
                imgs, labels = data['image'].cuda(), data['label'].cuda()
            else:
                imgs, labels = data['image'], data['label']
            train_output = model(imgs)
            _, predict_batch = train_output.topk(1)
            loss = criterion(train_output, labels)

            # Back propagation
            loss.backward()

            del imgs, labels

            if (i + 1) % options.accumulation_steps == 0:
                step_flag = False
                optimizer.step()
                optimizer.zero_grad()

                del loss

                # Evaluate the model only when no gradients are accumulated
                if (i + 1) % options.evaluation_steps == 0:
                    evaluation_flag = False
                    print('Iteration %d' % i)

                    acc_mean_train, total_loss_train = test(
                        model, train_loader, options.gpu, criterion)
                    mean_loss_train = total_loss_train / (
                        len(train_loader) * train_loader.batch_size)

                    acc_mean_valid, total_loss_valid = test(
                        model, valid_loader, options.gpu, criterion)
                    mean_loss_valid = total_loss_valid / (
                        len(valid_loader) * valid_loader.batch_size)
                    model.train()

                    writer_train.add_scalar('balanced_accuracy',
                                            acc_mean_train,
                                            i + epoch * len(train_loader))
                    writer_train.add_scalar('loss', mean_loss_train,
                                            i + epoch * len(train_loader))
                    writer_valid.add_scalar('balanced_accuracy',
                                            acc_mean_valid,
                                            i + epoch * len(train_loader))
                    writer_valid.add_scalar('loss', mean_loss_valid,
                                            i + epoch * len(train_loader))
                    print(
                        "Subject level training accuracy is %f at the end of iteration %d"
                        % (acc_mean_train, i))
                    print(
                        "Subject level validation accuracy is %f at the end of iteration %d"
                        % (acc_mean_valid, i))

                    t_current = time() - t_beggining
                    row = np.array([
                        epoch, i, acc_mean_train, mean_loss_train,
                        acc_mean_valid, mean_loss_valid, t_current
                    ]).reshape(1, -1)
                    row_df = pd.DataFrame(row, columns=columns)
                    with open(filename, 'a') as f:
                        row_df.to_csv(f, header=False, index=False, sep='\t')

            tend = time()
        print('Mean time per batch (train):',
              total_time / len(train_loader) * train_loader.batch_size)

        # If no step has been performed, raise Exception
        if step_flag:
            raise Exception(
                'The model has not been updated once in the epoch. The accumulation step may be too large.'
            )

        # If no evaluation has been performed, warn the user
        elif evaluation_flag:
            warnings.warn(
                'Your evaluation steps are too big compared to the size of the dataset.'
                'The model is evaluated only once at the end of the epoch')

        # Always test the results and save them once at the end of the epoch
        model.zero_grad()
        print('Last checkpoint at the end of the epoch %d' % epoch)

        acc_mean_train, total_loss_train = test(model, train_loader,
                                                options.gpu, criterion)
        mean_loss_train = total_loss_train / (len(train_loader) *
                                              train_loader.batch_size)

        acc_mean_valid, total_loss_valid = test(model, valid_loader,
                                                options.gpu, criterion)
        mean_loss_valid = total_loss_valid / (len(valid_loader) *
                                              valid_loader.batch_size)
        model.train()

        writer_train.add_scalar('balanced_accuracy', acc_mean_train,
                                i + epoch * len(train_loader))
        writer_train.add_scalar('loss', mean_loss_train,
                                i + epoch * len(train_loader))
        writer_valid.add_scalar('balanced_accuracy', acc_mean_valid,
                                i + epoch * len(train_loader))
        writer_valid.add_scalar('loss', mean_loss_valid,
                                i + epoch * len(train_loader))
        print(
            "Subject level training accuracy is %f at the end of iteration %d"
            % (acc_mean_train, i))
        print(
            "Subject level validation accuracy is %f at the end of iteration %d"
            % (acc_mean_valid, i))

        t_current = time() - t_beggining
        row = np.array([
            epoch, i, acc_mean_train, mean_loss_train, acc_mean_valid,
            mean_loss_valid, t_current
        ]).reshape(1, -1)
        row_df = pd.DataFrame(row, columns=columns)
        with open(filename, 'a') as f:
            row_df.to_csv(f, header=False, index=False, sep='\t')
        accuracy_is_best = acc_mean_valid > best_valid_accuracy
        loss_is_best = mean_loss_valid < best_valid_loss
        best_valid_accuracy = max(acc_mean_valid, best_valid_accuracy)
        best_valid_loss = min(mean_loss_valid, best_valid_loss)

        save_checkpoint(
            {
                'model': model.state_dict(),
                'epoch': epoch,
                'valid_acc': acc_mean_valid
            }, accuracy_is_best, loss_is_best, best_model_dir)
        # Save optimizer state_dict to be able to reload
        save_checkpoint(
            {
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'name': options.optimizer,
            },
            False,
            False,
            best_model_dir,
            filename='optimizer.pth.tar')

        epoch += 1
Esempio n. 6
0
def train(decoder,
          train_loader,
          valid_loader,
          criterion,
          optimizer,
          resume,
          log_dir,
          model_dir,
          options,
          logger=None):
    """
    Function used to train an autoencoder.
    The best autoencoder will be found in the 'best_model_dir' of options.output_dir.

    Args:
        decoder: (Autoencoder) Autoencoder constructed from a CNN with the Autoencoder class.
        train_loader: (DataLoader) wrapper of the training dataset.
        valid_loader: (DataLoader) wrapper of the validation dataset.
        criterion: (loss) function to calculate the loss.
        optimizer: (torch.optim) optimizer linked to model parameters.
        resume: (bool) if True, a begun job is resumed.
        log_dir: (str) path to the folder containing the logs.
        model_dir: (str) path to the folder containing the models weights and biases.
        options: (Namespace) ensemble of other options given to the main script.
        logger: (logging object) writer to stdout and stderr
    """
    from tensorboardX import SummaryWriter

    columns = ['epoch', 'iteration', 'time', 'loss_train', 'loss_valid']
    filename = os.path.join(os.path.dirname(log_dir), 'training.tsv')

    if logger is None:
        logger = logging

    columns = ['epoch', 'iteration', 'time', 'loss_train', 'loss_valid']
    filename = os.path.join(os.path.dirname(log_dir), 'training.tsv')

    if not resume:
        check_and_clean(model_dir)
        check_and_clean(log_dir)

        results_df = pd.DataFrame(columns=columns)
        with open(filename, 'w') as f:
            results_df.to_csv(f, index=False, sep='\t')
        options.beginning_epoch = 0

    else:
        if not os.path.exists(filename):
            raise ValueError(
                'The training.tsv file of the resumed experiment does not exist.'
            )
        truncated_tsv = pd.read_csv(filename, sep='\t')
        truncated_tsv.set_index(['epoch', 'iteration'], inplace=True)
        truncated_tsv.drop(options.beginning_epoch, level=0, inplace=True)
        truncated_tsv.to_csv(filename, index=True, sep='\t')

    # Create writers
    writer_train = SummaryWriter(os.path.join(log_dir, 'train'))
    writer_valid = SummaryWriter(os.path.join(log_dir, 'validation'))

    decoder.train()
    train_loader.dataset.train()
    logger.debug(decoder)

    if options.gpu:
        decoder.cuda()

    # Initialize variables
    best_loss_valid = np.inf
    epoch = options.beginning_epoch

    early_stopping = EarlyStopping('min',
                                   min_delta=options.tolerance,
                                   patience=options.patience)
    loss_valid = None
    t_beginning = time()

    logger.debug("Beginning training")
    while epoch < options.epochs and not early_stopping.step(loss_valid):
        logger.info("Beginning epoch %i." % epoch)

        decoder.zero_grad()
        evaluation_flag = True
        step_flag = True
        for i, data in enumerate(train_loader):
            if options.gpu:
                imgs = data['image'].cuda()
            else:
                imgs = data['image']

            train_output = decoder(imgs)
            loss = criterion(train_output, imgs)
            loss.backward()

            del imgs, train_output

            if (i + 1) % options.accumulation_steps == 0:
                step_flag = False
                optimizer.step()
                optimizer.zero_grad()

                # Evaluate the decoder only when no gradients are accumulated
                if options.evaluation_steps != 0 and (
                        i + 1) % options.evaluation_steps == 0:
                    evaluation_flag = False
                    loss_train = test_ae(decoder, train_loader, options.gpu,
                                         criterion)
                    mean_loss_train = loss_train / \
                        (len(train_loader) * train_loader.batch_size)

                    loss_valid = test_ae(decoder, valid_loader, options.gpu,
                                         criterion)
                    mean_loss_valid = loss_valid / \
                        (len(valid_loader) * valid_loader.batch_size)
                    decoder.train()
                    train_loader.dataset.train()

                    writer_train.add_scalar('loss', mean_loss_train,
                                            i + epoch * len(train_loader))
                    writer_valid.add_scalar('loss', mean_loss_valid,
                                            i + epoch * len(train_loader))
                    logger.info(
                        "%s level training loss is %f at the end of iteration %d"
                        % (options.mode, mean_loss_train, i))
                    logger.info(
                        "%s level validation loss is %f at the end of iteration %d"
                        % (options.mode, mean_loss_valid, i))

                    t_current = time() - t_beginning
                    row = [
                        epoch, i, t_current, mean_loss_train, mean_loss_valid
                    ]
                    row_df = pd.DataFrame([row], columns=columns)
                    with open(filename, 'a') as f:
                        row_df.to_csv(f, header=False, index=False, sep='\t')

        # If no step has been performed, raise Exception
        if step_flag:
            raise Exception(
                'The model has not been updated once in the epoch. The accumulation step may be too large.'
            )

        # If no evaluation has been performed, warn the user
        if evaluation_flag and options.evaluation_steps != 0:
            logger.warning(
                'Your evaluation steps are too big compared to the size of the dataset.'
                'The model is evaluated only once at the end of the epoch')

        # Always test the results and save them once at the end of the epoch
        logger.debug('Last checkpoint at the end of the epoch %d' % epoch)

        loss_train = test_ae(decoder, train_loader, options.gpu, criterion)
        mean_loss_train = loss_train / \
            (len(train_loader) * train_loader.batch_size)

        loss_valid = test_ae(decoder, valid_loader, options.gpu, criterion)
        mean_loss_valid = loss_valid / \
            (len(valid_loader) * valid_loader.batch_size)
        decoder.train()
        train_loader.dataset.train()

        writer_train.add_scalar('loss', mean_loss_train,
                                i + epoch * len(train_loader))
        writer_valid.add_scalar('loss', mean_loss_valid,
                                i + epoch * len(train_loader))
        logger.info("%s level training loss is %f at the end of iteration %d" %
                    (options.mode, mean_loss_train, i))
        logger.info(
            "%s level validation loss is %f at the end of iteration %d" %
            (options.mode, mean_loss_valid, i))

        t_current = time() - t_beginning
        row = [epoch, i, t_current, mean_loss_train, mean_loss_valid]
        row_df = pd.DataFrame([row], columns=columns)
        with open(filename, 'a') as f:
            row_df.to_csv(f, header=False, index=False, sep='\t')

        is_best = loss_valid < best_loss_valid
        best_loss_valid = min(best_loss_valid, loss_valid)
        # Always save the model at the end of the epoch and update best model
        save_checkpoint(
            {
                'model': decoder.state_dict(),
                'epoch': epoch,
                'valid_loss': loss_valid
            }, False, is_best, model_dir)
        # Save optimizer state_dict to be able to reload
        save_checkpoint(
            {
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'name': options.optimizer,
            },
            False,
            False,
            model_dir,
            filename='optimizer.pth.tar')

        epoch += 1

    os.remove(os.path.join(model_dir, "optimizer.pth.tar"))
    os.remove(os.path.join(model_dir, "checkpoint.pth.tar"))
Esempio n. 7
0
def train(model,
          train_loader,
          valid_loader,
          criterion,
          optimizer,
          resume,
          log_dir,
          model_dir,
          options,
          fi=None,
          cnn_index=None,
          num_cnn=None,
          train_begin_time=None):
    """
    Function used to train a CNN.
    The best model and checkpoint will be found in the 'best_model_dir' of options.output_dir.

    Args:
        model: (Module) CNN to be trained
        train_loader: (DataLoader) wrapper of the training dataset
        valid_loader: (DataLoader) wrapper of the validation dataset
        criterion: (loss) function to calculate the loss
        optimizer: (torch.optim) optimizer linked to model parameters
        resume: (bool) if True, a begun job is resumed
        log_dir: (str) path to the folder containing the logs
        model_dir: (str) path to the folder containing the models weights and biases
        options: (Namespace) ensemble of other options given to the main script.
    """
    from tensorboardX import SummaryWriter
    from time import time
    import wandb

    if not resume:
        check_and_clean(model_dir)
        check_and_clean(log_dir)

    # Create writers
    writer_train = SummaryWriter(os.path.join(log_dir, 'train'))
    writer_valid = SummaryWriter(os.path.join(log_dir, 'validation'))

    # Initialize variables
    best_valid_accuracy = 0.0
    best_valid_loss = np.inf
    epoch = options.beginning_epoch

    model.train()  # set the module to training mode

    early_stopping = EarlyStopping('min',
                                   min_delta=options.tolerance,
                                   patience=options.patience)
    mean_loss_valid = None

    while epoch < options.epochs and not early_stopping.step(mean_loss_valid):
        if fi is not None and options.n_splits is not None:
            print("[%s]: At (%d/%d) fold (%d/%d) epoch." %
                  (timeSince(train_begin_time), fi, options.n_splits, epoch,
                   options.epochs))
        else:
            print("[%s]: At (%d/%d) epoch." %
                  (timeSince(train_begin_time), epoch, options.epochs))

        model.zero_grad()
        evaluation_flag = True
        step_flag = True
        tend = time()
        total_time = 0

        for i, data in enumerate(train_loader, 0):
            t0 = time()
            total_time = total_time + t0 - tend
            if options.gpu:
                device = torch.device("cuda:{}".format(options.device))
                imgs, labels, participant_id, session_id = data['image'].to(
                    device), data['label'].to(
                        device), data['participant_id'], data['session_id']
            else:
                imgs, labels, participant_id, session_id = data['image'], data[
                    'label'], data['participant_id'], data['session_id']
            if options.model == 'ROI_GCN' or 'gcn' in options.model:
                # roi_image = data['roi_image'].to(device)
                train_output = model(imgs,
                                     label_list=labels,
                                     id=participant_id,
                                     session=session_id,
                                     fi=fi,
                                     epoch=epoch)
            else:
                train_output = model(imgs)
            # train_output = model(imgs)
            _, predict_batch = train_output.topk(1)
            loss = criterion(train_output, labels)

            # Back propagation
            loss.backward()
            # for name, param in model.named_parameters():
            #     if param.requires_grad:
            #         if param.grad is not None:
            #             pass
            #             # print("{}, gradient: {}".format(name, param.grad.mean()))
            #         else:
            #             print("{} has not gradient".format(name))
            # del imgs, labels

            if (i + 1) % options.accumulation_steps == 0:
                step_flag = False
                optimizer.step()
                optimizer.zero_grad()

                del loss

                # Evaluate the model only when no gradients are accumulated
                if options.evaluation_steps != 0 and (
                        i + 1) % options.evaluation_steps == 0:
                    evaluation_flag = False
                    print('Iteration %d' % i)

                    _, results_train = test(model,
                                            train_loader,
                                            options.gpu,
                                            criterion,
                                            device_index=options.device,
                                            train_begin_time=train_begin_time)
                    mean_loss_train = results_train["total_loss"] / (
                        len(train_loader) * train_loader.batch_size)

                    _, results_valid = test(model,
                                            valid_loader,
                                            options.gpu,
                                            criterion,
                                            device_index=options.device,
                                            train_begin_time=train_begin_time)
                    mean_loss_valid = results_valid["total_loss"] / (
                        len(valid_loader) * valid_loader.batch_size)
                    model.train()

                    global_step = i + epoch * len(train_loader)
                    if cnn_index is None:
                        writer_train.add_scalar(
                            'balanced_accuracy',
                            results_train["balanced_accuracy"], global_step)
                        writer_train.add_scalar('loss', mean_loss_train,
                                                global_step)
                        writer_valid.add_scalar(
                            'balanced_accuracy',
                            results_valid["balanced_accuracy"], global_step)
                        writer_valid.add_scalar('loss', mean_loss_valid,
                                                global_step)
                        wandb.log({
                            'train_balanced_accuracy':
                            results_train["balanced_accuracy"],
                            'train_loss':
                            mean_loss_train,
                            'valid_balanced_accuracy':
                            results_valid["balanced_accuracy"],
                            'valid_loss':
                            mean_loss_valid,
                            'global_step':
                            global_step
                        })
                        print(
                            "[%s]: %s level training accuracy is %f at the end of iteration %d - fake mri count: %d"
                            % (timeSince(train_begin_time), options.mode,
                               results_train["balanced_accuracy"], i,
                               data['num_fake_mri']))
                        print(
                            "[%s]: %s level validation accuracy is %f at the end of iteration %d - fake mri count: %d"
                            % (timeSince(train_begin_time), options.mode,
                               results_valid["balanced_accuracy"], i,
                               data['num_fake_mri']))
                    else:
                        writer_train.add_scalar(
                            '{}_model_balanced_accuracy'.format(cnn_index),
                            results_train["balanced_accuracy"], global_step)
                        writer_train.add_scalar(
                            '{}_model_loss'.format(cnn_index), mean_loss_train,
                            global_step)
                        writer_valid.add_scalar(
                            '{}_model_balanced_accuracy'.format(cnn_index),
                            results_valid["balanced_accuracy"], global_step)
                        writer_valid.add_scalar(
                            '{}_model_loss'.format(cnn_index), mean_loss_valid,
                            global_step)
                        wandb.log({
                            '{}_model_train_balanced_accuracy'.format(cnn_index):
                            results_train["balanced_accuracy"],
                            '{}_model_train_loss'.format(cnn_index):
                            mean_loss_train,
                            '{}_model_valid_balanced_accuracy'.format(cnn_index):
                            results_valid["balanced_accuracy"],
                            '{}_model_valid_loss'.format(cnn_index):
                            mean_loss_valid,
                            'global_step':
                            global_step
                        })
                        print(
                            "[{}]: ({}/{}) model {} level training accuracy is {} at the end of iteration {}-fake mri count:{}"
                            .format(timeSince(train_begin_time), cnn_index,
                                    num_cnn, options.mode,
                                    results_train["balanced_accuracy"], i,
                                    data['num_fake_mri']))
                        print(
                            "[{}]: ({}/{}) model {} level validation accuracy is {} at the end of iteration {}-fake mri count:{}"
                            .format(timeSince(train_begin_time), cnn_index,
                                    num_cnn, options.mode,
                                    results_valid["balanced_accuracy"], i,
                                    data['num_fake_mri']))

            tend = time()
        print(
            '[{}]: Mean time per batch loading (train):'.format(
                timeSince(train_begin_time)),
            total_time / len(train_loader) * train_loader.batch_size)

        # If no step has been performed, raise Exception
        if step_flag:
            raise Exception(
                'The model has not been updated once in the epoch. The accumulation step may be too large.'
            )

        # If no evaluation has been performed, warn the user
        elif evaluation_flag and options.evaluation_steps != 0:
            warnings.warn(
                'Your evaluation steps are too big compared to the size of the dataset.'
                'The model is evaluated only once at the end of the epoch')

        # Always test the results and save them once at the end of the epoch
        model.zero_grad()
        print('[%s]: Last checkpoint at the end of the epoch %d' %
              (timeSince(train_begin_time), epoch))

        _, results_train = test(model,
                                train_loader,
                                options.gpu,
                                criterion,
                                device_index=options.device,
                                train_begin_time=train_begin_time,
                                model_options=options)
        mean_loss_train = results_train["total_loss"] / (
            len(train_loader) * train_loader.batch_size)

        _, results_valid = test(model,
                                valid_loader,
                                options.gpu,
                                criterion,
                                device_index=options.device,
                                train_begin_time=train_begin_time,
                                model_options=options)
        mean_loss_valid = results_valid["total_loss"] / (
            len(valid_loader) * valid_loader.batch_size)
        model.train()

        global_step = (epoch + 1) * len(train_loader)
        if cnn_index is None:
            writer_train.add_scalar('balanced_accuracy',
                                    results_train["balanced_accuracy"],
                                    global_step)
            writer_train.add_scalar('loss', mean_loss_train, global_step)
            writer_valid.add_scalar('balanced_accuracy',
                                    results_valid["balanced_accuracy"],
                                    global_step)
            writer_valid.add_scalar('loss', mean_loss_valid, global_step)
            wandb.log({
                'train_balanced_accuracy':
                results_train["balanced_accuracy"],
                'train_loss':
                mean_loss_train,
                'valid_balanced_accuracy':
                results_valid["balanced_accuracy"],
                'valid_loss':
                mean_loss_valid,
                'global_step':
                global_step
            })
            print(
                "[%s]: %s level training accuracy is %f at the end of iteration %d"
                % (timeSince(train_begin_time), options.mode,
                   results_train["balanced_accuracy"], i))
            print(
                "[%s]: %s level validation accuracy is %f at the end of iteration %d"
                % (timeSince(train_begin_time), options.mode,
                   results_valid["balanced_accuracy"], i))
        else:
            writer_train.add_scalar(
                '{}_model_balanced_accuracy'.format(cnn_index),
                results_train["balanced_accuracy"], global_step)
            writer_train.add_scalar('{}_model_loss'.format(cnn_index),
                                    mean_loss_train, global_step)
            writer_valid.add_scalar(
                '{}_model_balanced_accuracy'.format(cnn_index),
                results_valid["balanced_accuracy"], global_step)
            writer_valid.add_scalar('{}_model_loss'.format(cnn_index),
                                    mean_loss_valid, global_step)
            wandb.log({
                '{}_model_train_balanced_accuracy'.format(cnn_index):
                results_train["balanced_accuracy"],
                '{}_model_train_loss'.format(cnn_index):
                mean_loss_train,
                '{}_model_valid_balanced_accuracy'.format(cnn_index):
                results_valid["balanced_accuracy"],
                '{}_model_valid_loss'.format(cnn_index):
                mean_loss_valid,
                'global_step':
                global_step
            })
            print(
                "[%s]: %s model %s level training accuracy is %f at the end of iteration %d"
                % (timeSince(train_begin_time), cnn_index, options.mode,
                   results_train["balanced_accuracy"], i))
            print(
                "[%s]: %s model %s level validation accuracy is %f at the end of iteration %d"
                % (timeSince(train_begin_time), cnn_index, options.mode,
                   results_valid["balanced_accuracy"], i))

        accuracy_is_best = results_valid[
            "balanced_accuracy"] > best_valid_accuracy
        loss_is_best = mean_loss_valid < best_valid_loss
        best_valid_accuracy = max(results_valid["balanced_accuracy"],
                                  best_valid_accuracy)
        best_valid_loss = min(mean_loss_valid, best_valid_loss)

        save_checkpoint(
            {
                'model': model.state_dict(),
                'epoch': epoch,
                'valid_loss': mean_loss_valid,
                'valid_acc': results_valid["balanced_accuracy"]
            }, accuracy_is_best, loss_is_best, model_dir)
        # Save optimizer state_dict to be able to reload
        save_checkpoint(
            {
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'name': options.optimizer,
            },
            False,
            False,
            model_dir,
            filename='optimizer.pth.tar')

        epoch += 1

    os.remove(os.path.join(model_dir, "optimizer.pth.tar"))
    os.remove(os.path.join(model_dir, "checkpoint.pth.tar"))
Esempio n. 8
0
def train(model, train_loader, valid_loader, criterion, optimizer, resume, log_dir, model_dir, options):
    """
    Function used to train a CNN.
    The best model and checkpoint will be found in the 'best_model_dir' of options.output_dir.

    Args:
        model: (Module) CNN to be trained
        train_loader: (DataLoader) wrapper of the training dataset
        valid_loader: (DataLoader) wrapper of the validation dataset
        criterion: (loss) function to calculate the loss
        optimizer: (torch.optim) optimizer linked to model parameters
        resume: (bool) if True, a begun job is resumed
        log_dir: (str) path to the folder containing the logs
        model_dir: (str) path to the folder containing the models weights and biases
        options: (Namespace) ensemble of other options given to the main script.
    """
    from tensorboardX import SummaryWriter
    from time import time

    if not resume:
        check_and_clean(model_dir)
        check_and_clean(log_dir)

    # Create writers
    writer_train = SummaryWriter(os.path.join(log_dir, 'train'))
    writer_valid = SummaryWriter(os.path.join(log_dir, 'validation'))

    # Initialize variables
    best_valid_accuracy = 0.0
    best_valid_loss = np.inf
    epoch = options.beginning_epoch

    model.train()  # set the module to training mode

    early_stopping = EarlyStopping('min', min_delta=options.tolerance, patience=options.patience)
    mean_loss_valid = None

    while epoch < options.epochs and not early_stopping.step(mean_loss_valid):
        print("At %d-th epoch." % epoch)

        model.zero_grad()
        evaluation_flag = True
        step_flag = True
        tend = time()
        total_time = 0

        for i, data in enumerate(train_loader, 0):
            t0 = time()
            total_time = total_time + t0 - tend
            if options.gpu:
                imgs, labels = data['image'].cuda(), data['label'].cuda()
            else:
                imgs, labels = data['image'], data['label']
            train_output = model(imgs)
            _, predict_batch = train_output.topk(1)
            loss = criterion(train_output, labels)

            # Back propagation
            loss.backward()

            del imgs, labels

            if (i + 1) % options.accumulation_steps == 0:
                step_flag = False
                optimizer.step()
                optimizer.zero_grad()

                del loss

                # Evaluate the model only when no gradients are accumulated
                if options.evaluation_steps != 0 and (i + 1) % options.evaluation_steps == 0:
                    evaluation_flag = False
                    print('Iteration %d' % i)

                    _, results_train = test(model, train_loader, options.gpu, criterion)
                    mean_loss_train = results_train["total_loss"] / (len(train_loader) * train_loader.batch_size)

                    _, results_valid = test(model, valid_loader, options.gpu, criterion)
                    mean_loss_valid = results_valid["total_loss"] / (len(valid_loader) * valid_loader.batch_size)
                    model.train()

                    global_step = i + epoch * len(train_loader)
                    writer_train.add_scalar('balanced_accuracy', results_train["balanced_accuracy"], global_step)
                    writer_train.add_scalar('loss', mean_loss_train, global_step)
                    writer_valid.add_scalar('balanced_accuracy', results_valid["balanced_accuracy"], global_step)
                    writer_valid.add_scalar('loss', mean_loss_valid, global_step)
                    print("%s level training accuracy is %f at the end of iteration %d"
                          % (options.mode, results_train["balanced_accuracy"], i))
                    print("%s level validation accuracy is %f at the end of iteration %d"
                          % (options.mode, results_valid["balanced_accuracy"], i))

            tend = time()
        print('Mean time per batch loading (train):', total_time / len(train_loader) * train_loader.batch_size)

        # If no step has been performed, raise Exception
        if step_flag:
            raise Exception('The model has not been updated once in the epoch. The accumulation step may be too large.')

        # If no evaluation has been performed, warn the user
        elif evaluation_flag and options.evaluation_steps != 0:
            warnings.warn('Your evaluation steps are too big compared to the size of the dataset.'
                          'The model is evaluated only once at the end of the epoch')

        # Always test the results and save them once at the end of the epoch
        model.zero_grad()
        print('Last checkpoint at the end of the epoch %d' % epoch)

        _, results_train = test(model, train_loader, options.gpu, criterion)
        mean_loss_train = results_train["total_loss"] / (len(train_loader) * train_loader.batch_size)

        _, results_valid = test(model, valid_loader, options.gpu, criterion)
        mean_loss_valid = results_valid["total_loss"] / (len(valid_loader) * valid_loader.batch_size)
        model.train()

        global_step = (epoch + 1) * len(train_loader)
        writer_train.add_scalar('balanced_accuracy', results_train["balanced_accuracy"], global_step)
        writer_train.add_scalar('loss', mean_loss_train, global_step)
        writer_valid.add_scalar('balanced_accuracy', results_valid["balanced_accuracy"], global_step)
        writer_valid.add_scalar('loss', mean_loss_valid, global_step)
        print("%s level training accuracy is %f at the end of iteration %d"
              % (options.mode, results_train["balanced_accuracy"], len(train_loader)))
        print("%s level validation accuracy is %f at the end of iteration %d"
              % (options.mode, results_valid["balanced_accuracy"], len(train_loader)))

        accuracy_is_best = results_valid["balanced_accuracy"] > best_valid_accuracy
        loss_is_best = mean_loss_valid < best_valid_loss
        best_valid_accuracy = max(results_valid["balanced_accuracy"], best_valid_accuracy)
        best_valid_loss = min(mean_loss_valid, best_valid_loss)

        save_checkpoint({'model': model.state_dict(),
                         'epoch': epoch,
                         'valid_loss': mean_loss_valid,
                         'valid_acc': results_valid["balanced_accuracy"]},
                        accuracy_is_best, loss_is_best,
                        model_dir)
        # Save optimizer state_dict to be able to reload
        save_checkpoint({'optimizer': optimizer.state_dict(),
                         'epoch': epoch,
                         'name': options.optimizer,
                         },
                        False, False,
                        model_dir,
                        filename='optimizer.pth.tar')

        epoch += 1

    os.remove(os.path.join(model_dir, "optimizer.pth.tar"))
    os.remove(os.path.join(model_dir, "checkpoint.pth.tar"))