コード例 #1
0
def main():
    log_hardware()
    args = parse_args()
    log_args(args)

    model = NeuMF(nb_users=args.n_users, nb_items=args.n_items, mf_dim=args.factors,
                  mlp_layer_sizes=args.layers, dropout=args.dropout)

    model = model.cuda()

    if args.load_checkpoint_path:
        state_dict = torch.load(args.load_checkpoint_path)
        model.load_state_dict(state_dict)

    if args.opt_level == "O2":
        model = amp.initialize(model, opt_level=args.opt_level,
                               keep_batchnorm_fp32=False, loss_scale='dynamic')

    users = torch.cuda.LongTensor(args.batch_size).random_(0, args.n_users)
    items = torch.cuda.LongTensor(args.batch_size).random_(0, args.n_items)

    latencies = []
    for _ in range(args.num_batches):
        torch.cuda.synchronize()
        start = time.time()
        predictions = model(users, items, sigmoid=True)
        torch.cuda.synchronize()
        latencies.append(time.time() - start)

    LOGGER.log(key='batch_size', value=args.batch_size)
    LOGGER.log(key='best_inference_throughput', value=args.batch_size / min(latencies))
    LOGGER.log(key='best_inference_latency', value=min(latencies))
    LOGGER.log(key='inference_latencies', value=latencies)
    return
コード例 #2
0
def main():
    log_hardware()
    args = parse_args()
    args.distributed, args.world_size = init_distributed(args.local_rank)
    log_args(args)

    if args.seed is not None:
        torch.manual_seed(args.seed)

    print("Saving results to {}".format(args.checkpoint_dir))
    if not os.path.exists(args.checkpoint_dir) and args.checkpoint_dir != '':
        os.makedirs(args.checkpoint_dir, exist_ok=True)

    # The default of np.random.choice is replace=True, so does pytorch random_()
    LOGGER.log(key=tags.PREPROC_HP_SAMPLE_EVAL_REPLACEMENT, value=True)
    LOGGER.log(key=tags.INPUT_HP_SAMPLE_TRAIN_REPLACEMENT, value=True)
    LOGGER.log(key=tags.INPUT_STEP_EVAL_NEG_GEN)

    # sync workers before timing
    if args.distributed:
        torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
    torch.cuda.synchronize()

    main_start_time = time.time()
    LOGGER.log(key=tags.RUN_START)

    train_ratings = torch.load(args.data + '/train_ratings.pt',
                               map_location=torch.device('cuda:{}'.format(
                                   args.local_rank)))
    test_ratings = torch.load(args.data + '/test_ratings.pt',
                              map_location=torch.device('cuda:{}'.format(
                                  args.local_rank)))
    test_negs = torch.load(args.data + '/test_negatives.pt',
                           map_location=torch.device('cuda:{}'.format(
                               args.local_rank)))

    valid_negative = test_negs.shape[1]
    LOGGER.log(key=tags.PREPROC_HP_NUM_EVAL, value=valid_negative)

    nb_maxs = torch.max(train_ratings, 0)[0]
    nb_users = nb_maxs[0].item() + 1
    nb_items = nb_maxs[1].item() + 1
    LOGGER.log(key=tags.INPUT_SIZE, value=len(train_ratings))

    all_test_users = test_ratings.shape[0]

    test_users, test_items, dup_mask, real_indices = dataloading.create_test_data(
        test_ratings, test_negs, args)

    # make pytorch memory behavior more consistent later
    torch.cuda.empty_cache()

    LOGGER.log(key=tags.INPUT_BATCH_SIZE, value=args.batch_size)
    LOGGER.log(key=tags.INPUT_ORDER)  # we shuffled later with randperm

    # Create model
    model = NeuMF(nb_users,
                  nb_items,
                  mf_dim=args.factors,
                  mlp_layer_sizes=args.layers,
                  dropout=args.dropout)

    optimizer = FusedAdam(model.parameters(),
                          lr=args.learning_rate,
                          betas=(args.beta1, args.beta2),
                          eps=args.eps)

    criterion = nn.BCEWithLogitsLoss(
        reduction='none'
    )  # use torch.mean() with dim later to avoid copy to host
    # Move model and loss to GPU
    model = model.cuda()
    criterion = criterion.cuda()

    if args.opt_level == "O2":
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level,
                                          keep_batchnorm_fp32=False,
                                          loss_scale='dynamic')

    if args.distributed:
        model = DDP(model)

    local_batch = args.batch_size // args.world_size
    traced_criterion = torch.jit.trace(
        criterion.forward,
        (torch.rand(local_batch, 1), torch.rand(local_batch, 1)))

    print(model)
    print("{} parameters".format(admm_utils.count_parameters(model)))
    LOGGER.log(key=tags.OPT_LR, value=args.learning_rate)
    LOGGER.log(key=tags.OPT_NAME, value="Adam")
    LOGGER.log(key=tags.OPT_HP_ADAM_BETA1, value=args.beta1)
    LOGGER.log(key=tags.OPT_HP_ADAM_BETA2, value=args.beta2)
    LOGGER.log(key=tags.OPT_HP_ADAM_EPSILON, value=args.eps)
    LOGGER.log(key=tags.MODEL_HP_LOSS_FN, value=tags.VALUE_BCE)

    if args.load_checkpoint_path:
        state_dict = torch.load(args.load_checkpoint_path)
        state_dict = {
            k.replace('module.', ''): v
            for k, v in state_dict.items()
        }
        model.load_state_dict(state_dict)

    if args.mode == 'test':
        LOGGER.log(key=tags.EVAL_START, value=0)
        start = time.time()
        hr, ndcg = val_epoch(model,
                             test_users,
                             test_items,
                             dup_mask,
                             real_indices,
                             args.topk,
                             samples_per_user=valid_negative + 1,
                             num_user=all_test_users,
                             distributed=args.distributed)
        print('HR@{K} = {hit_rate:.4f}, NDCG@{K} = {ndcg:.4f}'.format(
            K=args.topk, hit_rate=hr, ndcg=ndcg))
        val_time = time.time() - start
        eval_size = all_test_users * (valid_negative + 1)
        eval_throughput = eval_size / val_time

        LOGGER.log(key=tags.EVAL_ACCURACY, value={"epoch": 0, "value": hr})
        LOGGER.log(key=tags.EVAL_STOP, value=0)
        LOGGER.log(key='best_eval_throughput', value=eval_throughput)
        return

    success = False
    max_hr = 0
    train_throughputs, eval_throughputs = [], []

    LOGGER.log(key=tags.TRAIN_LOOP)
    for epoch in range(args.epochs):

        LOGGER.log(key=tags.TRAIN_EPOCH_START, value=epoch)
        LOGGER.log(key=tags.INPUT_HP_NUM_NEG, value=args.negative_samples)
        LOGGER.log(key=tags.INPUT_STEP_TRAIN_NEG_GEN)

        begin = time.time()

        epoch_users, epoch_items, epoch_label = dataloading.prepare_epoch_train_data(
            train_ratings, nb_items, args)
        num_batches = len(epoch_users)
        for i in range(num_batches // args.grads_accumulated):
            for j in range(args.grads_accumulated):
                batch_idx = (args.grads_accumulated * i) + j
                user = epoch_users[batch_idx]
                item = epoch_items[batch_idx]
                label = epoch_label[batch_idx].view(-1, 1)

                outputs = model(user, item)
                loss = traced_criterion(outputs, label).float()
                loss = torch.mean(loss.view(-1), 0)

                if args.opt_level == "O2":
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
            optimizer.step()

            for p in model.parameters():
                p.grad = None

        del epoch_users, epoch_items, epoch_label
        train_time = time.time() - begin
        begin = time.time()

        epoch_samples = len(train_ratings) * (args.negative_samples + 1)
        train_throughput = epoch_samples / train_time
        train_throughputs.append(train_throughput)
        LOGGER.log(key='train_throughput', value=train_throughput)
        LOGGER.log(key=tags.TRAIN_EPOCH_STOP, value=epoch)
        LOGGER.log(key=tags.EVAL_START, value=epoch)

        hr, ndcg = val_epoch(model,
                             test_users,
                             test_items,
                             dup_mask,
                             real_indices,
                             args.topk,
                             samples_per_user=valid_negative + 1,
                             num_user=all_test_users,
                             epoch=epoch,
                             distributed=args.distributed)

        val_time = time.time() - begin
        print(
            'Epoch {epoch}: HR@{K} = {hit_rate:.4f}, NDCG@{K} = {ndcg:.4f},'
            ' train_time = {train_time:.2f}, val_time = {val_time:.2f}'.format(
                epoch=epoch,
                K=args.topk,
                hit_rate=hr,
                ndcg=ndcg,
                train_time=train_time,
                val_time=val_time))

        LOGGER.log(key=tags.EVAL_ACCURACY, value={"epoch": epoch, "value": hr})
        LOGGER.log(key=tags.EVAL_TARGET, value=args.threshold)
        LOGGER.log(key=tags.EVAL_STOP, value=epoch)

        eval_size = all_test_users * (valid_negative + 1)
        eval_throughput = eval_size / val_time
        eval_throughputs.append(eval_throughput)
        LOGGER.log(key='eval_throughput', value=eval_throughput)

        if hr > max_hr and args.local_rank == 0:
            max_hr = hr
            save_checkpoint_path = os.path.join(args.checkpoint_dir,
                                                'model.pth')
            print("New best hr! Saving the model to: ", save_checkpoint_path)
            torch.save(model.state_dict(), save_checkpoint_path)
            best_model_timestamp = time.time()

        if args.threshold is not None:
            if hr >= args.threshold:
                print("Hit threshold of {}".format(args.threshold))
                success = True
                break

    if args.local_rank == 0:
        LOGGER.log(key='best_train_throughput', value=max(train_throughputs))
        LOGGER.log(key='best_eval_throughput', value=max(eval_throughputs))
        LOGGER.log(key='best_accuracy', value=max_hr)
        LOGGER.log(key='time_to_target', value=time.time() - main_start_time)
        LOGGER.log(key='time_to_best_model',
                   value=best_model_timestamp - main_start_time)

        LOGGER.log(key=tags.RUN_STOP, value={"success": success})
        LOGGER.log(key=tags.RUN_FINAL)
コード例 #3
0
def main():
    """
    Run training/evaluation
    """
    script_start = time.time()
    hvd_init()
    mpi_comm = MPI.COMM_WORLD
    args = parse_args()

    if hvd.rank() == 0:
        log_args(args)
    else:
        os.environ['WANDB_MODE'] = 'dryrun'
    wandb_id = os.environ.get('WANDB_ID', None)
    if wandb_id is None:
        wandb.init(config=args)
    else:
        wandb.init(config=args, id=f"{wandb_id}{hvd.rank()}")
    wandb.config.update({'SLURM_JOB_ID': os.environ.get('SLURM_JOB_ID', None)})
    wandb.tensorboard.patch(save=False)

    if args.seed is not None:
        tf.random.set_random_seed(args.seed)
        np.random.seed(args.seed)
        cp.random.seed(args.seed)

    if args.amp:
        os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "1"
    if "TF_ENABLE_AUTO_MIXED_PRECISION" in os.environ \
       and os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] == "1":
        args.fp16 = False

    # directory to store/read final checkpoint
    if args.mode == 'train' and hvd.rank() == 0:
        print("Saving best checkpoint to {}".format(args.checkpoint_dir))
    elif hvd.rank() == 0:
        print("Reading checkpoint: {}".format(args.checkpoint_dir))
    if not os.path.exists(args.checkpoint_dir) and args.checkpoint_dir != '':
        os.makedirs(args.checkpoint_dir, exist_ok=True)
    final_checkpoint_path = os.path.join(args.checkpoint_dir, 'model.ckpt')

    # Load converted data and get statistics
    train_df = pd.read_pickle(args.data + '/train_ratings.pickle')
    test_df = pd.read_pickle(args.data + '/test_ratings.pickle')
    nb_users, nb_items = train_df.max() + 1

    # Extract train and test feature tensors from dataframe
    pos_train_users = train_df.iloc[:, 0].values.astype(np.int32)
    pos_train_items = train_df.iloc[:, 1].values.astype(np.int32)
    pos_test_users = test_df.iloc[:, 0].values.astype(np.int32)
    pos_test_items = test_df.iloc[:, 1].values.astype(np.int32)
    # Negatives indicator for negatives generation
    neg_mat = np.ones((nb_users, nb_items), dtype=np.bool)
    neg_mat[pos_train_users, pos_train_items] = 0

    # Get the local training/test data
    train_users, train_items, train_labels = get_local_train_data(
        pos_train_users, pos_train_items, args.negative_samples)
    test_users, test_items = get_local_test_data(pos_test_users,
                                                 pos_test_items)

    # Create and run Data Generator in a separate thread
    data_generator = DataGenerator(
        args.seed,
        hvd.local_rank(),
        nb_users,
        nb_items,
        neg_mat,
        train_users,
        train_items,
        train_labels,
        args.batch_size // hvd.size(),
        args.negative_samples,
        test_users,
        test_items,
        args.valid_users_per_batch,
        args.valid_negative,
    )

    # Create tensorflow session and saver
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    if args.xla:
        config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
    sess = tf.Session(config=config)

    # Input tensors
    users = tf.placeholder(tf.int32, shape=(None, ))
    items = tf.placeholder(tf.int32, shape=(None, ))
    labels = tf.placeholder(tf.int32, shape=(None, ))
    is_dup = tf.placeholder(tf.float32, shape=(None, ))
    dropout = tf.placeholder_with_default(args.dropout, shape=())
    # Model ops and saver
    hit_rate, ndcg, eval_op, train_op = ncf_model_ops(
        users,
        items,
        labels,
        is_dup,
        params={
            'fp16': args.fp16,
            'val_batch_size': args.valid_negative + 1,
            'top_k': args.topk,
            'learning_rate': args.learning_rate,
            'beta_1': args.beta1,
            'beta_2': args.beta2,
            'epsilon': args.eps,
            'num_users': nb_users,
            'num_items': nb_items,
            'num_factors': args.factors,
            'mf_reg': 0,
            'layer_sizes': args.layers,
            'layer_regs': [0. for i in args.layers],
            'dropout': dropout,
            'sigmoid': True,
            'loss_scale': args.loss_scale
        },
        mode='TRAIN' if args.mode == 'train' else 'EVAL')
    saver = tf.train.Saver()

    # Accuracy metric tensors
    hr_sum = tf.get_default_graph().get_tensor_by_name(
        'neumf/hit_rate/total:0')
    hr_cnt = tf.get_default_graph().get_tensor_by_name(
        'neumf/hit_rate/count:0')
    ndcg_sum = tf.get_default_graph().get_tensor_by_name('neumf/ndcg/total:0')
    ndcg_cnt = tf.get_default_graph().get_tensor_by_name('neumf/ndcg/count:0')

    # Prepare evaluation data
    data_generator.prepare_eval_data()

    if args.load_checkpoint_path:
        saver.restore(sess, args.load_checkpoint_path)
    else:
        # Manual initialize weights
        sess.run(tf.global_variables_initializer())

    # If test mode, run one eval
    if args.mode == 'test':
        sess.run(tf.local_variables_initializer())
        eval_start = time.time()
        for user_batch, item_batch, dup_batch \
            in zip(data_generator.eval_users, data_generator.eval_items, data_generator.dup_mask):
            sess.run(eval_op,
                     feed_dict={
                         users: user_batch,
                         items: item_batch,
                         is_dup: dup_batch,
                         dropout: 0.0
                     })
        eval_duration = time.time() - eval_start

        # Report results
        hit_rate_sum = sess.run(hvd.allreduce(hr_sum, average=False))
        hit_rate_cnt = sess.run(hvd.allreduce(hr_cnt, average=False))
        ndcg_sum = sess.run(hvd.allreduce(ndcg_sum, average=False))
        ndcg_cnt = sess.run(hvd.allreduce(ndcg_cnt, average=False))

        hit_rate = hit_rate_sum / hit_rate_cnt
        ndcg = ndcg_sum / ndcg_cnt

        if hvd.rank() == 0:
            LOGGER.log("Eval Time: {:.4f}, HR: {:.4f}, NDCG: {:.4f}".format(
                eval_duration, hit_rate, ndcg))

            eval_throughput = pos_test_users.shape[0] * (args.valid_negative +
                                                         1) / eval_duration
            LOGGER.log(
                'Average Eval Throughput: {:.4f}'.format(eval_throughput))
        return

    # Performance Metrics
    train_times = list()
    eval_times = list()
    # Accuracy Metrics
    first_to_target = None
    time_to_train = 0.0
    best_hr = 0
    best_epoch = 0
    # Buffers for global metrics
    global_hr_sum = np.ones(1)
    global_hr_count = np.ones(1)
    global_ndcg_sum = np.ones(1)
    global_ndcg_count = np.ones(1)
    # Buffers for local metrics
    local_hr_sum = np.ones(1)
    local_hr_count = np.ones(1)
    local_ndcg_sum = np.ones(1)
    local_ndcg_count = np.ones(1)

    # Begin training
    begin_train = time.time()
    if hvd.rank() == 0:
        LOGGER.log("Begin Training. Setup Time: {}".format(begin_train -
                                                           script_start))
    for epoch in range(args.epochs):
        # Train for one epoch
        train_start = time.time()
        data_generator.prepare_train_data()
        for user_batch, item_batch, label_batch \
            in zip(data_generator.train_users_batches,
                   data_generator.train_items_batches,
                   data_generator.train_labels_batches):
            sess.run(train_op,
                     feed_dict={
                         users: user_batch.get(),
                         items: item_batch.get(),
                         labels: label_batch.get()
                     })
        train_duration = time.time() - train_start
        wandb.log({"train/epoch_time": train_duration}, commit=False)
        ## Only log "warm" epochs
        if epoch >= 1:
            train_times.append(train_duration)
        # Evaluate
        if epoch > args.eval_after:
            eval_start = time.time()
            sess.run(tf.local_variables_initializer())
            for user_batch, item_batch, dup_batch \
                in zip(data_generator.eval_users,
                       data_generator.eval_items,
                       data_generator.dup_mask):
                sess.run(eval_op,
                         feed_dict={
                             users: user_batch,
                             items: item_batch,
                             is_dup: dup_batch,
                             dropout: 0.0
                         })
            # Compute local metrics
            local_hr_sum[0] = sess.run(hr_sum)
            local_hr_count[0] = sess.run(hr_cnt)
            local_ndcg_sum[0] = sess.run(ndcg_sum)
            local_ndcg_count[0] = sess.run(ndcg_cnt)
            # Reduce metrics across all workers
            mpi_comm.Reduce(local_hr_count, global_hr_count)
            mpi_comm.Reduce(local_hr_sum, global_hr_sum)
            mpi_comm.Reduce(local_ndcg_count, global_ndcg_count)
            mpi_comm.Reduce(local_ndcg_sum, global_ndcg_sum)
            # Calculate metrics
            hit_rate = global_hr_sum[0] / global_hr_count[0]
            ndcg = global_ndcg_sum[0] / global_ndcg_count[0]

            eval_duration = time.time() - eval_start
            wandb.log(
                {
                    "eval/time": eval_duration,
                    "eval/hit_rate": hit_rate,
                    "eval/ndcg": ndcg
                },
                commit=False)
            ## Only log "warm" epochs
            if epoch >= 1:
                eval_times.append(eval_duration)

            if hvd.rank() == 0:
                if args.verbose:
                    log_string = "Epoch: {:02d}, Train Time: {:.4f}, Eval Time: {:.4f}, HR: {:.4f}, NDCG: {:.4f}"
                    LOGGER.log(
                        log_string.format(epoch, train_duration, eval_duration,
                                          hit_rate, ndcg))

                # Update summary metrics
                if hit_rate > args.target and first_to_target is None:
                    first_to_target = epoch
                    time_to_train = time.time() - begin_train
                if hit_rate > best_hr:
                    best_hr = hit_rate
                    best_epoch = epoch
                    time_to_best = time.time() - begin_train
                    if not args.verbose:
                        log_string = "New Best Epoch: {:02d}, Train Time: {:.4f}, Eval Time: {:.4f}, HR: {:.4f}, NDCG: {:.4f}"
                        LOGGER.log(
                            log_string.format(epoch, train_duration,
                                              eval_duration, hit_rate, ndcg))
                    # Save, if meets target
                    if hit_rate > args.target:
                        saver.save(sess, final_checkpoint_path)
        wandb.log({"epoch": epoch + 1})

    # Final Summary
    if hvd.rank() == 0:
        train_times = np.array(train_times)
        train_throughputs = pos_train_users.shape[0] * (args.negative_samples +
                                                        1) / train_times
        eval_times = np.array(eval_times)
        eval_throughputs = pos_test_users.shape[0] * (args.valid_negative +
                                                      1) / eval_times
        LOGGER.log(' ')

        LOGGER.log('batch_size: {}'.format(args.batch_size))
        LOGGER.log('num_gpus: {}'.format(hvd.size()))
        LOGGER.log('AMP: {}'.format(1 if args.amp else 0))
        LOGGER.log('seed: {}'.format(args.seed))
        LOGGER.log('Minimum Train Time per Epoch: {:.4f}'.format(
            np.min(train_times)))
        LOGGER.log('Average Train Time per Epoch: {:.4f}'.format(
            np.mean(train_times)))
        LOGGER.log('Average Train Throughput:     {:.4f}'.format(
            np.mean(train_throughputs)))
        LOGGER.log('Minimum Eval Time per Epoch:  {:.4f}'.format(
            np.min(eval_times)))
        LOGGER.log('Average Eval Time per Epoch:  {:.4f}'.format(
            np.mean(eval_times)))
        LOGGER.log('Average Eval Throughput:      {:.4f}'.format(
            np.mean(eval_throughputs)))
        LOGGER.log('First Epoch to hit:           {}'.format(first_to_target))
        LOGGER.log(
            'Time to Train:                {:.4f}'.format(time_to_train))
        LOGGER.log('Time to Best:                 {:.4f}'.format(time_to_best))
        LOGGER.log('Best HR:                      {:.4f}'.format(best_hr))
        LOGGER.log('Best Epoch:                   {}'.format(best_epoch))
        wandb.log({
            "batch_size": args.batch_size,
            "num_gpus": hvd.size(),
            "train/total_throughput": np.mean(train_throughputs),
            "eval/total_throughput": np.mean(eval_throughputs),
            "train/total_time": np.sum(train_times),
            "train/time_to_target": time_to_train,
            "train/time_to_best": time_to_best,
            "train/first_to_target": first_to_target,
            "train/best_hit_rate": best_hr,
            "train/best_epoch": best_epoch,
            "epoch": args.epochs
        })

    sess.close()
    return
コード例 #4
0
def main():
    log_hardware()

    args = parse_args()
    args.distributed, args.world_size = init_distributed(args.local_rank)
    log_args(args)

    main_start_time = time.time()

    if args.seed is not None:
        torch.manual_seed(args.seed)

    # Save configuration to file
    timestamp = "{:.0f}".format(datetime.utcnow().timestamp())
    run_dir = "./run/neumf/{}.{}".format(timestamp, args.local_rank)
    print("Saving results to {}".format(run_dir))
    if not os.path.exists(run_dir) and run_dir != '':
        os.makedirs(run_dir)

    # more like load trigger timer now
    LOGGER.log(key=tags.PREPROC_HP_NUM_EVAL, value=args.valid_negative)
    # The default of np.random.choice is replace=True, so does pytorch random_()
    LOGGER.log(key=tags.PREPROC_HP_SAMPLE_EVAL_REPLACEMENT, value=True)
    LOGGER.log(key=tags.INPUT_HP_SAMPLE_TRAIN_REPLACEMENT, value=True)
    LOGGER.log(key=tags.INPUT_STEP_EVAL_NEG_GEN)

    # sync worker before timing.
    if args.distributed:
        torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
    torch.cuda.synchronize()

    LOGGER.log(key=tags.RUN_START)
    run_start_time = time.time()

    # load not converted data, just seperate one for test
    train_ratings = torch.load(args.data + '/train_ratings.pt',
                               map_location=torch.device('cuda:{}'.format(
                                   args.local_rank)))
    test_ratings = torch.load(args.data + '/test_ratings.pt',
                              map_location=torch.device('cuda:{}'.format(
                                  args.local_rank)))

    # get input data
    # get dims
    nb_maxs = torch.max(train_ratings, 0)[0]
    nb_users = nb_maxs[0].item() + 1
    nb_items = nb_maxs[1].item() + 1
    train_users = train_ratings[:, 0]
    train_items = train_ratings[:, 1]
    del nb_maxs, train_ratings
    LOGGER.log(key=tags.INPUT_SIZE, value=len(train_users))
    # produce things not change between epoch
    # mask for filtering duplicates with real sample
    # note: test data is removed before create mask, same as reference
    mat = torch.cuda.ByteTensor(nb_users, nb_items).fill_(1)
    mat[train_users, train_items] = 0
    # create label
    train_label = torch.ones_like(train_users, dtype=torch.float32)
    neg_label = torch.zeros_like(train_label, dtype=torch.float32)
    neg_label = neg_label.repeat(args.negative_samples)
    train_label = torch.cat((train_label, neg_label))
    del neg_label
    if args.fp16:
        train_label = train_label.half()

    # produce validation negative sample on GPU
    all_test_users = test_ratings.shape[0]

    test_users = test_ratings[:, 0]
    test_pos = test_ratings[:, 1].reshape(-1, 1)
    test_negs = generate_neg(test_users, mat, nb_items, args.valid_negative,
                             True)[1]

    # create items with real sample at last position
    test_users = test_users.reshape(-1, 1).repeat(1, 1 + args.valid_negative)
    test_items = torch.cat(
        (test_negs.reshape(-1, args.valid_negative), test_pos), dim=1)
    del test_ratings, test_negs

    # generate dup mask and real indice for exact same behavior on duplication compare to reference
    # here we need a sort that is stable(keep order of duplicates)
    # this is a version works on integer
    sorted_items, indices = torch.sort(test_items)  # [1,1,1,2], [3,1,0,2]
    sum_item_indices = sorted_items.float() + indices.float() / len(
        indices[0])  #[1.75,1.25,1.0,2.5]
    indices_order = torch.sort(sum_item_indices)[1]  #[2,1,0,3]
    stable_indices = torch.gather(indices, 1, indices_order)  #[0,1,3,2]
    # produce -1 mask
    dup_mask = (sorted_items[:, 0:-1] == sorted_items[:, 1:])
    dup_mask = torch.cat(
        (torch.zeros_like(test_pos, dtype=torch.uint8), dup_mask), dim=1)
    dup_mask = torch.gather(dup_mask, 1, stable_indices.sort()[1])
    # produce real sample indices to later check in topk
    sorted_items, indices = (test_items != test_pos).sort()
    sum_item_indices = sorted_items.float() + indices.float() / len(indices[0])
    indices_order = torch.sort(sum_item_indices)[1]
    stable_indices = torch.gather(indices, 1, indices_order)
    real_indices = stable_indices[:, 0]
    del sorted_items, indices, sum_item_indices, indices_order, stable_indices, test_pos

    if args.distributed:
        test_users = torch.chunk(test_users, args.world_size)[args.local_rank]
        test_items = torch.chunk(test_items, args.world_size)[args.local_rank]
        dup_mask = torch.chunk(dup_mask, args.world_size)[args.local_rank]
        real_indices = torch.chunk(real_indices,
                                   args.world_size)[args.local_rank]

    # make pytorch memory behavior more consistent later
    torch.cuda.empty_cache()

    LOGGER.log(key=tags.INPUT_BATCH_SIZE, value=args.batch_size)
    LOGGER.log(key=tags.INPUT_ORDER)  # we shuffled later with randperm

    print('Load data done [%.1f s]. #user=%d, #item=%d, #train=%d, #test=%d' %
          (time.time() - run_start_time, nb_users, nb_items, len(train_users),
           nb_users))

    # Create model
    model = NeuMF(nb_users,
                  nb_items,
                  mf_dim=args.factors,
                  mf_reg=0.,
                  mlp_layer_sizes=args.layers,
                  mlp_layer_regs=[0. for i in args.layers],
                  dropout=args.dropout)

    if args.fp16:
        model = model.half()

    print(model)
    print("{} parameters".format(utils.count_parameters(model)))

    # Save model text description
    with open(os.path.join(run_dir, 'model.txt'), 'w') as file:
        file.write(str(model))

    # Add optimizer and loss to graph
    if args.fp16:
        fp_optimizer = Fp16Optimizer(model, args.loss_scale)
        params = fp_optimizer.fp32_params
    else:
        params = model.parameters()

    optimizer = FusedAdam(params,
                          lr=args.learning_rate,
                          betas=(args.beta1, args.beta2),
                          eps=args.eps,
                          eps_inside_sqrt=False)
    criterion = nn.BCEWithLogitsLoss(
        reduction='none'
    )  # use torch.mean() with dim later to avoid copy to host
    LOGGER.log(key=tags.OPT_LR, value=args.learning_rate)
    LOGGER.log(key=tags.OPT_NAME, value="Adam")
    LOGGER.log(key=tags.OPT_HP_ADAM_BETA1, value=args.beta1)
    LOGGER.log(key=tags.OPT_HP_ADAM_BETA2, value=args.beta2)
    LOGGER.log(key=tags.OPT_HP_ADAM_EPSILON, value=args.eps)
    LOGGER.log(key=tags.MODEL_HP_LOSS_FN, value=tags.VALUE_BCE)

    # Move model and loss to GPU
    model = model.cuda()
    criterion = criterion.cuda()

    if args.distributed:
        model = DDP(model)
        local_batch = args.batch_size // int(os.environ['WORLD_SIZE'])
    else:
        local_batch = args.batch_size
    traced_criterion = torch.jit.trace(
        criterion.forward,
        (torch.rand(local_batch, 1), torch.rand(local_batch, 1)))

    train_users_per_worker = len(train_label) / int(os.environ['WORLD_SIZE'])
    train_users_begin = int(train_users_per_worker * args.local_rank)
    train_users_end = int(train_users_per_worker * (args.local_rank + 1))

    # Create files for tracking training
    valid_results_file = os.path.join(run_dir, 'valid_results.csv')
    # Calculate initial Hit Ratio and NDCG
    test_x = test_users.view(-1).split(args.valid_batch_size)
    test_y = test_items.view(-1).split(args.valid_batch_size)

    if args.mode == 'test':
        state_dict = torch.load(args.checkpoint_path)
        model.load_state_dict(state_dict)

    begin = time.time()
    LOGGER.log(key=tags.EVAL_START, value=-1)

    hr, ndcg = val_epoch(model,
                         test_x,
                         test_y,
                         dup_mask,
                         real_indices,
                         args.topk,
                         samples_per_user=test_items.size(1),
                         num_user=all_test_users,
                         distributed=args.distributed)
    val_time = time.time() - begin
    print(
        'Initial HR@{K} = {hit_rate:.4f}, NDCG@{K} = {ndcg:.4f}, valid_time: {val_time:.4f}'
        .format(K=args.topk, hit_rate=hr, ndcg=ndcg, val_time=val_time))

    LOGGER.log(key=tags.EVAL_ACCURACY, value={"epoch": -1, "value": hr})
    LOGGER.log(key=tags.EVAL_TARGET, value=args.threshold)
    LOGGER.log(key=tags.EVAL_STOP, value=-1)

    if args.mode == 'test':
        return

    success = False
    max_hr = 0
    LOGGER.log(key=tags.TRAIN_LOOP)
    train_throughputs = []
    eval_throughputs = []

    for epoch in range(args.epochs):

        LOGGER.log(key=tags.TRAIN_EPOCH_START, value=epoch)
        LOGGER.log(key=tags.INPUT_HP_NUM_NEG, value=args.negative_samples)
        LOGGER.log(key=tags.INPUT_STEP_TRAIN_NEG_GEN)

        begin = time.time()

        # prepare data for epoch
        neg_users, neg_items = generate_neg(train_users, mat, nb_items,
                                            args.negative_samples)
        epoch_users = torch.cat((train_users, neg_users))
        epoch_items = torch.cat((train_items, neg_items))

        del neg_users, neg_items

        # shuffle prepared data and split into batches
        epoch_indices = torch.randperm(train_users_end - train_users_begin,
                                       device='cuda:{}'.format(
                                           args.local_rank))
        epoch_indices += train_users_begin

        epoch_users = epoch_users[epoch_indices]
        epoch_items = epoch_items[epoch_indices]
        epoch_label = train_label[epoch_indices]

        epoch_users_list = epoch_users.split(local_batch)
        epoch_items_list = epoch_items.split(local_batch)
        epoch_label_list = epoch_label.split(local_batch)

        # only print progress bar on rank 0
        num_batches = len(epoch_users_list)
        # handle extremely rare case where last batch size < number of worker
        if len(epoch_users) % args.batch_size < args.world_size:
            print("epoch_size % batch_size < number of worker!")
            exit(1)

        for i in range(num_batches // args.grads_accumulated):
            for j in range(args.grads_accumulated):
                batch_idx = (args.grads_accumulated * i) + j
                user = epoch_users_list[batch_idx]
                item = epoch_items_list[batch_idx]
                label = epoch_label_list[batch_idx].view(-1, 1)

                outputs = model(user, item)
                loss = traced_criterion(outputs, label).float()
                loss = torch.mean(loss.view(-1), 0)
                if args.fp16:
                    fp_optimizer.backward(loss)
                else:
                    loss.backward()

            if args.fp16:
                fp_optimizer.step(optimizer)
            else:
                optimizer.step()

            for p in model.parameters():
                p.grad = None

        del epoch_users, epoch_items, epoch_label, epoch_users_list, epoch_items_list, epoch_label_list, user, item, label
        train_time = time.time() - begin
        begin = time.time()

        epoch_samples = len(train_users) * (args.negative_samples + 1)
        train_throughput = epoch_samples / train_time
        train_throughputs.append(train_throughput)
        LOGGER.log(key='train_throughput', value=train_throughput)
        LOGGER.log(key=tags.TRAIN_EPOCH_STOP, value=epoch)
        LOGGER.log(key=tags.EVAL_START, value=epoch)

        hr, ndcg = val_epoch(model,
                             test_x,
                             test_y,
                             dup_mask,
                             real_indices,
                             args.topk,
                             samples_per_user=test_items.size(1),
                             num_user=all_test_users,
                             output=valid_results_file,
                             epoch=epoch,
                             distributed=args.distributed)

        val_time = time.time() - begin
        print(
            'Epoch {epoch}: HR@{K} = {hit_rate:.4f}, NDCG@{K} = {ndcg:.4f},'
            ' train_time = {train_time:.2f}, val_time = {val_time:.2f}'.format(
                epoch=epoch,
                K=args.topk,
                hit_rate=hr,
                ndcg=ndcg,
                train_time=train_time,
                val_time=val_time))

        LOGGER.log(key=tags.EVAL_ACCURACY, value={"epoch": epoch, "value": hr})
        LOGGER.log(key=tags.EVAL_TARGET, value=args.threshold)
        LOGGER.log(key=tags.EVAL_STOP, value=epoch)

        eval_size = all_test_users * test_items.size(1)
        eval_throughput = eval_size / val_time
        eval_throughputs.append(eval_throughput)
        LOGGER.log(key='eval_throughput', value=eval_throughput)

        if hr > max_hr and args.local_rank == 0:
            max_hr = hr
            print("New best hr! Saving the model to: ", args.checkpoint_path)
            torch.save(model.state_dict(), args.checkpoint_path)

        if args.threshold is not None:
            if hr >= args.threshold:
                print("Hit threshold of {}".format(args.threshold))
                success = True
                break

    LOGGER.log(key='best_train_throughput', value=max(train_throughputs))
    LOGGER.log(key='best_eval_throughput', value=max(eval_throughputs))
    LOGGER.log(key='best_accuracy', value=max_hr)
    LOGGER.log(key='time_to_target', value=time.time() - main_start_time)

    LOGGER.log(key=tags.RUN_STOP, value={"success": success})
    LOGGER.log(key=tags.RUN_FINAL)