Exemple #1
0
def train(raw_data=FLAGS.raw_data):
  with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, 
    log_device_placement=FLAGS.device_log)) as sess:
    run_options = None
    run_metadata = None
    if FLAGS.profile:
      run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
      run_metadata = tf.RunMetadata()
      FLAGS.steps_per_checkpoint = 30
    
    mylog("reading data")

    (seq_tr, items_dev, data_tr, data_va, u_attributes, i_attributes,
      item_ind2logit_ind, logit_ind2item_ind, end_ind, _, _) = get_data(raw_data,
      data_dir=FLAGS.data_dir)

    power = FLAGS.power
    item_pop, p_item = item_frequency(data_tr, power)

    if FLAGS.use_more_train:
      item_population = range(len(item_ind2logit_ind))
    else:
      item_population = item_pop

    model = create_model(sess, u_attributes, i_attributes, item_ind2logit_ind,
      logit_ind2item_ind, loss=FLAGS.loss, ind_item=item_population)

    # data iterators
    n_skips = FLAGS.ni if FLAGS.model == 'cbow' else FLAGS.num_skips
    dite = DataIterator(seq_tr, end_ind, FLAGS.batch_size, n_skips, 
      FLAGS.skip_window, False)
    if FLAGS.model == 'sg':
      ite = dite.get_next_sg()
    else:
      ite = dite.get_next_cbow()

    mylog('started training')
    step_time, loss, current_step, auc = 0.0, 0.0, 0, 0.0
    
    repeat = 5 if FLAGS.loss.startswith('bpr') else 1   
    patience = FLAGS.patience

    if os.path.isfile(os.path.join(FLAGS.train_dir, 'auc_train.npy')):
      auc_train = list(np.load(os.path.join(FLAGS.train_dir, 'auc_train.npy')))
      auc_dev = list(np.load(os.path.join(FLAGS.train_dir, 'auc_dev.npy')))
      previous_losses = list(np.load(os.path.join(FLAGS.train_dir, 
        'loss_train.npy')))
      losses_dev = list(np.load(os.path.join(FLAGS.train_dir, 'loss_dev.npy')))
      best_auc = max(auc_dev)
      best_loss = min(losses_dev)
    else:
      previous_losses, auc_train, auc_dev, losses_dev = [], [], [], []
      best_auc, best_loss = -1, 1000000

    item_sampled, item_sampled_id2idx = None, None

    train_total_size = float(len(data_tr))
    n_epoch = FLAGS.n_epoch
    steps_per_epoch = int(1.0 * train_total_size / FLAGS.batch_size)
    total_steps = steps_per_epoch * n_epoch

    mylog("Train:")
    mylog("total: {}".format(train_total_size))
    mylog("Steps_per_epoch: {}".format(steps_per_epoch))
    mylog("Total_steps:{}".format(total_steps))
    mylog("Dev:")
    mylog("total: {}".format(len(data_va)))

    while True:

      start_time = time.time()
      # generate batch of training
      (user_input, input_items, output_items) = ite.next()
      if current_step < 5:
        mylog("current step is {}".format(current_step))
        mylog('user')
        mylog(user_input)
        mylog('input_item')
        mylog(input_items)      
        mylog('output_item')
        mylog(output_items)
      

      if FLAGS.loss in ['mw', 'mce'] and current_step % FLAGS.n_resample == 0:
        item_sampled, item_sampled_id2idx = sample_items(item_population, 
          FLAGS.n_sampled, p_item)
      else:
        item_sampled = None

      step_loss = model.step(sess, user_input, input_items, 
        output_items, item_sampled, item_sampled_id2idx, loss=FLAGS.loss,run_op=run_options, run_meta=run_metadata)
      # step_loss = 0

      step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint
      loss += step_loss / FLAGS.steps_per_checkpoint
      # auc += step_auc / FLAGS.steps_per_checkpoint
      current_step += 1
      if current_step > total_steps:
        mylog("Training reaches maximum steps. Terminating...")
        break

      if current_step % FLAGS.steps_per_checkpoint == 0:

        if FLAGS.loss in ['ce', 'mce']:
          perplexity = math.exp(loss) if loss < 300 else float('inf')
          mylog("global step %d learning rate %.4f step-time %.4f perplexity %.2f" % (model.global_step.eval(), model.learning_rate.eval(), step_time, perplexity))
        else:
          mylog("global step %d learning rate %.4f step-time %.4f loss %.3f" % (model.global_step.eval(), model.learning_rate.eval(), step_time, loss))
        if FLAGS.profile:
          # Create the Timeline object, and write it to a json
          tl = timeline.Timeline(run_metadata.step_stats)
          ctf = tl.generate_chrome_trace_format()
          with open('timeline.json', 'w') as f:
              f.write(ctf)
          exit()

        # Decrease learning rate if no improvement was seen over last 3 times.
        if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
          sess.run(model.learning_rate_decay_op) 
        previous_losses.append(loss)
        auc_train.append(auc)
        step_time, loss, auc = 0.0, 0.0, 0.0

        if not FLAGS.eval:
          continue
        # # Save checkpoint and zero timer and loss.
        # checkpoint_path = os.path.join(FLAGS.train_dir, "go.ckpt")
        # current_model = model.saver.save(sess, checkpoint_path, 
        #   global_step=model.global_step)
        
        # Run evals on development set and print their loss/auc.
        l_va = len(data_va)
        eval_loss, eval_auc = 0.0, 0.0
        count_va = 0
        start_time = time.time()
        for idx_s in range(0, l_va, FLAGS.batch_size):
          idx_e = idx_s + FLAGS.batch_size
          if idx_e > l_va:
            break
          lt = data_va[idx_s:idx_e]
          user_va = [x[0] for x in lt]
          item_va_input = [items_dev[x[0]] for x in lt]
          item_va_input = map(list, zip(*item_va_input))
          item_va = [x[1] for x in lt]
          
          the_loss = 'warp' if FLAGS.loss == 'mw' else FLAGS.loss
          eval_loss0 = model.step(sess, user_va, item_va_input, item_va, 
            forward_only=True, loss=the_loss)
          eval_loss += eval_loss0
          count_va += 1
        eval_loss /= count_va
        eval_auc /= count_va
        step_time = (time.time() - start_time) / count_va
        if FLAGS.loss in ['ce', 'mce']:
          eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
          mylog("  dev: perplexity %.2f eval_auc %.4f step-time %.4f" % (
            eval_ppx, eval_auc, step_time))
        else:
          mylog("  dev: loss %.3f eval_auc %.4f step-time %.4f" % (eval_loss, 
            eval_auc, step_time))
        sys.stdout.flush()
        
        if eval_loss < best_loss and not FLAGS.test:
          best_loss = eval_loss
          patience = FLAGS.patience

          # Save checkpoint and zero timer and loss.
          checkpoint_path = os.path.join(FLAGS.train_dir, "best.ckpt")
          model.saver.save(sess, checkpoint_path, 
            global_step=0, write_meta_graph = False)
          mylog('Saving best model...')

        if FLAGS.test:
          checkpoint_path = os.path.join(FLAGS.train_dir, "best.ckpt")
          model.saver.save(sess, checkpoint_path, 
            global_step=0, write_meta_graph = False)
          mylog('Saving current model...')

        if eval_loss > best_loss: # and eval_auc < best_auc:
          patience -= 1

        auc_dev.append(eval_auc)
        losses_dev.append(eval_loss)

        if patience < 0 and not FLAGS.test:
          mylog("no improvement for too long.. terminating..")
          mylog("best auc %.4f" % best_auc)
          mylog("best loss %.4f" % best_loss)
          sys.stdout.flush()
          break
  return
Exemple #2
0
def get_data(raw_data,
             data_dir=FLAGS.data_dir,
             combine_att=FLAGS.combine_att,
             logits_size_tr=FLAGS.item_vocab_size,
             thresh=FLAGS.vocab_min_thresh,
             use_user_feature=FLAGS.use_user_feature,
             test=FLAGS.test,
             mylog=mylog,
             use_item_feature=FLAGS.use_item_feature,
             recommend=False):

    (data_tr, data_va, u_attr, i_attr, item_ind2logit_ind, logit_ind2item_ind,
     user_index,
     item_index) = read_attributed_data(raw_data_dir=raw_data,
                                        data_dir=data_dir,
                                        combine_att=combine_att,
                                        logits_size_tr=logits_size_tr,
                                        thresh=thresh,
                                        use_user_feature=use_user_feature,
                                        use_item_feature=use_item_feature,
                                        test=test,
                                        mylog=mylog)

    # remove unk
    data_tr = [p for p in data_tr if (p[1] in item_ind2logit_ind)]

    # remove items before week 40
    if FLAGS.after40:
        data_tr = [p for p in data_tr if (to_week(p[2]) >= 40)]

    # item frequency (for sampling)
    item_population, p_item = item_frequency(data_tr, FLAGS.power)

    # UNK and START
    # print(len(item_ind2logit_ind))
    # print(len(logit_ind2item_ind))
    # print(len(item_index))
    START_ID = len(item_index)
    # START_ID = i_attr.get_item_last_index()
    item_ind2logit_ind[START_ID] = 0
    seq_all = form_sequence(data_tr, maxlen=FLAGS.L)
    seq_tr0, seq_va0 = split_train_dev(seq_all, ratio=0.05)

    # calculate buckets
    global _buckets
    _buckets = calculate_buckets(seq_tr0 + seq_va0, FLAGS.L, FLAGS.n_bucket)
    _buckets = sorted(_buckets)

    # split_buckets
    seq_tr = split_buckets(seq_tr0, _buckets)
    seq_va = split_buckets(seq_va0, _buckets)

    # get test data
    if recommend:
        from evaluate import Evaluation as Evaluate
        evaluation = Evaluate(raw_data, test=test)
        uids = evaluation.get_uinds()  # abuse of 'uids'  : actually uinds
        seq_test = form_sequence_prediction(seq_all, uids, FLAGS.L, START_ID)
        _buckets = calculate_buckets(seq_test, FLAGS.L, FLAGS.n_bucket)
        _buckets = sorted(_buckets)
        seq_test = split_buckets(seq_test, _buckets)
    else:
        seq_test = []
        evaluation = None
        uids = []

    # create embedAttr

    devices = get_device_address(FLAGS.N)
    with tf.device(devices[0]):
        u_attr.set_model_size(FLAGS.size)
        i_attr.set_model_size(FLAGS.size)

        # if not FLAGS.use_item_feature:
        #     mylog("NOT using item attributes")
        #     i_attr.num_features_cat = 1
        #     i_attr.num_features_mulhot = 0

        # if not FLAGS.use_user_feature:
        #     mylog("NOT using user attributes")
        #     u_attr.num_features_cat = 1
        #     u_attr.num_features_mulhot = 0

        embAttr = embed_attribute.EmbeddingAttribute(u_attr,
                                                     i_attr,
                                                     FLAGS.batch_size,
                                                     FLAGS.n_sampled,
                                                     _buckets[-1],
                                                     FLAGS.use_sep_item,
                                                     item_ind2logit_ind,
                                                     logit_ind2item_ind,
                                                     devices=devices)

        if FLAGS.loss in ["warp", 'mw']:
            prepare_warp(embAttr, seq_tr0, seq_va0)

    return seq_tr, seq_va, seq_test, embAttr, START_ID, item_population, p_item, evaluation, uids, user_index, item_index, logit_ind2item_ind
Exemple #3
0
def train(raw_data=FLAGS.raw_data,
          train_dir=FLAGS.train_dir,
          mylog=mylog,
          data_dir=FLAGS.data_dir,
          combine_att=FLAGS.combine_att,
          test=FLAGS.test,
          logits_size_tr=FLAGS.item_vocab_size,
          thresh=FLAGS.item_vocab_min_thresh,
          use_user_feature=FLAGS.use_user_feature,
          use_item_feature=FLAGS.use_item_feature,
          batch_size=FLAGS.batch_size,
          steps_per_checkpoint=FLAGS.steps_per_checkpoint,
          loss_func=FLAGS.loss,
          max_patience=FLAGS.patience,
          go_test=FLAGS.test,
          max_epoch=FLAGS.n_epoch,
          sample_type=FLAGS.sample_type,
          power=FLAGS.power,
          use_more_train=FLAGS.use_more_train,
          profile=FLAGS.profile,
          device_log=FLAGS.device_log):
    with tf.Session(
            config=tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=device_log)) as sess:
        run_options = None
        run_metadata = None
        if profile:
            # in order to profile
            from tensorflow.python.client import timeline
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
            steps_per_checkpoint = 30

        mylog("reading data")
        (data_tr, data_va, u_attributes, i_attributes, item_ind2logit_ind,
         logit_ind2item_ind, _,
         _) = read_data(raw_data_dir=raw_data,
                        data_dir=data_dir,
                        combine_att=combine_att,
                        logits_size_tr=logits_size_tr,
                        thresh=thresh,
                        use_user_feature=use_user_feature,
                        use_item_feature=use_item_feature,
                        test=test,
                        mylog=mylog)

        mylog("train/dev size: %d/%d" % (len(data_tr), len(data_va)))
        '''
    remove some rare items in both train and valid set
    this helps make train/valid set distribution similar 
    to each other
    '''
        mylog("original train/dev size: %d/%d" % (len(data_tr), len(data_va)))
        data_tr = [p for p in data_tr if (p[1] in item_ind2logit_ind)]
        data_va = [p for p in data_va if (p[1] in item_ind2logit_ind)]
        mylog("new train/dev size: %d/%d" % (len(data_tr), len(data_va)))

        item_pop, p_item = item_frequency(data_tr, power)

        if use_more_train:
            item_population = range(len(item_ind2logit_ind))
        else:
            item_population = item_pop

        model = create_model(sess,
                             u_attributes,
                             i_attributes,
                             item_ind2logit_ind,
                             logit_ind2item_ind,
                             loss=loss_func,
                             ind_item=item_population)

        pos_item_list, pos_item_list_val = None, None
        if loss_func in ['warp', 'mw', 'rs', 'rs-sig', 'bbpr']:
            pos_item_list, pos_item_list_val = positive_items(data_tr, data_va)
            model.prepare_warp(pos_item_list, pos_item_list_val)

        mylog('started training')
        step_time, loss, current_step, auc = 0.0, 0.0, 0, 0.0

        repeat = 5 if loss_func.startswith('bpr') else 1
        patience = max_patience

        if os.path.isfile(os.path.join(train_dir, 'auc_train.npy')):
            auc_train = list(np.load(os.path.join(train_dir, 'auc_train.npy')))
            auc_dev = list(np.load(os.path.join(train_dir, 'auc_dev.npy')))
            previous_losses = list(
                np.load(os.path.join(train_dir, 'loss_train.npy')))
            losses_dev = list(np.load(os.path.join(train_dir, 'loss_dev.npy')))
            best_auc = max(auc_dev)
            best_loss = min(losses_dev)
        else:
            previous_losses, auc_train, auc_dev, losses_dev = [], [], [], []
            best_auc, best_loss = -1, 1000000

        item_sampled, item_sampled_id2idx = None, None

        if sample_type == 'random':
            get_next_batch = model.get_batch
        elif sample_type == 'permute':
            get_next_batch = model.get_permuted_batch
        else:
            print('not implemented!')
            exit()

        train_total_size = float(len(data_tr))
        n_epoch = max_epoch
        steps_per_epoch = int(1.0 * train_total_size / batch_size)
        total_steps = steps_per_epoch * n_epoch

        mylog("Train:")
        mylog("total: {}".format(train_total_size))
        mylog("Steps_per_epoch: {}".format(steps_per_epoch))
        mylog("Total_steps:{}".format(total_steps))
        mylog("Dev:")
        mylog("total: {}".format(len(data_va)))

        mylog("\n\ntraining start!")
        while True:
            ranndom_number_01 = np.random.random_sample()
            start_time = time.time()
            (user_input, item_input, neg_item_input) = get_next_batch(data_tr)

            if loss_func in ['mw', 'mce'
                             ] and current_step % FLAGS.n_resample == 0:
                item_sampled, item_sampled_id2idx = sample_items(
                    item_population, FLAGS.n_sampled, p_item)
            else:
                item_sampled = None

            step_loss = model.step(sess,
                                   user_input,
                                   item_input,
                                   neg_item_input,
                                   item_sampled,
                                   item_sampled_id2idx,
                                   loss=loss_func,
                                   run_op=run_options,
                                   run_meta=run_metadata)

            step_time += (time.time() - start_time) / steps_per_checkpoint
            loss += step_loss / steps_per_checkpoint
            current_step += 1
            if current_step > total_steps:
                mylog("Training reaches maximum steps. Terminating...")
                break

            if current_step % steps_per_checkpoint == 0:

                if loss_func in ['ce', 'mce']:
                    perplexity = math.exp(loss) if loss < 300 else float('inf')
                    mylog(
                        "global step %d learning rate %.4f step-time %.4f perplexity %.2f"
                        % (model.global_step.eval(),
                           model.learning_rate.eval(), step_time, perplexity))
                else:
                    mylog(
                        "global step %d learning rate %.4f step-time %.4f loss %.3f"
                        % (model.global_step.eval(),
                           model.learning_rate.eval(), step_time, loss))
                if profile:
                    # Create the Timeline object, and write it to a json
                    tl = timeline.Timeline(run_metadata.step_stats)
                    ctf = tl.generate_chrome_trace_format()
                    with open('timeline.json', 'w') as f:
                        f.write(ctf)
                    exit()

                # Decrease learning rate if no improvement was seen over last 3 times.
                if len(previous_losses) > 2 and loss > max(
                        previous_losses[-3:]):
                    sess.run(model.learning_rate_decay_op)
                previous_losses.append(loss)
                auc_train.append(auc)

                # Reset timer and loss.
                step_time, loss, auc = 0.0, 0.0, 0.0

                if not FLAGS.eval:
                    continue

                # Run evals on development set and print their loss.
                l_va = len(data_va)
                eval_loss, eval_auc = 0.0, 0.0
                count_va = 0
                start_time = time.time()
                for idx_s in range(0, l_va, batch_size):
                    idx_e = idx_s + batch_size
                    if idx_e > l_va:
                        break
                    lt = data_va[idx_s:idx_e]
                    user_va = [x[0] for x in lt]
                    item_va = [x[1] for x in lt]
                    for _ in range(repeat):
                        item_va_neg = None
                        the_loss = 'warp' if loss_func == 'mw' else loss_func
                        eval_loss0 = model.step(sess,
                                                user_va,
                                                item_va,
                                                item_va_neg,
                                                None,
                                                None,
                                                forward_only=True,
                                                loss=the_loss)
                        eval_loss += eval_loss0
                        count_va += 1
                eval_loss /= count_va
                eval_auc /= count_va
                step_time = (time.time() - start_time) / count_va
                if loss_func in ['ce', 'mce']:
                    eval_ppx = math.exp(
                        eval_loss) if eval_loss < 300 else float('inf')
                    mylog(
                        "  dev: perplexity %.2f eval_auc(not computed) %.4f step-time %.4f"
                        % (eval_ppx, eval_auc, step_time))
                else:
                    mylog(
                        "  dev: loss %.3f eval_auc(not computed) %.4f step-time %.4f"
                        % (eval_loss, eval_auc, step_time))
                sys.stdout.flush()

                if eval_loss < best_loss and not go_test:
                    best_loss = eval_loss
                    patience = max_patience
                    checkpoint_path = os.path.join(train_dir, "best.ckpt")
                    mylog('Saving best model...')
                    model.saver.save(sess,
                                     checkpoint_path,
                                     global_step=0,
                                     write_meta_graph=False)

                if go_test:
                    checkpoint_path = os.path.join(train_dir, "best.ckpt")
                    mylog('Saving best model...')
                    model.saver.save(sess,
                                     checkpoint_path,
                                     global_step=0,
                                     write_meta_graph=False)

                if eval_loss > best_loss:
                    patience -= 1

                auc_dev.append(eval_auc)
                losses_dev.append(eval_loss)

                if patience < 0 and not go_test:
                    mylog("no improvement for too long.. terminating..")
                    mylog("best loss %.4f" % best_loss)
                    sys.stdout.flush()
                    break
    return