def train(opt): ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) set_random_seed(opt.seed, False) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) if 'vocab' in checkpoint: logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: vocab = torch.load(opt.data + '.vocab.pt') else: vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) nb_gpu = len(opt.gpu_ranks) if opt.world_size > 1: queues = [] mp = torch.multiprocessing.get_context('spawn') semaphore = mp.Semaphore(opt.world_size * opt.queue_size) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for device_id in range(nb_gpu): q = mp.Queue(opt.queue_size) queues += [q] procs.append( mp.Process(target=run, args=(opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) producer = mp.Process(target=batch_producer, args=( train_iter, queues, semaphore, opt, ), daemon=True) producer.start() error_handler.add_child(producer.pid) for p in procs: p.join() producer.terminate() elif nb_gpu == 1: # case 1 GPU only single_main(opt, 0) else: # case only CPU single_main(opt, -1)
def train(opt): ArgumentParser.validate_train_opts(opt) set_random_seed(opt.seed, False) if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: vocab = torch.load(opt.data + '.vocab.pt') if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab patch_fields(opt, fields) if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) nb_gpu = len(opt.gpu_ranks) if opt.world_size > 1: queues = [] mp = torch.multiprocessing.get_context('spawn') semaphore = mp.Semaphore(opt.world_size * opt.queue_size) procs = [] for device_id in range(nb_gpu): q = mp.Queue(opt.queue_size) queues += [q] procs.append( mp.Process(target=run, args=(opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) producer = mp.Process(target=batch_producer, args=( train_iter, queues, semaphore, opt, ), daemon=True) producer.start() error_handler.add_child(producer.pid) for p in procs: p.join() producer.terminate() elif nb_gpu == 1: single_main(opt, 0) else: single_main(opt, -1)
def main(opt, device_id, batch_queue=None, semaphore=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec, nontrainable = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('non-trainable parameters (tgt_out_emb): %d' % nontrainable) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if trainer.report_manager.tensorboard_writer is not None: trainer.report_manager.tensorboard_writer.close()
def train(opt): ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) if opt.train_from != '': raise Exception( 'train_from will be set automatically to the latest model, you should not set it manually' ) # set gpu ranks automatically if not specified if len(opt.gpu_ranks) == 0: opt.gpu_ranks = list(range(opt.world_size)) # Set train_from to latest checkpoint if it exists file_list = glob.glob(opt.save_model + '*.pt') if len(os.listdir(os.path.dirname( opt.save_model))) > 0 and len(file_list) == 0: raise Exception( 'save_model directory is not empty but no pretrained models found') if len(file_list) > 0: ckpt_nos = list( map(lambda x: int(x.split('_')[-1].split('.')[0]), file_list)) ckpt_no = max(ckpt_nos) opt.train_from = opt.save_model + '_' + str(ckpt_no) + '.pt' print(opt.train_from) assert os.path.exists(opt.train_from) set_random_seed(opt.seed, False) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) nb_gpu = len(opt.gpu_ranks) if opt.world_size > 1: queues = [] mp = torch.multiprocessing.get_context('spawn') semaphore = mp.Semaphore(opt.world_size * opt.queue_size) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for device_id in range(nb_gpu): q = mp.Queue(opt.queue_size) queues += [q] procs.append( mp.Process(target=run, args=(opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) producer = mp.Process(target=batch_producer, args=( train_iter, queues, semaphore, opt, ), daemon=True) producer.start() error_handler.add_child(producer.pid) for p in procs: p.join() producer.terminate() elif nb_gpu == 1: # case 1 GPU only single_main(opt, 0) else: # case only CPU single_main(opt, -1)
def main(opt, device_id, batch_queue=None, semaphore=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') if opt.teacher_model_path: logger.info('Loading teacher model from {path}'.format( path=opt.teacher_model_path)) teacher_model_ckpt = torch.load( opt.teacher_model_path, map_location=lambda storage, loc: storage) teacher_model_opt = ArgumentParser.ckpt_model_opts( teacher_model_ckpt['opt']) ArgumentParser.update_model_opts(teacher_model_opt) ArgumentParser.validate_model_opts(teacher_model_opt) logger.info('Loading vocab from checkpoint at {path}'.format( path=opt.teacher_model_path)) teacher_vocab = teacher_model_ckpt['vocab'] # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab teacher_fields = teacher_vocab if opt.teacher_model_path else None # patch for fields that may be missing in old data/model # patch_fields(opt, fields) # Report src and tgt vocab sizes, including for features report_vocab_size(fields) if teacher_fields is not None: report_vocab_size(teacher_fields) # Build model. fields_opt = {"original": fields, "teacher": teacher_fields} model = custom_builder.build_model(model_opt, opt, fields_opt, checkpoint) # model = build_model(model_opt, opt, fields, checkpoint) teacher_model = build_model( teacher_model_opt, teacher_model_opt, teacher_fields, teacher_model_ckpt) if opt.teacher_model_path else None n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) if teacher_model is not None: n_params, enc, dec = _tally_parameters(teacher_model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(teacher_model_opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver # model_saver = build_model_saver(model_opt, opt, model, fields, optim) model_saver = custom_model_saver.build_model_saver(model_opt, opt, model, fields_opt, optim) tgt_field = dict(teacher_fields)["tgt"].base_field if teacher_model is not None \ else dict(fields)["tgt"].base_field sos_id = tgt_field.vocab.stoi[tgt_field.init_token] if teacher_model is not None and opt.word_sampling: sampler = Emulator(teacher_model, teacher_fields, device_id, max_length=50, random_sampling_topk=5) else: sampler = None if teacher_model is not None: trainer = build_trainer(opt, device_id, model, teacher_fields, optim, model_saver, teacher_model=teacher_model, emulator=sampler) else: trainer = build_trainer(opt, device_id, model, fields, optim, model_saver, teacher_model=teacher_model, emulator=sampler) if batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, sos_id=sos_id, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if trainer.report_manager.tensorboard_writer is not None: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id, batch_queue=None, semaphore=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) # save training settings if opt.log_file: shutil.copy2(opt.config, opt.exp_dir) logger.info(vars(opt)) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt # added by @memray for multiple datasets if opt.vocab and opt.vocab != 'none': vocab = torch.load(opt.vocab) elif opt.encoder_type == 'pretrained': vocab = None else: vocab = None # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab( vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # @memray: a temporary workaround, as well as train.py line 43 if opt.model_type == "keyphrase": if opt.tgt_type in ["one2one", "multiple"]: if 'sep_indices' in fields: del fields['sep_indices'] else: if 'sep_indices' not in fields: sep_indices = Field( use_vocab=False, dtype=torch.long, postprocessing=make_tgt, sequential=False) fields["sep_indices"] = sep_indices if 'src_ex_vocab' not in fields: src_ex_vocab = RawField() fields["src_ex_vocab"] = src_ex_vocab tokenizer = None if opt.pretrained_tokenizer: tokenizer = load_pretrained_tokenizer(opt.pretrained_tokenizer, opt.cache_dir, opt.special_vocab_path) setattr(opt, 'vocab_size', len(tokenizer)) if opt.data_type == 'news': fields = reload_news_fields(fields, opt, tokenizer) # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer( opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: if len(opt.data_ids) > 1: # added by @memray, for loading multiple datasets if opt.multi_dataset: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt, tokenizer=tokenizer) else: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt, tokenizer=tokenizer) else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() if opt.valid: valid_iter = build_dataset_iter( "valid", fields, opt, is_train=False) else: valid_iter = None if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train( train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if trainer.report_manager.tensorboard_writer is not None: trainer.report_manager.tensorboard_writer.close()
def train(opt): ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) set_random_seed(opt.seed, False) # @Memray, check the dir existence beforehand to avoid path conflicting errors, # and set save_model, tensorboard_log_dir, wandb_log_dir if not exist train_single._check_save_model_path(opt) if not os.path.exists(opt.tensorboard_log_dir): os.makedirs(opt.tensorboard_log_dir) # Scan previous checkpoint to resume training latest_step = 0 latest_ckpt = None for subdir, dirs, filenames in os.walk(opt.exp_dir): for filename in sorted(filenames): if not filename.endswith('.pt'): continue step = int(filename[filename.rfind('_') + 1:filename.rfind('.pt')]) if step > latest_step: latest_ckpt = os.path.join(subdir, filename) latest_step = step # if not saved in the exp folder, check opt.save_model if latest_ckpt is None and opt.save_model is not None: save_model_dir = os.path.dirname(os.path.abspath(opt.save_model)) model_prefix = opt.save_model[opt.save_model.rfind(os.path.sep) + 1:] for subdir, dirs, filenames in os.walk(save_model_dir): for filename in sorted(filenames): if not filename.endswith('.pt'): continue if not filename.startswith(model_prefix): continue step = int(filename[filename.rfind('_') + 1:filename.rfind('.pt')]) if step > latest_step: latest_ckpt = os.path.join(subdir, filename) latest_step = step if latest_ckpt is not None: logger.info("A previous checkpoint is found, train from it: %s" % latest_ckpt) setattr(opt, 'train_from', latest_ckpt) setattr(opt, 'reset_optim', 'none') # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] elif opt.vocab and opt.vocab != 'none': # added by @memray for multiple datasets vocab = torch.load(opt.vocab) # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): vocab = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) elif opt.encoder_type == 'pretrained': vocab = None else: vocab = None fields = vocab # @memray: a temporary workaround, as well as train_single.py line 78 if fields and opt.data_type == "keyphrase": if opt.tgt_type in ["one2one", "multiple"]: if 'sep_indices' in fields: del fields['sep_indices'] else: if 'sep_indices' not in fields: sep_indices = Field(use_vocab=False, dtype=torch.long, postprocessing=make_tgt, sequential=False) fields["sep_indices"] = sep_indices if 'src_ex_vocab' not in fields: src_ex_vocab = RawField() fields["src_ex_vocab"] = src_ex_vocab # @memray reload fields for news dataset and pretrained models tokenizer = None if opt.pretrained_tokenizer is not None: tokenizer = load_pretrained_tokenizer(opt.pretrained_tokenizer, opt.cache_dir, opt.special_vocab_path) setattr(opt, 'vocab_size', len(tokenizer)) if opt.data_type == 'news': fields = reload_news_fields(opt, tokenizer=tokenizer) # elif opt.data_type == 'keyphrase': # fields = reload_keyphrase_fields(opt, tokenizer=tokenizer) if len(opt.data_ids) > 1: # added by @memray, for loading multiple datasets if opt.multi_dataset: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt, multi=True) else: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) nb_gpu = len(opt.gpu_ranks) if opt.world_size > 1: queues = [] mp = torch.multiprocessing.get_context('spawn') semaphore = mp.Semaphore(opt.world_size * opt.queue_size) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for device_id in range(nb_gpu): q = mp.Queue(opt.queue_size) queues += [q] procs.append( mp.Process(target=run, args=(opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) producer = mp.Process(target=batch_producer, args=( train_iter, queues, semaphore, opt, ), daemon=True) producer.start() error_handler.add_child(producer.pid) for p in procs: p.join() producer.terminate() elif nb_gpu == 1: # case 1 GPU only single_main(opt, 0) else: # case only CPU single_main(opt, -1)
def main(opt, device_id, batch_queue=None, semaphore=None, train_iter=None, passed_fields=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) if opt.use_opt_from_trained: model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) else: model_opt = opt ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) aux_fields = None if passed_fields is not None: fields = passed_fields['main'] aux_fields = passed_fields['crosslingual'] elif old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint, aux_fields=aux_fields) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. if opt.almt_only: almt = model.encoder.embeddings.almt_layers['mapping'] logger.info('Only training the alignment mapping.') optim = Optimizer.from_opt(almt, opt, checkpoint=checkpoint) else: optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim, aux_fields=aux_fields) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver, aux_fields=aux_fields) if train_iter is not None: pass # NOTE Use the passed one. elif batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() cl_valid_iter = None if opt.crosslingual: valid_iter = build_dataset_iter("valid", fields, opt, is_train=False, task_cls=Eat2PlainMonoTask) if opt.crosslingual_dev_data: # NOTE I used 'train' to prepare this in `eat_prepare.sh`, so I use 'train' here as well. cl_valid_iter = build_dataset_iter( 'train', fields, opt, is_train=False, data_attr='crosslingual_dev_data', task_cls=Eat2PlainCrosslingualTask) # NOTE This is for the second eat->plain task. aux_valid_iter = build_dataset_iter('valid', fields, opt, is_train=False, data_attr='aux_train_data', task_cls=Eat2PlainMonoTask) valid_iters = [valid_iter, aux_valid_iter] else: valid_iters = [ build_dataset_iter("valid", fields, opt, is_train=False) ] if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iters=valid_iters, valid_steps=opt.valid_steps, cl_valid_iter=cl_valid_iter) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def train_single(self, output_model_dir: Path, opt, device_id, batch_queue=None, semaphore=None): from roosterize.ml.onmt.CustomTrainer import CustomTrainer from onmt.inputters.inputter import build_dataset_iter, load_old_vocab, old_style_vocab, build_dataset_iter_multiple from onmt.model_builder import build_model from onmt.train_single import configure_process, _tally_parameters, _check_save_model_path from onmt.models import build_model_saver from onmt.utils.optimizers import Optimizer from onmt.utils.parse import ArgumentParser configure_process(opt, device_id) assert len(opt.accum_count) == len(opt.accum_steps), 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: self.logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) self.logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # end if # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # end if # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] # end try for sn, sf in f_iter: if sf.use_vocab: self.logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # end for # Build model model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) self.logger.info('encoder: %d' % enc) self.logger.info('decoder: %d' % dec) self.logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = CustomTrainer.build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) # end for train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" # end if train_iter = build_dataset_iter(shard_base, fields, opt) # end if else: assert semaphore is not None, "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch # end while # end def train_iter = _train_iter() # end if valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): self.logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: self.logger.info('Starting training on CPU, could be very slow') # end if train_steps = opt.train_steps if opt.single_pass and train_steps > 0: self.logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 # end if trainer.train( train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) time_begin = trainer.report_manager.start_time time_end = time.time() if opt.tensorboard: trainer.report_manager.tensorboard_writer.close() # Dump train metrics train_history = trainer.report_manager.get_joint_history() train_metrics = { "time_begin": time_begin, "time_end": time_end, "time": time_end - time_begin, "train_history": train_history, } IOUtils.dump(output_model_dir/"train-metrics.json", train_metrics, IOUtils.Format.jsonNoSort) # Get the best step, depending on the lowest val_xent (cross entropy) best_loss = min([th["val_xent"] for th in train_history]) best_step = [th["step"] for th in train_history if th["val_xent"] == best_loss][-1] # Take the last if multiple IOUtils.dump(output_model_dir/"best-step.json", best_step, IOUtils.Format.json) return
def train_impl(self, train_processed_data_dir: Path, val_processed_data_dir: Path, output_model_dir: Path, ) -> NoReturn: from train import _get_parser as train_get_parser from train import ErrorHandler, batch_producer from onmt.inputters.inputter import old_style_vocab, load_old_vocab, build_dataset_iter, build_dataset_iter_multiple import onmt.utils.distributed from onmt.utils.parse import ArgumentParser with IOUtils.cd(self.open_nmt_path): parser = train_get_parser() opt = parser.parse_args( f" -data {output_model_dir}/processed-data" f" -save_model {output_model_dir}/models/ckpt" ) opt.gpu_ranks = [0] opt.early_stopping = self.config.early_stopping_threshold opt.report_every = 200 opt.valid_steps = 200 opt.save_checkpoint_steps = 200 opt.keep_checkpoint_max = self.config.ckpt_keep_max opt.optim = "adam" opt.learning_rate = self.config.learning_rate opt.max_grad_norm = self.config.max_grad_norm opt.batch_size = self.config.batch_size opt.encoder_type = self.config.encoder opt.decoder_type = self.config.decoder opt.dropout = [self.config.dropout] opt.src_word_vec_size = self.config.dim_embed opt.tgt_word_vec_size = self.config.dim_embed opt.layers = self.config.rnn_num_layers opt.enc_rnn_size = self.config.dim_encoder_hidden opt.dec_rnn_size = self.config.dim_decoder_hidden if self.config.use_attn: opt.global_attention = "general" else: opt.global_attention = "none" # end if if self.config.use_copy: opt.copy_attn = True opt.copy_attn_type = "general" # end if # train.main, one gpu case ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) # Load checkpoint if we resume from a previous training. if opt.train_from: self.logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) self.logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: vocab = torch.load(opt.data + '.vocab.pt') # end if # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # end if if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) # end for train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" # end if train_iter = build_dataset_iter(shard_base, fields, opt) # end if nb_gpu = len(opt.gpu_ranks) if opt.world_size > 1: queues = [] mp = torch.multiprocessing.get_context('spawn') semaphore = mp.Semaphore(opt.world_size * opt.queue_size) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for device_id in range(nb_gpu): q = mp.Queue(opt.queue_size) queues += [q] def run(opt, device_id, error_queue, batch_queue, semaphore): """ run process """ try: gpu_rank = onmt.utils.distributed.multi_init(opt, device_id) if gpu_rank != opt.gpu_ranks[device_id]: raise AssertionError("An error occurred in Distributed initialization") self.train_single(opt, device_id, batch_queue, semaphore) except KeyboardInterrupt: pass # killed by parent, do nothing except Exception: # propagate exception to parent process, keeping original traceback import traceback error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc())) # end try # end def procs.append(mp.Process(target=run, args=(opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() self.logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) # end for producer = mp.Process(target=batch_producer,args=(train_iter, queues, semaphore, opt,), daemon=True) producer.start() error_handler.add_child(producer.pid) for p in procs: p.join() producer.terminate() elif nb_gpu == 1: # case 1 GPU only self.train_single(output_model_dir, opt, 0) else: # case only CPU self.train_single(output_model_dir, opt, -1) # end if # end with return
def main(opt): ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) # Load checkpoint if we resume from a previous training. aux_vocab = None if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] if opt.crosslingual: aux_vocab = checkpoint['aux_vocab'] elif opt.crosslingual: assert opt.crosslingual in ['old', 'lm'] vocab = torch.load(opt.data + '.vocab.pt') aux_vocab = torch.load(opt.aux_train_data + '.vocab.pt') else: vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) def get_fields(vocab): if old_style_vocab(vocab): return load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: return vocab fields = get_fields(vocab) aux_fields = None if opt.crosslingual: aux_fields = get_fields(aux_vocab) if opt.crosslingual: if opt.crosslingual == 'old': aeq(len(opt.eat_formats), 3) fields_info = [ ('train', fields, 'data', Eat2PlainMonoTask, 'base', opt.eat_formats[0]), ('train', aux_fields, 'aux_train_data', Eat2PlainAuxMonoTask, 'aux', opt.eat_formats[1]), ('train', aux_fields, 'aux_train_data', Eat2PlainCrosslingualTask, 'crosslingual', opt.eat_format[2]) ] else: aeq(len(opt.eat_formats), 4) fields_info = [ ('train', fields, 'data', Eat2PlainMonoTask, 'base', opt.eat_formats[0]), ('train', fields, 'data', EatLMMonoTask, 'lm', opt.eat_formats[1]), ('train', aux_fields, 'aux_train_data', Eat2PlainAuxMonoTask, 'aux', opt.eat_formats[2]), ('train', aux_fields, 'aux_train_data', EatLMCrosslingualTask, 'crosslingual', opt.eat_formats[3]) ] train_iter = build_crosslingual_dataset_iter(fields_info, opt) elif len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) nb_gpu = len(opt.gpu_ranks) if opt.world_size > 1: queues = [] mp = torch.multiprocessing.get_context('spawn') semaphore = mp.Semaphore(opt.world_size * opt.queue_size) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for device_id in range(nb_gpu): q = mp.Queue(opt.queue_size) queues += [q] procs.append( mp.Process(target=run, args=(opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) producer = mp.Process(target=batch_producer, args=( train_iter, queues, semaphore, opt, ), daemon=True) producer.start() error_handler.add_child(producer.pid) for p in procs: p.join() producer.terminate() else: device_id = 0 if nb_gpu == 1 else -1 # NOTE Only pass train_iter in my crosslingual mode. train_iter = train_iter if opt.crosslingual else None passed_fields = { 'main': fields, 'crosslingual': aux_fields } if opt.crosslingual else None single_main(opt, device_id, train_iter=train_iter, passed_fields=passed_fields)
def main(opt): ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab( vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # @memray: a temporary workaround, as well as train_single.py line 78 if opt.model_type == "keyphrase": if opt.tgt_type in ["one2one", "multiple"]: del fields['sep_indices'] else: if 'sep_indices' not in fields: sep_indices = Field( use_vocab=False, dtype=torch.long, postprocessing=make_tgt, sequential=False) fields["sep_indices"] = sep_indices if 'src_ex_vocab' not in fields: src_ex_vocab = RawField() fields["src_ex_vocab"] = src_ex_vocab if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) nb_gpu = len(opt.gpu_ranks) print(os.environ['PATH']) if opt.world_size > 1: queues = [] mp = torch.multiprocessing.get_context('spawn') semaphore = mp.Semaphore(opt.world_size * opt.queue_size) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for device_id in range(nb_gpu): q = mp.Queue(opt.queue_size) queues += [q] procs.append(mp.Process(target=run, args=( opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) producer = mp.Process(target=batch_producer, args=(train_iter, queues, semaphore, opt,), daemon=True) producer.start() error_handler.add_child(producer.pid) for p in procs: p.join() producer.terminate() elif nb_gpu == 1: # case 1 GPU only single_main(opt, 0) else: # case only CPU single_main(opt, -1)