def get_fields(vocab): if old_style_vocab(vocab): return load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: return vocab
def load(cls, path, args): vocab = torch.load(path) # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab( vocab, args.model_type, dynamic_dict=args.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: print(f'| [{sn}] dictionary: {len(sf.vocab)} types') return cls(fields)
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt #vocab = torch.load(opt.data + '.vocab.pt') train_iters = OrderedDict() valid_iters = OrderedDict() encoders = OrderedDict() decoders = OrderedDict() generators = OrderedDict() src_vocabs = OrderedDict() tgt_vocabs = OrderedDict() Fields_dict = OrderedDict() # variables needed for sharing the same embedding matrix across encoders and decoders firstTime = True weightToShare = None # we share the word embedding space when source lang and target lang are the same mapLang2Emb = {} #for (src_tgt_lang), data_path in zip(opt.src_tgt, opt.data): for index in range(len(opt.src_tgt)): src_tgt_lang = opt.src_tgt[index] data_path = opt.data[index] local_enc_dec_opts = AttrDict({ key: model_opt.__dict__[key] for key in model_opt.__dict__.keys() }) local_enc_dec_opts.model_type = update_to_local_attr( model_opt.model_type, index) #local_enc_dec_opts.audio_enc_pooling = model_opt.audio_enc_pooling[index] local_enc_dec_opts.audio_enc_pooling = update_to_local_attr( model_opt.audio_enc_pooling, index) local_enc_dec_opts.enc_layers = update_to_local_attr( model_opt.enc_layers, index) local_enc_dec_opts.dec_layers = update_to_local_attr( model_opt.dec_layers, index) local_enc_dec_opts.rnn_type = update_to_local_attr( model_opt.rnn_type, index) local_enc_dec_opts.encoder_type = update_to_local_attr( model_opt.encoder_type, index) local_enc_dec_opts.batch_size = update_to_local_attr( model_opt.batch_size, index) local_enc_dec_opts.batch_type = update_to_local_attr( model_opt.batch_type, index) local_enc_dec_opts.normalization = update_to_local_attr( model_opt.normalization, index) #local_enc_dec_opts.dec_rnn_size = model_opt.dec_rnn_size[index] src_lang, tgt_lang = src_tgt_lang.split('-') vocab = torch.load(data_path + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type[0], dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. encoder, src_embeddings = build_embeddings_then_encoder( local_enc_dec_opts, fields) encoders[src_lang] = encoder decoder, generator, tgt_embeddings = build_decoder_and_generator( local_enc_dec_opts, fields) decoders[tgt_lang] = decoder # Share the embedding matrix across all the encoders and decoders - preprocess with share_vocab required. if model_opt.share_embeddings and firstTime: tgt_embeddings.word_lut.weight = src_embeddings.word_lut.weight weightToShare = src_embeddings.word_lut.weight if model_opt.share_embeddings and (not firstTime): tgt_embeddings.word_lut.weight = weightToShare src_embeddings.word_lut.weight = weightToShare firstTime = False #TEST #if src_lang in mapLang2Emb: if src_lang in mapLang2Emb and model_opt.model_type == "text": encoder.embeddings.word_lut.weight = mapLang2Emb.get(src_lang) #TEST #else: elif model_opt.model_type == "text": mapLang2Emb[src_lang] = src_embeddings.word_lut.weight if tgt_lang in mapLang2Emb: decoder.embeddings.word_lut.weight = mapLang2Emb.get(tgt_lang) else: mapLang2Emb[tgt_lang] = tgt_embeddings.word_lut.weight #TEST if model_opt.model_type == "text": src_vocabs[src_lang] = fields['src'].base_field.vocab tgt_vocabs[tgt_lang] = fields['tgt'].base_field.vocab generators[tgt_lang] = generator # add this dataset iterator to the training iterators train_iters[(src_lang, tgt_lang)] = build_dataset_iter_fct( 'train', fields, data_path, local_enc_dec_opts) # add this dataset iterator to the validation iterators valid_iters[(src_lang, tgt_lang)] = build_dataset_iter_fct('valid', fields, data_path, local_enc_dec_opts, is_train=False) Fields_dict[src_tgt_lang] = fields # Build model. model = build_model(model_opt, opt, fields, encoders, decoders, generators, src_vocabs, tgt_vocabs, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, Fields_dict, optim) trainer = build_trainer(opt, device_id, model, fields, optim, generators, tgt_vocabs, model_saver=model_saver) # TODO: not implemented yet #train_iterables = [] #if len(opt.data_ids) > 1: # for train_id in opt.data_ids: # shard_base = "train_" + train_id # iterable = build_dataset_iter(shard_base, fields, opt, multi=True) # train_iterables.append(iterable) # train_iter = MultipleDatasetIterator(train_iterables, device_id, opt) #else: # train_iter = build_dataset_iter("train", fields, opt) #valid_iter = build_dataset_iter( # "valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iters, train_steps, opt.save_checkpoint_steps, valid_iters, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab( vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer( opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter( "valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train( train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id, batch_queue=None, semaphore=None, train_iter=None, passed_fields=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) if opt.use_opt_from_trained: model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) else: model_opt = opt ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) aux_fields = None if passed_fields is not None: fields = passed_fields['main'] aux_fields = passed_fields['crosslingual'] elif old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint, aux_fields=aux_fields) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. if opt.almt_only: almt = model.encoder.embeddings.almt_layers['mapping'] logger.info('Only training the alignment mapping.') optim = Optimizer.from_opt(almt, opt, checkpoint=checkpoint) else: optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim, aux_fields=aux_fields) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver, aux_fields=aux_fields) if train_iter is not None: pass # NOTE Use the passed one. elif batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() cl_valid_iter = None if opt.crosslingual: valid_iter = build_dataset_iter("valid", fields, opt, is_train=False, task_cls=Eat2PlainMonoTask) if opt.crosslingual_dev_data: # NOTE I used 'train' to prepare this in `eat_prepare.sh`, so I use 'train' here as well. cl_valid_iter = build_dataset_iter( 'train', fields, opt, is_train=False, data_attr='crosslingual_dev_data', task_cls=Eat2PlainCrosslingualTask) # NOTE This is for the second eat->plain task. aux_valid_iter = build_dataset_iter('valid', fields, opt, is_train=False, data_attr='aux_train_data', task_cls=Eat2PlainMonoTask) valid_iters = [valid_iter, aux_valid_iter] else: valid_iters = [ build_dataset_iter("valid", fields, opt, is_train=False) ] if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iters=valid_iters, valid_steps=opt.valid_steps, cl_valid_iter=cl_valid_iter) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): import pickle # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. import json train_iters = [] with open(opt.data) as json_file: data = json.load(json_file) vocab = data["vocab"] vocab2 = vocab if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(vocab) # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab for key, value in data.items(): if key == ("vocab"): continue elif key.startswith("valid"): valid_iter = (key.split("valid-")[1].split("-"), build_dataset_iter(value, fields, opt, is_train=False)) else: train_iters.append( (key.split("-"), build_dataset_iter(value, fields, opt))) # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) model.critic = critic() model.critic.to(model.device) if model.decoder2 is not None: model.critic2 = critic() model.critic2.to(model.device) else: model.critic2 = None model.critic3 = None n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) vocab = torch.load(vocab2) #print (valid_iter is None) #valid_iter = valid_iter(datas[0][0],valid_iter) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iters, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps, smooth=opt.smooth) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()