Beispiel #1
0
def train(raw_data=FLAGS.raw_data):
  with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, 
    log_device_placement=FLAGS.device_log)) as sess:
    run_options = None
    run_metadata = None
    if FLAGS.profile:
      run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
      run_metadata = tf.RunMetadata()
      FLAGS.steps_per_checkpoint = 30
    
    mylog("reading data")

    (seq_tr, items_dev, data_tr, data_va, u_attributes, i_attributes,
      item_ind2logit_ind, logit_ind2item_ind, end_ind, _, _) = get_data(raw_data,
      data_dir=FLAGS.data_dir)

    power = FLAGS.power
    item_pop, p_item = item_frequency(data_tr, power)

    if FLAGS.use_more_train:
      item_population = range(len(item_ind2logit_ind))
    else:
      item_population = item_pop

    model = create_model(sess, u_attributes, i_attributes, item_ind2logit_ind,
      logit_ind2item_ind, loss=FLAGS.loss, ind_item=item_population)

    # data iterators
    n_skips = FLAGS.ni if FLAGS.model == 'cbow' else FLAGS.num_skips
    dite = DataIterator(seq_tr, end_ind, FLAGS.batch_size, n_skips, 
      FLAGS.skip_window, False)
    if FLAGS.model == 'sg':
      ite = dite.get_next_sg()
    else:
      ite = dite.get_next_cbow()

    mylog('started training')
    step_time, loss, current_step, auc = 0.0, 0.0, 0, 0.0
    
    repeat = 5 if FLAGS.loss.startswith('bpr') else 1   
    patience = FLAGS.patience

    if os.path.isfile(os.path.join(FLAGS.train_dir, 'auc_train.npy')):
      auc_train = list(np.load(os.path.join(FLAGS.train_dir, 'auc_train.npy')))
      auc_dev = list(np.load(os.path.join(FLAGS.train_dir, 'auc_dev.npy')))
      previous_losses = list(np.load(os.path.join(FLAGS.train_dir, 
        'loss_train.npy')))
      losses_dev = list(np.load(os.path.join(FLAGS.train_dir, 'loss_dev.npy')))
      best_auc = max(auc_dev)
      best_loss = min(losses_dev)
    else:
      previous_losses, auc_train, auc_dev, losses_dev = [], [], [], []
      best_auc, best_loss = -1, 1000000

    item_sampled, item_sampled_id2idx = None, None

    train_total_size = float(len(data_tr))
    n_epoch = FLAGS.n_epoch
    steps_per_epoch = int(1.0 * train_total_size / FLAGS.batch_size)
    total_steps = steps_per_epoch * n_epoch

    mylog("Train:")
    mylog("total: {}".format(train_total_size))
    mylog("Steps_per_epoch: {}".format(steps_per_epoch))
    mylog("Total_steps:{}".format(total_steps))
    mylog("Dev:")
    mylog("total: {}".format(len(data_va)))

    while True:

      start_time = time.time()
      # generate batch of training
      (user_input, input_items, output_items) = ite.next()
      if current_step < 5:
        mylog("current step is {}".format(current_step))
        mylog('user')
        mylog(user_input)
        mylog('input_item')
        mylog(input_items)      
        mylog('output_item')
        mylog(output_items)
      

      if FLAGS.loss in ['mw', 'mce'] and current_step % FLAGS.n_resample == 0:
        item_sampled, item_sampled_id2idx = sample_items(item_population, 
          FLAGS.n_sampled, p_item)
      else:
        item_sampled = None

      step_loss = model.step(sess, user_input, input_items, 
        output_items, item_sampled, item_sampled_id2idx, loss=FLAGS.loss,run_op=run_options, run_meta=run_metadata)
      # step_loss = 0

      step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint
      loss += step_loss / FLAGS.steps_per_checkpoint
      # auc += step_auc / FLAGS.steps_per_checkpoint
      current_step += 1
      if current_step > total_steps:
        mylog("Training reaches maximum steps. Terminating...")
        break

      if current_step % FLAGS.steps_per_checkpoint == 0:

        if FLAGS.loss in ['ce', 'mce']:
          perplexity = math.exp(loss) if loss < 300 else float('inf')
          mylog("global step %d learning rate %.4f step-time %.4f perplexity %.2f" % (model.global_step.eval(), model.learning_rate.eval(), step_time, perplexity))
        else:
          mylog("global step %d learning rate %.4f step-time %.4f loss %.3f" % (model.global_step.eval(), model.learning_rate.eval(), step_time, loss))
        if FLAGS.profile:
          # Create the Timeline object, and write it to a json
          tl = timeline.Timeline(run_metadata.step_stats)
          ctf = tl.generate_chrome_trace_format()
          with open('timeline.json', 'w') as f:
              f.write(ctf)
          exit()

        # Decrease learning rate if no improvement was seen over last 3 times.
        if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
          sess.run(model.learning_rate_decay_op) 
        previous_losses.append(loss)
        auc_train.append(auc)
        step_time, loss, auc = 0.0, 0.0, 0.0

        if not FLAGS.eval:
          continue
        # # Save checkpoint and zero timer and loss.
        # checkpoint_path = os.path.join(FLAGS.train_dir, "go.ckpt")
        # current_model = model.saver.save(sess, checkpoint_path, 
        #   global_step=model.global_step)
        
        # Run evals on development set and print their loss/auc.
        l_va = len(data_va)
        eval_loss, eval_auc = 0.0, 0.0
        count_va = 0
        start_time = time.time()
        for idx_s in range(0, l_va, FLAGS.batch_size):
          idx_e = idx_s + FLAGS.batch_size
          if idx_e > l_va:
            break
          lt = data_va[idx_s:idx_e]
          user_va = [x[0] for x in lt]
          item_va_input = [items_dev[x[0]] for x in lt]
          item_va_input = map(list, zip(*item_va_input))
          item_va = [x[1] for x in lt]
          
          the_loss = 'warp' if FLAGS.loss == 'mw' else FLAGS.loss
          eval_loss0 = model.step(sess, user_va, item_va_input, item_va, 
            forward_only=True, loss=the_loss)
          eval_loss += eval_loss0
          count_va += 1
        eval_loss /= count_va
        eval_auc /= count_va
        step_time = (time.time() - start_time) / count_va
        if FLAGS.loss in ['ce', 'mce']:
          eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
          mylog("  dev: perplexity %.2f eval_auc %.4f step-time %.4f" % (
            eval_ppx, eval_auc, step_time))
        else:
          mylog("  dev: loss %.3f eval_auc %.4f step-time %.4f" % (eval_loss, 
            eval_auc, step_time))
        sys.stdout.flush()
        
        if eval_loss < best_loss and not FLAGS.test:
          best_loss = eval_loss
          patience = FLAGS.patience

          # Save checkpoint and zero timer and loss.
          checkpoint_path = os.path.join(FLAGS.train_dir, "best.ckpt")
          model.saver.save(sess, checkpoint_path, 
            global_step=0, write_meta_graph = False)
          mylog('Saving best model...')

        if FLAGS.test:
          checkpoint_path = os.path.join(FLAGS.train_dir, "best.ckpt")
          model.saver.save(sess, checkpoint_path, 
            global_step=0, write_meta_graph = False)
          mylog('Saving current model...')

        if eval_loss > best_loss: # and eval_auc < best_auc:
          patience -= 1

        auc_dev.append(eval_auc)
        losses_dev.append(eval_loss)

        if patience < 0 and not FLAGS.test:
          mylog("no improvement for too long.. terminating..")
          mylog("best auc %.4f" % best_auc)
          mylog("best loss %.4f" % best_loss)
          sys.stdout.flush()
          break
  return
Beispiel #2
0
def train(raw_data=FLAGS.raw_data):

    # Read Data
    mylog("Reading Data...")
    train_set, dev_set, test_set, embAttr, START_ID, item_population, p_item, _, _, _, _, _ = get_data(
        raw_data, data_dir=FLAGS.data_dir)
    n_targets_train = np.sum(
        [np.sum([len(items) for uid, items in x]) for x in train_set])
    train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]
    train_total_size = float(sum(train_bucket_sizes))
    train_buckets_scale = [
        sum(train_bucket_sizes[:i + 1]) / train_total_size
        for i in xrange(len(train_bucket_sizes))
    ]
    dev_bucket_sizes = [len(dev_set[b]) for b in xrange(len(_buckets))]
    dev_total_size = int(sum(dev_bucket_sizes))

    # steps
    batch_size = FLAGS.batch_size
    n_epoch = FLAGS.n_epoch
    steps_per_epoch = int(train_total_size / batch_size)
    steps_per_dev = int(dev_total_size / batch_size)

    steps_per_checkpoint = int(steps_per_epoch / 2)
    total_steps = steps_per_epoch * n_epoch

    # reports
    mylog(_buckets)
    mylog("Train:")
    mylog("total: {}".format(train_total_size))
    mylog("bucket sizes: {}".format(train_bucket_sizes))
    mylog("Dev:")
    mylog("total: {}".format(dev_total_size))
    mylog("bucket sizes: {}".format(dev_bucket_sizes))
    mylog("")
    mylog("Steps_per_epoch: {}".format(steps_per_epoch))
    mylog("Total_steps:{}".format(total_steps))
    mylog("Steps_per_checkpoint: {}".format(steps_per_checkpoint))

    # with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement = False, device_count={'CPU':8, 'GPU':1})) as sess:
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)) as sess:

        # runtime profile
        if FLAGS.profile:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
        else:
            run_options = None
            run_metadata = None

        mylog("Creating Model.. (this can take a few minutes)")
        model = create_model(sess, embAttr, START_ID, run_options,
                             run_metadata)
        show_all_variables()

        # Data Iterators
        dite = DataIterator(model, train_set, len(train_buckets_scale),
                            batch_size, train_buckets_scale)

        iteType = 0
        if iteType == 0:
            mylog("withRandom")
            ite = dite.next_random()
        elif iteType == 1:
            mylog("withSequence")
            ite = dite.next_sequence()

        # statistics during training
        step_time, loss = 0.0, 0.0
        current_step = 0
        previous_losses = []
        his = []
        low_ppx = float("inf")
        low_ppx_step = 0
        steps_per_report = 30
        n_targets_report = 0
        report_time = 0
        n_valid_sents = 0
        patience = FLAGS.patience
        item_sampled, item_sampled_id2idx = None, None

        while current_step < total_steps:

            # start
            start_time = time.time()

            # re-sample every once a while
            if FLAGS.loss in ['mw', 'mce'
                              ] and current_step % FLAGS.n_resample == 0:
                item_sampled, item_sampled_id2idx = sample_items(
                    item_population, FLAGS.n_sampled, p_item)
            else:
                item_sampled = None

            # data and train
            users, inputs, outputs, weights, bucket_id = ite.next()

            L = model.step(sess,
                           users,
                           inputs,
                           outputs,
                           weights,
                           bucket_id,
                           item_sampled=item_sampled,
                           item_sampled_id2idx=item_sampled_id2idx)

            # loss and time
            step_time += (time.time() - start_time) / steps_per_checkpoint

            loss += L
            current_step += 1
            n_valid_sents += np.sum(np.sign(weights[0]))

            # for report
            report_time += (time.time() - start_time)
            n_targets_report += np.sum(weights)

            if current_step % steps_per_report == 0:
                mylog("--------------------" + "Report" + str(current_step) +
                      "-------------------")
                mylog(
                    "StepTime: {} Speed: {} targets / sec in total {} targets".
                    format(report_time / steps_per_report,
                           n_targets_report * 1.0 / report_time,
                           n_targets_train))

                report_time = 0
                n_targets_report = 0

                # Create the Timeline object, and write it to a json
                if FLAGS.profile:
                    tl = timeline.Timeline(run_metadata.step_stats)
                    ctf = tl.generate_chrome_trace_format()
                    with open('timeline.json', 'w') as f:
                        f.write(ctf)
                    exit()

            if current_step % steps_per_checkpoint == 0:
                mylog("--------------------" + "TRAIN" + str(current_step) +
                      "-------------------")
                # Print statistics for the previous epoch.

                loss = loss / n_valid_sents
                perplexity = math.exp(
                    float(loss)) if loss < 300 else float("inf")
                mylog(
                    "global step %d learning rate %.4f step-time %.2f perplexity "
                    "%.2f" %
                    (model.global_step.eval(), model.learning_rate.eval(),
                     step_time, perplexity))

                train_ppx = perplexity

                # Save checkpoint and zero timer and loss.
                step_time, loss, n_valid_sents = 0.0, 0.0, 0

                # dev data
                mylog("--------------------" + "DEV" + str(current_step) +
                      "-------------------")
                eval_loss, eval_ppx = evaluate(
                    sess,
                    model,
                    dev_set,
                    item_sampled_id2idx=item_sampled_id2idx)
                mylog("dev: ppx: {}".format(eval_ppx))

                his.append([current_step, train_ppx, eval_ppx])

                if eval_ppx < low_ppx:
                    patience = FLAGS.patience
                    low_ppx = eval_ppx
                    low_ppx_step = current_step
                    checkpoint_path = os.path.join(FLAGS.train_dir,
                                                   "best.ckpt")
                    mylog("Saving best model....")
                    s = time.time()
                    model.saver.save(sess,
                                     checkpoint_path,
                                     global_step=0,
                                     write_meta_graph=False)
                    mylog("Best model saved using {} sec".format(time.time() -
                                                                 s))
                else:
                    patience -= 1

                if patience <= 0:
                    mylog("Training finished. Running out of patience.")
                    break

                sys.stdout.flush()
Beispiel #3
0
def train(raw_data=FLAGS.raw_data,
          train_dir=FLAGS.train_dir,
          mylog=mylog,
          data_dir=FLAGS.data_dir,
          combine_att=FLAGS.combine_att,
          test=FLAGS.test,
          logits_size_tr=FLAGS.item_vocab_size,
          thresh=FLAGS.item_vocab_min_thresh,
          use_user_feature=FLAGS.use_user_feature,
          use_item_feature=FLAGS.use_item_feature,
          batch_size=FLAGS.batch_size,
          steps_per_checkpoint=FLAGS.steps_per_checkpoint,
          loss_func=FLAGS.loss,
          max_patience=FLAGS.patience,
          go_test=FLAGS.test,
          max_epoch=FLAGS.n_epoch,
          sample_type=FLAGS.sample_type,
          power=FLAGS.power,
          use_more_train=FLAGS.use_more_train,
          profile=FLAGS.profile,
          device_log=FLAGS.device_log):
    with tf.Session(
            config=tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=device_log)) as sess:
        run_options = None
        run_metadata = None
        if profile:
            # in order to profile
            from tensorflow.python.client import timeline
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
            steps_per_checkpoint = 30

        mylog("reading data")
        (data_tr, data_va, u_attributes, i_attributes, item_ind2logit_ind,
         logit_ind2item_ind, _,
         _) = read_data(raw_data_dir=raw_data,
                        data_dir=data_dir,
                        combine_att=combine_att,
                        logits_size_tr=logits_size_tr,
                        thresh=thresh,
                        use_user_feature=use_user_feature,
                        use_item_feature=use_item_feature,
                        test=test,
                        mylog=mylog)

        mylog("train/dev size: %d/%d" % (len(data_tr), len(data_va)))
        '''
    remove some rare items in both train and valid set
    this helps make train/valid set distribution similar 
    to each other
    '''
        mylog("original train/dev size: %d/%d" % (len(data_tr), len(data_va)))
        data_tr = [p for p in data_tr if (p[1] in item_ind2logit_ind)]
        data_va = [p for p in data_va if (p[1] in item_ind2logit_ind)]
        mylog("new train/dev size: %d/%d" % (len(data_tr), len(data_va)))

        item_pop, p_item = item_frequency(data_tr, power)

        if use_more_train:
            item_population = range(len(item_ind2logit_ind))
        else:
            item_population = item_pop

        model = create_model(sess,
                             u_attributes,
                             i_attributes,
                             item_ind2logit_ind,
                             logit_ind2item_ind,
                             loss=loss_func,
                             ind_item=item_population)

        pos_item_list, pos_item_list_val = None, None
        if loss_func in ['warp', 'mw', 'rs', 'rs-sig', 'bbpr']:
            pos_item_list, pos_item_list_val = positive_items(data_tr, data_va)
            model.prepare_warp(pos_item_list, pos_item_list_val)

        mylog('started training')
        step_time, loss, current_step, auc = 0.0, 0.0, 0, 0.0

        repeat = 5 if loss_func.startswith('bpr') else 1
        patience = max_patience

        if os.path.isfile(os.path.join(train_dir, 'auc_train.npy')):
            auc_train = list(np.load(os.path.join(train_dir, 'auc_train.npy')))
            auc_dev = list(np.load(os.path.join(train_dir, 'auc_dev.npy')))
            previous_losses = list(
                np.load(os.path.join(train_dir, 'loss_train.npy')))
            losses_dev = list(np.load(os.path.join(train_dir, 'loss_dev.npy')))
            best_auc = max(auc_dev)
            best_loss = min(losses_dev)
        else:
            previous_losses, auc_train, auc_dev, losses_dev = [], [], [], []
            best_auc, best_loss = -1, 1000000

        item_sampled, item_sampled_id2idx = None, None

        if sample_type == 'random':
            get_next_batch = model.get_batch
        elif sample_type == 'permute':
            get_next_batch = model.get_permuted_batch
        else:
            print('not implemented!')
            exit()

        train_total_size = float(len(data_tr))
        n_epoch = max_epoch
        steps_per_epoch = int(1.0 * train_total_size / batch_size)
        total_steps = steps_per_epoch * n_epoch

        mylog("Train:")
        mylog("total: {}".format(train_total_size))
        mylog("Steps_per_epoch: {}".format(steps_per_epoch))
        mylog("Total_steps:{}".format(total_steps))
        mylog("Dev:")
        mylog("total: {}".format(len(data_va)))

        mylog("\n\ntraining start!")
        while True:
            ranndom_number_01 = np.random.random_sample()
            start_time = time.time()
            (user_input, item_input, neg_item_input) = get_next_batch(data_tr)

            if loss_func in ['mw', 'mce'
                             ] and current_step % FLAGS.n_resample == 0:
                item_sampled, item_sampled_id2idx = sample_items(
                    item_population, FLAGS.n_sampled, p_item)
            else:
                item_sampled = None

            step_loss = model.step(sess,
                                   user_input,
                                   item_input,
                                   neg_item_input,
                                   item_sampled,
                                   item_sampled_id2idx,
                                   loss=loss_func,
                                   run_op=run_options,
                                   run_meta=run_metadata)

            step_time += (time.time() - start_time) / steps_per_checkpoint
            loss += step_loss / steps_per_checkpoint
            current_step += 1
            if current_step > total_steps:
                mylog("Training reaches maximum steps. Terminating...")
                break

            if current_step % steps_per_checkpoint == 0:

                if loss_func in ['ce', 'mce']:
                    perplexity = math.exp(loss) if loss < 300 else float('inf')
                    mylog(
                        "global step %d learning rate %.4f step-time %.4f perplexity %.2f"
                        % (model.global_step.eval(),
                           model.learning_rate.eval(), step_time, perplexity))
                else:
                    mylog(
                        "global step %d learning rate %.4f step-time %.4f loss %.3f"
                        % (model.global_step.eval(),
                           model.learning_rate.eval(), step_time, loss))
                if profile:
                    # Create the Timeline object, and write it to a json
                    tl = timeline.Timeline(run_metadata.step_stats)
                    ctf = tl.generate_chrome_trace_format()
                    with open('timeline.json', 'w') as f:
                        f.write(ctf)
                    exit()

                # Decrease learning rate if no improvement was seen over last 3 times.
                if len(previous_losses) > 2 and loss > max(
                        previous_losses[-3:]):
                    sess.run(model.learning_rate_decay_op)
                previous_losses.append(loss)
                auc_train.append(auc)

                # Reset timer and loss.
                step_time, loss, auc = 0.0, 0.0, 0.0

                if not FLAGS.eval:
                    continue

                # Run evals on development set and print their loss.
                l_va = len(data_va)
                eval_loss, eval_auc = 0.0, 0.0
                count_va = 0
                start_time = time.time()
                for idx_s in range(0, l_va, batch_size):
                    idx_e = idx_s + batch_size
                    if idx_e > l_va:
                        break
                    lt = data_va[idx_s:idx_e]
                    user_va = [x[0] for x in lt]
                    item_va = [x[1] for x in lt]
                    for _ in range(repeat):
                        item_va_neg = None
                        the_loss = 'warp' if loss_func == 'mw' else loss_func
                        eval_loss0 = model.step(sess,
                                                user_va,
                                                item_va,
                                                item_va_neg,
                                                None,
                                                None,
                                                forward_only=True,
                                                loss=the_loss)
                        eval_loss += eval_loss0
                        count_va += 1
                eval_loss /= count_va
                eval_auc /= count_va
                step_time = (time.time() - start_time) / count_va
                if loss_func in ['ce', 'mce']:
                    eval_ppx = math.exp(
                        eval_loss) if eval_loss < 300 else float('inf')
                    mylog(
                        "  dev: perplexity %.2f eval_auc(not computed) %.4f step-time %.4f"
                        % (eval_ppx, eval_auc, step_time))
                else:
                    mylog(
                        "  dev: loss %.3f eval_auc(not computed) %.4f step-time %.4f"
                        % (eval_loss, eval_auc, step_time))
                sys.stdout.flush()

                if eval_loss < best_loss and not go_test:
                    best_loss = eval_loss
                    patience = max_patience
                    checkpoint_path = os.path.join(train_dir, "best.ckpt")
                    mylog('Saving best model...')
                    model.saver.save(sess,
                                     checkpoint_path,
                                     global_step=0,
                                     write_meta_graph=False)

                if go_test:
                    checkpoint_path = os.path.join(train_dir, "best.ckpt")
                    mylog('Saving best model...')
                    model.saver.save(sess,
                                     checkpoint_path,
                                     global_step=0,
                                     write_meta_graph=False)

                if eval_loss > best_loss:
                    patience -= 1

                auc_dev.append(eval_auc)
                losses_dev.append(eval_loss)

                if patience < 0 and not go_test:
                    mylog("no improvement for too long.. terminating..")
                    mylog("best loss %.4f" % best_loss)
                    sys.stdout.flush()
                    break
    return