Exemplo n.º 1
0
    def __init__(self, modelPath, maxSeqLen=128):

        device = torch.device('cpu')
        if torch.cuda.is_available():
            device = torch.device('cuda')

        self.maxSeqLen = maxSeqLen
        self.modelPath = modelPath
        assert os.path.exists(
            self.modelPath), "saved model not present at {}".format(
                self.modelPath)

        loadedDict = torch.load(self.modelPath, map_location=device)
        self.taskParams = loadedDict['task_params']
        logger.info('Task Params loaded from saved model.')

        modelName = self.taskParams.modelType.name.lower()
        _, _, tokenizerClass, defaultName = NLP_MODELS[modelName]
        configName = self.taskParams.modelConfig
        if configName is None:
            configName = defaultName
        #making tokenizer for model
        self.tokenizer = tokenizerClass.from_pretrained(configName)
        logger.info('{} model tokenizer loaded for config {}'.format(
            modelName, configName))

        allParams = {}
        allParams['task_params'] = self.taskParams
        allParams['gpu'] = torch.cuda.is_available()
        # dummy values
        allParams['num_train_steps'] = 10
        allParams['warmup_steps'] = 0
        allParams['learning_rate'] = 2e-5
        allParams['epsilon'] = 1e-8

        #making and loading model
        self.model = multiTaskModel(allParams)
        self.model.load_multi_task_model(loadedDict)
Exemplo n.º 2
0
def main():
    allParams = vars(args)
    logger.info('ARGS : {}'.format(allParams))
    # loading if load_saved_model
    if args.load_saved_model is not None:
        assert os.path.exists(
            args.load_saved_model), "saved model not present at {}".format(
                args.load_saved_model)
        loadedDict = torch.load(args.load_saved_model, map_location=device)
        logger.info('Saved Model loaded from {}'.format(args.load_saved_model))

        if args.finetune is True:
            '''
            NOTE :- 
            In finetune mode, only the weights from the shared encoder (pre-trained) from the model will be used. The headers
            over the model will be made from the task file. You can further finetune for training the entire model.
            Freezing of the pre-trained moddel is also possible with argument 
            '''
            logger.info(
                'In Finetune model. Only shared Encoder weights will be loaded from {}'
                .format(args.load_saved_model))
            logger.info(
                'Task specific headers will be made according to task file')
            taskParams = TasksParam(args.task_file)

        else:
            '''
            NOTE : -
            taskParams used with this saved model must also be stored. THE SAVED TASK PARAMS 
            SHALL BE USED HERE TO AVOID ANY DISCREPENCIES/CHANGES IN THE TASK FILE.
            Hence, if changes are made to task file after saving this model, they shall be ignored
            '''
            taskParams = loadedDict['task_params']
            logger.info('Task Params loaded from saved model.')
            logger.info('Any changes made to task file except the data \
                        file paths after saving this model shall be ignored')
            tempTaskParams = TasksParam(args.task_file)
            #transfering the names of file in new task file to loaded task params
            for taskId, taskName in taskParams.taskIdNameMap.items():
                assert taskName in tempTaskParams.taskIdNameMap.values(
                ), "task names changed in task file given.\
                tasks supported for loaded model are {}".format(
                    list(taskParams.taskIdNameMap.values()))

                taskParams.fileNamesMap[
                    taskName] = tempTaskParams.fileNamesMap[taskName]
    else:
        taskParams = TasksParam(args.task_file)
        logger.info("Task params object created from task file...")

    allParams['task_params'] = taskParams
    allParams['gpu'] = torch.cuda.is_available()
    logger.info('task parameters:\n {}'.format(taskParams.taskDetails))

    tensorboard = SummaryWriter(log_dir=os.path.join(logDir, 'tb_logs'))
    logger.info("Tensorboard writing at {}".format(
        os.path.join(logDir, 'tb_logs')))

    # making handlers for train
    logger.info("Creating data handlers for training...")
    allDataTrain, BatchSamplerTrain, multiTaskDataLoaderTrain = make_data_handlers(
        taskParams, "train", isTrain=True, gpu=allParams['gpu'])
    # if evaluation on dev set is required during training. Labels are required
    # It will occur at the end of each epoch
    if args.eval_while_train:
        logger.info("Creating data handlers for dev...")
        allDataDev, BatchSamplerDev, multiTaskDataLoaderDev = make_data_handlers(
            taskParams, "dev", isTrain=False, gpu=allParams['gpu'])
    # if evaluation on test set is required during training. Labels are required
    # It will occur at the end of each epoch
    if args.test_while_train:
        logger.info("Creating data handlers for test...")
        allDataTest, BatchSamplerTest, multiTaskDataLoaderTest = make_data_handlers(
            taskParams, "test", isTrain=False, gpu=allParams['gpu'])
    #making multi-task model
    allParams['num_train_steps'] = math.ceil(
        len(multiTaskDataLoaderTrain) /
        args.train_batch_size) * args.epochs // args.grad_accumulation_steps
    allParams['warmup_steps'] = args.num_of_warmup_steps
    logger.info("NUM TRAIN STEPS: {}".format(allParams['num_train_steps']))
    logger.info("len of dataloader: {}".format(len(multiTaskDataLoaderTrain)))
    logger.info("Making multi-task model...")
    model = multiTaskModel(allParams)
    #logger.info('################ Network ###################')
    #logger.info('\n{}\n'.format(model.network))

    if args.load_saved_model:
        if args.finetune is True:
            model.load_shared_model(loadedDict, args.freeze_shared_model)
            logger.info('shared model loaded for finetune from {}'.format(
                args.load_saved_model))
        else:
            model.load_multi_task_model(loadedDict)
            logger.info(
                'saved model loaded with global step {} from {}'.format(
                    model.globalStep, args.load_saved_model))
        if args.resume_train:
            logger.info(
                "Resuming training from global step {}. Steps before it will be skipped"
                .format(model.globalStep))

    # training
    resCnt = 0
    for epoch in range(args.epochs):
        logger.info(
            '\n####################### EPOCH {} ###################\n'.format(
                epoch))
        totalEpochLoss = 0
        text = "Epoch: {}".format(epoch)
        tt = int(allParams['num_train_steps'] * args.grad_accumulation_steps /
                 args.epochs)
        with tqdm(total=tt, position=epoch, desc=text) as progress:
            for i, (batchMetaData,
                    batchData) in enumerate(multiTaskDataLoaderTrain):
                batchMetaData, batchData = BatchSamplerTrain.patch_data(
                    batchMetaData, batchData, gpu=allParams['gpu'])
                if args.resume_train and args.load_saved_model and resCnt * args.grad_accumulation_steps < model.globalStep:
                    '''
                    NOTE: - Resume function is only to be used in case the training process couldnt
                    complete or you wish to extend the training to some more epochs.
                    Please keep the gradient accumulation step the same for exact resuming.
                    '''
                    resCnt += 1
                    progress.update(1)
                    continue
                model.update_step(batchMetaData, batchData)
                totalEpochLoss += model.taskLoss.item()

                if model.globalStep % args.log_per_updates == 0 and (
                        model.accumulatedStep + 1
                        == args.grad_accumulation_steps):
                    taskId = batchMetaData['task_id']
                    taskName = taskParams.taskIdNameMap[taskId]
                    #avgLoss = totalEpochLoss / ((i+1)*args.train_batch_size)
                    avgLoss = totalEpochLoss / (i + 1)
                    logger.info(
                        'Steps: {} Task: {} Avg.Loss: {} Task Loss: {}'.format(
                            model.globalStep, taskName, avgLoss,
                            model.taskLoss.item()))

                    tensorboard.add_scalar('train/avg_loss',
                                           avgLoss,
                                           global_step=model.globalStep)
                    tensorboard.add_scalar('train/{}_loss'.format(taskName),
                                           model.taskLoss.item(),
                                           global_step=model.globalStep)

                if args.save_per_updates > 0 and (
                    (model.globalStep + 1) % args.save_per_updates) == 0 and (
                        model.accumulatedStep + 1
                        == args.grad_accumulation_steps):
                    savePath = os.path.join(
                        args.out_dir, 'multi_task_model_{}_{}.pt'.format(
                            epoch, model.globalStep))
                    model.save_multi_task_model(savePath)

                    # limiting the checkpoints save, remove checkpoints if beyond limit
                    if args.limit_save > 0:
                        stepCkpMap = {
                            int(ckp.rstrip('.pt').split('_')[-1]): ckp
                            for ckp in os.listdir(args.out_dir)
                            if ckp.endswith('.pt')
                        }

                        #sorting based on global step
                        stepToDel = sorted(list(
                            stepCkpMap.keys()))[:-args.limit_save]

                        for ckpStep in stepToDel:
                            os.remove(
                                os.path.join(args.out_dir,
                                             stepCkpMap[ckpStep]))
                            logger.info('Removing checkpoint {}'.format(
                                stepCkpMap[ckpStep]))

                progress.update(1)

            #saving model after epoch
            if args.resume_train and args.load_saved_model and resCnt * args.grad_accumulation_steps < model.globalStep:
                pass
            else:
                savePath = os.path.join(
                    args.out_dir,
                    'multi_task_model_{}_{}.pt'.format(epoch,
                                                       model.globalStep))
                model.save_multi_task_model(savePath)

            if args.eval_while_train:
                logger.info("\nRunning Evaluation on dev...")
                with torch.no_grad():
                    evaluate(allDataDev,
                             BatchSamplerDev,
                             multiTaskDataLoaderDev,
                             taskParams,
                             model,
                             gpu=allParams['gpu'],
                             evalBatchSize=args.eval_batch_size,
                             hasTrueLabels=True,
                             needMetrics=True)

            if args.test_while_train:
                logger.info("\nRunning Evaluation on test...")
                wrtPredpath = "test_predictions_{}.tsv".format(epoch)
                with torch.no_grad():
                    evaluate(allDataTest,
                             BatchSamplerTest,
                             multiTaskDataLoaderTest,
                             taskParams,
                             model,
                             gpu=allParams['gpu'],
                             evalBatchSize=args.eval_batch_size,
                             needMetrics=True,
                             hasTrueLabels=True,
                             wrtDir=args.out_dir,
                             wrtPredPath=wrtPredpath)
Exemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--pred_file_path', type=str, required=True,
                        help="path to the tsv file on which predictions to be made")
    parser.add_argument('--out_dir', type = str, required=True,
                        help="path to save the predictions")
    parser.add_argument('--has_labels', type=str, default=False,
                        help = "If labels are not present in file then False")
    parser.add_argument('--task_name', type=str, required = True,
                        help = "task name for which prediction is required.")
    parser.add_argument('--saved_model_path', type=str, required = True,
                        help = "path to the trained model to load")
    parser.add_argument('--eval_batch_size', type=int, default = 32,
                        help = "batch size for prediction")
    parser.add_argument('--max_seq_len', type=int, 
                        help = "max seq len used during training of model")
    parser.add_argument('--seed', type=int, default = 42,
                        help = "seed")
    args = parser.parse_args()

    allParams = vars(args)
    assert os.path.exists(args.saved_model_path), "saved model not present at {}".format(args.saved_model_path)
    assert os.path.exists(args.pred_file_path), "prediction tsv file not present at {}".format(args.pred_file_path)
    loadedDict = torch.load(args.saved_model_path, map_location=device)
    taskParamsModel = loadedDict['task_params']
    logger.info('Task Params loaded from saved model.')

    assert args.task_name in taskParamsModel.taskIdNameMap.values(), "task Name not in task names for loaded model"
    
    taskId = [taskId for taskId, taskName in taskParamsModel.taskIdNameMap.items() if taskName==args.task_name][0]
    taskType = taskParamsModel.taskTypeMap[args.task_name]

    # preparing data from tsv file
    rows = load_data(args.pred_file_path, taskType, hasLabels = args.has_labels)

    modelName = taskParamsModel.modelType.name.lower()
    _, _ , tokenizerClass, defaultName = NLP_MODELS[modelName]
    configName = taskParamsModel.modelConfig
    if configName is None:
        configName = defaultName
    
    #making tokenizer for model
    tokenizer = tokenizerClass.from_pretrained(configName)
    logger.info('{} model tokenizer loaded for config {}'.format(modelName, configName))
    
    dataPath = os.path.join(args.out_dir, '{}_prediction_data'.format(configName))
    if not os.path.exists(dataPath):
        os.makedirs(dataPath)
    wrtFile = os.path.join(dataPath, '{}.json'.format(args.pred_file_path.split('/')[-1].split('.')[0]))
    print('Processing Started...')
    create_data_multithreaded(rows, wrtFile, tokenizer, taskParamsModel, args.task_name,
                            args.max_seq_len, multithreaded = True)
    print('Data Processing done for {}. File saved at {}'.format(args.task_name, wrtFile))

    allTaskslist = [ 
        {"data_task_id" : int(taskId),
         "data_path" : wrtFile,
         "data_task_type" : taskType,
         "data_task_name" : args.task_name}
        ]
    allData = allTasksDataset(allTaskslist)
    batchSampler = Batcher(allData, batchSize=args.eval_batch_size, seed = args.seed)
    batchSamplerUtils = batchUtils(isTrain = False, modelType= taskParamsModel.modelType,
                                  maxSeqLen = args.max_seq_len)
    inferDataLoader = DataLoader(allData, batch_sampler=batchSampler,
                                collate_fn=batchSamplerUtils.collate_fn,
                                pin_memory=torch.cuda.is_available())

    allParams['task_params'] = taskParamsModel
    allParams['gpu'] = torch.cuda.is_available()
    # dummy values
    allParams['num_train_steps'] = 10
    allParams['warmup_steps'] = 0
    allParams['learning_rate'] = 2e-5
    allParams['epsilon'] = 1e-8

    #making and loading model
    model = multiTaskModel(allParams)
    model.load_multi_task_model(loadedDict)

    with torch.no_grad():
        wrtPredFile = 'predictions.tsv'
        evaluate(allData, batchSampler, inferDataLoader, taskParamsModel,
                model, gpu=allParams['gpu'], evalBatchSize=args.eval_batch_size, needMetrics=False, hasTrueLabels=False,
                wrtDir=args.out_dir, wrtPredPath=wrtPredFile)