def run(only_forward=False):
    if FLAGS.seed != 0:
        random.seed(FLAGS.seed)
        torch.manual_seed(FLAGS.seed)

    # set visualization
    vis = None
    if FLAGS.has_visualization:
        vis = Visualizer(env=FLAGS.experiment_name,
                         port=FLAGS.visualization_port)
        vis.log(json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True),
                win_name="Parameter")

    # set logger
    log_file = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".log")

    logger = logging.getLogger()
    log_level = logging.DEBUG if FLAGS.log_level == "debug" else logging.INFO
    logger.setLevel(level=log_level)
    # Formatter
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    # FileHandler
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    # StreamHandler
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    logger.info("Flag Values:\n" +
                json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True))

    # load data
    kg_path = os.path.join(os.path.join(FLAGS.data_path, FLAGS.dataset), 'kg')
    eval_files = FLAGS.kg_test_files.split(':')

    train_dataset, eval_datasets, e_map, r_map = load_data(
        kg_path,
        eval_files,
        FLAGS.batch_size,
        logger=logger,
        negtive_samples=FLAGS.negtive_samples)

    entity_total = len(e_map)
    relation_total = len(r_map)

    train_iter, train_total, train_list, train_head_dict, train_tail_dict = train_dataset

    model = init_model(FLAGS, 0, 0, entity_total, relation_total, logger)
    epoch_length = math.ceil(train_total / FLAGS.batch_size)
    trainer = ModelTrainer(model, logger, epoch_length, FLAGS)

    # todo : load ckpt full path
    if FLAGS.load_ckpt_file is not None:
        trainer.loadEmbedding(os.path.join(FLAGS.log_path,
                                           FLAGS.load_ckpt_file),
                              model.state_dict(),
                              cpu=not USE_CUDA)
        model.is_pretrained = True

    # Do an evaluation-only run.
    if only_forward:
        # head_iter, tail_iter, eval_total, eval_list, eval_head_dict, eval_tail_dict
        for i, eval_data in enumerate(eval_datasets):
            all_head_dicts = None
            all_tail_dicts = None
            if FLAGS.filter_wrong_corrupted:
                all_head_dicts = [train_head_dict] + [
                    tmp_data[4]
                    for j, tmp_data in enumerate(eval_datasets) if j != i
                ]
                all_tail_dicts = [train_tail_dict] + [
                    tmp_data[5]
                    for j, tmp_data in enumerate(eval_datasets) if j != i
                ]
            evaluate(FLAGS,
                     model,
                     entity_total,
                     relation_total,
                     eval_data[0],
                     eval_data[1],
                     eval_data[4],
                     eval_data[5],
                     all_head_dicts,
                     all_tail_dicts,
                     logger,
                     eval_descending=False,
                     is_report=FLAGS.is_report)
    else:
        train_loop(FLAGS,
                   model,
                   trainer,
                   train_dataset,
                   eval_datasets,
                   entity_total,
                   relation_total,
                   logger,
                   vis=vis,
                   is_report=False)
    if vis is not None:
        vis.log("Finish!", win_name="Best Performances")
def run(only_forward=False):
    if FLAGS.seed != 0:
        random.seed(FLAGS.seed)
        torch.manual_seed(FLAGS.seed)

    # set visualization
    vis = None
    if FLAGS.has_visualization:
        vis = Visualizer(env=FLAGS.experiment_name,
                         port=FLAGS.visualization_port)
        vis.log(json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True),
                win_name="Parameter")

    # set logger
    log_file = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".log")
    logger = logging.getLogger()
    log_level = logging.DEBUG if FLAGS.log_level == "debug" else logging.INFO
    logger.setLevel(level=log_level)
    # Formatter
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    # FileHandler
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    # StreamHandler
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    logger.info("Flag Values:\n" +
                json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True))

    # load data
    dataset_path = os.path.join(FLAGS.data_path, FLAGS.dataset)
    rec_eval_files = []
    kg_eval_files = []
    if FLAGS.rec_test_files is not None:
        rec_eval_files = FLAGS.rec_test_files.split(':')
    if FLAGS.kg_test_files is not None:
        kg_eval_files = FLAGS.kg_test_files.split(':')

    rating_train_dataset, rating_eval_datasets, u_map, i_map, triple_train_dataset, triple_eval_datasets, e_map, r_map, ikg_map = load_data(
        dataset_path,
        rec_eval_files,
        kg_eval_files,
        FLAGS.batch_size,
        negtive_samples=FLAGS.negtive_samples,
        logger=logger)

    rating_train_iter, rating_train_total, rating_train_list, rating_train_dict = rating_train_dataset

    triple_train_iter, triple_train_total, triple_train_list, triple_train_head_dict, triple_train_tail_dict = triple_train_dataset

    user_total = len(u_map)
    item_total = len(i_map)
    entity_total = len(e_map)
    relation_total = len(r_map)

    if FLAGS.share_embeddings:
        item_entity_total = len(ikg_map)
        entity_total = item_entity_total
        item_total = item_entity_total

    joint_model = init_model(FLAGS,
                             user_total,
                             item_total,
                             entity_total,
                             relation_total,
                             logger,
                             i_map=i_map,
                             e_map=e_map,
                             new_map=ikg_map)

    triple_epoch_length = math.ceil(
        float(triple_train_total) / (1 - FLAGS.joint_ratio) / FLAGS.batch_size)
    rating_epoch_length = math.ceil(
        float(rating_train_total) / FLAGS.joint_ratio / FLAGS.batch_size)

    epoch_length = max(triple_epoch_length, rating_epoch_length)

    trainer = ModelTrainer(joint_model, logger, epoch_length, FLAGS)

    if FLAGS.load_ckpt_file is not None and FLAGS.share_embeddings:
        load_ckpt_files = FLAGS.load_ckpt_file.split(':')
        for filename in load_ckpt_files:
            trainer.loadEmbedding(os.path.join(FLAGS.log_path, filename),
                                  joint_model.state_dict(),
                                  e_remap=e_map,
                                  i_remap=i_map)
        joint_model.is_pretrained = True
    elif FLAGS.load_ckpt_file is not None:
        load_ckpt_files = FLAGS.load_ckpt_file.split(':')
        for filename in load_ckpt_files:
            trainer.loadEmbedding(os.path.join(FLAGS.log_path, filename),
                                  joint_model.state_dict())
        joint_model.is_pretrained = True

    # Do an evaluation-only run.
    if only_forward:
        for i, eval_data in enumerate(rating_eval_datasets):
            all_dicts = None
            if FLAGS.filter_wrong_corrupted:
                all_dicts = [rating_train_dict] + [
                    tmp_data[3]
                    for j, tmp_data in enumerate(rating_eval_datasets)
                    if j != i
                ]
            evaluateRec(
                FLAGS,
                joint_model,
                eval_data[0],
                eval_data[3],
                all_dicts,
                i_map,
                logger,
                eval_descending=True if trainer.model_target == 1 else False,
                is_report=FLAGS.is_report)
        # head_iter, tail_iter, eval_total, eval_list, eval_head_dict, eval_tail_dict
        for i, eval_data in enumerate(triple_eval_datasets):
            all_head_dicts = None
            all_tail_dicts = None
            if FLAGS.filter_wrong_corrupted:
                all_head_dicts = [triple_train_head_dict] + [
                    tmp_data[4]
                    for j, tmp_data in enumerate(triple_eval_datasets)
                    if j != i
                ]
                all_tail_dicts = [triple_train_tail_dict] + [
                    tmp_data[5]
                    for j, tmp_data in enumerate(triple_eval_datasets)
                    if j != i
                ]
            evaluateKG(FLAGS,
                       joint_model,
                       eval_data[0],
                       eval_data[1],
                       eval_data[4],
                       eval_data[5],
                       all_head_dicts,
                       all_tail_dicts,
                       e_map,
                       logger,
                       eval_descending=False,
                       is_report=FLAGS.is_report)
    else:
        train_loop(FLAGS,
                   joint_model,
                   trainer,
                   rating_train_dataset,
                   triple_train_dataset,
                   rating_eval_datasets,
                   triple_eval_datasets,
                   e_map,
                   i_map,
                   ikg_map,
                   logger,
                   vis=vis,
                   is_report=False)
    if vis is not None:
        vis.log("Finish!", win_name="Best Performances")
Exemple #3
0
def run(only_forward=False):
    if FLAGS.seed != 0:
        random.seed(FLAGS.seed)
        torch.manual_seed(FLAGS.seed)

    # set visualization
    vis = None
    if FLAGS.has_visualization:
        vis = Visualizer(env=FLAGS.experiment_name,
                         port=FLAGS.visualization_port)
        vis.log(json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True),
                win_name="Parameter")

    # set logger
    log_file = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".log")
    logger = logging.getLogger()
    log_level = logging.DEBUG if FLAGS.log_level == "debug" else logging.INFO
    logger.setLevel(level=log_level)
    # Formatter
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    # FileHandler
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    # StreamHandler
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    logger.info("Flag Values:\n" +
                json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True))

    # load data
    dataset_path = os.path.join(FLAGS.data_path, FLAGS.dataset)
    eval_files = FLAGS.rec_test_files.split(
        ':') if FLAGS.rec_test_files else []

    train_dataset, eval_datasets, u_map, i_map = load_data(
        dataset_path,
        eval_files,
        FLAGS.batch_size,
        logger=logger,
        negtive_samples=FLAGS.negtive_samples)

    train_iter, train_total, train_list, train_dict = train_dataset

    user_total = max(len(u_map), max(u_map.values()))
    item_total = max(len(i_map), max(i_map.values()))

    model = init_model(FLAGS, user_total, item_total, 0, 0, logger)
    epoch_length = math.ceil(train_total / FLAGS.batch_size)
    trainer = ModelTrainer(model, logger, epoch_length, FLAGS)

    if FLAGS.load_ckpt_file is not None:
        trainer.loadEmbedding(os.path.join(FLAGS.log_path,
                                           FLAGS.load_ckpt_file),
                              model.state_dict(),
                              cpu=not USE_CUDA)
        model.is_pretrained = True

    # Do an evaluation-only run.
    if only_forward:
        for i, eval_data in enumerate(eval_datasets):
            all_dicts = None
            if FLAGS.filter_wrong_corrupted:
                all_dicts = [train_dict] + [
                    tmp_data[3]
                    for j, tmp_data in enumerate(eval_datasets) if j != i
                ]
            evaluate(
                FLAGS,
                model,
                eval_data[0],
                eval_data[3],
                all_dicts,
                logger,
                eval_descending=True if trainer.model_target == 1 else False,
                is_report=FLAGS.is_report)
    else:
        train_loop(FLAGS,
                   model,
                   trainer,
                   train_dataset,
                   eval_datasets,
                   user_total,
                   item_total,
                   logger,
                   vis=vis,
                   is_report=False)
    if vis is not None:
        vis.log("Finish!", win_name="Best Performances")
    torch.save(
        model.state_dict(),
        './embedding/transup-{data:s}-{embed_dim:d}-{lr:f}-{batch_size:d}-{negtive_samples:d}-{lrdecay:f}-no_early_stop_steps-{steps:d}.emb'
        .format(
            negtive_samples=FLAGS.negtive_samples,
            batch_size=FLAGS.batch_size,
            data=FLAGS.dataset,
            embed_dim=FLAGS.embedding_size,
            lr=FLAGS.learning_rate,
            lrdecay=FLAGS.learning_rate_decay_when_no_progress,
            # early_stop_steps=FLAGS.early_stopping_steps_to_wait,
            steps=FLAGS.training_steps,
        ))