def main():

  if os.path.exists(FLAGS.checkpoint_path) == False:
    os.makedirs(FLAGS.checkpoint_path)
  checkpoint_file_path = FLAGS.checkpoint_path + "/checkpoint.ckpt"
  latest_checkpoint_file_path = tf.train.latest_checkpoint(
      FLAGS.checkpoint_path)

  if os.path.exists(FLAGS.output_path) == False:
    os.makedirs(FLAGS.output_path)

  # Step 1: Construct the dataset op
  epoch_number = FLAGS.epoch_number
  if epoch_number <= 0:
    epoch_number = -1
  train_buffer_size = FLAGS.train_batch_size * 3
  validation_buffer_size = FLAGS.train_batch_size * 3

  train_filename_list = [filename for filename in FLAGS.train_files.split(",")]
  train_filename_placeholder = tf.placeholder(tf.string, shape=[None])
  train_dataset = tf.data.TFRecordDataset(train_filename_placeholder)
  train_dataset = train_dataset.map(parse_tfrecords_function).repeat(
      epoch_number).batch(FLAGS.train_batch_size).shuffle(
          buffer_size=train_buffer_size)
  train_dataset_iterator = train_dataset.make_initializable_iterator()
  batch_labels, batch_ids, batch_values = train_dataset_iterator.get_next()

  validation_filename_list = [
      filename for filename in FLAGS.validation_files.split(",")
  ]
  validation_filename_placeholder = tf.placeholder(tf.string, shape=[None])
  validation_dataset = tf.data.TFRecordDataset(validation_filename_placeholder)
  validation_dataset = validation_dataset.map(parse_tfrecords_function).repeat(
  ).batch(FLAGS.validation_batch_size).shuffle(
      buffer_size=validation_buffer_size)
  validation_dataset_iterator = validation_dataset.make_initializable_iterator(
  )
  validation_labels, validation_ids, validation_values = validation_dataset_iterator.get_next(
  )

  # Define the model
  logits = inference(batch_ids, batch_values, True)
  batch_labels = tf.to_int64(batch_labels)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      logits=logits, labels=batch_labels)
  loss = tf.reduce_mean(cross_entropy, name="loss")
  global_step = tf.Variable(0, name="global_step", trainable=False)
  if FLAGS.enable_lr_decay:
    logging.info(
        "Enable learning rate decay rate: {}".format(FLAGS.lr_decay_rate))
    starter_learning_rate = FLAGS.learning_rate
    learning_rate = tf.train.exponential_decay(
        starter_learning_rate,
        global_step,
        100000,
        FLAGS.lr_decay_rate,
        staircase=True)
  else:
    learning_rate = FLAGS.learning_rate
  optimizer = util.get_optimizer_by_name(FLAGS.optimizer, learning_rate)
  train_op = optimizer.minimize(loss, global_step=global_step)
  tf.get_variable_scope().reuse_variables()

  # Define accuracy op for train data
  train_accuracy_logits = inference(batch_ids, batch_values, False)
  train_softmax = tf.nn.softmax(train_accuracy_logits)
  train_correct_prediction = tf.equal(
      tf.argmax(train_softmax, 1), batch_labels)
  train_accuracy = tf.reduce_mean(
      tf.cast(train_correct_prediction, tf.float32))

  # Define auc op for train data
  batch_labels = tf.cast(batch_labels, tf.int32)
  sparse_labels = tf.reshape(batch_labels, [-1, 1])
  derived_size = tf.shape(batch_labels)[0]
  indices = tf.reshape(tf.range(0, derived_size, 1), [-1, 1])
  concated = tf.concat(axis=1, values=[indices, sparse_labels])
  outshape = tf.stack([derived_size, FLAGS.label_size])
  new_train_batch_labels = tf.sparse_to_dense(concated, outshape, 1.0, 0.0)
  _, train_auc = tf.contrib.metrics.streaming_auc(train_softmax,
                                                  new_train_batch_labels)

  # Define accuracy op for validate data
  validate_accuracy_logits = inference(validation_ids, validation_values,
                                       False)
  validate_softmax = tf.nn.softmax(validate_accuracy_logits)
  validate_batch_labels = tf.to_int64(validation_labels)
  validate_correct_prediction = tf.equal(
      tf.argmax(validate_softmax, 1), validate_batch_labels)
  validate_accuracy = tf.reduce_mean(
      tf.cast(validate_correct_prediction, tf.float32))

  # Define auc op for validate data
  validate_batch_labels = tf.cast(validate_batch_labels, tf.int32)
  sparse_labels = tf.reshape(validate_batch_labels, [-1, 1])
  derived_size = tf.shape(validate_batch_labels)[0]
  indices = tf.reshape(tf.range(0, derived_size, 1), [-1, 1])
  concated = tf.concat(axis=1, values=[indices, sparse_labels])
  outshape = tf.stack([derived_size, FLAGS.label_size])
  new_validate_batch_labels = tf.sparse_to_dense(concated, outshape, 1.0, 0.0)
  _, validate_auc = tf.contrib.metrics.streaming_auc(validate_softmax,
                                                     new_validate_batch_labels)

  # Define inference op
  sparse_index = tf.placeholder(tf.int64, [None, 2])
  sparse_ids = tf.placeholder(tf.int64, [None])
  sparse_values = tf.placeholder(tf.float32, [None])
  sparse_shape = tf.placeholder(tf.int64, [2])
  inference_ids = tf.SparseTensor(sparse_index, sparse_ids, sparse_shape)
  inference_values = tf.SparseTensor(sparse_index, sparse_values, sparse_shape)
  inference_logits = inference(inference_ids, inference_values, False)
  inference_softmax = tf.nn.softmax(inference_logits)
  inference_op = tf.argmax(inference_softmax, 1)
  keys_placeholder = tf.placeholder(tf.int32, shape=[None, 1])
  keys = tf.identity(keys_placeholder)

  signature_def_map = {
      signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
      signature_def_utils.build_signature_def(
          inputs={
              "keys": utils.build_tensor_info(keys_placeholder),
              "indexs": utils.build_tensor_info(sparse_index),
              "ids": utils.build_tensor_info(sparse_ids),
              "values": utils.build_tensor_info(sparse_values),
              "shape": utils.build_tensor_info(sparse_shape)
          },
          outputs={
              "keys": utils.build_tensor_info(keys),
              "softmax": utils.build_tensor_info(inference_softmax),
              "prediction": utils.build_tensor_info(inference_op)
          },
          method_name=signature_constants.PREDICT_METHOD_NAME)
  }

  # Initialize saver and summary
  saver = tf.train.Saver()
  tf.summary.scalar("loss", loss)
  tf.summary.scalar("train_accuracy", train_accuracy)
  tf.summary.scalar("train_auc", train_auc)
  tf.summary.scalar("validate_accuracy", validate_accuracy)
  tf.summary.scalar("validate_auc", validate_auc)
  summary_op = tf.summary.merge_all()
  init_op = [
      tf.global_variables_initializer(),
      tf.local_variables_initializer()
  ]

  # Create session to run
  with tf.Session() as sess:
    writer = tf.summary.FileWriter(FLAGS.output_path, sess.graph)
    sess.run(init_op)
    sess.run(
        train_dataset_iterator.initializer,
        feed_dict={train_filename_placeholder: train_filename_list})
    sess.run(
        validation_dataset_iterator.initializer,
        feed_dict={validation_filename_placeholder: validation_filename_list})

    if FLAGS.mode == "train":
      # Restore session and start queue runner
      util.restore_from_checkpoint(sess, saver, latest_checkpoint_file_path)
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(coord=coord, sess=sess)
      start_time = datetime.datetime.now()

      try:
        while not coord.should_stop():
          if FLAGS.benchmark_mode:
            sess.run(train_op)
          else:
            _, step = sess.run([train_op, global_step])

            # Print state while training
            if step % FLAGS.steps_to_validate == 0:
              loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, auc_value, summary_value = sess.run(
                  [
                      loss, train_accuracy, train_auc, validate_accuracy,
                      validate_auc, summary_op
                  ])
              end_time = datetime.datetime.now()

              logging.info(
                  "[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}".
                  format(end_time - start_time, step, loss_value,
                         train_accuracy_value, train_auc_value,
                         validate_accuracy_value, auc_value))
              writer.add_summary(summary_value, step)
              saver.save(sess, checkpoint_file_path, global_step=step)
              start_time = end_time
      except tf.errors.OutOfRangeError:
        if FLAGS.benchmark_mode:
          print("Finish training for benchmark")
          exit(0)
        else:
          # Export the model after training
          util.save_model(
              FLAGS.model_path,
              FLAGS.model_version,
              sess,
              signature_def_map,
              is_save_graph=False)
      finally:
        coord.request_stop()
      coord.join(threads)

    elif FLAGS.mode == "save_model":
      if not util.restore_from_checkpoint(sess, saver,
                                          latest_checkpoint_file_path):
        logging.error("No checkpoint found, exit now")
        exit(1)

      util.save_model(
          FLAGS.model_path,
          FLAGS.model_version,
          sess,
          signature_def_map,
          is_save_graph=False)

    elif FLAGS.mode == "inference":
      if not util.restore_from_checkpoint(sess, saver,
                                          latest_checkpoint_file_path):
        logging.error("No checkpoint found, exit now")
        exit(1)

      # Load inference test data
      inference_result_file_name = "./inference_result.txt"
      inference_test_file_name = "./data/a8a_test.libsvm"
      labels = []
      feature_ids = []
      feature_values = []
      feature_index = []
      ins_num = 0
      for line in open(inference_test_file_name, "r"):
        tokens = line.split(" ")
        labels.append(int(tokens[0]))
        feature_num = 0
        for feature in tokens[1:]:
          feature_id, feature_value = feature.split(":")
          feature_ids.append(int(feature_id))
          feature_values.append(float(feature_value))
          feature_index.append([ins_num, feature_num])
          feature_num += 1
        ins_num += 1

      # Run inference
      start_time = datetime.datetime.now()
      prediction, prediction_softmax = sess.run(
          [inference_op, inference_softmax],
          feed_dict={
              sparse_index: feature_index,
              sparse_ids: feature_ids,
              sparse_values: feature_values,
              sparse_shape: [ins_num, FLAGS.feature_size]
          })

      end_time = datetime.datetime.now()

      # Compute accuracy
      label_number = len(labels)
      correct_label_number = 0
      for i in range(label_number):
        if labels[i] == prediction[i]:
          correct_label_number += 1
      accuracy = float(correct_label_number) / label_number

      # Compute auc
      expected_labels = np.array(labels)
      predict_labels = prediction_softmax[:, 0]
      fpr, tpr, thresholds = metrics.roc_curve(
          expected_labels, predict_labels, pos_label=0)
      auc = metrics.auc(fpr, tpr)
      logging.info("[{}] Inference accuracy: {}, auc: {}".format(
          end_time - start_time, accuracy, auc))

      # Save result into the file
      np.savetxt(inference_result_file_name, prediction_softmax, delimiter=",")
      logging.info(
          "Save result to file: {}".format(inference_result_file_name))

    elif FLAGS.mode == "inference_with_tfrecords":
      if not util.restore_from_checkpoint(sess, saver,
                                          latest_checkpoint_file_path):
        logging.error("No checkpoint found, exit now")
        exit(1)

      # Load inference test data
      inference_result_file_name = "./inference_result.txt"
      inference_test_file_name = "./data/a8a/a8a_test.libsvm.tfrecords"

      batch_feature_index = []
      batch_labels = []
      batch_ids = []
      batch_values = []
      ins_num = 0

      # Read from TFRecords files
      for serialized_example in tf.python_io.tf_record_iterator(
          inference_test_file_name):
        # Get serialized example from file
        example = tf.train.Example()
        example.ParseFromString(serialized_example)
        label = example.features.feature["label"].float_list.value
        ids = example.features.feature["ids"].int64_list.value
        values = example.features.feature["values"].float_list.value
        #print("label: {}, features: {}".format(label, " ".join([str(id) + ":" + str(value) for id, value in zip(ids, values)])))
        batch_labels.append(label)
        # Notice that using extend() instead of append() to flatten the values
        batch_ids.extend(ids)
        batch_values.extend(values)
        for i in xrange(len(ids)):
          batch_feature_index.append([ins_num, i])

        ins_num += 1

      # Run inference
      start_time = datetime.datetime.now()
      prediction, prediction_softmax = sess.run(
          [inference_op, inference_softmax],
          feed_dict={
              sparse_index: batch_feature_index,
              sparse_ids: batch_ids,
              sparse_values: batch_values,
              sparse_shape: [ins_num, FLAGS.feature_size]
          })

      end_time = datetime.datetime.now()

      # Compute accuracy
      label_number = len(batch_labels)
      correct_label_number = 0
      for i in range(label_number):
        if batch_labels[i] == prediction[i]:
          correct_label_number += 1
      accuracy = float(correct_label_number) / label_number

      # Compute auc
      expected_labels = np.array(batch_labels)
      predict_labels = prediction_softmax[:, 0]
      fpr, tpr, thresholds = metrics.roc_curve(
          expected_labels, predict_labels, pos_label=0)
      auc = metrics.auc(fpr, tpr)
      logging.info("[{}] Inference accuracy: {}, auc: {}".format(
          end_time - start_time, accuracy, auc))

      # Save result into the file
      np.savetxt(inference_result_file_name, prediction_softmax, delimiter=",")
      logging.info(
          "Save result to file: {}".format(inference_result_file_name))
Example #2
0
def main_loop():
    util.cancel_shutdown()
    losses = []

    args = g.args

    if not args.local:
        g.logger.info(
            f'Distributed initializing process group with '
            f'{args.dist_backend}, {args.dist_url}, {util.get_world_size()}')
        dist.init_process_group(
            backend=args.dist_backend,
            #init_method=args.dist_url,
            #world_size=util.get_world_size()
        )
        assert (util.get_world_size() == dist.get_world_size())
        g.logger.info(
            f"Distributed: success ({args.local_rank}/{dist.get_world_size()})"
        )

    g.logger.info("creating new model")
    g.state = TrainState(args)
    g.state.model = MemTransformerLM(g.ntokens,
                                     args.n_layer,
                                     args.n_head,
                                     args.d_model,
                                     args.d_head,
                                     args.d_inner,
                                     args.dropout,
                                     args.dropatt,
                                     tie_weight=args.tied,
                                     d_embed=args.d_embed,
                                     div_val=args.div_val,
                                     tie_projs=g.tie_projs,
                                     pre_lnorm=args.pre_lnorm,
                                     tgt_len=args.tgt_len,
                                     ext_len=args.ext_len,
                                     mem_len=args.mem_len,
                                     cutoffs=g.cutoffs,
                                     same_length=args.same_length,
                                     attn_type=args.attn_type,
                                     clamp_len=args.clamp_len,
                                     sample_softmax=args.sample_softmax,
                                     freeze_below=args.freeze_below)
    g.state.model.to(g.device)
    optimizer_setup(g.state)
    if args.checkpoint:
        if args.checkpoint_secondary:
            g.logger.info(f"restoring extra checkpoint")
            util.restore_from_checkpoint(g.state.model, g.state.optimizer,
                                         args.checkpoint_secondary,
                                         args.optim_state_dict)
        g.logger.info(f"Restoring model from {args.checkpoint}" +
                      f" and optimizer from {args.optim_state_dict}" if args.
                      optim_state_dict else "")
        util.restore_from_checkpoint(g.state.model, g.state.optimizer,
                                     args.checkpoint, args.optim_state_dict)

    else:
        g.state.model.apply(weights_init)
        # ensure embedding init is not overridden by out_layer in case of weight sharing
        g.state.model.word_emb.apply(weights_init)

    model: MemTransformerLM = g.state.model
    optimizer = g.state.optimizer

    if g.state.args.fp16:
        model = FP16_Module(model)
        optimizer = FP16_Optimizer(
            optimizer,
            static_loss_scale=g.state.args.static_loss_scale,
            dynamic_loss_scale=g.state.args.dynamic_loss_scale,
            dynamic_loss_args={'init_scale': 2**16},
            verbose=False)

    # log model info
    # n_all_param = sum([p.nelement() for p in model.parameters()])
    # log_tb('sizes/params', n_all_param)
    # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
    # log_tb('sizes/non_emb_params', n_nonemb_param)
    # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param)

    # scheduler
    if args.scheduler == 'cosine':
        # Divide by 1e6 for numerical stability.
        g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.max_tokens // 1e6, eta_min=args.eta_min)
    elif args.scheduler == 'finder':
        g.state.scheduler: LRFinder = LRFinder(optimizer,
                                               args.max_tokens,
                                               init_value=args.lr / 1e3)
    else:
        assert args.scheduler == 'constant'
        g.state.scheduler = util.NoOp()

    # Setup distributed model
    if args.local:
        model = nn.DataParallel(model, dim=1)
    else:
        # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding.
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank],
            output_device=args.local_rank)  # , find_unused_parameters=True)

    if util.get_global_rank() == 0:
        if not args.test:
            wandb.config.update(vars(args))
            # wandb.watch(model)

    g.event_writer.add_text('args', str(args))  # TODO: replace with log_tb

    accumulated_loss = 0
    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in itertools.count(start=g.state.last_epoch):
            print(f"epoch -- {epoch}, token_count -- {g.state.token_count}")
            model.train()

            log_tb('sizes/batch_size', args.batch_size)
            log_tb('sizes/seq_size', args.tgt_len)

            if g.state.partial_epoch:
                # reuse previously loaded tr_iter and states
                assert g.state.tr_iter is not None
                assert g.state.mems is not None
            else:
                g.state.tr_iter = g.corpus.get_dist_iterator(
                    'train',
                    rank=util.get_global_rank(),
                    max_rank=util.get_world_size(),
                    bsz=args.batch_size,
                    bptt=args.tgt_len,
                    device=g.device,
                    ext_len=args.ext_len,
                    skip_files=g.args.skip_files)
                g.state.mems = tuple()
            g.state.last_epoch = epoch

            log_start_time = time.time()
            tokens_per_epoch = 0
            for batch, (data, target, seq_len) in enumerate(g.state.tr_iter):
                # assert seq_len == data.shape[0]
                # for i in range(1, data.shape[0]):
                #     assert torch.all(torch.eq(data[i], target[i - 1]))
                #     break

                # print(g.state.token_count, data)

                if g.state.train_step % args.eval_interval == 0:
                    evaluate_and_log(model,
                                     g.va_iter,
                                     'val_short-mem-1',
                                     generate_text=False,
                                     reset_mems_interval=1)
                    evaluate_and_log(model,
                                     g.va_iter,
                                     'val_short-mem-2',
                                     generate_text=False,
                                     reset_mems_interval=2)
                    evaluate_and_log(model,
                                     g.va_iter,
                                     'val_short-mem-3',
                                     generate_text=False,
                                     reset_mems_interval=3)
                    evaluate_and_log(model, g.va_iter, 'val')
                    if g.va_custom_iter:
                        evaluate_and_log(g.state.model,
                                         g.va_custom_iter,
                                         g.args.valid_custom,
                                         generate_text=False)

                batch_total = torch.tensor(data.shape[1]).to(g.device)
                if args.local:  # TODO(y): factor out (need way to see if dist was inited)
                    batch_total = batch_total.sum()
                else:
                    batch_total = util.dist_sum_tensor(
                        batch_total)  # global batch size
                batch_total = util.toscalar(batch_total)

                should_log = (g.state.train_step < args.verbose_log_steps) or \
                             (g.state.train_step + 1) % args.log_interval == 0

                model.zero_grad()

                ret = model(data, target, *g.state.mems)
                loss, g.state.mems = ret[0], ret[1:]

                loss: torch.Tensor = loss.float().mean().type_as(loss)
                with timeit('backwards', noop=not should_log):
                    if args.fp16:
                        optimizer.backward(loss)
                    else:
                        loss.backward()
                loss0 = util.toscalar(loss)
                util.record('loss', loss0)

                util.record('params', torch.sum(util.flat_param(model)).item())
                losses.append(loss0)
                accumulated_loss += loss0

                if args.fp16:
                    optimizer.clip_master_grads(args.clip)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.clip)

                # step-wise learning rate annealing
                if hasattr(optimizer, 'overflow') and optimizer.overflow:
                    g.logger.info("skipped iteration")
                else:
                    if args.scheduler in ['cosine', 'constant', 'dev_perf']:
                        # linear warmup stage
                        if g.state.token_count < args.warmup_tokens:
                            curr_lr = args.lr * float(
                                g.state.token_count) / args.warmup_tokens
                            optimizer.param_groups[0]['lr'] = curr_lr
                        elif args.scheduler == 'cosine':
                            # Divide by 1e6 for numerical stability.
                            g.state.scheduler.step(g.state.token_count //
                                                   1000 // 1000)
                    else:
                        g.state.scheduler.step(g.state.token_count)

                optimizer.step()
                g.state.train_step += 1

                consumed_tokens = data.shape[0] * data.shape[1]
                world_size = int(os.environ.get("WORLD_SIZE", "8"))
                if world_size > 8:  # correction factor for multiple machines
                    consumed_tokens = consumed_tokens * (world_size // 8)
                tokens_per_epoch += consumed_tokens
                g.state.token_count += consumed_tokens
                g.token_count = g.state.token_count
                if g.state.token_count >= args.max_tokens:
                    g.state.partial_epoch = True
                    raise StopIteration  # break out of parent train loop

                if should_log:
                    elapsed_time = time.time() - log_start_time
                    elapsed_steps = g.state.train_step - g.state.last_log_step

                    # compute average loss over last logging interval
                    cur_loss = accumulated_loss / elapsed_steps
                    cur_loss_mean = util.dist_mean(cur_loss)
                    log_str = f'| epoch {epoch:3d} step {g.state.train_step:>8d} ' \
                              f'| {batch:>6d} batches ' \
                              f'| lr {optimizer.param_groups[0]["lr"]:.3g} ' \
                              f'| ms/batch {elapsed_time * 1000 / elapsed_steps:5.2f} ' \
                              f'| loss {cur_loss:5.2f}'
                    if args.dataset in ['enwik8', 'text8']:
                        log_str += f' | bpc {cur_loss / math.log(2):9.5f}'
                    else:
                        log_str += f' | ppl {math.exp(cur_loss):9.3f}'
                    g.logger.info(log_str)
                    log_tb('learning/epoch', epoch)
                    log_tb('_loss', cur_loss_mean)  # the most important thing
                    log_tb('learning/loss', cur_loss_mean)
                    log_tb('learning/ppl', math.exp(cur_loss_mean))

                    # currently step timings are not synchronized in multi-machine
                    # case (see #4). Can add torch.distributed.barrier() to get
                    # more accurate timings, but this may add slowness.
                    log_tb('times/step', 1000 * elapsed_time / elapsed_steps)
                    current_lr = optimizer.param_groups[0]['lr']

                    log_tb('learning/lr', current_lr)

                    # 32 is the "canonical" batch size
                    linear_scaling_factor = batch_total / 32  # TODO(y): merge logic from master
                    log_tb('learning/base_lr',
                           current_lr / linear_scaling_factor)
                    if args.optim == 'lamb':
                        log_lamb_rs(optimizer, g.event_writer,
                                    g.state.token_count)

                    time_per_batch = elapsed_time / elapsed_steps
                    time_per_sample = time_per_batch / args.batch_size
                    time_per_token = time_per_sample / args.tgt_len

                    log_tb('times/batches_per_sec', 1 / time_per_batch)
                    log_tb('times/samples_per_sec', 1 / time_per_sample)
                    log_tb('times/tokens_per_sec', 1 / time_per_token)

                    if str(g.device) == 'cuda':
                        log_tb("memory/allocated_gb",
                               torch.cuda.memory_allocated() / 1e9)
                        log_tb("memory/max_allocated_gb",
                               torch.cuda.max_memory_allocated() / 1e9)
                        log_tb("memory/cached_gb",
                               torch.cuda.memory_cached() / 1e9)
                        log_tb("memory/max_cached_gb",
                               torch.cuda.max_memory_cached() / 1e9)

                    accumulated_loss = 0
                    log_start_time = time.time()
                    g.state.last_log_step = g.state.train_step

            if args.checkpoint_each_epoch:
                g.logger.info(f'Saving checkpoint for epoch {epoch}')
                util.dist_save_checkpoint(model,
                                          optimizer,
                                          args.logdir,
                                          suffix=f'{epoch}')
            if tokens_per_epoch == 0:
                logging.info("Zero tokens in last epoch, breaking")

                break

            g.state.partial_epoch = False

    except KeyboardInterrupt:
        g.logger.info('-' * 100)
        g.logger.info('Exiting from training early')
    except StopIteration:
        pass

    return losses
def main():

    if os.path.exists(FLAGS.checkpoint_path) == False:
        os.makedirs(FLAGS.checkpoint_path)
    checkpoint_file_path = FLAGS.checkpoint_path + "/checkpoint.ckpt"
    latest_checkpoint_file_path = tf.train.latest_checkpoint(
        FLAGS.checkpoint_path)

    if os.path.exists(FLAGS.output_path) == False:
        os.makedirs(FLAGS.output_path)

    # Step 1: Construct the dataset op
    epoch_number = FLAGS.epoch_number
    if epoch_number <= 0:
        epoch_number = -1
    train_buffer_size = FLAGS.train_batch_size * 3
    validation_buffer_size = FLAGS.train_batch_size * 3

    train_filename_list = [
        filename for filename in FLAGS.train_files.split(",")
    ]
    train_filename_placeholder = tf.placeholder(tf.string, shape=[None])
    train_dataset = tf.data.TFRecordDataset(train_filename_placeholder)
    train_dataset = train_dataset.map(parse_tfrecords_function).repeat(
        epoch_number).batch(
            FLAGS.train_batch_size).shuffle(buffer_size=train_buffer_size)
    train_dataset_iterator = train_dataset.make_initializable_iterator()
    batch_labels, batch_ids, batch_values = train_dataset_iterator.get_next()

    validation_filename_list = [
        filename for filename in FLAGS.validation_files.split(",")
    ]
    validation_filename_placeholder = tf.placeholder(tf.string, shape=[None])
    validation_dataset = tf.data.TFRecordDataset(
        validation_filename_placeholder)
    validation_dataset = validation_dataset.map(
        parse_tfrecords_function).repeat().batch(
            FLAGS.validation_batch_size).shuffle(
                buffer_size=validation_buffer_size)
    validation_dataset_iterator = validation_dataset.make_initializable_iterator(
    )
    validation_labels, validation_ids, validation_values = validation_dataset_iterator.get_next(
    )

    # Define the model
    logits = inference(batch_ids, batch_values, True)
    batch_labels = tf.to_int64(batch_labels)
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits, labels=batch_labels)
    loss = tf.reduce_mean(cross_entropy, name="loss")
    global_step = tf.Variable(0, name="global_step", trainable=False)
    if FLAGS.enable_lr_decay:
        logging.info("Enable learning rate decay rate: {}".format(
            FLAGS.lr_decay_rate))
        starter_learning_rate = FLAGS.learning_rate
        learning_rate = tf.train.exponential_decay(starter_learning_rate,
                                                   global_step,
                                                   100000,
                                                   FLAGS.lr_decay_rate,
                                                   staircase=True)
    else:
        learning_rate = FLAGS.learning_rate
    optimizer = util.get_optimizer_by_name(FLAGS.optimizer, learning_rate)
    train_op = optimizer.minimize(loss, global_step=global_step)
    tf.get_variable_scope().reuse_variables()

    # Define accuracy op for train data
    train_accuracy_logits = inference(batch_ids, batch_values, False)
    train_softmax = tf.nn.softmax(train_accuracy_logits)
    train_correct_prediction = tf.equal(tf.argmax(train_softmax, 1),
                                        batch_labels)
    train_accuracy = tf.reduce_mean(
        tf.cast(train_correct_prediction, tf.float32))

    # Define auc op for train data
    batch_labels = tf.cast(batch_labels, tf.int32)
    sparse_labels = tf.reshape(batch_labels, [-1, 1])
    derived_size = tf.shape(batch_labels)[0]
    indices = tf.reshape(tf.range(0, derived_size, 1), [-1, 1])
    concated = tf.concat(axis=1, values=[indices, sparse_labels])
    outshape = tf.stack([derived_size, FLAGS.label_size])
    new_train_batch_labels = tf.sparse_to_dense(concated, outshape, 1.0, 0.0)
    _, train_auc = tf.contrib.metrics.streaming_auc(train_softmax,
                                                    new_train_batch_labels)

    # Define accuracy op for validate data
    validate_accuracy_logits = inference(validation_ids, validation_values,
                                         False)
    validate_softmax = tf.nn.softmax(validate_accuracy_logits)
    validate_batch_labels = tf.to_int64(validation_labels)
    validate_correct_prediction = tf.equal(tf.argmax(validate_softmax, 1),
                                           validate_batch_labels)
    validate_accuracy = tf.reduce_mean(
        tf.cast(validate_correct_prediction, tf.float32))

    # Define auc op for validate data
    validate_batch_labels = tf.cast(validate_batch_labels, tf.int32)
    sparse_labels = tf.reshape(validate_batch_labels, [-1, 1])
    derived_size = tf.shape(validate_batch_labels)[0]
    indices = tf.reshape(tf.range(0, derived_size, 1), [-1, 1])
    concated = tf.concat(axis=1, values=[indices, sparse_labels])
    outshape = tf.stack([derived_size, FLAGS.label_size])
    new_validate_batch_labels = tf.sparse_to_dense(concated, outshape, 1.0,
                                                   0.0)
    _, validate_auc = tf.contrib.metrics.streaming_auc(
        validate_softmax, new_validate_batch_labels)

    # Define inference op
    sparse_index = tf.placeholder(tf.int64, [None, 2])
    sparse_ids = tf.placeholder(tf.int64, [None])
    sparse_values = tf.placeholder(tf.float32, [None])
    sparse_shape = tf.placeholder(tf.int64, [2])
    inference_ids = tf.SparseTensor(sparse_index, sparse_ids, sparse_shape)
    inference_values = tf.SparseTensor(sparse_index, sparse_values,
                                       sparse_shape)
    inference_logits = inference(inference_ids, inference_values, False)
    inference_softmax = tf.nn.softmax(inference_logits)
    inference_op = tf.argmax(inference_softmax, 1)
    keys_placeholder = tf.placeholder(tf.int32, shape=[None, 1])
    keys = tf.identity(keys_placeholder)

    signature_def_map = {
        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
        signature_def_utils.build_signature_def(
            inputs={
                "keys": utils.build_tensor_info(keys_placeholder),
                "indexs": utils.build_tensor_info(sparse_index),
                "ids": utils.build_tensor_info(sparse_ids),
                "values": utils.build_tensor_info(sparse_values),
                "shape": utils.build_tensor_info(sparse_shape)
            },
            outputs={
                "keys": utils.build_tensor_info(keys),
                "softmax": utils.build_tensor_info(inference_softmax),
                "prediction": utils.build_tensor_info(inference_op)
            },
            method_name=signature_constants.PREDICT_METHOD_NAME)
    }

    # Initialize saver and summary
    saver = tf.train.Saver()
    tf.summary.scalar("loss", loss)
    tf.summary.scalar("train_accuracy", train_accuracy)
    tf.summary.scalar("train_auc", train_auc)
    tf.summary.scalar("validate_accuracy", validate_accuracy)
    tf.summary.scalar("validate_auc", validate_auc)
    summary_op = tf.summary.merge_all()
    init_op = [
        tf.global_variables_initializer(),
        tf.local_variables_initializer()
    ]

    # Create session to run
    with tf.Session() as sess:
        writer = tf.summary.FileWriter(FLAGS.output_path, sess.graph)
        sess.run(init_op)
        sess.run(train_dataset_iterator.initializer,
                 feed_dict={train_filename_placeholder: train_filename_list})
        sess.run(validation_dataset_iterator.initializer,
                 feed_dict={
                     validation_filename_placeholder: validation_filename_list
                 })

        if FLAGS.mode == "train":
            # Restore session and start queue runner
            util.restore_from_checkpoint(sess, saver,
                                         latest_checkpoint_file_path)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord, sess=sess)
            start_time = datetime.datetime.now()

            try:
                while not coord.should_stop():
                    if FLAGS.benchmark_mode:
                        sess.run(train_op)
                    else:
                        _, step = sess.run([train_op, global_step])

                        # Print state while training
                        if step % FLAGS.steps_to_validate == 0:
                            loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, auc_value, summary_value = sess.run(
                                [
                                    loss, train_accuracy, train_auc,
                                    validate_accuracy, validate_auc, summary_op
                                ])
                            end_time = datetime.datetime.now()

                            logging.info(
                                "[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}"
                                .format(end_time - start_time, step,
                                        loss_value, train_accuracy_value,
                                        train_auc_value,
                                        validate_accuracy_value, auc_value))
                            writer.add_summary(summary_value, step)
                            saver.save(sess,
                                       checkpoint_file_path,
                                       global_step=step)
                            start_time = end_time
            except tf.errors.OutOfRangeError:
                if FLAGS.benchmark_mode:
                    print("Finish training for benchmark")
                    exit(0)
                else:
                    # Export the model after training
                    util.save_model(FLAGS.model_path,
                                    FLAGS.model_version,
                                    sess,
                                    signature_def_map,
                                    is_save_graph=False)
            finally:
                coord.request_stop()
            coord.join(threads)

        elif FLAGS.mode == "save_model":
            if not util.restore_from_checkpoint(sess, saver,
                                                latest_checkpoint_file_path):
                logging.error("No checkpoint found, exit now")
                exit(1)

            util.save_model(FLAGS.model_path,
                            FLAGS.model_version,
                            sess,
                            signature_def_map,
                            is_save_graph=False)

        elif FLAGS.mode == "inference":
            if not util.restore_from_checkpoint(sess, saver,
                                                latest_checkpoint_file_path):
                logging.error("No checkpoint found, exit now")
                exit(1)

            # Load inference test data
            inference_result_file_name = "./inference_result.txt"
            inference_test_file_name = "./data/a8a_test.libsvm"
            labels = []
            feature_ids = []
            feature_values = []
            feature_index = []
            ins_num = 0
            for line in open(inference_test_file_name, "r"):
                tokens = line.split(" ")
                labels.append(int(tokens[0]))
                feature_num = 0
                for feature in tokens[1:]:
                    feature_id, feature_value = feature.split(":")
                    feature_ids.append(int(feature_id))
                    feature_values.append(float(feature_value))
                    feature_index.append([ins_num, feature_num])
                    feature_num += 1
                ins_num += 1

            # Run inference
            start_time = datetime.datetime.now()
            prediction, prediction_softmax = sess.run(
                [inference_op, inference_softmax],
                feed_dict={
                    sparse_index: feature_index,
                    sparse_ids: feature_ids,
                    sparse_values: feature_values,
                    sparse_shape: [ins_num, FLAGS.feature_size]
                })

            end_time = datetime.datetime.now()

            # Compute accuracy
            label_number = len(labels)
            correct_label_number = 0
            for i in range(label_number):
                if labels[i] == prediction[i]:
                    correct_label_number += 1
            accuracy = float(correct_label_number) / label_number

            # Compute auc
            expected_labels = np.array(labels)
            predict_labels = prediction_softmax[:, 0]
            fpr, tpr, thresholds = metrics.roc_curve(expected_labels,
                                                     predict_labels,
                                                     pos_label=0)
            auc = metrics.auc(fpr, tpr)
            logging.info("[{}] Inference accuracy: {}, auc: {}".format(
                end_time - start_time, accuracy, auc))

            # Save result into the file
            np.savetxt(inference_result_file_name,
                       prediction_softmax,
                       delimiter=",")
            logging.info(
                "Save result to file: {}".format(inference_result_file_name))

        elif FLAGS.mode == "inference_with_tfrecords":
            if not util.restore_from_checkpoint(sess, saver,
                                                latest_checkpoint_file_path):
                logging.error("No checkpoint found, exit now")
                exit(1)

            # Load inference test data
            inference_result_file_name = "./inference_result.txt"
            inference_test_file_name = "./data/a8a/a8a_test.libsvm.tfrecords"

            batch_feature_index = []
            batch_labels = []
            batch_ids = []
            batch_values = []
            ins_num = 0

            # Read from TFRecords files
            for serialized_example in tf.python_io.tf_record_iterator(
                    inference_test_file_name):
                # Get serialized example from file
                example = tf.train.Example()
                example.ParseFromString(serialized_example)
                label = example.features.feature["label"].float_list.value
                ids = example.features.feature["ids"].int64_list.value
                values = example.features.feature["values"].float_list.value
                #print("label: {}, features: {}".format(label, " ".join([str(id) + ":" + str(value) for id, value in zip(ids, values)])))
                batch_labels.append(label)
                # Notice that using extend() instead of append() to flatten the values
                batch_ids.extend(ids)
                batch_values.extend(values)
                for i in xrange(len(ids)):
                    batch_feature_index.append([ins_num, i])

                ins_num += 1

            # Run inference
            start_time = datetime.datetime.now()
            prediction, prediction_softmax = sess.run(
                [inference_op, inference_softmax],
                feed_dict={
                    sparse_index: batch_feature_index,
                    sparse_ids: batch_ids,
                    sparse_values: batch_values,
                    sparse_shape: [ins_num, FLAGS.feature_size]
                })

            end_time = datetime.datetime.now()

            # Compute accuracy
            label_number = len(batch_labels)
            correct_label_number = 0
            for i in range(label_number):
                if batch_labels[i] == prediction[i]:
                    correct_label_number += 1
            accuracy = float(correct_label_number) / label_number

            # Compute auc
            expected_labels = np.array(batch_labels)
            predict_labels = prediction_softmax[:, 0]
            fpr, tpr, thresholds = metrics.roc_curve(expected_labels,
                                                     predict_labels,
                                                     pos_label=0)
            auc = metrics.auc(fpr, tpr)
            logging.info("[{}] Inference accuracy: {}, auc: {}".format(
                end_time - start_time, accuracy, auc))

            # Save result into the file
            np.savetxt(inference_result_file_name,
                       prediction_softmax,
                       delimiter=",")
            logging.info(
                "Save result to file: {}".format(inference_result_file_name))
def main():
  """
  Train the TensorFlow models.
  """

  # Get hyper-parameters
  if os.path.exists(FLAGS.checkpoint_path) == False:
    os.makedirs(FLAGS.checkpoint_path)
  checkpoint_file_path = FLAGS.checkpoint_path + "/checkpoint.ckpt"
  latest_checkpoint_file_path = tf.train.latest_checkpoint(
      FLAGS.checkpoint_path)

  if os.path.exists(FLAGS.output_path) == False:
    os.makedirs(FLAGS.output_path)

  # Step 1: Construct the dataset op
  epoch_number = FLAGS.epoch_number
  if epoch_number <= 0:
    epoch_number = -1
  train_buffer_size = FLAGS.train_batch_size * 3
  validation_buffer_size = FLAGS.train_batch_size * 3

  train_filename_list = [filename for filename in FLAGS.train_files.split(",")]
  train_filename_placeholder = tf.placeholder(tf.string, shape=[None])
  if FLAGS.file_format == "tfrecords":
    train_dataset = tf.data.TFRecordDataset(train_filename_placeholder)
    train_dataset = train_dataset.map(parse_tfrecords_function).repeat(
        epoch_number).batch(FLAGS.train_batch_size).shuffle(
            buffer_size=train_buffer_size)
  elif FLAGS.file_format == "csv":
    # Skip the header or not
    train_dataset = tf.data.TextLineDataset(train_filename_placeholder)
    train_dataset = train_dataset.map(parse_csv_function).repeat(
        epoch_number).batch(FLAGS.train_batch_size).shuffle(
            buffer_size=train_buffer_size)
  train_dataset_iterator = train_dataset.make_initializable_iterator()
  train_features_op, train_label_op = train_dataset_iterator.get_next()

  validation_filename_list = [
      filename for filename in FLAGS.validation_files.split(",")
  ]
  validation_filename_placeholder = tf.placeholder(tf.string, shape=[None])
  if FLAGS.file_format == "tfrecords":
    validation_dataset = tf.data.TFRecordDataset(
        validation_filename_placeholder)
    validation_dataset = validation_dataset.map(
        parse_tfrecords_function).repeat(epoch_number).batch(
            FLAGS.validation_batch_size).shuffle(
                buffer_size=validation_buffer_size)
  elif FLAGS.file_format == "csv":
    validation_dataset = tf.data.TextLineDataset(
        validation_filename_placeholder)
    validation_dataset = validation_dataset.map(parse_csv_function).repeat(
        epoch_number).batch(FLAGS.validation_batch_size).shuffle(
            buffer_size=validation_buffer_size)
  validation_dataset_iterator = validation_dataset.make_initializable_iterator(
  )
  validation_features_op, validation_label_op = validation_dataset_iterator.get_next(
  )

  # Step 2: Define the model
  input_units = FLAGS.feature_size
  output_units = FLAGS.label_size
  logits = inference(train_features_op, input_units, output_units, True)

  if FLAGS.loss == "sparse_cross_entropy":
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits, labels=train_label_op)
    loss = tf.reduce_mean(cross_entropy, name="loss")
  elif FLAGS.loss == "cross_entropy":
    cross_entropy = tf.nn.cross_entropy_with_logits(
        logits=logits, labels=train_label_op)
    loss = tf.reduce_mean(cross_entropy, name="loss")
  elif FLAGS.loss == "mean_square":
    msl = tf.square(logits - train_label_op, name="msl")
    loss = tf.reduce_mean(msl, name="loss")

  global_step = tf.Variable(0, name="global_step", trainable=False)
  learning_rate = FLAGS.learning_rate

  if FLAGS.enable_lr_decay:
    logging.info(
        "Enable learning rate decay rate: {}".format(FLAGS.lr_decay_rate))
    starter_learning_rate = FLAGS.learning_rate
    learning_rate = tf.train.exponential_decay(
        starter_learning_rate,
        global_step,
        100000,
        FLAGS.lr_decay_rate,
        staircase=True)

  optimizer = util.get_optimizer_by_name(FLAGS.optimizer, learning_rate)
  train_op = optimizer.minimize(loss, global_step=global_step)

  # Need to re-use the Variables for training and validation
  tf.get_variable_scope().reuse_variables()

  # Define accuracy op and auc op for train
  train_accuracy_logits = inference(train_features_op, input_units,
                                    output_units, False)
  train_softmax_op, train_accuracy_op = model.compute_softmax_and_accuracy(
      train_accuracy_logits, train_label_op)
  train_auc_op = model.compute_auc(train_softmax_op, train_label_op,
                                   FLAGS.label_size)

  # Define accuracy op and auc op for validation
  validation_accuracy_logits = inference(validation_features_op, input_units,
                                         output_units, False)
  validation_softmax_op, validation_accuracy_op = model.compute_softmax_and_accuracy(
      validation_accuracy_logits, validation_label_op)
  validation_auc_op = model.compute_auc(validation_softmax_op,
                                        validation_label_op, FLAGS.label_size)

  # Define inference op
  inference_features = tf.placeholder(
      "float", [None, FLAGS.feature_size], name="features")
  inference_logits = inference(inference_features, input_units, output_units,
                               False)
  inference_softmax_op = tf.nn.softmax(
      inference_logits, name="inference_softmax")
  inference_prediction_op = tf.argmax(
      inference_softmax_op, 1, name="inference_prediction")
  keys_placeholder = tf.placeholder(tf.int32, shape=[None, 1], name="keys")
  keys_identity = tf.identity(keys_placeholder, name="inference_keys")

  signature_def_map = {
      signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
      signature_def_utils.build_signature_def(
          inputs={
              "keys": utils.build_tensor_info(keys_placeholder),
              "features": utils.build_tensor_info(inference_features)
          },
          outputs={
              "keys": utils.build_tensor_info(keys_identity),
              "prediction": utils.build_tensor_info(inference_prediction_op),
          },
          method_name="tensorflow/serving/predictss"),
      "serving_detail":
      signature_def_utils.build_signature_def(
          inputs={
              "keys": utils.build_tensor_info(keys_placeholder),
              "features": utils.build_tensor_info(inference_features)
          },
          outputs={
              "keys": utils.build_tensor_info(keys_identity),
              "prediction": utils.build_tensor_info(inference_prediction_op),
              "softmax": utils.build_tensor_info(inference_softmax_op),
          },
          method_name="sdfas")
  }

  # Initialize saver and summary
  saver = tf.train.Saver()
  tf.summary.scalar("loss", loss)
  if FLAGS.scenario == "classification":
    tf.summary.scalar("train_accuracy", train_accuracy_op)
    tf.summary.scalar("train_auc", train_auc_op)
    tf.summary.scalar("validate_accuracy", validation_accuracy_op)
    tf.summary.scalar("validate_auc", validation_auc_op)
  summary_op = tf.summary.merge_all()
  init_op = [
      tf.global_variables_initializer(),
      tf.local_variables_initializer()
  ]

  # Step 3: Create session to run
  with tf.Session() as sess:
    writer = tf.summary.FileWriter(FLAGS.output_path, sess.graph)
    sess.run(init_op)
    sess.run(
        [
            train_dataset_iterator.initializer,
            validation_dataset_iterator.initializer
        ],
        feed_dict={
            train_filename_placeholder: train_filename_list,
            validation_filename_placeholder: validation_filename_list
        })

    if FLAGS.mode == "train":
      if FLAGS.resume_from_checkpoint:
        util.restore_from_checkpoint(sess, saver, latest_checkpoint_file_path)

      try:
        start_time = datetime.datetime.now()

        while True:
          if FLAGS.enable_benchmark:
            sess.run(train_op)
          else:

            _, global_step_value = sess.run([train_op, global_step])

            # Step 4: Display training metrics after steps
            if global_step_value % FLAGS.steps_to_validate == 0:
              if FLAGS.scenario == "classification":
                loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, validate_auc_value, summary_value = sess.run(
                    [
                        loss, train_accuracy_op, train_auc_op,
                        validation_accuracy_op, validation_auc_op, summary_op
                    ])
                end_time = datetime.datetime.now()

                logging.info(
                    "[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}".
                    format(end_time - start_time, global_step_value,
                           loss_value, train_accuracy_value, train_auc_value,
                           validate_accuracy_value, validate_auc_value))

              elif FLAGS.scenario == "regression":
                loss_value, summary_value = sess.run([loss, summary_op])
                end_time = datetime.datetime.now()
                logging.info("[{}] Step: {}, loss: {}".format(
                    end_time - start_time, global_step_value, loss_value))

              writer.add_summary(summary_value, global_step_value)
              saver.save(
                  sess, checkpoint_file_path, global_step=global_step_value)

              start_time = end_time

      except tf.errors.OutOfRangeError:
        if FLAGS.enable_benchmark:
          logging.info("Finish training for benchmark")
        else:
          # Step 5: Export the model after training
          util.save_model(
              FLAGS.model_path,
              FLAGS.model_version,
              sess,
              signature_def_map,
              is_save_graph=False)

    elif FLAGS.mode == "savedmodel":
      if util.restore_from_checkpoint(sess, saver,
                                      latest_checkpoint_file_path) == False:
        logging.error("No checkpoint for exporting model, exit now")
        return

      util.save_model(
          FLAGS.model_path,
          FLAGS.model_version,
          sess,
          signature_def_map,
          is_save_graph=False)

    elif FLAGS.mode == "inference":
      if util.restore_from_checkpoint(sess, saver,
                                      latest_checkpoint_file_path) == False:
        logging.error("No checkpoint for inference, exit now")
        return

      # Load test data
      inference_result_file_name = FLAGS.inference_result_file
      inference_test_file_name = FLAGS.inference_data_file
      inference_data = np.genfromtxt(inference_test_file_name, delimiter=",")
      inference_data_features = inference_data[:, 0:9]
      inference_data_labels = inference_data[:, 9]

      # Run inference
      start_time = datetime.datetime.now()
      prediction, prediction_softmax = sess.run(
          [inference_prediction_op, inference_softmax_op],
          feed_dict={inference_features: inference_data_features})
      end_time = datetime.datetime.now()

      # Compute accuracy
      label_number = len(inference_data_labels)
      correct_label_number = 0
      for i in range(label_number):
        if inference_data_labels[i] == prediction[i]:
          correct_label_number += 1
      accuracy = float(correct_label_number) / label_number

      # Compute auc
      y_true = np.array(inference_data_labels)
      y_score = prediction_softmax[:, 1]
      fpr, tpr, thresholds = metrics.roc_curve(y_true, y_score, pos_label=1)
      auc = metrics.auc(fpr, tpr)
      logging.info("[{}] Inference accuracy: {}, auc: {}".format(
          end_time - start_time, accuracy, auc))

      # Save result into the file
      np.savetxt(inference_result_file_name, prediction_softmax, delimiter=",")
      logging.info(
          "Save result to file: {}".format(inference_result_file_name))
Example #5
0
def main():
    global global_token_count, event_writer, train_step, train_loss, last_log_step, \
        best_val_loss, epoch, model

    if args.local_rank > 0:
        pass  # skip shutdown when rank is explicitly set + not zero rank
    else:
        os.system('shutdown -c')

    if not args.local:
        logger.info(
            f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}'
        )
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=util.get_world_size())
        assert (util.get_world_size() == dist.get_world_size())
        logger.info(
            f"Distributed: success ({args.local_rank}/{dist.get_world_size()})"
        )

    model = MemTransformerLM(ntokens,
                             args.n_layer,
                             args.n_head,
                             args.d_model,
                             args.d_head,
                             args.d_inner,
                             args.dropout,
                             args.dropatt,
                             tie_weight=args.tied,
                             d_embed=args.d_embed,
                             div_val=args.div_val,
                             tie_projs=tie_projs,
                             pre_lnorm=args.pre_lnorm,
                             tgt_len=args.tgt_len,
                             ext_len=args.ext_len,
                             mem_len=args.mem_len,
                             cutoffs=cutoffs,
                             same_length=args.same_length,
                             attn_type=args.attn_type,
                             clamp_len=args.clamp_len,
                             sample_softmax=args.sample_softmax)

    # log model info
    n_all_param = sum([p.nelement() for p in model.parameters()])
    log_tb('sizes/params', n_all_param)
    n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
    log_tb('sizes/non_emb_params', n_nonemb_param)
    logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param)

    # optimizer
    if args.optim.lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.mom)
    elif args.optim.lower() == 'lamb':
        optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=args.wd)
    else:
        assert args.optim.lower() == 'adam'
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.wd)

    # scheduler
    if args.scheduler == 'cosine':
        # Divide by 1e6 for numerical stability.
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         args.max_tokens //
                                                         1e6,
                                                         eta_min=args.eta_min)
    elif args.scheduler == 'finder':
        scheduler = LRFinder(optimizer,
                             args.max_tokens,
                             init_value=args.lr / 1e3)
    elif args.scheduler == 'constant':
        pass

    model.apply(weights_init)
    model.word_emb.apply(
        weights_init
    )  # ensure embedding init is not overridden by out_layer in case of weight sharing

    if args.checkpoint:
        if global_rank == 0:
            util.restore_from_checkpoint(model=model,
                                         checkpoint_fn=args.checkpoint)

    model = model.to(device)
    if args.fp16:
        model = FP16_Module(model)
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={'init_scale': 2**16},
                                   verbose=False)

    if args.local:
        model = nn.DataParallel(model, dim=1)
    else:
        # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding.
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank],
            output_device=args.local_rank)  #, find_unused_parameters=True)

    if global_rank == 0:
        event_writer = SummaryWriter(args.logdir)

    event_writer.add_text('args', str(args))

    # test checkpoint writing
    if args.checkpoint_each_epoch:
        logger.info(f'Saving checkpoint for epoch {epoch}')
        util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{0}')

    # Loop over epochs.
    train_step = 0
    train_loss = 0
    last_log_step = 0
    best_val_loss = None
    va_iter, te_iter = [
        corpus.get_dist_iterator(split,
                                 global_rank,
                                 max_rank,
                                 args.batch_size * 2,
                                 args.tgt_len,
                                 device=device,
                                 ext_len=args.ext_len)
        for split in ('valid', 'test')
    ]

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in itertools.count(start=1):
            train(va_iter, optimizer, scheduler)
    except KeyboardInterrupt:
        logger.info('-' * 100)
        logger.info('Exiting from training early')
    except StopIteration:
        pass

    # Eval one more time.
    evaluate_and_log(optimizer, va_iter, 'val', train_step=-1)

    # Load the best saved model.
    logger.info("Loading best checkpoint")
    model_file = os.path.join(args.logdir, 'model-best.pt')
    if os.path.exists(model_file):
        with open(model_file, 'rb') as model_f:
            with timeit('load'):
                if args.local:
                    model = torch.load(model_f)
                else:
                    model = torch.load(model_f,
                                       map_location=lambda storage, loc:
                                       storage.cuda(args.local_rank))
                    model = DistributedDataParallel(
                        model,
                        device_ids=[args.local_rank],
                        output_device=args.local_rank)
    else:
        logger.warn('no model file, using current model for loss')

    # Run on test data.
    evaluate_and_log(optimizer, te_iter, 'test', -1)
Example #6
0
def main_loop():
    util.cancel_shutdown()
    losses = []

    args = g.args

    if not args.local:
        g.logger.info(
            f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}')
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=util.get_world_size())
        assert (util.get_world_size() == dist.get_world_size())
        g.logger.info(f"Distributed: success ({args.local_rank}/{dist.get_world_size()})")

    if args.load_state_fn:
        g.state = load_state(args.load_state_fn)
        g.logger.info(f"Restoring training from {args.load_state_fn}")
    else:
        g.logger.info("creating new model")
        g.state = TrainState(args)

        g.state.model = MemTransformerLM(g.ntokens, args.n_layer, args.n_head, args.d_model,
                                         args.d_head, args.d_inner, args.dropout, args.dropatt,
                                         tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val,
                                         tie_projs=g.tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
                                         ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=g.cutoffs,
                                         same_length=args.same_length, attn_type=args.attn_type,
                                         clamp_len=args.clamp_len, sample_softmax=args.sample_softmax)
        if args.checkpoint:
            util.restore_from_checkpoint(g.state.model, checkpoint_fn=args.checkpoint)
        else:
            g.state.model.apply(weights_init)
            g.state.model.word_emb.apply(
                weights_init)  # ensure embedding init is not overridden by out_layer in case of weight sharing
        g.state.model.to(g.device)
        optimizer_setup(g.state)

    model: MemTransformerLM = g.state.model
    optimizer = g.state.optimizer

    # log model info
    # n_all_param = sum([p.nelement() for p in model.parameters()])
    # log_tb('sizes/params', n_all_param)
    # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
    # log_tb('sizes/non_emb_params', n_nonemb_param)
    # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param)

    # scheduler
    if not g.args.load_state_fn:
        if args.scheduler == 'cosine':
            # Divide by 1e6 for numerical stability.
            g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_tokens // 1e6,
                                                                     eta_min=args.eta_min)
        elif args.scheduler == 'finder':
            g.state.scheduler: LRFinder = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3)
        else:
            assert args.scheduler == 'constant'
            g.state.scheduler = util.NoOp()

    # Setup distributed model
    if args.local:
        model = nn.DataParallel(model, dim=1)
    else:
        # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding.
        model = DistributedDataParallel(model, device_ids=[args.local_rank],
                                        output_device=args.local_rank)  # , find_unused_parameters=True)

    if util.get_global_rank() == 0:
        if not args.test:
            wandb.config.update(vars(args))
            # wandb.watch(model)

    g.event_writer.add_text('args', str(args))  # TODO: replace with log_tb

    accumulated_loss = 0
    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in itertools.count(start=g.state.last_epoch):
            print(f"epoch -- {epoch}, token_count -- {g.state.
Example #7
0
def main():
    """
  Train the TensorFlow models.
  """

    # Get hyper-parameters
    if os.path.exists(FLAGS.checkpoint_path) == False:
        os.makedirs(FLAGS.checkpoint_path)
    checkpoint_file_path = FLAGS.checkpoint_path + "/checkpoint.ckpt"
    latest_checkpoint_file_path = tf.train.latest_checkpoint(
        FLAGS.checkpoint_path)

    if os.path.exists(FLAGS.output_path) == False:
        os.makedirs(FLAGS.output_path)

    # Step 1: Construct the dataset op
    epoch_number = FLAGS.epoch_number
    if epoch_number <= 0:
        epoch_number = -1
    train_buffer_size = FLAGS.train_batch_size * 3
    validation_buffer_size = FLAGS.train_batch_size * 3

    train_filename_list = [
        filename for filename in FLAGS.train_files.split(",")
    ]
    train_filename_placeholder = tf.placeholder(tf.string, shape=[None])
    if FLAGS.file_format == "tfrecords":
        train_dataset = tf.data.TFRecordDataset(train_filename_placeholder)
        train_dataset = train_dataset.map(parse_tfrecords_function).repeat(
            epoch_number).batch(
                FLAGS.train_batch_size).shuffle(buffer_size=train_buffer_size)
    elif FLAGS.file_format == "csv":
        # Skip the header or not
        train_dataset = tf.data.TextLineDataset(train_filename_placeholder)
        train_dataset = train_dataset.map(parse_csv_function).repeat(
            epoch_number).batch(
                FLAGS.train_batch_size).shuffle(buffer_size=train_buffer_size)
    train_dataset_iterator = train_dataset.make_initializable_iterator()
    train_features_op, train_label_op = train_dataset_iterator.get_next()

    validation_filename_list = [
        filename for filename in FLAGS.validation_files.split(",")
    ]
    validation_filename_placeholder = tf.placeholder(tf.string, shape=[None])
    if FLAGS.file_format == "tfrecords":
        validation_dataset = tf.data.TFRecordDataset(
            validation_filename_placeholder)
        validation_dataset = validation_dataset.map(
            parse_tfrecords_function).repeat(epoch_number).batch(
                FLAGS.validation_batch_size).shuffle(
                    buffer_size=validation_buffer_size)
    elif FLAGS.file_format == "csv":
        validation_dataset = tf.data.TextLineDataset(
            validation_filename_placeholder)
        validation_dataset = validation_dataset.map(parse_csv_function).repeat(
            epoch_number).batch(FLAGS.validation_batch_size).shuffle(
                buffer_size=validation_buffer_size)
    validation_dataset_iterator = validation_dataset.make_initializable_iterator(
    )
    validation_features_op, validation_label_op = validation_dataset_iterator.get_next(
    )

    # Step 2: Define the model
    input_units = FLAGS.feature_size
    output_units = FLAGS.label_size
    logits = inference(train_features_op, input_units, output_units, True)

    if FLAGS.loss == "sparse_cross_entropy":
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=train_label_op)
        loss = tf.reduce_mean(cross_entropy, name="loss")
    elif FLAGS.loss == "cross_entropy":
        cross_entropy = tf.nn.cross_entropy_with_logits(logits=logits,
                                                        labels=train_label_op)
        loss = tf.reduce_mean(cross_entropy, name="loss")
    elif FLAGS.loss == "mean_square":
        msl = tf.square(logits - train_label_op, name="msl")
        loss = tf.reduce_mean(msl, name="loss")

    global_step = tf.Variable(0, name="global_step", trainable=False)
    learning_rate = FLAGS.learning_rate

    if FLAGS.enable_lr_decay:
        logging.info("Enable learning rate decay rate: {}".format(
            FLAGS.lr_decay_rate))
        starter_learning_rate = FLAGS.learning_rate
        learning_rate = tf.train.exponential_decay(starter_learning_rate,
                                                   global_step,
                                                   100000,
                                                   FLAGS.lr_decay_rate,
                                                   staircase=True)

    optimizer = util.get_optimizer_by_name(FLAGS.optimizer, learning_rate)
    train_op = optimizer.minimize(loss, global_step=global_step)

    # Need to re-use the Variables for training and validation
    tf.get_variable_scope().reuse_variables()

    # Define accuracy op and auc op for train
    train_accuracy_logits = inference(train_features_op, input_units,
                                      output_units, False)
    train_softmax_op, train_accuracy_op = model.compute_softmax_and_accuracy(
        train_accuracy_logits, train_label_op)
    train_auc_op = model.compute_auc(train_softmax_op, train_label_op,
                                     FLAGS.label_size)

    # Define accuracy op and auc op for validation
    validation_accuracy_logits = inference(validation_features_op, input_units,
                                           output_units, False)
    validation_softmax_op, validation_accuracy_op = model.compute_softmax_and_accuracy(
        validation_accuracy_logits, validation_label_op)
    validation_auc_op = model.compute_auc(validation_softmax_op,
                                          validation_label_op,
                                          FLAGS.label_size)

    # Define inference op
    inference_features = tf.placeholder("float", [None, FLAGS.feature_size],
                                        name="features")
    inference_logits = inference(inference_features, input_units, output_units,
                                 False)
    inference_softmax_op = tf.nn.softmax(inference_logits,
                                         name="inference_softmax")
    inference_prediction_op = tf.argmax(inference_softmax_op,
                                        1,
                                        name="inference_prediction")
    keys_placeholder = tf.placeholder(tf.int32, shape=[None, 1], name="keys")
    keys_identity = tf.identity(keys_placeholder, name="inference_keys")

    signature_def_map = {
        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
        signature_def_utils.build_signature_def(
            inputs={
                "keys": utils.build_tensor_info(keys_placeholder),
                "features": utils.build_tensor_info(inference_features)
            },
            outputs={
                "keys": utils.build_tensor_info(keys_identity),
                "prediction": utils.build_tensor_info(inference_prediction_op),
            },
            method_name="tensorflow/serving/predictss"),
        "serving_detail":
        signature_def_utils.build_signature_def(
            inputs={
                "keys": utils.build_tensor_info(keys_placeholder),
                "features": utils.build_tensor_info(inference_features)
            },
            outputs={
                "keys": utils.build_tensor_info(keys_identity),
                "prediction": utils.build_tensor_info(inference_prediction_op),
                "softmax": utils.build_tensor_info(inference_softmax_op),
            },
            method_name="sdfas")
    }

    # Initialize saver and summary
    saver = tf.train.Saver()
    tf.summary.scalar("loss", loss)
    if FLAGS.scenario == "classification":
        tf.summary.scalar("train_accuracy", train_accuracy_op)
        tf.summary.scalar("train_auc", train_auc_op)
        tf.summary.scalar("validate_accuracy", validation_accuracy_op)
        tf.summary.scalar("validate_auc", validation_auc_op)
    summary_op = tf.summary.merge_all()
    init_op = [
        tf.global_variables_initializer(),
        tf.local_variables_initializer()
    ]

    # Step 3: Create session to run
    with tf.Session() as sess:
        writer = tf.summary.FileWriter(FLAGS.output_path, sess.graph)
        sess.run(init_op)
        sess.run(
            [
                train_dataset_iterator.initializer,
                validation_dataset_iterator.initializer
            ],
            feed_dict={
                train_filename_placeholder: train_filename_list,
                validation_filename_placeholder: validation_filename_list
            })

        if FLAGS.mode == "train":
            if FLAGS.resume_from_checkpoint:
                util.restore_from_checkpoint(sess, saver,
                                             latest_checkpoint_file_path)

            try:
                start_time = datetime.datetime.now()

                while True:
                    if FLAGS.enable_benchmark:
                        sess.run(train_op)
                    else:

                        _, global_step_value = sess.run(
                            [train_op, global_step])

                        # Step 4: Display training metrics after steps
                        if global_step_value % FLAGS.steps_to_validate == 0:
                            if FLAGS.scenario == "classification":
                                loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, validate_auc_value, summary_value = sess.run(
                                    [
                                        loss, train_accuracy_op, train_auc_op,
                                        validation_accuracy_op,
                                        validation_auc_op, summary_op
                                    ])
                                end_time = datetime.datetime.now()

                                logging.info(
                                    "[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}"
                                    .format(end_time - start_time,
                                            global_step_value, loss_value,
                                            train_accuracy_value,
                                            train_auc_value,
                                            validate_accuracy_value,
                                            validate_auc_value))

                            elif FLAGS.scenario == "regression":
                                loss_value, summary_value = sess.run(
                                    [loss, summary_op])
                                end_time = datetime.datetime.now()
                                logging.info("[{}] Step: {}, loss: {}".format(
                                    end_time - start_time, global_step_value,
                                    loss_value))

                            writer.add_summary(summary_value,
                                               global_step_value)
                            saver.save(sess,
                                       checkpoint_file_path,
                                       global_step=global_step_value)

                            start_time = end_time

            except tf.errors.OutOfRangeError:
                if FLAGS.enable_benchmark:
                    logging.info("Finish training for benchmark")
                else:
                    # Step 5: Export the model after training
                    util.save_model(FLAGS.model_path,
                                    FLAGS.model_version,
                                    sess,
                                    signature_def_map,
                                    is_save_graph=False)

        elif FLAGS.mode == "savedmodel":
            if util.restore_from_checkpoint(
                    sess, saver, latest_checkpoint_file_path) == False:
                logging.error("No checkpoint for exporting model, exit now")
                return

            util.save_model(FLAGS.model_path,
                            FLAGS.model_version,
                            sess,
                            signature_def_map,
                            is_save_graph=False)

        elif FLAGS.mode == "inference":
            if util.restore_from_checkpoint(
                    sess, saver, latest_checkpoint_file_path) == False:
                logging.error("No checkpoint for inference, exit now")
                return

            # Load test data
            inference_result_file_name = FLAGS.inference_result_file
            inference_test_file_name = FLAGS.inference_data_file
            inference_data = np.genfromtxt(inference_test_file_name,
                                           delimiter=",")
            inference_data_features = inference_data[:, 0:9]
            inference_data_labels = inference_data[:, 9]

            # Run inference
            start_time = datetime.datetime.now()
            prediction, prediction_softmax = sess.run(
                [inference_prediction_op, inference_softmax_op],
                feed_dict={inference_features: inference_data_features})
            end_time = datetime.datetime.now()

            # Compute accuracy
            label_number = len(inference_data_labels)
            correct_label_number = 0
            for i in range(label_number):
                if inference_data_labels[i] == prediction[i]:
                    correct_label_number += 1
            accuracy = float(correct_label_number) / label_number

            # Compute auc
            y_true = np.array(inference_data_labels)
            y_score = prediction_softmax[:, 1]
            fpr, tpr, thresholds = metrics.roc_curve(y_true,
                                                     y_score,
                                                     pos_label=1)
            auc = metrics.auc(fpr, tpr)
            logging.info("[{}] Inference accuracy: {}, auc: {}".format(
                end_time - start_time, accuracy, auc))

            # Save result into the file
            np.savetxt(inference_result_file_name,
                       prediction_softmax,
                       delimiter=",")
            logging.info(
                "Save result to file: {}".format(inference_result_file_name))