Exemplo n.º 1
0
def _postprocess_complete_hypos(hypos):
    """This function applies the following operations on the list of
    complete hypotheses returned by the Decoder:

      - </s> removal
      - Apply --nbest parameter if necessary
      - Applies combination_scheme on full hypotheses, reorder list

    Args:
      hypos (list): List of complete hypotheses

    Returns:
      list. Postprocessed hypotheses.
    """
    if args.remove_eos:
        for hypo in hypos:
            if (hypo.trgt_sentence 
                    and hypo.trgt_sentence[-1] == utils.EOS_ID):
                hypo.trgt_sentence = hypo.trgt_sentence[:-1]
    if args.nbest > 0:
        hypos = hypos[:args.nbest]
    kwargs={'full': True}
    if args.combination_scheme != 'sum': 
        if args.combination_scheme == 'length_norm':
            breakdown_fn = combination.breakdown2score_length_norm
        elif args.combination_scheme == 'bayesian_loglin':
            breakdown_fn = combination.breakdown2score_bayesian_loglin
        elif args.combination_scheme == 'bayesian':
            breakdown_fn = combination.breakdown2score_bayesian  
        elif args.combination_scheme == 'bayesian_state_dependent':
            breakdown_fn = combination.breakdown2score_bayesian_state_dependent  
            kwargs['lambdas'] = CombiBeamDecoder.get_domain_task_weights(
                args.bayesian_domain_task_weights)
        else:
            logging.warn("Unknown combination scheme '%s'" 
                         % args.combination_scheme)
        for hypo in hypos:
            hypo.total_score = breakdown_fn(
                    hypo.total_score, hypo.score_breakdown, **kwargs)
        hypos.sort(key=lambda hypo: hypo.total_score, reverse=True)
    return hypos
Exemplo n.º 2
0
def _postprocess_complete_hypos(hypos):
    """This function applies the following operations on the list of
    complete hypotheses returned by the Decoder:

      - </s> removal
      - Apply --nbest parameter if necessary
      - Applies combination_scheme on full hypotheses, reorder list

    Args:
      hypos (list): List of complete hypotheses

    Returns:
      list. Postprocessed hypotheses.
    """
    if args.remove_eos:
        for hypo in hypos:
            if (hypo.trgt_sentence 
                    and hypo.trgt_sentence[-1] == utils.EOS_ID):
                hypo.trgt_sentence = hypo.trgt_sentence[:-1]
    if args.nbest > 0:
        hypos = hypos[:args.nbest]
    kwargs={'full': True}
    if args.combination_scheme != 'sum': 
        if args.combination_scheme == 'length_norm':
            breakdown_fn = combination.breakdown2score_length_norm
        elif args.combination_scheme == 'bayesian_loglin':
            breakdown_fn = combination.breakdown2score_bayesian_loglin
        elif args.combination_scheme == 'bayesian':
            breakdown_fn = combination.breakdown2score_bayesian  
        elif args.combination_scheme == 'bayesian_state_dependent':
            breakdown_fn = combination.breakdown2score_bayesian_state_dependent  
            kwargs['lambdas'] = CombiBeamDecoder.get_domain_task_weights(
                args.bayesian_domain_task_weights)
        else:
            logging.warn("Unknown combination scheme '%s'" 
                         % args.combination_scheme)
        for hypo in hypos:
            hypo.total_score = breakdown_fn(
                    hypo.total_score, hypo.score_breakdown, **kwargs)
        hypos.sort(key=lambda hypo: hypo.total_score, reverse=True)
    return hypos
Exemplo n.º 3
0
def create_decoder():
    """Creates the ``Decoder`` instance. This specifies the search
    strategy used to traverse the space spanned by the predictors. This
    method relies on the global ``args`` variable.

    TODO: Refactor to avoid long argument lists

    Returns:
        Decoder. Instance of the search strategy
    """
    # Create decoder instance and add predictors
    if args.decoder == "greedy":
        decoder = GreedyDecoder(args)
    elif args.decoder == "beam":
        decoder = BeamDecoder(args)
    elif args.decoder == "simbeam":
        decoder = SimBeamDecoder(args)
    elif args.decoder == "simbeamv2":
        decoder = SimBeamDecoder_v2(args)
    elif args.decoder == "multisegbeam":
        decoder = MultisegBeamDecoder(args, args.hypo_recombination, args.beam,
                                      args.multiseg_tokenizations,
                                      args.early_stopping, args.max_word_len)
    elif args.decoder == "syncbeam":
        decoder = SyncBeamDecoder(args)
    elif args.decoder == "mbrbeam":
        decoder = MBRBeamDecoder(args)
    elif args.decoder == "sepbeam":
        decoder = SepBeamDecoder(args)
    elif args.decoder == "syntaxbeam":
        decoder = SyntaxBeamDecoder(args)
    elif args.decoder == "combibeam":
        decoder = CombiBeamDecoder(args)
    elif args.decoder == "dfs":
        decoder = DFSDecoder(args)
    elif args.decoder == "restarting":
        decoder = RestartingDecoder(args, args.hypo_recombination,
                                    args.max_node_expansions,
                                    args.low_decoder_memory,
                                    args.restarting_node_score,
                                    args.stochastic_decoder,
                                    args.decode_always_single_step)
    elif args.decoder == "bow":
        decoder = BOWDecoder(args)
    elif args.decoder == "flip":
        decoder = FlipDecoder(args)
    elif args.decoder == "bigramgreedy":
        decoder = BigramGreedyDecoder(args)
    elif args.decoder == "bucket":
        decoder = BucketDecoder(
            args, args.hypo_recombination, args.max_node_expansions,
            args.low_decoder_memory, args.beam, args.pure_heuristic_scores,
            args.decoder_diversity_factor, args.early_stopping,
            args.stochastic_decoder, args.bucket_selector,
            args.bucket_score_strategy, args.collect_statistics)
    elif args.decoder == "astar":
        decoder = AstarDecoder(args)
    elif args.decoder == "vanilla":
        decoder = construct_nmt_vanilla_decoder()
        args.predictors = "vanilla"
    else:
        logging.fatal("Decoder %s not available. Please double-check the "
                      "--decoder parameter." % args.decoder)
    add_predictors(decoder)

    # Add heuristics for search strategies like A*
    if args.heuristics:
        add_heuristics(decoder)

    # Update start sentence id if necessary
    if args.range:
        idx, _ = args.range.split(":") if (":" in args.range) else (args.range,
                                                                    0)
        decoder.set_start_sen_id(int(idx) -
                                 1)  # -1 because indices start with 1
    return decoder
Exemplo n.º 4
0
def create_decoder():
    """Creates the ``Decoder`` instance. This specifies the search 
    strategy used to traverse the space spanned by the predictors. This
    method relies on the global ``args`` variable.
    
    TODO: Refactor to avoid long argument lists
    
    Returns:
        Decoder. Instance of the search strategy
    """
    # Create decoder instance and add predictors
    decoder = None
    try:
        if args.decoder == "greedy":
            decoder = GreedyDecoder(args)
        elif args.decoder == "beam":
            decoder = BeamDecoder(args)
        elif args.decoder == "multisegbeam":
            decoder = MultisegBeamDecoder(args,
                                          args.hypo_recombination,
                                          args.beam,
                                          args.multiseg_tokenizations,
                                          args.early_stopping,
                                          args.max_word_len)
        elif args.decoder == "syncbeam":
            decoder = SyncBeamDecoder(args)
        elif args.decoder == "mbrbeam":
            decoder = MBRBeamDecoder(args)
        elif args.decoder == "sepbeam":
            decoder = SepBeamDecoder(args)
        elif args.decoder == "syntaxbeam":
            decoder = SyntaxBeamDecoder(args)
        elif args.decoder == "combibeam":
            decoder = CombiBeamDecoder(args)
        elif args.decoder == "dfs":
            decoder = DFSDecoder(args)
        elif args.decoder == "restarting":
            decoder = RestartingDecoder(args,
                                        args.hypo_recombination,
                                        args.max_node_expansions,
                                        args.low_decoder_memory,
                                        args.restarting_node_score,
                                        args.stochastic_decoder,
                                        args.decode_always_single_step)
        elif args.decoder == "bow":
            decoder = BOWDecoder(args)
        elif args.decoder == "flip":
            decoder = FlipDecoder(args)
        elif args.decoder == "bigramgreedy":
            decoder = BigramGreedyDecoder(args)
        elif args.decoder == "bucket":
            decoder = BucketDecoder(args,
                                    args.hypo_recombination,
                                    args.max_node_expansions,
                                    args.low_decoder_memory,
                                    args.beam,
                                    args.pure_heuristic_scores,
                                    args.decoder_diversity_factor,
                                    args.early_stopping,
                                    args.stochastic_decoder,
                                    args.bucket_selector,
                                    args.bucket_score_strategy,
                                    args.collect_statistics)
        elif args.decoder == "astar":
            decoder = AstarDecoder(args)
        elif args.decoder == "vanilla":
            decoder = construct_nmt_vanilla_decoder()
            args.predictors = "vanilla"
        else:
            logging.fatal("Decoder %s not available. Please double-check the "
                          "--decoder parameter." % args.decoder)
    except Exception as e:
        logging.fatal("An %s has occurred while initializing the decoder: %s"
                      " Stack trace: %s" % (sys.exc_info()[0],
                                            e,
                                            traceback.format_exc()))
    if decoder is None:
        sys.exit("Could not initialize decoder.")
    add_predictors(decoder)
    # Add heuristics for search strategies like A*
    if args.heuristics:
        add_heuristics(decoder)
    return decoder