def main():
    script_start = time.time()
    hvd_init()
    mpi_comm = MPI.COMM_WORLD
    args = parse_args()

    if hvd.rank() == 0:
        dllogger.init(backends=[
            dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
                                       filename=args.log_path),
            dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE)
        ])
    else:
        dllogger.init(backends=[])

    dllogger.log(data=vars(args), step='PARAMETER')

    if args.seed is not None:
        tf.random.set_random_seed(args.seed)
        np.random.seed(args.seed)
        cp.random.seed(args.seed)

    if args.amp:
        os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "1"
    if "TF_ENABLE_AUTO_MIXED_PRECISION" in os.environ \
       and os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] == "1":
        args.fp16 = False

    if not os.path.exists(args.checkpoint_dir) and args.checkpoint_dir != '':
        os.makedirs(args.checkpoint_dir, exist_ok=True)
    final_checkpoint_path = os.path.join(args.checkpoint_dir, 'model.ckpt')

    # Load converted data and get statistics
    train_df = pd.read_pickle(args.data + '/train_ratings.pickle')
    test_df = pd.read_pickle(args.data + '/test_ratings.pickle')
    nb_users, nb_items = train_df.max() + 1

    # Extract train and test feature tensors from dataframe
    pos_train_users = train_df.iloc[:, 0].values.astype(np.int32)
    pos_train_items = train_df.iloc[:, 1].values.astype(np.int32)
    pos_test_users = test_df.iloc[:, 0].values.astype(np.int32)
    pos_test_items = test_df.iloc[:, 1].values.astype(np.int32)
    # Negatives indicator for negatives generation
    neg_mat = np.ones((nb_users, nb_items), dtype=np.bool)
    neg_mat[pos_train_users, pos_train_items] = 0

    # Get the local training/test data
    train_users, train_items, train_labels = get_local_train_data(
        pos_train_users, pos_train_items, args.negative_samples)
    test_users, test_items = get_local_test_data(pos_test_users,
                                                 pos_test_items)

    # Create and run Data Generator in a separate thread
    data_generator = DataGenerator(
        args.seed,
        hvd.rank(),
        nb_users,
        nb_items,
        neg_mat,
        train_users,
        train_items,
        train_labels,
        args.batch_size // hvd.size(),
        args.negative_samples,
        test_users,
        test_items,
        args.valid_users_per_batch,
        args.valid_negative,
    )

    # Create tensorflow session and saver
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    if args.xla:
        config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
    sess = tf.Session(config=config)

    # Input tensors
    users = tf.placeholder(tf.int32, shape=(None, ))
    items = tf.placeholder(tf.int32, shape=(None, ))
    labels = tf.placeholder(tf.int32, shape=(None, ))
    is_dup = tf.placeholder(tf.float32, shape=(None, ))
    dropout = tf.placeholder_with_default(args.dropout, shape=())
    # Model ops and saver
    hit_rate, ndcg, eval_op, train_op = ncf_model_ops(
        users,
        items,
        labels,
        is_dup,
        params={
            'fp16': args.fp16,
            'val_batch_size': args.valid_negative + 1,
            'top_k': args.topk,
            'learning_rate': args.learning_rate,
            'beta_1': args.beta1,
            'beta_2': args.beta2,
            'epsilon': args.eps,
            'num_users': nb_users,
            'num_items': nb_items,
            'num_factors': args.factors,
            'mf_reg': 0,
            'layer_sizes': args.layers,
            'layer_regs': [0. for i in args.layers],
            'dropout': dropout,
            'sigmoid': True,
            'loss_scale': args.loss_scale
        },
        mode='TRAIN' if args.mode == 'train' else 'EVAL')
    saver = tf.train.Saver()

    # Accuracy metric tensors
    hr_sum = tf.get_default_graph().get_tensor_by_name(
        'neumf/hit_rate/total:0')
    hr_cnt = tf.get_default_graph().get_tensor_by_name(
        'neumf/hit_rate/count:0')
    ndcg_sum = tf.get_default_graph().get_tensor_by_name('neumf/ndcg/total:0')
    ndcg_cnt = tf.get_default_graph().get_tensor_by_name('neumf/ndcg/count:0')

    # Prepare evaluation data
    data_generator.prepare_eval_data()

    if args.load_checkpoint_path:
        saver.restore(sess, args.load_checkpoint_path)
    else:
        # Manual initialize weights
        sess.run(tf.global_variables_initializer())

    # If test mode, run one eval
    if args.mode == 'test':
        sess.run(tf.local_variables_initializer())
        eval_start = time.time()
        for user_batch, item_batch, dup_batch \
            in zip(data_generator.eval_users, data_generator.eval_items, data_generator.dup_mask):
            sess.run(eval_op,
                     feed_dict={
                         users: user_batch,
                         items: item_batch,
                         is_dup: dup_batch,
                         dropout: 0.0
                     })
        eval_duration = time.time() - eval_start

        # Report results
        hit_rate_sum = sess.run(hvd.allreduce(hr_sum, average=False))
        hit_rate_cnt = sess.run(hvd.allreduce(hr_cnt, average=False))
        ndcg_sum = sess.run(hvd.allreduce(ndcg_sum, average=False))
        ndcg_cnt = sess.run(hvd.allreduce(ndcg_cnt, average=False))

        hit_rate = hit_rate_sum / hit_rate_cnt
        ndcg = ndcg_sum / ndcg_cnt

        if hvd.rank() == 0:
            eval_throughput = pos_test_users.shape[0] * (args.valid_negative +
                                                         1) / eval_duration
            dllogger.log(step=tuple(),
                         data={
                             'eval_throughput': eval_throughput,
                             'eval_time': eval_duration,
                             'hr@10': hit_rate,
                             'ndcg': ndcg
                         })
        return

    # Performance Metrics
    train_times = list()
    eval_times = list()
    # Accuracy Metrics
    first_to_target = None
    time_to_train = 0.0
    best_hr = 0
    best_epoch = 0
    # Buffers for global metrics
    global_hr_sum = np.ones(1)
    global_hr_count = np.ones(1)
    global_ndcg_sum = np.ones(1)
    global_ndcg_count = np.ones(1)
    # Buffers for local metrics
    local_hr_sum = np.ones(1)
    local_hr_count = np.ones(1)
    local_ndcg_sum = np.ones(1)
    local_ndcg_count = np.ones(1)

    # Begin training
    begin_train = time.time()
    for epoch in range(args.epochs):
        # Train for one epoch
        train_start = time.time()
        data_generator.prepare_train_data()
        for user_batch, item_batch, label_batch \
            in zip(data_generator.train_users_batches,
                   data_generator.train_items_batches,
                   data_generator.train_labels_batches):
            sess.run(train_op,
                     feed_dict={
                         users: user_batch.get(),
                         items: item_batch.get(),
                         labels: label_batch.get()
                     })
        train_duration = time.time() - train_start
        # Only log "warm" epochs
        if epoch >= 1:
            train_times.append(train_duration)
        # Evaluate
        if epoch > args.eval_after:
            eval_start = time.time()
            sess.run(tf.local_variables_initializer())
            for user_batch, item_batch, dup_batch \
                in zip(data_generator.eval_users,
                       data_generator.eval_items,
                       data_generator.dup_mask):
                sess.run(eval_op,
                         feed_dict={
                             users: user_batch,
                             items: item_batch,
                             is_dup: dup_batch,
                             dropout: 0.0
                         })
            # Compute local metrics
            local_hr_sum[0] = sess.run(hr_sum)
            local_hr_count[0] = sess.run(hr_cnt)
            local_ndcg_sum[0] = sess.run(ndcg_sum)
            local_ndcg_count[0] = sess.run(ndcg_cnt)
            # Reduce metrics across all workers
            mpi_comm.Reduce(local_hr_count, global_hr_count)
            mpi_comm.Reduce(local_hr_sum, global_hr_sum)
            mpi_comm.Reduce(local_ndcg_count, global_ndcg_count)
            mpi_comm.Reduce(local_ndcg_sum, global_ndcg_sum)
            # Calculate metrics
            hit_rate = global_hr_sum[0] / global_hr_count[0]
            ndcg = global_ndcg_sum[0] / global_ndcg_count[0]

            eval_duration = time.time() - eval_start
            # Only log "warm" epochs
            if epoch >= 1:
                eval_times.append(eval_duration)

            if hvd.rank() == 0:
                dllogger.log(step=(epoch, ),
                             data={
                                 'train_time': train_duration,
                                 'eval_time': eval_duration,
                                 'hr@10': hit_rate,
                                 'ndcg': ndcg
                             })

                # Update summary metrics
                if hit_rate > args.target and first_to_target is None:
                    first_to_target = epoch
                    time_to_train = time.time() - begin_train
                if hit_rate > best_hr:
                    best_hr = hit_rate
                    best_epoch = epoch
                    time_to_best = time.time() - begin_train
                    if hit_rate > args.target:
                        saver.save(sess, final_checkpoint_path)

    # Final Summary
    if hvd.rank() == 0:
        train_times = np.array(train_times)
        train_throughputs = pos_train_users.shape[0] * (args.negative_samples +
                                                        1) / train_times
        eval_times = np.array(eval_times)
        eval_throughputs = pos_test_users.shape[0] * (args.valid_negative +
                                                      1) / eval_times

        dllogger.log(step=tuple(),
                     data={
                         'average_train_time_per_epoch': np.mean(train_times),
                         'average_train_throughput':
                         np.mean(train_throughputs),
                         'average_eval_time_per_epoch': np.mean(eval_times),
                         'average_eval_throughput': np.mean(eval_throughputs),
                         'first_epoch_to_hit': first_to_target,
                         'time_to_train': time_to_train,
                         'time_to_best': time_to_best,
                         'best_hr': best_hr,
                         'best_epoch': best_epoch
                     })
        dllogger.flush()

    sess.close()
    return
Esempio n. 2
0
def main():
    """
    Run training/evaluation
    """
    script_start = time.time()
    hvd_init()
    mpi_comm = MPI.COMM_WORLD
    args = parse_args()

    if hvd.rank() == 0:
        log_args(args)
    else:
        os.environ['WANDB_MODE'] = 'dryrun'
    wandb_id = os.environ.get('WANDB_ID', None)
    if wandb_id is None:
        wandb.init(config=args)
    else:
        wandb.init(config=args, id=f"{wandb_id}{hvd.rank()}")
    wandb.config.update({'SLURM_JOB_ID': os.environ.get('SLURM_JOB_ID', None)})
    wandb.tensorboard.patch(save=False)

    if args.seed is not None:
        tf.random.set_random_seed(args.seed)
        np.random.seed(args.seed)
        cp.random.seed(args.seed)

    if args.amp:
        os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "1"
    if "TF_ENABLE_AUTO_MIXED_PRECISION" in os.environ \
       and os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] == "1":
        args.fp16 = False

    # directory to store/read final checkpoint
    if args.mode == 'train' and hvd.rank() == 0:
        print("Saving best checkpoint to {}".format(args.checkpoint_dir))
    elif hvd.rank() == 0:
        print("Reading checkpoint: {}".format(args.checkpoint_dir))
    if not os.path.exists(args.checkpoint_dir) and args.checkpoint_dir != '':
        os.makedirs(args.checkpoint_dir, exist_ok=True)
    final_checkpoint_path = os.path.join(args.checkpoint_dir, 'model.ckpt')

    # Load converted data and get statistics
    train_df = pd.read_pickle(args.data + '/train_ratings.pickle')
    test_df = pd.read_pickle(args.data + '/test_ratings.pickle')
    nb_users, nb_items = train_df.max() + 1

    # Extract train and test feature tensors from dataframe
    pos_train_users = train_df.iloc[:, 0].values.astype(np.int32)
    pos_train_items = train_df.iloc[:, 1].values.astype(np.int32)
    pos_test_users = test_df.iloc[:, 0].values.astype(np.int32)
    pos_test_items = test_df.iloc[:, 1].values.astype(np.int32)
    # Negatives indicator for negatives generation
    neg_mat = np.ones((nb_users, nb_items), dtype=np.bool)
    neg_mat[pos_train_users, pos_train_items] = 0

    # Get the local training/test data
    train_users, train_items, train_labels = get_local_train_data(
        pos_train_users, pos_train_items, args.negative_samples)
    test_users, test_items = get_local_test_data(pos_test_users,
                                                 pos_test_items)

    # Create and run Data Generator in a separate thread
    data_generator = DataGenerator(
        args.seed,
        hvd.local_rank(),
        nb_users,
        nb_items,
        neg_mat,
        train_users,
        train_items,
        train_labels,
        args.batch_size // hvd.size(),
        args.negative_samples,
        test_users,
        test_items,
        args.valid_users_per_batch,
        args.valid_negative,
    )

    # Create tensorflow session and saver
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    if args.xla:
        config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
    sess = tf.Session(config=config)

    # Input tensors
    users = tf.placeholder(tf.int32, shape=(None, ))
    items = tf.placeholder(tf.int32, shape=(None, ))
    labels = tf.placeholder(tf.int32, shape=(None, ))
    is_dup = tf.placeholder(tf.float32, shape=(None, ))
    dropout = tf.placeholder_with_default(args.dropout, shape=())
    # Model ops and saver
    hit_rate, ndcg, eval_op, train_op = ncf_model_ops(
        users,
        items,
        labels,
        is_dup,
        params={
            'fp16': args.fp16,
            'val_batch_size': args.valid_negative + 1,
            'top_k': args.topk,
            'learning_rate': args.learning_rate,
            'beta_1': args.beta1,
            'beta_2': args.beta2,
            'epsilon': args.eps,
            'num_users': nb_users,
            'num_items': nb_items,
            'num_factors': args.factors,
            'mf_reg': 0,
            'layer_sizes': args.layers,
            'layer_regs': [0. for i in args.layers],
            'dropout': dropout,
            'sigmoid': True,
            'loss_scale': args.loss_scale
        },
        mode='TRAIN' if args.mode == 'train' else 'EVAL')
    saver = tf.train.Saver()

    # Accuracy metric tensors
    hr_sum = tf.get_default_graph().get_tensor_by_name(
        'neumf/hit_rate/total:0')
    hr_cnt = tf.get_default_graph().get_tensor_by_name(
        'neumf/hit_rate/count:0')
    ndcg_sum = tf.get_default_graph().get_tensor_by_name('neumf/ndcg/total:0')
    ndcg_cnt = tf.get_default_graph().get_tensor_by_name('neumf/ndcg/count:0')

    # Prepare evaluation data
    data_generator.prepare_eval_data()

    if args.load_checkpoint_path:
        saver.restore(sess, args.load_checkpoint_path)
    else:
        # Manual initialize weights
        sess.run(tf.global_variables_initializer())

    # If test mode, run one eval
    if args.mode == 'test':
        sess.run(tf.local_variables_initializer())
        eval_start = time.time()
        for user_batch, item_batch, dup_batch \
            in zip(data_generator.eval_users, data_generator.eval_items, data_generator.dup_mask):
            sess.run(eval_op,
                     feed_dict={
                         users: user_batch,
                         items: item_batch,
                         is_dup: dup_batch,
                         dropout: 0.0
                     })
        eval_duration = time.time() - eval_start

        # Report results
        hit_rate_sum = sess.run(hvd.allreduce(hr_sum, average=False))
        hit_rate_cnt = sess.run(hvd.allreduce(hr_cnt, average=False))
        ndcg_sum = sess.run(hvd.allreduce(ndcg_sum, average=False))
        ndcg_cnt = sess.run(hvd.allreduce(ndcg_cnt, average=False))

        hit_rate = hit_rate_sum / hit_rate_cnt
        ndcg = ndcg_sum / ndcg_cnt

        if hvd.rank() == 0:
            LOGGER.log("Eval Time: {:.4f}, HR: {:.4f}, NDCG: {:.4f}".format(
                eval_duration, hit_rate, ndcg))

            eval_throughput = pos_test_users.shape[0] * (args.valid_negative +
                                                         1) / eval_duration
            LOGGER.log(
                'Average Eval Throughput: {:.4f}'.format(eval_throughput))
        return

    # Performance Metrics
    train_times = list()
    eval_times = list()
    # Accuracy Metrics
    first_to_target = None
    time_to_train = 0.0
    best_hr = 0
    best_epoch = 0
    # Buffers for global metrics
    global_hr_sum = np.ones(1)
    global_hr_count = np.ones(1)
    global_ndcg_sum = np.ones(1)
    global_ndcg_count = np.ones(1)
    # Buffers for local metrics
    local_hr_sum = np.ones(1)
    local_hr_count = np.ones(1)
    local_ndcg_sum = np.ones(1)
    local_ndcg_count = np.ones(1)

    # Begin training
    begin_train = time.time()
    if hvd.rank() == 0:
        LOGGER.log("Begin Training. Setup Time: {}".format(begin_train -
                                                           script_start))
    for epoch in range(args.epochs):
        # Train for one epoch
        train_start = time.time()
        data_generator.prepare_train_data()
        for user_batch, item_batch, label_batch \
            in zip(data_generator.train_users_batches,
                   data_generator.train_items_batches,
                   data_generator.train_labels_batches):
            sess.run(train_op,
                     feed_dict={
                         users: user_batch.get(),
                         items: item_batch.get(),
                         labels: label_batch.get()
                     })
        train_duration = time.time() - train_start
        wandb.log({"train/epoch_time": train_duration}, commit=False)
        ## Only log "warm" epochs
        if epoch >= 1:
            train_times.append(train_duration)
        # Evaluate
        if epoch > args.eval_after:
            eval_start = time.time()
            sess.run(tf.local_variables_initializer())
            for user_batch, item_batch, dup_batch \
                in zip(data_generator.eval_users,
                       data_generator.eval_items,
                       data_generator.dup_mask):
                sess.run(eval_op,
                         feed_dict={
                             users: user_batch,
                             items: item_batch,
                             is_dup: dup_batch,
                             dropout: 0.0
                         })
            # Compute local metrics
            local_hr_sum[0] = sess.run(hr_sum)
            local_hr_count[0] = sess.run(hr_cnt)
            local_ndcg_sum[0] = sess.run(ndcg_sum)
            local_ndcg_count[0] = sess.run(ndcg_cnt)
            # Reduce metrics across all workers
            mpi_comm.Reduce(local_hr_count, global_hr_count)
            mpi_comm.Reduce(local_hr_sum, global_hr_sum)
            mpi_comm.Reduce(local_ndcg_count, global_ndcg_count)
            mpi_comm.Reduce(local_ndcg_sum, global_ndcg_sum)
            # Calculate metrics
            hit_rate = global_hr_sum[0] / global_hr_count[0]
            ndcg = global_ndcg_sum[0] / global_ndcg_count[0]

            eval_duration = time.time() - eval_start
            wandb.log(
                {
                    "eval/time": eval_duration,
                    "eval/hit_rate": hit_rate,
                    "eval/ndcg": ndcg
                },
                commit=False)
            ## Only log "warm" epochs
            if epoch >= 1:
                eval_times.append(eval_duration)

            if hvd.rank() == 0:
                if args.verbose:
                    log_string = "Epoch: {:02d}, Train Time: {:.4f}, Eval Time: {:.4f}, HR: {:.4f}, NDCG: {:.4f}"
                    LOGGER.log(
                        log_string.format(epoch, train_duration, eval_duration,
                                          hit_rate, ndcg))

                # Update summary metrics
                if hit_rate > args.target and first_to_target is None:
                    first_to_target = epoch
                    time_to_train = time.time() - begin_train
                if hit_rate > best_hr:
                    best_hr = hit_rate
                    best_epoch = epoch
                    time_to_best = time.time() - begin_train
                    if not args.verbose:
                        log_string = "New Best Epoch: {:02d}, Train Time: {:.4f}, Eval Time: {:.4f}, HR: {:.4f}, NDCG: {:.4f}"
                        LOGGER.log(
                            log_string.format(epoch, train_duration,
                                              eval_duration, hit_rate, ndcg))
                    # Save, if meets target
                    if hit_rate > args.target:
                        saver.save(sess, final_checkpoint_path)
        wandb.log({"epoch": epoch + 1})

    # Final Summary
    if hvd.rank() == 0:
        train_times = np.array(train_times)
        train_throughputs = pos_train_users.shape[0] * (args.negative_samples +
                                                        1) / train_times
        eval_times = np.array(eval_times)
        eval_throughputs = pos_test_users.shape[0] * (args.valid_negative +
                                                      1) / eval_times
        LOGGER.log(' ')

        LOGGER.log('batch_size: {}'.format(args.batch_size))
        LOGGER.log('num_gpus: {}'.format(hvd.size()))
        LOGGER.log('AMP: {}'.format(1 if args.amp else 0))
        LOGGER.log('seed: {}'.format(args.seed))
        LOGGER.log('Minimum Train Time per Epoch: {:.4f}'.format(
            np.min(train_times)))
        LOGGER.log('Average Train Time per Epoch: {:.4f}'.format(
            np.mean(train_times)))
        LOGGER.log('Average Train Throughput:     {:.4f}'.format(
            np.mean(train_throughputs)))
        LOGGER.log('Minimum Eval Time per Epoch:  {:.4f}'.format(
            np.min(eval_times)))
        LOGGER.log('Average Eval Time per Epoch:  {:.4f}'.format(
            np.mean(eval_times)))
        LOGGER.log('Average Eval Throughput:      {:.4f}'.format(
            np.mean(eval_throughputs)))
        LOGGER.log('First Epoch to hit:           {}'.format(first_to_target))
        LOGGER.log(
            'Time to Train:                {:.4f}'.format(time_to_train))
        LOGGER.log('Time to Best:                 {:.4f}'.format(time_to_best))
        LOGGER.log('Best HR:                      {:.4f}'.format(best_hr))
        LOGGER.log('Best Epoch:                   {}'.format(best_epoch))
        wandb.log({
            "batch_size": args.batch_size,
            "num_gpus": hvd.size(),
            "train/total_throughput": np.mean(train_throughputs),
            "eval/total_throughput": np.mean(eval_throughputs),
            "train/total_time": np.sum(train_times),
            "train/time_to_target": time_to_train,
            "train/time_to_best": time_to_best,
            "train/first_to_target": first_to_target,
            "train/best_hit_rate": best_hr,
            "train/best_epoch": best_epoch,
            "epoch": args.epochs
        })

    sess.close()
    return
def main():
    args = parse_args()

    if args.amp:
        os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "1"

    # Input tensors
    users = tf.placeholder(tf.int32, shape=(None, ))
    items = tf.placeholder(tf.int32, shape=(None, ))
    dropout = tf.placeholder_with_default(0.0, shape=())

    # Model ops and saver
    logits_op = ncf_model_ops(users=users,
                              items=items,
                              labels=None,
                              dup_mask=None,
                              params={
                                  'fp16': False,
                                  'val_batch_size': args.batch_size,
                                  'num_users': args.n_users,
                                  'num_items': args.n_items,
                                  'num_factors': args.factors,
                                  'mf_reg': 0,
                                  'layer_sizes': args.layers,
                                  'layer_regs': [0. for i in args.layers],
                                  'dropout': 0.0,
                                  'sigmoid': True,
                                  'top_k': None,
                                  'learning_rate': None,
                                  'beta_1': None,
                                  'beta_2': None,
                                  'epsilon': None,
                                  'loss_scale': None,
                              },
                              mode='INFERENCE')

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    if args.xla:
        config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    sess = tf.Session(config=config)

    saver = tf.train.Saver()
    if args.load_checkpoint_path:
        saver.restore(sess, args.load_checkpoint_path)
    else:
        # Manual initialize weights
        sess.run(tf.global_variables_initializer())

    sess.run(tf.local_variables_initializer())

    users_batch = np.random.randint(size=args.batch_size,
                                    low=0,
                                    high=args.n_users)
    items_batch = np.random.randint(size=args.batch_size,
                                    low=0,
                                    high=args.n_items)

    latencies = []
    for _ in range(args.num_batches):
        start = time.time()
        logits = sess.run(logits_op,
                          feed_dict={
                              users: users_batch,
                              items: items_batch,
                              dropout: 0.0
                          })
        latencies.append(time.time() - start)

    results = {
        'args': vars(args),
        'best_inference_throughput': args.batch_size / min(latencies),
        'best_inference_latency': min(latencies),
        'inference_latencies': latencies
    }
    print('RESULTS: ', json.dumps(results, indent=4))
    if args.log_path is not None:
        json.dump(results, open(args.log_path, 'w'), indent=4)
def main():
    args = parse_args()

    if args.amp:
        os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "1"

    dllogger.init(backends=[
        dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
                                   filename=args.log_path),
        dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE)
    ])
    dllogger.log(data=vars(args), step='PARAMETER')

    batch_sizes = args.batch_sizes.split(',')
    batch_sizes = [int(s) for s in batch_sizes]
    result_data = {}

    for batch_size in batch_sizes:
        print('Benchmarking batch size', batch_size)
        tf.reset_default_graph()

        # Input tensors
        users = tf.placeholder(tf.int32, shape=(None, ))
        items = tf.placeholder(tf.int32, shape=(None, ))
        dropout = tf.placeholder_with_default(0.0, shape=())

        # Model ops and saver
        logits_op = ncf_model_ops(users=users,
                                  items=items,
                                  labels=None,
                                  dup_mask=None,
                                  mode='INFERENCE',
                                  params={
                                      'fp16': False,
                                      'val_batch_size': batch_size,
                                      'num_users': args.n_users,
                                      'num_items': args.n_items,
                                      'num_factors': args.factors,
                                      'mf_reg': 0,
                                      'layer_sizes': args.layers,
                                      'layer_regs': [0. for i in args.layers],
                                      'dropout': 0.0,
                                      'sigmoid': True,
                                      'top_k': None,
                                      'learning_rate': None,
                                      'beta_1': None,
                                      'beta_2': None,
                                      'epsilon': None,
                                      'loss_scale': None,
                                  })

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        if args.xla:
            config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

        sess = tf.Session(config=config)

        saver = tf.train.Saver()
        if args.load_checkpoint_path:
            saver.restore(sess, args.load_checkpoint_path)
        else:
            sess.run(tf.global_variables_initializer())

        sess.run(tf.local_variables_initializer())

        users_batch = np.random.randint(size=batch_size,
                                        low=0,
                                        high=args.n_users)
        items_batch = np.random.randint(size=batch_size,
                                        low=0,
                                        high=args.n_items)

        latencies = []
        for i in range(args.num_batches):
            start = time.time()
            _ = sess.run(logits_op,
                         feed_dict={
                             users: users_batch,
                             items: items_batch,
                             dropout: 0.0
                         })
            end = time.time()

            if i < 10:  # warmup iterations
                continue

            latencies.append(end - start)

        result_data[
            f'batch_{batch_size}_mean_throughput'] = batch_size / np.mean(
                latencies)
        result_data[f'batch_{batch_size}_mean_latency'] = np.mean(latencies)
        result_data[f'batch_{batch_size}_p90_latency'] = np.percentile(
            latencies, 90)
        result_data[f'batch_{batch_size}_p95_latency'] = np.percentile(
            latencies, 95)
        result_data[f'batch_{batch_size}_p99_latency'] = np.percentile(
            latencies, 99)

    dllogger.log(data=result_data, step=tuple())
    dllogger.flush()