def load_model(model_name, model_dir, model_config, gpu_id=-1): # search models target_model_dir = utils.get_target_model_dir(model_dir) # load model model_config = json.load(open(model_config, 'r')) if model_name == 'ng': logger.info('loading test model {}...'.format(target_model_dir)) model = RGCNLinkPredict.from_pretrained(target_model_dir, model_config) elif model_name == 'bert': model = BERTLinkPredict(**model_config) elif model_name == 'bert_transe': logger.info('loading test model {}...'.format(target_model_dir)) model = BertEventTransE.from_pretrained(target_model_dir, model_config) elif model_name == 'bert_comp': logger.info('loading test model {}...'.format(target_model_dir)) model = BertEventComp.from_pretrained(target_model_dir, model_config) elif model_name == 'lstm': logger.info('loading test model {}...'.format(target_model_dir)) model = BertEventLSTM.from_pretrained(target_model_dir, model_config) else: raise ValueError('not implemented model_name: {}'.format(model_name)) if gpu_id != -1: torch.cuda.set_device(gpu_id) model.cuda(gpu_id) return model
def test(local_rank, model_dir, data_dir, logger, args): # search models target_model_dir = utils.get_target_model_dir(model_dir) # load model logger.info('loading test model {}...'.format(target_model_dir)) model_config = json.load(open(args.model_config, 'r')) model = RGCNLinkPredict.from_pretrained(target_model_dir, model_config) if args.gpu_id != -1: torch.cuda.set_device(args.gpu_id) model.cuda(args.gpu_id) # load data logger.info('loading test dataset {}...'.format(data_dir)) test_dataset = get_concat_dataset(data_dir, has_target_edges=True) test_dataloader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=1, collate_fn=test_collate, pin_memory=False) # test test_metric = evaluate(local_rank, model, test_dataloader, args.gpu_id, model_name=args.model_name, get_prec_recall_f1=True, logger=logger) return test_metric
def get_optimizer(args, model, logger): no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [{ "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon) if args.from_checkpoint: target_dir = utils.get_target_model_dir(args.from_checkpoint) fpath = os.path.join(target_dir, 'optimizer.pt') if os.path.isfile(fpath): logger.info('loading optimizer from {}...'.format(fpath)) optimizer.load_state_dict(torch.load(fpath, map_location='cpu')) return optimizer
def get_scheduler(args, n_instances, optimizer, logger): t_total = (n_instances // (args.train_batch_size * max(1, args.n_gpus)) // args.gradient_accumulation_steps) * args.n_epochs warmup_steps = 0 if args.warmup_steps > 0: warmup_steps = args.warmup_steps elif args.warmup_portion > 0: warmup_steps = int(t_total * args.warmup_portion) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total) logger.info(" Total optimization steps = %d", t_total) logger.info(" Warmup steps = %d", warmup_steps) if args.from_checkpoint: target_dir = utils.get_target_model_dir(args.from_checkpoint) fpath = os.path.join(target_dir, 'scheduler.pt') if os.path.isfile(fpath): logger.info('loading scheduler from {}...'.format(fpath)) scheduler.load_state_dict(torch.load(fpath, map_location='cpu')) return scheduler
def get_init_model(args, logger): logger.info('model_name = {}'.format(args.model_name)) model_config = json.load(open(args.model_config)) if args.from_checkpoint: target_dir = utils.get_target_model_dir(args.from_checkpoint) logger.info('loading model from {}...'.format(target_dir)) if args.model_name == 'ng': model = RGCNLinkPredict.from_pretrained(target_dir, model_config) elif args.model_name == 'bert_transe': n_negs = args.n_neg_per_pos model_config['n_negs'] = n_negs model = BertEventTransE.from_pretrained(target_dir, model_config) elif args.model_name == 'bert_comp': n_negs = args.n_neg_per_pos model_config['n_negs'] = n_negs model = BertEventComp.from_pretrained(target_dir, model_config) else: raise NotImplementedError else: if args.model_name == 'ng': model = RGCNLinkPredict(**model_config) elif args.model_name == 'bert_transe': n_negs = args.n_neg_per_pos model_config['n_negs'] = n_negs model = BertEventTransE(**model_config) elif args.model_name == 'bert_comp': n_negs = args.n_neg_per_pos model_config['n_negs'] = n_negs model = BertEventComp(**model_config) else: raise NotImplementedError if args.gpu_id != -1: model.cuda(args.gpu_id) if args.n_gpus > 1: model = DistributedDataParallel(model, device_ids=[args.gpu_id], find_unused_parameters=True) return model