def test(args): ent_vocab = Vocab.load(args.ent) rel_vocab = Vocab.load(args.rel) # preparing data if args.task == 'kbc': test_dat = TripletDataset.load(args.data, ent_vocab, rel_vocab) elif args.task == 'tc': test_dat = LabeledTripletDataset.load(args.data, ent_vocab, rel_vocab) else: raise ValueError('Invalid task: {}'.format(args.task)) print('loading model...') if args.method == 'transe': from models.transe import TransE as Model elif args.method == 'complex': from models.complex import ComplEx as Model elif args.method == 'analogy': from models.analogy import ANALOGY as Model else: raise NotImplementedError if args.filtered: print('loading whole graph...') from utils.graph import TensorTypeGraph graphall = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab) # graphall = TensorTypeGraph.load(args.graphall) else: graphall = None model = Model.load_model(args.model) if args.metric == 'all': evaluator = Evaluator('all', None, args.filtered, False, graphall) if args.filtered: evaluator.prepare_valid(test_dat) all_res = evaluator.run_all_matric(model, test_dat) for metric in sorted(all_res.keys()): print('{:20s}: {}'.format(metric, all_res[metric])) else: evaluator = Evaluator(args.metric, None, False, True, None) res = evaluator.run(model, test_dat) print('{:20s}: {}'.format(args.metric, res))
def train(args): # setting for logging if not os.path.exists(args.log): os.mkdir(args.log) logger = logging.getLogger() logging.basicConfig(level=logging.INFO) log_path = os.path.join(args.log, 'log') file_handler = logging.FileHandler(log_path) fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s') file_handler.setFormatter(fmt) logger.addHandler(file_handler) logger.info('Arguments...') for arg, val in vars(args).items(): logger.info('{:>10} -----> {}'.format(arg, val)) if args.sk: from sklearn.datasets import make_regression x, y = make_regression(n_samples=NUM, n_features=DIM_X, n_targets=DIM_Y, n_informative=args.info, noise=1.) else: x, y = gen_synthetic_data(DIM, DIM_X, DIM_Y, NUM, args.skewx, DIM_X - args.info) train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.2) if args.stdx: logger.info('standardize data...') sc = StandardScaler() sc.fit(train_x) train_x = sc.transform(train_x) test_x = sc.transform(test_x) evaluator = Evaluator(mode=args.mode) if args.mode == 'online': with tf.Graph().as_default(): tf.set_random_seed(46) sess = tf.Session() opt = tf.train.AdamOptimizer( ) # TODO: make other optimizers available if args.method == 'ridgex': from models.ridge_x import RidgeX model = RidgeX(d_x=DIM_X, d_y=DIM_Y, l=args.l) elif args.method == 'ridgey': from models.ridge_y import RidgeY model = RidgeY(d_x=DIM_X, d_y=DIM_Y, l=args.l, l1=args.l1) else: raise NotImplementedError trainer = SimpleTrainer(model=model, epoch=args.epoch, opt=opt, sess=sess, logger=logger) trainer.fit(train_x, train_y) logger.info('evaluation...') evaluator.set_sess(sess) accuracy, skewness = evaluator.run(model, test_x, test_y) elif args.mode == 'closed': if args.method == 'ridgex': from models.closed_solver import RidgeX model = RidgeX(d_x=DIM_X, d_y=DIM_Y, l=args.l) elif args.method == 'ridgey': from models.closed_solver import RidgeY model = RidgeY(d_x=DIM_X, d_y=DIM_Y, l=args.l) else: raise NotImplementedError logger.info('calculating closed form solution...') model.solve(train_x, train_y) logger.info('evaluation...') accuracy, skewness = evaluator.run(model, test_x, test_y) elif args.mode == 'cd': if args.method == 'ridgex': from models.elasticnet import ElasticNetX model = ElasticNetX(d_x=DIM_X, d_y=DIM_Y, alpha=args.l, l1_ratio=args.l1_ratio) elif args.method == 'ridgey': from models.elasticnet import ElasticNetY model = ElasticNetY(d_x=DIM_X, d_y=DIM_Y, alpha=args.l, l1_ratio=args.l1_ratio) logger.info('calculating solution...') model.solve(train_x, train_y) print(model.get_param().tolist()) logger.info('evaluation...') accuracy, skewness = evaluator.run(model, test_x, test_y) else: raise ValueError('Invalid mode: {}'.format(args.mode)) logger.info(' accuracy : {}'.format(accuracy)) logger.info('skewness@10 : {}'.format(skewness))