def train(raw_data=FLAGS.raw_data): with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.device_log)) as sess: run_options = None run_metadata = None if FLAGS.profile: run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() FLAGS.steps_per_checkpoint = 30 mylog("reading data") (seq_tr, items_dev, data_tr, data_va, u_attributes, i_attributes, item_ind2logit_ind, logit_ind2item_ind, end_ind, _, _) = get_data(raw_data, data_dir=FLAGS.data_dir) power = FLAGS.power item_pop, p_item = item_frequency(data_tr, power) if FLAGS.use_more_train: item_population = range(len(item_ind2logit_ind)) else: item_population = item_pop model = create_model(sess, u_attributes, i_attributes, item_ind2logit_ind, logit_ind2item_ind, loss=FLAGS.loss, ind_item=item_population) # data iterators n_skips = FLAGS.ni if FLAGS.model == 'cbow' else FLAGS.num_skips dite = DataIterator(seq_tr, end_ind, FLAGS.batch_size, n_skips, FLAGS.skip_window, False) if FLAGS.model == 'sg': ite = dite.get_next_sg() else: ite = dite.get_next_cbow() mylog('started training') step_time, loss, current_step, auc = 0.0, 0.0, 0, 0.0 repeat = 5 if FLAGS.loss.startswith('bpr') else 1 patience = FLAGS.patience if os.path.isfile(os.path.join(FLAGS.train_dir, 'auc_train.npy')): auc_train = list(np.load(os.path.join(FLAGS.train_dir, 'auc_train.npy'))) auc_dev = list(np.load(os.path.join(FLAGS.train_dir, 'auc_dev.npy'))) previous_losses = list(np.load(os.path.join(FLAGS.train_dir, 'loss_train.npy'))) losses_dev = list(np.load(os.path.join(FLAGS.train_dir, 'loss_dev.npy'))) best_auc = max(auc_dev) best_loss = min(losses_dev) else: previous_losses, auc_train, auc_dev, losses_dev = [], [], [], [] best_auc, best_loss = -1, 1000000 item_sampled, item_sampled_id2idx = None, None train_total_size = float(len(data_tr)) n_epoch = FLAGS.n_epoch steps_per_epoch = int(1.0 * train_total_size / FLAGS.batch_size) total_steps = steps_per_epoch * n_epoch mylog("Train:") mylog("total: {}".format(train_total_size)) mylog("Steps_per_epoch: {}".format(steps_per_epoch)) mylog("Total_steps:{}".format(total_steps)) mylog("Dev:") mylog("total: {}".format(len(data_va))) while True: start_time = time.time() # generate batch of training (user_input, input_items, output_items) = ite.next() if current_step < 5: mylog("current step is {}".format(current_step)) mylog('user') mylog(user_input) mylog('input_item') mylog(input_items) mylog('output_item') mylog(output_items) if FLAGS.loss in ['mw', 'mce'] and current_step % FLAGS.n_resample == 0: item_sampled, item_sampled_id2idx = sample_items(item_population, FLAGS.n_sampled, p_item) else: item_sampled = None step_loss = model.step(sess, user_input, input_items, output_items, item_sampled, item_sampled_id2idx, loss=FLAGS.loss,run_op=run_options, run_meta=run_metadata) # step_loss = 0 step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint loss += step_loss / FLAGS.steps_per_checkpoint # auc += step_auc / FLAGS.steps_per_checkpoint current_step += 1 if current_step > total_steps: mylog("Training reaches maximum steps. Terminating...") break if current_step % FLAGS.steps_per_checkpoint == 0: if FLAGS.loss in ['ce', 'mce']: perplexity = math.exp(loss) if loss < 300 else float('inf') mylog("global step %d learning rate %.4f step-time %.4f perplexity %.2f" % (model.global_step.eval(), model.learning_rate.eval(), step_time, perplexity)) else: mylog("global step %d learning rate %.4f step-time %.4f loss %.3f" % (model.global_step.eval(), model.learning_rate.eval(), step_time, loss)) if FLAGS.profile: # Create the Timeline object, and write it to a json tl = timeline.Timeline(run_metadata.step_stats) ctf = tl.generate_chrome_trace_format() with open('timeline.json', 'w') as f: f.write(ctf) exit() # Decrease learning rate if no improvement was seen over last 3 times. if len(previous_losses) > 2 and loss > max(previous_losses[-3:]): sess.run(model.learning_rate_decay_op) previous_losses.append(loss) auc_train.append(auc) step_time, loss, auc = 0.0, 0.0, 0.0 if not FLAGS.eval: continue # # Save checkpoint and zero timer and loss. # checkpoint_path = os.path.join(FLAGS.train_dir, "go.ckpt") # current_model = model.saver.save(sess, checkpoint_path, # global_step=model.global_step) # Run evals on development set and print their loss/auc. l_va = len(data_va) eval_loss, eval_auc = 0.0, 0.0 count_va = 0 start_time = time.time() for idx_s in range(0, l_va, FLAGS.batch_size): idx_e = idx_s + FLAGS.batch_size if idx_e > l_va: break lt = data_va[idx_s:idx_e] user_va = [x[0] for x in lt] item_va_input = [items_dev[x[0]] for x in lt] item_va_input = map(list, zip(*item_va_input)) item_va = [x[1] for x in lt] the_loss = 'warp' if FLAGS.loss == 'mw' else FLAGS.loss eval_loss0 = model.step(sess, user_va, item_va_input, item_va, forward_only=True, loss=the_loss) eval_loss += eval_loss0 count_va += 1 eval_loss /= count_va eval_auc /= count_va step_time = (time.time() - start_time) / count_va if FLAGS.loss in ['ce', 'mce']: eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf') mylog(" dev: perplexity %.2f eval_auc %.4f step-time %.4f" % ( eval_ppx, eval_auc, step_time)) else: mylog(" dev: loss %.3f eval_auc %.4f step-time %.4f" % (eval_loss, eval_auc, step_time)) sys.stdout.flush() if eval_loss < best_loss and not FLAGS.test: best_loss = eval_loss patience = FLAGS.patience # Save checkpoint and zero timer and loss. checkpoint_path = os.path.join(FLAGS.train_dir, "best.ckpt") model.saver.save(sess, checkpoint_path, global_step=0, write_meta_graph = False) mylog('Saving best model...') if FLAGS.test: checkpoint_path = os.path.join(FLAGS.train_dir, "best.ckpt") model.saver.save(sess, checkpoint_path, global_step=0, write_meta_graph = False) mylog('Saving current model...') if eval_loss > best_loss: # and eval_auc < best_auc: patience -= 1 auc_dev.append(eval_auc) losses_dev.append(eval_loss) if patience < 0 and not FLAGS.test: mylog("no improvement for too long.. terminating..") mylog("best auc %.4f" % best_auc) mylog("best loss %.4f" % best_loss) sys.stdout.flush() break return
def get_data(raw_data, data_dir=FLAGS.data_dir, combine_att=FLAGS.combine_att, logits_size_tr=FLAGS.item_vocab_size, thresh=FLAGS.vocab_min_thresh, use_user_feature=FLAGS.use_user_feature, test=FLAGS.test, mylog=mylog, use_item_feature=FLAGS.use_item_feature, recommend=False): (data_tr, data_va, u_attr, i_attr, item_ind2logit_ind, logit_ind2item_ind, user_index, item_index) = read_attributed_data(raw_data_dir=raw_data, data_dir=data_dir, combine_att=combine_att, logits_size_tr=logits_size_tr, thresh=thresh, use_user_feature=use_user_feature, use_item_feature=use_item_feature, test=test, mylog=mylog) # remove unk data_tr = [p for p in data_tr if (p[1] in item_ind2logit_ind)] # remove items before week 40 if FLAGS.after40: data_tr = [p for p in data_tr if (to_week(p[2]) >= 40)] # item frequency (for sampling) item_population, p_item = item_frequency(data_tr, FLAGS.power) # UNK and START # print(len(item_ind2logit_ind)) # print(len(logit_ind2item_ind)) # print(len(item_index)) START_ID = len(item_index) # START_ID = i_attr.get_item_last_index() item_ind2logit_ind[START_ID] = 0 seq_all = form_sequence(data_tr, maxlen=FLAGS.L) seq_tr0, seq_va0 = split_train_dev(seq_all, ratio=0.05) # calculate buckets global _buckets _buckets = calculate_buckets(seq_tr0 + seq_va0, FLAGS.L, FLAGS.n_bucket) _buckets = sorted(_buckets) # split_buckets seq_tr = split_buckets(seq_tr0, _buckets) seq_va = split_buckets(seq_va0, _buckets) # get test data if recommend: from evaluate import Evaluation as Evaluate evaluation = Evaluate(raw_data, test=test) uids = evaluation.get_uinds() # abuse of 'uids' : actually uinds seq_test = form_sequence_prediction(seq_all, uids, FLAGS.L, START_ID) _buckets = calculate_buckets(seq_test, FLAGS.L, FLAGS.n_bucket) _buckets = sorted(_buckets) seq_test = split_buckets(seq_test, _buckets) else: seq_test = [] evaluation = None uids = [] # create embedAttr devices = get_device_address(FLAGS.N) with tf.device(devices[0]): u_attr.set_model_size(FLAGS.size) i_attr.set_model_size(FLAGS.size) # if not FLAGS.use_item_feature: # mylog("NOT using item attributes") # i_attr.num_features_cat = 1 # i_attr.num_features_mulhot = 0 # if not FLAGS.use_user_feature: # mylog("NOT using user attributes") # u_attr.num_features_cat = 1 # u_attr.num_features_mulhot = 0 embAttr = embed_attribute.EmbeddingAttribute(u_attr, i_attr, FLAGS.batch_size, FLAGS.n_sampled, _buckets[-1], FLAGS.use_sep_item, item_ind2logit_ind, logit_ind2item_ind, devices=devices) if FLAGS.loss in ["warp", 'mw']: prepare_warp(embAttr, seq_tr0, seq_va0) return seq_tr, seq_va, seq_test, embAttr, START_ID, item_population, p_item, evaluation, uids, user_index, item_index, logit_ind2item_ind
def train(raw_data=FLAGS.raw_data, train_dir=FLAGS.train_dir, mylog=mylog, data_dir=FLAGS.data_dir, combine_att=FLAGS.combine_att, test=FLAGS.test, logits_size_tr=FLAGS.item_vocab_size, thresh=FLAGS.item_vocab_min_thresh, use_user_feature=FLAGS.use_user_feature, use_item_feature=FLAGS.use_item_feature, batch_size=FLAGS.batch_size, steps_per_checkpoint=FLAGS.steps_per_checkpoint, loss_func=FLAGS.loss, max_patience=FLAGS.patience, go_test=FLAGS.test, max_epoch=FLAGS.n_epoch, sample_type=FLAGS.sample_type, power=FLAGS.power, use_more_train=FLAGS.use_more_train, profile=FLAGS.profile, device_log=FLAGS.device_log): with tf.Session( config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=device_log)) as sess: run_options = None run_metadata = None if profile: # in order to profile from tensorflow.python.client import timeline run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() steps_per_checkpoint = 30 mylog("reading data") (data_tr, data_va, u_attributes, i_attributes, item_ind2logit_ind, logit_ind2item_ind, _, _) = read_data(raw_data_dir=raw_data, data_dir=data_dir, combine_att=combine_att, logits_size_tr=logits_size_tr, thresh=thresh, use_user_feature=use_user_feature, use_item_feature=use_item_feature, test=test, mylog=mylog) mylog("train/dev size: %d/%d" % (len(data_tr), len(data_va))) ''' remove some rare items in both train and valid set this helps make train/valid set distribution similar to each other ''' mylog("original train/dev size: %d/%d" % (len(data_tr), len(data_va))) data_tr = [p for p in data_tr if (p[1] in item_ind2logit_ind)] data_va = [p for p in data_va if (p[1] in item_ind2logit_ind)] mylog("new train/dev size: %d/%d" % (len(data_tr), len(data_va))) item_pop, p_item = item_frequency(data_tr, power) if use_more_train: item_population = range(len(item_ind2logit_ind)) else: item_population = item_pop model = create_model(sess, u_attributes, i_attributes, item_ind2logit_ind, logit_ind2item_ind, loss=loss_func, ind_item=item_population) pos_item_list, pos_item_list_val = None, None if loss_func in ['warp', 'mw', 'rs', 'rs-sig', 'bbpr']: pos_item_list, pos_item_list_val = positive_items(data_tr, data_va) model.prepare_warp(pos_item_list, pos_item_list_val) mylog('started training') step_time, loss, current_step, auc = 0.0, 0.0, 0, 0.0 repeat = 5 if loss_func.startswith('bpr') else 1 patience = max_patience if os.path.isfile(os.path.join(train_dir, 'auc_train.npy')): auc_train = list(np.load(os.path.join(train_dir, 'auc_train.npy'))) auc_dev = list(np.load(os.path.join(train_dir, 'auc_dev.npy'))) previous_losses = list( np.load(os.path.join(train_dir, 'loss_train.npy'))) losses_dev = list(np.load(os.path.join(train_dir, 'loss_dev.npy'))) best_auc = max(auc_dev) best_loss = min(losses_dev) else: previous_losses, auc_train, auc_dev, losses_dev = [], [], [], [] best_auc, best_loss = -1, 1000000 item_sampled, item_sampled_id2idx = None, None if sample_type == 'random': get_next_batch = model.get_batch elif sample_type == 'permute': get_next_batch = model.get_permuted_batch else: print('not implemented!') exit() train_total_size = float(len(data_tr)) n_epoch = max_epoch steps_per_epoch = int(1.0 * train_total_size / batch_size) total_steps = steps_per_epoch * n_epoch mylog("Train:") mylog("total: {}".format(train_total_size)) mylog("Steps_per_epoch: {}".format(steps_per_epoch)) mylog("Total_steps:{}".format(total_steps)) mylog("Dev:") mylog("total: {}".format(len(data_va))) mylog("\n\ntraining start!") while True: ranndom_number_01 = np.random.random_sample() start_time = time.time() (user_input, item_input, neg_item_input) = get_next_batch(data_tr) if loss_func in ['mw', 'mce' ] and current_step % FLAGS.n_resample == 0: item_sampled, item_sampled_id2idx = sample_items( item_population, FLAGS.n_sampled, p_item) else: item_sampled = None step_loss = model.step(sess, user_input, item_input, neg_item_input, item_sampled, item_sampled_id2idx, loss=loss_func, run_op=run_options, run_meta=run_metadata) step_time += (time.time() - start_time) / steps_per_checkpoint loss += step_loss / steps_per_checkpoint current_step += 1 if current_step > total_steps: mylog("Training reaches maximum steps. Terminating...") break if current_step % steps_per_checkpoint == 0: if loss_func in ['ce', 'mce']: perplexity = math.exp(loss) if loss < 300 else float('inf') mylog( "global step %d learning rate %.4f step-time %.4f perplexity %.2f" % (model.global_step.eval(), model.learning_rate.eval(), step_time, perplexity)) else: mylog( "global step %d learning rate %.4f step-time %.4f loss %.3f" % (model.global_step.eval(), model.learning_rate.eval(), step_time, loss)) if profile: # Create the Timeline object, and write it to a json tl = timeline.Timeline(run_metadata.step_stats) ctf = tl.generate_chrome_trace_format() with open('timeline.json', 'w') as f: f.write(ctf) exit() # Decrease learning rate if no improvement was seen over last 3 times. if len(previous_losses) > 2 and loss > max( previous_losses[-3:]): sess.run(model.learning_rate_decay_op) previous_losses.append(loss) auc_train.append(auc) # Reset timer and loss. step_time, loss, auc = 0.0, 0.0, 0.0 if not FLAGS.eval: continue # Run evals on development set and print their loss. l_va = len(data_va) eval_loss, eval_auc = 0.0, 0.0 count_va = 0 start_time = time.time() for idx_s in range(0, l_va, batch_size): idx_e = idx_s + batch_size if idx_e > l_va: break lt = data_va[idx_s:idx_e] user_va = [x[0] for x in lt] item_va = [x[1] for x in lt] for _ in range(repeat): item_va_neg = None the_loss = 'warp' if loss_func == 'mw' else loss_func eval_loss0 = model.step(sess, user_va, item_va, item_va_neg, None, None, forward_only=True, loss=the_loss) eval_loss += eval_loss0 count_va += 1 eval_loss /= count_va eval_auc /= count_va step_time = (time.time() - start_time) / count_va if loss_func in ['ce', 'mce']: eval_ppx = math.exp( eval_loss) if eval_loss < 300 else float('inf') mylog( " dev: perplexity %.2f eval_auc(not computed) %.4f step-time %.4f" % (eval_ppx, eval_auc, step_time)) else: mylog( " dev: loss %.3f eval_auc(not computed) %.4f step-time %.4f" % (eval_loss, eval_auc, step_time)) sys.stdout.flush() if eval_loss < best_loss and not go_test: best_loss = eval_loss patience = max_patience checkpoint_path = os.path.join(train_dir, "best.ckpt") mylog('Saving best model...') model.saver.save(sess, checkpoint_path, global_step=0, write_meta_graph=False) if go_test: checkpoint_path = os.path.join(train_dir, "best.ckpt") mylog('Saving best model...') model.saver.save(sess, checkpoint_path, global_step=0, write_meta_graph=False) if eval_loss > best_loss: patience -= 1 auc_dev.append(eval_auc) losses_dev.append(eval_loss) if patience < 0 and not go_test: mylog("no improvement for too long.. terminating..") mylog("best loss %.4f" % best_loss) sys.stdout.flush() break return