def processDescriptionParam(descOpts, bugReportDatabase, inputHandlers, preprocessors, encoders, databasePath,
                            cacheFolder,
                            logger, paddingSym):
    # Use summary and description (concatenated) to address this problem
    logger.info("Using Description information.")
    # Loading word embedding

    lexicon, embedding = load_embedding(descOpts, paddingSym)
    logger.info("Lexicon size: %d" % (lexicon.getLen()))
    logger.info("Word Embedding size: %d" % (embedding.getEmbeddingSize()))
    paddingId = lexicon.getLexiconIndex(paddingSym)
    # Loading Filters
    filters = loadFilters(descOpts['filters'])
    # Tokenizer
    if descOpts['tokenizer'] == 'default':
        logger.info("Use default tokenizer to tokenize summary information")
        tokenizer = MultiLineTokenizer()
    elif descOpts['tokenizer'] == 'white_space':
        logger.info("Use white space tokenizer to tokenize summary information")
        tokenizer = WhitespaceTokenizer()
    else:
        raise ArgumentError(
            "Tokenizer value %s is invalid. You should choose one of these: default and white_space" %
            descOpts['tokenizer'])

    arguments = (
        databasePath, descOpts['word_embedding'], str(descOpts['lexicon']),
        ' '.join(sorted([fil.__class__.__name__ for fil in filters])),
        descOpts['tokenizer'], "description")

    descCache = PreprocessingCache(cacheFolder, arguments)
    descPreprocessor = DescriptionPreprocessor(lexicon, bugReportDatabase, filters, tokenizer, paddingId, descCache)
    preprocessors.append(descPreprocessor)

    if descOpts['encoder_type'] == 'rnn':
        rnnType = descOpts.get('rnn_type')
        hiddenSize = descOpts.get('hidden_size')
        bidirectional = descOpts.get('bidirectional', False)
        numLayers = descOpts.get('num_layers', 1)
        dropout = descOpts.get('dropout', 0.0)
        updateEmb = descOpts.get('update_embedding', False)
        fixedOpt = descOpts.get('fixed_opt', False)

        descRNN = SortedRNNEncoder(rnnType, embedding, hiddenSize, numLayers, bidirectional, updateEmb,
                                  dropout)

        if fixedOpt == 'self_att':
            att = SelfAttention(descRNN.getOutputSize(), descOpts['self_att_hidden'], descOpts['n_hops'])
            descEncoder = RNN_Self_Attention(descRNN, att, paddingId, dropout)
        else:
            descEncoder = RNNFixedOuput(descRNN, fixedOpt, dropout)

        encoders.append(descEncoder)
        inputHandlers.append(RNNInputHandler(paddingId))
    elif descOpts['encoder_type'] == 'cnn':
        windowSizes = descOpts.get('window_sizes', [3])
        nFilters = descOpts.get('nfilters', 100)
        updateEmb = descOpts.get('update_embedding', False)
        actFunc = loadActivationFunction(descOpts.get('activation', 'relu'))
        batchNorm = descOpts.get('batch_normalization', False)
        dropout = descOpts.get('dropout', 0.0)

        descEncoder = TextCNN(windowSizes, nFilters, embedding, updateEmb, actFunc, batchNorm, dropout)
        encoders.append(descEncoder)
        inputHandlers.append(TextCNNInputHandler(paddingId, max(windowSizes)))
    elif descOpts['encoder_type'] == 'cnn+dense':
        windowSizes = descOpts.get('window_sizes', [3])
        nFilters = descOpts.get('nfilters', 100)
        updateEmb = descOpts.get('update_embedding', False)
        actFunc = loadActivationFunction(descOpts.get('activation', 'relu'))
        batchNorm = descOpts.get('batch_normalization', False)
        dropout = descOpts.get('dropout', 0.0)
        hiddenSizes = descOpts.get('hidden_sizes')
        hiddenAct = loadActivationClass(descOpts.get('hidden_act'))
        hiddenDropout = descOpts.get('hidden_dropout')
        batchLast = descOpts.get("bn_last_layer", False)

        cnnEnc = TextCNN(windowSizes, nFilters, embedding, updateEmb, actFunc, batchNorm, dropout)
        descEncoder = MultilayerDense(cnnEnc, hiddenSizes, hiddenAct, batchNorm, batchLast, hiddenDropout)
        encoders.append(descEncoder)
        inputHandlers.append(TextCNNInputHandler(paddingId, max(windowSizes)))
    elif descOpts['encoder_type'] == 'dense+self_att':
        dropout = descOpts.get('dropout', 0.0)
        hiddenSize = descOpts.get('hidden_size')
        self_att_hidden = descOpts['self_att_hidden']
        n_hops = descOpts['n_hops']
        updateEmb = descOpts.get('update_embedding', False)

        descEncoder = Dense_Self_Attention(embedding, hiddenSize, self_att_hidden, n_hops, paddingId, updateEmb, dropout=dropout)
        encoders.append(descEncoder)
        inputHandlers.append(TextCNNInputHandler(paddingId, -1))
    elif descOpts['encoder_type'] == 'word_mean':
        standardization = descOpts.get('standardization', False)
        dropout = descOpts.get('dropout', 0.0)
        updateEmb = descOpts.get('update_embedding', False)
        batch_normalization = descOpts.get('update_embedding', False)
        hiddenSize = descOpts.get('hidden_size')

        descEncoder = WordMean( embedding, updateEmb, hiddenSize, standardization, dropout, batch_normalization)

        encoders.append(descEncoder)
        inputHandlers.append(RNNInputHandler(paddingId))
    else:
        raise ArgumentError(
            "Encoder type of summary and description is invalid (%s). You should choose one of these: cnn" %
            descOpts['encoder_type'])
def processSumDescParam(sum_desc_opts, bugReportDatabase, inputHandlers, preprocessors, encoders, cacheFolder,
                        databasePath, logger, paddingSym):
    # Use summary and description (concatenated) to address this problem
    logger.info("Using summary and description information.")
    # Loading word embedding
    lexicon, embedding =  load_embedding(sum_desc_opts, paddingSym)
    logger.info("Lexicon size: %d" % (lexicon.getLen()))
    logger.info("Word Embedding size: %d" % (embedding.getEmbeddingSize()))
    paddingId = lexicon.getLexiconIndex(paddingSym)
    # Loading Filters
    filters = loadFilters(sum_desc_opts['filters'])
    # Tokenizer
    if sum_desc_opts['tokenizer'] == 'default':
        logger.info("Use default tokenizer to tokenize summary+description information")
        tokenizer = MultiLineTokenizer()
    elif sum_desc_opts['tokenizer'] == 'white_space':
        logger.info("Use white space tokenizer to tokenize summary+description information")
        tokenizer = WhitespaceTokenizer()
    else:
        raise ArgumentError(
            "Tokenizer value %s is invalid. You should choose one of these: default and white_space" %
            sum_desc_opts['tokenizer'])
    arguments = (
        databasePath, sum_desc_opts['word_embedding'],
        ' '.join(sorted([fil.__class__.__name__ for fil in filters])),
        sum_desc_opts['tokenizer'], "summary_description")
    cacheSumDesc = PreprocessingCache(cacheFolder, arguments)
    sumDescPreprocessor = SummaryDescriptionPreprocessor(lexicon, bugReportDatabase, filters, tokenizer, paddingId, cacheSumDesc)
    preprocessors.append(sumDescPreprocessor)
    if sum_desc_opts['encoder_type'] == 'cnn':
        windowSizes = sum_desc_opts.get('window_sizes', [3])
        nFilters = sum_desc_opts.get('nfilters', 100)
        updateEmb = sum_desc_opts.get('update_embedding', False)
        actFunc = loadActivationFunction(sum_desc_opts.get('activation', 'relu'))
        batchNorm = sum_desc_opts.get('batch_normalization', False)
        dropout = sum_desc_opts.get('dropout', 0.0)

        sumDescEncoder = TextCNN(windowSizes, nFilters, embedding, updateEmb, actFunc, batchNorm, dropout)
        encoders.append(sumDescEncoder)
        inputHandlers.append(TextCNNInputHandler(paddingId, max(windowSizes)))

    elif sum_desc_opts['encoder_type'] == 'cnn+dense':
        windowSizes = sum_desc_opts.get('window_sizes', [3])
        nFilters = sum_desc_opts.get('nfilters', 100)
        updateEmb = sum_desc_opts.get('update_embedding', False)
        actFunc = loadActivationFunction(sum_desc_opts.get('activation', 'relu'))
        batchNorm = sum_desc_opts.get('batch_normalization', False)
        dropout = sum_desc_opts.get('dropout', 0.0)
        hiddenSizes = sum_desc_opts.get('hidden_sizes')
        hiddenAct = loadActivationClass(sum_desc_opts.get('hidden_act'))
        hiddenDropout = sum_desc_opts.get('hidden_dropout')
        batchLast = sum_desc_opts.get("bn_last_layer", False)

        cnnEnc = TextCNN(windowSizes, nFilters, embedding, updateEmb, actFunc, batchNorm, dropout)
        sumDescEncoder = MultilayerDense(cnnEnc, hiddenSizes, hiddenAct, batchNorm, batchLast, hiddenDropout)
        encoders.append(sumDescEncoder)
        inputHandlers.append(TextCNNInputHandler(paddingId, max(windowSizes)))
    elif sum_desc_opts['encoder_type'] == 'word_mean':
        standardization = sum_desc_opts.get('standardization', False)
        dropout = sum_desc_opts.get('dropout', 0.0)
        updateEmb = sum_desc_opts.get('update_embedding', False)
        batch_normalization = sum_desc_opts.get('update_embedding', False)
        hiddenSize = sum_desc_opts.get('hidden_size')

        sumDescEncoder = WordMean(embedding, updateEmb, hiddenSize, standardization, dropout, batch_normalization)

        encoders.append(sumDescEncoder)
        inputHandlers.append(RNNInputHandler(paddingId))
    else:
        raise ArgumentError(
            "Encoder type of summary and description is invalid (%s). You should choose one of these: cnn" %
            sum_desc_opts['encoder_type'])
Beispiel #3
0
def main(_run, _config, _seed, _log):
    """

    :param _run:
    :param _config:
    :param _seed:
    :param _log:
    :return:
    """
    """
    Setting and loading parameters
    """
    # Setting logger
    args = _config
    logger = _log

    logger.info(args)
    logger.info('It started at: %s' % datetime.now())

    torch.manual_seed(_seed)

    bugReportDatabase = BugReportDatabase.fromJson(args['bug_database'])
    paddingSym = "</s>"
    batchSize = args['batch_size']

    device = torch.device('cuda' if args['cuda'] else "cpu")

    if args['cuda']:
        logger.info("Turning CUDA on")
    else:
        logger.info("Turning CUDA off")

    # It is the folder where the preprocessed information will be stored.
    cacheFolder = args['cache_folder']

    # Setting the parameter to save and loading parameters
    importantParameters = ['compare_aggregation', 'categorical']
    parametersToSave = dict([(parName, args[parName])
                             for parName in importantParameters])

    if args['load'] is not None:
        mapLocation = (
            lambda storage, loc: storage.cuda()) if cudaOn else 'cpu'
        modelInfo = torch.load(args['load'], map_location=mapLocation)
        modelState = modelInfo['model']

        for paramName, paramValue in modelInfo['params'].items():
            args[paramName] = paramValue
    else:
        modelState = None

    if args['rep'] is not None and args['rep']['model']:
        logger.info("Loading REP")
        rep = read_weights(args['rep']['model'])
        rep_input, max_tkn_id = read_dbrd_file(args['rep']['input'], math.inf)
        rep_recommendation = args['rep']['k']

        rep.fit_transform(rep_input, max_tkn_id, True)

        rep_input_by_id = {}

        for inp in rep_input:
            rep_input_by_id[inp[SUN_REPORT_ID_INDEX]] = inp

    else:
        rep = None

    preprocessors = PreprocessorList()
    inputHandlers = []

    categoricalOpt = args.get('categorical')

    if categoricalOpt is not None and len(categoricalOpt) != 0:
        categoricalEncoder, _, _ = processCategoricalParam(
            categoricalOpt, bugReportDatabase, inputHandlers, preprocessors,
            None, logger, cudaOn)
    else:
        categoricalEncoder = None

    filterInputHandlers = []

    compareAggOpt = args['compare_aggregation']
    databasePath = args['bug_database']

    # Loading word embedding
    if compareAggOpt["word_embedding"]:
        # todo: Allow use embeddings and other representation
        lexicon, embedding = Embedding.fromFile(
            compareAggOpt['word_embedding'],
            'UUUKNNN',
            hasHeader=False,
            paddingSym=paddingSym)
        logger.info("Lexicon size: %d" % (lexicon.getLen()))
        logger.info("Word Embedding size: %d" % (embedding.getEmbeddingSize()))
        paddingId = lexicon.getLexiconIndex(paddingSym)
        lazy = False
    else:
        embedding = None

    # Tokenizer
    if compareAggOpt['tokenizer'] == 'default':
        logger.info("Use default tokenizer to tokenize summary information")
        tokenizer = MultiLineTokenizer()
    elif compareAggOpt['tokenizer'] == 'white_space':
        logger.info(
            "Use white space tokenizer to tokenize summary information")
        tokenizer = WhitespaceTokenizer()
    else:
        raise ArgumentError(
            "Tokenizer value %s is invalid. You should choose one of these: default and white_space"
            % compareAggOpt['tokenizer'])

    # Preparing input handlers, preprocessors and cache
    minSeqSize = max(compareAggOpt['aggregate']["window"]
                     ) if compareAggOpt['aggregate']["model"] == "cnn" else -1

    if compareAggOpt['summary'] is not None:
        # Use summary and description (concatenated) to address this problem
        logger.info("Using Summary information.")
        # Loading Filters
        sumFilters = loadFilters(compareAggOpt['summary']['filters'])

        if compareAggOpt['summary']['model_type'] in ('lstm', 'gru',
                                                      'word_emd', 'residual'):
            arguments = (databasePath, compareAggOpt['word_embedding'],
                         ' '.join(
                             sorted([
                                 fil.__class__.__name__ for fil in sumFilters
                             ])), compareAggOpt['tokenizer'],
                         SummaryPreprocessor.__name__)

            inputHandlers.append(
                RNNInputHandler(paddingId, minInputSize=minSeqSize))

            summaryCache = PreprocessingCache(cacheFolder, arguments)
            summaryPreprocessor = SummaryPreprocessor(lexicon,
                                                      bugReportDatabase,
                                                      sumFilters, tokenizer,
                                                      paddingId, summaryCache)
        elif compareAggOpt['summary']['model_type'] == 'ELMo':
            raise NotImplementedError("ELMO is not implemented!")
            # inputHandlers.append(ELMoInputHandler(cudaOn, minInputSize=minSeqSize))
            # summaryPreprocessor = ELMoPreprocessor(0, elmoEmbedding)
            # compareAggOpt['summary']["input_size"] = elmoEmbedding.get_size()
        elif compareAggOpt['summary']['model_type'] == 'BERT':
            arguments = (databasePath, "CADD SUMMARY", "BERT",
                         "bert-base-uncased")

            inputHandlers.append(BERTInputHandler(0, minInputSize=minSeqSize))

            summaryCache = PreprocessingCache(cacheFolder, arguments)
            summaryPreprocessor = TransformerPreprocessor(
                "short_desc", "bert-base-uncased", BertTokenizer, 0,
                bugReportDatabase, summaryCache)
#            compareAggOpt['summary']["input_size"] = 768

        preprocessors.append(summaryPreprocessor)

    if compareAggOpt['desc'] is not None:
        # Use summary and description (concatenated) to address this problem
        logger.info("Using Description information.")
        descFilters = loadFilters(compareAggOpt['desc']['filters'])

        if compareAggOpt['desc']['model_type'] in ('lstm', 'gru', 'word_emd',
                                                   'residual'):
            arguments = (databasePath, compareAggOpt['word_embedding'],
                         ' '.join(
                             sorted([
                                 fil.__class__.__name__ for fil in descFilters
                             ])), compareAggOpt['tokenizer'], "CADD DESC",
                         str(compareAggOpt['desc']['summarization']))

            inputHandlers.append(
                RNNInputHandler(paddingId, minInputSize=minSeqSize))

            descriptionCache = PreprocessingCache(cacheFolder, arguments)
            descPreprocessor = DescriptionPreprocessor(lexicon,
                                                       bugReportDatabase,
                                                       descFilters,
                                                       tokenizer,
                                                       paddingId,
                                                       cache=descriptionCache)
        elif compareAggOpt['desc']['model_type'] == 'ELMo':
            raise NotImplementedError("ELMO is not implemented!")
            # inputHandlers.append(ELMoInputHandler(cudaOn, minInputSize=minSeqSize))
            # descPreprocessor = ELMoPreprocessor(1, elmoEmbedding)
            # compareAggOpt['desc']["input_size"] = elmoEmbedding.get_size()
        elif compareAggOpt['desc']['model_type'] == 'BERT':
            arguments = (databasePath, "CADD DESC", "BERT",
                         "bert-base-uncased")

            inputHandlers.append(BERTInputHandler(0, minInputSize=minSeqSize))

            descriptionCache = PreprocessingCache(cacheFolder, arguments)
            descPreprocessor = TransformerPreprocessor("description",
                                                       "bert-base-uncased",
                                                       BertTokenizer, 0,
                                                       bugReportDatabase,
                                                       descriptionCache)
#            compareAggOpt['desc']["input_size"] = 768

        preprocessors.append(descPreprocessor)

    # Create model
    model = CADD(embedding,
                 categoricalEncoder,
                 compareAggOpt,
                 compareAggOpt['summary'],
                 compareAggOpt['desc'],
                 compareAggOpt['matching'],
                 compareAggOpt['aggregate'],
                 cudaOn=cudaOn)

    lossFn = F.nll_loss
    lossNoReduction = NLLLoss(reduction='none')

    if cudaOn:
        model.cuda()

    if modelState:
        model.load_state_dict(modelState)
    """
    Loading the training and validation. Also, it sets how the negative example will be generated.
    """
    cmpAggCollate = PairBugCollate(inputHandlers, torch.int64)

    # load training
    if args.get('pairs_training'):
        negativePairGenOpt = args.get('neg_pair_generator', )
        pairTrainingFile = args.get('pairs_training')

        offlineGeneration = not (negativePairGenOpt is None
                                 or negativePairGenOpt['type'] == 'none')
        masterIdByBugId = bugReportDatabase.getMasterIdByBugId()
        randomAnchor = negativePairGenOpt['random_anchor']

        if rep:
            logger.info("Generate negative examples using REP.")
            randomAnchor = negativePairGenOpt['random_anchor']
            trainingDataset = BugDataset(args['rep']['training'])

            bugIds = trainingDataset.bugIds
            negativePairGenerator = REPGenerator(rep, rep_input_by_id,
                                                 args['rep']['neg_training'],
                                                 preprocessors, bugIds,
                                                 masterIdByBugId,
                                                 args['rep']['rate'],
                                                 randomAnchor)
        elif not offlineGeneration:
            logger.info("Not generate dynamically the negative examples.")
            negativePairGenerator = None
        else:
            pairGenType = negativePairGenOpt['type']

            if pairGenType == 'random':
                logger.info("Random Negative Pair Generator")
                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (trainingDataset.info, len(bugIds)))

                negativePairGenerator = RandomGenerator(
                    preprocessors,
                    cmpAggCollate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    randomAnchor=randomAnchor)

            elif pairGenType == 'non_negative':
                logger.info("Non Negative Pair Generator")
                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (trainingDataset.info, len(bugIds)))

                negativePairGenerator = NonNegativeRandomGenerator(
                    preprocessors,
                    cmpAggCollate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    negativePairGenOpt['n_tries'],
                    device,
                    randomAnchor=randomAnchor)
            elif pairGenType == 'misc_non_zero':
                logger.info("Misc Non Zero Pair Generator")
                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (trainingDataset.info, len(bugIds)))

                negativePairGenerator = MiscNonZeroRandomGen(
                    preprocessors,
                    cmpAggCollate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    trainingDataset.duplicateIds,
                    masterIdByBugId,
                    negativePairGenOpt['n_tries'],
                    device,
                    randomAnchor=randomAnchor)
            elif pairGenType == 'random_k':
                logger.info("Random K Negative Pair Generator")
                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (trainingDataset.info, len(bugIds)))

                negativePairGenerator = KRandomGenerator(
                    preprocessors,
                    cmpAggCollate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    negativePairGenOpt['k'],
                    device,
                    randomAnchor=randomAnchor)
            elif pairGenType == "pre":
                logger.info("Pre-selected list generator")
                negativePairGenerator = PreSelectedGenerator(
                    negativePairGenOpt['pre_list_file'],
                    preprocessors,
                    negativePairGenOpt['rate'],
                    masterIdByBugId,
                    negativePairGenOpt['preselected_length'],
                    randomAnchor=randomAnchor)

            elif pairGenType == "positive_pre":
                logger.info("Positive Pre-selected list generator")
                negativePairGenerator = PositivePreSelectedGenerator(
                    negativePairGenOpt['pre_list_file'],
                    preprocessors,
                    cmpAggCollate,
                    negativePairGenOpt['rate'],
                    masterIdByBugId,
                    negativePairGenOpt['preselected_length'],
                    randomAnchor=randomAnchor)
            elif pairGenType == "misc_non_zero_pre":
                logger.info("Misc: non-zero and Pre-selected list generator")
                negativePairGenerator1 = PreSelectedGenerator(
                    negativePairGenOpt['pre_list_file'],
                    preprocessors,
                    negativePairGenOpt['rate'],
                    masterIdByBugId,
                    negativePairGenOpt['preselected_length'],
                    randomAnchor=randomAnchor)

                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                negativePairGenerator2 = NonNegativeRandomGenerator(
                    preprocessors,
                    cmpAggCollate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    negativePairGenOpt['n_tries'],
                    device,
                    randomAnchor=randomAnchor)

                negativePairGenerator = MiscOfflineGenerator(
                    (negativePairGenerator1, negativePairGenerator2))
            elif pairGenType == "misc_non_zero_positive_pre":
                logger.info(
                    "Misc: non-zero and Positive Pre-selected list generator")
                negativePairGenerator1 = PositivePreSelectedGenerator(
                    negativePairGenOpt['pre_list_file'],
                    preprocessors,
                    cmpAggCollate,
                    negativePairGenOpt['rate'],
                    masterIdByBugId,
                    negativePairGenOpt['preselected_length'],
                    randomAnchor=randomAnchor)

                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                negativePairGenerator2 = NonNegativeRandomGenerator(
                    preprocessors,
                    cmpAggCollate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    negativePairGenOpt['n_tries'],
                    device,
                    randomAnchor=randomAnchor)

                negativePairGenerator = MiscOfflineGenerator(
                    (negativePairGenerator1, negativePairGenerator2))

            else:
                raise ArgumentError(
                    "Offline generator is invalid (%s). You should choose one of these: random, hard and pre"
                    % pairGenType)

        pairTrainingReader = PairBugDatasetReader(
            pairTrainingFile,
            preprocessors,
            negativePairGenerator,
            randomInvertPair=args['random_switch'])
        trainingCollate = cmpAggCollate
        trainingLoader = DataLoader(pairTrainingReader,
                                    batch_size=batchSize,
                                    collate_fn=trainingCollate.collate,
                                    shuffle=True)
        logger.info("Training size: %s" % (len(trainingLoader.dataset)))

    # load validation
    if args.get('pairs_validation'):
        pairValidationReader = PairBugDatasetReader(
            args.get('pairs_validation'), preprocessors)
        validationLoader = DataLoader(pairValidationReader,
                                      batch_size=batchSize,
                                      collate_fn=cmpAggCollate.collate)

        logger.info("Validation size: %s" % (len(validationLoader.dataset)))
    else:
        validationLoader = None
    """
    Training and evaluate the model. 
    """
    optimizer_opt = args.get('optimizer', 'adam')

    if optimizer_opt == 'sgd':
        logger.info('SGD')
        optimizer = optim.SGD(model.parameters(),
                              lr=args['lr'],
                              weight_decay=args['l2'])
    elif optimizer_opt == 'adam':
        logger.info('Adam')
        optimizer = optim.Adam(model.parameters(),
                               lr=args['lr'],
                               weight_decay=args['l2'])

    # Recall rate
    rankingScorer = GeneralScorer(model, preprocessors, device, cmpAggCollate)
    recallEstimationTrainOpt = args.get('recall_estimation_train')

    if recallEstimationTrainOpt:
        preselectListRankingTrain = PreselectListRanking(
            recallEstimationTrainOpt, args['sample_size_rr_tr'])

    recallEstimationOpt = args.get('recall_estimation')

    if recallEstimationOpt:
        preselectListRanking = PreselectListRanking(recallEstimationOpt,
                                                    args['sample_size_rr_val'])

    # LR scheduler
    lrSchedulerOpt = args.get('lr_scheduler', None)

    if lrSchedulerOpt is None:
        logger.info("Scheduler: Constant")
        lrSched = None
    elif lrSchedulerOpt["type"] == 'step':
        logger.info("Scheduler: StepLR (step:%s, decay:%f)" %
                    (lrSchedulerOpt["step_size"], args["decay"]))
        lrSched = StepLR(optimizer, lrSchedulerOpt["step_size"],
                         lrSchedulerOpt["decay"])
    elif lrSchedulerOpt["type"] == 'exp':
        logger.info("Scheduler: ExponentialLR (decay:%f)" %
                    (lrSchedulerOpt["decay"]))
        lrSched = ExponentialLR(optimizer, lrSchedulerOpt["decay"])
    elif lrSchedulerOpt["type"] == 'linear':
        logger.info(
            "Scheduler: Divide by (1 + epoch * decay) ---- (decay:%f)" %
            (lrSchedulerOpt["decay"]))

        lrDecay = lrSchedulerOpt["decay"]
        lrSched = LambdaLR(optimizer, lambda epoch: 1 /
                           (1.0 + epoch * lrDecay))
    else:
        raise ArgumentError(
            "LR Scheduler is invalid (%s). You should choose one of these: step, exp and linear "
            % pairGenType)

    # Set training functions
    def trainingIteration(engine, batch):
        engine.kk = 0

        model.train()
        optimizer.zero_grad()
        x, y = batch
        output = model(*x)
        loss = lossFn(output, y)
        loss.backward()
        optimizer.step()
        return loss, output, y

    def scoreDistanceTrans(output):
        if len(output) == 3:
            _, y_pred, y = output
        else:
            y_pred, y = output

        if lossFn == F.nll_loss:
            return torch.exp(y_pred[:, 1]), y

    trainer = Engine(trainingIteration)
    trainingMetrics = {
        'training_loss':
        AverageLoss(lossFn, batch_size=lambda x: x[0].shape[0]),
        'training_dist_target':
        MeanScoreDistance(output_transform=scoreDistanceTrans)
    }

    # Add metrics to trainer
    for name, metric in trainingMetrics.items():
        metric.attach(trainer, name)

    # Set validation functions
    def validationIteration(engine, batch):
        if not hasattr(engine, 'kk'):
            engine.kk = 0

        model.eval()
        with torch.no_grad():
            x, y = batch
            y_pred = model(*x)

            # for k, (pred, t) in enumerate(zip(y_pred, y)):
            #     engine.kk += 1
            #     print("{}: {} \t {}".format(engine.kk, torch.round(torch.exp(pred) * 100), t))
            return y_pred, y

    validationMetrics = {
        'validation_loss':
        ignite.metrics.Loss(lossFn),
        'validation_dist_target':
        MeanScoreDistance(output_transform=scoreDistanceTrans)
    }
    evaluator = Engine(validationIteration)

    # Add metrics to evaluator
    for name, metric in validationMetrics.items():
        metric.attach(evaluator, name)

    # recommendation
    if rep:
        recommendation_fn = REP_CADD_Recommender(
            rep, rep_input_by_id,
            rep_recommendation).generateRecommendationList
    else:
        recommendation_fn = generateRecommendationList

    @trainer.on(Events.EPOCH_STARTED)
    def onStartEpoch(engine):
        epoch = engine.state.epoch
        logger.info("Epoch: %d" % epoch)

        if lrSched:
            lrSched.step()

        logger.info("LR: %s" % str(optimizer.param_groups[0]["lr"]))

    @trainer.on(Events.EPOCH_COMPLETED)
    def onEndEpoch(engine):
        epoch = engine.state.epoch

        logMetrics(_run, logger, engine.state.metrics, epoch)

        # Evaluate Training
        if validationLoader:
            evaluator.run(validationLoader)
            logMetrics(_run, logger, evaluator.state.metrics, epoch)

        if recallEstimationTrainOpt and (epoch % args['rr_train_epoch'] == 0):
            logRankingResult(_run,
                             logger,
                             preselectListRankingTrain,
                             rankingScorer,
                             bugReportDatabase,
                             None,
                             epoch,
                             "train",
                             recommendationListfn=recommendation_fn)
            rankingScorer.free()

        if recallEstimationOpt and (epoch % args['rr_val_epoch'] == 0):
            logRankingResult(_run,
                             logger,
                             preselectListRanking,
                             rankingScorer,
                             bugReportDatabase,
                             args.get("ranking_result_file"),
                             epoch,
                             "validation",
                             recommendationListfn=recommendation_fn)
            rankingScorer.free()

        pairTrainingReader.sampleNewNegExamples(model, lossNoReduction)

        if args.get('save'):
            save_by_epoch = args['save_by_epoch']

            if save_by_epoch and epoch in save_by_epoch:
                file_name, file_extension = os.path.splitext(args['save'])
                file_path = file_name + '_epoch_{}'.format(
                    epoch) + file_extension
            else:
                file_path = args['save']

            modelInfo = {
                'model': model.state_dict(),
                'params': parametersToSave
            }

            logger.info("==> Saving Model: %s" % file_path)
            torch.save(modelInfo, file_path)

    if args.get('pairs_training'):
        trainer.run(trainingLoader, max_epochs=args['epochs'])
    elif args.get('pairs_validation'):
        # Evaluate Training
        evaluator.run(validationLoader)
        logMetrics(_run, logger, evaluator.state.metrics, 0)

        if recallEstimationOpt:
            logRankingResult(_run,
                             logger,
                             preselectListRanking,
                             rankingScorer,
                             bugReportDatabase,
                             args.get("ranking_result_file"),
                             0,
                             "validation",
                             recommendationListfn=recommendation_fn)

    recallRateOpt = args.get('recall_rate', {'type': 'none'})
    if recallRateOpt['type'] != 'none':
        if recallRateOpt['type'] == 'sun2011':
            logger.info("Calculating recall rate: {}".format(
                recallRateOpt['type']))
            recallRateDataset = BugDataset(recallRateOpt['dataset'])

            rankingClass = SunRanking(bugReportDatabase, recallRateDataset,
                                      recallRateOpt['window'])
            # We always group all bug reports by master in the results in the sun 2011 methodology
            group_by_master = True
        elif recallRateOpt['type'] == 'deshmukh':
            logger.info("Calculating recall rate: {}".format(
                recallRateOpt['type']))
            recallRateDataset = BugDataset(recallRateOpt['dataset'])
            rankingClass = DeshmukhRanking(bugReportDatabase,
                                           recallRateDataset)
            group_by_master = recallRateOpt['group_by_master']
        else:
            raise ArgumentError(
                "recall_rate.type is invalid (%s). You should choose one of these: step, exp and linear "
                % recallRateOpt['type'])

        logRankingResult(_run,
                         logger,
                         rankingClass,
                         rankingScorer,
                         bugReportDatabase,
                         recallRateOpt["result_file"],
                         0,
                         None,
                         group_by_master,
                         recommendationListfn=recommendation_fn)