def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Gather information related to the training script and commit version script_path = os.path.abspath(__file__) script_dir = os.path.dirname(os.path.dirname(script_path)) logger.info('Train script dir: %s' % script_dir) git_commit = str(subprocess.check_output(['bash', script_dir + '/cluster_scripts/git_version.sh'])) logger.info("Git Commit: %s" % git_commit[2:-3]) # Load checkpoint if we resume from a previous training. if opt.train_from: # TODO: load MTL model logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) else: checkpoint = None model_opt = opt num_tasks = len(opt.data.split(',')) opt.num_tasks = num_tasks checkpoint_list=[] if opt.warm_model: base_name=opt.warm_model for task_id in range(num_tasks): chkpt_path=base_name.replace("X",str(task_id)) if not os.path.isfile(chkpt_path): chkpt_path = base_name.replace("X", str(0)) logger.info('Loading a checkpoint from %s' % chkpt_path) checkpoint = torch.load(chkpt_path, map_location=lambda storage, loc: storage) checkpoint_list.append(checkpoint) else: for task_id in range(num_tasks): checkpoint_list.append(None) fields_list = [] data_type=None for task_id in range(num_tasks): # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt, task_id=task_id)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. if opt.mtl_shared_vocab and task_id > 0: logger.info(' * vocabulary size. Same as the main task!') fields = fields_list[0] else: fields = load_fields(first_dataset, opt, checkpoint_list[task_id], task_id=task_id) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * (Task %d) src feature %d size = %d' % (task_id, j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * (Task %) tgt feature %d size = %d' % (task_id, j, len(fields[feat].vocab))) fields_list.append(fields) if opt.epochs > -1: total_num_batch = 0 for task_id in range(num_tasks): train_iter = build_dataset_iter(lazily_load_dataset("train", opt, task_id=task_id), fields_list[task_id], opt) for i, batch in enumerate(train_iter): num_batch = i total_num_batch+=num_batch if opt.mtl_schedule < 10: break num_batch = total_num_batch opt.train_steps = (num_batch * opt.epochs) + 1 # Do the validation and save after each epoch opt.valid_steps = num_batch opt.save_checkpoint_steps = 1 # logger.info(opt_to_string(opt)) logger.info(opt) # Build model(s). models_list = [] for task_id in range(num_tasks): if opt.mtl_fully_share and task_id > 0: # Since we only have one model, copy the pointer to the model for all models_list.append(models_list[0]) else: main_model = models_list[0] if task_id > 0 else None model = build_model(model_opt, opt, fields_list[task_id], checkpoint_list[task_id], main_model=main_model, task_id=task_id) n_params, enc, dec = _tally_parameters(model) logger.info('(Task %d) encoder: %d' % (task_id, enc)) logger.info('(Task %d) decoder: %d' % (task_id, dec)) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) models_list.append(model) # combine parameters of different models and consider shared parameters just once. def combine_named_parameters(named_params_list): observed_params = [] for model_named_params in named_params_list: for name, p in model_named_params: is_observed = False # Check whether we observed this parameter before for param in observed_params: if p is param: is_observed = True break if not is_observed: observed_params.append(p) yield name, p # Build optimizer. optims_list = [] all_models_params=[] for task_id in range(num_tasks): if not opt.mtl_shared_optimizer: optim = build_optim(models_list[task_id], opt, checkpoint) optims_list.append(optim) else: all_models_params.append(models_list[task_id].named_parameters()) # Extract the list of shared parameters among the models of all tasks. observed_params = [] shared_params = [] for task_id in range(num_tasks): for name, p in models_list[task_id].named_parameters(): is_observed = False # Check whether we observed this parameter before for param in observed_params: if p is param: shared_params.append(name) is_observed = True break if not is_observed: observed_params.append(p) opt.shared_params = shared_params if opt.mtl_shared_optimizer: optim = build_optim_mtl_params(combine_named_parameters(all_models_params), opt, checkpoint) optims_list.append(optim) # Build model saver model_saver = build_mtl_model_saver(model_opt, opt, models_list, fields_list, optims_list) trainer = build_trainer(opt, device_id, models_list, fields_list, optims_list, data_type, model_saver=model_saver) def train_iter_fct(task_id): return build_dataset_iter( lazily_load_dataset("train", opt, task_id=task_id), fields_list[task_id], opt) def valid_iter_fct(task_id): return build_dataset_iter( lazily_load_dataset("valid", opt, task_id=task_id), fields_list[task_id], opt) def meta_valid_iter_fct(task_id, is_log=False): return build_dataset_iter( lazily_load_dataset("meta_valid", opt, task_id=task_id, is_log=is_log), fields_list[task_id], opt) # Do training. if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps, meta_valid_iter_fct=meta_valid_iter_fct) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = load_fields(first_dataset, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = load_fields(first_dataset, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') if opt.no_base == False: trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close() if opt.comparable: logger.info('') logger.info('Beginning comparable data extraction and training.') # 1. Initialize Comparable object comp = Comparable(model, trainer, fields, logger, opt) # 2. Infer similarity threshold from training data for epoch in range(opt.comp_epochs): # 3. Update threshold if dynamic if opt.threshold_dynamics != 'static' and epoch != 0: comp.update_threshold(opt.threshold_dynamics, opt.infer_threshold) # 4. Extract parallel data and train #if opt.match_articles: # comparable_data = comp.match_articles(opt.match_articles) # train_stats = comp.extract_and_train(comparable_data) #else: train_stats = comp.extract_and_train(opt.comparable_data) # 5. Validate on validation set if opt.no_valid == False: valid_iter = build_dataset_iter( lazily_load_dataset("valid", opt), fields, opt) valid_stats = comp.validate(valid_iter) # 6. Drop a checkpoint if needed comp.trainer.model_saver._save(epoch)