Ejemplo n.º 1
0
def test(args):
    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)

    # preparing data
    test_dat = TripletDataset.load(args.data, ent_vocab, rel_vocab)

    print('loading model...')
    if args.method == 'complex':
        from models.complex import ComplEx as Model
    elif args.method == 'distmult':
        from models.distmult import DistMult as Model
    else:
        raise NotImplementedError

    if args.filtered:
        print('loading whole graph...')
        from utils.graph import TensorTypeGraph
        whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab)
    else:
        whole_graph = None
    evaluator = Evaluator('all', None, args.filtered, whole_graph)
    if args.filtered:
        evaluator.prepare_valid(test_dat)
    model = Model.load_model(args.model)

    all_res = evaluator.run_all_matric(model, test_dat)
    for metric in sorted(all_res.keys()):
        print('{:20s}: {}'.format(metric, all_res[metric]))
Ejemplo n.º 2
0
def test(args):
    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)

    # preparing data
    test_dat = TripletDataset.load(args.data, ent_vocab, rel_vocab)
    # graph = GraphDataset.load(args.knowledge, ent_vocab, rel_vocab)
    graph = None

    print('loading model...')
    if args.method == 'complex':
        from models.complex import ComplEx as Model
    elif args.method == 'distmult':
        from models.distmult import DistMult as Model
    elif args.method == 'transe':
        from models.transe import TransE as Model
    elif args.method == 'hole':
        from models.hole import HolE as Model
    elif args.method == 'rescal':
        from models.rescal import RESCAL as Model
    elif args.method == 'analogy':
        from models.analogy import ANALOGY as Model
    elif args.method == 'randwalk':
        from models.randwalk import RandWalk as Model
    elif args.method == 'lr':
        from models.lr import LogisticReg as Model
    else:
        raise NotImplementedError

    if args.filtered:
        print('loading whole graph...')
        from utils.graph import TensorTypeGraph
        whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab)
    else:
        whole_graph = None
    evaluator = Evaluator('all', None, args.filtered, whole_graph)
    if args.filtered:
        evaluator.prepare_valid(test_dat)
    model = Model.load_model(args.model)

    if args.method == 'randwalk':
        model.load_wv_model(args.wv_model)

    all_res = evaluator.run_all_matric(model, test_dat, graph)
    for metric in sorted(all_res.keys()):
        print('{:20s}: {}'.format(metric, all_res[metric]))
Ejemplo n.º 3
0
def test(args):
    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)

    # preparing data
    if args.task == 'kbc':
        test_dat = TripletDataset.load(args.data, ent_vocab, rel_vocab)
    elif args.task == 'tc':
        test_dat = LabeledTripletDataset.load(args.data, ent_vocab, rel_vocab)
    else:
        raise ValueError('Invalid task: {}'.format(args.task))

    print('loading model...')
    if args.method == 'transe':
        from models.transe import TransE as Model
    elif args.method == 'complex':
        from models.complex import ComplEx as Model
    elif args.method == 'analogy':
        from models.analogy import ANALOGY as Model
    else:
        raise NotImplementedError

    if args.filtered:
        print('loading whole graph...')
        from utils.graph import TensorTypeGraph
        graphall = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab,
                                                 rel_vocab)
        # graphall = TensorTypeGraph.load(args.graphall)
    else:
        graphall = None

    model = Model.load_model(args.model)

    if args.metric == 'all':
        evaluator = Evaluator('all', None, args.filtered, False, graphall)
        if args.filtered:
            evaluator.prepare_valid(test_dat)

        all_res = evaluator.run_all_matric(model, test_dat)
        for metric in sorted(all_res.keys()):
            print('{:20s}: {}'.format(metric, all_res[metric]))
    else:
        evaluator = Evaluator(args.metric, None, False, True, None)
        res = evaluator.run(model, test_dat)
        print('{:20s}: {}'.format(args.metric, res))
Ejemplo n.º 4
0
def train(args):
    # setting for logging
    if not os.path.exists(args.log):
        os.mkdir(args.log)
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)
    log_path = os.path.join(args.log, 'log')
    file_handler = logging.FileHandler(log_path)
    fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    logger.info('Arguments...')
    for arg, val in sorted(vars(args).items()):
        logger.info('{:>10} -----> {}'.format(arg, val))

    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)
    n_entity, n_relation = len(ent_vocab), len(rel_vocab)

    # preparing data
    if args.task == 'kbc':
        train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
        valid_dat = TripletDataset.load(args.valid, ent_vocab, rel_vocab) if args.valid else None
    elif args.task == 'tc':
        assert args.metric == 'acc'
        train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
        valid_dat = LabeledTripletDataset.load(args.valid, ent_vocab, rel_vocab) if args.valid else None
    else:
        raise ValueError('Invalid task: {}'.format(args.task))

    assert args.l1_ratio >= 0 and args.l1_ratio <= 1.0
    if args.l1_ratio == 0:
        logger.info("===== WARNING : l1_ratio has zero value. not inducing sparsity =====")
    if args.opt == 'adarda':
        opt = AdagradRDA(args.lr, args.reg*args.l1_ratio)
    elif args.opt == 'adardamul':
        opt = AdagradRDAmul(args.lr, args.reg*args.l1_ratio)
    else:
        raise NotImplementedError

    if args.reg*(1-args.l1_ratio) > 0:
        opt.set_l2_reg(args.reg*(1-args.l1_ratio))
    # elif args.reg*(1-args.l1_ratio) > 0 and args.onlyl1:
    #     opt.sel_ent_l2_reg(args.reg*(1-args.l1_ratio))
    if args.gradclip > 0:
        opt.set_gradclip(args.gradclip)

    logger.info('building model...')
    if args.method == 'complex':
        from models.complex import ComplEx
        model = ComplEx(n_entity=n_entity,
                        n_relation=n_relation,
                        margin=args.margin,
                        dim=args.dim,
                        mode=args.mode)
    else:
        raise NotImplementedError

    if args.filtered:
        print('loading whole graph...')
        from utils.graph import TensorTypeGraph
        graphall = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab)
    else:
        graphall = None
    evaluator = Evaluator(args.metric, args.nbest, args.filtered, True, graphall) if args.valid else None
    if args.filtered and args.valid:
        evaluator.prepare_valid(valid_dat)
    if args.mode == 'pairwise':
        raise NotImplementedError
        trainer = PairwiseTrainer(model=model, opt=opt, save_step=args.save_step,
                                  batchsize=args.batch, logger=logger,
                                  evaluator=evaluator, valid_dat=valid_dat,
                                  n_negative=args.negative, epoch=args.epoch,
                                  model_dir=args.log, restart=args.restart,
                                  add_re=args.add_re)
    elif args.mode == 'single':
        trainer = SingleTrainer(model=model, opt=opt, save_step=args.save_step,
                                batchsize=args.batch, logger=logger,
                                evaluator=evaluator, valid_dat=valid_dat,
                                n_negative=args.negative, epoch=args.epoch,
                                model_dir=args.log, restart=args.restart,
                                add_re=args.add_re)
    else:
        raise NotImplementedError

    trainer.fit(train_dat)
Ejemplo n.º 5
0
def train(args):
    # setting for logging
    if not os.path.exists(args.log):
        os.mkdir(args.log)
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)
    log_path = os.path.join(args.log, 'log')
    file_handler = logging.FileHandler(log_path)
    fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    # TODO: develop the recording of arguments in logging
    logger.info('Arguments...')
    for arg, val in sorted(vars(args).items()):
        logger.info('{:>10} -----> {}'.format(arg, val))

    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)
    n_entity, n_relation = len(ent_vocab), len(rel_vocab)

    # preparing data
    logger.info('preparing data...')
    train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
    valid_dat = TripletDataset.load(args.valid, ent_vocab,
                                    rel_vocab) if args.valid else None

    if args.filtered:
        logger.info('loading whole graph...')
        from utils.graph import TensorTypeGraph
        whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab,
                                                    rel_vocab)
    else:
        whole_graph = None

    if args.opt == 'sgd':
        opt = SGD(args.lr)
    elif args.opt == 'adagrad':
        opt = Adagrad(args.lr)
    else:
        raise NotImplementedError

    if args.l2_reg > 0:
        opt.set_l2_reg(args.l2_reg)
    if args.gradclip > 0:
        opt.set_gradclip(args.gradclip)

    logger.info('loading model...')
    with open(args.load, 'rb') as f:
        model = dill.load(f)

    # evaluator = Evaluator(args.metric, args.nbest, args.filtered, whole_graph) if args.valid or args.synthetic else None
    evaluator = Evaluator(args.metric, args.nbest, args.filtered,
                          whole_graph) if args.valid else None
    # delete args.synthetic to run
    if args.filtered and args.valid:
        evaluator.prepare_valid(valid_dat)
    if args.mode == 'pairwise':
        trainer = PairwiseTrainer(model=model,
                                  opt=opt,
                                  save_step=args.save_step,
                                  batchsize=args.batch,
                                  logger=logger,
                                  evaluator=evaluator,
                                  valid_dat=valid_dat,
                                  n_negative=args.negative,
                                  epoch=args.epoch,
                                  model_dir=args.log)
    elif args.mode == 'single':
        trainer = SingleTrainer(model=model,
                                opt=opt,
                                save_step=args.save_step,
                                batchsize=args.batch,
                                logger=logger,
                                evaluator=evaluator,
                                valid_dat=valid_dat,
                                n_negative=args.negative,
                                epoch=args.epoch,
                                model_dir=args.log)
    else:
        raise NotImplementedError

    trainer.fit(train_dat)

    logger.info('done all')
Ejemplo n.º 6
0
def train(args):
    # setting for logging
    if not os.path.exists(args.log):
        os.mkdir(args.log)
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)
    log_path = os.path.join(args.log, 'log')
    file_handler = logging.FileHandler(log_path)
    fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    logger.info('Arguments...')
    for arg, val in vars(args).items():
        logger.info('{:>10} -----> {}'.format(arg, val))

    if args.sk:
        from sklearn.datasets import make_regression
        x, y = make_regression(n_samples=NUM,
                               n_features=DIM_X,
                               n_targets=DIM_Y,
                               n_informative=args.info,
                               noise=1.)
    else:
        x, y = gen_synthetic_data(DIM, DIM_X, DIM_Y, NUM, args.skewx,
                                  DIM_X - args.info)
    train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.2)
    if args.stdx:
        logger.info('standardize data...')
        sc = StandardScaler()
        sc.fit(train_x)
        train_x = sc.transform(train_x)
        test_x = sc.transform(test_x)
    evaluator = Evaluator(mode=args.mode)

    if args.mode == 'online':
        with tf.Graph().as_default():
            tf.set_random_seed(46)
            sess = tf.Session()
            opt = tf.train.AdamOptimizer(
            )  # TODO: make other optimizers available
            if args.method == 'ridgex':
                from models.ridge_x import RidgeX
                model = RidgeX(d_x=DIM_X, d_y=DIM_Y, l=args.l)
            elif args.method == 'ridgey':
                from models.ridge_y import RidgeY
                model = RidgeY(d_x=DIM_X, d_y=DIM_Y, l=args.l, l1=args.l1)
            else:
                raise NotImplementedError

            trainer = SimpleTrainer(model=model,
                                    epoch=args.epoch,
                                    opt=opt,
                                    sess=sess,
                                    logger=logger)
            trainer.fit(train_x, train_y)

            logger.info('evaluation...')
            evaluator.set_sess(sess)
            accuracy, skewness = evaluator.run(model, test_x, test_y)

    elif args.mode == 'closed':
        if args.method == 'ridgex':
            from models.closed_solver import RidgeX
            model = RidgeX(d_x=DIM_X, d_y=DIM_Y, l=args.l)
        elif args.method == 'ridgey':
            from models.closed_solver import RidgeY
            model = RidgeY(d_x=DIM_X, d_y=DIM_Y, l=args.l)
        else:
            raise NotImplementedError

        logger.info('calculating closed form solution...')
        model.solve(train_x, train_y)

        logger.info('evaluation...')
        accuracy, skewness = evaluator.run(model, test_x, test_y)

    elif args.mode == 'cd':
        if args.method == 'ridgex':
            from models.elasticnet import ElasticNetX
            model = ElasticNetX(d_x=DIM_X,
                                d_y=DIM_Y,
                                alpha=args.l,
                                l1_ratio=args.l1_ratio)
        elif args.method == 'ridgey':
            from models.elasticnet import ElasticNetY
            model = ElasticNetY(d_x=DIM_X,
                                d_y=DIM_Y,
                                alpha=args.l,
                                l1_ratio=args.l1_ratio)

        logger.info('calculating solution...')
        model.solve(train_x, train_y)
        print(model.get_param().tolist())

        logger.info('evaluation...')
        accuracy, skewness = evaluator.run(model, test_x, test_y)

    else:
        raise ValueError('Invalid mode: {}'.format(args.mode))

    logger.info('  accuracy  : {}'.format(accuracy))
    logger.info('skewness@10 : {}'.format(skewness))
Ejemplo n.º 7
0
Archivo: train.py Proyecto: shaoyx/kbc
def train(args):
    # setting for logging
    if not os.path.exists(args.log):
        os.mkdir(args.log)
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)
    log_path = os.path.join(args.log, 'log')
    file_handler = logging.FileHandler(log_path)
    fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    # TODO: develop the recording of arguments in logging
    logger.info('Arguments...')
    for arg, val in sorted(vars(args).items()):
        logger.info('{:>10} -----> {}'.format(arg, val))

    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)
    n_entity, n_relation = len(ent_vocab), len(rel_vocab)

    # preparing data
    logger.info('preparing data...')
    train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
    valid_dat = TripletDataset.load(args.valid, ent_vocab, rel_vocab) if args.valid else None

    if args.filtered:
        logger.info('loading whole graph...')
        from utils.graph import TensorTypeGraph
        whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab)
    else:
        whole_graph = None

    if args.opt == 'sgd':
        opt = SGD(args.lr)
    elif args.opt == 'adagrad':
        opt = Adagrad(args.lr)
    elif args.opt == 'dsgd':
        opt = DecaySGD(args.lr)
    else:
        raise NotImplementedError

    if args.l2_reg > 0:
        opt.set_l2_reg(args.l2_reg)
    if args.gradclip > 0:
        opt.set_gradclip(args.gradclip)

    logger.info('building model...')
    if args.method == 'complex':
        from models.complex import ComplEx
        model = ComplEx(n_entity=n_entity,
                        n_relation=n_relation,
                        margin=args.margin,
                        dim=args.dim,
                        mode=args.mode)
    elif args.method == 'distmult':
        from models.distmult import DistMult
        model = DistMult(n_entity=n_entity,
                         n_relation=n_relation,
                         margin=args.margin,
                         dim=args.dim,
                         mode=args.mode)
    elif args.method == 'transe':
        from models.transe import TransE
        model = TransE(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)
    elif args.method == 'hole':
        from models.hole import HolE
        model = HolE(n_entity=n_entity,
                     n_relation=n_relation,
                     margin=args.margin,
                     dim=args.dim,
                     mode=args.mode)
    elif args.method == 'rescal':
        from models.rescal import RESCAL
        model = RESCAL(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)
    elif args.method == 'analogy':
        from models.analogy import ANALOGY
        model = ANALOGY(n_entity=n_entity,
                        n_relation=n_relation,
                        margin=args.margin,
                        dim=args.dim,
                        cp_ratio=args.cp_ratio,
                        mode=args.mode)
    elif args.method == 'transe_set':
        from models.transe_set import TransE_set
        model = TransE_set(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)
    elif args.method == 'line':
        from models.line_model import LineModel
        model = LineModel(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)

    else:
        raise NotImplementedError

    # evaluator = Evaluator(args.metric, args.nbest, args.filtered, whole_graph) if args.valid or args.synthetic else None
    evaluator = Evaluator(args.metric, args.nbest, args.filtered, whole_graph) if args.valid else None
    # delete args.synthetic to run
    if args.filtered and args.valid:
        evaluator.prepare_valid(valid_dat)
    if args.mode == 'pairwise':
        trainer = PairwiseTrainer(model=model, opt=opt, save_step=args.save_step,
                                  batchsize=args.batch, logger=logger,
                                  evaluator=evaluator, valid_dat=valid_dat,
                                  n_negative=args.negative, epoch=args.epoch,
                                  model_dir=args.log)
    elif args.mode == 'single':
        trainer = SingleTrainer(model=model, opt=opt, save_step=args.save_step,
                                batchsize=args.batch, logger=logger,
                                evaluator=evaluator, valid_dat=valid_dat,
                                n_negative=args.negative, epoch=args.epoch,
                                model_dir=args.log)
    else:
        raise NotImplementedError

    trainer.fit(train_dat)

    logger.info('done all')
Ejemplo n.º 8
0
def train(args):
    # setting for logging
    if not os.path.exists(args.log):
        os.mkdir(args.log)
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)
    log_path = os.path.join(args.log, 'log')
    file_handler = logging.FileHandler(log_path)
    fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    # TODO: develop the recording of arguments in logging
    logger.info('Arguments...')
    for arg, val in vars(args).items():
        logger.info('{:>10} -----> {}'.format(arg, val))

    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)
    n_entity, n_relation = len(ent_vocab), len(rel_vocab)

    # preparing data
    logger.info('preparing data...')
    train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
    valid_dat = TripletDataset.load(args.valid, ent_vocab,
                                    rel_vocab) if args.valid else None

    if args.filtered:
        logger.info('loading whole graph...')
        from utils.graph import TensorTypeGraph
        whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab,
                                                    rel_vocab)
    else:
        whole_graph = None

    if args.opt == 'sgd':
        opt = SGD(args.lr)
    elif args.opt == 'adagrad':
        opt = Adagrad(args.lr)
    else:
        raise NotImplementedError

    if args.l2_reg > 0:
        opt.set_l2_reg(args.l2_reg)
    if args.gradclip > 0:
        opt.set_gradclip(args.gradclip)

    logger.info('building model...')
    if args.method == 'complex':
        from models.complex import ComplEx
        model = ComplEx(n_entity=n_entity,
                        n_relation=n_relation,
                        margin=args.margin,
                        dim=args.dim,
                        mode=args.mode)
    elif args.method == 'distmult':
        from models.distmult import DistMult
        model = DistMult(n_entity=n_entity,
                         n_relation=n_relation,
                         margin=args.margin,
                         dim=args.dim,
                         mode=args.mode)
    elif args.method == 'transe':
        from models.transe import TransE
        model = TransE(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)
    elif args.method == 'hole':
        from models.hole import HolE
        model = HolE(n_entity=n_entity,
                     n_relation=n_relation,
                     margin=args.margin,
                     dim=args.dim,
                     mode=args.mode)
    elif args.method == 'rescal':
        from models.rescal import RESCAL
        model = RESCAL(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)
    elif args.method == 'analogy':
        from models.analogy import ANALOGY
        model = ANALOGY(n_entity=n_entity,
                        n_relation=n_relation,
                        margin=args.margin,
                        dim=args.dim,
                        cp_ratio=args.cp_ratio,
                        mode=args.mode)
    elif args.method == 'randwalk':
        from models.randwalk import RandWalk
        logger.info(
            'using random walk model to learning embedding unsupervisedly.')
        model = RandWalk(n_entity=n_entity,
                         n_relation=n_relation,
                         knowledge_path=args.train,
                         ent_vocab=ent_vocab,
                         rel_vocab=rel_vocab,
                         dim=args.dim,
                         output=args.log)
        model.train()
        model.save_model(os.path.join(args.log, model.__class__.__name__))
        return
    elif args.method == "lr":
        from models.lr import LogisticReg
        model = LogisticReg(n_entity=n_entity,
                            n_relation=n_relation,
                            train_path=args.train,
                            ent_vocab=ent_vocab,
                            rel_vocab=rel_vocab,
                            dim=args.dim,
                            output=args.log,
                            wv_model_path=args.wv_model,
                            negative=args.negative,
                            feat_type=args.feat_type)
        starttime = time()
        if args.mode == "triplet_cls":
            logger.info("Training a triple classifer")
            model.train_triple_classifer()
        else:
            model.train()
        endtime = time()
        logger.info("lr model train time {:.6f}".format(endtime - starttime))
        model.save_model(os.path.join(args.log, model.__class__.__name__))
        return
    else:
        raise NotImplementedError

    evaluator = Evaluator(
        args.metric, args.nbest, args.filtered,
        whole_graph) if args.valid or args.synthetic else None
    if args.filtered and args.valid:
        evaluator.prepare_valid(valid_dat)
    if args.mode == 'pairwise':
        trainer = PairwiseTrainer(model=model,
                                  opt=opt,
                                  save_step=args.save_step,
                                  batchsize=args.batch,
                                  logger=logger,
                                  evaluator=evaluator,
                                  valid_dat=valid_dat,
                                  n_negative=args.negative,
                                  epoch=args.epoch,
                                  model_dir=args.log)
    elif args.mode == 'single':
        trainer = SingleTrainer(model=model,
                                opt=opt,
                                save_step=args.save_step,
                                batchsize=args.batch,
                                logger=logger,
                                evaluator=evaluator,
                                valid_dat=valid_dat,
                                n_negative=args.negative,
                                epoch=args.epoch,
                                model_dir=args.log)
    else:
        raise NotImplementedError

    trainer.fit(train_dat)

    logger.info('done all')