示例#1
0
def test(args):
    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)

    # preparing data
    test_dat = TripletDataset.load(args.data, ent_vocab, rel_vocab)

    print('loading model...')
    if args.method == 'complex':
        from models.complex import ComplEx as Model
    elif args.method == 'distmult':
        from models.distmult import DistMult as Model
    else:
        raise NotImplementedError

    if args.filtered:
        print('loading whole graph...')
        from utils.graph import TensorTypeGraph
        whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab)
    else:
        whole_graph = None
    evaluator = Evaluator('all', None, args.filtered, whole_graph)
    if args.filtered:
        evaluator.prepare_valid(test_dat)
    model = Model.load_model(args.model)

    all_res = evaluator.run_all_matric(model, test_dat)
    for metric in sorted(all_res.keys()):
        print('{:20s}: {}'.format(metric, all_res[metric]))
示例#2
0
def test(args):
    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)

    # preparing data
    test_dat = TripletDataset.load(args.data, ent_vocab, rel_vocab)
    # graph = GraphDataset.load(args.knowledge, ent_vocab, rel_vocab)
    graph = None

    print('loading model...')
    if args.method == 'complex':
        from models.complex import ComplEx as Model
    elif args.method == 'distmult':
        from models.distmult import DistMult as Model
    elif args.method == 'transe':
        from models.transe import TransE as Model
    elif args.method == 'hole':
        from models.hole import HolE as Model
    elif args.method == 'rescal':
        from models.rescal import RESCAL as Model
    elif args.method == 'analogy':
        from models.analogy import ANALOGY as Model
    elif args.method == 'randwalk':
        from models.randwalk import RandWalk as Model
    elif args.method == 'lr':
        from models.lr import LogisticReg as Model
    else:
        raise NotImplementedError

    if args.filtered:
        print('loading whole graph...')
        from utils.graph import TensorTypeGraph
        whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab)
    else:
        whole_graph = None
    evaluator = Evaluator('all', None, args.filtered, whole_graph)
    if args.filtered:
        evaluator.prepare_valid(test_dat)
    model = Model.load_model(args.model)

    if args.method == 'randwalk':
        model.load_wv_model(args.wv_model)

    all_res = evaluator.run_all_matric(model, test_dat, graph)
    for metric in sorted(all_res.keys()):
        print('{:20s}: {}'.format(metric, all_res[metric]))
示例#3
0
def test(args):
    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)

    # preparing data
    if args.task == 'kbc':
        test_dat = TripletDataset.load(args.data, ent_vocab, rel_vocab)
    elif args.task == 'tc':
        test_dat = LabeledTripletDataset.load(args.data, ent_vocab, rel_vocab)
    else:
        raise ValueError('Invalid task: {}'.format(args.task))

    print('loading model...')
    if args.method == 'transe':
        from models.transe import TransE as Model
    elif args.method == 'complex':
        from models.complex import ComplEx as Model
    elif args.method == 'analogy':
        from models.analogy import ANALOGY as Model
    else:
        raise NotImplementedError

    if args.filtered:
        print('loading whole graph...')
        from utils.graph import TensorTypeGraph
        graphall = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab,
                                                 rel_vocab)
        # graphall = TensorTypeGraph.load(args.graphall)
    else:
        graphall = None

    model = Model.load_model(args.model)

    if args.metric == 'all':
        evaluator = Evaluator('all', None, args.filtered, False, graphall)
        if args.filtered:
            evaluator.prepare_valid(test_dat)

        all_res = evaluator.run_all_matric(model, test_dat)
        for metric in sorted(all_res.keys()):
            print('{:20s}: {}'.format(metric, all_res[metric]))
    else:
        evaluator = Evaluator(args.metric, None, False, True, None)
        res = evaluator.run(model, test_dat)
        print('{:20s}: {}'.format(args.metric, res))
示例#4
0
    def __init__(self, model_file_path):

        model_name = os.path.basename(model_file_path)
        self._test_dir = os.path.join(config.log_root,
                                      'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._test_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._test_dir, 'rouge_dec')
        for p in [self._test_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path=config.decode_data_path,
                               vocab=self.vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)
示例#5
0
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(self.vocab, config.train_data_path,
                               config.batch_size, single_pass=False, mode='train')
        time.sleep(10)

        train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'models')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)
示例#6
0
    def __init__(self, model_path):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(self.vocab,
                               config.eval_data_path,
                               mode='eval',
                               batch_size=config.batch_size,
                               single_pass=True)
        time.sleep(15)
        model_name = os.path.basename(model_path)

        eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name))
        if not os.path.exists(eval_dir):
            os.mkdir(eval_dir)
        self.summary_writer = tf.summary.FileWriter(eval_dir)

        self.model = Model(model_path, is_eval=True)
示例#7
0
def train(args):
    # setting for logging
    if not os.path.exists(args.log):
        os.mkdir(args.log)
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)
    log_path = os.path.join(args.log, 'log')
    file_handler = logging.FileHandler(log_path)
    fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    logger.info('Arguments...')
    for arg, val in sorted(vars(args).items()):
        logger.info('{:>10} -----> {}'.format(arg, val))

    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)
    n_entity, n_relation = len(ent_vocab), len(rel_vocab)

    # preparing data
    if args.task == 'kbc':
        train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
        valid_dat = TripletDataset.load(args.valid, ent_vocab, rel_vocab) if args.valid else None
    elif args.task == 'tc':
        assert args.metric == 'acc'
        train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
        valid_dat = LabeledTripletDataset.load(args.valid, ent_vocab, rel_vocab) if args.valid else None
    else:
        raise ValueError('Invalid task: {}'.format(args.task))

    assert args.l1_ratio >= 0 and args.l1_ratio <= 1.0
    if args.l1_ratio == 0:
        logger.info("===== WARNING : l1_ratio has zero value. not inducing sparsity =====")
    if args.opt == 'adarda':
        opt = AdagradRDA(args.lr, args.reg*args.l1_ratio)
    elif args.opt == 'adardamul':
        opt = AdagradRDAmul(args.lr, args.reg*args.l1_ratio)
    else:
        raise NotImplementedError

    if args.reg*(1-args.l1_ratio) > 0:
        opt.set_l2_reg(args.reg*(1-args.l1_ratio))
    # elif args.reg*(1-args.l1_ratio) > 0 and args.onlyl1:
    #     opt.sel_ent_l2_reg(args.reg*(1-args.l1_ratio))
    if args.gradclip > 0:
        opt.set_gradclip(args.gradclip)

    logger.info('building model...')
    if args.method == 'complex':
        from models.complex import ComplEx
        model = ComplEx(n_entity=n_entity,
                        n_relation=n_relation,
                        margin=args.margin,
                        dim=args.dim,
                        mode=args.mode)
    else:
        raise NotImplementedError

    if args.filtered:
        print('loading whole graph...')
        from utils.graph import TensorTypeGraph
        graphall = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab)
    else:
        graphall = None
    evaluator = Evaluator(args.metric, args.nbest, args.filtered, True, graphall) if args.valid else None
    if args.filtered and args.valid:
        evaluator.prepare_valid(valid_dat)
    if args.mode == 'pairwise':
        raise NotImplementedError
        trainer = PairwiseTrainer(model=model, opt=opt, save_step=args.save_step,
                                  batchsize=args.batch, logger=logger,
                                  evaluator=evaluator, valid_dat=valid_dat,
                                  n_negative=args.negative, epoch=args.epoch,
                                  model_dir=args.log, restart=args.restart,
                                  add_re=args.add_re)
    elif args.mode == 'single':
        trainer = SingleTrainer(model=model, opt=opt, save_step=args.save_step,
                                batchsize=args.batch, logger=logger,
                                evaluator=evaluator, valid_dat=valid_dat,
                                n_negative=args.negative, epoch=args.epoch,
                                model_dir=args.log, restart=args.restart,
                                add_re=args.add_re)
    else:
        raise NotImplementedError

    trainer.fit(train_dat)
示例#8
0
文件: re_train.py 项目: shaoyx/kbc
def train(args):
    # setting for logging
    if not os.path.exists(args.log):
        os.mkdir(args.log)
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)
    log_path = os.path.join(args.log, 'log')
    file_handler = logging.FileHandler(log_path)
    fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    # TODO: develop the recording of arguments in logging
    logger.info('Arguments...')
    for arg, val in sorted(vars(args).items()):
        logger.info('{:>10} -----> {}'.format(arg, val))

    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)
    n_entity, n_relation = len(ent_vocab), len(rel_vocab)

    # preparing data
    logger.info('preparing data...')
    train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
    valid_dat = TripletDataset.load(args.valid, ent_vocab,
                                    rel_vocab) if args.valid else None

    if args.filtered:
        logger.info('loading whole graph...')
        from utils.graph import TensorTypeGraph
        whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab,
                                                    rel_vocab)
    else:
        whole_graph = None

    if args.opt == 'sgd':
        opt = SGD(args.lr)
    elif args.opt == 'adagrad':
        opt = Adagrad(args.lr)
    else:
        raise NotImplementedError

    if args.l2_reg > 0:
        opt.set_l2_reg(args.l2_reg)
    if args.gradclip > 0:
        opt.set_gradclip(args.gradclip)

    logger.info('loading model...')
    with open(args.load, 'rb') as f:
        model = dill.load(f)

    # evaluator = Evaluator(args.metric, args.nbest, args.filtered, whole_graph) if args.valid or args.synthetic else None
    evaluator = Evaluator(args.metric, args.nbest, args.filtered,
                          whole_graph) if args.valid else None
    # delete args.synthetic to run
    if args.filtered and args.valid:
        evaluator.prepare_valid(valid_dat)
    if args.mode == 'pairwise':
        trainer = PairwiseTrainer(model=model,
                                  opt=opt,
                                  save_step=args.save_step,
                                  batchsize=args.batch,
                                  logger=logger,
                                  evaluator=evaluator,
                                  valid_dat=valid_dat,
                                  n_negative=args.negative,
                                  epoch=args.epoch,
                                  model_dir=args.log)
    elif args.mode == 'single':
        trainer = SingleTrainer(model=model,
                                opt=opt,
                                save_step=args.save_step,
                                batchsize=args.batch,
                                logger=logger,
                                evaluator=evaluator,
                                valid_dat=valid_dat,
                                n_negative=args.negative,
                                epoch=args.epoch,
                                model_dir=args.log)
    else:
        raise NotImplementedError

    trainer.fit(train_dat)

    logger.info('done all')
示例#9
0
class BeamSearch(object):
    def __init__(self, model_file_path):

        model_name = os.path.basename(model_file_path)
        self._test_dir = os.path.join(config.log_root,
                                      'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._test_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._test_dir, 'rouge_dec')
        for p in [self._test_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path=config.decode_data_path,
                               vocab=self.vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def beam_search(self, batch):
        # single example repeated across the batch
        enc_batch, enc_lens, enc_pos, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, c_t, coverage = \
            get_input_from_batch(batch, use_cuda)

        enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_lens)
        s_t = self.model.reduce_state(enc_h)

        dec_h, dec_c = s_t  # b x hidden_dim
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        # decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(config.BOS_TOKEN)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t[0],
                 coverage=(coverage[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]

        steps = 0
        results = []
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(config.UNK_TOKEN) \
                             for t in latest_tokens]
            y_t = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t = y_t.cuda()
            all_state_h = [h.state[0] for h in beams]
            all_state_c = [h.state[1] for h in beams]
            all_context = [h.context for h in beams]

            s_t = (torch.stack(all_state_h,
                               0).unsqueeze(0), torch.stack(all_state_c,
                                                            0).unsqueeze(0))
            c_t = torch.stack(all_context, 0)

            coverage_t = None
            if config.is_coverage:
                all_coverage = [h.coverage for h in beams]
                coverage_t = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                y_t, s_t, enc_out, enc_fea, enc_padding_mask, c_t, extra_zeros,
                enc_batch_extend_vocab, coverage_t, steps)
            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct.
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(config.EOS_TOKEN):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]

    def run(self):

        counter = 0
        start = time.time()
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = utils.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(dataset.EOS_TOKEN)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            # notice: "original_abstract_sents": 'original' means its datetype is bytes-like.
            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._test_dir)
示例#10
0
                rel_type[rel_vocab.word2id[rel]] = 'n-n'
                rnn = rnn + 1
            elif hpt > 1.5:
                rel_type[rel_vocab.word2id[rel]] = '1-n'
                rn1 = rn1 + 1
            elif tph > 1.5:
                rel_type[rel_vocab.word2id[rel]] = 'n-1'
                r1n = r1n + 1
            else:
                rel_type[rel_vocab.word2id[rel]] = '1-1'
                r11 = r11 + 1
            out_path = "./test.txt"
    return rel_type
    # with open(out_path, "a") as out_file:
    # 	print(rel, rel_type[rel], file = out_file)

    # with open(out_path, "a") as out_file:
    # 	print("n-n", rnn, file=out_file)
    # 	print("1-1", r11, file=out_file)
    # 	print("1-n", r1n, file=out_file)
    # 	print("n-1", rn1, file=out_file)


if __name__ == '__main__':
    ent_path = "../dat/FB15k/train.entlist"
    rel_path = "../dat/FB15k/train.rellist"
    dat_path = "../dat/FB15k/whole.txt"
    print("loading entities & relation")
    ent_vocab = Vocab.load(ent_path)
    rel_vocab = Vocab.load(rel_path)
    rel_type = rel_classify(ent_vocab, rel_vocab, dat_path)
示例#11
0
文件: train.py 项目: shaoyx/kbc
def train(args):
    # setting for logging
    if not os.path.exists(args.log):
        os.mkdir(args.log)
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)
    log_path = os.path.join(args.log, 'log')
    file_handler = logging.FileHandler(log_path)
    fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    # TODO: develop the recording of arguments in logging
    logger.info('Arguments...')
    for arg, val in sorted(vars(args).items()):
        logger.info('{:>10} -----> {}'.format(arg, val))

    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)
    n_entity, n_relation = len(ent_vocab), len(rel_vocab)

    # preparing data
    logger.info('preparing data...')
    train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
    valid_dat = TripletDataset.load(args.valid, ent_vocab, rel_vocab) if args.valid else None

    if args.filtered:
        logger.info('loading whole graph...')
        from utils.graph import TensorTypeGraph
        whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab)
    else:
        whole_graph = None

    if args.opt == 'sgd':
        opt = SGD(args.lr)
    elif args.opt == 'adagrad':
        opt = Adagrad(args.lr)
    elif args.opt == 'dsgd':
        opt = DecaySGD(args.lr)
    else:
        raise NotImplementedError

    if args.l2_reg > 0:
        opt.set_l2_reg(args.l2_reg)
    if args.gradclip > 0:
        opt.set_gradclip(args.gradclip)

    logger.info('building model...')
    if args.method == 'complex':
        from models.complex import ComplEx
        model = ComplEx(n_entity=n_entity,
                        n_relation=n_relation,
                        margin=args.margin,
                        dim=args.dim,
                        mode=args.mode)
    elif args.method == 'distmult':
        from models.distmult import DistMult
        model = DistMult(n_entity=n_entity,
                         n_relation=n_relation,
                         margin=args.margin,
                         dim=args.dim,
                         mode=args.mode)
    elif args.method == 'transe':
        from models.transe import TransE
        model = TransE(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)
    elif args.method == 'hole':
        from models.hole import HolE
        model = HolE(n_entity=n_entity,
                     n_relation=n_relation,
                     margin=args.margin,
                     dim=args.dim,
                     mode=args.mode)
    elif args.method == 'rescal':
        from models.rescal import RESCAL
        model = RESCAL(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)
    elif args.method == 'analogy':
        from models.analogy import ANALOGY
        model = ANALOGY(n_entity=n_entity,
                        n_relation=n_relation,
                        margin=args.margin,
                        dim=args.dim,
                        cp_ratio=args.cp_ratio,
                        mode=args.mode)
    elif args.method == 'transe_set':
        from models.transe_set import TransE_set
        model = TransE_set(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)
    elif args.method == 'line':
        from models.line_model import LineModel
        model = LineModel(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)

    else:
        raise NotImplementedError

    # evaluator = Evaluator(args.metric, args.nbest, args.filtered, whole_graph) if args.valid or args.synthetic else None
    evaluator = Evaluator(args.metric, args.nbest, args.filtered, whole_graph) if args.valid else None
    # delete args.synthetic to run
    if args.filtered and args.valid:
        evaluator.prepare_valid(valid_dat)
    if args.mode == 'pairwise':
        trainer = PairwiseTrainer(model=model, opt=opt, save_step=args.save_step,
                                  batchsize=args.batch, logger=logger,
                                  evaluator=evaluator, valid_dat=valid_dat,
                                  n_negative=args.negative, epoch=args.epoch,
                                  model_dir=args.log)
    elif args.mode == 'single':
        trainer = SingleTrainer(model=model, opt=opt, save_step=args.save_step,
                                batchsize=args.batch, logger=logger,
                                evaluator=evaluator, valid_dat=valid_dat,
                                n_negative=args.negative, epoch=args.epoch,
                                model_dir=args.log)
    else:
        raise NotImplementedError

    trainer.fit(train_dat)

    logger.info('done all')
示例#12
0
def train(args):
    # setting for logging
    if not os.path.exists(args.log):
        os.mkdir(args.log)
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)
    log_path = os.path.join(args.log, 'log')
    file_handler = logging.FileHandler(log_path)
    fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    # TODO: develop the recording of arguments in logging
    logger.info('Arguments...')
    for arg, val in vars(args).items():
        logger.info('{:>10} -----> {}'.format(arg, val))

    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)
    n_entity, n_relation = len(ent_vocab), len(rel_vocab)

    # preparing data
    logger.info('preparing data...')
    train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab)
    valid_dat = TripletDataset.load(args.valid, ent_vocab,
                                    rel_vocab) if args.valid else None

    if args.filtered:
        logger.info('loading whole graph...')
        from utils.graph import TensorTypeGraph
        whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab,
                                                    rel_vocab)
    else:
        whole_graph = None

    if args.opt == 'sgd':
        opt = SGD(args.lr)
    elif args.opt == 'adagrad':
        opt = Adagrad(args.lr)
    else:
        raise NotImplementedError

    if args.l2_reg > 0:
        opt.set_l2_reg(args.l2_reg)
    if args.gradclip > 0:
        opt.set_gradclip(args.gradclip)

    logger.info('building model...')
    if args.method == 'complex':
        from models.complex import ComplEx
        model = ComplEx(n_entity=n_entity,
                        n_relation=n_relation,
                        margin=args.margin,
                        dim=args.dim,
                        mode=args.mode)
    elif args.method == 'distmult':
        from models.distmult import DistMult
        model = DistMult(n_entity=n_entity,
                         n_relation=n_relation,
                         margin=args.margin,
                         dim=args.dim,
                         mode=args.mode)
    elif args.method == 'transe':
        from models.transe import TransE
        model = TransE(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)
    elif args.method == 'hole':
        from models.hole import HolE
        model = HolE(n_entity=n_entity,
                     n_relation=n_relation,
                     margin=args.margin,
                     dim=args.dim,
                     mode=args.mode)
    elif args.method == 'rescal':
        from models.rescal import RESCAL
        model = RESCAL(n_entity=n_entity,
                       n_relation=n_relation,
                       margin=args.margin,
                       dim=args.dim,
                       mode=args.mode)
    elif args.method == 'analogy':
        from models.analogy import ANALOGY
        model = ANALOGY(n_entity=n_entity,
                        n_relation=n_relation,
                        margin=args.margin,
                        dim=args.dim,
                        cp_ratio=args.cp_ratio,
                        mode=args.mode)
    elif args.method == 'randwalk':
        from models.randwalk import RandWalk
        logger.info(
            'using random walk model to learning embedding unsupervisedly.')
        model = RandWalk(n_entity=n_entity,
                         n_relation=n_relation,
                         knowledge_path=args.train,
                         ent_vocab=ent_vocab,
                         rel_vocab=rel_vocab,
                         dim=args.dim,
                         output=args.log)
        model.train()
        model.save_model(os.path.join(args.log, model.__class__.__name__))
        return
    elif args.method == "lr":
        from models.lr import LogisticReg
        model = LogisticReg(n_entity=n_entity,
                            n_relation=n_relation,
                            train_path=args.train,
                            ent_vocab=ent_vocab,
                            rel_vocab=rel_vocab,
                            dim=args.dim,
                            output=args.log,
                            wv_model_path=args.wv_model,
                            negative=args.negative,
                            feat_type=args.feat_type)
        starttime = time()
        if args.mode == "triplet_cls":
            logger.info("Training a triple classifer")
            model.train_triple_classifer()
        else:
            model.train()
        endtime = time()
        logger.info("lr model train time {:.6f}".format(endtime - starttime))
        model.save_model(os.path.join(args.log, model.__class__.__name__))
        return
    else:
        raise NotImplementedError

    evaluator = Evaluator(
        args.metric, args.nbest, args.filtered,
        whole_graph) if args.valid or args.synthetic else None
    if args.filtered and args.valid:
        evaluator.prepare_valid(valid_dat)
    if args.mode == 'pairwise':
        trainer = PairwiseTrainer(model=model,
                                  opt=opt,
                                  save_step=args.save_step,
                                  batchsize=args.batch,
                                  logger=logger,
                                  evaluator=evaluator,
                                  valid_dat=valid_dat,
                                  n_negative=args.negative,
                                  epoch=args.epoch,
                                  model_dir=args.log)
    elif args.mode == 'single':
        trainer = SingleTrainer(model=model,
                                opt=opt,
                                save_step=args.save_step,
                                batchsize=args.batch,
                                logger=logger,
                                evaluator=evaluator,
                                valid_dat=valid_dat,
                                n_negative=args.negative,
                                epoch=args.epoch,
                                model_dir=args.log)
    else:
        raise NotImplementedError

    trainer.fit(train_dat)

    logger.info('done all')