def init_from_scratch(args, train_exs, dev_exs): """New model, new data, new dictionary.""" # Build a dictionary from the data questions + words (train/dev splits) logger.info('-' * 100) logger.info('Build word dictionary') src_dict = util.build_word_and_char_dict(args, examples=train_exs + dev_exs, fields=['code'], dict_size=args.src_vocab_size, no_special_token=True) tgt_dict = util.build_word_and_char_dict(args, examples=train_exs + dev_exs, fields=['summary'], dict_size=args.tgt_vocab_size, no_special_token=False) logger.info('Num words in source = %d and target = %d' % (len(src_dict), len(tgt_dict))) # Initialize model model = Code2NaturalLanguage(config.get_model_args(args), src_dict, tgt_dict) # Load pretrained embeddings for words in dictionary if args.embedding_file: pass return model
def main(args): # -------------------------------------------------------------------------- # DATA logger.info('-' * 100) logger.info('Load and process data files') dev_exs = [] for dev_src, dev_src_tag, dev_tgt, dataset_name in \ zip(args.dev_src_files, args.dev_src_tag_files, args.dev_tgt_files, args.dataset_name): dev_files = dict() dev_files['src'] = dev_src dev_files['src_tag'] = dev_src_tag dev_files['tgt'] = dev_tgt exs = util.load_data(args, dev_files, max_examples=args.max_examples, dataset_name=dataset_name, test_split=True) dev_exs.extend(exs) logger.info('Num dev examples = %d' % len(dev_exs)) # -------------------------------------------------------------------------- # MODEL logger.info('-' * 100) if not os.path.isfile(args.model_file): raise IOError('No such file: %s' % args.model_file) model = Code2NaturalLanguage.load(args.model_file) # Use the GPU? if args.cuda: model.cuda() # Use multiple GPUs? if args.parallel: model.parallelize() # -------------------------------------------------------------------------- # DATA ITERATORS # Two datasets: train and dev. If we sort by length it's faster. logger.info('-' * 100) logger.info('Make data loaders') dev_dataset = data.CommentDataset(dev_exs, model) dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset) dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=args.test_batch_size, sampler=dev_sampler, num_workers=args.data_workers, collate_fn=vector.batchify, pin_memory=args.cuda, drop_last=args.parallel) # ------------------------------------------------------------------------- # PRINT CONFIG logger.info('-' * 100) #logger.info('CONFIG:\n%s' % # json.dumps(vars(args), indent=4, sort_keys=True)) # -------------------------------------------------------------------------- # DO TEST validate_official(args, dev_loader, model)
def init_from_scratch(args, train_exs, dev_exs, logger): """New model, new data, new dictionary.""" # Build a dictionary from the data questions + words (train/dev splits) logger.print('-' * 100) logger.print('Build word dictionary') src_dict = util.build_word_and_char_dict(args, examples=train_exs,# + dev_exs, fields=['code'], dict_size=args.src_vocab_size, special_tokens="pad_unk",\ attrname="tokens" if \ not args.sum_over_subtokens\ else "subtokens") tgt_dict = util.build_word_and_char_dict( args, examples=train_exs, # + dev_exs, fields=['summary'], dict_size=args.tgt_vocab_size, special_tokens="pad_unk_bos_eos") if args.use_tree_relative_attn: rel_dict = util.build_word_and_char_dict(args, examples=train_exs, fields=["rel_matrix"], dict_size=None, special_tokens="unk") else: rel_dict = None if args.use_code_type: type_dict = util.build_word_and_char_dict(args, examples=train_exs,# + dev_exs, fields=['code'], dict_size=None, special_tokens="pad_unk",\ attrname="type") else: type_dict = None logger.print('Num words in source = %d and target = %d' % (len(src_dict), len(tgt_dict))) if args.use_tree_relative_attn: logger.print("Num relations in relative matrix = %d" % (len(rel_dict))) # Initialize model model = Code2NaturalLanguage(config.get_model_args(args), src_dict, tgt_dict, rel_dict, type_dict) return model
def main(args): # -------------------------------------------------------------------------- # DATA logger.info('-' * 100) logger.info('Load and process data files') train_exs = [] if not args.only_test: args.dataset_weights = dict() for train_src, train_src_tag, train_tgt, dataset_name in \ zip(args.train_src_files, args.train_src_tag_files, args.train_tgt_files, args.dataset_name): train_files = dict() train_files['src'] = train_src train_files['src_tag'] = train_src_tag train_files['tgt'] = train_tgt exs = util.load_data(args, train_files, max_examples=args.max_examples, dataset_name=dataset_name) lang_name = constants.DATA_LANG_MAP[dataset_name] args.dataset_weights[constants.LANG_ID_MAP[lang_name]] = len(exs) train_exs.extend(exs) logger.info('Num train examples = %d' % len(train_exs)) args.num_train_examples = len(train_exs) for lang_id in args.dataset_weights.keys(): weight = (1.0 * args.dataset_weights[lang_id]) / len(train_exs) args.dataset_weights[lang_id] = round(weight, 2) logger.info('Dataset weights = %s' % str(args.dataset_weights)) dev_exs = [] for dev_src, dev_src_tag, dev_tgt, dataset_name in \ zip(args.dev_src_files, args.dev_src_tag_files, args.dev_tgt_files, args.dataset_name): dev_files = dict() dev_files['src'] = dev_src dev_files['src_tag'] = dev_src_tag dev_files['tgt'] = dev_tgt exs = util.load_data(args, dev_files, max_examples=args.max_examples, dataset_name=dataset_name, test_split=True) dev_exs.extend(exs) logger.info('Num dev examples = %d' % len(dev_exs)) # -------------------------------------------------------------------------- # MODEL logger.info('-' * 100) start_epoch = 1 if args.only_test: if args.pretrained: model = Code2NaturalLanguage.load(args.pretrained) else: if not os.path.isfile(args.model_file): raise IOError('No such file: %s' % args.model_file) model = Code2NaturalLanguage.load(args.model_file) else: if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'): # Just resume training, no modifications. logger.info('Found a checkpoint...') checkpoint_file = args.model_file + '.checkpoint' model, start_epoch = Code2NaturalLanguage.load_checkpoint( checkpoint_file, args.cuda) else: # Training starts fresh. But the model state is either pretrained or # newly (randomly) initialized. if args.pretrained: logger.info('Using pretrained model...') model = Code2NaturalLanguage.load(args.pretrained, args) else: logger.info('Training model from scratch...') model = init_from_scratch(args, train_exs, dev_exs) # Set up optimizer model.init_optimizer() # log the parameter details logger.info( 'Trainable #parameters [encoder-decoder] {} [total] {}'.format( human_format(model.network.count_encoder_parameters() + model.network.count_decoder_parameters()), human_format(model.network.count_parameters()))) table = model.network.layer_wise_parameters() logger.info('Breakdown of the trainable paramters\n%s' % table) # Use the GPU? if args.cuda: model.cuda() if args.parallel: model.parallelize() # -------------------------------------------------------------------------- # DATA ITERATORS # Two datasets: train and dev. If we sort by length it's faster. logger.info('-' * 100) logger.info('Make data loaders') if not args.only_test: train_dataset = data.CommentDataset(train_exs, model) if args.sort_by_len: train_sampler = data.SortedBatchSampler(train_dataset.lengths(), args.batch_size, shuffle=True) else: train_sampler = torch.utils.data.sampler.RandomSampler( train_dataset) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.data_workers, collate_fn=vector.batchify, pin_memory=args.cuda, drop_last=args.parallel) dev_dataset = data.CommentDataset(dev_exs, model) dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset) dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=args.test_batch_size, sampler=dev_sampler, num_workers=args.data_workers, collate_fn=vector.batchify, pin_memory=args.cuda, drop_last=args.parallel) # ------------------------------------------------------------------------- # PRINT CONFIG logger.info('-' * 100) logger.info('CONFIG:\n%s' % json.dumps(vars(args), indent=4, sort_keys=True)) # -------------------------------------------------------------------------- # DO TEST if args.only_test: stats = { 'timer': Timer(), 'epoch': 0, 'best_valid': 0, 'no_improvement': 0 } validate_official(args, dev_loader, model, stats, mode='test') # -------------------------------------------------------------------------- # TRAIN/VALID LOOP else: logger.info('-' * 100) logger.info('Starting training...') stats = { 'timer': Timer(), 'epoch': start_epoch, 'best_valid': 0, 'no_improvement': 0 } if args.optimizer in ['sgd', 'adam' ] and args.warmup_epochs >= start_epoch: logger.info("Use warmup lrate for the %d epoch, from 0 up to %s." % (args.warmup_epochs, args.learning_rate)) num_batches = len(train_loader.dataset) // args.batch_size warmup_factor = (args.learning_rate + 0.) / (num_batches * args.warmup_epochs) stats['warmup_factor'] = warmup_factor for epoch in range(start_epoch, args.num_epochs + 1): stats['epoch'] = epoch if args.optimizer in ['sgd', 'adam' ] and epoch > args.warmup_epochs: model.optimizer.param_groups[0]['lr'] = \ model.optimizer.param_groups[0]['lr'] * args.lr_decay train(args, train_loader, model, stats) result = validate_official(args, dev_loader, model, stats) # Save best valid if result[args.valid_metric] > stats['best_valid']: logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' % (args.valid_metric, result[args.valid_metric], stats['epoch'], model.updates)) model.save(args.model_file) stats['best_valid'] = result[args.valid_metric] stats['no_improvement'] = 0 else: stats['no_improvement'] += 1 if stats['no_improvement'] >= args.early_stop: break