Ejemplo n.º 1
0
def main():
    allParams = vars(args)
    logger.info('ARGS : {}'.format(allParams))
    # loading if load_saved_model
    if args.load_saved_model is not None:
        assert os.path.exists(
            args.load_saved_model), "saved model not present at {}".format(
                args.load_saved_model)
        loadedDict = torch.load(args.load_saved_model, map_location=device)
        logger.info('Saved Model loaded from {}'.format(args.load_saved_model))

        if args.finetune is True:
            '''
            NOTE :- 
            In finetune mode, only the weights from the shared encoder (pre-trained) from the model will be used. The headers
            over the model will be made from the task file. You can further finetune for training the entire model.
            Freezing of the pre-trained moddel is also possible with argument 
            '''
            logger.info(
                'In Finetune model. Only shared Encoder weights will be loaded from {}'
                .format(args.load_saved_model))
            logger.info(
                'Task specific headers will be made according to task file')
            taskParams = TasksParam(args.task_file)

        else:
            '''
            NOTE : -
            taskParams used with this saved model must also be stored. THE SAVED TASK PARAMS 
            SHALL BE USED HERE TO AVOID ANY DISCREPENCIES/CHANGES IN THE TASK FILE.
            Hence, if changes are made to task file after saving this model, they shall be ignored
            '''
            taskParams = loadedDict['task_params']
            logger.info('Task Params loaded from saved model.')
            logger.info('Any changes made to task file except the data \
                        file paths after saving this model shall be ignored')
            tempTaskParams = TasksParam(args.task_file)
            #transfering the names of file in new task file to loaded task params
            for taskId, taskName in taskParams.taskIdNameMap.items():
                assert taskName in tempTaskParams.taskIdNameMap.values(
                ), "task names changed in task file given.\
                tasks supported for loaded model are {}".format(
                    list(taskParams.taskIdNameMap.values()))

                taskParams.fileNamesMap[
                    taskName] = tempTaskParams.fileNamesMap[taskName]
    else:
        taskParams = TasksParam(args.task_file)
        logger.info("Task params object created from task file...")

    allParams['task_params'] = taskParams
    allParams['gpu'] = torch.cuda.is_available()
    logger.info('task parameters:\n {}'.format(taskParams.taskDetails))

    tensorboard = SummaryWriter(log_dir=os.path.join(logDir, 'tb_logs'))
    logger.info("Tensorboard writing at {}".format(
        os.path.join(logDir, 'tb_logs')))

    # making handlers for train
    logger.info("Creating data handlers for training...")
    allDataTrain, BatchSamplerTrain, multiTaskDataLoaderTrain = make_data_handlers(
        taskParams, "train", isTrain=True, gpu=allParams['gpu'])
    # if evaluation on dev set is required during training. Labels are required
    # It will occur at the end of each epoch
    if args.eval_while_train:
        logger.info("Creating data handlers for dev...")
        allDataDev, BatchSamplerDev, multiTaskDataLoaderDev = make_data_handlers(
            taskParams, "dev", isTrain=False, gpu=allParams['gpu'])
    # if evaluation on test set is required during training. Labels are required
    # It will occur at the end of each epoch
    if args.test_while_train:
        logger.info("Creating data handlers for test...")
        allDataTest, BatchSamplerTest, multiTaskDataLoaderTest = make_data_handlers(
            taskParams, "test", isTrain=False, gpu=allParams['gpu'])
    #making multi-task model
    allParams['num_train_steps'] = math.ceil(
        len(multiTaskDataLoaderTrain) /
        args.train_batch_size) * args.epochs // args.grad_accumulation_steps
    allParams['warmup_steps'] = args.num_of_warmup_steps
    logger.info("NUM TRAIN STEPS: {}".format(allParams['num_train_steps']))
    logger.info("len of dataloader: {}".format(len(multiTaskDataLoaderTrain)))
    logger.info("Making multi-task model...")
    model = multiTaskModel(allParams)
    #logger.info('################ Network ###################')
    #logger.info('\n{}\n'.format(model.network))

    if args.load_saved_model:
        if args.finetune is True:
            model.load_shared_model(loadedDict, args.freeze_shared_model)
            logger.info('shared model loaded for finetune from {}'.format(
                args.load_saved_model))
        else:
            model.load_multi_task_model(loadedDict)
            logger.info(
                'saved model loaded with global step {} from {}'.format(
                    model.globalStep, args.load_saved_model))
        if args.resume_train:
            logger.info(
                "Resuming training from global step {}. Steps before it will be skipped"
                .format(model.globalStep))

    # training
    resCnt = 0
    for epoch in range(args.epochs):
        logger.info(
            '\n####################### EPOCH {} ###################\n'.format(
                epoch))
        totalEpochLoss = 0
        text = "Epoch: {}".format(epoch)
        tt = int(allParams['num_train_steps'] * args.grad_accumulation_steps /
                 args.epochs)
        with tqdm(total=tt, position=epoch, desc=text) as progress:
            for i, (batchMetaData,
                    batchData) in enumerate(multiTaskDataLoaderTrain):
                batchMetaData, batchData = BatchSamplerTrain.patch_data(
                    batchMetaData, batchData, gpu=allParams['gpu'])
                if args.resume_train and args.load_saved_model and resCnt * args.grad_accumulation_steps < model.globalStep:
                    '''
                    NOTE: - Resume function is only to be used in case the training process couldnt
                    complete or you wish to extend the training to some more epochs.
                    Please keep the gradient accumulation step the same for exact resuming.
                    '''
                    resCnt += 1
                    progress.update(1)
                    continue
                model.update_step(batchMetaData, batchData)
                totalEpochLoss += model.taskLoss.item()

                if model.globalStep % args.log_per_updates == 0 and (
                        model.accumulatedStep + 1
                        == args.grad_accumulation_steps):
                    taskId = batchMetaData['task_id']
                    taskName = taskParams.taskIdNameMap[taskId]
                    #avgLoss = totalEpochLoss / ((i+1)*args.train_batch_size)
                    avgLoss = totalEpochLoss / (i + 1)
                    logger.info(
                        'Steps: {} Task: {} Avg.Loss: {} Task Loss: {}'.format(
                            model.globalStep, taskName, avgLoss,
                            model.taskLoss.item()))

                    tensorboard.add_scalar('train/avg_loss',
                                           avgLoss,
                                           global_step=model.globalStep)
                    tensorboard.add_scalar('train/{}_loss'.format(taskName),
                                           model.taskLoss.item(),
                                           global_step=model.globalStep)

                if args.save_per_updates > 0 and (
                    (model.globalStep + 1) % args.save_per_updates) == 0 and (
                        model.accumulatedStep + 1
                        == args.grad_accumulation_steps):
                    savePath = os.path.join(
                        args.out_dir, 'multi_task_model_{}_{}.pt'.format(
                            epoch, model.globalStep))
                    model.save_multi_task_model(savePath)

                    # limiting the checkpoints save, remove checkpoints if beyond limit
                    if args.limit_save > 0:
                        stepCkpMap = {
                            int(ckp.rstrip('.pt').split('_')[-1]): ckp
                            for ckp in os.listdir(args.out_dir)
                            if ckp.endswith('.pt')
                        }

                        #sorting based on global step
                        stepToDel = sorted(list(
                            stepCkpMap.keys()))[:-args.limit_save]

                        for ckpStep in stepToDel:
                            os.remove(
                                os.path.join(args.out_dir,
                                             stepCkpMap[ckpStep]))
                            logger.info('Removing checkpoint {}'.format(
                                stepCkpMap[ckpStep]))

                progress.update(1)

            #saving model after epoch
            if args.resume_train and args.load_saved_model and resCnt * args.grad_accumulation_steps < model.globalStep:
                pass
            else:
                savePath = os.path.join(
                    args.out_dir,
                    'multi_task_model_{}_{}.pt'.format(epoch,
                                                       model.globalStep))
                model.save_multi_task_model(savePath)

            if args.eval_while_train:
                logger.info("\nRunning Evaluation on dev...")
                with torch.no_grad():
                    evaluate(allDataDev,
                             BatchSamplerDev,
                             multiTaskDataLoaderDev,
                             taskParams,
                             model,
                             gpu=allParams['gpu'],
                             evalBatchSize=args.eval_batch_size,
                             hasTrueLabels=True,
                             needMetrics=True)

            if args.test_while_train:
                logger.info("\nRunning Evaluation on test...")
                wrtPredpath = "test_predictions_{}.tsv".format(epoch)
                with torch.no_grad():
                    evaluate(allDataTest,
                             BatchSamplerTest,
                             multiTaskDataLoaderTest,
                             taskParams,
                             model,
                             gpu=allParams['gpu'],
                             evalBatchSize=args.eval_batch_size,
                             needMetrics=True,
                             hasTrueLabels=True,
                             wrtDir=args.out_dir,
                             wrtPredPath=wrtPredpath)
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--pred_file_path', type=str, required=True,
                        help="path to the tsv file on which predictions to be made")
    parser.add_argument('--out_dir', type = str, required=True,
                        help="path to save the predictions")
    parser.add_argument('--has_labels', type=str, default=False,
                        help = "If labels are not present in file then False")
    parser.add_argument('--task_name', type=str, required = True,
                        help = "task name for which prediction is required.")
    parser.add_argument('--saved_model_path', type=str, required = True,
                        help = "path to the trained model to load")
    parser.add_argument('--eval_batch_size', type=int, default = 32,
                        help = "batch size for prediction")
    parser.add_argument('--max_seq_len', type=int, 
                        help = "max seq len used during training of model")
    parser.add_argument('--seed', type=int, default = 42,
                        help = "seed")
    args = parser.parse_args()

    allParams = vars(args)
    assert os.path.exists(args.saved_model_path), "saved model not present at {}".format(args.saved_model_path)
    assert os.path.exists(args.pred_file_path), "prediction tsv file not present at {}".format(args.pred_file_path)
    loadedDict = torch.load(args.saved_model_path, map_location=device)
    taskParamsModel = loadedDict['task_params']
    logger.info('Task Params loaded from saved model.')

    assert args.task_name in taskParamsModel.taskIdNameMap.values(), "task Name not in task names for loaded model"
    
    taskId = [taskId for taskId, taskName in taskParamsModel.taskIdNameMap.items() if taskName==args.task_name][0]
    taskType = taskParamsModel.taskTypeMap[args.task_name]

    # preparing data from tsv file
    rows = load_data(args.pred_file_path, taskType, hasLabels = args.has_labels)

    modelName = taskParamsModel.modelType.name.lower()
    _, _ , tokenizerClass, defaultName = NLP_MODELS[modelName]
    configName = taskParamsModel.modelConfig
    if configName is None:
        configName = defaultName
    
    #making tokenizer for model
    tokenizer = tokenizerClass.from_pretrained(configName)
    logger.info('{} model tokenizer loaded for config {}'.format(modelName, configName))
    
    dataPath = os.path.join(args.out_dir, '{}_prediction_data'.format(configName))
    if not os.path.exists(dataPath):
        os.makedirs(dataPath)
    wrtFile = os.path.join(dataPath, '{}.json'.format(args.pred_file_path.split('/')[-1].split('.')[0]))
    print('Processing Started...')
    create_data_multithreaded(rows, wrtFile, tokenizer, taskParamsModel, args.task_name,
                            args.max_seq_len, multithreaded = True)
    print('Data Processing done for {}. File saved at {}'.format(args.task_name, wrtFile))

    allTaskslist = [ 
        {"data_task_id" : int(taskId),
         "data_path" : wrtFile,
         "data_task_type" : taskType,
         "data_task_name" : args.task_name}
        ]
    allData = allTasksDataset(allTaskslist)
    batchSampler = Batcher(allData, batchSize=args.eval_batch_size, seed = args.seed)
    batchSamplerUtils = batchUtils(isTrain = False, modelType= taskParamsModel.modelType,
                                  maxSeqLen = args.max_seq_len)
    inferDataLoader = DataLoader(allData, batch_sampler=batchSampler,
                                collate_fn=batchSamplerUtils.collate_fn,
                                pin_memory=torch.cuda.is_available())

    allParams['task_params'] = taskParamsModel
    allParams['gpu'] = torch.cuda.is_available()
    # dummy values
    allParams['num_train_steps'] = 10
    allParams['warmup_steps'] = 0
    allParams['learning_rate'] = 2e-5
    allParams['epsilon'] = 1e-8

    #making and loading model
    model = multiTaskModel(allParams)
    model.load_multi_task_model(loadedDict)

    with torch.no_grad():
        wrtPredFile = 'predictions.tsv'
        evaluate(allData, batchSampler, inferDataLoader, taskParamsModel,
                model, gpu=allParams['gpu'], evalBatchSize=args.eval_batch_size, needMetrics=False, hasTrueLabels=False,
                wrtDir=args.out_dir, wrtPredPath=wrtPredFile)
Ejemplo n.º 3
0
    def infer(self, dataList, taskNamesList, batchSize=8, seed=42):
        """
        This is the function which can be called to get the predictions for input samples
        for the mentioned tasks.

        - Samples can be packed in a ``list of lists`` manner as the function processes inputs in batch.
        - In case, an input sample requires sentence pair, the two sentences can be kept as elements of the list.
        - In case of single sentence classification or NER tasks, only the first element of a sample will be used.
        - For NER, the infer function automatically splits the sentence into tokens.
        - All the tasks mentioned in ``taskNamesList`` are performed for all the input samples.

        Args:

            dataList (:obj:`list of lists`) : A batch of input samples. For eg.
                
                [
                    [<sentenceA>, <sentenceB>],
                    
                    [<sentenceA>, <sentenceB>],

                ]

                or in case all the tasks just require single sentence inputs,
                
                [
                    [<sentenceA>],

                    [<sentenceA>],

                ]

            taskNamesList (:obj:`list`) : List of tasks to be performed on dataList samples. For eg.

                ['TaskA', 'TaskB', 'TaskC']

                You can choose the tasks you want to infer. For eg.

                ['TaskB']

            batchSize (:obj:`int`, defaults to :obj:`8`) : Batch size for running inference.


        Return:

            outList (:obj:`list of objects`) :
                List of dictionary objects where each object contains one corresponding input sample and it's tasks outputs. The task outputs
                can also contain the confidence scores. For eg.

                [
                    {'Query' : [<sentence>],

                    'TaskA' : <TaskA output>,

                    'TaskB' : <TaskB output>,

                    'TaskC' : <TaskC output>},

                ]

        Example::

            >>> samples = [ ['sample_sentence_1'], ['sample_sentence_2'] ]
            >>> tasks = ['TaskA', 'TaskB']
            >>> pipe.infer(samples, tasks)

        """
        #print(dataList)
        #print(taskNamesList)
        allTasksList = []
        for taskName in taskNamesList:
            assert taskName in self.taskParams.taskIdNameMap.values(
            ), "task Name not in task names for loaded model"
            taskId = [
                taskId
                for taskId, tName in self.taskParams.taskIdNameMap.items()
                if tName == taskName
            ][0]
            taskType = self.taskParams.taskTypeMap[taskName]

            taskData = self.make_feature_samples(dataList, taskType, taskName)
            #print('task data :', taskData)

            tasksDict = {
                "data_task_id": int(taskId),
                "data_": taskData,
                "data_task_type": taskType,
                "data_task_name": taskName
            }
            allTasksList.append(tasksDict)

        allData = allTasksDataset(allTasksList, pipeline=True)
        batchSampler = Batcher(allData,
                               batchSize=batchSize,
                               seed=seed,
                               shuffleBatch=False,
                               shuffleTask=False)
        # VERY IMPORTANT TO TURN OFF BATCH SHUFFLE IN INFERENCE. ELSE PREDICTION SCORES
        # WILL GET JUMBLED

        batchSamplerUtils = batchUtils(isTrain=False,
                                       modelType=self.taskParams.modelType,
                                       maxSeqLen=self.maxSeqLen)
        inferDataLoader = DataLoader(allData,
                                     batch_sampler=batchSampler,
                                     collate_fn=batchSamplerUtils.collate_fn,
                                     pin_memory=torch.cuda.is_available())

        with torch.no_grad():
            allIds, allPreds, allScores = evaluate(
                allData,
                batchSampler,
                inferDataLoader,
                self.taskParams,
                self.model,
                gpu=torch.cuda.is_available(),
                evalBatchSize=batchSize,
                needMetrics=False,
                hasTrueLabels=False,
                returnPred=True)

            finalOutList = self.format_output(dataList, allIds, allPreds,
                                              allScores)
            #print(finalOutList)
            return finalOutList
Ejemplo n.º 4
0
                        train_dataloader,
                        optimizer,
                        criterion,
                        scheduler,
                        device,
                        epoch,
                        metric_logger,
                        graph_loss,
                        graph_accuracy,
                        args.out_dir,
                        print_freq=100)
        evaluate(model,
                 val_dataloader,
                 criterion,
                 device,
                 epoch,
                 metric_logger,
                 graph_loss,
                 graph_accuracy,
                 print_freq=5)
    test(model, test_dataloader, device)

    if args.save_model:
        functions.save_model(model, args.out_dir)

    session.done()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training ended')
    logger.info('Training time %s', total_time_str)