def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) else: checkpoint = None model_opt = opt # Load fields generated from preprocess phase. fields = load_fields(opt, checkpoint) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter( load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter( load_dataset("valid", opt), fields, opt, is_train=False) # Do training. if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt): if opt.gpu == 0: device_id = 0 else: device_id = -1 # dummy_parser = configargparse.ArgumentParser(description='reinforce.py') # opts.model_opts(dummy_parser) # dummy_opt = dummy_parser.parse_known_args([])[0] # # build the model and get the checkpoint and field # fields, model = nmt_model.load_reinforce_model(opt, dummy_opt.__dict__) opt = training_opt_reinforcing(opt, device_id) init_logger(opt.log_file) logger.info("Input args: %r", opt) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) else: checkpoint = None model_opt = opt # Load fields generated from preprocess phase. fields = load_fields(opt, checkpoint) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) optim = build_optim(model, opt, checkpoint) optim.learning_rate = 1e-5 # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) reinforcor = build_reinforcor(model, fields, opt, model_saver=model_saver, optim=optim) # out_file = codecs.open(opt.output, 'w+', 'utf-8') # X_train, X_valid, X_test, y_train, y_valid, y_test = data_loader.test_mosei_emotion_data() # src_path=X_train # src_iter = make_text_iterator_from_file(src_path)#(opt.src) # tgt_path=y_train # tgt_iter=make_text_iterator_from_file(tgt_path) def train_iter_fct(): return build_dataset_iter(load_dataset("train", opt), fields, opt) # if opt.tgt is not None: # tgt_iter = make_text_iterator_from_file(opt.tgt) # else: # tgt_iter = None # reinforcor.reinforce(src_data_iter=src_iter, # tgt_data_iter=tgt_iter, # batch_size=opt.batch_size, # out_file=out_file) reinforcor.reinforce(train_iter_fct, opt.rein_steps)
def main(opt, device_id): # device_id = -1 # 初始化gpu opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) # Load all tensors onto the CPU checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) # 返回值为两个,第一个与parse_args()返回值类型相同 default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt # 把opt中原有的选项也加入新的参数列表中 # 也就是说选项只可以增加而不可以删除或者修改, # 如果是这样,那么后文就不需要opt了? model_opt.__dict__.update(checkpoint['opt'].__dict__) else: # 第一次载入 checkpoint = None model_opt = opt # Load fields generated from preprocess phase. # {"src": Field, "tgt": Field, "indices": Field} # Field中最重要的是vocab属性,其中包含freqs、itos、stoi # freqs是词频,不包含特殊字符 # src : stoi中含有<unk>、<blank>, 不含<s>与</s> # tgt : stoi含有<unk>、<blank>、<s>、</s> # <unk> = 0, <blank>(pad) = 1 fields = load_fields(opt, checkpoint) # Build model. # 第一次应该不需要opt参数,可用model_opt代替 model = build_model(model_opt, opt, fields, checkpoint) # for name, param in model.named_parameters(): # if param.requires_grad: # print(name) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) # 没有模型保存目录则创建该目录 _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) # 打印模型所有参数 # for name, param in model.named_parameters(): # if param.requires_grad: # print(param) def train_iter_fct(): return build_dataset_iter( load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter( load_dataset("valid", opt), fields, opt, is_train=False) # Do training. if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()