Example #1
0
def test(args):
    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)

    # preparing data
    test_dat = TripletDataset.load(args.data, ent_vocab, rel_vocab)

    print('loading model...')
    if args.method == 'complex':
        from models.complex import ComplEx as Model
    elif args.method == 'distmult':
        from models.distmult import DistMult as Model
    else:
        raise NotImplementedError

    if args.filtered:
        print('loading whole graph...')
        from utils.graph import TensorTypeGraph
        whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab)
    else:
        whole_graph = None
    evaluator = Evaluator('all', None, args.filtered, whole_graph)
    if args.filtered:
        evaluator.prepare_valid(test_dat)
    model = Model.load_model(args.model)

    all_res = evaluator.run_all_matric(model, test_dat)
    for metric in sorted(all_res.keys()):
        print('{:20s}: {}'.format(metric, all_res[metric]))
def test(args):
    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)

    # preparing data
    test_dat = TripletDataset.load(args.data, ent_vocab, rel_vocab)
    # graph = GraphDataset.load(args.knowledge, ent_vocab, rel_vocab)
    graph = None

    print('loading model...')
    if args.method == 'complex':
        from models.complex import ComplEx as Model
    elif args.method == 'distmult':
        from models.distmult import DistMult as Model
    elif args.method == 'transe':
        from models.transe import TransE as Model
    elif args.method == 'hole':
        from models.hole import HolE as Model
    elif args.method == 'rescal':
        from models.rescal import RESCAL as Model
    elif args.method == 'analogy':
        from models.analogy import ANALOGY as Model
    elif args.method == 'randwalk':
        from models.randwalk import RandWalk as Model
    elif args.method == 'lr':
        from models.lr import LogisticReg as Model
    else:
        raise NotImplementedError

    if args.filtered:
        print('loading whole graph...')
        from utils.graph import TensorTypeGraph
        whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab)
    else:
        whole_graph = None
    evaluator = Evaluator('all', None, args.filtered, whole_graph)
    if args.filtered:
        evaluator.prepare_valid(test_dat)
    model = Model.load_model(args.model)

    if args.method == 'randwalk':
        model.load_wv_model(args.wv_model)

    all_res = evaluator.run_all_matric(model, test_dat, graph)
    for metric in sorted(all_res.keys()):
        print('{:20s}: {}'.format(metric, all_res[metric]))
Example #3
0
def test(args):
    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)

    # preparing data
    if args.task == 'kbc':
        test_dat = TripletDataset.load(args.data, ent_vocab, rel_vocab)
    elif args.task == 'tc':
        test_dat = LabeledTripletDataset.load(args.data, ent_vocab, rel_vocab)
    else:
        raise ValueError('Invalid task: {}'.format(args.task))

    print('loading model...')
    if args.method == 'transe':
        from models.transe import TransE as Model
    elif args.method == 'complex':
        from models.complex import ComplEx as Model
    elif args.method == 'analogy':
        from models.analogy import ANALOGY as Model
    else:
        raise NotImplementedError

    if args.filtered:
        print('loading whole graph...')
        from utils.graph import TensorTypeGraph
        graphall = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab,
                                                 rel_vocab)
        # graphall = TensorTypeGraph.load(args.graphall)
    else:
        graphall = None

    model = Model.load_model(args.model)

    if args.metric == 'all':
        evaluator = Evaluator('all', None, args.filtered, False, graphall)
        if args.filtered:
            evaluator.prepare_valid(test_dat)

        all_res = evaluator.run_all_matric(model, test_dat)
        for metric in sorted(all_res.keys()):
            print('{:20s}: {}'.format(metric, all_res[metric]))
    else:
        evaluator = Evaluator(args.metric, None, False, True, None)
        res = evaluator.run(model, test_dat)
        print('{:20s}: {}'.format(args.metric, res))
Example #4
0
def path_analysis(args):
    ent_vocab = Vocab.load(args.entity)
    rel_vocab = RelationVocab.load(args.relation, inv_flg=True)
    triple_dat = TripletDataset.load(args.triple, ent_vocab, rel_vocab)
    pq_dat = PathQueryDataset.load(args.query, ent_vocab, rel_vocab)
    g = LabeledDiGraph(triple_dat, inv_flg=True)

    # traversal path querys
    n_rel = []
    n_tail = []
    for (sub, rels, _) in pq_dat.samples:
        cur_ents = set([sub])
        for r in rels:
            next_ents = set()
            for e in cur_ents:
                new_ents = g.walk(e, r)
                next_ents.update(new_ents)
            cur_ents = next_ents
        n_rel.append(len(rels))
        n_tail.append(len(cur_ents))
    print(n_rel)
    print(n_tail)
    print('Correlation Coefficient: {}'.format(
        np.corrcoef(n_rel, n_tail)[0, 1]))
    args = p.parse_args()

    assert args.task in ['kbc', 'pq'], 'Invalid task: {}'.format(args.task)
    assert args.metric in ['mrr',
                           'hits'], 'Invalid metric: {}'.format(args.metric)
    if args.metric == 'hits':
        assert args.nbest, 'Please indecate n-best in using hits'

    model = GaussianBilinearModel.load_model(args.model)

    print('Preparing dataset...')
    if args.task == 'kbc':
        ent_vocab = Vocab.load(args.entity)
        rel_vocab = Vocab.load(args.relation)
        dataset = TripletDataset.load(args.data, ent_vocab, rel_vocab)
    elif args.task == 'pq':
        ent_vocab = Vocab.load(args.entity)
        rel_vocab = RelationVocab.load(args.relation, inv_flg=True)
        dataset = PathQueryDataset.load(args.data, ent_vocab, rel_vocab)
        if not hasattr(model, 'inv_flg') or not model.inv_flg:
            print('initializing inverse relation representations...')
            model.init_inverse()

    print('Start evaluation...')
    if args.metric == 'mrr':
        from evaluation import mrr
        # res = mrr.cal_mrr(model, dataset)
        res = mrr.multi_cal_mrr(model, dataset)
        print('MRR: {}'.format(res))
Example #6
0
def train(args):
    # setting for logging
    if not os.path.exists(args.log):
        os.mkdir(args.log)
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)
    log_path = os.path.join(args.log, 'log')
    file_handler = logging.FileHandler(log_path)
    fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    logger.info('Arguments...')
    for arg, val in sorted(vars(args).items()):
        logger.info('{:>10} -----> {}'.format(arg, val))

    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)
    n_entity, n_relation = len(ent_vocab), len(rel_vocab)

    # preparing data
    if args.task == 'kbc':
        train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
        valid_dat = TripletDataset.load(args.valid, ent_vocab, rel_vocab) if args.valid else None
    elif args.task == 'tc':
        assert args.metric == 'acc'
        train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
        valid_dat = LabeledTripletDataset.load(args.valid, ent_vocab, rel_vocab) if args.valid else None
    else:
        raise ValueError('Invalid task: {}'.format(args.task))

    assert args.l1_ratio >= 0 and args.l1_ratio <= 1.0
    if args.l1_ratio == 0:
        logger.info("===== WARNING : l1_ratio has zero value. not inducing sparsity =====")
    if args.opt == 'adarda':
        opt = AdagradRDA(args.lr, args.reg*args.l1_ratio)
    elif args.opt == 'adardamul':
        opt = AdagradRDAmul(args.lr, args.reg*args.l1_ratio)
    else:
        raise NotImplementedError

    if args.reg*(1-args.l1_ratio) > 0:
        opt.set_l2_reg(args.reg*(1-args.l1_ratio))
    # elif args.reg*(1-args.l1_ratio) > 0 and args.onlyl1:
    #     opt.sel_ent_l2_reg(args.reg*(1-args.l1_ratio))
    if args.gradclip > 0:
        opt.set_gradclip(args.gradclip)

    logger.info('building model...')
    if args.method == 'complex':
        from models.complex import ComplEx
        model = ComplEx(n_entity=n_entity,
                        n_relation=n_relation,
                        margin=args.margin,
                        dim=args.dim,
                        mode=args.mode)
    else:
        raise NotImplementedError

    if args.filtered:
        print('loading whole graph...')
        from utils.graph import TensorTypeGraph
        graphall = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab)
    else:
        graphall = None
    evaluator = Evaluator(args.metric, args.nbest, args.filtered, True, graphall) if args.valid else None
    if args.filtered and args.valid:
        evaluator.prepare_valid(valid_dat)
    if args.mode == 'pairwise':
        raise NotImplementedError
        trainer = PairwiseTrainer(model=model, opt=opt, save_step=args.save_step,
                                  batchsize=args.batch, logger=logger,
                                  evaluator=evaluator, valid_dat=valid_dat,
                                  n_negative=args.negative, epoch=args.epoch,
                                  model_dir=args.log, restart=args.restart,
                                  add_re=args.add_re)
    elif args.mode == 'single':
        trainer = SingleTrainer(model=model, opt=opt, save_step=args.save_step,
                                batchsize=args.batch, logger=logger,
                                evaluator=evaluator, valid_dat=valid_dat,
                                n_negative=args.negative, epoch=args.epoch,
                                model_dir=args.log, restart=args.restart,
                                add_re=args.add_re)
    else:
        raise NotImplementedError

    trainer.fit(train_dat)
Example #7
0
def train(args):
    # setting for logging
    if not os.path.exists(args.log):
        os.mkdir(args.log)
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)
    log_path = os.path.join(args.log, 'log')
    file_handler = logging.FileHandler(log_path)
    fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    # TODO: develop the recording of arguments in logging
    logger.info('Arguments...')
    for arg, val in sorted(vars(args).items()):
        logger.info('{:>10} -----> {}'.format(arg, val))

    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)
    n_entity, n_relation = len(ent_vocab), len(rel_vocab)

    # preparing data
    logger.info('preparing data...')
    train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
    valid_dat = TripletDataset.load(args.valid, ent_vocab,
                                    rel_vocab) if args.valid else None

    if args.filtered:
        logger.info('loading whole graph...')
        from utils.graph import TensorTypeGraph
        whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab,
                                                    rel_vocab)
    else:
        whole_graph = None

    if args.opt == 'sgd':
        opt = SGD(args.lr)
    elif args.opt == 'adagrad':
        opt = Adagrad(args.lr)
    else:
        raise NotImplementedError

    if args.l2_reg > 0:
        opt.set_l2_reg(args.l2_reg)
    if args.gradclip > 0:
        opt.set_gradclip(args.gradclip)

    logger.info('loading model...')
    with open(args.load, 'rb') as f:
        model = dill.load(f)

    # evaluator = Evaluator(args.metric, args.nbest, args.filtered, whole_graph) if args.valid or args.synthetic else None
    evaluator = Evaluator(args.metric, args.nbest, args.filtered,
                          whole_graph) if args.valid else None
    # delete args.synthetic to run
    if args.filtered and args.valid:
        evaluator.prepare_valid(valid_dat)
    if args.mode == 'pairwise':
        trainer = PairwiseTrainer(model=model,
                                  opt=opt,
                                  save_step=args.save_step,
                                  batchsize=args.batch,
                                  logger=logger,
                                  evaluator=evaluator,
                                  valid_dat=valid_dat,
                                  n_negative=args.negative,
                                  epoch=args.epoch,
                                  model_dir=args.log)
    elif args.mode == 'single':
        trainer = SingleTrainer(model=model,
                                opt=opt,
                                save_step=args.save_step,
                                batchsize=args.batch,
                                logger=logger,
                                evaluator=evaluator,
                                valid_dat=valid_dat,
                                n_negative=args.negative,
                                epoch=args.epoch,
                                model_dir=args.log)
    else:
        raise NotImplementedError

    trainer.fit(train_dat)

    logger.info('done all')
def train(args):

    if args.log:
        log_dir = args.log
    else:
        log_dir = os.path.join(
            os.path.abspath(os.path.dirname(__file__)),
            '{}'.format(datetime.now().strftime('%Y%m%d_%H:%M')))

    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    # setting for logging
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)
    log_path = os.path.join(log_dir, 'log')
    file_handler = logging.FileHandler(log_path)
    fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    logger.info('Arguments...')
    for arg, val in vars(args).items():
        logger.info('{} : {}'.format(arg, val))

    logger.info('Preparing dataset...')
    if not args.entity or not args.relation:
        # make vocab from train set
        logger.info('Making entity/relation vocab from train data...')
        raise NotImplementedError()
    else:
        ent_vocab = Vocab.load(args.entity)
        rel_vocab = Vocab.load(args.relation)

    n_entity, n_relation = len(ent_vocab), len(rel_vocab)
    train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
    logger.info('')
    if args.valid:
        assert args.metric in ['mrr',
                               'hits'], 'Invalid evaluation metric: {}'.format(
                                   args.metric)
        assert args.metric, 'Please indecate evaluation metric for validation'
        if args.metric == 'hits':
            assert args.nbest, 'Please indecate nbest for hits'
        valid_dat = TripletDataset.load(args.valid, ent_vocab, rel_vocab)

    if args.restart:
        logger.info('Restarting training: {}'.format(args.restart))
        model = GaussianBilinearModel.load_model(args.restart)
    else:
        logger.info('Building new model')
        opt = SGD(args.lr, args.gradclip)
        model = GaussianBilinearModel(n_entity, n_relation, args.dim,
                                      args.cmin, args.cmax, opt, args.tri,
                                      args.init_sigma)

    best_model = None
    best_val = -1
    for epoch in range(args.epoch):
        logger.info('start {} epoch'.format(epoch + 1))
        sum_loss = 0
        start = time.time()
        for i, pos_sample in enumerate(data_iter(train_dat)):
            neg_samples = [(pos_sample[0], pos_sample[1],
                            np.random.randint(n_entity))
                           for _ in range(args.num_negative)]
            for neg_sample in neg_samples:
                loss = model.update(pos_sample, neg_sample)
                sum_loss += loss
                # logger.info('loss: {}'.format(loss))
            # logger.info('processing {} samples in this epoch'.format(i+1))
            print('processing {} samples in this epoch'.format(i + 1))
        logger.info('sum loss: {}'.format(sum_loss))
        logger.info('{} sec/epoch for training'.format(time.time() - start))
        model_path = os.path.join(log_dir, 'model{}'.format(epoch + 1))
        model.save_model(model_path)
        if args.valid and (epoch + 1) % args.evalstep == 0:
            val = evaluation(model, valid_dat, args.metric, args.nbest)
            logger.info('{} in validation: {}'.format(args.metric, val))
            if val > best_val:
                best_model = copy.deepcopy(model)
                best_val = val
                best_epoch = epoch + 1

    if args.valid:
        logger.info('best model is {} epoch'.format(best_epoch))
        model_path = os.path.join(log_dir, 'bestmodel')
        best_model.save_model(model_path)

    logger.info('done all')
Example #9
0
        CAUTION : this relation sampling method is NOT UNIFORM
        """
        _batchsize = len(pos_triplets)
        sample_size = _batchsize * n_negative
        neg_rel_ents = self.sample(sample_size)
        neg_triplets = np.tile(pos_triplets, (n_negative, 1))
        head_rel_tail = np.random.randint(0, 3, sample_size)
        rel_idxs = np.where(head_rel_tail == 1)
        # TODO: fix to sample uniformly
        neg_rel_ents[rel_idxs] = neg_rel_ents[rel_idxs] % self.n_rel
        neg_triplets[np.arange(sample_size), head_rel_tail] = neg_rel_ents
        return neg_triplets

    def sample(self, size):
        return np.random.randint(0, self.n_ent, size=size)


if __name__ == '__main__':
    from collections import defaultdict
    from utils.dataset import TripletDataset
    d = TripletDataset([[0, 0, 1], [0, 0, 2], [1, 0, 1]])
    sampler = UnigramIntSampler.build(d, 3)

    dic = defaultdict(lambda: 0)
    for _ in range(5):
        sample = sampler.sample(10)
        print(sample)
        for i in sample:
            dic[i] += 1
    print(dic)
Example #10
0
File: train.py Project: shaoyx/kbc
def train(args):
    # setting for logging
    if not os.path.exists(args.log):
        os.mkdir(args.log)
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)
    log_path = os.path.join(args.log, 'log')
    file_handler = logging.FileHandler(log_path)
    fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    # TODO: develop the recording of arguments in logging
    logger.info('Arguments...')
    for arg, val in sorted(vars(args).items()):
        logger.info('{:>10} -----> {}'.format(arg, val))

    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)
    n_entity, n_relation = len(ent_vocab), len(rel_vocab)

    # preparing data
    logger.info('preparing data...')
    train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
    valid_dat = TripletDataset.load(args.valid, ent_vocab, rel_vocab) if args.valid else None

    if args.filtered:
        logger.info('loading whole graph...')
        from utils.graph import TensorTypeGraph
        whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab)
    else:
        whole_graph = None

    if args.opt == 'sgd':
        opt = SGD(args.lr)
    elif args.opt == 'adagrad':
        opt = Adagrad(args.lr)
    elif args.opt == 'dsgd':
        opt = DecaySGD(args.lr)
    else:
        raise NotImplementedError

    if args.l2_reg > 0:
        opt.set_l2_reg(args.l2_reg)
    if args.gradclip > 0:
        opt.set_gradclip(args.gradclip)

    logger.info('building model...')
    if args.method == 'complex':
        from models.complex import ComplEx
        model = ComplEx(n_entity=n_entity,
                        n_relation=n_relation,
                        margin=args.margin,
                        dim=args.dim,
                        mode=args.mode)
    elif args.method == 'distmult':
        from models.distmult import DistMult
        model = DistMult(n_entity=n_entity,
                         n_relation=n_relation,
                         margin=args.margin,
                         dim=args.dim,
                         mode=args.mode)
    elif args.method == 'transe':
        from models.transe import TransE
        model = TransE(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)
    elif args.method == 'hole':
        from models.hole import HolE
        model = HolE(n_entity=n_entity,
                     n_relation=n_relation,
                     margin=args.margin,
                     dim=args.dim,
                     mode=args.mode)
    elif args.method == 'rescal':
        from models.rescal import RESCAL
        model = RESCAL(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)
    elif args.method == 'analogy':
        from models.analogy import ANALOGY
        model = ANALOGY(n_entity=n_entity,
                        n_relation=n_relation,
                        margin=args.margin,
                        dim=args.dim,
                        cp_ratio=args.cp_ratio,
                        mode=args.mode)
    elif args.method == 'transe_set':
        from models.transe_set import TransE_set
        model = TransE_set(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)
    elif args.method == 'line':
        from models.line_model import LineModel
        model = LineModel(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)

    else:
        raise NotImplementedError

    # evaluator = Evaluator(args.metric, args.nbest, args.filtered, whole_graph) if args.valid or args.synthetic else None
    evaluator = Evaluator(args.metric, args.nbest, args.filtered, whole_graph) if args.valid else None
    # delete args.synthetic to run
    if args.filtered and args.valid:
        evaluator.prepare_valid(valid_dat)
    if args.mode == 'pairwise':
        trainer = PairwiseTrainer(model=model, opt=opt, save_step=args.save_step,
                                  batchsize=args.batch, logger=logger,
                                  evaluator=evaluator, valid_dat=valid_dat,
                                  n_negative=args.negative, epoch=args.epoch,
                                  model_dir=args.log)
    elif args.mode == 'single':
        trainer = SingleTrainer(model=model, opt=opt, save_step=args.save_step,
                                batchsize=args.batch, logger=logger,
                                evaluator=evaluator, valid_dat=valid_dat,
                                n_negative=args.negative, epoch=args.epoch,
                                model_dir=args.log)
    else:
        raise NotImplementedError

    trainer.fit(train_dat)

    logger.info('done all')
def train(args):
    # setting for logging
    if not os.path.exists(args.log):
        os.mkdir(args.log)
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)
    log_path = os.path.join(args.log, 'log')
    file_handler = logging.FileHandler(log_path)
    fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    # TODO: develop the recording of arguments in logging
    logger.info('Arguments...')
    for arg, val in vars(args).items():
        logger.info('{:>10} -----> {}'.format(arg, val))

    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)
    n_entity, n_relation = len(ent_vocab), len(rel_vocab)

    # preparing data
    logger.info('preparing data...')
    train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
    valid_dat = TripletDataset.load(args.valid, ent_vocab,
                                    rel_vocab) if args.valid else None

    if args.filtered:
        logger.info('loading whole graph...')
        from utils.graph import TensorTypeGraph
        whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab,
                                                    rel_vocab)
    else:
        whole_graph = None

    if args.opt == 'sgd':
        opt = SGD(args.lr)
    elif args.opt == 'adagrad':
        opt = Adagrad(args.lr)
    else:
        raise NotImplementedError

    if args.l2_reg > 0:
        opt.set_l2_reg(args.l2_reg)
    if args.gradclip > 0:
        opt.set_gradclip(args.gradclip)

    logger.info('building model...')
    if args.method == 'complex':
        from models.complex import ComplEx
        model = ComplEx(n_entity=n_entity,
                        n_relation=n_relation,
                        margin=args.margin,
                        dim=args.dim,
                        mode=args.mode)
    elif args.method == 'distmult':
        from models.distmult import DistMult
        model = DistMult(n_entity=n_entity,
                         n_relation=n_relation,
                         margin=args.margin,
                         dim=args.dim,
                         mode=args.mode)
    elif args.method == 'transe':
        from models.transe import TransE
        model = TransE(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)
    elif args.method == 'hole':
        from models.hole import HolE
        model = HolE(n_entity=n_entity,
                     n_relation=n_relation,
                     margin=args.margin,
                     dim=args.dim,
                     mode=args.mode)
    elif args.method == 'rescal':
        from models.rescal import RESCAL
        model = RESCAL(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)
    elif args.method == 'analogy':
        from models.analogy import ANALOGY
        model = ANALOGY(n_entity=n_entity,
                        n_relation=n_relation,
                        margin=args.margin,
                        dim=args.dim,
                        cp_ratio=args.cp_ratio,
                        mode=args.mode)
    elif args.method == 'randwalk':
        from models.randwalk import RandWalk
        logger.info(
            'using random walk model to learning embedding unsupervisedly.')
        model = RandWalk(n_entity=n_entity,
                         n_relation=n_relation,
                         knowledge_path=args.train,
                         ent_vocab=ent_vocab,
                         rel_vocab=rel_vocab,
                         dim=args.dim,
                         output=args.log)
        model.train()
        model.save_model(os.path.join(args.log, model.__class__.__name__))
        return
    elif args.method == "lr":
        from models.lr import LogisticReg
        model = LogisticReg(n_entity=n_entity,
                            n_relation=n_relation,
                            train_path=args.train,
                            ent_vocab=ent_vocab,
                            rel_vocab=rel_vocab,
                            dim=args.dim,
                            output=args.log,
                            wv_model_path=args.wv_model,
                            negative=args.negative,
                            feat_type=args.feat_type)
        starttime = time()
        if args.mode == "triplet_cls":
            logger.info("Training a triple classifer")
            model.train_triple_classifer()
        else:
            model.train()
        endtime = time()
        logger.info("lr model train time {:.6f}".format(endtime - starttime))
        model.save_model(os.path.join(args.log, model.__class__.__name__))
        return
    else:
        raise NotImplementedError

    evaluator = Evaluator(
        args.metric, args.nbest, args.filtered,
        whole_graph) if args.valid or args.synthetic else None
    if args.filtered and args.valid:
        evaluator.prepare_valid(valid_dat)
    if args.mode == 'pairwise':
        trainer = PairwiseTrainer(model=model,
                                  opt=opt,
                                  save_step=args.save_step,
                                  batchsize=args.batch,
                                  logger=logger,
                                  evaluator=evaluator,
                                  valid_dat=valid_dat,
                                  n_negative=args.negative,
                                  epoch=args.epoch,
                                  model_dir=args.log)
    elif args.mode == 'single':
        trainer = SingleTrainer(model=model,
                                opt=opt,
                                save_step=args.save_step,
                                batchsize=args.batch,
                                logger=logger,
                                evaluator=evaluator,
                                valid_dat=valid_dat,
                                n_negative=args.negative,
                                epoch=args.epoch,
                                model_dir=args.log)
    else:
        raise NotImplementedError

    trainer.fit(train_dat)

    logger.info('done all')