def do_epoch(self, gen):

        self.h_init = self.srnn.init_hidden()

        if self.use_utility_loss:
            x_batch, y_batch, x_c_batch, y_c, x_s_batch, y_s = gen.get_batch(
                as_tensor=True)
            x_batch = x_batch[:, 1:]
            self.optimizer.zero_grad()

            y_hat, h = self.srnn.forward(x_batch, self.h_init)

            x_c_batch = torch.transpose(x_c_batch, 0, 1)
            x_s_batch = torch.transpose(x_s_batch, 0, 1)

            set_dims = (self.seq_len, self.k, self.batch_size)

            y_hat_c = torch.zeros(set_dims)
            for i in range(x_c_batch.shape[0]):
                y_hat_c[i], h = self.srnn.forward(x_c_batch[i], self.h_init)

            y_hat_s = torch.zeros(set_dims)
            for i in range(x_c_batch.shape[0]):
                y_hat_s[i], h = self.srnn.forward(x_s_batch[i], self.h_init)

            y_hat_c = torch.transpose(y_hat_c, 0, 1)
            y_hat_s = torch.transpose(y_hat_s, 0, 1)

            loss_u = utility_loss(y_hat, y_hat_c, y_hat_s,
                                  np.transpose(y_batch), np.transpose(y_c),
                                  np.transpose(y_s))
            loss_u.backward(retain_graph=True)

            x_grad = self.srnn.get_input_grad(x_batch)
            x_c_grad = self.srnn.get_input_grad(x_c_batch)
            x_s_grad = self.srnn.get_input_grad(x_s_batch)

            loss = mrs_loss(loss_u, x_grad.unsqueeze(2),
                            torch.transpose(x_c_grad, 0, 1),
                            torch.transpose(x_s_grad, 0, 1))

        else:
            x_batch, y_batch = gen.get_batch(as_tensor=True)
            # only consider items as features
            x_batch = x_batch[:, 1:]
            self.optimizer.zero_grad()
            y_hat, h = self.srnn.forward(x_batch, self.h_init)
            loss = loss_mse(y_true=np.transpose(y_batch), y_hat=y_hat)

        return loss
    def fit_utility_loss(self):

        self.print_device_specs()

        if self.X_val is not None:
            _ = self.get_validation_loss(self.X_val[:, 1:], self.y_val)

        loss_arr = []

        iter = 0
        cum_loss = 0
        prev_loss = -1

        self.dataset = self.get_dataset(self.users, self.items, self.y_train,
                                        True)

        dataloader = DataLoader(dataset=self.dataset,
                                batch_size=self.batch_size,
                                shuffle=False,
                                num_workers=self.num_workers)
        num_batches = len(dataloader)

        for epoch in range(self.n_epochs):
            for i, batch in enumerate(dataloader):
                batch = self.generator.get_batch(as_tensor=True)

                batch['y'] = batch['y'].to(self.device)
                batch['y_c'] = batch['y_c'].to(self.device)
                batch['y_s'] = batch['y_s'].to(self.device)

                batch['items'] = batch['items'].requires_grad_(True).to(
                    self.device)
                batch['x_c'] = batch['x_c'].requires_grad_(True).to(
                    self.device)
                batch['x_s'] = batch['x_s'].requires_grad_(True).to(
                    self.device)
                batch['users'] = batch['users'].to(self.device)

                y_hat = self.model.forward(batch['users'],
                                           batch['items']).to(self.device)

                batch["x_c"] = batch["x_c"].view(self.batch_size,
                                                 self.c_size * self.seq_len,
                                                 -1)
                y_hat_c = self.model.forward(batch['users'],
                                             batch['x_c']).to(self.device)
                y_hat_c = y_hat_c.view(self.batch_size, self.seq_len,
                                       self.c_size)

                batch["x_s"] = batch["x_s"].view(self.batch_size,
                                                 self.s_size * self.seq_len,
                                                 -1)
                y_hat_s = self.model.forward(batch['users'],
                                             batch['x_s']).to(self.device)
                y_hat_s = y_hat_s.view(self.batch_size, self.seq_len,
                                       self.s_size)

                # TODO: Make this function flexible in the loss type (e.g., MSE, binary CE)
                loss_u = utility_loss(y_hat, torch.squeeze(y_hat_c),
                                      torch.squeeze(y_hat_s), batch['y'],
                                      batch['y_c'], batch['y_s'])

                if self.n_gpu > 1:
                    loss_u = loss_u.mean()

                x_grad = self._get_input_grad(loss_u, batch['items'])
                x_c_grad = self._get_input_grad(loss_u, batch['x_c'])
                x_s_grad = self._get_input_grad(loss_u, batch['x_s'])

                # x_grad = x_grad.view(self.batch_size, self.seq_len)
                x_c_grad = x_c_grad.view(self.batch_size, self.seq_len,
                                         self.c_size)
                x_s_grad = x_s_grad.view(self.batch_size, self.seq_len,
                                         self.s_size)

                loss = mrs_loss(loss_u,
                                x_grad.unsqueeze(-1),
                                x_c_grad,
                                x_s_grad,
                                lmbda=self.lmbda)

                if self.n_gpu > 1:
                    loss = loss.mean()

                # zero gradient
                self.optimizer.zero_grad()
                loss.backward()

                if self.grad_clip:
                    nn.utils.clip_grad_norm_(self.model.parameters(),
                                             self.grad_clip)

                self.optimizer.step()
                loss = loss.detach()
                cum_loss += loss

                if iter % self.loss_step == 0:
                    if iter == 0:
                        avg_loss = cum_loss
                    else:
                        avg_loss = cum_loss / self.loss_step
                    print("iteration: {} - loss: {:.5f}".format(
                        iter, avg_loss))
                    cum_loss = 0

                    loss_arr.append(avg_loss)

                    if abs(prev_loss - loss) < self.eps:
                        print(
                            'early stopping criterion met. Finishing training')
                        print("{:.4f} --> {:.5f}".format(prev_loss, loss))
                        break
                    else:
                        prev_loss = loss

                    if i == (num_batches - 1):
                        # Check if epoch is ending. Checkpoint and get evaluation metrics
                        self.checkpoint_model(suffix=iter)
                        if self.X_val is not None:
                            _ = self.get_validation_loss(
                                self.X_val[:, 1:], self.y_val)

                    iter += 1

                    stop = self._check_max_iter(iter)
                    if stop:
                        break

                if stop:
                    break

        self.checkpoint_model(suffix='done')
        return loss_arr
    def fit_pairwise_utility_loss(self):

        self.print_device_specs()

        if self.X_val is not None:
            _ = self.get_validation_loss(self.X_val[:, 1:], self.y_val)

        loss_arr = []

        iter = 0
        cum_loss = 0
        prev_loss = -1

        self.dataset = self.get_dataset(self.users, self.items, self.y_train,
                                        True)

        dataloader = DataLoader(dataset=self.dataset,
                                batch_size=self.batch_size,
                                shuffle=False)
        num_batches = len(dataloader)

        for epoch in range(self.n_epochs):
            for i, batch in enumerate(dataloader):

                batch['y'] = batch['y'].to(self.device)
                batch['y_c'] = batch['y_c'].to(self.device)
                batch['y_s'] = batch['y_s'].to(self.device)

                batch['items'] = batch['items'].requires_grad_(True).to(
                    self.device)
                batch['x_c'] = batch['x_c'].requires_grad_(True).to(
                    self.device)
                batch['x_s'] = batch['x_s'].requires_grad_(True).to(
                    self.device)
                batch['users'] = batch['users'].to(self.device)

                y_hat = self.model.forward(batch['users'],
                                           batch['items']).to(self.device)

                if y_hat.ndim == 3:
                    y_hat = y_hat.squeeze(-1)

                y_hat_c = torch.sigmoid(
                    self.model.forward(batch['users'],
                                       batch['x_c']).to(self.device))
                y_hat_s = torch.sigmoid(
                    self.model.forward(batch['users'],
                                       batch['x_s']).to(self.device))

                # classify difference of each x_ui to the first sample x_ij
                y_hat_diff = torch.sigmoid(y_hat - y_hat_s[:, 0, :])

                # TODO: Make this function flexible in the loss type (e.g., MSE, binary CE)
                loss_u = utility_loss(y_hat_diff, torch.squeeze(y_hat_c),
                                      torch.squeeze(y_hat_s), batch['y'],
                                      batch['y_c'], batch['y_s'], self.loss)

                if self.n_gpu > 1:
                    loss_u = loss_u.mean()

                x_grad = self._get_input_grad(loss_u, batch['items'])
                x_c_grad = self._get_input_grad(loss_u, batch['x_c'])
                x_s_grad = self._get_input_grad(loss_u, batch['x_s'])

                loss = mrs_loss(loss_u,
                                x_grad.reshape(-1, 1),
                                x_c_grad,
                                x_s_grad,
                                lmbda=self.lmbda)

                if self.n_gpu > 1:
                    loss = loss.mean()

                # zero gradient
                self.optimizer.zero_grad()

                loss.backward()
                self.optimizer.step()
                loss = loss.detach()
                cum_loss += loss

                if iter % self.loss_step == 0:
                    if iter == 0:
                        avg_loss = cum_loss
                    else:
                        avg_loss = cum_loss / self.loss_step
                    print("iteration: {} - loss: {:.5f}".format(
                        iter, avg_loss))
                    cum_loss = 0

                    loss_arr.append(avg_loss)

                    if abs(prev_loss - loss) < self.eps:
                        print(
                            'early stopping criterion met. Finishing training')
                        print("{:.4f} --> {:.5f}".format(prev_loss, loss))
                        break
                    else:
                        prev_loss = loss

                if i == (num_batches - 1):
                    # Check if epoch is ending. Checkpoint and get evaluation metrics
                    self.checkpoint_model(suffix=iter)
                    if self.X_val is not None:
                        _ = self.get_validation_loss(self.X_val[:, 1:],
                                                     self.y_val)

                iter += 1

                stop = self._check_max_iter(iter)
                if stop:
                    break

            if stop:
                break

        self.checkpoint_model(suffix='done')
        return loss_arr
Ejemplo n.º 4
0
    def fit_utility_loss(self):

        self.print_device_specs()

        if self.X_val is not None:
            _ = self.get_validation_loss(self.X_val[:, 1:], self.y_val)

        loss_arr = []

        iter = 0
        cum_loss = 0
        prev_loss = -1

        self.generator = self.get_generator(self.users, self.items,
                                            self.y_train, True)

        while self.generator.epoch_cntr < self.n_epochs:

            batch = self.generator.get_batch(as_tensor=True)

            batch['y'] = batch['y'].to(self.device)
            batch['y_c'] = batch['y_c'].to(self.device)
            batch['y_s'] = batch['y_s'].to(self.device)

            batch['items'] = batch['items'].requires_grad_(True).to(
                self.device)
            batch['x_c'] = batch['x_c'].requires_grad_(True).to(self.device)
            batch['x_s'] = batch['x_s'].requires_grad_(True).to(self.device)
            batch['users'] = batch['users'].to(self.device)

            y_hat = self.model.forward(batch['users'],
                                       batch['items']).to(self.device)
            y_hat_c = self.model.forward(batch['users'],
                                         batch['x_c']).to(self.device)
            y_hat_s = self.model.forward(batch['users'],
                                         batch['x_s']).to(self.device)

            # TODO: Make this function flexible in the loss type (e.g., MSE, binary CE)
            loss_u = utility_loss(y_hat, torch.squeeze(y_hat_c),
                                  torch.squeeze(y_hat_s), batch['y'],
                                  batch['y_c'], batch['y_s'], self.loss)

            if self.n_gpu > 1:
                loss_u = loss_u.mean()

            x_grad = self._get_input_grad(loss_u, batch['items'])
            x_c_grad = self._get_input_grad(loss_u, batch['x_c'])
            x_s_grad = self._get_input_grad(loss_u, batch['x_s'])

            loss = mrs_loss(loss_u,
                            x_grad.reshape(-1, 1),
                            x_c_grad,
                            x_s_grad,
                            lmbda=self.lmbda)

            if self.n_gpu > 1:
                loss = loss.mean()

            # zero gradient
            self.optimizer.zero_grad()

            loss.backward()
            self.optimizer.step()
            loss = loss.detach()
            cum_loss += loss

            if iter % self.loss_step == 0:
                if iter == 0:
                    avg_loss = cum_loss
                else:
                    avg_loss = cum_loss / self.loss_step
                print("iteration: {} - loss: {:.5f}".format(iter, avg_loss))
                cum_loss = 0

                loss_arr.append(avg_loss)

                if abs(prev_loss - loss) < self.eps:
                    print('early stopping criterion met. Finishing training')
                    print("{:.4f} --> {:.5f}".format(prev_loss, loss))
                    break
                else:
                    prev_loss = loss

            if self.generator.check():
                # Check if epoch is ending. Checkpoint and get evaluation metrics
                self.checkpoint_model(suffix=iter)
                if self.X_val is not None:
                    _ = self.get_validation_loss(self.X_val[:, 1:], self.y_val)

            iter += 1

            stop = self._check_max_iter(iter)
            if stop:
                break

        self.checkpoint_model(suffix='done')
        return loss_arr