def get_user_items_seq(data): # group (u,i) by user and sort by time d = {} for u, i, t in data: if FLAGS.after40 and to_week(t) < 40: continue if u not in d: d[u] = [] d[u].append((i,t)) for u in d: tmp = sorted(d[u], key=lambda x:x[1]) tmp = [x[0] for x in tmp] assert(len(tmp)>0) d[u] = tmp return d
def get_data(raw_data, data_dir=FLAGS.data_dir, combine_att=FLAGS.combine_att, logits_size_tr=FLAGS.item_vocab_size, thresh=FLAGS.vocab_min_thresh, use_user_feature=FLAGS.use_user_feature, test=FLAGS.test, mylog=mylog, use_item_feature=FLAGS.use_item_feature, recommend=False): (data_tr, data_va, u_attr, i_attr, item_ind2logit_ind, logit_ind2item_ind, user_index, item_index) = read_attributed_data(raw_data_dir=raw_data, data_dir=data_dir, combine_att=combine_att, logits_size_tr=logits_size_tr, thresh=thresh, use_user_feature=use_user_feature, use_item_feature=use_item_feature, test=test, mylog=mylog) # remove unk data_tr = [p for p in data_tr if (p[1] in item_ind2logit_ind)] # remove items before week 40 if FLAGS.after40: data_tr = [p for p in data_tr if (to_week(p[2]) >= 40)] # item frequency (for sampling) item_population, p_item = item_frequency(data_tr, FLAGS.power) # UNK and START # print(len(item_ind2logit_ind)) # print(len(logit_ind2item_ind)) # print(len(item_index)) START_ID = len(item_index) # START_ID = i_attr.get_item_last_index() item_ind2logit_ind[START_ID] = 0 seq_all = form_sequence(data_tr, maxlen=FLAGS.L) seq_tr0, seq_va0 = split_train_dev(seq_all, ratio=0.05) # calculate buckets global _buckets _buckets = calculate_buckets(seq_tr0 + seq_va0, FLAGS.L, FLAGS.n_bucket) _buckets = sorted(_buckets) # split_buckets seq_tr = split_buckets(seq_tr0, _buckets) seq_va = split_buckets(seq_va0, _buckets) # get test data if recommend: from evaluate import Evaluation as Evaluate evaluation = Evaluate(raw_data, test=test) uids = evaluation.get_uinds() # abuse of 'uids' : actually uinds seq_test = form_sequence_prediction(seq_all, uids, FLAGS.L, START_ID) _buckets = calculate_buckets(seq_test, FLAGS.L, FLAGS.n_bucket) _buckets = sorted(_buckets) seq_test = split_buckets(seq_test, _buckets) else: seq_test = [] evaluation = None uids = [] # create embedAttr devices = get_device_address(FLAGS.N) with tf.device(devices[0]): u_attr.set_model_size(FLAGS.size) i_attr.set_model_size(FLAGS.size) # if not FLAGS.use_item_feature: # mylog("NOT using item attributes") # i_attr.num_features_cat = 1 # i_attr.num_features_mulhot = 0 # if not FLAGS.use_user_feature: # mylog("NOT using user attributes") # u_attr.num_features_cat = 1 # u_attr.num_features_mulhot = 0 embAttr = embed_attribute.EmbeddingAttribute(u_attr, i_attr, FLAGS.batch_size, FLAGS.n_sampled, _buckets[-1], FLAGS.use_sep_item, item_ind2logit_ind, logit_ind2item_ind, devices=devices) if FLAGS.loss in ["warp", 'mw']: prepare_warp(embAttr, seq_tr0, seq_va0) return seq_tr, seq_va, seq_test, embAttr, START_ID, item_population, p_item, evaluation, uids, user_index, item_index, logit_ind2item_ind