示例#1
0
    def __init__(self, cfg):
        FormSelect.__init__(self)
        TFModel.__init__(self,
                         scope_name='formselect-' +
                         cfg.get('scope_suffix', ''))
        # load configuration
        self._sample = cfg.get('form_sample', False)

        self.randomize = cfg.get('randomize', True)
        self.emb_size = cfg.get('emb_size', 50)
        self.passes = cfg.get('passes', 200)
        self.alpha = cfg.get('alpha', 1)
        self.batch_size = cfg.get('batch_size', 1)
        self.max_sent_len = cfg.get('max_sent_len', 32)
        self.cell_type = cfg.get('cell_type', 'lstm')
        self.max_grad_norm = cfg.get('max_grad_norm', 100)
        self.optimizer_type = cfg.get('optimizer_type', 'adam')
        self.max_cores = cfg.get('max_cores', 4)
        self.alpha_decay = cfg.get('alpha_decay', 0.0)
        self.vocab = {
            '<VOID>': self.VOID,
            '<GO>': self.GO,
            '<STOP>': self.STOP,
            '<UNK>': self.UNK
        }
        self.reverse_dict = {
            self.VOID: '<VOID>',
            self.GO: '<GO>',
            self.STOP: '<STOP>',
            self.UNK: '<UNK>'
        }
        self.vocab_size = None
        np.random.seed(rnd.randint(0, 2**32 - 1))
示例#2
0
文件: seq2seq.py 项目: pdsujnow/tgen
    def __init__(self, cfg):
        """Initialize the generator, fill in the configuration."""

        Seq2SeqBase.__init__(self, cfg)
        TFModel.__init__(self, scope_name='seq2seq_gen-' + cfg.get('scope_suffix', ''))

        # extract the individual elements out of the configuration dict

        self.emb_size = cfg.get('emb_size', 50)
        self.batch_size = cfg.get('batch_size', 10)
        self.dropout_keep_prob = cfg.get('dropout_prob', 1)
        self.optimizer_type = cfg.get('optimizer_type', 'adam')

        self.passes = cfg.get('passes', 5)
        self.min_passes = cfg.get('min_passes', 1)
        self.improve_interval = cfg.get('improve_interval', 10)
        self.top_k = cfg.get('top_k', 5)
        # self.checkpoint_dir = cfg.get('checkpoint_dir', '/tmp/')  # TODO fix (not used now)
        self.use_dec_cost = cfg.get('use_dec_cost', False)

        self.alpha = cfg.get('alpha', 1e-3)
        self.alpha_decay = cfg.get('alpha_decay', 0.0)
        self.validation_size = cfg.get('validation_size', 0)
        self.validation_freq = cfg.get('validation_freq', 10)
        self.multiple_refs = cfg.get('multiple_refs', False)  # multiple references for validation
        self.ref_selectors = cfg.get('ref_selectors', None)  # selectors of validation trees (if in separate file)
        self.max_cores = cfg.get('max_cores')
        self.use_tokens = cfg.get('use_tokens', False)
        self.nn_type = cfg.get('nn_type', 'emb_seq2seq')
        self.randomize = cfg.get('randomize', True)
        self.cell_type = cfg.get('cell_type', 'lstm')
        self.bleu_validation_weight = cfg.get('bleu_validation_weight', 0.0)

        self.use_context = cfg.get('use_context', False)
示例#3
0
文件: seq2seq.py 项目: qjay612/tgen
    def __init__(self, cfg):
        """Initialize the generator, fill in the configuration."""

        Seq2SeqBase.__init__(self, cfg)
        TFModel.__init__(self, scope_name='seq2seq_gen-' + cfg.get('scope_suffix', ''))

        # extract the individual elements out of the configuration dict

        self.emb_size = cfg.get('emb_size', 50)
        self.batch_size = cfg.get('batch_size', 10)
        self.dropout_keep_prob = cfg.get('dropout_prob', 1)
        self.optimizer_type = cfg.get('optimizer_type', 'adam')

        self.passes = cfg.get('passes', 5)
        self.min_passes = cfg.get('min_passes', 1)
        self.improve_interval = cfg.get('improve_interval', 10)
        self.top_k = cfg.get('top_k', 5)
        # self.checkpoint_dir = cfg.get('checkpoint_dir', '/tmp/')  # TODO fix (not used now)
        self.use_dec_cost = cfg.get('use_dec_cost', False)

        self.alpha = cfg.get('alpha', 1e-3)
        self.alpha_decay = cfg.get('alpha_decay', 0.0)
        self.validation_size = cfg.get('validation_size', 0)
        self.validation_freq = cfg.get('validation_freq', 10)
        self.multiple_refs = cfg.get('multiple_refs', False)  # multiple references for validation
        self.ref_selectors = cfg.get('ref_selectors', None)  # selectors of validation trees (if in separate file)
        self.max_cores = cfg.get('max_cores')
        self.mode = cfg.get('mode', 'tokens' if cfg.get('use_tokens') else 'trees')
        self.nn_type = cfg.get('nn_type', 'emb_seq2seq')
        self.randomize = cfg.get('randomize', True)
        self.cell_type = cfg.get('cell_type', 'lstm')
        self.bleu_validation_weight = cfg.get('bleu_validation_weight', 0.0)

        self.use_context = cfg.get('use_context', False)
示例#4
0
    def __init__(self, cfg):

        Reranker.__init__(self, cfg)
        TFModel.__init__(self, scope_name='rerank-' + cfg.get('scope_suffix', ''))

        self.tree_embs = cfg.get('nn', '').startswith('emb')
        if self.tree_embs:
            self.tree_embs = TreeEmbeddingClassifExtract(cfg)
            self.emb_size = cfg.get('emb_size', 50)

        self.nn_shape = cfg.get('nn_shape', 'ff')
        self.num_hidden_units = cfg.get('num_hidden_units', 512)

        self.passes = cfg.get('passes', 200)
        self.min_passes = cfg.get('min_passes', 0)
        self.alpha = cfg.get('alpha', 0.1)
        self.randomize = cfg.get('randomize', True)
        self.batch_size = cfg.get('batch_size', 1)

        self.validation_freq = cfg.get('validation_freq', 10)
        self.checkpoint_path = None
        self.max_cores = cfg.get('max_cores')

        # Train Summaries
        self.train_summary_dir = cfg.get('tb_summary_dir', None)
        if self.train_summary_dir:
            self.loss_summary_reranker = None
            self.train_summary_op = None
            self.train_summary_writer = None

        # backward compatibility flag -- will be 1 when loading older models
        self.version = 2