Esempio n. 1
0
    def __init__(self, logger, config):
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')

        self.logger = logger
        self.train_config = registry.instantiate(TrainConfig, config['train'])
        self.data_random = random_state.RandomContext(
            self.train_config.data_seed)
        self.model_random = random_state.RandomContext(
            self.train_config.model_seed)

        self.init_random = random_state.RandomContext(
            self.train_config.init_seed)
        with self.init_random:
            # 0. Construct preprocessors
            self.model_preproc = registry.instantiate(registry.lookup(
                'model', config['model']).Preproc,
                                                      config['model'],
                                                      unused_keys=('name', ))
            self.model_preproc.load()

            # 1. Construct model
            self.model = registry.construct('model',
                                            config['model'],
                                            unused_keys=('encoder_preproc',
                                                         'decoder_preproc'),
                                            preproc=self.model_preproc,
                                            device=self.device)
            self.model.to(self.device)
Esempio n. 2
0
    def __init__(self, config):
        self.config = config
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
            torch.set_num_threads(1)

        # 0. Construct preprocessors
        self.model_preproc = registry.instantiate(
            registry.lookup('model', config['model']).Preproc, config['model'])
        self.model_preproc.load()
Esempio n. 3
0
    def __init__(self, logger, config, gpu):
        if torch.cuda.is_available():
            self.device = torch.device('cuda:{}'.format(gpu))
        else:
            self.device = torch.device('cpu')
        random.seed(1)
        numpy.random.seed(1)
        torch.manual_seed(1)
        torch.cuda.manual_seed_all(1)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        self.logger = logger
        self.train_config = registry.instantiate(TrainConfig, config['train'])
        self.train_config.eval_every_n = 500
        self.train_config.save_every_n = 500
        self.data_random = random_state.RandomContext(
            self.train_config.data_seed)
        self.model_random = random_state.RandomContext(
            self.train_config.model_seed)

        self.init_random = random_state.RandomContext(
            self.train_config.init_seed)
        with self.init_random:
            # 0. Construct preprocessors
            self.model_preproc = registry.instantiate(registry.lookup(
                'model', config['model']).Preproc,
                                                      config['model'],
                                                      unused_keys=('name', ))
            self.model_preproc.load()

            # 1. Construct model
            self.model = registry.construct('model',
                                            config['model'],
                                            unused_keys=('encoder_preproc',
                                                         'decoder_preproc'),
                                            preproc=self.model_preproc,
                                            device=self.device)
            self.model.to(self.device)
Esempio n. 4
0
    def __init__(
            self,
            device,
            preproc,
            word_emb_size=128,
            recurrent_size=256,
            dropout=0.,
            question_encoder=('emb', 'bilstm'),
            column_encoder=('emb', 'bilstm'),
            table_encoder=('emb', 'bilstm'),
            update_config={},
            include_in_memory=('question', 'column', 'table'),
            batch_encs_update=True,
            top_k_learnable=0):
        super().__init__()
        self._device = device
        self.preproc = preproc

        self.vocab = preproc.vocab
        self.word_emb_size = word_emb_size
        self.recurrent_size = recurrent_size
        assert self.recurrent_size % 2 == 0
        word_freq = self.preproc.vocab_builder.word_freq
        top_k_words = set([_a[0] for _a in word_freq.most_common(top_k_learnable)])
        self.learnable_words = top_k_words

        self.include_in_memory = set(include_in_memory)
        self.dropout = dropout

        self.question_encoder = self._build_modules(question_encoder)
        self.column_encoder = self._build_modules(column_encoder)
        self.table_encoder = self._build_modules(table_encoder)

        update_modules = {
            'relational_transformer':
                spider_enc_modules.RelationalTransformerUpdate,
            'none':
                spider_enc_modules.NoOpUpdate,
        }

        self.encs_update = registry.instantiate(
            update_modules[update_config['name']],
            update_config,
            unused_keys={"name"},
            device=self._device,
            hidden_size=recurrent_size,
        )
        self.batch_encs_update = batch_encs_update
Esempio n. 5
0
    def __init__(self,
                 device,
                 preproc,
                 update_config={},
                 bert_token_type=False,
                 bert_version="bert-base-uncased",
                 summarize_header="first",
                 use_column_type=True,
                 include_in_memory=('question', 'column', 'table')):
        super().__init__()
        self._device = device
        self.preproc = preproc
        self.bert_token_type = bert_token_type
        self.base_enc_hidden_size = 1024 if bert_version in [
            "bert-large-uncased-whole-word-masking",
            "Salesforce/grappa_large_jnt"
        ] else 768
        assert summarize_header in ["first", "avg"]
        self.summarize_header = summarize_header
        self.enc_hidden_size = self.base_enc_hidden_size
        self.use_column_type = use_column_type

        self.include_in_memory = set(include_in_memory)
        update_modules = {
            'relational_transformer':
            spider_enc_modules.RelationalTransformerUpdate,
            'none': spider_enc_modules.NoOpUpdate,
        }

        self.encs_update = registry.instantiate(
            update_modules[update_config['name']],
            update_config,
            unused_keys={"name"},
            device=self._device,
            hidden_size=self.enc_hidden_size,
            sc_link=True,
        )

        if bert_version == 'Salesforce/grappa_large_jnt':
            self.bert_model = RobertaModel.from_pretrained(bert_version)
        else:
            self.bert_model = BertModel.from_pretrained(bert_version)
        self.tokenizer = self.preproc.tokenizer
        self.bert_model.resize_token_embeddings(len(
            self.tokenizer))  # several tokens added
Esempio n. 6
0
 def __init__(self, config):
     self.config = config
     self.model_preproc = registry.instantiate(
         registry.lookup('model', config['model']).Preproc, config['model'])
Esempio n. 7
0
    def __init__(self,
                 device,
                 preproc,
                 encode_size,
                 update_config={},
                 inputembedding_config={},
                 dropout=0.1,
                 encoder_num_layers=1,
                 bert_token_type=False,
                 bert_version="bert-base-uncased",
                 summarize_header="first",
                 use_column_type=True,
                 use_discourse_level_lstm=True,
                 use_utterance_attention=True,
                 include_in_memory=('question', 'column', 'table')):
        super().__init__()
        self._device = device
        self.dropout = dropout
        self.preproc = preproc  #预处理
        self.bert_token_type = bert_token_type  #True
        self.base_enc_hidden_size = 1024 if bert_version == "bert-large-uncased-whole-word-masking" else 768

        assert summarize_header in ["first", "avg"]
        self.summarize_header = summarize_header  #avg
        self.enc_hidden_size = encode_size
        self.use_discourse_level_lstm = use_discourse_level_lstm
        self.use_column_type = use_column_type  #False
        self.use_utterance_attention = use_utterance_attention
        self.num_utterances_to_keep = inputembedding_config[
            "num_utterance_keep"]

        if self.use_discourse_level_lstm:
            self.utterance_encoder = Encoder(
                encoder_num_layers,
                self.base_enc_hidden_size + self.enc_hidden_size / 2,
                self.enc_hidden_size)
        else:
            self.utterance_encoder = Encoder(encoder_num_layers,
                                             self.base_enc_hidden_size,
                                             self.enc_hidden_size)
        self.schema_encoder = Encoder(encoder_num_layers,
                                      self.base_enc_hidden_size,
                                      self.enc_hidden_size)
        self.table_encoder = Encoder(encoder_num_layers,
                                     self.base_enc_hidden_size,
                                     self.enc_hidden_size)
        self.include_in_memory = set(
            include_in_memory)  #('question', 'column', 'table')
        update_modules = {
            'relational_transformer':
            sparc_enc_modules.RelationalTransformerUpdate,  #走这个
            'none': sparc_enc_modules.NoOpUpdate,
        }
        self.input_embedding = registry.instantiate(
            sparc_enc_modules.InputsquenceEmbedding,
            inputembedding_config,
            unused_keys=('name', ),
            device=self._device,
            hidden_size=self.enc_hidden_size,
        )

        self.encs_update = registry.instantiate(
            update_modules[update_config['name']],
            update_config,
            unused_keys={"name"},
            device=self._device,
            hidden_size=self.enc_hidden_size,
            sc_link=True,
        )

        self.bert_model = BertModel.from_pretrained(bert_version)
        self.tokenizer = self.preproc.tokenizer
        self.bert_model.resize_token_embeddings(len(
            self.tokenizer))  # several tokens added