def get_fields(vocab): if old_style_vocab(vocab): return load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: return vocab
def main(opt): ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: vocab = torch.load(opt.data + '.vocab.pt') segment_token_idx = None if opt.use_segments: segment_token_idx = vocab['tgt'].base_field.vocab.stoi['.'] opt.segment_token_idx = segment_token_idx # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] fields = vocab for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) model = build_model(model_opt, opt, fields, checkpoint) pdb.set_trace()
def from_opt(cls, opt, embeddings): """Alternate constructor.""" # Retrieve fields vocab = torch.load(opt.data + '.vocab.pt') if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab select_idx = fields["src"][-1][-1].vocab.stoi[SLCT_label] return cls(select_idx, opt.enc_layers, opt.enc_rnn_size, opt.heads, opt.transformer_ff, opt.dropout, embeddings, opt.max_relative_positions)
def load(cls, path, args): vocab = torch.load(path) # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab( vocab, args.model_type, dynamic_dict=args.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: print(f'| [{sn}] dictionary: {len(sf.vocab)} types') return cls(fields)
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. # import pdb # _check_ = torch.load("/home/irteam/users/kaist/ginalee/clean_data/baselines/9-domain5-185pre_step_2500.pt") # model_encoder = [i for i in _check_['model'].keys() if "encoder" in i.split(".")] # encoder = {} # pdb.set_trace() # for i, param in enumerate(model_encoder): # if i == 0: # encoder['embeddings.word_embeddings.weight'] = _check_['model'][param] # continue # param_ = ".".join(param.split(".")[1:]) # # if param.split(".")[1] == 'encoder': # # param_ = ".".join(param.split(".")[2:]) # # else: # # param_ = ".".join(param.split(".")[1:]) # encoder[param_] = _check_['model'][param] # pdb.set_trace() configure_process(opt, device_id) init_logger(opt.log_file) logger.info(opt) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) load_vocab = torch.load(opt.data + '.vocab.pt') vocab = checkpoint['vocab'] load_vocab['src'].fields[0][1].vocab = vocab['src'].fields[0][1].vocab load_vocab['tgt'].fields[0][1].vocab = vocab['tgt'].fields[0][1].vocab vocab = load_vocab else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) if opt.pretrain_from: check = torch.load(opt.pretrain_from, map_location=lambda storage, loc: storage) model.load_state_dict(check['model'], strict=False) model.load_state_dict(check['generator'], strict=False) if 'dom_classifier' in check: model.load_state_dict(check['dom_classifier'], strict=False) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) translator = None if opt.domain_cls_enc == False: translator = train_build_translator(opt, model, model_opt, fields, report_score=True) trainer = build_trainer(translator, opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def train(opt): ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) set_random_seed(opt.seed, False) # @Memray, check the dir existence beforehand to avoid path conflicting errors, # and set save_model, tensorboard_log_dir, wandb_log_dir if not exist train_single._check_save_model_path(opt) if not os.path.exists(opt.tensorboard_log_dir): os.makedirs(opt.tensorboard_log_dir) # Scan previous checkpoint to resume training latest_step = 0 latest_ckpt = None for subdir, dirs, filenames in os.walk(opt.exp_dir): for filename in sorted(filenames): if not filename.endswith('.pt'): continue step = int(filename[filename.rfind('_') + 1:filename.rfind('.pt')]) if step > latest_step: latest_ckpt = os.path.join(subdir, filename) latest_step = step # if not saved in the exp folder, check opt.save_model if latest_ckpt is None and opt.save_model is not None: save_model_dir = os.path.dirname(os.path.abspath(opt.save_model)) model_prefix = opt.save_model[opt.save_model.rfind(os.path.sep) + 1:] for subdir, dirs, filenames in os.walk(save_model_dir): for filename in sorted(filenames): if not filename.endswith('.pt'): continue if not filename.startswith(model_prefix): continue step = int(filename[filename.rfind('_') + 1:filename.rfind('.pt')]) if step > latest_step: latest_ckpt = os.path.join(subdir, filename) latest_step = step if latest_ckpt is not None: logger.info("A previous checkpoint is found, train from it: %s" % latest_ckpt) setattr(opt, 'train_from', latest_ckpt) setattr(opt, 'reset_optim', 'none') # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] elif opt.vocab and opt.vocab != 'none': # added by @memray for multiple datasets vocab = torch.load(opt.vocab) # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): vocab = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) elif opt.encoder_type == 'pretrained': vocab = None else: vocab = None fields = vocab # @memray: a temporary workaround, as well as train_single.py line 78 if fields and opt.data_type == "keyphrase": if opt.tgt_type in ["one2one", "multiple"]: if 'sep_indices' in fields: del fields['sep_indices'] else: if 'sep_indices' not in fields: sep_indices = Field(use_vocab=False, dtype=torch.long, postprocessing=make_tgt, sequential=False) fields["sep_indices"] = sep_indices if 'src_ex_vocab' not in fields: src_ex_vocab = RawField() fields["src_ex_vocab"] = src_ex_vocab # @memray reload fields for news dataset and pretrained models tokenizer = None if opt.pretrained_tokenizer is not None: tokenizer = load_pretrained_tokenizer(opt.pretrained_tokenizer, opt.cache_dir, opt.special_vocab_path) setattr(opt, 'vocab_size', len(tokenizer)) if opt.data_type == 'news': fields = reload_news_fields(opt, tokenizer=tokenizer) # elif opt.data_type == 'keyphrase': # fields = reload_keyphrase_fields(opt, tokenizer=tokenizer) if len(opt.data_ids) > 1: # added by @memray, for loading multiple datasets if opt.multi_dataset: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt, multi=True) else: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) nb_gpu = len(opt.gpu_ranks) if opt.world_size > 1: queues = [] mp = torch.multiprocessing.get_context('spawn') semaphore = mp.Semaphore(opt.world_size * opt.queue_size) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for device_id in range(nb_gpu): q = mp.Queue(opt.queue_size) queues += [q] procs.append( mp.Process(target=run, args=(opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) producer = mp.Process(target=batch_producer, args=( train_iter, queues, semaphore, opt, ), daemon=True) producer.start() error_handler.add_child(producer.pid) for p in procs: p.join() producer.terminate() elif nb_gpu == 1: # case 1 GPU only single_main(opt, 0) else: # case only CPU single_main(opt, -1)
def main(opt, device_id): import pickle # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. import json train_iters = [] with open(opt.data) as json_file: data = json.load(json_file) vocab = data["vocab"] vocab2 = vocab if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(vocab) # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab for key, value in data.items(): if key == ("vocab"): continue elif key.startswith("valid"): valid_iter = (key.split("valid-")[1].split("-"), build_dataset_iter(value, fields, opt, is_train=False)) else: train_iters.append( (key.split("-"), build_dataset_iter(value, fields, opt))) # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) model.critic = critic() model.critic.to(model.device) if model.decoder2 is not None: model.critic2 = critic() model.critic2.to(model.device) else: model.critic2 = None model.critic3 = None n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) vocab = torch.load(vocab2) #print (valid_iter is None) #valid_iter = valid_iter(datas[0][0],valid_iter) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iters, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps, smooth=opt.smooth) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab( vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer( opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter( "valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train( train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. if opt.local_rank != -1: torch.cuda.set_device(opt.local_rank) device = torch.device("cuda", opt.local_rank) torch.distributed.init_process_group(backend='nccl') device_id = opt.local_rank world_size = torch.distributed.get_world_size() else: if device_id == -1: device = torch.device("cpu") else: device = torch.device("cuda", device_id) if opt.local_rank > 0: logger.disabled = True configure_process(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) if opt.bert_kd: src_vocab = vocab['src'].fields[0][1].vocab.stoi tgt_vocab = vocab['tgt'].fields[0][1].vocab.stoi assert 0 < opt.kd_topk <= 128 train_dataset = BertKdDataset(opt.data_db, opt.bert_dump, src_vocab, tgt_vocab, max_len=150, k=opt.kd_topk) BUCKET_SIZE = 8192 if True or opt.local_rank == -1 and opt.world_size == 1: train_sampler = TokenBucketSampler(train_dataset.keys, BUCKET_SIZE, opt.batch_size, batch_multiple=1) else: assert False # seems like it's handled in training loop train_sampler = DistributedTokenBucketSampler(world_size, device_id, train_dataset.keys, BUCKET_SIZE, opt.batch_size, batch_multiple=1) train_loader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=4, collate_fn=BertKdDataset.pad_collate) train_iter = cycle_loader(train_loader, device) else: train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: if trainer.report_manager.tensorboard_writer: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') logger.info('Loading alignment.') lemma_aligns = open(model_opt.lemma_align, 'rb').readlines() src_stoi = vocab['src'].base_field.vocab.stoi lemma_stoi = vocab['word_topic'].base_field.vocab.stoi w2l = {} word_to_lemma = [] for pair in lemma_aligns: pair = pair.strip().split() w2l[src_stoi[pair[0].decode('utf-8')]] = \ lemma_stoi[pair[1].decode('utf-8')] w2l[src_stoi['unk']] = lemma_stoi['unk'] for index in range(len(vocab['src'].base_field.vocab.itos)): if index in w2l: word_to_lemma.append(w2l[index]) else: word_to_lemma.append(w2l[lemma_stoi['unk']]) word_to_lemma = torch.tensor(word_to_lemma) logger.info('Loading topic matrix') if device_id >= 0: topic_matrix = torch.load(opt.topic_matrix, map_location=torch.device(device_id)) else: topic_matrix = torch.load(opt.topic_matrix) if opt.model_dtype == 'fp16': topic_matrix = topic_matrix.half() # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(topic_matrix, word_to_lemma, train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def train_single(self, output_model_dir: Path, opt, device_id, batch_queue=None, semaphore=None): from roosterize.ml.onmt.MultiSourceInputter import MultiSourceInputter from roosterize.ml.onmt.MultiSourceModelBuilder import MultiSourceModelBuilder from roosterize.ml.onmt.MultiSourceModelSaver import MultiSourceModelSaver from roosterize.ml.onmt.MultiSourceTrainer import MultiSourceTrainer from onmt.inputters.inputter import load_old_vocab, old_style_vocab from onmt.train_single import configure_process, _tally_parameters, _check_save_model_path from onmt.utils.optimizers import Optimizer from onmt.utils.parse import ArgumentParser configure_process(opt, device_id) assert len(opt.accum_count) == len( opt.accum_steps ), 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: self.logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) self.logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # end if # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # end if # Report src and tgt vocab sizes, including for features data_keys = [ f"src.{src_type}" for src_type in self.config.get_src_types() ] + ["tgt"] for side in data_keys: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] # end try for sn, sf in f_iter: if sf.use_vocab: self.logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # end for # Build model model = MultiSourceModelBuilder.build_model( self.config.get_src_types(), model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) self.logger.info('encoder: %d' % enc) self.logger.info('decoder: %d' % dec) self.logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = MultiSourceModelSaver.build_model_saver( self.config.get_src_types(), model_opt, opt, model, fields, optim) trainer = MultiSourceTrainer.build_trainer(self.config.get_src_types(), opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) # end for train_iter = MultiSourceInputter.build_dataset_iter_multiple( self.config.get_src_types(), train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" # end if train_iter = MultiSourceInputter.build_dataset_iter( self.config.get_src_types(), shard_base, fields, opt) # end if else: assert semaphore is not None, "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch # end while # end def train_iter = _train_iter() # end if valid_iter = MultiSourceInputter.build_dataset_iter( self.config.get_src_types(), "valid", fields, opt, is_train=False) if len(opt.gpu_ranks): self.logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: self.logger.info('Starting training on CPU, could be very slow') # end if train_steps = opt.train_steps if opt.single_pass and train_steps > 0: self.logger.warning( "Option single_pass is enabled, ignoring train_steps.") train_steps = 0 # end if trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) time_begin = trainer.report_manager.start_time time_end = time.time() if opt.tensorboard: trainer.report_manager.tensorboard_writer.close() # Dump train metrics train_history = trainer.report_manager.get_joint_history() train_metrics = { "time_begin": time_begin, "time_end": time_end, "time": time_end - time_begin, "train_history": train_history, } IOUtils.dump(output_model_dir / "train-metrics.json", train_metrics, IOUtils.Format.jsonNoSort) # Get the best step, depending on the lowest val_xent (cross entropy) best_loss = min([th["val_xent"] for th in train_history]) best_step = [ th["step"] for th in train_history if th["val_xent"] == best_loss ][-1] # Take the last if multiple IOUtils.dump(output_model_dir / "best-step.json", best_step, IOUtils.Format.json) return
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') ################ # model_opt.train_steps = 0 ################ # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) # import pickle # pickle.dump('\n'.join([a for a in vocab['tgt'][0][1].fields[0][1].vocab.itos]), # open('cnndm_duc17_vocab_addtgt.vocab.itos.txt', 'w', encoding='utf-8')) if old_style_vocab(vocab): data_type = opt.model_type fields = load_old_vocab(vocab, data_type, dynamic_dict=opt.copy_attn) else: fields = vocab ################# # if len(fields['src'][0][1].fields[0][1].vocab.itos) == 30522: # fields['src'][0][1].fields[0][1].vocab = fields['tgt'][0][1].fields[0][1].vocab if len(fields['src'][0][1].base_field.vocab.itos) == 30801: fields['src'][0][1].base_field.vocab.stoi.pop( fields['src'][0][1].base_field.vocab.itos.pop(-1)) fields['src'][0][1].base_field.vocab.stoi.pop( fields['src'][0][1].base_field.vocab.itos.pop(-1)) fields['src'][0][1].base_field.vocab.stoi.pop( fields['src'][0][1].base_field.vocab.itos.pop(-1)) fields['src'][0][1].base_field.vocab.stoi.pop( fields['src'][0][1].base_field.vocab.itos.pop(-1)) vocab['src'][0][1].base_field.vocab.freqs.pop('<s>') vocab['src'][0][1].base_field.vocab.freqs.pop('</s>') vocab['src'][0][1].base_field.vocab.freqs.pop('<t>') vocab['src'][0][1].base_field.vocab.freqs.pop('</t>') # if len(fields['tgt'][0][1].fields[0][1].vocab.itos) == 30529: # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('<unk>') # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('<blank>') # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('<s>') # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('</s>') # fields['tgt'][0][1].fields[0][1].vocab.itos = fields['tgt'][0][1].fields[0][1].vocab.itos[4:] # fields['tgt'][0][1].fields[0][1].vocab.itos += ['<s>', '</s>'] # for i, k in enumerate(fields['tgt'][0][1].fields[0][1].vocab.itos): # fields['tgt'][0][1].fields[0][1].vocab.stoi[k] = i # if len(fields['tgt'][0][1].fields[0][1].vocab.itos) == 30526: # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('<unk>') # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('<blank>') # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('<s>') # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('</s>') # fields['tgt'][0][1].fields[0][1].vocab.itos += ['</s>'] # fields['tgt'][0][1].fields[0][1].vocab.stoi['</s>'] = 30526 ################# # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: for name, f in fields[side]: try: f_iter = iter(f) except TypeError: f_iter = [(name, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) # import pickle # model = pickle.load(open('model.pkl', 'rb')) # model.param = torch.nn.Parameter(torch.randn(model.encoder.total_hidden_dim * 4, dtype=torch.float32, # device=torch.device('cuda')).view(1, -1, 1)) n_params, enc, dec = _tally_parameters(model) # model = build_model(model_opt, opt, fields, checkpoint) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # optim = Optimizer.from_opt(model, opt, checkpoint=None) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) # this line is kind of a temporary kludge because different objects expect # fields to have a different structure dataset_fields = dict(chain.from_iterable(fields.values())) train_iter = build_dataset_iter("train", dataset_fields, opt) valid_iter = build_dataset_iter("valid", dataset_fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter, opt.train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps, bert=opt.bert) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) print("load weight success") model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): print("old style vocab") fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: print("not old style") fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) # added and deleted by zhengquan # model = torch.nn.parallel.DistributedDataParallel(model, # device_ids=[opt.local_rank], # output_device=opt.local_rank) # added and deleted by zhengquan for the availability of cuda devices. # In the DistributedDataParallel doc, it says # "DistributedDataParallel with multi-device module only works " # "with CUDA devices, but module parameters locate in {}." n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter( "train", fields, opt) #在build_dataset_iter()中会用opt中的dataset_paths来载入数据 valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id, batch_queue=None, semaphore=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. # configure_process(opt, device_id) # init_logger(opt.log_file) # assert len(opt.accum_count) == len(opt.accum_steps), \ # 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) # ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) # vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. rl_model = build_model(model_opt, opt, fields, checkpoint) _check_save_model_path(opt) # Build optimizer. # optim = torch.optim.Adam(rl_model.parameters()) optim = Optimizer.from_opt(rl_model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, rl_model, optim) # model_saver = None # trainer = build_trainer( # opt, device_id, model, fields, optim, model_saver=model_saver) build_rltor = build_rltor_enc # if not opt.rl_step else build_rltor_dec rltor = build_rltor(opt, rl_model, optim, model_saver, report_score=False) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) if opt.infer: tag_src_shards = split_corpus(opt.tag_src, opt.shard_size) \ if opt.tag_src is not None else repeat(None) shard_pairs = zip(src_shards, tag_src_shards) for i, (src_shard, tag_src_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) rltor.infer(src_shard, tag_src_shard, batch_size=opt.batch_size, batch_type=opt.batch_type) else: valid_src_shards = split_corpus(opt.valid_src, opt.shard_size) valid_tgt_shards = split_corpus(opt.valid_tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) tag_src_shards = split_corpus(opt.tag_src, opt.shard_size) \ if opt.tag_src is not None else repeat(None) valid_tag_src_shards = split_corpus(opt.valid_tag_src, opt.shard_size) \ if opt.valid_tag_src is not None else repeat(None) tag_tgt_shards = split_corpus(opt.tag_tgt, opt.shard_size) \ if opt.tag_tgt is not None else repeat(None) valid_tag_tgt_shards = split_corpus(opt.valid_tag_tgt, opt.shard_size) \ if opt.valid_tag_tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards, tag_src_shards, tag_tgt_shards, valid_src_shards, valid_tgt_shards, valid_tag_src_shards, valid_tag_tgt_shards) for i, (train_src_shard, train_tgt_shard, train_tag_src_shard, train_tag_tgt_shard, valid_src_shard, valid_tgt_shard, valid_tag_src_shard, valid_tag_tgt_shard) in enumerate(shard_pairs): logger.info("Learning shard %d." % i) rltor.train(train_src_shard, train_tgt_shard, train_tag_src_shard, train_tag_tgt_shard, valid_src_shard, valid_tgt_shard, valid_tag_src_shard, valid_tag_tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type)
def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, align_reader, opt): assert corpus_type in [config.train, config.valid] if corpus_type == config.train: counters = defaultdict(Counter) srcs = opt.train_src tgts = opt.train_tgt ids = opt.train_ids aligns = opt.train_align elif corpus_type == config.valid: counters = None srcs = [opt.valid_src] tgts = [opt.valid_tgt] ids = [None] aligns = [opt.valid_align] src_vocab, tgt_vocab, existing_fields = maybe_load_vocab( corpus_type, counters, opt ) existing_shards = [] def shard_iterator(srcs, tgts, ids, aligns, existing_shards, existing_fields, corpus_type, opt): for src, tgt, maybe_id, maybe_align in zip(srcs, tgts, ids, aligns): if maybe_id in existing_shards: if opt.overwrite: logger.warning("Overwrite shards for corpus {}" .format(maybe_id)) else: if corpus_type == config.train: assert existing_fields is not None, \ ("A 'vocab.pt' file should be passed to " "`-src_vocab` when adding a corpus to " "a set of already existing shards.") logger.warning("Ignore corpus {} because " "shards already exist" .format(maybe_id)) continue if ((corpus_type == "train" or opt.filter_valid) and tgt is not None): filter_pred = partial( inputters.filter_example, use_src_len=opt.data_type == "text", max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length) else: filter_pred = None src_shards = split_corpus(src, opt.shard_size) tgt_shards = split_corpus(tgt, opt.shard_size) align_shards = split_corpus(maybe_align, opt.shard_size) for i, (ss, ts, a_s) in enumerate( zip(src_shards, tgt_shards, align_shards)): yield (i, (ss, ts, a_s, maybe_id, filter_pred)) shard_iter = shard_iterator(srcs, tgts, ids, aligns, existing_shards, existing_fields, corpus_type, opt) with Pool(opt.num_threads) as p: dataset_params = ( corpus_type, fields, src_reader, tgt_reader, align_reader, opt, existing_fields, src_vocab, tgt_vocab ) func = partial(process_one_shard, dataset_params) for sub_counter in p.imap(func, shard_iter): if sub_counter is not None: for key, value in sub_counter.items(): counters[key].update(value) if corpus_type == "train": vocab_path = opt.save_data + '.vocab.pt' new_fields = _build_fields_vocab( fields, counters, opt.data_type, opt.share_vocab, opt.vocab_size_multiple, opt.src_vocab_size, opt.src_words_min_frequency, opt.tgt_vocab_size, opt.tgt_words_min_frequency, subword_prefix=opt.subword_prefix, subword_prefix_is_joiner=opt.subword_prefix_is_joiner) if existing_fields is None: fields = new_fields else: fields = existing_fields if old_style_vocab(fields): fields = load_old_vocab( fields, opt.data_type, dynamic_dict=opt.dynamic_dict) # patch corpus_id if fields.get("corpus_id", False): fields["corpus_id"].vocab = new_fields["corpus_id"].vocab_cls( counters["corpus_id"]) torch.save(fields, vocab_path)
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt #vocab = torch.load(opt.data + '.vocab.pt') train_iters = OrderedDict() valid_iters = OrderedDict() encoders = OrderedDict() decoders = OrderedDict() generators = OrderedDict() src_vocabs = OrderedDict() tgt_vocabs = OrderedDict() Fields_dict = OrderedDict() # variables needed for sharing the same embedding matrix across encoders and decoders firstTime = True weightToShare = None # we share the word embedding space when source lang and target lang are the same mapLang2Emb = {} #for (src_tgt_lang), data_path in zip(opt.src_tgt, opt.data): for index in range(len(opt.src_tgt)): src_tgt_lang = opt.src_tgt[index] data_path = opt.data[index] local_enc_dec_opts = AttrDict({ key: model_opt.__dict__[key] for key in model_opt.__dict__.keys() }) local_enc_dec_opts.model_type = update_to_local_attr( model_opt.model_type, index) #local_enc_dec_opts.audio_enc_pooling = model_opt.audio_enc_pooling[index] local_enc_dec_opts.audio_enc_pooling = update_to_local_attr( model_opt.audio_enc_pooling, index) local_enc_dec_opts.enc_layers = update_to_local_attr( model_opt.enc_layers, index) local_enc_dec_opts.dec_layers = update_to_local_attr( model_opt.dec_layers, index) local_enc_dec_opts.rnn_type = update_to_local_attr( model_opt.rnn_type, index) local_enc_dec_opts.encoder_type = update_to_local_attr( model_opt.encoder_type, index) local_enc_dec_opts.batch_size = update_to_local_attr( model_opt.batch_size, index) local_enc_dec_opts.batch_type = update_to_local_attr( model_opt.batch_type, index) local_enc_dec_opts.normalization = update_to_local_attr( model_opt.normalization, index) #local_enc_dec_opts.dec_rnn_size = model_opt.dec_rnn_size[index] src_lang, tgt_lang = src_tgt_lang.split('-') vocab = torch.load(data_path + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type[0], dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. encoder, src_embeddings = build_embeddings_then_encoder( local_enc_dec_opts, fields) encoders[src_lang] = encoder decoder, generator, tgt_embeddings = build_decoder_and_generator( local_enc_dec_opts, fields) decoders[tgt_lang] = decoder # Share the embedding matrix across all the encoders and decoders - preprocess with share_vocab required. if model_opt.share_embeddings and firstTime: tgt_embeddings.word_lut.weight = src_embeddings.word_lut.weight weightToShare = src_embeddings.word_lut.weight if model_opt.share_embeddings and (not firstTime): tgt_embeddings.word_lut.weight = weightToShare src_embeddings.word_lut.weight = weightToShare firstTime = False #TEST #if src_lang in mapLang2Emb: if src_lang in mapLang2Emb and model_opt.model_type == "text": encoder.embeddings.word_lut.weight = mapLang2Emb.get(src_lang) #TEST #else: elif model_opt.model_type == "text": mapLang2Emb[src_lang] = src_embeddings.word_lut.weight if tgt_lang in mapLang2Emb: decoder.embeddings.word_lut.weight = mapLang2Emb.get(tgt_lang) else: mapLang2Emb[tgt_lang] = tgt_embeddings.word_lut.weight #TEST if model_opt.model_type == "text": src_vocabs[src_lang] = fields['src'].base_field.vocab tgt_vocabs[tgt_lang] = fields['tgt'].base_field.vocab generators[tgt_lang] = generator # add this dataset iterator to the training iterators train_iters[(src_lang, tgt_lang)] = build_dataset_iter_fct( 'train', fields, data_path, local_enc_dec_opts) # add this dataset iterator to the validation iterators valid_iters[(src_lang, tgt_lang)] = build_dataset_iter_fct('valid', fields, data_path, local_enc_dec_opts, is_train=False) Fields_dict[src_tgt_lang] = fields # Build model. model = build_model(model_opt, opt, fields, encoders, decoders, generators, src_vocabs, tgt_vocabs, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, Fields_dict, optim) trainer = build_trainer(opt, device_id, model, fields, optim, generators, tgt_vocabs, model_saver=model_saver) # TODO: not implemented yet #train_iterables = [] #if len(opt.data_ids) > 1: # for train_id in opt.data_ids: # shard_base = "train_" + train_id # iterable = build_dataset_iter(shard_base, fields, opt, multi=True) # train_iterables.append(iterable) # train_iter = MultipleDatasetIterator(train_iterables, device_id, opt) #else: # train_iter = build_dataset_iter("train", fields, opt) #valid_iter = build_dataset_iter( # "valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iters, train_steps, opt.save_checkpoint_steps, valid_iters, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. load_str = opt.train_from if opt.train_from else opt.load_uncond_from if load_str: logger.info('Loading checkpoint from %s' % load_str) checkpoint = torch.load(load_str, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % load_str) vocab = checkpoint['vocab'] if opt.train_from: model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) else: model_opt = opt else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') if opt.gpt2_params_path is not None: import tensorflow as tf import numpy as np # Taken from pytorch-pretrained-BERT: # Load weights from TF model logger.info("Loading TF GPT weights...") init_vars = tf.train.list_variables(opt.gpt2_params_path) names = [] arrays = [] for name, shape in init_vars: if opt.gpt_emb_only and ('wpe' not in name and 'wte' not in name): continue if opt.gpt_wpe_only and 'wpe' not in name: continue #print("Loading TF weight {} with shape {}".format(name, shape)) array = tf.train.load_variable(opt.gpt2_params_path, name) names.append(name) arrays.append(array.squeeze()) logger.info("Done.") if checkpoint is None: checkpoint = {'gpt2_params': zip(names, arrays)} else: checkpoint['gpt2_params'] = zip(names, arrays) if opt.encoder_from is not None: logger.info('Loading checkpoint with encoder from %s' % opt.encoder_from) enc_checkpoint = torch.load(opt.encoder_from, map_location=lambda storage, loc: storage) enc_vocab = enc_checkpoint['vocab'] if vocab['src'].base_field.vocab != enc_vocab['src'].base_field.vocab: raise ValueError( 'encoder vocab and model vocab need to be identical it using pretrained encoder' ) if checkpoint is None: checkpoint = {} checkpoint['enc_model'] = enc_checkpoint['model'] # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features sides = ['tgt'] if opt.model_type == 'none' else ['src', 'tgt'] for side in sides: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec, lm_dec = _tally_parameters(model) n_params_t, enc_t, dec_t, lm_dec_t = _tally_parameters(model, only_trainable=True) logger.info('encoder: %d (%d)' % (enc, enc_t)) logger.info('decoder: %d (%d)' % (dec, dec_t)) if opt.simple_fusion: logger.info('lm decoder: %d (%d)' % (lm_dec, lm_dec_t)) logger.info('* number of parameters: %d (%d)' % (n_params, n_params_t)) _check_save_model_path(opt) if not opt.train_from and opt.gpt2_params_path is not None: checkpoint = None # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id, batch_queue=None, semaphore=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) # save training settings if opt.log_file: shutil.copy2(opt.config, opt.exp_dir) logger.info(vars(opt)) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt # added by @memray for multiple datasets if opt.vocab and opt.vocab != 'none': vocab = torch.load(opt.vocab) elif opt.encoder_type == 'pretrained': vocab = None else: vocab = None # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab( vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # @memray: a temporary workaround, as well as train.py line 43 if opt.model_type == "keyphrase": if opt.tgt_type in ["one2one", "multiple"]: if 'sep_indices' in fields: del fields['sep_indices'] else: if 'sep_indices' not in fields: sep_indices = Field( use_vocab=False, dtype=torch.long, postprocessing=make_tgt, sequential=False) fields["sep_indices"] = sep_indices if 'src_ex_vocab' not in fields: src_ex_vocab = RawField() fields["src_ex_vocab"] = src_ex_vocab tokenizer = None if opt.pretrained_tokenizer: tokenizer = load_pretrained_tokenizer(opt.pretrained_tokenizer, opt.cache_dir, opt.special_vocab_path) setattr(opt, 'vocab_size', len(tokenizer)) if opt.data_type == 'news': fields = reload_news_fields(fields, opt, tokenizer) # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer( opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: if len(opt.data_ids) > 1: # added by @memray, for loading multiple datasets if opt.multi_dataset: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt, tokenizer=tokenizer) else: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt, tokenizer=tokenizer) else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() if opt.valid: valid_iter = build_dataset_iter( "valid", fields, opt, is_train=False) else: valid_iter = None if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train( train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if trainer.report_manager.tensorboard_writer is not None: trainer.report_manager.tensorboard_writer.close()
def train_impl( self, train_processed_data_dir: Path, val_processed_data_dir: Path, output_model_dir: Path, ) -> NoReturn: self.preprocess(train_processed_data_dir, val_processed_data_dir, output_model_dir) from train import _get_parser as train_get_parser from train import ErrorHandler, batch_producer from roosterize.ml.onmt.MultiSourceInputter import MultiSourceInputter from onmt.inputters.inputter import old_style_vocab, load_old_vocab import onmt.utils.distributed from onmt.utils.parse import ArgumentParser with IOUtils.cd(self.open_nmt_path): parser = train_get_parser() opt = parser.parse_args( f" -data {output_model_dir}/processed-data" f" -save_model {output_model_dir}/models/ckpt") opt.gpu_ranks = [0] opt.early_stopping = self.config.early_stopping_threshold opt.report_every = 200 opt.valid_steps = 200 opt.save_checkpoint_steps = 200 opt.keep_checkpoint_max = self.config.ckpt_keep_max opt.optim = "adam" opt.learning_rate = self.config.learning_rate opt.max_grad_norm = self.config.max_grad_norm opt.batch_size = self.config.batch_size opt.encoder_type = self.config.encoder opt.decoder_type = self.config.decoder opt.dropout = [self.config.dropout] opt.src_word_vec_size = self.config.dim_embed opt.tgt_word_vec_size = self.config.dim_embed opt.layers = self.config.rnn_num_layers opt.enc_rnn_size = self.config.dim_encoder_hidden opt.dec_rnn_size = self.config.dim_decoder_hidden opt.__setattr__("num_srcs", len(self.config.get_src_types())) if self.config.use_attn: opt.global_attention = "general" else: opt.global_attention = "none" # end if if self.config.use_copy: opt.copy_attn = True opt.copy_attn_type = "general" # end if # train.main ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) # Load checkpoint if we resume from a previous training. if opt.train_from: self.logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load( opt.train_from, map_location=lambda storage, loc: storage) self.logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: vocab = torch.load(opt.data + '.vocab.pt') # end if # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # end if if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) # end for train_iter = MultiSourceInputter.build_dataset_iter_multiple( self.config.get_src_types(), train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" # end if train_iter = MultiSourceInputter.build_dataset_iter( self.config.get_src_types(), shard_base, fields, opt) # end if nb_gpu = len(opt.gpu_ranks) if opt.world_size > 1: queues = [] mp = torch.multiprocessing.get_context('spawn') semaphore = mp.Semaphore(opt.world_size * opt.queue_size) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for device_id in range(nb_gpu): q = mp.Queue(opt.queue_size) queues += [q] def run(opt, device_id, error_queue, batch_queue, semaphore): """ run process """ try: gpu_rank = onmt.utils.distributed.multi_init( opt, device_id) if gpu_rank != opt.gpu_ranks[device_id]: raise AssertionError( "An error occurred in Distributed initialization" ) self.train_single(opt, device_id, batch_queue, semaphore) except KeyboardInterrupt: pass # killed by parent, do nothing except Exception: # propagate exception to parent process, keeping original traceback import traceback error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc())) # end try # end def procs.append( mp.Process(target=run, args=(opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() self.logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) # end for producer = mp.Process(target=batch_producer, args=( train_iter, queues, semaphore, opt, ), daemon=True) producer.start() error_handler.add_child(producer.pid) for p in procs: p.join() producer.terminate() elif nb_gpu == 1: # case 1 GPU only self.train_single(output_model_dir, opt, 0) else: # case only CPU self.train_single(output_model_dir, opt, -1) # end if # end with return
def validate(opt, device_id=0): configure_process(opt, device_id) configure_process if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) #_check_save_model_path(opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) tgt_field = dict(fields)["tgt"].base_field valid_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt, train=False) model.eval() with torch.no_grad(): stats = onmt.utils.Statistics() for batch in valid_iter: src, src_lengths = batch.src if isinstance(batch.src, tuple) \ else (batch.src, None) tgt = batch.tgt # F-prop through the model. outputs, attns = model(src, tgt, src_lengths) # Compute loss. _, batch_stats = valid_loss(batch, outputs, attns) # Update statistics. stats.update(batch_stats) print('n words: %d' % stats.n_words) print('Validation perplexity: %g' % stats.ppl()) print('Validation accuracy: %g' % stats.accuracy()) print('Validation avg attention entropy: %g' % stats.attn_entropy())
def train(opt): ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) if opt.train_from != '': raise Exception( 'train_from will be set automatically to the latest model, you should not set it manually' ) # set gpu ranks automatically if not specified if len(opt.gpu_ranks) == 0: opt.gpu_ranks = list(range(opt.world_size)) # Set train_from to latest checkpoint if it exists file_list = glob.glob(opt.save_model + '*.pt') if len(os.listdir(os.path.dirname( opt.save_model))) > 0 and len(file_list) == 0: raise Exception( 'save_model directory is not empty but no pretrained models found') if len(file_list) > 0: ckpt_nos = list( map(lambda x: int(x.split('_')[-1].split('.')[0]), file_list)) ckpt_no = max(ckpt_nos) opt.train_from = opt.save_model + '_' + str(ckpt_no) + '.pt' print(opt.train_from) assert os.path.exists(opt.train_from) set_random_seed(opt.seed, False) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) nb_gpu = len(opt.gpu_ranks) if opt.world_size > 1: queues = [] mp = torch.multiprocessing.get_context('spawn') semaphore = mp.Semaphore(opt.world_size * opt.queue_size) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for device_id in range(nb_gpu): q = mp.Queue(opt.queue_size) queues += [q] procs.append( mp.Process(target=run, args=(opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) producer = mp.Process(target=batch_producer, args=( train_iter, queues, semaphore, opt, ), daemon=True) producer.start() error_handler.add_child(producer.pid) for p in procs: p.join() producer.terminate() elif nb_gpu == 1: # case 1 GPU only single_main(opt, 0) else: # case only CPU single_main(opt, -1)
def train(opt): ArgumentParser.validate_train_opts(opt) set_random_seed(opt.seed, False) if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: vocab = torch.load(opt.data + '.vocab.pt') if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab patch_fields(opt, fields) if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) nb_gpu = len(opt.gpu_ranks) if opt.world_size > 1: queues = [] mp = torch.multiprocessing.get_context('spawn') semaphore = mp.Semaphore(opt.world_size * opt.queue_size) procs = [] for device_id in range(nb_gpu): q = mp.Queue(opt.queue_size) queues += [q] procs.append( mp.Process(target=run, args=(opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) producer = mp.Process(target=batch_producer, args=( train_iter, queues, semaphore, opt, ), daemon=True) producer.start() error_handler.add_child(producer.pid) for p in procs: p.join() producer.terminate() elif nb_gpu == 1: single_main(opt, 0) else: single_main(opt, -1)
def main(opt, device_id, batch_queue=None, semaphore=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') if opt.teacher_model_path: logger.info('Loading teacher model from {path}'.format( path=opt.teacher_model_path)) teacher_model_ckpt = torch.load( opt.teacher_model_path, map_location=lambda storage, loc: storage) teacher_model_opt = ArgumentParser.ckpt_model_opts( teacher_model_ckpt['opt']) ArgumentParser.update_model_opts(teacher_model_opt) ArgumentParser.validate_model_opts(teacher_model_opt) logger.info('Loading vocab from checkpoint at {path}'.format( path=opt.teacher_model_path)) teacher_vocab = teacher_model_ckpt['vocab'] # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab teacher_fields = teacher_vocab if opt.teacher_model_path else None # patch for fields that may be missing in old data/model # patch_fields(opt, fields) # Report src and tgt vocab sizes, including for features report_vocab_size(fields) if teacher_fields is not None: report_vocab_size(teacher_fields) # Build model. fields_opt = {"original": fields, "teacher": teacher_fields} model = custom_builder.build_model(model_opt, opt, fields_opt, checkpoint) # model = build_model(model_opt, opt, fields, checkpoint) teacher_model = build_model( teacher_model_opt, teacher_model_opt, teacher_fields, teacher_model_ckpt) if opt.teacher_model_path else None n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) if teacher_model is not None: n_params, enc, dec = _tally_parameters(teacher_model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(teacher_model_opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver # model_saver = build_model_saver(model_opt, opt, model, fields, optim) model_saver = custom_model_saver.build_model_saver(model_opt, opt, model, fields_opt, optim) tgt_field = dict(teacher_fields)["tgt"].base_field if teacher_model is not None \ else dict(fields)["tgt"].base_field sos_id = tgt_field.vocab.stoi[tgt_field.init_token] if teacher_model is not None and opt.word_sampling: sampler = Emulator(teacher_model, teacher_fields, device_id, max_length=50, random_sampling_topk=5) else: sampler = None if teacher_model is not None: trainer = build_trainer(opt, device_id, model, teacher_fields, optim, model_saver, teacher_model=teacher_model, emulator=sampler) else: trainer = build_trainer(opt, device_id, model, fields, optim, model_saver, teacher_model=teacher_model, emulator=sampler) if batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, sos_id=sos_id, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if trainer.report_manager.tensorboard_writer is not None: trainer.report_manager.tensorboard_writer.close()
def train(opt): ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) set_random_seed(opt.seed, False) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) if 'vocab' in checkpoint: logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: vocab = torch.load(opt.data + '.vocab.pt') else: vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) nb_gpu = len(opt.gpu_ranks) if opt.world_size > 1: queues = [] mp = torch.multiprocessing.get_context('spawn') semaphore = mp.Semaphore(opt.world_size * opt.queue_size) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for device_id in range(nb_gpu): q = mp.Queue(opt.queue_size) queues += [q] procs.append( mp.Process(target=run, args=(opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) producer = mp.Process(target=batch_producer, args=( train_iter, queues, semaphore, opt, ), daemon=True) producer.start() error_handler.add_child(producer.pid) for p in procs: p.join() producer.terminate() elif nb_gpu == 1: # case 1 GPU only single_main(opt, 0) else: # case only CPU single_main(opt, -1)
def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, tag_reader, align_reader, opt): assert corpus_type in ['train', 'valid'] if corpus_type == 'train': counters = defaultdict(Counter) srcs = opt.train_src tgts = opt.train_tgt ids = opt.train_ids aligns = opt.train_align # wei 20200721 nfr_tags = opt.train_nfr_tag # end wei # wei 20200730 flat_tags = opt.train_flat_tag # end wei elif corpus_type == 'valid': counters = None srcs = [opt.valid_src] tgts = [opt.valid_tgt] ids = [None] aligns = [opt.valid_align] # wei 20200723 nfr_tags = [opt.valid_nfr_tag] # end wei # wei 20200730 flat_tags = [opt.valid_flat_tag] # end wei src_vocab, tgt_vocab, existing_fields = maybe_load_vocab( corpus_type, counters, opt) existing_shards = check_existing_pt_files(opt, corpus_type, ids, existing_fields) # every corpus has shards, no new one if existing_shards == ids and not opt.overwrite: return # def shard_iterator(srcs, tgts, ids, aligns, existing_shards, # wei 20200721 # def shard_iterator(srcs, tgts, ids, aligns, tags, existing_shards, # wei 20200730 def shard_iterator(srcs, tgts, ids, aligns, nfr_tags, flat_tags, existing_shards, existing_fields, corpus_type, opt): """ Builds a single iterator yielding every shard of every corpus. """ for src, tgt, maybe_id, maybe_align, nfr_tag, flat_tag in zip( srcs, tgts, ids, aligns, nfr_tags, flat_tags): if maybe_id in existing_shards: if opt.overwrite: logger.warning( "Overwrite shards for corpus {}".format(maybe_id)) else: if corpus_type == "train": assert existing_fields is not None,\ ("A 'vocab.pt' file should be passed to " "`-src_vocab` when adding a corpus to " "a set of already existing shards.") logger.warning("Ignore corpus {} because " "shards already exist".format(maybe_id)) continue if ((corpus_type == "train" or opt.filter_valid) and tgt is not None): filter_pred = partial(inputters.filter_example, use_src_len=opt.data_type == "text", max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length) else: filter_pred = None src_shards = split_corpus(src, opt.shard_size) tgt_shards = split_corpus(tgt, opt.shard_size) align_shards = split_corpus(maybe_align, opt.shard_size) # wei 20200721 nfr_tag_shards = split_corpus(nfr_tag, opt.shard_size) # end wei # wei 20200730 flat_tag_shards = split_corpus(flat_tag, opt.shard_size) # end wei # for i, (ss, ts, a_s) in enumerate( # wei 20200721 for i, (ss, ts, a_s, n_t_s, f_t_s) in enumerate( # zip(src_shards, tgt_shards, align_shards)): # wei 20200721 # zip(src_shards, tgt_shards, align_shards, tag_shards)): # wei 20200730 zip(src_shards, tgt_shards, align_shards, nfr_tag_shards, flat_tag_shards)): # yield (i, (ss, ts, a_s, maybe_id, filter_pred)) # wei 20200721 # yield (i, (ss, ts, a_s, t_s, maybe_id, filter_pred)) # wei 20200730 yield (i, (ss, ts, a_s, n_t_s, f_t_s, maybe_id, filter_pred)) # shard_iter = shard_iterator(srcs, tgts, ids, aligns, existing_shards, # wei 20200721 # shard_iter=shard_iterator(srcs, tgts, ids, aligns, tags, existing_shards, # wei 20200730 shard_iter = shard_iterator(srcs, tgts, ids, aligns, nfr_tags, flat_tags, existing_shards, existing_fields, corpus_type, opt) with Pool(opt.num_threads) as p: # dataset_params = (corpus_type, fields, src_reader, tgt_reader, # wei 20200721 dataset_params = (corpus_type, fields, src_reader, tgt_reader, tag_reader, align_reader, opt, existing_fields, src_vocab, tgt_vocab) func = partial(process_one_shard, dataset_params) for sub_counter in p.imap(func, shard_iter): if sub_counter is not None: for key, value in sub_counter.items(): counters[key].update(value) if corpus_type == "train": vocab_path = opt.save_data + '.vocab.pt' new_fields = _build_fields_vocab( fields, counters, opt.data_type, opt.share_vocab, opt.vocab_size_multiple, opt.src_vocab_size, opt.src_words_min_frequency, opt.tgt_vocab_size, opt.tgt_words_min_frequency, subword_prefix=opt.subword_prefix, subword_prefix_is_joiner=opt.subword_prefix_is_joiner) if existing_fields is None: fields = new_fields else: fields = existing_fields if old_style_vocab(fields): fields = load_old_vocab(fields, opt.data_type, dynamic_dict=opt.dynamic_dict) # patch corpus_id if fields.get("corpus_id", False): fields["corpus_id"].vocab = new_fields["corpus_id"].vocab_cls( counters["corpus_id"]) torch.save(fields, vocab_path)
def main(opt, device_id, batch_queue=None, semaphore=None, train_iter=None, passed_fields=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) if opt.use_opt_from_trained: model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) else: model_opt = opt ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) aux_fields = None if passed_fields is not None: fields = passed_fields['main'] aux_fields = passed_fields['crosslingual'] elif old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint, aux_fields=aux_fields) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. if opt.almt_only: almt = model.encoder.embeddings.almt_layers['mapping'] logger.info('Only training the alignment mapping.') optim = Optimizer.from_opt(almt, opt, checkpoint=checkpoint) else: optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim, aux_fields=aux_fields) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver, aux_fields=aux_fields) if train_iter is not None: pass # NOTE Use the passed one. elif batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() cl_valid_iter = None if opt.crosslingual: valid_iter = build_dataset_iter("valid", fields, opt, is_train=False, task_cls=Eat2PlainMonoTask) if opt.crosslingual_dev_data: # NOTE I used 'train' to prepare this in `eat_prepare.sh`, so I use 'train' here as well. cl_valid_iter = build_dataset_iter( 'train', fields, opt, is_train=False, data_attr='crosslingual_dev_data', task_cls=Eat2PlainCrosslingualTask) # NOTE This is for the second eat->plain task. aux_valid_iter = build_dataset_iter('valid', fields, opt, is_train=False, data_attr='aux_train_data', task_cls=Eat2PlainMonoTask) valid_iters = [valid_iter, aux_valid_iter] else: valid_iters = [ build_dataset_iter("valid", fields, opt, is_train=False) ] if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iters=valid_iters, valid_steps=opt.valid_steps, cl_valid_iter=cl_valid_iter) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id, batch_queue=None, semaphore=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec, nontrainable = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('non-trainable parameters (tgt_out_emb): %d' % nontrainable) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if trainer.report_manager.tensorboard_writer is not None: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): data_type = opt.model_type fields = load_old_vocab(vocab, data_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: for name, f in fields[side]: if f.use_vocab: logger.info(' * %s vocab size = %d' % (name, len(f.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) # this line is kind of a temporary kludge because different objects expect # fields to have a different structure dataset_fields = dict(chain.from_iterable(fields.values())) train_iter = build_dataset_iter("train", dataset_fields, opt) valid_iter = build_dataset_iter("valid", dataset_fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter, valid_iter, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt): ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab( vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # @memray: a temporary workaround, as well as train_single.py line 78 if opt.model_type == "keyphrase": if opt.tgt_type in ["one2one", "multiple"]: del fields['sep_indices'] else: if 'sep_indices' not in fields: sep_indices = Field( use_vocab=False, dtype=torch.long, postprocessing=make_tgt, sequential=False) fields["sep_indices"] = sep_indices if 'src_ex_vocab' not in fields: src_ex_vocab = RawField() fields["src_ex_vocab"] = src_ex_vocab if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) nb_gpu = len(opt.gpu_ranks) print(os.environ['PATH']) if opt.world_size > 1: queues = [] mp = torch.multiprocessing.get_context('spawn') semaphore = mp.Semaphore(opt.world_size * opt.queue_size) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for device_id in range(nb_gpu): q = mp.Queue(opt.queue_size) queues += [q] procs.append(mp.Process(target=run, args=( opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) producer = mp.Process(target=batch_producer, args=(train_iter, queues, semaphore, opt,), daemon=True) producer.start() error_handler.add_child(producer.pid) for p in procs: p.join() producer.terminate() elif nb_gpu == 1: # case 1 GPU only single_main(opt, 0) else: # case only CPU single_main(opt, -1)