Ejemplo n.º 1
0
def train_and_evaluate(model: nn.Module, train_loader: DataLoader,
                       test_loader: DataLoader, optimizer: optim,
                       args) -> None:
    logger.info('begin training and evaluation')
    best_test_R2 = float('inf')
    train_len = len(train_loader)
    loss_summary = np.zeros((train_len * args.num_epochs))
    R2_summary = np.zeros(args.num_epochs)
    for epoch in range(args.num_epochs):
        logger.info('Epoch {}/{}'.format(epoch + 1, args.num_epochs))
        loss_summary[epoch * train_len:(epoch + 1) * train_len] = train(
            model, optimizer, train_loader, test_loader, args, epoch)
        test_metrics = evaluate(model, test_loader, args, epoch)
        R2_summary[epoch] = test_metrics['R2']
        is_best = R2_summary[epoch] <= best_test_R2

        # Save weights
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optim_dict': optimizer.state_dict()
            },
            epoch=epoch,
            is_best=is_best,
            checkpoint=args.model_dir)

        if is_best:
            logger.info('- Found new best R2')
            best_test_R2 = R2_summary[epoch]
            best_json_path = os.path.join(args.model_dir,
                                          'metrics_test_best_weights.json')
            utils.save_dict_to_json(test_metrics, best_json_path)

        logger.info('Current Best R2 is: %.5f' % best_test_R2)

        utils.plot_all_epoch(R2_summary[:epoch + 1], args.dataset + '_ND',
                             args.plot_dir)
        utils.plot_all_epoch(loss_summary[:(epoch + 1) * train_len],
                             args.dataset + '_loss', args.plot_dir)

        last_json_path = os.path.join(args.model_dir,
                                      'metrics_test_last_weights.json')
        utils.save_dict_to_json(test_metrics, last_json_path)
Ejemplo n.º 2
0
def train_and_evaluate(model: nn.Module,
                       train_loader: DataLoader,
                       test_loader: DataLoader,
                       optimizer: optim,
                       loss_fn,
                       params: utils.Params,
                       restore_file: str = None) -> None:
    '''Train the model and evaluate every epoch.
    Args:
        model: (torch.nn.Module) the Deep AR model
        train_loader: load train data and labels
        test_loader: load test data and labels
        optimizer: (torch.optim) optimizer for parameters of model
        loss_fn: a function that takes outputs and labels per timestep, and then computes the loss for the batch
        params: (Params) hyperparameters
        restore_file: (string) optional- name of file to restore from (without its extension .pth.tar)
    '''
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(params.model_dir,
                                    restore_file + '.pth.tar')
        logger.info('Restoring parameters from {}'.format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)
    logger.info('begin training and evaluation')
    best_test_ND = float('inf')
    train_len = len(train_loader)
    ND_summary = np.zeros(params.num_epochs)
    loss_summary = np.zeros((train_len * params.num_epochs))
    for epoch in range(params.num_epochs):
        logger.info('Epoch {}/{}'.format(epoch + 1, params.num_epochs))
        loss_summary[epoch * train_len:(epoch + 1) * train_len] = train(
            model, optimizer, loss_fn, train_loader, test_loader, params,
            epoch)
        test_metrics = evaluate(model,
                                loss_fn,
                                test_loader,
                                params,
                                epoch,
                                sample=args.sampling)
        ND_summary[epoch] = test_metrics['ND']
        is_best = ND_summary[epoch] <= best_test_ND

        # Save weights
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optim_dict': optimizer.state_dict()
            },
            epoch=epoch,
            is_best=is_best,
            checkpoint=params.model_dir)

        if is_best:
            logger.info('- Found new best ND')
            best_test_ND = ND_summary[epoch]
            best_json_path = os.path.join(params.model_dir,
                                          'metrics_test_best_weights.json')
            utils.save_dict_to_json(test_metrics, best_json_path)

        logger.info('Current Best ND is: %.5f' % best_test_ND)

        utils.plot_all_epoch(ND_summary[:epoch + 1], args.dataset + '_ND',
                             params.plot_dir)
        utils.plot_all_epoch(loss_summary[:(epoch + 1) * train_len],
                             args.dataset + '_loss', params.plot_dir)

        last_json_path = os.path.join(params.model_dir,
                                      'metrics_test_last_weights.json')
        utils.save_dict_to_json(test_metrics, last_json_path)

    if args.save_best:
        f = open('./param_search.txt', 'w')
        f.write('-----------\n')
        list_of_params = args.search_params.split(',')
        print_params = ''
        for param in list_of_params:
            param_value = getattr(params, param)
            print_params += f'{param}: {param_value:.2f}'
        print_params = print_params[:-1]
        f.write(print_params + '\n')
        f.write('Best ND: ' + str(best_test_ND) + '\n')
        logger.info(print_params)
        logger.info(f'Best ND: {best_test_ND}')
        f.close()
        utils.plot_all_epoch(ND_summary,
                             print_params + '_ND',
                             location=params.plot_dir)
        utils.plot_all_epoch(loss_summary,
                             print_params + '_loss',
                             location=params.plot_dir)
Ejemplo n.º 3
0
def train_and_evaluate(model: nn.Module,
                       train_loader: DataLoader,
                       test_loader: DataLoader,
                       optimizer: optim, loss_fn,
                       params: utils.Params,
                       restore_file: str = None) -> None:
    '''Train the model and evaluate every epoch.
    Args:
        model: (torch.nn.Module) the Deep AR model
        train_loader: load train data and labels
        test_loader: load test data and labels
        optimizer: (torch.optim) optimizer for parameters of model
        loss_fn: a function that takes outputs and labels per timestep, and then computes the loss for the batch
        params: (Params) hyperparameters
        restore_file: (string) optional- name of file to restore from (without its extension .pth.tar)
    '''
    # reload weights from restore_file if specified
    restore_epoch = 0
    if restore_file is not None:
        restore_path = os.path.join(params.model_dir, restore_file + '.pth.tar')
        logger.info('Restoring parameters from {}'.format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)
        restore_epoch = int(restore_file[-2:].replace('_',''))+1
    logger.info('Restoring epoch: {}'.format(restore_epoch))
    logger.info('Begin training and evaluation')
    
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=25, verbose=True, delta=0.0001, folder=params.model_dir)
    
    if os.path.exists(os.path.join(params.model_dir, 'metrics_test_best_weights.json')):
        with open(os.path.join(params.model_dir, 'metrics_test_best_weights.json')) as json_file:
            best_test_ND = json.load(json_file)['ND']
            early_stopping.best_score = best_test_ND
    else:
        best_test_ND = float('inf')
        early_stopping.best_score = best_test_ND
    
    train_len = len(train_loader)
    ND_summary = np.zeros(params.num_epochs)
    loss_summary = np.zeros((train_len * params.num_epochs))
    
    for epoch in range(restore_epoch, params.num_epochs):
        logger.info('Epoch {}/{}'.format(epoch + 1, params.num_epochs))
        loss_summary[epoch * train_len:(epoch + 1) * train_len] = train(model, optimizer, loss_fn, train_loader,
                                                                        test_loader, params, epoch)
        test_metrics = evaluate(model, loss_fn, test_loader, params, epoch, sample=args.sampling)
#         if test_metrics['ND'] == float('nan'):
#             test_metrics['ND'] = 1000
#             print('NAN ')

#         elif test_metrics['ND'] == np.nan:
#             print('NAN ')
#             test_metrics['ND'] = 1000
        
        ND_summary[epoch] = test_metrics['ND'] ##################################'ND'
        is_best = ND_summary[epoch] <= best_test_ND

        # Save weights
        utils.save_checkpoint({'epoch': epoch + 1,
                               'state_dict': model.state_dict(),
                               'optim_dict': optimizer.state_dict()},
                              epoch=epoch,
                              is_best=is_best,
                              checkpoint=params.model_dir)

        if is_best:
            logger.info('- Found new best ND') ############# 'ND'
            best_test_ND = ND_summary[epoch]
            best_json_path = os.path.join(params.model_dir, 'metrics_test_best_weights.json')
            utils.save_dict_to_json(test_metrics, best_json_path)

        logger.info('Current Best ND is: %.5f' % best_test_ND) ## 'ND'

        utils.plot_all_epoch(ND_summary[:epoch + 1], args.dataset + '_ND', params.plot_dir)
        utils.plot_all_epoch(loss_summary[:(epoch + 1) * train_len], args.dataset + '_loss', params.plot_dir)

        last_json_path = os.path.join(params.model_dir, 'metrics_test_last_weights.json')
        utils.save_dict_to_json(test_metrics, last_json_path)
        
        
        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        logger.info('ND : %.5f ' % test_metrics['ND'])
        early_stopping(test_metrics['ND'], model)
        
        if early_stopping.early_stop:
            logger.info('Early stopping')
            break
        
#     # load the last checkpoint with the best model
#     model.load_state_dict(torch.load('checkpoint.pt'))

    if args.save_best:
        f = open('./param_search.txt', 'w')
        f.write('-----------\n')
        list_of_params = args.search_params.split(',')
        print_params = ''
        for param in list_of_params:
            param_value = getattr(params, param)
            print_params += f'{param}: {param_value:.2f}'
        print_params = print_params[:-1]
        f.write(print_params + '\n')
        f.write('Best ND: ' + str(best_test_ND) + '\n')
        logger.info(print_params)
        logger.info(f'Best ND: {best_test_ND}')
        f.close()
        utils.plot_all_epoch(ND_summary, print_params + '_ND', location=params.plot_dir)
        utils.plot_all_epoch(loss_summary, print_params + '_loss', location=params.plot_dir)