Beispiel #1
0
    def infer(self, model, output_path, args):
        output = open(output_path, 'w')

        with torch.no_grad():
            if args.mode == 'infer':
                orig_data = registry.construct('dataset', self.config['data'][args.section])
                preproc_data = self.model_preproc.dataset(args.section)
                if args.limit:
                    sliced_orig_data = itertools.islice(orig_data, args.limit)
                    sliced_preproc_data = itertools.islice(preproc_data, args.limit)
                else:
                    sliced_orig_data = orig_data
                    sliced_preproc_data = preproc_data
                assert len(orig_data) == len(preproc_data)
                self._inner_infer(model, args.beam_size, args.output_history, sliced_orig_data, sliced_preproc_data, output, args.use_heuristic)
            elif args.mode == 'debug':
                data = self.model_preproc.dataset(args.section)
                if args.limit:
                    sliced_data = itertools.islice(data, args.limit)
                else:
                    sliced_data = data
                self._debug(model, sliced_data, output)
            elif args.mode == 'visualize_attention':
                model.visualize_flag = True
                model.decoder.visualize_flag = True
                data = registry.construct('dataset', self.config['data'][args.section])
                if args.limit:
                    sliced_data = itertools.islice(data, args.limit)
                else:
                    sliced_data = data
                self._visualize_attention(model, args.beam_size, args.output_history, sliced_data, args.res1, args.res2, args.res3, output)
Beispiel #2
0
 def __init__(self, preproc, device, encoder, decoder):
     super().__init__()
     self.preproc = preproc
     self.encoder = registry.construct(
             'encoder', encoder, device=device, preproc=preproc.enc_preproc)
     self.decoder = registry.construct(
             'decoder', decoder, device=device, preproc=preproc.dec_preproc)
     self.decoder.visualize_flag = False
     
     if getattr(self.encoder, 'batched'):
         self.compute_loss = self._compute_loss_enc_batched
     else:
         self.compute_loss = self._compute_loss_unbatched
Beispiel #3
0
    def __init__(self, grammar, save_path, censor_pointers):
        self.save_path = save_path
        self.censor_pointers = censor_pointers
        self.grammar = registry.construct('grammar', grammar)
        self.ast_wrapper = self.grammar.ast_wrapper

        self.items = collections.defaultdict(list)
Beispiel #4
0
def compute_metrics(config_path, config_args, section, inferred_path,logdir=None):
    if config_args:
        config = json.loads(_jsonnet.evaluate_file(config_path, tla_codes={'args': config_args}))
    else:
        config = json.loads(_jsonnet.evaluate_file(config_path))

    if 'model_name' in config and logdir:
        logdir = os.path.join(logdir, config['model_name'])
    if logdir:
        inferred_path = inferred_path.replace('__LOGDIR__', logdir)

    inferred = open(inferred_path)
    data = registry.construct('dataset', config['data'][section])
    metrics = data.Metrics(data)

    inferred_lines = list(inferred)
    if len(inferred_lines) < len(data):
        raise Exception('Not enough inferred: {} vs {}'.format(len(inferred_lines),
          len(data)))


    for line in inferred_lines:
        infer_results = json.loads(line)
        if infer_results['beams']:
            inferred_code = infer_results['beams'][0]['inferred_code']
        else:
            inferred_code = None
        if 'index' in infer_results:
            metrics.add(data[infer_results['index']], inferred_code)
        else:
            metrics.add(None, inferred_code, obsolete_gold_code=infer_results['gold_code'])

    return logdir, metrics.finalize()
Beispiel #5
0
    def __init__(self, logger, config):
        if torch.cuda.is_available():  # and False:
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')

        self.logger = logger
        self.train_config = registry.instantiate(TrainConfig, config['train'])
        self.data_random = random_state.RandomContext(
            self.train_config.data_seed)
        self.model_random = random_state.RandomContext(
            self.train_config.model_seed)

        self.init_random = random_state.RandomContext(
            self.train_config.init_seed)
        with self.init_random:
            # 0. Construct preprocessors
            self.model_preproc = registry.instantiate(registry.lookup(
                'model', config['model']).Preproc,
                                                      config['model'],
                                                      unused_keys=('name', ))
            self.model_preproc.load()

            # 1. Construct model
            self.model = registry.construct('model',
                                            config['model'],
                                            unused_keys=('encoder_preproc',
                                                         'decoder_preproc'),
                                            preproc=self.model_preproc,
                                            device=device)
            self.model.to(device)
Beispiel #6
0
    def infer(self, model, output_path, args):
        # 3. Get training data somewhere
        output = open(output_path, 'w')
        orig_data = registry.construct('dataset',
                                       self.config['data'][args.section])
        sliced_orig_data = maybe_slice(orig_data, args.start_offset,
                                       args.limit)
        preproc_data = self.model_preproc.dataset(args.section)
        sliced_preproc_data = maybe_slice(preproc_data, args.start_offset,
                                          args.limit)

        with torch.no_grad():
            if args.mode == 'infer':
                assert len(orig_data) == len(preproc_data)
                self._inner_infer(model, args.beam_size, args.output_history,
                                  sliced_orig_data, sliced_preproc_data,
                                  output, args.nproc)
            elif args.mode == 'debug':
                self._debug(model, sliced_orig_data, output)
            elif args.mode == 'visualize_attention':
                model.visualize_flag = True
                model.decoder.visualize_flag = True
                self._visualize_attention(model, args.beam_size,
                                          args.output_history,
                                          sliced_orig_data, args.res1,
                                          args.res2, args.res3, output)
Beispiel #7
0
    def __init__(self,
                 grammar,
                 save_path,
                 min_freq=3,
                 max_count=5000,
                 use_seq_elem_rules=False):
        self.grammar = registry.construct('grammar', grammar)
        self.ast_wrapper = self.grammar.ast_wrapper

        self.vocab_path = os.path.join(save_path, 'dec_vocab.json')
        self.observed_productions_path = os.path.join(
            save_path, 'observed_productions.json')
        self.grammar_rules_path = os.path.join(save_path, 'grammar_rules.json')
        self.data_dir = os.path.join(save_path, 'dec')

        self.vocab_builder = vocab.VocabBuilder(min_freq, max_count)
        self.use_seq_elem_rules = use_seq_elem_rules

        self.items = collections.defaultdict(list)
        self.sum_type_constructors = collections.defaultdict(set)
        self.field_presence_infos = collections.defaultdict(set)
        self.seq_lengths = collections.defaultdict(set)
        self.primitive_types = set()

        self.vocab = None
        self.all_rules = None
        self.rules_mask = None
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', required=True)
    parser.add_argument('--config-args')
    parser.add_argument('--section', required=True)
    parser.add_argument('--inferred', required=True)
    parser.add_argument('--output', required=True)
    args = parser.parse_args()
    if args.config_args:
        config = json.loads(
            _jsonnet.evaluate_file(args.config,
                                   tla_codes={'args': args.config_args}))
    else:
        config = json.loads(_jsonnet.evaluate_file(args.config))

    os.makedirs(args.output, exist_ok=True)
    gold = open(os.path.join(args.output, 'gold.txt'), 'w')
    predicted = open(os.path.join(args.output, 'predicted.txt'), 'w')

    inferred = open(args.inferred)
    data = registry.construct('dataset', config['data'][args.section])

    for line in inferred:
        infer_results = json.loads(line)
        if infer_results['beams']:
            inferred_code = infer_results['beams'][0]['inferred_code']
        else:
            inferred_code = 'SELECT a FROM b'
        item = data[infer_results['index']]
        gold.write('{}\t{}\n'.format(item.orig['query'].replace('\t', ' '),
                                     item.schema.db_id))
        predicted.write('{}\n'.format(inferred_code))
Beispiel #9
0
    def load_model(self, logdir, step):
        '''Load a model (identified by the config used for construction) and return it'''
        # 1. Construct model
        model = registry.construct('model', self.config['model'], preproc=self.model_preproc, device=self.device)
        model.to(self.device)
        model.eval()
        model.visualize_flag = False

        optimizer = registry.construct('optimizer', self.config['optimizer'], params=model.parameters())

        # 2. Restore its parameters
        saver = saver_mod.Saver(model, optimizer)
        last_step = saver.restore(logdir, step=step, map_location=self.device)
        if not last_step:
            raise Exception('Attempting to infer on untrained model')
        return model
Beispiel #10
0
 def preprocess(self):
     self.model_preproc.clear_items()
     for section in self.config['data']:
         data = registry.construct('dataset', self.config['data'][section])
         for item in tqdm.tqdm(data, desc=section, dynamic_ncols=True):
             to_add, validation_info = self.model_preproc.validate_item(item, section)
             if to_add:
                 self.model_preproc.add_item(item, section, validation_info)
     self.model_preproc.save()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', required=True)
    parser.add_argument('--config-args')
    args = parser.parse_args()

    if args.config_args:
        config = json.loads(
            _jsonnet.evaluate_file(args.config,
                                   tla_codes={'args': args.config_args}))
    else:
        config = json.loads(_jsonnet.evaluate_file(args.config))

    train_data = registry.construct('dataset', config['data']['train'])
    grammar = registry.construct('grammar',
                                 config['model']['decoder_preproc']['grammar'])
    base_grammar = registry.construct(
        'grammar',
        config['model']['decoder_preproc']['grammar']['base_grammar'])

    for i, item in enumerate(tqdm.tqdm(train_data, dynamic_ncols=True)):
        parsed = grammar.parse(item.code, 'train')
        orig_parsed = base_grammar.parse(item.orig['orig'], 'train')

        canonicalized_orig_code = base_grammar.unparse(
            base_grammar.parse(item.orig['orig'], 'train'), item)
        unparsed = grammar.unparse(parsed, item)
        if canonicalized_orig_code != unparsed:
            print('Original tree:')
            pprint.pprint(orig_parsed)
            print('Rewritten tree:')
            pprint.pprint(parsed)
            print('Reconstructed tree:')
            pprint.pprint(grammar._expand_templates(parsed))
            print('Original code:')
            print(canonicalized_orig_code)
            print('Reconstructed code:')
            print(unparsed)

            import IPython
            IPython.embed()
            break
Beispiel #12
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', required=True)
    parser.add_argument('--config-args')
    parser.add_argument('--output', required=True)
    args = parser.parse_args()

    if args.config_args:
        config = json.loads(
            _jsonnet.evaluate_file(args.config,
                                   tla_codes={'args': args.config_args}))
    else:
        config = json.loads(_jsonnet.evaluate_file(args.config))

    os.makedirs(args.output, exist_ok=True)
    gold = open(os.path.join(args.output, 'gold.txt'), 'w')
    predicted = open(os.path.join(args.output, 'predicted.txt'), 'w')

    train_data = registry.construct('dataset', config['data']['train'])
    grammar = registry.construct('grammar',
                                 config['model']['decoder_preproc']['grammar'])

    evaluator = evaluation.Evaluator(
        'data/spider-20190205/database',
        evaluation.build_foreign_key_map_from_json(
            'data/spider-20190205/tables.json'), 'match')

    for i, item in enumerate(tqdm.tqdm(train_data, dynamic_ncols=True)):
        parsed = grammar.parse(item.code, 'train')
        sql = grammar.unparse(parsed, item)

        evaluator.evaluate_one(item.schema.db_id,
                               item.orig['query'].replace('\t', ' '), sql)

        gold.write('{}\t{}\n'.format(item.orig['query'].replace('\t', ' '),
                                     item.schema.db_id))
        predicted.write('{}\n'.format(sql))
Beispiel #13
0
def load_db():
    db_id = "singer"
    my_schema = dump_db_json_schema(
        "data/sqlite_files/{db_id}/{db_id}.sqlite".format(db_id=db_id), db_id)
    schema, eval_foreign_key_maps = load_tables_from_schema_dict(my_schema)
    schema.keys()
    dataset = registry.construct(
        'dataset_infer', {
            "name": "spider",
            "schemas": schema,
            "eval_foreign_key_maps": eval_foreign_key_maps,
            "db_path": "data/sqlite_files/"
        })
    for _, schema in dataset.schemas.items():
        model.preproc.enc_preproc._preprocess_schema(schema)
    return dataset.schemas[db_id]
Beispiel #14
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', required=True)
    parser.add_argument('--config-args')
    args = parser.parse_args()

    if args.config_args:
        config = json.loads(_jsonnet.evaluate_file(args.config, tla_codes={'args': args.config_args}))
    else:
        config = json.loads(_jsonnet.evaluate_file(args.config))

    model_preproc = registry.instantiate(
        registry.lookup('model', config['model']).Preproc,
        config['model'])

    for section in config['data']:
        data = registry.construct('dataset', config['data'][section])
        for item in tqdm.tqdm(data, desc=section, dynamic_ncols=True):
            to_add, validation_info = model_preproc.validate_item(item, section)
            if to_add:
                model_preproc.add_item(item, section, validation_info)
    model_preproc.save()
Beispiel #15
0
        def __init__(self,
                     save_path,
                     min_freq=3,
                     max_count=5000,
                     include_table_name_in_column=True,
                     word_emb=None,
                     count_tokens_in_word_emb_for_vocab=False):
            if word_emb is None:
                self.word_emb = None
            else:
                self.word_emb = registry.construct('word_emb', word_emb)

            self.data_dir = os.path.join(save_path, 'enc')
            self.include_table_name_in_column = include_table_name_in_column
            self.count_tokens_in_word_emb_for_vocab = count_tokens_in_word_emb_for_vocab
            # TODO: Write 'train', 'val', 'test' somewhere else
            self.texts = {'train': [], 'val': [], 'test': []}

            self.vocab_builder = vocab.VocabBuilder(min_freq, max_count)
            self.vocab_path = os.path.join(save_path, 'enc_vocab.json')
            self.vocab = None
            self.counted_db_ids = set()
Beispiel #16
0
def compute_metrics(config_path, config_args, section, inferred_path,logdir=None, evaluate_beams_individually=False):
    if config_args:
        config = json.loads(_jsonnet.evaluate_file(config_path, tla_codes={'args': config_args}))
    else:
        config = json.loads(_jsonnet.evaluate_file(config_path))

    if 'model_name' in config and logdir:
        logdir = os.path.join(logdir, config['model_name'])
    if logdir:
        inferred_path = inferred_path.replace('__LOGDIR__', logdir)

    inferred = open(inferred_path)
    data = registry.construct('dataset', config['data'][section])

    inferred_lines = list(inferred)
    if len(inferred_lines) < len(data):
        raise Exception('Not enough inferred: {} vs {}'.format(len(inferred_lines),
          len(data)))

    if evaluate_beams_individually:
        return logdir, evaluate_all_beams(data, inferred_lines)
    else:
        return logdir, evaluate_default(data, inferred_lines)
Beispiel #17
0
def main():
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
        torch.set_num_threads(1)
    if args.config_args:
        config = json.loads(
            _jsonnet.evaluate_file(args.config,
                                   tla_codes={'args': args.config_args}))
    else:
        config = json.loads(_jsonnet.evaluate_file(args.config))

    if 'model_name' in config:
        args.logdir = os.path.join(args.logdir, config['model_name'])

    output_path = args.output.replace('__LOGDIR__', args.logdir)
    if os.path.exists(output_path):
        print('Output file {} already exists'.format(output_path))
        sys.exit(1)

    # 0. Construct preprocessors
    model_preproc = registry.instantiate(
        registry.lookup('model', config['model']).Preproc, config['model'])
    model_preproc.load()

    # 1. Construct model
    model = registry.construct('model',
                               config['model'],
                               preproc=model_preproc,
                               device=device)
    model.to(device)
    model.eval()
    model.visualize_flag = False

    optimizer = registry.construct('optimizer',
                                   config['optimizer'],
                                   params=model.parameters())

    # 2. Restore its parameters
    saver = saver_mod.Saver(model, optimizer)
    last_step = saver.restore(args.logdir, step=args.step, map_location=device)
    if not last_step:
        raise Exception('Attempting to infer on untrained model')

    # 3. Get training data somewhere
    output = open(output_path, 'w')
    data = registry.construct('dataset', config['data'][args.section])
    if args.limit:
        sliced_data = itertools.islice(data, args.limit)
    else:
        sliced_data = data

    with torch.no_grad():
        if args.mode == 'infer':
            orig_data = registry.construct('dataset',
                                           config['data'][args.section])
            preproc_data = model_preproc.dataset(args.section)
            if args.limit:
                sliced_orig_data = itertools.islice(data, args.limit)
                sliced_preproc_data = itertools.islice(data, args.limit)
            else:
                sliced_orig_data = orig_data
                sliced_preproc_data = preproc_data
            assert len(orig_data) == len(preproc_data)
            infer(model, args.beam_size, args.output_history, sliced_orig_data,
                  sliced_preproc_data, output)
        elif args.mode == 'debug':
            data = model_preproc.dataset(args.section)
            if args.limit:
                sliced_data = itertools.islice(data, args.limit)
            else:
                sliced_data = data
            debug(model, sliced_data, output)
        elif args.mode == 'visualize_attention':
            model.visualize_flag = True
            model.decoder.visualize_flag = True
            data = registry.construct('dataset', config['data'][args.section])
            if args.limit:
                sliced_data = itertools.islice(data, args.limit)
            else:
                sliced_data = data
            visualize_attention(model, args.beam_size, args.output_history,
                                sliced_data, output)
Beispiel #18
0
    def train(self, config, modeldir):
        # slight difference here vs. unrefactored train: The init_random starts over here. Could be fixed if it was important by saving random state at end of init
        with self.init_random:
            # We may be able to move optimizer and lr_scheduler to __init__ instead. Empirically it works fine. I think that's because saver.restore 
            # resets the state by calling optimizer.load_state_dict. 
            # But, if there is no saved file yet, I think this is not true, so might need to reset the optimizer manually?
            # For now, just creating it from scratch each time is safer and appears to be the same speed, but also means you have to pass in the config to train which is kind of ugly.
            optimizer = registry.construct('optimizer', config['optimizer'], params=self.model.parameters())
            lr_scheduler = registry.construct(
                    'lr_scheduler',
                    config.get('lr_scheduler', {'name': 'noop'}),
                    optimizer=optimizer)

        # 2. Restore model parameters
        saver = saver_mod.Saver(
            self.model, optimizer, keep_every_n=self.train_config.keep_every_n)
        last_step = saver.restore(modeldir)

        # 3. Get training data somewhere
        with self.data_random:
            train_data = self.model_preproc.dataset('train')
            train_data_loader = self._yield_batches_from_epochs(
                torch.utils.data.DataLoader(
                    train_data,
                    batch_size=self.train_config.batch_size,
                    shuffle=True,
                    drop_last=True,
                    collate_fn=lambda x: x))
        train_eval_data_loader = torch.utils.data.DataLoader(
                train_data,
                batch_size=self.train_config.eval_batch_size,
                collate_fn=lambda x: x)

        val_data = self.model_preproc.dataset('val')
        val_data_loader = torch.utils.data.DataLoader(
                val_data,
                batch_size=self.train_config.eval_batch_size,
                collate_fn=lambda x: x)

        # 4. Start training loop
        with self.data_random:
            for batch in train_data_loader:
                # Quit if too long
                if last_step >= self.train_config.max_steps:
                    break

                # Evaluate model
                if last_step % self.train_config.eval_every_n == 0:
                    if self.train_config.eval_on_train:
                        self._eval_model(self.logger, self.model, last_step, train_eval_data_loader, 'train', num_eval_items=self.train_config.num_eval_items)
                    if self.train_config.eval_on_val:
                        self._eval_model(self.logger, self.model, last_step, val_data_loader, 'val', num_eval_items=self.train_config.num_eval_items)

                # Compute and apply gradient
                with self.model_random:
                    optimizer.zero_grad()
                    loss = self.model.compute_loss(batch)
                    loss.backward()
                    lr_scheduler.update_lr(last_step)
                    optimizer.step()

                # Report metrics
                if last_step % self.train_config.report_every_n == 0:
                    self.logger.log('Step {}: loss={:.4f}'.format(last_step, loss.item()))

                last_step += 1
                # Run saver
                if last_step % self.train_config.save_every_n == 0:
                    saver.save(modeldir, last_step)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', required=True)
    parser.add_argument('--config-args')
    args = parser.parse_args()

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    if args.config_args:
        config = json.loads(
            _jsonnet.evaluate_file(args.config,
                                   tla_codes={'args': args.config_args}))
    else:
        config = json.loads(_jsonnet.evaluate_file(args.config))

    # 0. Construct preprocessors
    model_preproc = registry.instantiate(registry.lookup(
        'model', config['model']).Preproc,
                                         config['model'],
                                         unused_keys=('name', ))
    model_preproc.load()

    # 1. Construct model
    model = registry.construct('model',
                               config['model'],
                               unused_keys=('encoder_preproc',
                                            'decoder_preproc'),
                               preproc=model_preproc,
                               device=device)
    model.to(device)
    model.eval()

    # 3. Get training data somewhere
    train_data = model_preproc.dataset('train')
    train_eval_data_loader = torch.utils.data.DataLoader(
        train_data, batch_size=10, collate_fn=lambda x: x)

    batch = next(iter(train_eval_data_loader))
    descs = [x for x, y in batch]

    q0, qb = test_enc_equal([descs[0]['question']],
                            [[desc['question']] for desc in descs],
                            model.encoder.question_encoder)

    c0, cb = test_enc_equal(descs[0]['columns'],
                            [desc['columns'] for desc in descs],
                            model.encoder.column_encoder)

    t0, tb = test_enc_equal(descs[0]['tables'],
                            [desc['tables'] for desc in descs],
                            model.encoder.table_encoder)

    q0_enc, c0_enc, t0_enc = model.encoder.encs_update.forward_unbatched(
        descs[0], q0[0], c0[0], c0[1], t0[0], t0[1])
    qb_enc, cb_enc, tb_enc = model.encoder.encs_update.forward(
        descs, qb[0], cb[0], cb[1], tb[0], tb[1])

    check_close(q0_enc.squeeze(1), qb_enc.select(0))
    check_close(c0_enc.squeeze(1), cb_enc.select(0))
    check_close(t0_enc.squeeze(1), tb_enc.select(0))
Beispiel #20
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--logdir', required=True)
    parser.add_argument('--config', required=True)
    parser.add_argument('--config-args')
    args = parser.parse_args()

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    if args.config_args:
        config = json.loads(_jsonnet.evaluate_file(args.config, tla_codes={'args': args.config_args}))
    else:
        config = json.loads(_jsonnet.evaluate_file(args.config))

    if 'model_name' in config:
        args.logdir = os.path.join(args.logdir, config['model_name'])
    train_config = registry.instantiate(TrainConfig, config['train'])

    reopen_to_flush = config.get('log', {}).get('reopen_to_flush')
    logger = Logger(os.path.join(args.logdir, 'log.txt'), reopen_to_flush)
    with open(os.path.join(args.logdir,
          'config-{}.json'.format(
            datetime.datetime.now().strftime('%Y%m%dT%H%M%S%Z'))), 'w') as f:
        json.dump(config, f, sort_keys=True, indent=4)
    logger.log('Logging to {}'.format(args.logdir))

    init_random = random_state.RandomContext(train_config.init_seed)
    data_random = random_state.RandomContext(train_config.data_seed)
    model_random = random_state.RandomContext(train_config.model_seed)

    with init_random:
        # 0. Construct preprocessors
        model_preproc = registry.instantiate(
            registry.lookup('model', config['model']).Preproc,
            config['model'],
            unused_keys=('name',))
        model_preproc.load()

        # 1. Construct model
        model = registry.construct('model', config['model'],
                unused_keys=('encoder_preproc', 'decoder_preproc'), preproc=model_preproc, device=device)
        model.to(device)

        optimizer = registry.construct('optimizer', config['optimizer'], params=model.parameters())
        lr_scheduler = registry.construct(
                'lr_scheduler',
                config.get('lr_scheduler', {'name': 'noop'}),
                optimizer=optimizer)

    # 2. Restore its parameters
    saver = saver_mod.Saver(
        model, optimizer, keep_every_n=train_config.keep_every_n)
    last_step = saver.restore(args.logdir)

    # 3. Get training data somewhere
    with data_random:
        train_data = model_preproc.dataset('train')
        train_data_loader = yield_batches_from_epochs(
            torch.utils.data.DataLoader(
                train_data,
                batch_size=train_config.batch_size,
                shuffle=True,
                drop_last=True,
                collate_fn=lambda x: x))
    train_eval_data_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=train_config.eval_batch_size,
            collate_fn=lambda x: x)

    val_data = model_preproc.dataset('val')
    val_data_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=train_config.eval_batch_size,
            collate_fn=lambda x: x)

    # 4. Start training loop
    with data_random:
        for batch in train_data_loader:
            # Quit if too long
            if last_step >= train_config.max_steps:
                break

            # Evaluate model
            if last_step % train_config.eval_every_n == 0:
                if train_config.eval_on_train:
                    eval_model(logger, model, last_step, train_eval_data_loader, 'train', num_eval_items=train_config.num_eval_items)
                if train_config.eval_on_val:
                    eval_model(logger, model, last_step, val_data_loader, 'val', num_eval_items=train_config.num_eval_items)

            # Compute and apply gradient
            with model_random:
                optimizer.zero_grad()
                loss = model.compute_loss(batch)
                loss.backward()
                lr_scheduler.update_lr(last_step)
                optimizer.step()

            # Report metrics
            if last_step % train_config.report_every_n == 0:
                logger.log('Step {}: loss={:.4f}'.format(last_step, loss.item()))

            last_step += 1
            # Run saver
            if last_step % train_config.save_every_n == 0:
                saver.save(args.logdir, last_step)
Beispiel #21
0
    def train(self, config, modeldir):
        # slight difference here vs. unrefactored train: The init_random starts over here. Could be fixed if it was important by saving random state at end of init
        with self.init_random:
            # We may be able to move optimizer and lr_scheduler to __init__ instead. Empirically it works fine. I think that's because saver.restore
            # resets the state by calling optimizer.load_state_dict.
            # But, if there is no saved file yet, I think this is not true, so might need to reset the optimizer manually?
            # For now, just creating it from scratch each time is safer and appears to be the same speed, but also means you have to pass in the config to train which is kind of ugly.

            # TODO: not nice
            if config["optimizer"].get("name", None) == 'bertAdamw':
                bert_params = list(self.model.encoder.bert_model.parameters())
                assert len(bert_params) > 0
                non_bert_params = []
                for name, _param in self.model.named_parameters():
                    if "bert" not in name:
                        non_bert_params.append(_param)
                assert len(non_bert_params) + len(bert_params) == len(
                    list(self.model.parameters()))

                optimizer = registry.construct('optimizer', config['optimizer'], non_bert_params=non_bert_params, \
                    bert_params=bert_params)
                lr_scheduler = registry.construct( 'lr_scheduler',
                        config.get('lr_scheduler', {'name': 'noop'}),
                        param_groups=[optimizer.non_bert_param_group, \
                                optimizer.bert_param_group])
            else:
                optimizer = registry.construct('optimizer',
                                               config['optimizer'],
                                               params=self.model.parameters())
                lr_scheduler = registry.construct(
                    'lr_scheduler',
                    config.get('lr_scheduler', {'name': 'noop'}),
                    param_groups=optimizer.param_groups)

        # 2. Restore model parameters
        saver = saver_mod.Saver({
            "model": self.model,
            "optimizer": optimizer
        },
                                keep_every_n=self.train_config.keep_every_n)
        last_step = saver.restore(modeldir, map_location=self.device)

        if "pretrain" in config and last_step == 0:
            pretrain_config = config["pretrain"]
            _path = pretrain_config["pretrained_path"]
            _step = pretrain_config["checkpoint_step"]
            pretrain_step = saver.restore(_path,
                                          step=_step,
                                          map_location=self.device,
                                          item_keys=["model"])
            saver.save(modeldir,
                       pretrain_step)  # for evaluating pretrained models
            last_step = pretrain_step

        # 3. Get training data somewhere
        with self.data_random:
            train_data = self.model_preproc.dataset('train')
            train_data_loader = self._yield_batches_from_epochs(
                torch.utils.data.DataLoader(
                    train_data,
                    batch_size=self.train_config.batch_size,
                    shuffle=True,
                    drop_last=True,
                    collate_fn=lambda x: x))
        train_eval_data_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=self.train_config.eval_batch_size,
            collate_fn=lambda x: x)

        val_data = self.model_preproc.dataset('val')
        val_data_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=self.train_config.eval_batch_size,
            collate_fn=lambda x: x)

        # 4. Start training loop
        with self.data_random:
            for batch in train_data_loader:
                # Quit if too long
                if last_step >= self.train_config.max_steps:
                    break

                # Evaluate model
                if last_step % self.train_config.eval_every_n == 0:
                    if self.train_config.eval_on_train:
                        self._eval_model(
                            self.logger,
                            self.model,
                            last_step,
                            train_eval_data_loader,
                            'train',
                            num_eval_items=self.train_config.num_eval_items)
                    if self.train_config.eval_on_val:
                        self._eval_model(
                            self.logger,
                            self.model,
                            last_step,
                            val_data_loader,
                            'val',
                            num_eval_items=self.train_config.num_eval_items)

                # Compute and apply gradient
                with self.model_random:
                    for _i in range(self.train_config.num_batch_accumulated):
                        if _i > 0: batch = next(train_data_loader)
                        loss = self.model.compute_loss(batch)
                        norm_loss = loss / self.train_config.num_batch_accumulated
                        norm_loss.backward()

                    if self.train_config.clip_grad:
                        torch.nn.utils.clip_grad_norm_(optimizer.bert_param_group["params"], \
                            self.train_config.clip_grad)
                    optimizer.step()
                    lr_scheduler.update_lr(last_step)
                    optimizer.zero_grad()

                # Report metrics
                if last_step % self.train_config.report_every_n == 0:
                    self.logger.log('Step {}: loss={:.4f}'.format(
                        last_step, loss.item()))

                last_step += 1
                # Run saver
                if last_step % self.train_config.save_every_n == 0:
                    saver.save(modeldir, last_step)

            # Save final model
            saver.save(modeldir, last_step)
Beispiel #22
0
    def __init__(self,
                 base_grammar,
                 template_file,
                 root_type=None,
                 all_sections_rewritten=False):
        self.base_grammar = registry.construct('grammar', base_grammar)
        self.templates = json.load(open(template_file))
        self.all_sections_rewritten = all_sections_rewritten

        self.pointers = self.base_grammar.pointers
        self.ast_wrapper = copy.deepcopy(self.base_grammar.ast_wrapper)
        self.base_ast_wrapper = self.base_grammar.ast_wrapper
        # TODO: Override root_type more intelligently
        self.root_type = self.base_grammar.root_type
        if base_grammar['name'] == 'python':
            self.root_type = 'mod'

        singular_types_with_single_seq_field = set(
            name
            for name, type_info in self.ast_wrapper.singular_types.items()
            if len(type_info.fields) == 1 and type_info.fields[0].seq)
        seq_fields = {
            '{}-{}'.format(name, field.name): SeqField(name, field)
            for name, type_info in self.ast_wrapper.singular_types.items()
            for field in type_info.fields if field.seq
        }

        templates_by_head_type = collections.defaultdict(list)
        for template in self.templates:
            head_type = template['idiom'][0]
            # head_type can be one of the following:
            # 1. name of a constructor/product with a single seq field.
            # 2. name of any other constructor/product
            # 3. name of a seq field (e.g. 'Dict-keys'),
            #    when the containing constructor/product contains more than one field
            #    (not yet implemented)
            # For 1 and 3, the template should be treated as a 'seq fragment'
            # which can occur in any seq field of the corresponding sum/product type.
            # However, the NL2Code model has no such notion currently.
            if head_type in singular_types_with_single_seq_field:
                # field.type could be sum type or product type, but not constructor
                field = self.ast_wrapper.singular_types[head_type].fields[0]
                templates_by_head_type[field.type].append(
                    (template, SeqField(head_type, field)))
                templates_by_head_type[head_type].append((template, None))
            elif head_type in seq_fields:
                seq_field = seq_fields[head_type]
                templates_by_head_type[seq_field.field.type].append(
                    (template, seq_field))
            else:
                templates_by_head_type[head_type].append((template, None))

        types_to_replace = {}

        for head_type, templates in templates_by_head_type.items():
            constructors, seq_fragment_constructors = [], []
            for template, seq_field in templates:
                if seq_field:
                    if head_type in self.ast_wrapper.product_types:
                        seq_type = '{}_plus_templates'.format(head_type)
                    else:
                        seq_type = head_type

                    seq_fragment_constructors.append(
                        self._template_to_constructor(
                            template, '_{}_seq'.format(seq_type), seq_field))
                else:
                    constructors.append(
                        self._template_to_constructor(template, '', seq_field))

            # head type can be:
            # constructor (member of sum type)
            if head_type in self.ast_wrapper.constructors:
                assert constructors
                assert not seq_fragment_constructors

                self.ast_wrapper.add_constructors_to_sum_type(
                    self.ast_wrapper.constructor_to_sum_type[head_type],
                    constructors)

            # sum type
            elif head_type in self.ast_wrapper.sum_types:
                assert not constructors
                assert seq_fragment_constructors
                self.ast_wrapper.add_seq_fragment_type(
                    head_type, seq_fragment_constructors)

            # product type
            elif head_type in self.ast_wrapper.product_types:
                # Replace Product with Constructor
                # - make a Constructor
                orig_prod_type = self.ast_wrapper.product_types[head_type]
                new_constructor_for_prod_type = asdl.Constructor(
                    name=head_type, fields=orig_prod_type.fields)
                # - remove Product in ast_wrapper
                self.ast_wrapper.remove_product_type(head_type)

                # Define a new sum type
                # Add the original product type and template as constructors
                name = '{}_plus_templates'.format(head_type)
                self.ast_wrapper.add_sum_type(
                    name,
                    asdl.Sum(types=constructors +
                             [new_constructor_for_prod_type]))
                # Add seq fragment constructors
                self.ast_wrapper.add_seq_fragment_type(
                    name, seq_fragment_constructors)

                # Replace every occurrence of the product type in the grammar
                types_to_replace[head_type] = name

            # built-in type
            elif head_type in self.ast_wrapper.primitive_types:
                raise NotImplementedError(
                    'built-in type as head type of idiom unsupported: {}'.
                    format(head_type))
                # Define a new sum type
                # Add the original built-in type and template as constructors
                # Replace every occurrence of the product type in the grammar

            else:
                raise NotImplementedError(
                    'Unable to handle head type of idiom: {}'.format(
                        head_type))

        # Replace occurrences of product types which have been used as idiom head types
        for constructor_or_product in self.ast_wrapper.singular_types.values():
            for field in constructor_or_product.fields:
                if field.type in types_to_replace:
                    field.type = types_to_replace[field.type]

        self.templates_containing_placeholders = {}
        for name, constructor in self.ast_wrapper.singular_types.items():
            if not hasattr(constructor, 'template'):
                continue
            hole_values = {}
            for field in constructor.fields:
                hole_id = self.get_hole_id(field.name)
                placeholder = ast_util.HoleValuePlaceholder(id=hole_id,
                                                            is_seq=field.seq,
                                                            is_opt=field.opt)
                if field.seq:
                    hole_values[hole_id] = [placeholder]
                else:
                    hole_values[hole_id] = placeholder

            self.templates_containing_placeholders[
                name] = constructor.template(hole_values)

        if root_type is not None:
            if isinstance(root_type, (list, tuple)):
                for choice in root_type:
                    if (choice in self.ast_wrapper.singular_types
                            or choice in self.ast_wrapper.sum_types):
                        self.root_type = choice
                        break
            else:
                self.root_type = root_type