def load_model(model_name, model_dir, model_config, gpu_id=-1):
    # search models
    target_model_dir = utils.get_target_model_dir(model_dir)

    # load model
    model_config = json.load(open(model_config, 'r'))
    if model_name == 'ng':
        logger.info('loading test model {}...'.format(target_model_dir))
        model = RGCNLinkPredict.from_pretrained(target_model_dir, model_config)
    elif model_name == 'bert':
        model = BERTLinkPredict(**model_config)
    elif model_name == 'bert_transe':
        logger.info('loading test model {}...'.format(target_model_dir))
        model = BertEventTransE.from_pretrained(target_model_dir, model_config)
    elif model_name == 'bert_comp':
        logger.info('loading test model {}...'.format(target_model_dir))
        model = BertEventComp.from_pretrained(target_model_dir, model_config)
    elif model_name == 'lstm':
        logger.info('loading test model {}...'.format(target_model_dir))
        model = BertEventLSTM.from_pretrained(target_model_dir, model_config)
    else:
        raise ValueError('not implemented model_name: {}'.format(model_name))

    if gpu_id != -1:
        torch.cuda.set_device(gpu_id)
        model.cuda(gpu_id)
    return model
def test(local_rank, model_dir, data_dir, logger, args):
    # search models
    target_model_dir = utils.get_target_model_dir(model_dir)

    # load model
    logger.info('loading test model {}...'.format(target_model_dir))
    model_config = json.load(open(args.model_config, 'r'))
    model = RGCNLinkPredict.from_pretrained(target_model_dir, model_config)
    if args.gpu_id != -1:
        torch.cuda.set_device(args.gpu_id)
        model.cuda(args.gpu_id)

    # load data
    logger.info('loading test dataset {}...'.format(data_dir))
    test_dataset = get_concat_dataset(data_dir, has_target_edges=True)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=args.eval_batch_size,
                                 shuffle=False,
                                 num_workers=1,
                                 collate_fn=test_collate,
                                 pin_memory=False)

    # test
    test_metric = evaluate(local_rank,
                           model,
                           test_dataloader,
                           args.gpu_id,
                           model_name=args.model_name,
                           get_prec_recall_f1=True,
                           logger=logger)
    return test_metric
def get_optimizer(args, model, logger):
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [{
        "params": [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay":
        args.weight_decay
    }, {
        "params": [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        "weight_decay":
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.lr,
                      eps=args.adam_epsilon)
    if args.from_checkpoint:
        target_dir = utils.get_target_model_dir(args.from_checkpoint)
        fpath = os.path.join(target_dir, 'optimizer.pt')
        if os.path.isfile(fpath):
            logger.info('loading optimizer from {}...'.format(fpath))
            optimizer.load_state_dict(torch.load(fpath, map_location='cpu'))
    return optimizer
def get_scheduler(args, n_instances, optimizer, logger):
    t_total = (n_instances // (args.train_batch_size * max(1, args.n_gpus)) //
               args.gradient_accumulation_steps) * args.n_epochs
    warmup_steps = 0
    if args.warmup_steps > 0:
        warmup_steps = args.warmup_steps
    elif args.warmup_portion > 0:
        warmup_steps = int(t_total * args.warmup_portion)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=t_total)
    logger.info("  Total optimization steps = %d", t_total)
    logger.info("  Warmup steps = %d", warmup_steps)
    if args.from_checkpoint:
        target_dir = utils.get_target_model_dir(args.from_checkpoint)
        fpath = os.path.join(target_dir, 'scheduler.pt')
        if os.path.isfile(fpath):
            logger.info('loading scheduler from {}...'.format(fpath))
            scheduler.load_state_dict(torch.load(fpath, map_location='cpu'))
    return scheduler
def get_init_model(args, logger):
    logger.info('model_name = {}'.format(args.model_name))
    model_config = json.load(open(args.model_config))
    if args.from_checkpoint:
        target_dir = utils.get_target_model_dir(args.from_checkpoint)
        logger.info('loading model from {}...'.format(target_dir))
        if args.model_name == 'ng':
            model = RGCNLinkPredict.from_pretrained(target_dir, model_config)
        elif args.model_name == 'bert_transe':
            n_negs = args.n_neg_per_pos
            model_config['n_negs'] = n_negs
            model = BertEventTransE.from_pretrained(target_dir, model_config)
        elif args.model_name == 'bert_comp':
            n_negs = args.n_neg_per_pos
            model_config['n_negs'] = n_negs
            model = BertEventComp.from_pretrained(target_dir, model_config)
        else:
            raise NotImplementedError
    else:
        if args.model_name == 'ng':
            model = RGCNLinkPredict(**model_config)
        elif args.model_name == 'bert_transe':
            n_negs = args.n_neg_per_pos
            model_config['n_negs'] = n_negs
            model = BertEventTransE(**model_config)
        elif args.model_name == 'bert_comp':
            n_negs = args.n_neg_per_pos
            model_config['n_negs'] = n_negs
            model = BertEventComp(**model_config)
        else:
            raise NotImplementedError
    if args.gpu_id != -1:
        model.cuda(args.gpu_id)
    if args.n_gpus > 1:
        model = DistributedDataParallel(model,
                                        device_ids=[args.gpu_id],
                                        find_unused_parameters=True)
    return model