예제 #1
0
파일: DLA.py 프로젝트: DongHande/AutoDebias
def train_and_eval(train_data,
                   val_data,
                   test_data,
                   device='cuda',
                   model_class=MF,
                   base_model_args: dict = {
                       'emb_dim': 64,
                       'learning_rate': 0.05,
                       'weight_decay': 0.05
                   },
                   position_model_args: dict = {
                       'learning_rate': 0.05,
                       'weight_decay': 0.05
                   },
                   training_args: dict = {
                       'batch_size': 1024,
                       'epochs': 100,
                       'patience': 20,
                       'block_batch': [1000, 100]
                   }):
    train_position = train_data['position']
    train_rating = train_data['rating']

    # build data_loader.
    train_loader = utils.data_loader.User(
        train_position,
        train_rating,
        u_batch_size=training_args['block_batch'][0],
        device=device)
    val_loader = utils.data_loader.DataLoader(
        utils.data_loader.Interactions(val_data),
        batch_size=training_args['batch_size'],
        shuffle=False,
        num_workers=0)
    test_loader = utils.data_loader.DataLoader(
        utils.data_loader.Interactions(test_data),
        batch_size=training_args['batch_size'],
        shuffle=False,
        num_workers=0)

    n_user, n_item = train_position.shape
    n_position = torch.max(train_position._values()).item() + 1

    base_model = MF(n_user, n_item, dim=base_model_args['emb_dim']).to(device)
    base_optimizer = torch.optim.SGD(base_model.parameters(),
                                     lr=base_model_args['learning_rate'],
                                     weight_decay=0)

    position_model = Position(n_position).to(device)
    position_optimizer = torch.optim.SGD(
        position_model.parameters(),
        lr=position_model_args['learning_rate'],
        weight_decay=0)

    # begin training
    stopping_args = Stop_args(patience=training_args['patience'],
                              max_epochs=training_args['epochs'])
    early_stopping = EarlyStopping(base_model, **stopping_args)

    none_criterion = nn.MSELoss(reduction='none')

    for epo in range(early_stopping.max_epochs):
        for u_batch_idx, users in enumerate(train_loader.User_loader):
            base_model.train()
            position_model.train()

            # observation data in this batch
            users_train, items_train, positions_train, y_train = train_loader.get_batch(
                users)
            first_index_row = torch.where(
                positions_train == 0)[0] // n_position
            first_index_col = torch.where(positions_train == 0)[0] % n_position

            y_hat = base_model(users_train, items_train)
            p_hat = position_model(positions_train)

            # users_train = users_train.view(-1, n_position)
            # items_train = items_train.view(-1, n_position)
            # positions_train = positions_train.view(-1, n_position)
            y_train = y_train.view(-1, n_position)

            y_hat = y_hat.view(-1, n_position)
            p_hat = p_hat.view(-1, n_position)

            y_hat_som = torch.softmax(y_hat, dim=-1)  # softmax
            p_hat_som = torch.softmax(p_hat, dim=-1)

            IRW = torch.detach(
                y_hat_som[first_index_row, first_index_col].unsqueeze(dim=-1) *
                torch.reciprocal(y_hat_som))
            IPW = torch.detach(
                p_hat_som[first_index_row, first_index_col].unsqueeze(dim=-1) *
                torch.reciprocal(p_hat_som))

            cost_base = none_criterion(y_hat, y_train)
            # IPW = torch.ones(IPW.shape).to(device)
            loss_IPW = torch.sum(IPW * cost_base) + base_model_args[
                'weight_decay'] * base_model.l2_norm(users_train, items_train)
            cost_position = none_criterion(p_hat, y_train)
            loss_IRW = torch.sum(IRW * cost_position) + position_model_args[
                'weight_decay'] * position_model.l2_norm(positions_train)

            # y_train = y_train == 1
            # loss_IRW = - torch.sum(y_train * IRW * torch.log(p_hat_som)) + position_model_args['weight_decay'] * position_model.l2_norm(positions_train)
            # loss_IPW = - torch.sum(y_train * IPW * torch.log(y_hat_som)) + base_model_args['weight_decay'] * base_model.l2_norm(users_train, items_train)

            base_optimizer.zero_grad()
            loss_IPW.backward()
            base_optimizer.step()

            position_optimizer.zero_grad()
            loss_IRW.backward()
            position_optimizer.step()

        base_model.eval()
        with torch.no_grad():
            # training metrics
            train_pre_ratings = torch.empty(0).to(device)
            train_ratings = torch.empty(0).to(device)
            for u_batch_idx, users in enumerate(train_loader.User_loader):
                users_train, items_train, positions_train, y_train = train_loader.get_batch(
                    users)
                pre_ratings = base_model(users_train, items_train)
                train_pre_ratings = torch.cat((train_pre_ratings, pre_ratings))
                train_ratings = torch.cat((train_ratings, y_train))

            # validation metrics
            val_pre_ratings = torch.empty(0).to(device)
            val_ratings = torch.empty(0).to(device)
            for batch_idx, (users, items, ratings) in enumerate(val_loader):
                pre_ratings = base_model(users, items)
                val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
                val_ratings = torch.cat((val_ratings, ratings))

        train_results = utils.metrics.evaluate(train_pre_ratings,
                                               train_ratings, ['MSE', 'NLL'])
        val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                             ['MSE', 'NLL', 'AUC'])

        print('Epoch: {0:2d} / {1}, Traning: {2}, Validation: {3}'.format(
            epo, training_args['epochs'], ' '.join([
                key + ':' + '%.3f' % train_results[key]
                for key in train_results
            ]), ' '.join([
                key + ':' + '%.3f' % val_results[key] for key in val_results
            ])))

        if early_stopping.check([val_results['AUC']], epo):
            break

    # testing loss
    print('Loading {}th epoch'.format(early_stopping.best_epoch))
    base_model.load_state_dict(early_stopping.best_state)

    # validation metrics
    val_pre_ratings = torch.empty(0).to(device)
    val_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(val_loader):
        pre_ratings = base_model(users, items)
        val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
        val_ratings = torch.cat((val_ratings, ratings))

    # test metrics
    test_users = torch.empty(0, dtype=torch.int64).to(device)
    test_items = torch.empty(0, dtype=torch.int64).to(device)
    test_pre_ratings = torch.empty(0).to(device)
    test_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(test_loader):
        pre_ratings = base_model(users, items)
        test_users = torch.cat((test_users, users))
        test_items = torch.cat((test_items, items))
        test_pre_ratings = torch.cat((test_pre_ratings, pre_ratings))
        test_ratings = torch.cat((test_ratings, ratings))

    val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                         ['MSE', 'NLL', 'AUC'])
    test_results = utils.metrics.evaluate(
        test_pre_ratings,
        test_ratings, ['MSE', 'NLL', 'AUC', 'Recall_Precision_NDCG@'],
        users=test_users,
        items=test_items)
    print('-' * 30)
    print('The performance of validation set: {}'.format(' '.join(
        [key + ':' + '%.3f' % val_results[key] for key in val_results])))
    print('The performance of testing set: {}'.format(' '.join(
        [key + ':' + '%.3f' % test_results[key] for key in test_results])))
    print('-' * 30)
    return val_results, test_results
예제 #2
0
def train_and_eval(train_data,
                   val_data,
                   test_data,
                   device='cuda',
                   model_class=MF,
                   model_args: dict = {
                       'emb_dim': 64,
                       'learning_rate': 0.05,
                       'weight_decay': 0.05
                   },
                   training_args: dict = {
                       'batch_size': 1024,
                       'epochs': 100,
                       'patience': 20,
                       'block_batch': [1000, 100]
                   }):

    # build data_loader.
    train_loader = utils.data_loader.Block(
        train_data,
        u_batch_size=training_args['block_batch'][0],
        i_batch_size=training_args['block_batch'][1],
        device=device)
    val_loader = utils.data_loader.DataLoader(
        utils.data_loader.Interactions(val_data),
        batch_size=training_args['batch_size'],
        shuffle=False,
        num_workers=0)
    test_loader = utils.data_loader.DataLoader(
        utils.data_loader.Interactions(test_data),
        batch_size=training_args['batch_size'],
        shuffle=False,
        num_workers=0)

    # data shape
    n_user, n_item = train_data.shape

    # model and its optimizer.
    model = MF(n_user, n_item, dim=model_args['emb_dim'], dropout=0).to(device)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=model_args['learning_rate'],
                                weight_decay=0)

    # loss_criterion
    criterion = nn.MSELoss(reduction='sum')

    def complement(u, i, u_all, i_all):
        mask_u = np.isin(u_all.cpu().numpy(), u.cpu().numpy())
        mask_i = np.isin(i_all.cpu().numpy(), i.cpu().numpy())
        mask = torch.tensor(1 - mask_u * mask_i).to('cuda')
        return mask

    # begin training
    stopping_args = Stop_args(patience=training_args['patience'],
                              max_epochs=training_args['epochs'])
    early_stopping = EarlyStopping(model, **stopping_args)
    for epo in range(early_stopping.max_epochs):
        training_loss = 0
        for u_batch_idx, users in enumerate(train_loader.User_loader):
            for i_batch_idx, items in enumerate(train_loader.Item_loader):
                # loss of training set
                model.train()
                users_train, items_train, y_train = train_loader.get_batch(
                    users, items)
                y_hat_obs = model(users_train, items_train)
                loss_obs = criterion(y_hat_obs, y_train)

                all_pair = torch.cartesian_prod(users, items)
                users_all, items_all = all_pair[:, 0], all_pair[:, 1]
                y_hat_all = model(users_all, items_all)
                impu_all = torch.zeros((users_all.shape)).to(device) - 1
                # mask = complement(users_train, items_train, users_all, items_all)
                # loss_all = criterion(y_hat_all * mask, impu_all * mask)
                loss_all = criterion(y_hat_all, impu_all)

                loss = loss_obs + model_args[
                    'imputaion_lambda'] * loss_all + model_args[
                        'weight_decay'] * model.l2_norm(users_all, items_all)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                training_loss += loss.item()

        model.eval()
        with torch.no_grad():
            # train metrics
            train_pre_ratings = torch.empty(0).to(device)
            train_ratings = torch.empty(0).to(device)
            for u_batch_idx, users in enumerate(train_loader.User_loader):
                for i_batch_idx, items in enumerate(train_loader.Item_loader):
                    users_train, items_train, y_train = train_loader.get_batch(
                        users, items)
                    pre_ratings = model(users_train, items_train)
                    train_pre_ratings = torch.cat(
                        (train_pre_ratings, pre_ratings))
                    train_ratings = torch.cat((train_ratings, y_train))

            # validation metrics
            val_pre_ratings = torch.empty(0).to(device)
            val_ratings = torch.empty(0).to(device)
            for batch_idx, (users, items, ratings) in enumerate(val_loader):
                pre_ratings = model(users, items)
                val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
                val_ratings = torch.cat((val_ratings, ratings))

        train_results = utils.metrics.evaluate(train_pre_ratings,
                                               train_ratings, ['MSE', 'NLL'])
        val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                             ['MSE', 'NLL', 'AUC'])

        print('Epoch: {0:2d} / {1}, Traning: {2}, Validation: {3}'.format(
            epo, training_args['epochs'], ' '.join([
                key + ':' + '%.3f' % train_results[key]
                for key in train_results
            ]), ' '.join([
                key + ':' + '%.3f' % val_results[key] for key in val_results
            ])))

        if early_stopping.check([val_results['AUC']], epo):
            break

    # testing loss
    print('Loading {}th epoch'.format(early_stopping.best_epoch))
    model.load_state_dict(early_stopping.best_state)

    # validation metrics
    val_pre_ratings = torch.empty(0).to(device)
    val_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(val_loader):
        pre_ratings = model(users, items)
        val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
        val_ratings = torch.cat((val_ratings, ratings))

    # test metrics
    test_users = torch.empty(0, dtype=torch.int64).to(device)
    test_items = torch.empty(0, dtype=torch.int64).to(device)
    test_pre_ratings = torch.empty(0).to(device)
    test_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(test_loader):
        pre_ratings = model(users, items)
        test_users = torch.cat((test_users, users))
        test_items = torch.cat((test_items, items))
        test_pre_ratings = torch.cat((test_pre_ratings, pre_ratings))
        test_ratings = torch.cat((test_ratings, ratings))

    val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                         ['MSE', 'NLL', 'AUC'])
    test_results = utils.metrics.evaluate(
        test_pre_ratings,
        test_ratings, ['MSE', 'NLL', 'AUC', 'Recall_Precision_NDCG@'],
        users=test_users,
        items=test_items)
    print('-' * 30)
    print('The performance of validation set: {}'.format(' '.join(
        [key + ':' + '%.3f' % val_results[key] for key in val_results])))
    print('The performance of testing set: {}'.format(' '.join(
        [key + ':' + '%.3f' % test_results[key] for key in test_results])))
    print('-' * 30)
    return val_results, test_results
예제 #3
0
def train_and_eval(train_data,
                   unif_train_data,
                   val_data,
                   test_data,
                   device='cuda',
                   model_args: dict = {
                       'emb_dim': 64,
                       'learning_rate': 0.01,
                       'weight_decay': 0.1
                   },
                   training_args: dict = {
                       'batch_size': 1024,
                       'epochs': 100,
                       'patience': 20,
                       'block_batch': [1000, 100]
                   }):

    # build data_loader.
    train_loader = utils.data_loader.Block(
        train_data,
        u_batch_size=training_args['block_batch'][0],
        i_batch_size=training_args['block_batch'][1],
        device=device)
    val_loader = utils.data_loader.DataLoader(
        utils.data_loader.Interactions(val_data),
        batch_size=training_args['batch_size'],
        shuffle=False,
        num_workers=0)
    test_loader = utils.data_loader.DataLoader(
        utils.data_loader.Interactions(test_data),
        batch_size=training_args['batch_size'],
        shuffle=False,
        num_workers=0)

    # Naive Bayes propensity Estimator
    def Naive_Bayes_Propensity(train, unif):
        # the implementation of naive bayes propensity
        P_Oeq1 = train._nnz() / (train.size()[0] * train.size()[1])

        y_unique = torch.unique(train._values())
        P_y_givenO = torch.zeros(y_unique.shape).to(device)
        P_y = torch.zeros(y_unique.shape).to(device)

        for i in range(len(y_unique)):
            P_y_givenO[i] = torch.sum(
                train._values() == y_unique[i]) / torch.sum(
                    torch.ones(train._values().shape).to(device))
            P_y[i] = torch.sum(unif._values() == y_unique[i]) / torch.sum(
                torch.ones(unif._values().shape).to(device))

        Propensity = P_y_givenO * P_Oeq1 / P_y

        return y_unique, Propensity

    y_unique, Propensity = Naive_Bayes_Propensity(train_data, unif_train_data)
    InvP = torch.reciprocal(Propensity)

    # data shape
    n_user, n_item = train_data.shape

    # model and its optimizer.
    model = MF(n_user, n_item, dim=model_args['emb_dim'], dropout=0).to(device)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=model_args['learning_rate'],
                                weight_decay=0)

    # loss_criterion
    none_criterion = nn.MSELoss(reduction='none')

    # begin training
    stopping_args = Stop_args(patience=training_args['patience'],
                              max_epochs=training_args['epochs'])
    early_stopping = EarlyStopping(model, **stopping_args)
    for epo in range(early_stopping.max_epochs):
        training_loss = 0
        for u_batch_idx, users in enumerate(train_loader.User_loader):
            for i_batch_idx, items in enumerate(train_loader.Item_loader):
                # loss of training set
                model.train()
                users_train, items_train, y_train = train_loader.get_batch(
                    users, items)
                y_hat = model(users_train, items_train)

                cost = none_criterion(y_hat, y_train)
                weight = torch.ones(y_train.shape).to(device)
                for i in range(len(y_unique)):
                    weight[y_train == y_unique[i]] = InvP[i]

                loss = torch.sum(
                    weight * cost
                ) + model_args['weight_decay'] * model.l2_norm(users, items)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                training_loss += loss.item()

        model.eval()
        with torch.no_grad():
            # train metrics
            train_pre_ratings = torch.empty(0).to(device)
            train_ratings = torch.empty(0).to(device)
            for u_batch_idx, users in enumerate(train_loader.User_loader):
                for i_batch_idx, items in enumerate(train_loader.Item_loader):
                    users_train, items_train, y_train = train_loader.get_batch(
                        users, items)
                    pre_ratings = model(users_train, items_train)
                    train_pre_ratings = torch.cat(
                        (train_pre_ratings, pre_ratings))
                    train_ratings = torch.cat((train_ratings, y_train))

            # validation metrics
            val_pre_ratings = torch.empty(0).to(device)
            val_ratings = torch.empty(0).to(device)
            for batch_idx, (users, items, ratings) in enumerate(val_loader):
                pre_ratings = model(users, items)
                val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
                val_ratings = torch.cat((val_ratings, ratings))

        train_results = utils.metrics.evaluate(train_pre_ratings,
                                               train_ratings, ['MSE', 'NLL'])
        val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                             ['MSE', 'NLL', 'AUC'])

        print('Epoch: {0:2d} / {1}, Traning: {2}, Validation: {3}'.format(
            epo, training_args['epochs'], ' '.join([
                key + ':' + '%.3f' % train_results[key]
                for key in train_results
            ]), ' '.join([
                key + ':' + '%.3f' % val_results[key] for key in val_results
            ])))

        if early_stopping.check([val_results['AUC']], epo):
            break

    # testing loss
    print('Loading {}th epoch'.format(early_stopping.best_epoch))
    model.load_state_dict(early_stopping.best_state)

    # validation metrics
    val_pre_ratings = torch.empty(0).to(device)
    val_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(val_loader):
        pre_ratings = model(users, items)
        val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
        val_ratings = torch.cat((val_ratings, ratings))

    # test metrics
    test_users = torch.empty(0, dtype=torch.int64).to(device)
    test_items = torch.empty(0, dtype=torch.int64).to(device)
    test_pre_ratings = torch.empty(0).to(device)
    test_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(test_loader):
        pre_ratings = model(users, items)
        test_users = torch.cat((test_users, users))
        test_items = torch.cat((test_items, items))
        test_pre_ratings = torch.cat((test_pre_ratings, pre_ratings))
        test_ratings = torch.cat((test_ratings, ratings))

    val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                         ['MSE', 'NLL', 'AUC'])
    test_results = utils.metrics.evaluate(
        test_pre_ratings,
        test_ratings, ['MSE', 'NLL', 'AUC', 'Recall_Precision_NDCG@'],
        users=test_users,
        items=test_items)
    print('-' * 30)
    print('The performance of validation set: {}'.format(' '.join(
        [key + ':' + '%.3f' % val_results[key] for key in val_results])))
    print('The performance of testing set: {}'.format(' '.join(
        [key + ':' + '%.3f' % test_results[key] for key in test_results])))
    print('-' * 30)
    return val_results, test_results
예제 #4
0
def train_and_eval(
    train_data,
    unif_train_data,
    val_data,
    test_data,
    device='cuda',
    base_model_args: dict = {
        'emb_dim': 64,
        'dropout': 0.0,
        'learning_rate': 0.05,
        'imputaion_lambda': 0.01,
        'weight_decay': 0.05
    },
    weight1_model_args: dict = {
        'learning_rate': 0.1,
        'weight_decay': 0.005
    },
    weight2_model_args: dict = {
        'learning_rate': 0.1,
        'weight_decay': 0.005
    },
    imputation_model_args: dict = {
        'learning_rate': 0.01,
        'weight_decay': 0.5
    },
    training_args: dict = {
        'batch_size': 1024,
        'epochs': 100,
        'patience': 20,
        'block_batch': [1000, 100]
    }):
    train_position = train_data['position']
    train_rating = train_data['rating']
    train_dense = train_rating.to_dense()
    # uniform data
    users_unif = unif_train_data._indices()[0]
    items_unif = unif_train_data._indices()[1]
    y_unif = unif_train_data._values()

    # build data_loader.
    train_loader = utils.data_loader.User(
        train_position,
        train_rating,
        u_batch_size=training_args['block_batch'][0],
        device=device)
    val_loader = utils.data_loader.DataLoader(
        utils.data_loader.Interactions(val_data),
        batch_size=training_args['batch_size'],
        shuffle=False,
        num_workers=0)
    test_loader = utils.data_loader.DataLoader(
        utils.data_loader.Interactions(test_data),
        batch_size=training_args['batch_size'],
        shuffle=False,
        num_workers=0)

    n_user, n_item = train_position.shape
    n_position = torch.max(train_position._values()).item() + 1

    # base model and its optimizer
    base_model = MetaMF(n_user,
                        n_item,
                        dim=base_model_args['emb_dim'],
                        dropout=0).to(device)
    base_optimizer = torch.optim.SGD(
        base_model.params(),
        lr=base_model_args['learning_rate'],
        weight_decay=0)  # todo: other optimizer SGD

    # meta models and their optimizer
    weight1_model = FourLinear(n_user, n_item, 2,
                               n_position=n_position).to(device)
    weight1_optimizer = torch.optim.Adam(
        weight1_model.parameters(),
        lr=weight1_model_args['learning_rate'],
        weight_decay=weight1_model_args['weight_decay'])

    weight2_model = ThreeLinear(n_user, n_item, 2).to(device)
    weight2_optimizer = torch.optim.Adam(
        weight2_model.parameters(),
        lr=weight2_model_args['learning_rate'],
        weight_decay=weight2_model_args['weight_decay'])

    imputation_model = OneLinear(3).to(device)
    imputation_optimizer = torch.optim.Adam(
        imputation_model.parameters(),
        lr=imputation_model_args['learning_rate'],
        weight_decay=imputation_model_args['weight_decay'])

    # loss_criterion
    sum_criterion = nn.MSELoss(reduction='sum')
    none_criterion = nn.MSELoss(reduction='none')

    # begin training
    stopping_args = Stop_args(patience=training_args['patience'],
                              max_epochs=training_args['epochs'])
    early_stopping = EarlyStopping(base_model, **stopping_args)

    for epo in range(early_stopping.max_epochs):
        training_loss = 0
        for u_batch_idx, users in enumerate(train_loader.User_loader):
            users_train, items_train, positions_train, y_train = train_loader.get_batch(
                users)

            # all pair
            all_pair = torch.cartesian_prod(users,
                                            torch.arange(n_item).to(device))
            users_all, items_all = all_pair[:, 0], all_pair[:, 1]

            # calculate weight1
            weight1_model.train()
            weight1 = weight1_model(users_train, items_train,
                                    (y_train == 1) * 1, positions_train)
            weight1 = torch.exp(weight1 / 5)  # for stable training

            # calculate weight2
            weight2_model.train()
            weight2 = weight2_model(users_all, items_all,
                                    (train_dense[users_all, items_all] != 0) *
                                    1)
            weight2 = torch.exp(weight2 / 5)  #for stable training

            # calculate imputation values
            imputation_model.train()
            impu_f_all = torch.tanh(
                imputation_model((train_dense[users_all, items_all]).long() +
                                 1))

            # one_step_model: assumed model, just update one step on base model. it is for updating weight parameters
            one_step_model = MetaMF(n_user,
                                    n_item,
                                    dim=base_model_args['emb_dim'],
                                    dropout=0)
            one_step_model.load_state_dict(base_model.state_dict())

            # formal parameter: Using training set to update parameters
            one_step_model.train()
            # all pair data in this block
            y_hat_f_all = one_step_model(users_all, items_all)
            cost_f_all = none_criterion(y_hat_f_all, impu_f_all)
            loss_f_all = torch.sum(cost_f_all * weight2)
            # observation data
            y_hat_f_obs = one_step_model(users_train, items_train)
            cost_f_obs = none_criterion(y_hat_f_obs, y_train)
            loss_f_obs = torch.sum(cost_f_obs * weight1)
            loss_f = loss_f_obs + base_model_args[
                'imputaion_lambda'] * loss_f_all + base_model_args[
                    'weight_decay'] * one_step_model.l2_norm(
                        users_all, items_all)

            # update parameters of one_step_model
            one_step_model.zero_grad()
            grads = torch.autograd.grad(loss_f, (one_step_model.params()),
                                        create_graph=True)
            one_step_model.update_params(base_model_args['learning_rate'],
                                         source_params=grads)

            # latter hyper_parameter: Using uniform set to update hyper_parameters
            y_hat_l = one_step_model(users_unif, items_unif)
            loss_l = sum_criterion(y_hat_l, y_unif)

            # update weight parameters
            weight1_optimizer.zero_grad()
            weight2_optimizer.zero_grad()
            imputation_optimizer.zero_grad()
            loss_l.backward()

            if epo >= 20:
                weight1_optimizer.step()
                weight2_optimizer.step()
            imputation_optimizer.step()

            # use new weights to update parameters
            weight1_model.train()
            weight1 = weight1_model(users_train, items_train,
                                    (y_train == 1) * 1, positions_train)
            weight1 = torch.exp(weight1 / 5)  # for stable training

            # calculate weight2
            weight2_model.train()
            weight2 = weight2_model(users_all, items_all,
                                    (train_dense[users_all, items_all] != 0) *
                                    1)
            weight2 = torch.exp(weight2 / 5)  # for stable training

            # use new imputation to update parameters
            imputation_model.train()
            impu_all = torch.tanh(
                imputation_model((train_dense[users_all, items_all]).long() +
                                 1))

            # loss of training set
            base_model.train()
            # all pair
            y_hat_all = base_model(users_all, items_all)
            cost_all = none_criterion(y_hat_all, impu_all)
            loss_all = torch.sum(cost_all * weight2)
            # observation
            y_hat_obs = base_model(users_train, items_train)
            cost_obs = none_criterion(y_hat_obs, y_train)
            loss_obs = torch.sum(cost_obs * weight1)
            loss = loss_obs + base_model_args[
                'imputaion_lambda'] * loss_all + base_model_args[
                    'weight_decay'] * base_model.l2_norm(users_all, items_all)

            base_optimizer.zero_grad()
            loss.backward()
            base_optimizer.step()

            training_loss += loss.item()

        base_model.eval()
        with torch.no_grad():
            # training metrics
            train_pre_ratings = torch.empty(0).to(device)
            train_ratings = torch.empty(0).to(device)
            for u_batch_idx, users in enumerate(train_loader.User_loader):
                users_train, items_train, positions_train, y_train = train_loader.get_batch(
                    users)
                pre_ratings = base_model(users_train, items_train)
                train_pre_ratings = torch.cat((train_pre_ratings, pre_ratings))
                train_ratings = torch.cat((train_ratings, y_train))

            # validation metrics
            val_pre_ratings = torch.empty(0).to(device)
            val_ratings = torch.empty(0).to(device)
            for batch_idx, (users, items, ratings) in enumerate(val_loader):
                pre_ratings = base_model(users, items)
                val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
                val_ratings = torch.cat((val_ratings, ratings))

        train_results = utils.metrics.evaluate(train_pre_ratings,
                                               train_ratings, ['MSE', 'NLL'])
        val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                             ['MSE', 'NLL', 'AUC'])

        print('Epoch: {0:2d} / {1}, Traning: {2}, Validation: {3}'.format(
            epo, training_args['epochs'], ' '.join([
                key + ':' + '%.3f' % train_results[key]
                for key in train_results
            ]), ' '.join([
                key + ':' + '%.3f' % val_results[key] for key in val_results
            ])))

        if epo >= 50 and early_stopping.check([val_results['AUC']], epo):
            break

    # restore best model
    print('Loading {}th epoch'.format(early_stopping.best_epoch))
    base_model.load_state_dict(early_stopping.best_state)

    # validation metrics
    val_pre_ratings = torch.empty(0).to(device)
    val_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(val_loader):
        pre_ratings = base_model(users, items)
        val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
        val_ratings = torch.cat((val_ratings, ratings))

    # test metrics
    test_users = torch.empty(0, dtype=torch.int64).to(device)
    test_items = torch.empty(0, dtype=torch.int64).to(device)
    test_pre_ratings = torch.empty(0).to(device)
    test_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(test_loader):
        pre_ratings = base_model(users, items)
        test_users = torch.cat((test_users, users))
        test_items = torch.cat((test_items, items))
        test_pre_ratings = torch.cat((test_pre_ratings, pre_ratings))
        test_ratings = torch.cat((test_ratings, ratings))

    val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                         ['MSE', 'NLL', 'AUC'])
    test_results = utils.metrics.evaluate(
        test_pre_ratings,
        test_ratings, ['MSE', 'NLL', 'AUC', 'Recall_Precision_NDCG@'],
        users=test_users,
        items=test_items)
    print('-' * 30)
    print('The performance of validation set: {}'.format(' '.join(
        [key + ':' + '%.3f' % val_results[key] for key in val_results])))
    print('The performance of testing set: {}'.format(' '.join(
        [key + ':' + '%.3f' % test_results[key] for key in test_results])))
    print('-' * 30)
    return val_results, test_results
예제 #5
0
def train_and_eval(train_data,
                   unif_data,
                   val_data,
                   test_data,
                   device='cuda',
                   model_class=MF,
                   base_model_args: dict = {
                       'emb_dim': 10,
                       'learning_rate': 0.001,
                       'weight_decay': 0.0
                   },
                   teacher_model_args: dict = {
                       'emb_dim': 10,
                       'learning_rate': 0.1,
                       'weight_decay': 0.0
                   },
                   training_args: dict = {
                       'batch_size': 1024,
                       'epochs': 100,
                       'patience': 20,
                       'block_batch': [1000, 100]
                   }):

    # build data_loader.
    train_loader = utils.data_loader.Block(
        train_data,
        u_batch_size=training_args['block_batch'][0],
        i_batch_size=training_args['block_batch'][1],
        device=device)
    unif_loader = utils.data_loader.Block(
        unif_data,
        u_batch_size=training_args['block_batch'][0],
        i_batch_size=training_args['block_batch'][1],
        device=device)
    val_loader = utils.data_loader.DataLoader(
        utils.data_loader.Interactions(val_data),
        batch_size=training_args['batch_size'],
        shuffle=False,
        num_workers=0)
    test_loader = utils.data_loader.DataLoader(
        utils.data_loader.Interactions(test_data),
        batch_size=training_args['batch_size'],
        shuffle=False,
        num_workers=0)

    # data shape
    n_user, n_item = train_data.shape

    # base model and its optimizer.
    base_model = MF(n_user, n_item, dim=base_model_args['emb_dim'],
                    dropout=0).to(device)
    base_optimizer = torch.optim.SGD(base_model.parameters(),
                                     lr=base_model_args['learning_rate'],
                                     weight_decay=0)

    # teacher model and its optimizer.
    teacher_model = MF(n_user,
                       n_item,
                       dim=teacher_model_args['emb_dim'],
                       dropout=0).to(device)
    teacher_optimizer = torch.optim.SGD(teacher_model.parameters(),
                                        lr=teacher_model_args['learning_rate'],
                                        weight_decay=0)

    # loss_criterion
    criterion = nn.MSELoss(reduction='sum')

    # begin base model training
    stopping_args = Stop_args(patience=training_args['patience'],
                              max_epochs=training_args['epochs'] // 1)
    early_stopping = EarlyStopping(base_model, **stopping_args)
    for epo in range(early_stopping.max_epochs):
        training_loss = 0
        for u_batch_idx, users in enumerate(train_loader.User_loader):
            for i_batch_idx, items in enumerate(train_loader.Item_loader):
                # loss of training set
                base_model.train()
                users_train, items_train, y_train = train_loader.get_batch(
                    users, items)
                y_hat = base_model(users_train, items_train)
                loss = criterion(
                    y_hat, y_train
                ) + base_model_args['weight_decay'] * base_model.l2_norm(
                    users_train, items_train)

                base_optimizer.zero_grad()
                loss.backward()
                base_optimizer.step()

                training_loss += loss.item()

        base_model.eval()
        with torch.no_grad():
            # train metrics
            train_pre_ratings = torch.empty(0).to(device)
            train_ratings = torch.empty(0).to(device)
            for u_batch_idx, users in enumerate(train_loader.User_loader):
                for i_batch_idx, items in enumerate(train_loader.Item_loader):
                    users_train, items_train, y_train = train_loader.get_batch(
                        users, items)
                    pre_ratings = base_model(users_train, items_train)
                    train_pre_ratings = torch.cat(
                        (train_pre_ratings, pre_ratings))
                    train_ratings = torch.cat((train_ratings, y_train))

            # validation metrics
            val_pre_ratings = torch.empty(0).to(device)
            val_ratings = torch.empty(0).to(device)
            for batch_idx, (users, items, ratings) in enumerate(val_loader):
                pre_ratings = base_model(users, items)
                val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
                val_ratings = torch.cat((val_ratings, ratings))

        train_results = utils.metrics.evaluate(train_pre_ratings,
                                               train_ratings, ['MSE', 'NLL'])
        val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                             ['MSE', 'NLL', 'AUC'])

        print('Epoch: {0:2d} / {1}, Traning: {2}, Validation: {3}'.format(
            epo, training_args['epochs'], ' '.join([
                key + ':' + '%.3f' % train_results[key]
                for key in train_results
            ]), ' '.join([
                key + ':' + '%.3f' % val_results[key] for key in val_results
            ])))

        if early_stopping.check([val_results['AUC']], epo):
            break

    # testing loss
    print('Loading {}th epoch'.format(early_stopping.best_epoch))
    base_model.load_state_dict(early_stopping.best_state)

    # validation metrics
    val_pre_ratings = torch.empty(0).to(device)
    val_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(val_loader):
        pre_ratings = base_model(users, items)
        val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
        val_ratings = torch.cat((val_ratings, ratings))

    # test metrics
    test_users = torch.empty(0, dtype=torch.int64).to(device)
    test_items = torch.empty(0, dtype=torch.int64).to(device)
    test_pre_ratings = torch.empty(0).to(device)
    test_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(test_loader):
        pre_ratings = base_model(users, items)
        test_users = torch.cat((test_users, users))
        test_items = torch.cat((test_items, items))
        test_pre_ratings = torch.cat((test_pre_ratings, pre_ratings))
        test_ratings = torch.cat((test_ratings, ratings))

    val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                         ['MSE', 'NLL', 'AUC'])
    test_results = utils.metrics.evaluate(
        test_pre_ratings,
        test_ratings, ['MSE', 'NLL', 'AUC', 'Recall_Precision_NDCG@'],
        users=test_users,
        items=test_items)
    print('-' * 30)
    print('The performance of validation set: {}'.format(' '.join(
        [key + ':' + '%.3f' % val_results[key] for key in val_results])))
    print('The performance of testing set: {}'.format(' '.join(
        [key + ':' + '%.3f' % test_results[key] for key in test_results])))
    print('-' * 30)

    # begin teacher model training
    stopping_args = Stop_args(patience=training_args['patience'],
                              max_epochs=training_args['epochs'] // 1)
    early_stopping = EarlyStopping(teacher_model, **stopping_args)
    for epo in range(early_stopping.max_epochs):
        training_loss = 0
        for u_batch_idx, users in enumerate(unif_loader.User_loader):
            for i_batch_idx, items in enumerate(unif_loader.Item_loader):
                # loss of training set
                teacher_model.train()
                users_train, items_train, y_train = unif_loader.get_batch(
                    users, items)
                y_hat = teacher_model(users_train, items_train)
                loss = criterion(
                    y_hat, y_train
                ) + teacher_model_args['weight_decay'] * teacher_model.l2_norm(
                    users_train, items_train)

                teacher_optimizer.zero_grad()
                loss.backward()
                teacher_optimizer.step()

                training_loss += loss.item()

        teacher_model.eval()
        with torch.no_grad():
            # train metrics
            train_pre_ratings = torch.empty(0).to(device)
            train_ratings = torch.empty(0).to(device)
            for u_batch_idx, users in enumerate(unif_loader.User_loader):
                for i_batch_idx, items in enumerate(unif_loader.Item_loader):
                    users_train, items_train, y_train = unif_loader.get_batch(
                        users, items)
                    pre_ratings = teacher_model(users_train, items_train)
                    train_pre_ratings = torch.cat(
                        (train_pre_ratings, pre_ratings))
                    train_ratings = torch.cat((train_ratings, y_train))

            # validation metrics
            val_pre_ratings = torch.empty(0).to(device)
            val_ratings = torch.empty(0).to(device)
            for batch_idx, (users, items, ratings) in enumerate(val_loader):
                pre_ratings = teacher_model(users, items)
                val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
                val_ratings = torch.cat((val_ratings, ratings))

        train_results = utils.metrics.evaluate(train_pre_ratings,
                                               train_ratings, ['MSE', 'NLL'])
        val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                             ['MSE', 'NLL', 'AUC'])

        print('Epoch: {0:2d} / {1}, Traning: {2}, Validation: {3}'.format(
            epo, training_args['epochs'], ' '.join([
                key + ':' + '%.3f' % train_results[key]
                for key in train_results
            ]), ' '.join([
                key + ':' + '%.3f' % val_results[key] for key in val_results
            ])))

        if early_stopping.check([val_results['AUC']], epo):
            break

    # testing loss
    print('Loading {}th epoch'.format(early_stopping.best_epoch))
    teacher_model.load_state_dict(early_stopping.best_state)

    # validation metrics
    val_pre_ratings = torch.empty(0).to(device)
    val_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(val_loader):
        pre_ratings = teacher_model(users, items)
        val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
        val_ratings = torch.cat((val_ratings, ratings))

    # test metrics
    test_users = torch.empty(0, dtype=torch.int64).to(device)
    test_items = torch.empty(0, dtype=torch.int64).to(device)
    test_pre_ratings = torch.empty(0).to(device)
    test_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(test_loader):
        pre_ratings = teacher_model(users, items)
        test_users = torch.cat((test_users, users))
        test_items = torch.cat((test_items, items))
        test_pre_ratings = torch.cat((test_pre_ratings, pre_ratings))
        test_ratings = torch.cat((test_ratings, ratings))

    val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                         ['MSE', 'NLL', 'AUC'])
    test_results = utils.metrics.evaluate(
        test_pre_ratings,
        test_ratings, ['MSE', 'NLL', 'AUC', 'Recall_Precision_NDCG@'],
        users=test_users,
        items=test_items)
    print('-' * 30)
    print('The performance of validation set: {}'.format(' '.join(
        [key + ':' + '%.3f' % val_results[key] for key in val_results])))
    print('The performance of testing set: {}'.format(' '.join(
        [key + ':' + '%.3f' % test_results[key] for key in test_results])))
    print('-' * 30)

    # base model re-tune
    stopping_args = Stop_args(patience=training_args['patience'],
                              max_epochs=training_args['epochs'] // 1)
    early_stopping = EarlyStopping(base_model, **stopping_args)
    for epo in range(early_stopping.max_epochs):
        training_loss = 0
        for u_batch_idx, users in enumerate(train_loader.User_loader):
            for i_batch_idx, items in enumerate(train_loader.Item_loader):
                # loss of training set
                base_model.train()
                users_train, items_train, y_train = train_loader.get_batch(
                    users, items)
                y_hat = base_model(users_train, items_train)
                y_res = teacher_model(users_train, items_train)
                y_res = (y_res - torch.min(y_res)) / (torch.max(y_res) -
                                                      torch.min(y_res))
                y_train = y_train + y_res
                loss = criterion(
                    y_hat, y_train
                ) + base_model_args['weight_decay'] * base_model.l2_norm(
                    users_train, items_train)

                base_optimizer.zero_grad()
                loss.backward()
                base_optimizer.step()

                training_loss += loss.item()

        base_model.eval()
        with torch.no_grad():
            # train metrics
            train_pre_ratings = torch.empty(0).to(device)
            train_ratings = torch.empty(0).to(device)
            for u_batch_idx, users in enumerate(train_loader.User_loader):
                for i_batch_idx, items in enumerate(train_loader.Item_loader):
                    users_train, items_train, y_train = train_loader.get_batch(
                        users, items)
                    pre_ratings = base_model(users_train, items_train)
                    train_pre_ratings = torch.cat(
                        (train_pre_ratings, pre_ratings))
                    train_ratings = torch.cat((train_ratings, y_train))

            # validation metrics
            val_pre_ratings = torch.empty(0).to(device)
            val_ratings = torch.empty(0).to(device)
            for batch_idx, (users, items, ratings) in enumerate(val_loader):
                pre_ratings = base_model(users, items)
                val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
                val_ratings = torch.cat((val_ratings, ratings))

        train_results = utils.metrics.evaluate(train_pre_ratings,
                                               train_ratings, ['MSE', 'NLL'])
        val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                             ['MSE', 'NLL', 'AUC'])

        print('Epoch: {0:2d} / {1}, Traning: {2}, Validation: {3}'.format(
            epo, training_args['epochs'], ' '.join([
                key + ':' + '%.3f' % train_results[key]
                for key in train_results
            ]), ' '.join([
                key + ':' + '%.3f' % val_results[key] for key in val_results
            ])))

        if early_stopping.check([val_results['AUC']], epo):
            break

    # testing loss
    print('Loading {}th epoch'.format(early_stopping.best_epoch))
    base_model.load_state_dict(early_stopping.best_state)

    # validation metrics
    val_pre_ratings = torch.empty(0).to(device)
    val_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(val_loader):
        pre_ratings = base_model(users, items)
        val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
        val_ratings = torch.cat((val_ratings, ratings))

    # test metrics
    test_users = torch.empty(0, dtype=torch.int64).to(device)
    test_items = torch.empty(0, dtype=torch.int64).to(device)
    test_pre_ratings = torch.empty(0).to(device)
    test_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(test_loader):
        pre_ratings = base_model(users, items)
        test_users = torch.cat((test_users, users))
        test_items = torch.cat((test_items, items))
        test_pre_ratings = torch.cat((test_pre_ratings, pre_ratings))
        test_ratings = torch.cat((test_ratings, ratings))

    val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                         ['MSE', 'NLL', 'AUC'])
    test_results = utils.metrics.evaluate(
        test_pre_ratings,
        test_ratings, ['MSE', 'NLL', 'AUC', 'Recall_Precision_NDCG@'],
        users=test_users,
        items=test_items)
    print('-' * 30)
    print('The performance of validation set: {}'.format(' '.join(
        [key + ':' + '%.3f' % val_results[key] for key in val_results])))
    print('The performance of testing set: {}'.format(' '.join(
        [key + ':' + '%.3f' % test_results[key] for key in test_results])))
    print('-' * 30)
    return val_results, test_results
예제 #6
0
파일: DR.py 프로젝트: DongHande/AutoDebias
def train_and_eval(train_data, unif_train_data, val_data, test_data, device = 'cuda', 
        base_model_args: dict = {'emb_dim': 64, 'learning_rate': 0.01, 'weight_decay': 0.1}, 
        imputation_model_args: dict = {'emb_dim': 10, 'learning_rate': 0.1, 'weight_decay': 0.1}, 
        training_args: dict =  {'batch_size': 1024, 'epochs': 100, 'patience': 20, 'block_batch': [1000, 100]}): 
    # double robust model 

    # build data_loader. 
    train_loader = utils.data_loader.Block(train_data, u_batch_size=training_args['block_batch'][0], i_batch_size=training_args['block_batch'][1], device=device)
    val_loader = utils.data_loader.DataLoader(utils.data_loader.Interactions(val_data), batch_size=training_args['batch_size'], shuffle=False, num_workers=0)
    test_loader = utils.data_loader.DataLoader(utils.data_loader.Interactions(test_data), batch_size=training_args['batch_size'], shuffle=False, num_workers=0)

    # Naive Bayes propensity Estimator
    def Naive_Bayes_Propensity(train, unif): 
        # the implementation of naive bayes propensity
        P_Oeq1 = train._nnz() / (train.size()[0] * train.size()[1])

        y_unique = torch.unique(train._values())
        P_y_givenO = torch.zeros(y_unique.shape).to(device)
        P_y = torch.zeros(y_unique.shape).to(device)

        for i in range(len(y_unique)): 
            P_y_givenO[i] = torch.sum(train._values() == y_unique[i]) / torch.sum(torch.ones(train._values().shape).to(device))
            P_y[i] = torch.sum(unif._values() == y_unique[i]) / torch.sum(torch.ones(unif._values().shape).to(device))

        Propensity = P_y_givenO / P_y * P_Oeq1

        return y_unique, Propensity

    y_unique, Propensity = Naive_Bayes_Propensity(train_data, unif_train_data)
    InvP = torch.reciprocal(Propensity)
    
    # data shape
    n_user, n_item = train_data.shape

    # model and its optimizer. 
    base_model = MF(n_user, n_item, dim=base_model_args['emb_dim'], dropout=0).to(device)
    base_optimizer = torch.optim.SGD(base_model.parameters(), lr=base_model_args['learning_rate'], weight_decay=0)

    imputation_model = MF(n_user, n_item, dim=imputation_model_args['emb_dim'], dropout=0).to(device)
    imputation_optimizer = torch.optim.SGD(imputation_model.parameters(), lr=imputation_model_args['learning_rate'], weight_decay=0)
    
    # loss_criterion
    none_criterion = nn.MSELoss(reduction='none')
    sum_criterion = nn.MSELoss(reduction='sum')

    # begin training
    stopping_args = Stop_args(patience=training_args['patience'], max_epochs=training_args['epochs'])
    early_stopping = EarlyStopping(base_model, **stopping_args)
    for epo in range(early_stopping.max_epochs):
        for u_batch_idx, users in enumerate(train_loader.User_loader): 
            for i_batch_idx, items in enumerate(train_loader.Item_loader): 
                # observation data in this batch
                users_train, items_train, y_train = train_loader.get_batch(users, items)
                weight = torch.ones(y_train.shape).to(device)
                for i in range(len(y_unique)): 
                    weight[y_train == y_unique[i]] = InvP[i]

                # step 1: update imptation error model
                imputation_model.train()

                e_hat = imputation_model(users_train, items_train) # imputation error
                e = y_train - base_model(users_train, items_train) # prediction error
                cost_e = none_criterion(e_hat, e) # the cost of error, i.e., the difference between imputaiton error and prediction error

                loss_imp = torch.sum(weight * cost_e) + imputation_model_args['weight_decay'] * imputation_model.l2_norm(users_train, items_train)
                
                imputation_optimizer.zero_grad()
                loss_imp.backward()
                imputation_optimizer.step()

                # step 2: update predition model
                base_model.train()

                # all pair data in this block
                all_pair = torch.cartesian_prod(users, items)
                users_all, items_all = all_pair[:,0], all_pair[:,1]

                y_hat_all = base_model(users_all, items_all)
                y_hat_all_detach = torch.detach(y_hat_all)
                g_all = imputation_model(users_all, items_all)

                loss_all = sum_criterion(y_hat_all, g_all + y_hat_all_detach) # sum(e_hat)

                # observation data
                y_hat_obs = base_model(users_train, items_train)
                y_hat_obs_detach = torch.detach(y_hat_obs)
                g_obs = imputation_model(users_train, items_train)

                e_obs = none_criterion(y_hat_obs, y_train)
                e_hat_obs = none_criterion(y_hat_obs, g_obs + y_hat_obs_detach)

                cost_obs = e_obs - e_hat_obs
                loss_obs = torch.sum(weight * cost_obs)

                loss_base = loss_all + loss_obs + base_model_args['weight_decay'] * base_model.l2_norm(users_all, items_all)

                base_optimizer.zero_grad()
                loss_base.backward()
                base_optimizer.step()
        
        base_model.eval()
        with torch.no_grad():
            # training metrics
            train_pre_ratings = torch.empty(0).to(device)
            train_ratings = torch.empty(0).to(device)
            for u_batch_idx, users in enumerate(train_loader.User_loader): 
                for i_batch_idx, items in enumerate(train_loader.Item_loader): 
                    users_train, items_train, y_train = train_loader.get_batch(users, items)
                    pre_ratings = base_model(users_train, items_train)
                    train_pre_ratings = torch.cat((train_pre_ratings, pre_ratings))
                    train_ratings = torch.cat((train_ratings, y_train))

            # validation metrics
            val_pre_ratings = torch.empty(0).to(device)
            val_ratings = torch.empty(0).to(device)
            for batch_idx, (users, items, ratings) in enumerate(val_loader):
                pre_ratings = base_model(users, items)
                val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
                val_ratings = torch.cat((val_ratings, ratings))
            
        train_results = utils.metrics.evaluate(train_pre_ratings, train_ratings, ['MSE', 'NLL'])
        val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings, ['MSE', 'NLL', 'AUC'])

        print('Epoch: {0:2d} / {1}, Traning: {2}, Validation: {3}'.
                format(epo, training_args['epochs'], ' '.join([key+':'+'%.3f'%train_results[key] for key in train_results]), 
                ' '.join([key+':'+'%.3f'%val_results[key] for key in val_results])))

        if early_stopping.check([val_results['AUC']], epo):
            break

    # testing loss
    print('Loading {}th epoch'.format(early_stopping.best_epoch))
    base_model.load_state_dict(early_stopping.best_state)

    # validation metrics
    val_pre_ratings = torch.empty(0).to(device)
    val_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(val_loader):
        pre_ratings = base_model(users, items)
        val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
        val_ratings = torch.cat((val_ratings, ratings))

    # test metrics
    test_users = torch.empty(0, dtype=torch.int64).to(device)
    test_items = torch.empty(0, dtype=torch.int64).to(device)
    test_pre_ratings = torch.empty(0).to(device)
    test_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(test_loader):
        pre_ratings = base_model(users, items)
        test_users = torch.cat((test_users, users))
        test_items = torch.cat((test_items, items))
        test_pre_ratings = torch.cat((test_pre_ratings, pre_ratings))
        test_ratings = torch.cat((test_ratings, ratings))

    val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings, ['MSE', 'NLL', 'AUC'])
    test_results = utils.metrics.evaluate(test_pre_ratings, test_ratings, ['MSE', 'NLL', 'AUC', 'Recall_Precision_NDCG@'], users=test_users, items=test_items)
    print('-'*30)
    print('The performance of validation set: {}'.format(' '.join([key+':'+'%.3f'%val_results[key] for key in val_results])))
    print('The performance of testing set: {}'.format(' '.join([key+':'+'%.3f'%test_results[key] for key in test_results])))
    print('-'*30)
    return val_results,test_results
예제 #7
0
def train_and_eval(train_data,
                   val_data,
                   test_data,
                   device='cuda',
                   model_class=MF,
                   base_model_args: dict = {
                       'emb_dim': 64,
                       'learning_rate': 0.05,
                       'weight_decay': 0.05
                   },
                   position_model_args: dict = {
                       'learning_rate': 0.05,
                       'weight_decay': 0.05
                   },
                   training_args: dict = {
                       'batch_size': 1024,
                       'epochs': 100,
                       'patience': 20,
                       'block_batch': [1000, 100]
                   }):
    train_position = train_data['position']
    train_rating = train_data['rating']

    # build data_loader.
    train_loader = utils.data_loader.User(
        train_position,
        train_rating,
        u_batch_size=training_args['block_batch'][0],
        device=device)
    val_loader = utils.data_loader.DataLoader(
        utils.data_loader.Interactions(val_data),
        batch_size=training_args['batch_size'],
        shuffle=False,
        num_workers=0)
    test_loader = utils.data_loader.DataLoader(
        utils.data_loader.Interactions(test_data),
        batch_size=training_args['batch_size'],
        shuffle=False,
        num_workers=0)

    n_user, n_item = train_position.shape
    n_position = torch.max(train_position._values()).item() + 1

    criterion = nn.MSELoss(reduction='sum')

    # train Heckman model
    heckman_model = MF(n_user, n_item,
                       dim=base_model_args['emb_dim']).to(device)
    heckman_optimizer = torch.optim.SGD(heckman_model.parameters(),
                                        lr=base_model_args['learning_rate'],
                                        weight_decay=0)

    stopping_args = Stop_args(patience=training_args['patience'],
                              max_epochs=training_args['epochs'])
    early_stopping = EarlyStopping(heckman_model, **stopping_args)

    for epo in range(early_stopping.max_epochs):
        training_loss = 0
        for u_batch_idx, users in enumerate(train_loader.User_loader):
            heckman_model.train()

            # observation data in this batch
            users_train, items_train, positions_train, y_train = train_loader.get_batch(
                users)
            y_hat_obs = heckman_model(users_train, items_train)
            loss_obs = torch.sum((y_hat_obs - 1)**2)

            all_pair = torch.cartesian_prod(users,
                                            torch.arange(n_item).to(device))
            users_all, items_all = all_pair[:, 0], all_pair[:, 1]
            y_hat_all = heckman_model(users_all, items_all)
            loss_all = torch.sum((y_hat_all + 1)**2)

            loss = loss_obs + base_model_args[
                'imputaion_lambda'] * loss_all  # + 0.000 * heckman_model.l2_norm(users_all, items_all)

            heckman_optimizer.zero_grad()
            loss.backward()
            heckman_optimizer.step()

            training_loss += loss.item()

        # print('Epoch: {0:2d} / {1}, Traning: {2}'.
        #         format(epo, training_args['epochs'], training_loss))

        if early_stopping.check([-training_loss], epo):
            break

    print('Loading {}th epoch for heckman model'.format(
        early_stopping.best_epoch))
    heckman_model.load_state_dict(early_stopping.best_state)
    heckman_model.eval()

    # train click model
    click_model = MF_heckman(n_user, n_item,
                             dim=base_model_args['emb_dim']).to(device)
    click_optimizer = torch.optim.SGD(click_model.parameters(),
                                      lr=base_model_args['learning_rate'],
                                      weight_decay=0)

    stopping_args = Stop_args(patience=training_args['patience'],
                              max_epochs=training_args['epochs'])
    early_stopping = EarlyStopping(click_model, **stopping_args)

    for epo in range(early_stopping.max_epochs):
        training_loss = 0
        for u_batch_idx, users in enumerate(train_loader.User_loader):
            click_model.train()

            # observation data in this batch
            users_train, items_train, positions_train, y_train = train_loader.get_batch(
                users)
            heckman_train = heckman_model(users_train, items_train)
            lam_train = (2 * math.pi)**(-0.5) * torch.exp(
                -heckman_train**2 / 2) / torch.sigmoid(1.7 * heckman_train)

            y_hat = click_model(users_train, items_train, lam_train)
            loss = criterion(y_hat, y_train) + base_model_args[
                'weight_decay'] * click_model.l2_norm(users_train, items_train)

            click_optimizer.zero_grad()
            loss.backward()
            click_optimizer.step()

            training_loss += loss.item()

        click_model.eval()
        with torch.no_grad():
            # training metrics
            train_pre_ratings = torch.empty(0).to(device)
            train_ratings = torch.empty(0).to(device)
            for u_batch_idx, users in enumerate(train_loader.User_loader):
                users_train, items_train, positions_train, y_train = train_loader.get_batch(
                    users)
                heckman_train = heckman_model(users_train, items_train)
                lams = (2 * math.pi)**(-0.5) * torch.exp(
                    -heckman_train**2 / 2) / torch.sigmoid(1.7 * heckman_train)

                pre_ratings = click_model(users_train, items_train, lams)
                train_pre_ratings = torch.cat((train_pre_ratings, pre_ratings))
                train_ratings = torch.cat((train_ratings, y_train))

            # validation metrics
            val_pre_ratings = torch.empty(0).to(device)
            val_ratings = torch.empty(0).to(device)
            for batch_idx, (users, items, ratings) in enumerate(val_loader):
                heckman_train = heckman_model(users, items)
                lams = (2 * math.pi)**(-0.5) * torch.exp(
                    -heckman_train**2 / 2) / torch.sigmoid(1.7 * heckman_train)
                pre_ratings = click_model(users, items, lams)
                val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
                val_ratings = torch.cat((val_ratings, ratings))

        train_results = utils.metrics.evaluate(train_pre_ratings,
                                               train_ratings, ['MSE', 'NLL'])
        val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                             ['MSE', 'NLL', 'AUC'])

        print('Epoch: {0:2d} / {1}, Traning: {2}, Validation: {3}'.format(
            epo, training_args['epochs'], ' '.join([
                key + ':' + '%.3f' % train_results[key]
                for key in train_results
            ]), ' '.join([
                key + ':' + '%.3f' % val_results[key] for key in val_results
            ])))

        if early_stopping.check([val_results['AUC']], epo):
            break

    print('Loading {}th epoch for click model'.format(
        early_stopping.best_epoch))
    click_model.load_state_dict(early_stopping.best_state)

    # validation metrics
    val_pre_ratings_1 = torch.empty(0).to(device)
    val_ratings_1 = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(val_loader):
        heckman_train = heckman_model(users, items)
        lams = (2 * math.pi)**(-0.5) * torch.exp(
            -heckman_train**2 / 2) / torch.sigmoid(1.7 * heckman_train)
        pre_ratings = click_model(users, items, lams)
        val_pre_ratings_1 = torch.cat((val_pre_ratings_1, pre_ratings))
        val_ratings_1 = torch.cat((val_ratings_1, ratings))

    # test metrics
    test_users_1 = torch.empty(0, dtype=torch.int64).to(device)
    test_items_1 = torch.empty(0, dtype=torch.int64).to(device)
    test_pre_ratings_1 = torch.empty(0).to(device)
    test_ratings_1 = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(test_loader):
        heckman_train = heckman_model(users, items)
        lams = (2 * math.pi)**(-0.5) * torch.exp(
            -heckman_train**2 / 2) / torch.sigmoid(1.7 * heckman_train)
        pre_ratings = click_model(users, items, lams)
        test_users_1 = torch.cat((test_users_1, users))
        test_items_1 = torch.cat((test_items_1, items))
        test_pre_ratings_1 = torch.cat((test_pre_ratings_1, pre_ratings))
        test_ratings_1 = torch.cat((test_ratings_1, ratings))

    val_results_1 = utils.metrics.evaluate(val_pre_ratings_1, val_ratings_1,
                                           ['MSE', 'NLL', 'AUC'])
    test_results_1 = utils.metrics.evaluate(
        test_pre_ratings_1,
        test_ratings_1, ['MSE', 'NLL', 'AUC', 'Recall_Precision_NDCG@'],
        users=test_users_1,
        items=test_items_1)
    print('-' * 30)
    print('The performance of validation set: {}'.format(' '.join(
        [key + ':' + '%.3f' % val_results_1[key] for key in val_results_1])))
    print('The performance of testing set: {}'.format(' '.join(
        [key + ':' + '%.3f' % test_results_1[key] for key in test_results_1])))
    print('-' * 30)

    # train position model (DLA)
    base_model = MF(n_user, n_item, dim=base_model_args['emb_dim']).to(device)
    base_optimizer = torch.optim.SGD(base_model.parameters(),
                                     lr=base_model_args['learning_rate'],
                                     weight_decay=0)

    position_model = Position(n_position).to(device)
    position_optimizer = torch.optim.SGD(
        position_model.parameters(),
        lr=position_model_args['learning_rate'],
        weight_decay=0)

    # begin training
    stopping_args = Stop_args(patience=training_args['patience'],
                              max_epochs=training_args['epochs'])
    early_stopping = EarlyStopping(base_model, **stopping_args)

    none_criterion = nn.MSELoss(reduction='none')

    for epo in range(early_stopping.max_epochs):
        for u_batch_idx, users in enumerate(train_loader.User_loader):
            base_model.train()
            position_model.train()

            # observation data in this batch
            users_train, items_train, positions_train, y_train = train_loader.get_batch(
                users)
            first_index_row = torch.where(
                positions_train == 0)[0] // n_position
            first_index_col = torch.where(positions_train == 0)[0] % n_position

            y_hat = base_model(users_train, items_train)
            p_hat = position_model(positions_train)

            # users_train = users_train.view(-1, n_position)
            # items_train = items_train.view(-1, n_position)
            # positions_train = positions_train.view(-1, n_position)
            y_train = y_train.view(-1, n_position)

            y_hat = y_hat.view(-1, n_position)
            p_hat = p_hat.view(-1, n_position)

            y_hat_som = torch.softmax(y_hat, dim=-1)  # softmax
            p_hat_som = torch.softmax(p_hat, dim=-1)

            IRW = torch.detach(
                y_hat_som[first_index_row, first_index_col].unsqueeze(dim=-1) *
                torch.reciprocal(y_hat_som))
            IPW = torch.detach(
                p_hat_som[first_index_row, first_index_col].unsqueeze(dim=-1) *
                torch.reciprocal(p_hat_som))

            cost_base = none_criterion(y_hat, y_train)
            # IPW = torch.ones(IPW.shape).to(device)
            loss_IPW = torch.sum(IPW * cost_base) + 5 * base_model_args[
                'weight_decay'] * base_model.l2_norm(users_train, items_train)
            cost_position = none_criterion(p_hat, y_train)
            loss_IRW = torch.sum(IRW * cost_position) + position_model_args[
                'weight_decay'] * position_model.l2_norm(positions_train)

            # y_train = y_train == 1
            # loss_IRW = - torch.sum(y_train * IRW * torch.log(p_hat_som)) + position_model_args['weight_decay'] * position_model.l2_norm(positions_train)
            # loss_IPW = - torch.sum(y_train * IPW * torch.log(y_hat_som)) + base_model_args['weight_decay'] * base_model.l2_norm(users_train, items_train)

            base_optimizer.zero_grad()
            loss_IPW.backward()
            base_optimizer.step()

            position_optimizer.zero_grad()
            loss_IRW.backward()
            position_optimizer.step()

        base_model.eval()
        with torch.no_grad():
            # training metrics
            train_pre_ratings = torch.empty(0).to(device)
            train_ratings = torch.empty(0).to(device)
            for u_batch_idx, users in enumerate(train_loader.User_loader):
                users_train, items_train, positions_train, y_train = train_loader.get_batch(
                    users)
                pre_ratings = base_model(users_train, items_train)
                train_pre_ratings = torch.cat((train_pre_ratings, pre_ratings))
                train_ratings = torch.cat((train_ratings, y_train))

            # validation metrics
            val_pre_ratings = torch.empty(0).to(device)
            val_ratings = torch.empty(0).to(device)
            for batch_idx, (users, items, ratings) in enumerate(val_loader):
                pre_ratings = base_model(users, items)
                val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
                val_ratings = torch.cat((val_ratings, ratings))

        train_results = utils.metrics.evaluate(train_pre_ratings,
                                               train_ratings, ['MSE', 'NLL'])
        val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                             ['MSE', 'NLL', 'AUC'])

        print('Epoch: {0:2d} / {1}, Traning: {2}, Validation: {3}'.format(
            epo, training_args['epochs'], ' '.join([
                key + ':' + '%.3f' % train_results[key]
                for key in train_results
            ]), ' '.join([
                key + ':' + '%.3f' % val_results[key] for key in val_results
            ])))

        if early_stopping.check([val_results['AUC']], epo):
            break

    # testing loss
    print('Loading {}th epoch'.format(early_stopping.best_epoch))
    base_model.load_state_dict(early_stopping.best_state)

    # validation metrics
    val_pre_ratings_2 = torch.empty(0).to(device)
    val_ratings_2 = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(val_loader):
        pre_ratings = base_model(users, items)
        val_pre_ratings_2 = torch.cat((val_pre_ratings_2, pre_ratings))
        val_ratings_2 = torch.cat((val_ratings_2, ratings))

    # test metrics
    test_users_2 = torch.empty(0, dtype=torch.int64).to(device)
    test_items_2 = torch.empty(0, dtype=torch.int64).to(device)
    test_pre_ratings_2 = torch.empty(0).to(device)
    test_ratings_2 = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(test_loader):
        pre_ratings = base_model(users, items)
        test_users_2 = torch.cat((test_users_2, users))
        test_items_2 = torch.cat((test_items_2, items))
        test_pre_ratings_2 = torch.cat((test_pre_ratings_2, pre_ratings))
        test_ratings_2 = torch.cat((test_ratings_2, ratings))

    val_results_2 = utils.metrics.evaluate(val_pre_ratings_2, val_ratings_2,
                                           ['MSE', 'NLL', 'AUC'])
    test_results_2 = utils.metrics.evaluate(
        test_pre_ratings_2,
        test_ratings_2, ['MSE', 'NLL', 'AUC', 'Recall_Precision_NDCG@'],
        users=test_users_2,
        items=test_items_2)
    print('-' * 30)
    print('The performance of validation set: {}'.format(' '.join(
        [key + ':' + '%.3f' % val_results_2[key] for key in val_results_2])))
    print('The performance of testing set: {}'.format(' '.join(
        [key + ':' + '%.3f' % test_results_2[key] for key in test_results_2])))
    print('-' * 30)

    alp = 0.4
    val_ratings = val_ratings_1
    val_pre_ratings = alp * val_pre_ratings_1 + (1 - alp) * val_pre_ratings_2

    test_users = test_users_1
    test_items = test_items_1
    test_ratings = test_ratings_1
    test_pre_ratings = alp * test_pre_ratings_1 + (1 -
                                                   alp) * test_pre_ratings_2

    val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                         ['MSE', 'NLL', 'AUC'])
    test_results = utils.metrics.evaluate(
        test_pre_ratings,
        test_ratings, ['MSE', 'NLL', 'AUC', 'Recall_Precision_NDCG@'],
        users=test_users,
        items=test_items)
    print('-' * 30)
    print('The performance of validation set: {}'.format(' '.join(
        [key + ':' + '%.3f' % val_results[key] for key in val_results])))
    print('The performance of testing set: {}'.format(' '.join(
        [key + ':' + '%.3f' % test_results[key] for key in test_results])))
    print('-' * 30)

    return val_results, test_results