def train(args, config): # load data print("Loading data") train_dataset = data.get_loader(config.train_file, config.batch_size, config.version) dev_dataset = data.get_loader(config.dev_file, config.val_batch_size, config.version) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) print("loading embeddings") with open(config.word_emb_file, "rb") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "rb") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) # create model print("creating model") if config.model_type == "model6": Model = model_utils.get_model_func(config.model_type, config.version) # get pretrained model from config # takes a pretrained of QANet model_dict_ = torch.load(config.pretrained_model, pickle_module=dill)["model"] model_dict = {} for key in model_dict_: model_dict[key[7:]] = model_dict_[key] from models import QANet pretrained_model = QANet(word_mat, char_mat, config) # load its state model_data = pretrained_model.state_dict() model_data.update(model_dict) pretrained_model.load_state_dict(model_data) model = Model(pretrained_model, config).to(config.device) model = torch.nn.DataParallel(model) else: Model = model_utils.get_model_func(config.model_type, config.version) model = Model(word_mat, char_mat, config).to(config.device) model = torch.nn.DataParallel(model) print("Training Model") trainer = Trainer(model, train_dataset, dev_dataset, dev_eval_file, config) if args.model_file is not None: trainer.load(args.model_file) trainer.ema.resume(trainer.model) trainer.train()
def train_entry(config): from models import QANet with open(config.word_emb_file, "rb") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "rb") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) print("Building model...") train_dataset = get_loader(config.train_record_file, config.batch_size) dev_dataset = get_loader(config.dev_record_file, config.batch_size) lr = config.learning_rate base_lr = 1 lr_warm_up_num = config.lr_warm_up_num model = QANet(word_mat, char_mat).to(device) if torch.cuda.device_count() > 1: print('i can use gpu') model = torch.nn.DataParallel(model, device_ids=[0, 1]) model.load_state_dict( torch.load('/home/cn/AI/QANet-pytorch-/model_state_dict.pt')) ema = EMA(config.decay) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) parameters = filter(lambda param: param.requires_grad, model.parameters()) optimizer = optim.Adam(lr=base_lr, betas=(0.9, 0.999), eps=1e-7, weight_decay=5e-8, params=parameters) cr = lr / math.log2(lr_warm_up_num) scheduler = optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda ee: cr * math.log2(ee + 1) if ee < lr_warm_up_num else lr) best_f1 = 0 best_em = 0 patience = 0 unused = False for iter in range(config.num_epoch): train(model, optimizer, scheduler, train_dataset, dev_dataset, dev_eval_file, iter, ema) print(iter) ema.assign(model) metrics = test(model, dev_dataset, dev_eval_file, (iter + 1) * len(train_dataset)) dev_f1 = metrics["f1"] dev_em = metrics["exact_match"] if dev_f1 < best_f1 and dev_em < best_em: patience += 1 if patience > config.early_stop: break else: patience = 0 best_f1 = max(best_f1, dev_f1) best_em = max(best_em, dev_em) fn = os.path.join(config.save_dir, "model.pt") torch.save(model, fn) torch.save(model.state_dict(), 'model_state_dict.pt') ema.resume(model)