Exemplo n.º 1
0
def run():
    parser = argparse.ArgumentParser(description='RecSys 2018 Challenge')
    parser.add_argument('recommender',
                        type=str,
                        help='Recommandation algorithm')
    parser.add_argument('output', type=str, help='Output file')
    parser.add_argument('--dataset',
                        type=str,
                        default='dataset',
                        help='Dataset path')
    parser.add_argument('--no-dry',
                        dest='dry',
                        action='store_false',
                        default=True,
                        help='Real run')
    parser.add_argument('--w2r',
                        type=str,
                        default='models/fast_text/w2r.bin',
                        help='Word2Rec model')
    parser.add_argument('--pl',
                        type=str,
                        default='models/fast_text/pl.bin',
                        help='Playlist embeddings')
    parser.add_argument('--ft',
                        type=str,
                        default='models/fast_text/ft.bin',
                        help='FastText model')
    parser.add_argument('--ft_vec',
                        type=str,
                        default='models/fast_text/ft_vec.bin',
                        help='FastText vector')
    parser.add_argument('--cluster',
                        type=str,
                        default='models/fast_text/cluster.bin',
                        help='Cluster model')
    args = parser.parse_args()

    dataset = Dataset(args.dataset)

    if args.recommender == 'mp':
        baseline.MostPopular(dataset, dry=args.dry).run(args.output)
    elif args.recommender == 'random':
        baseline.Random(dataset, dry=args.dry).run(args.output)
    elif args.recommender == 'random_mp':
        baseline.Random(dataset, dry=args.dry, weighted=True).run(args.output)
    elif args.recommender == 'word2rec_item':
        baseline.Word2Rec(dataset,
                          dry=args.dry,
                          model_file=args.w2r,
                          mode=sentence.Mode.ITEM).run(args.output)
    elif args.recommender == 'word2rec_album':
        baseline.Word2Rec(dataset,
                          dry=args.dry,
                          model_file=args.w2r,
                          mode=sentence.Mode.ALBUM).run(args.output)
    elif args.recommender == 'word2rec_artist':
        baseline.Word2Rec(dataset,
                          dry=args.dry,
                          model_file=args.w2r,
                          mode=sentence.Mode.ARTIST).run(args.output)
    elif args.recommender == 'title2rec':
        title2rec.Title2Rec(dataset,
                            dry=args.dry,
                            w2r_model_file=args.w2r,
                            pl_model_file=args.pl,
                            ft_model_file=args.ft,
                            ft_vec_file=args.ft_vec,
                            cluster_file=args.cluster).run(args.output)
    elif args.recommender == 'wordplustitle2rec':
        title2rec.WordPlusTitle2Rec(dataset,
                                    dry=args.dry,
                                    w2r_model_file=args.w2r,
                                    pl_model_file=args.pl,
                                    ft_model_file=args.ft,
                                    ft_vec_file=args.ft_vec,
                                    cluster_file=args.cluster).run(args.output)
    elif args.recommender == 'title2rec_embs':
        t2r = title2rec.Title2Rec(rnn=True, ft_model_file=args.ft)
        embeddings = np.zeros((1049362, 100), dtype=np.float32)
        for playlist in dataset.reader('playlists.csv', 'items.csv'):
            pid = int(playlist['pid']) + 1
            embeddings[pid] = t2r.get_title_vector_from_playlist(playlist)
        for playlist in dataset.reader('playlists_challenge.csv',
                                       'items_challenge.csv'):
            pid = int(playlist['pid']) + 1
            embeddings[pid] = t2r.get_title_vector_from_playlist(playlist)
        np.save(args.output, embeddings)
    else:
        print('Unknown recommender', args.recommender)
Exemplo n.º 2
0
def main(_):
    if not FLAGS.data_path:
        raise ValueError("Must set --data_path to PTB data directory")

    dataset = Dataset(FLAGS.data_path)

    # the number of tracks plus <eos>
    vocab_size = len(dataset.tracks_uri2id) + 1

    config = get_config()
    config.vocab_size = vocab_size
    eval_config = get_config()

    eval_config.vocab_size = vocab_size
    eval_config.batch_size = 1
    eval_config.num_steps = 1

    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-config.init_scale,
                                                    config.init_scale)
        with tf.name_scope("Train"):

            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                m = PTBModel(is_training=True, config=config)
                tf.summary.scalar("Training Loss", m.cost)
                tf.summary.scalar("Learning Rate", m.lr)

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                mvalid = PTBModel(is_training=False, config=config)
                tf.summary.scalar("Validation Loss", mvalid.cost)

        with tf.name_scope("Test"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                mtest = PTBModel(is_training=False, config=eval_config)

        saver = tf.train.Saver(name='saver',
                               write_version=tf.train.SaverDef.V2)
        sv = tf.train.Supervisor(logdir=FLAGS.save_path,
                                 save_model_secs=0,
                                 save_summaries_secs=0,
                                 saver=saver)

        old_valid_perplexity = 10000000000.0

        # sessconfig = tf.ConfigProto(allow_soft_placement=True)
        # sessconfig.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

        with sv.managed_session() as session:
            if FLAGS.restore_path is not None:
                saver.restore(session,
                              tf.train.latest_checkpoint(FLAGS.restore_path))

            if FLAGS.embs is not None:
                items_embeddings = get_items_embeddings(
                    vocab_size, dataset, FLAGS.lyrics)
                m.assign_items_embeddings(session, items_embeddings)
                mvalid.assign_items_embeddings(session, items_embeddings)
                mtest.assign_items_embeddings(session, items_embeddings)

                if FLAGS.title_embs is not None:
                    playlists_embeddings = np.load(FLAGS.title_embs)
                    m.assign_playlists_embeddings(session,
                                                  playlists_embeddings)
                    mvalid.assign_playlists_embeddings(session,
                                                       playlists_embeddings)
                    mtest.assign_playlists_embeddings(session,
                                                      playlists_embeddings)

            if FLAGS.sample_file is not None:

                fallback = Title2Rec(dataset,
                                     dry=FLAGS.is_dry,
                                     ft_model_file='models/fast_text/ft.bin',
                                     ft_vec_file='models/fast_text/ft_vec.bin')
                writer = dataset.writer(FLAGS.sample_file)

                if FLAGS.is_dry:
                    dataset_reader = dataset.reader('playlists_test.csv',
                                                    'items_test_x.csv')
                else:
                    dataset_reader = dataset.reader('playlists_challenge.csv',
                                                    'items_challenge.csv')

                for i, playlist in enumerate(dataset_reader):
                    print('sampling playlist', i)

                    if len(playlist['items']) == 0:
                        fallback.recommend(playlist)
                    else:
                        if FLAGS.strategy == "rank":
                            do_rank(session, mtest, playlist, 500)
                        elif FLAGS.strategy == "summed_rank":
                            do_summed_rank(session,
                                           mtest,
                                           playlist,
                                           500,
                                           smooth=FLAGS.smooth)
                        elif FLAGS.strategy == "sample":
                            do_sample(session, mtest, playlist, 500)
                        else:
                            raise RuntimeError("Unknown strategy " +
                                               FLAGS.strategy)

                    writer.write(playlist)

            else:

                train_data, valid_data = reader.ptb_raw_data(dataset)
                print('Distinct terms: %d' % vocab_size)

                for i in range(config.max_max_epoch):

                    lr_decay = config.lr_decay**max(i - config.max_epoch, 0)
                    m.assign_lr(session, config.learning_rate * lr_decay)
                    print("Epoch: %d Learning rate: %.3f" %
                          (i + 1, session.run(m.lr)))
                    train_perplexity = run_epoch(session,
                                                 m,
                                                 train_data,
                                                 is_train=True,
                                                 verbose=True,
                                                 sv=sv,
                                                 epoch=i)
                    print("Epoch: %d Train Perplexity: %.3f" %
                          (i + 1, train_perplexity))
                    valid_perplexity = run_epoch(session, mvalid, valid_data)
                    print("Epoch: %d Valid Perplexity: %.3f" %
                          (i + 1, valid_perplexity))
                    if valid_perplexity < old_valid_perplexity:
                        old_valid_perplexity = valid_perplexity
                        sv.saver.save(session, FLAGS.save_path, (i + 1) * 100)
                    elif valid_perplexity >= 1.3 * old_valid_perplexity:
                        if len(sv.saver.last_checkpoints) > 0:
                            sv.saver.restore(session,
                                             sv.saver.last_checkpoints[-1])
                        break
                    else:
                        if len(sv.saver.last_checkpoints) > 0:
                            sv.saver.restore(session,
                                             sv.saver.last_checkpoints[-1])
                        lr_decay *= 0.5