Exemplo n.º 1
0
def train_full(config,
               params=None,
               warm_start_NN=None,
               restore_old_checkpoint=False,
               workers=1,
               verbosity=0,
               cross_fold_loaders=None):
    """
    OLD AND OUTDATED, IS replaced by train.


    :param config:
    :param params:
    :param warm_start_NN:
    :param restore_old_checkpoint:
    :param workers:
    :param verbosity:
    :param cross_fold_loaders:
    :return:
    """
    if verbosity == 0:
        logger.setLevel(logging.INFO)
    if verbosity >= 1:
        logger.setLevel(logging.DEBUG)

    start = time.time()

    logger.info('Preparing Datasets')
    if cross_fold_loaders is not None:
        pass
    else:
        train_dataset, validation_dataset = prepare_dataset_torch(config)
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=params['batch_size'], shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            validation_dataset, batch_size=params['batch_size'], shuffle=True)

    logger.info('Initializing Torch Network')

    net = map_model(config, params)

    logger.info('Optimizer Initialize')
    optimizer = map_optimizer(params['optimizer'], net.parameters(),
                              params['learning_rate'])
    loss_func = map_loss_func(params['loss'])

    logger.info('Start Training!')
    if config['scheduler']:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=config['scheduler_milestones'], gamma=0.1)

    epochs = config['epochs']
    criterion = torch.nn.MSELoss()

    last_results = []
    metrics = {}
    losses = []

    # Track the losses to determine early stopping
    train_losses = []
    validation_losses = []

    avg_train_loss = []
    avg_valid_loss = []

    # initalize the early_stopping object
    early_stopping = EarlyStopping(verbose=True,
                                   trace_func=logger.info,
                                   path=config['save_model_path'])

    for epoch in range(epochs):

        # TRAINING
        net.train()
        max_error = 0.0

        for i, batch in enumerate(train_loader):
            optimizer.zero_grad()
            inputs, targets = batch['input'], batch['target']
            output = net(inputs)
            loss = loss_func(output, targets)
            loss.backward()
            optimizer.step()
            max_error = max(max_error, loss.detach().numpy())

            train_losses.append(loss.item())

        if config['scheduler']:
            scheduler.step()
        losses.append(max_error)

        # Validation

        net.eval()
        max_error = 0.0

        for i, batch in enumerate(test_loader):
            inputs, targets = batch['input'], batch['target']
            output = net(inputs)

            MSE = criterion(output, targets)
            MSE = torch.sqrt(MSE)
            max_error = max(MSE, max_error)
            score = -math.log10(max_error)

            loss = loss_func(output, targets)
            validation_losses.append(loss.item())

        train_loss = np.average(train_losses)
        validation_loss = np.average(validation_losses)

        avg_train_loss.append(train_loss)
        avg_valid_loss.append(validation_loss)

        train_losses = validation_losses = []

        early_stopping(validation_loss, net)

        logger.info(
            'Epoch {}; Train Loss: {:.5}; Valid Loss: {:.5}; Validation RMSE: {:.5}'
            .format(epoch, train_loss, validation_loss, max_error))
        # print('Epoch {}; Train Loss: {:.5}; Valid Loss: {:.5}; Validation RMSE: {:.5}'.format(epoch, train_loss, validation_loss, max_error))
        if early_stopping.early_stop:
            logger.info('Early Stopping')
            break
        last_results.append(score)

    final_score = min(last_results)
    metrics['default'] = final_score

    net.load_state_dict(torch.load(config['save_model_path']))

    end = time.time()
    logger.info(
        'Training Completed: Time elapsed: {:.2} Seconds'.format(end - start))
    save_path = config['diagnostics_path']
    plot_results(net, validation_dataset, criterion, save_path=save_path)
    plot_early_stopping(avg_train_loss,
                        avg_valid_loss,
                        save_path=save_path + '_loss')
Exemplo n.º 2
0
def train(config,
          params=None,
          warm_start_NN=None,
          restore_old_checkpoint=False,
          workers=1,
          verbosity=0,
          diagnostics=False):
    """
    ---------------------------------
    Implements the train per epoch to train a pytorch model based on the configuration file.
    It will also save the model to the location specified in the config.

    ---------------------------------
    :param config: see example config in repository
    :param params: this is mapped from the config, if not generated using default parameters that generally cover the bases for all models here
    :param warm_start_NN: this needs to be a loaded pytorch module, so you can define your own model, and use that
    :param restore_old_checkpoint: TBD
    :param workers: TBD
    :param verbosity: logging
    :param diagnostics: output results figure
    :return: None, but maybe in future it returns the net
    """
    if verbosity == 0:
        logger.setLevel(logging.INFO)
    if verbosity >= 1:
        logger.setLevel(logging.DEBUG)

    logger.info('Preparing Datasets')

    train_dataset, validation_dataset = prepare_dataset_torch(config)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=params['batch_size'],
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(validation_dataset,
                                              batch_size=params['batch_size'],
                                              shuffle=True)

    logger.info('Initializing Torch Network')

    if warm_start_NN is not None:
        net = warm_start_NN
    else:
        net = map_model(config, params)

    logger.info('Optimizer Initialize')
    optimizer = map_optimizer(params['optimizer'], net.parameters(),
                              params['learning_rate'])
    loss_func = map_loss_func(params['loss'])

    logger.info('Start Training!')
    if config['scheduler']:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=config['scheduler_milestones'], gamma=0.1)

    epochs = config['epochs']
    criterion = torch.nn.MSELoss()

    # Track the losses to determine early stopping
    avg_train_loss = []
    avg_valid_loss = []

    # initalize the early_stopping object
    early_stopping = EarlyStopping(verbose=True,
                                   trace_func=logger.info,
                                   path=config['save_model_path'])

    for epoch in range(epochs):
        train_loss, validation_loss, RMSE = train_epoch(
            net,
            optimizer,
            loss_func,
            train_loader=train_loader,
            test_loader=test_loader,
            scheduler=scheduler,
            criterion=criterion)
        if early_stopping is not None:
            early_stopping(validation_loss, net, RMSE)

        avg_train_loss.append(train_loss)
        avg_valid_loss.append(validation_loss)

        RMSE = early_stopping.RMSE
        logger.info(
            'Epoch {}; Train Loss: {:.5}; Valid Loss: {:.5}; Best Validation RMSE: {:.5}'
            .format(epoch, train_loss, validation_loss, RMSE))
        print(
            'Epoch {}; Train Loss: {:.5}; Valid Loss: {:.5}; Validation RMSE: {:.5}'
            .format(epoch, train_loss, validation_loss, RMSE))
        if early_stopping.early_stop:
            logger.info('Early Stopping')
            break

    if diagnostics:
        net.load_state_dict(torch.load(config['save_model_path']))
        try:
            save_path = config['diagnostics_path']
        except KeyError as exc:
            Warning('No Path to Save Diagnostics, saving to root dir')
            save_path = 'trial_run'

        plot_results(net,
                     validation_dataset,
                     criterion,
                     save_path=save_path,
                     config=config)
        plot_early_stopping(avg_train_loss,
                            avg_valid_loss,
                            save_path=save_path + '_loss')
Exemplo n.º 3
0
def main(config, params=None, verbosity=0):
    if verbosity == 0:
        logger.setLevel(logging.INFO)
    if verbosity >= 1:
        logger.setLevel(logging.DEBUG)

    start = time.time()
    params = map_config_params(config['hyperparameters'])
    logger.info('Preparing Datasets')

    # Split dataset into K-Fold
    # Train on each fold
    # make predictions on test set
    # calculate RMSE and store in list
    # calc std. dev in the RMSE for predictions

    kfold = KFold(n_splits=5, shuffle=True)

    train_dataset, validation_dataset = prepare_dataset_torch(config)
    results = {}
    for fold, (train_ids, test_ids) in enumerate(kfold.split(train_dataset)):

        print(f'Fold {fold}')
        print('-----------------------')
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)


        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=64, sampler=train_subsampler)
        test_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=64, sampler=test_subsampler)

        logger.info('Initializing Torch Network')

        if config.get('nn_type') == 'SimpleNet':
            net = SimpleNet2(config, params)
        elif config.get('nn_type') == 'ComplexCross':
            net = PedDeepCross(config, params)
        elif config.get('nn_type') == 'SimpleCross':
            net = Simple_Cross2(config, params)
        else:
            raise KeyError('NN Type does not yet exist... please choose from SimpleNet, ComplexCross, SimpleCross')

        logger.info('Optimizer Initialize')
        optimizer = map_optimizer(params['optimizer'], net.parameters(), params['learning_rate'])
        loss_func = map_loss_func(params['loss'])

        logger.info('Start Training!')
        if config['scheduler']:
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config['scheduler_milestones'],
                                                             gamma=0.1)
        early_stopping = EarlyStopping(verbose=True, trace_func=logger.info)

        epochs = config['epochs']
        criterion = torch.nn.MSELoss()

        last_results = []

        # Track the losses to determine early stopping
        train_losses = []
        validation_losses = []

        avg_train_loss = []
        avg_valid_loss = []

        for epoch in range(epochs):

            # Training

            net.train()

            for i, batch in enumerate(train_loader):
                optimizer.zero_grad()
                inputs, targets = batch['input'], batch['target']
                output = net(inputs)
                loss = loss_func(output, targets)
                loss.backward()
                optimizer.step()

                train_losses.append(loss.item())

            if config['scheduler']:
                scheduler.step()

            # Validation
            net.eval()
            max_error = 0.0

            for i, batch in enumerate(test_loader):
                inputs, targets = batch['input'], batch['target']
                output = net(inputs)

                MSE = criterion(output, targets)
                MSE = torch.sqrt(MSE)
                max_error = max(MSE, max_error)

                loss = loss_func(output, targets)
                validation_losses.append(loss.item())

            train_loss = np.average(train_losses)
            validation_loss = np.average(validation_losses)

            avg_train_loss.append(train_loss)
            avg_valid_loss.append(validation_loss)

            train_losses = validation_losses = []

            early_stopping(validation_loss, net, max_error)

            logger.info(
                'Epoch {}; Train Loss: {:.5}; Valid Loss: {:.5}; Max Validation RMSE: {:.5}'.format(epoch, train_loss,
                                                                                                validation_loss,
                                                                                                max_error))

            print('Epoch {}; Train Loss: {:.5}; Valid Loss: {:.5}; Max Validation RMSE: {:.5}'.format(epoch, train_loss,
                                                                                                validation_loss,
                                                                                                max_error))
            if early_stopping.early_stop:
                logger.info('Early Stopping')
                break

            last_results.append(max_error)

        save_path = '/home/adam/Uni_Sache/Bachelors/Thesis/next_phase/NNI_meditations/density_predictions/final_cuts/SimpleNet/' + str(fold) + '_crossvalid'
        plot_results(net=net, dataset=validation_dataset, criterion=criterion, save_path=save_path, config=config)
        results[fold] = early_stopping.RMSE

    print(f'K_FOLD CROSS VAL Ssores for 5 Folds')
    print('-------------------------')
    sum = 0.0
    values = []
    for key, value in results.items():
        print(f'Fold {key}: {value} ')
        sum += value
        values.append(value.detach().numpy())
    print(f'Average: {sum / len(results.items())} ')
    print(f'Std Dev: {np.std(values)}')
Exemplo n.º 4
0
def main2(config, params=None, verbosity=0):
    if verbosity == 0:
        logger.setLevel(logging.INFO)
    if verbosity >= 1:
        logger.setLevel(logging.DEBUG)

    start = time.time()
    params = map_config_params(config['hyperparameters'])
    logger.info('Preparing Datasets')

    # Split dataset into K-Fold
    # Train on each fold
    # make predictions on test set
    # calculate RMSE and store in list
    # calc std. dev in the RMSE for predictions

    kfold = KFold(n_splits=5, shuffle=True)

    train_dataset, validation_dataset = prepare_dataset_torch(config)
    results = {}
    for fold, (train_ids, test_ids) in enumerate(kfold.split(train_dataset)):

        print(f'Fold {fold}')
        print('-----------------------')
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)


        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=64, sampler=train_subsampler)
        test_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=64, sampler=test_subsampler)

        logger.info('Initializing Torch Network')

        net = map_model(config, params)

        logger.info('Optimizer Initialize')
        optimizer = map_optimizer(params['optimizer'], net.parameters(), params['learning_rate'])
        loss_func = map_loss_func(params['loss'])

        logger.info('Start Training!')
        if config['scheduler']:
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config['scheduler_milestones'],
                                                             gamma=0.1)

        early_stopping = EarlyStopping(verbose=True, trace_func=logger.info, patience=35)

        epochs = config['epochs']
        criterion = torch.nn.MSELoss()

        # Track the losses to determine early stopping

        avg_train_loss = []
        avg_valid_loss = []

        for epoch in range(epochs):

            train_loss, validation_loss, RMSE = train_epoch(net, optimizer, loss_func, train_loader=train_loader,
                                                            test_loader=test_loader, scheduler=scheduler,
                                                            criterion=criterion)

            early_stopping(validation_loss, net, RMSE)

            logger.info(
                'Epoch {}; Train Loss: {:.5}; Valid Loss: {:.5}; Max Validation RMSE: {:.5}'.format(epoch, train_loss,
                                                                                                validation_loss,
                                                                                                RMSE))

            print('Epoch {}; Train Loss: {:.5}; Valid Loss: {:.5}; Max Validation RMSE: {:.5}'.format(epoch, train_loss,
                                                                                                validation_loss,
                                                                                                RMSE))

            avg_train_loss.append(train_loss)
            avg_valid_loss.append(validation_loss)

            if early_stopping.early_stop:
                logger.info('Early Stopping')
                break

        save_path = '/home/adam/Uni_Sache/Bachelors/Thesis/next_phase/NNI_meditations/density_predictions/final_cuts/ComplexCross/' + str(fold) + '_crossvalid'
        plot_results(net, dataset=validation_dataset, criterion=criterion, save_path=save_path, config=config)
        plot_early_stopping(avg_train_loss, avg_valid_loss, save_path=save_path + str(fold)  + '_loss')
        results[fold] = early_stopping.RMSE

    print(f'K_FOLD CROSS VAL Ssores for 5 Folds')
    print('-------------------------')
    sum = 0.0
    values = []
    for key, value in results.items():
        print(f'Fold {key}: {value} ')
        sum += value
        values.append(value.detach().numpy())
    print(f'Average: {sum / len(results.items())} ')
    print(f'Std Dev: {np.std(values)}')
def train_search(config,
                 params=None,
                 warm_start_NN=None,
                 restore_old_checkpoint=False,
                 workers=1,
                 verbosity=0):
    """
    train_search is practically the same as the train function from training_torch, just made for NNI experiments

    :param config:
    :param params:
    :param warm_start_NN:
    :param restore_old_checkpoint:
    :param workers:
    :param verbosity:
    :return:
    """
    if verbosity == 0:
        logger.setLevel(logging.INFO)
    if verbosity >= 1:
        logger.setLevel(logging.DEBUG)
    start = time.time()

    logger.info('Preparing Datasets')

    train_dataset, validation_dataset = prepare_dataset_torch(config)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=params['batch_size'],
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(validation_dataset,
                                              batch_size=params['batch_size'],
                                              shuffle=True)

    logger.info('Initializing Torch Network')

    net = map_model(config, params)

    logger.info('Optimizer Initialize')
    optimizer = map_optimizer(params['optimizer'], net.parameters(),
                              params['learning_rate'])
    loss_func = map_loss_func(params['loss'])
    criterion = torch.nn.MSELoss()

    if config['scheduler']:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=config['scheduler_milestones'], gamma=0.1)
    else:
        scheduler = None

    epochs = config['epochs']

    # Track the losses to determine early stopping
    avg_train_loss = []
    avg_valid_loss = []

    # initalize the early_stopping object
    early_stopping = EarlyStopping(verbose=True, trace_func=logger.info)

    logger.info('Start Training!')
    for epoch in range(epochs):

        train_loss, validation_loss, RMSE = train_epoch(
            net,
            optimizer,
            loss_func,
            train_loader=train_loader,
            test_loader=test_loader,
            scheduler=scheduler,
            criterion=criterion)

        nni.report_intermediate_result(-math.log10(RMSE))
        if early_stopping is not None:
            early_stopping(validation_loss, net, RMSE)
            RMSE = early_stopping.RMSE

        avg_train_loss.append(train_loss)
        avg_valid_loss.append(validation_loss)

        logger.info(
            'Epoch {}; Train Loss: {:.5}; Valid Loss: {:.5}; Best Validation RMSE: {:.5}'
            .format(epoch, train_loss, validation_loss, RMSE))
        print(
            'Epoch {}; Train Loss: {:.5}; Valid Loss: {:.5}; Validation RMSE: {:.5}'
            .format(epoch, train_loss, validation_loss, RMSE))
        if early_stopping.early_stop:
            logger.info('Early Stopping')
            RMSE = early_stopping.RMSE
            break

    nni.report_final_result(-math.log10(RMSE))
    end = time.time()
    logger.info(
        'Training Completed: Time elapsed: {:.2} Seconds'.format(end - start))
    plot_against_scaling(net,
                         validation_dataset,
                         criterion,
                         trial_id=str(nni.get_trial_id()),
                         exp_id=str(nni.get_experiment_id()))