def f_load_data(self, args):
        self.m_data_name = args.data_name
        # self.m_vocab_file = self.m_data_name+"_vocab.json"
        self.m_vocab_file = "vocab.json"
        print("data_dir", args.data_dir)

        with open(os.path.join(args.data_dir, args.data_file), 'rb') as file:
            data = pickle.load(file)

        with open(os.path.join(args.data_dir, self.m_vocab_file), 'r') as file:
            vocab = json.load(file)

        review_corpus = data['review']

        global_user2iid = data['user']
        global_item2iid = data['item']

        vocab_obj = _Vocab()
        vocab_obj.f_set_vocab(vocab['w2i'], vocab['i2w'])

        # item_num = len(item_corpus)

        vocab_obj.f_set_user(global_user2iid)
        vocab_obj.f_set_item(global_item2iid)
        # print("vocab size", vocab_obj.m_vocab_size)

        # train_data = _CLOTHING(args, vocab_obj, review_corpus['train'])
        # # valid_data = Amazon(args, vocab_obj, review_corpus['valid'])
        # valid_data = _CLOTHING_TEST(args, vocab_obj, review_corpus['valid'])

        train_data = _MOVIE(args, vocab_obj, review_corpus['train'])
        valid_data = _MOVIE_TEST(args, vocab_obj, review_corpus['valid'])

        return train_data, valid_data, vocab_obj
예제 #2
0
    def f_load_data_movie(self, args):
        self.m_data_name = args.data_name
        # self.m_vocab_file = self.m_data_name+".vocab.json"
        self.m_vocab_file = args.vocab_file
        self.m_item_boa_file = args.item_boa_file
        self.m_user_boa_file = args.user_boa_file

        train_data_file = args.data_dir+"/new_train.pickle"
        valid_data_file = args.data_dir+"/new_valid.pickle"
        test_data_file = args.data_dir+"/new_valid.pickle"

        # train_data_file = args.data_dir+"/train.pickle"
        # valid_data_file = args.data_dir+"/valid.pickle"
        # test_data_file = args.data_dir+"/test.pickle"
        
        train_df = pd.read_pickle(train_data_file)
        valid_df = pd.read_pickle(valid_data_file)
        test_df = pd.read_pickle(test_data_file)

        user_num = train_df.userid.nunique()
        print("user num", user_num)

        with open(os.path.join(args.data_dir, self.m_vocab_file), 'r',encoding='utf8') as f:
            vocab = json.loads(f.read())

        with open(os.path.join(args.data_dir, self.m_item_boa_file), 'r',encoding='utf8') as f:
            item_boa_dict = json.loads(f.read())

        with open(os.path.join(args.data_dir, self.m_user_boa_file), 'r', encoding='utf8') as f:
            user_boa_dict = json.loads(f.read())

        vocab_obj = _Vocab()
        vocab_obj.f_set_vocab(vocab['t2i'], vocab['i2t'])
        vocab_obj.f_set_user_num(user_num)

        global_user2iid = vocab['user_index']
        global_item2iid = vocab['item_index']

        vocab_obj.f_set_user(global_user2iid)
        vocab_obj.f_set_item(global_item2iid)
        
        print("vocab size", vocab_obj.m_vocab_size)

        train_data = _MOVIE(args, vocab_obj, train_df, item_boa_dict, user_boa_dict)
        valid_data = _MOVIE_TEST(args, vocab_obj, valid_df, item_boa_dict, user_boa_dict)

        batch_size = args.batch_size

        if args.parallel:
            train_sampler = DistributedSampler(dataset=train_data)

            train_loader = DataLoader(dataset=train_data, batch_size=batch_size, sampler=train_sampler, num_workers=8, collate_fn=train_data.collate)
        else:
            train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=train_data.collate)
        test_loader = DataLoader(dataset=valid_data, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=valid_data.collate)

        return train_loader, test_loader, vocab_obj   
예제 #3
0
    def f_load_data_yelp(self, args):
        self.m_data_name = args.data_name

        train_data_file = args.data_dir + "/train_100.pickle"
        valid_data_file = args.data_dir + "/valid.pickle"
        test_data_file = args.data_dir + "/test.pickle"
        # train_data_file = args.data_dir+"/train_100.pickle"
        # valid_data_file = args.data_dir+"/valid.pickle"
        # test_data_file = args.data_dir+"/valid.pickle"

        # train_data_file = args.data_dir+"/self_attn_train_100.pickle"
        # valid_data_file = args.data_dir+"/self_attn_valid.pickle"
        # test_data_file = args.data_dir+"/self_attn_valid.pickle"

        train_df = pd.read_pickle(train_data_file)
        valid_df = pd.read_pickle(valid_data_file)
        test_df = pd.read_pickle(test_data_file)

        self.m_vocab_file = args.vocab_file

        with open(os.path.join(args.data_dir, self.m_vocab_file),
                  'r',
                  encoding='utf8') as f:
            vocab = json.loads(f.read())

        vocab_obj = _Vocab()
        vocab_obj.f_set_vocab(vocab['t2i'], vocab['i2t'])
        vocab_obj.f_set_user(vocab['u2i'])
        vocab_obj.f_set_item(vocab['i2i'])

        train_user_num = train_df.userid.nunique()
        print("train user num", train_user_num)

        train_item_num = train_df.itemid.nunique()
        print("train item num", train_item_num)

        train_pos_tag_num = train_df.pos_tagid.nunique()
        print("train tag num", train_pos_tag_num)

        train_data = _MOVIE(args, train_df)
        valid_data = _MOVIE_TEST(args, valid_df)

        batch_size = args.batch_size

        if args.parallel:
            train_sampler = DistributedSampler(dataset=train_data)
            train_loader = DataLoader(dataset=train_data,
                                      batch_size=batch_size,
                                      sampler=train_sampler,
                                      num_workers=0,
                                      collate_fn=train_data.collate)
        else:
            train_loader = DataLoader(dataset=train_data,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=8,
                                      collate_fn=train_data.collate)
        test_loader = DataLoader(dataset=valid_data,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=valid_data.collate)

        return train_loader, test_loader, vocab_obj