def infer(self, model, output_path, args): output = open(output_path, 'w') with torch.no_grad(): if args.mode == 'infer': orig_data = registry.construct('dataset', self.config['data'][args.section]) preproc_data = self.model_preproc.dataset(args.section) if args.limit: sliced_orig_data = itertools.islice(orig_data, args.limit) sliced_preproc_data = itertools.islice(preproc_data, args.limit) else: sliced_orig_data = orig_data sliced_preproc_data = preproc_data assert len(orig_data) == len(preproc_data) self._inner_infer(model, args.beam_size, args.output_history, sliced_orig_data, sliced_preproc_data, output, args.use_heuristic) elif args.mode == 'debug': data = self.model_preproc.dataset(args.section) if args.limit: sliced_data = itertools.islice(data, args.limit) else: sliced_data = data self._debug(model, sliced_data, output) elif args.mode == 'visualize_attention': model.visualize_flag = True model.decoder.visualize_flag = True data = registry.construct('dataset', self.config['data'][args.section]) if args.limit: sliced_data = itertools.islice(data, args.limit) else: sliced_data = data self._visualize_attention(model, args.beam_size, args.output_history, sliced_data, args.res1, args.res2, args.res3, output)
def __init__(self, preproc, device, encoder, decoder): super().__init__() self.preproc = preproc self.encoder = registry.construct( 'encoder', encoder, device=device, preproc=preproc.enc_preproc) self.decoder = registry.construct( 'decoder', decoder, device=device, preproc=preproc.dec_preproc) self.decoder.visualize_flag = False if getattr(self.encoder, 'batched'): self.compute_loss = self._compute_loss_enc_batched else: self.compute_loss = self._compute_loss_unbatched
def __init__(self, grammar, save_path, censor_pointers): self.save_path = save_path self.censor_pointers = censor_pointers self.grammar = registry.construct('grammar', grammar) self.ast_wrapper = self.grammar.ast_wrapper self.items = collections.defaultdict(list)
def compute_metrics(config_path, config_args, section, inferred_path,logdir=None): if config_args: config = json.loads(_jsonnet.evaluate_file(config_path, tla_codes={'args': config_args})) else: config = json.loads(_jsonnet.evaluate_file(config_path)) if 'model_name' in config and logdir: logdir = os.path.join(logdir, config['model_name']) if logdir: inferred_path = inferred_path.replace('__LOGDIR__', logdir) inferred = open(inferred_path) data = registry.construct('dataset', config['data'][section]) metrics = data.Metrics(data) inferred_lines = list(inferred) if len(inferred_lines) < len(data): raise Exception('Not enough inferred: {} vs {}'.format(len(inferred_lines), len(data))) for line in inferred_lines: infer_results = json.loads(line) if infer_results['beams']: inferred_code = infer_results['beams'][0]['inferred_code'] else: inferred_code = None if 'index' in infer_results: metrics.add(data[infer_results['index']], inferred_code) else: metrics.add(None, inferred_code, obsolete_gold_code=infer_results['gold_code']) return logdir, metrics.finalize()
def __init__(self, logger, config): if torch.cuda.is_available(): # and False: device = torch.device('cuda') else: device = torch.device('cpu') self.logger = logger self.train_config = registry.instantiate(TrainConfig, config['train']) self.data_random = random_state.RandomContext( self.train_config.data_seed) self.model_random = random_state.RandomContext( self.train_config.model_seed) self.init_random = random_state.RandomContext( self.train_config.init_seed) with self.init_random: # 0. Construct preprocessors self.model_preproc = registry.instantiate(registry.lookup( 'model', config['model']).Preproc, config['model'], unused_keys=('name', )) self.model_preproc.load() # 1. Construct model self.model = registry.construct('model', config['model'], unused_keys=('encoder_preproc', 'decoder_preproc'), preproc=self.model_preproc, device=device) self.model.to(device)
def infer(self, model, output_path, args): # 3. Get training data somewhere output = open(output_path, 'w') orig_data = registry.construct('dataset', self.config['data'][args.section]) sliced_orig_data = maybe_slice(orig_data, args.start_offset, args.limit) preproc_data = self.model_preproc.dataset(args.section) sliced_preproc_data = maybe_slice(preproc_data, args.start_offset, args.limit) with torch.no_grad(): if args.mode == 'infer': assert len(orig_data) == len(preproc_data) self._inner_infer(model, args.beam_size, args.output_history, sliced_orig_data, sliced_preproc_data, output, args.nproc) elif args.mode == 'debug': self._debug(model, sliced_orig_data, output) elif args.mode == 'visualize_attention': model.visualize_flag = True model.decoder.visualize_flag = True self._visualize_attention(model, args.beam_size, args.output_history, sliced_orig_data, args.res1, args.res2, args.res3, output)
def __init__(self, grammar, save_path, min_freq=3, max_count=5000, use_seq_elem_rules=False): self.grammar = registry.construct('grammar', grammar) self.ast_wrapper = self.grammar.ast_wrapper self.vocab_path = os.path.join(save_path, 'dec_vocab.json') self.observed_productions_path = os.path.join( save_path, 'observed_productions.json') self.grammar_rules_path = os.path.join(save_path, 'grammar_rules.json') self.data_dir = os.path.join(save_path, 'dec') self.vocab_builder = vocab.VocabBuilder(min_freq, max_count) self.use_seq_elem_rules = use_seq_elem_rules self.items = collections.defaultdict(list) self.sum_type_constructors = collections.defaultdict(set) self.field_presence_infos = collections.defaultdict(set) self.seq_lengths = collections.defaultdict(set) self.primitive_types = set() self.vocab = None self.all_rules = None self.rules_mask = None
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', required=True) parser.add_argument('--config-args') parser.add_argument('--section', required=True) parser.add_argument('--inferred', required=True) parser.add_argument('--output', required=True) args = parser.parse_args() if args.config_args: config = json.loads( _jsonnet.evaluate_file(args.config, tla_codes={'args': args.config_args})) else: config = json.loads(_jsonnet.evaluate_file(args.config)) os.makedirs(args.output, exist_ok=True) gold = open(os.path.join(args.output, 'gold.txt'), 'w') predicted = open(os.path.join(args.output, 'predicted.txt'), 'w') inferred = open(args.inferred) data = registry.construct('dataset', config['data'][args.section]) for line in inferred: infer_results = json.loads(line) if infer_results['beams']: inferred_code = infer_results['beams'][0]['inferred_code'] else: inferred_code = 'SELECT a FROM b' item = data[infer_results['index']] gold.write('{}\t{}\n'.format(item.orig['query'].replace('\t', ' '), item.schema.db_id)) predicted.write('{}\n'.format(inferred_code))
def load_model(self, logdir, step): '''Load a model (identified by the config used for construction) and return it''' # 1. Construct model model = registry.construct('model', self.config['model'], preproc=self.model_preproc, device=self.device) model.to(self.device) model.eval() model.visualize_flag = False optimizer = registry.construct('optimizer', self.config['optimizer'], params=model.parameters()) # 2. Restore its parameters saver = saver_mod.Saver(model, optimizer) last_step = saver.restore(logdir, step=step, map_location=self.device) if not last_step: raise Exception('Attempting to infer on untrained model') return model
def preprocess(self): self.model_preproc.clear_items() for section in self.config['data']: data = registry.construct('dataset', self.config['data'][section]) for item in tqdm.tqdm(data, desc=section, dynamic_ncols=True): to_add, validation_info = self.model_preproc.validate_item(item, section) if to_add: self.model_preproc.add_item(item, section, validation_info) self.model_preproc.save()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', required=True) parser.add_argument('--config-args') args = parser.parse_args() if args.config_args: config = json.loads( _jsonnet.evaluate_file(args.config, tla_codes={'args': args.config_args})) else: config = json.loads(_jsonnet.evaluate_file(args.config)) train_data = registry.construct('dataset', config['data']['train']) grammar = registry.construct('grammar', config['model']['decoder_preproc']['grammar']) base_grammar = registry.construct( 'grammar', config['model']['decoder_preproc']['grammar']['base_grammar']) for i, item in enumerate(tqdm.tqdm(train_data, dynamic_ncols=True)): parsed = grammar.parse(item.code, 'train') orig_parsed = base_grammar.parse(item.orig['orig'], 'train') canonicalized_orig_code = base_grammar.unparse( base_grammar.parse(item.orig['orig'], 'train'), item) unparsed = grammar.unparse(parsed, item) if canonicalized_orig_code != unparsed: print('Original tree:') pprint.pprint(orig_parsed) print('Rewritten tree:') pprint.pprint(parsed) print('Reconstructed tree:') pprint.pprint(grammar._expand_templates(parsed)) print('Original code:') print(canonicalized_orig_code) print('Reconstructed code:') print(unparsed) import IPython IPython.embed() break
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', required=True) parser.add_argument('--config-args') parser.add_argument('--output', required=True) args = parser.parse_args() if args.config_args: config = json.loads( _jsonnet.evaluate_file(args.config, tla_codes={'args': args.config_args})) else: config = json.loads(_jsonnet.evaluate_file(args.config)) os.makedirs(args.output, exist_ok=True) gold = open(os.path.join(args.output, 'gold.txt'), 'w') predicted = open(os.path.join(args.output, 'predicted.txt'), 'w') train_data = registry.construct('dataset', config['data']['train']) grammar = registry.construct('grammar', config['model']['decoder_preproc']['grammar']) evaluator = evaluation.Evaluator( 'data/spider-20190205/database', evaluation.build_foreign_key_map_from_json( 'data/spider-20190205/tables.json'), 'match') for i, item in enumerate(tqdm.tqdm(train_data, dynamic_ncols=True)): parsed = grammar.parse(item.code, 'train') sql = grammar.unparse(parsed, item) evaluator.evaluate_one(item.schema.db_id, item.orig['query'].replace('\t', ' '), sql) gold.write('{}\t{}\n'.format(item.orig['query'].replace('\t', ' '), item.schema.db_id)) predicted.write('{}\n'.format(sql))
def load_db(): db_id = "singer" my_schema = dump_db_json_schema( "data/sqlite_files/{db_id}/{db_id}.sqlite".format(db_id=db_id), db_id) schema, eval_foreign_key_maps = load_tables_from_schema_dict(my_schema) schema.keys() dataset = registry.construct( 'dataset_infer', { "name": "spider", "schemas": schema, "eval_foreign_key_maps": eval_foreign_key_maps, "db_path": "data/sqlite_files/" }) for _, schema in dataset.schemas.items(): model.preproc.enc_preproc._preprocess_schema(schema) return dataset.schemas[db_id]
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', required=True) parser.add_argument('--config-args') args = parser.parse_args() if args.config_args: config = json.loads(_jsonnet.evaluate_file(args.config, tla_codes={'args': args.config_args})) else: config = json.loads(_jsonnet.evaluate_file(args.config)) model_preproc = registry.instantiate( registry.lookup('model', config['model']).Preproc, config['model']) for section in config['data']: data = registry.construct('dataset', config['data'][section]) for item in tqdm.tqdm(data, desc=section, dynamic_ncols=True): to_add, validation_info = model_preproc.validate_item(item, section) if to_add: model_preproc.add_item(item, section, validation_info) model_preproc.save()
def __init__(self, save_path, min_freq=3, max_count=5000, include_table_name_in_column=True, word_emb=None, count_tokens_in_word_emb_for_vocab=False): if word_emb is None: self.word_emb = None else: self.word_emb = registry.construct('word_emb', word_emb) self.data_dir = os.path.join(save_path, 'enc') self.include_table_name_in_column = include_table_name_in_column self.count_tokens_in_word_emb_for_vocab = count_tokens_in_word_emb_for_vocab # TODO: Write 'train', 'val', 'test' somewhere else self.texts = {'train': [], 'val': [], 'test': []} self.vocab_builder = vocab.VocabBuilder(min_freq, max_count) self.vocab_path = os.path.join(save_path, 'enc_vocab.json') self.vocab = None self.counted_db_ids = set()
def compute_metrics(config_path, config_args, section, inferred_path,logdir=None, evaluate_beams_individually=False): if config_args: config = json.loads(_jsonnet.evaluate_file(config_path, tla_codes={'args': config_args})) else: config = json.loads(_jsonnet.evaluate_file(config_path)) if 'model_name' in config and logdir: logdir = os.path.join(logdir, config['model_name']) if logdir: inferred_path = inferred_path.replace('__LOGDIR__', logdir) inferred = open(inferred_path) data = registry.construct('dataset', config['data'][section]) inferred_lines = list(inferred) if len(inferred_lines) < len(data): raise Exception('Not enough inferred: {} vs {}'.format(len(inferred_lines), len(data))) if evaluate_beams_individually: return logdir, evaluate_all_beams(data, inferred_lines) else: return logdir, evaluate_default(data, inferred_lines)
def main(): if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') torch.set_num_threads(1) if args.config_args: config = json.loads( _jsonnet.evaluate_file(args.config, tla_codes={'args': args.config_args})) else: config = json.loads(_jsonnet.evaluate_file(args.config)) if 'model_name' in config: args.logdir = os.path.join(args.logdir, config['model_name']) output_path = args.output.replace('__LOGDIR__', args.logdir) if os.path.exists(output_path): print('Output file {} already exists'.format(output_path)) sys.exit(1) # 0. Construct preprocessors model_preproc = registry.instantiate( registry.lookup('model', config['model']).Preproc, config['model']) model_preproc.load() # 1. Construct model model = registry.construct('model', config['model'], preproc=model_preproc, device=device) model.to(device) model.eval() model.visualize_flag = False optimizer = registry.construct('optimizer', config['optimizer'], params=model.parameters()) # 2. Restore its parameters saver = saver_mod.Saver(model, optimizer) last_step = saver.restore(args.logdir, step=args.step, map_location=device) if not last_step: raise Exception('Attempting to infer on untrained model') # 3. Get training data somewhere output = open(output_path, 'w') data = registry.construct('dataset', config['data'][args.section]) if args.limit: sliced_data = itertools.islice(data, args.limit) else: sliced_data = data with torch.no_grad(): if args.mode == 'infer': orig_data = registry.construct('dataset', config['data'][args.section]) preproc_data = model_preproc.dataset(args.section) if args.limit: sliced_orig_data = itertools.islice(data, args.limit) sliced_preproc_data = itertools.islice(data, args.limit) else: sliced_orig_data = orig_data sliced_preproc_data = preproc_data assert len(orig_data) == len(preproc_data) infer(model, args.beam_size, args.output_history, sliced_orig_data, sliced_preproc_data, output) elif args.mode == 'debug': data = model_preproc.dataset(args.section) if args.limit: sliced_data = itertools.islice(data, args.limit) else: sliced_data = data debug(model, sliced_data, output) elif args.mode == 'visualize_attention': model.visualize_flag = True model.decoder.visualize_flag = True data = registry.construct('dataset', config['data'][args.section]) if args.limit: sliced_data = itertools.islice(data, args.limit) else: sliced_data = data visualize_attention(model, args.beam_size, args.output_history, sliced_data, output)
def train(self, config, modeldir): # slight difference here vs. unrefactored train: The init_random starts over here. Could be fixed if it was important by saving random state at end of init with self.init_random: # We may be able to move optimizer and lr_scheduler to __init__ instead. Empirically it works fine. I think that's because saver.restore # resets the state by calling optimizer.load_state_dict. # But, if there is no saved file yet, I think this is not true, so might need to reset the optimizer manually? # For now, just creating it from scratch each time is safer and appears to be the same speed, but also means you have to pass in the config to train which is kind of ugly. optimizer = registry.construct('optimizer', config['optimizer'], params=self.model.parameters()) lr_scheduler = registry.construct( 'lr_scheduler', config.get('lr_scheduler', {'name': 'noop'}), optimizer=optimizer) # 2. Restore model parameters saver = saver_mod.Saver( self.model, optimizer, keep_every_n=self.train_config.keep_every_n) last_step = saver.restore(modeldir) # 3. Get training data somewhere with self.data_random: train_data = self.model_preproc.dataset('train') train_data_loader = self._yield_batches_from_epochs( torch.utils.data.DataLoader( train_data, batch_size=self.train_config.batch_size, shuffle=True, drop_last=True, collate_fn=lambda x: x)) train_eval_data_loader = torch.utils.data.DataLoader( train_data, batch_size=self.train_config.eval_batch_size, collate_fn=lambda x: x) val_data = self.model_preproc.dataset('val') val_data_loader = torch.utils.data.DataLoader( val_data, batch_size=self.train_config.eval_batch_size, collate_fn=lambda x: x) # 4. Start training loop with self.data_random: for batch in train_data_loader: # Quit if too long if last_step >= self.train_config.max_steps: break # Evaluate model if last_step % self.train_config.eval_every_n == 0: if self.train_config.eval_on_train: self._eval_model(self.logger, self.model, last_step, train_eval_data_loader, 'train', num_eval_items=self.train_config.num_eval_items) if self.train_config.eval_on_val: self._eval_model(self.logger, self.model, last_step, val_data_loader, 'val', num_eval_items=self.train_config.num_eval_items) # Compute and apply gradient with self.model_random: optimizer.zero_grad() loss = self.model.compute_loss(batch) loss.backward() lr_scheduler.update_lr(last_step) optimizer.step() # Report metrics if last_step % self.train_config.report_every_n == 0: self.logger.log('Step {}: loss={:.4f}'.format(last_step, loss.item())) last_step += 1 # Run saver if last_step % self.train_config.save_every_n == 0: saver.save(modeldir, last_step)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', required=True) parser.add_argument('--config-args') args = parser.parse_args() if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') if args.config_args: config = json.loads( _jsonnet.evaluate_file(args.config, tla_codes={'args': args.config_args})) else: config = json.loads(_jsonnet.evaluate_file(args.config)) # 0. Construct preprocessors model_preproc = registry.instantiate(registry.lookup( 'model', config['model']).Preproc, config['model'], unused_keys=('name', )) model_preproc.load() # 1. Construct model model = registry.construct('model', config['model'], unused_keys=('encoder_preproc', 'decoder_preproc'), preproc=model_preproc, device=device) model.to(device) model.eval() # 3. Get training data somewhere train_data = model_preproc.dataset('train') train_eval_data_loader = torch.utils.data.DataLoader( train_data, batch_size=10, collate_fn=lambda x: x) batch = next(iter(train_eval_data_loader)) descs = [x for x, y in batch] q0, qb = test_enc_equal([descs[0]['question']], [[desc['question']] for desc in descs], model.encoder.question_encoder) c0, cb = test_enc_equal(descs[0]['columns'], [desc['columns'] for desc in descs], model.encoder.column_encoder) t0, tb = test_enc_equal(descs[0]['tables'], [desc['tables'] for desc in descs], model.encoder.table_encoder) q0_enc, c0_enc, t0_enc = model.encoder.encs_update.forward_unbatched( descs[0], q0[0], c0[0], c0[1], t0[0], t0[1]) qb_enc, cb_enc, tb_enc = model.encoder.encs_update.forward( descs, qb[0], cb[0], cb[1], tb[0], tb[1]) check_close(q0_enc.squeeze(1), qb_enc.select(0)) check_close(c0_enc.squeeze(1), cb_enc.select(0)) check_close(t0_enc.squeeze(1), tb_enc.select(0))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--logdir', required=True) parser.add_argument('--config', required=True) parser.add_argument('--config-args') args = parser.parse_args() if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') if args.config_args: config = json.loads(_jsonnet.evaluate_file(args.config, tla_codes={'args': args.config_args})) else: config = json.loads(_jsonnet.evaluate_file(args.config)) if 'model_name' in config: args.logdir = os.path.join(args.logdir, config['model_name']) train_config = registry.instantiate(TrainConfig, config['train']) reopen_to_flush = config.get('log', {}).get('reopen_to_flush') logger = Logger(os.path.join(args.logdir, 'log.txt'), reopen_to_flush) with open(os.path.join(args.logdir, 'config-{}.json'.format( datetime.datetime.now().strftime('%Y%m%dT%H%M%S%Z'))), 'w') as f: json.dump(config, f, sort_keys=True, indent=4) logger.log('Logging to {}'.format(args.logdir)) init_random = random_state.RandomContext(train_config.init_seed) data_random = random_state.RandomContext(train_config.data_seed) model_random = random_state.RandomContext(train_config.model_seed) with init_random: # 0. Construct preprocessors model_preproc = registry.instantiate( registry.lookup('model', config['model']).Preproc, config['model'], unused_keys=('name',)) model_preproc.load() # 1. Construct model model = registry.construct('model', config['model'], unused_keys=('encoder_preproc', 'decoder_preproc'), preproc=model_preproc, device=device) model.to(device) optimizer = registry.construct('optimizer', config['optimizer'], params=model.parameters()) lr_scheduler = registry.construct( 'lr_scheduler', config.get('lr_scheduler', {'name': 'noop'}), optimizer=optimizer) # 2. Restore its parameters saver = saver_mod.Saver( model, optimizer, keep_every_n=train_config.keep_every_n) last_step = saver.restore(args.logdir) # 3. Get training data somewhere with data_random: train_data = model_preproc.dataset('train') train_data_loader = yield_batches_from_epochs( torch.utils.data.DataLoader( train_data, batch_size=train_config.batch_size, shuffle=True, drop_last=True, collate_fn=lambda x: x)) train_eval_data_loader = torch.utils.data.DataLoader( train_data, batch_size=train_config.eval_batch_size, collate_fn=lambda x: x) val_data = model_preproc.dataset('val') val_data_loader = torch.utils.data.DataLoader( val_data, batch_size=train_config.eval_batch_size, collate_fn=lambda x: x) # 4. Start training loop with data_random: for batch in train_data_loader: # Quit if too long if last_step >= train_config.max_steps: break # Evaluate model if last_step % train_config.eval_every_n == 0: if train_config.eval_on_train: eval_model(logger, model, last_step, train_eval_data_loader, 'train', num_eval_items=train_config.num_eval_items) if train_config.eval_on_val: eval_model(logger, model, last_step, val_data_loader, 'val', num_eval_items=train_config.num_eval_items) # Compute and apply gradient with model_random: optimizer.zero_grad() loss = model.compute_loss(batch) loss.backward() lr_scheduler.update_lr(last_step) optimizer.step() # Report metrics if last_step % train_config.report_every_n == 0: logger.log('Step {}: loss={:.4f}'.format(last_step, loss.item())) last_step += 1 # Run saver if last_step % train_config.save_every_n == 0: saver.save(args.logdir, last_step)
def train(self, config, modeldir): # slight difference here vs. unrefactored train: The init_random starts over here. Could be fixed if it was important by saving random state at end of init with self.init_random: # We may be able to move optimizer and lr_scheduler to __init__ instead. Empirically it works fine. I think that's because saver.restore # resets the state by calling optimizer.load_state_dict. # But, if there is no saved file yet, I think this is not true, so might need to reset the optimizer manually? # For now, just creating it from scratch each time is safer and appears to be the same speed, but also means you have to pass in the config to train which is kind of ugly. # TODO: not nice if config["optimizer"].get("name", None) == 'bertAdamw': bert_params = list(self.model.encoder.bert_model.parameters()) assert len(bert_params) > 0 non_bert_params = [] for name, _param in self.model.named_parameters(): if "bert" not in name: non_bert_params.append(_param) assert len(non_bert_params) + len(bert_params) == len( list(self.model.parameters())) optimizer = registry.construct('optimizer', config['optimizer'], non_bert_params=non_bert_params, \ bert_params=bert_params) lr_scheduler = registry.construct( 'lr_scheduler', config.get('lr_scheduler', {'name': 'noop'}), param_groups=[optimizer.non_bert_param_group, \ optimizer.bert_param_group]) else: optimizer = registry.construct('optimizer', config['optimizer'], params=self.model.parameters()) lr_scheduler = registry.construct( 'lr_scheduler', config.get('lr_scheduler', {'name': 'noop'}), param_groups=optimizer.param_groups) # 2. Restore model parameters saver = saver_mod.Saver({ "model": self.model, "optimizer": optimizer }, keep_every_n=self.train_config.keep_every_n) last_step = saver.restore(modeldir, map_location=self.device) if "pretrain" in config and last_step == 0: pretrain_config = config["pretrain"] _path = pretrain_config["pretrained_path"] _step = pretrain_config["checkpoint_step"] pretrain_step = saver.restore(_path, step=_step, map_location=self.device, item_keys=["model"]) saver.save(modeldir, pretrain_step) # for evaluating pretrained models last_step = pretrain_step # 3. Get training data somewhere with self.data_random: train_data = self.model_preproc.dataset('train') train_data_loader = self._yield_batches_from_epochs( torch.utils.data.DataLoader( train_data, batch_size=self.train_config.batch_size, shuffle=True, drop_last=True, collate_fn=lambda x: x)) train_eval_data_loader = torch.utils.data.DataLoader( train_data, batch_size=self.train_config.eval_batch_size, collate_fn=lambda x: x) val_data = self.model_preproc.dataset('val') val_data_loader = torch.utils.data.DataLoader( val_data, batch_size=self.train_config.eval_batch_size, collate_fn=lambda x: x) # 4. Start training loop with self.data_random: for batch in train_data_loader: # Quit if too long if last_step >= self.train_config.max_steps: break # Evaluate model if last_step % self.train_config.eval_every_n == 0: if self.train_config.eval_on_train: self._eval_model( self.logger, self.model, last_step, train_eval_data_loader, 'train', num_eval_items=self.train_config.num_eval_items) if self.train_config.eval_on_val: self._eval_model( self.logger, self.model, last_step, val_data_loader, 'val', num_eval_items=self.train_config.num_eval_items) # Compute and apply gradient with self.model_random: for _i in range(self.train_config.num_batch_accumulated): if _i > 0: batch = next(train_data_loader) loss = self.model.compute_loss(batch) norm_loss = loss / self.train_config.num_batch_accumulated norm_loss.backward() if self.train_config.clip_grad: torch.nn.utils.clip_grad_norm_(optimizer.bert_param_group["params"], \ self.train_config.clip_grad) optimizer.step() lr_scheduler.update_lr(last_step) optimizer.zero_grad() # Report metrics if last_step % self.train_config.report_every_n == 0: self.logger.log('Step {}: loss={:.4f}'.format( last_step, loss.item())) last_step += 1 # Run saver if last_step % self.train_config.save_every_n == 0: saver.save(modeldir, last_step) # Save final model saver.save(modeldir, last_step)
def __init__(self, base_grammar, template_file, root_type=None, all_sections_rewritten=False): self.base_grammar = registry.construct('grammar', base_grammar) self.templates = json.load(open(template_file)) self.all_sections_rewritten = all_sections_rewritten self.pointers = self.base_grammar.pointers self.ast_wrapper = copy.deepcopy(self.base_grammar.ast_wrapper) self.base_ast_wrapper = self.base_grammar.ast_wrapper # TODO: Override root_type more intelligently self.root_type = self.base_grammar.root_type if base_grammar['name'] == 'python': self.root_type = 'mod' singular_types_with_single_seq_field = set( name for name, type_info in self.ast_wrapper.singular_types.items() if len(type_info.fields) == 1 and type_info.fields[0].seq) seq_fields = { '{}-{}'.format(name, field.name): SeqField(name, field) for name, type_info in self.ast_wrapper.singular_types.items() for field in type_info.fields if field.seq } templates_by_head_type = collections.defaultdict(list) for template in self.templates: head_type = template['idiom'][0] # head_type can be one of the following: # 1. name of a constructor/product with a single seq field. # 2. name of any other constructor/product # 3. name of a seq field (e.g. 'Dict-keys'), # when the containing constructor/product contains more than one field # (not yet implemented) # For 1 and 3, the template should be treated as a 'seq fragment' # which can occur in any seq field of the corresponding sum/product type. # However, the NL2Code model has no such notion currently. if head_type in singular_types_with_single_seq_field: # field.type could be sum type or product type, but not constructor field = self.ast_wrapper.singular_types[head_type].fields[0] templates_by_head_type[field.type].append( (template, SeqField(head_type, field))) templates_by_head_type[head_type].append((template, None)) elif head_type in seq_fields: seq_field = seq_fields[head_type] templates_by_head_type[seq_field.field.type].append( (template, seq_field)) else: templates_by_head_type[head_type].append((template, None)) types_to_replace = {} for head_type, templates in templates_by_head_type.items(): constructors, seq_fragment_constructors = [], [] for template, seq_field in templates: if seq_field: if head_type in self.ast_wrapper.product_types: seq_type = '{}_plus_templates'.format(head_type) else: seq_type = head_type seq_fragment_constructors.append( self._template_to_constructor( template, '_{}_seq'.format(seq_type), seq_field)) else: constructors.append( self._template_to_constructor(template, '', seq_field)) # head type can be: # constructor (member of sum type) if head_type in self.ast_wrapper.constructors: assert constructors assert not seq_fragment_constructors self.ast_wrapper.add_constructors_to_sum_type( self.ast_wrapper.constructor_to_sum_type[head_type], constructors) # sum type elif head_type in self.ast_wrapper.sum_types: assert not constructors assert seq_fragment_constructors self.ast_wrapper.add_seq_fragment_type( head_type, seq_fragment_constructors) # product type elif head_type in self.ast_wrapper.product_types: # Replace Product with Constructor # - make a Constructor orig_prod_type = self.ast_wrapper.product_types[head_type] new_constructor_for_prod_type = asdl.Constructor( name=head_type, fields=orig_prod_type.fields) # - remove Product in ast_wrapper self.ast_wrapper.remove_product_type(head_type) # Define a new sum type # Add the original product type and template as constructors name = '{}_plus_templates'.format(head_type) self.ast_wrapper.add_sum_type( name, asdl.Sum(types=constructors + [new_constructor_for_prod_type])) # Add seq fragment constructors self.ast_wrapper.add_seq_fragment_type( name, seq_fragment_constructors) # Replace every occurrence of the product type in the grammar types_to_replace[head_type] = name # built-in type elif head_type in self.ast_wrapper.primitive_types: raise NotImplementedError( 'built-in type as head type of idiom unsupported: {}'. format(head_type)) # Define a new sum type # Add the original built-in type and template as constructors # Replace every occurrence of the product type in the grammar else: raise NotImplementedError( 'Unable to handle head type of idiom: {}'.format( head_type)) # Replace occurrences of product types which have been used as idiom head types for constructor_or_product in self.ast_wrapper.singular_types.values(): for field in constructor_or_product.fields: if field.type in types_to_replace: field.type = types_to_replace[field.type] self.templates_containing_placeholders = {} for name, constructor in self.ast_wrapper.singular_types.items(): if not hasattr(constructor, 'template'): continue hole_values = {} for field in constructor.fields: hole_id = self.get_hole_id(field.name) placeholder = ast_util.HoleValuePlaceholder(id=hole_id, is_seq=field.seq, is_opt=field.opt) if field.seq: hole_values[hole_id] = [placeholder] else: hole_values[hole_id] = placeholder self.templates_containing_placeholders[ name] = constructor.template(hole_values) if root_type is not None: if isinstance(root_type, (list, tuple)): for choice in root_type: if (choice in self.ast_wrapper.singular_types or choice in self.ast_wrapper.sum_types): self.root_type = choice break else: self.root_type = root_type