def train(X_train, X_val, model, optimizer, logger, num_epochs, batch_size): """Train FFW model. Arguments: X (sparse matrix): output by encode_ffw.py model (torch Module) optimizer (torch optimizer) logger: wrapper for TensorboardX logger num_epochs (int): number of epochs to train for batch_size (int) """ criterion = nn.BCEWithLogitsLoss() metrics = Metrics() train_idxs = np.arange(X_train.shape[0]) val_idxs = np.arange(X_val.shape[0]) step = 0 for epoch in tqdm(range(num_epochs)): shuffle(train_idxs) shuffle(val_idxs) # Training for k in range(0, len(train_idxs), batch_size): inputs, item_ids, labels = get_tensors( X_train[train_idxs[k:k + batch_size]]) inputs = inputs.to(device=args.device) preds = model(inputs) relevant_preds = preds[torch.arange(preds.shape[0]), item_ids.to(device=args.device)] loss = criterion(relevant_preds, labels.to(device=args.device)) train_auc = compute_auc(preds.detach().cpu(), item_ids, labels) model.zero_grad() loss.backward() optimizer.step() step += 1 metrics.store({"loss/train": loss.item()}) metrics.store({"auc/train": train_auc}) # Logging if step % 20 == 0: logger.log_scalars(metrics.average(), step * batch_size) # Validation model.eval() for k in range(0, len(val_idxs), batch_size): inputs, item_ids, labels = get_tensors(X_val[val_idxs[k:k + batch_size]]) inputs = inputs.to(device=args.device) with torch.no_grad(): preds = model(inputs) val_auc = compute_auc(preds.cpu(), item_ids, labels) metrics.store({"auc/val": val_auc}) model.train()
def train(df, model, optimizer, logger, num_epochs, batch_size): """Train SAKT model. Arguments: df (pandas DataFrame): output by prepare_data.py model (torch Module) optimizer (torch optimizer) logger: wrapper for TensorboardX logger num_epochs (int): number of epochs to train for batch_size (int) """ train_data, val_data = get_data(df) criterion = nn.BCEWithLogitsLoss() metrics = Metrics() step = 0 for epoch in range(num_epochs): train_batches = prepare_batches(train_data, batch_size) val_batches = prepare_batches(val_data, batch_size) # Training for inputs, item_ids, labels in train_batches: inputs = inputs.cuda() preds = model(inputs) loss = compute_loss(preds, item_ids.cuda(), labels.cuda(), criterion) #loss = compute_loss(preds, item_ids, labels, criterion) train_auc = compute_auc(preds.detach().cpu(), item_ids, labels) model.zero_grad() loss.backward() optimizer.step() step += 1 metrics.store({'loss/train': loss.item()}) metrics.store({'auc/train': train_auc}) # Logging if step % 20 == 0: logger.log_scalars(metrics.average(), step) weights = {"weight/" + name: param for name, param in model.named_parameters()} grads = {"grad/" + name: param.grad for name, param in model.named_parameters() if param.grad is not None} logger.log_histograms(weights, step) logger.log_histograms(grads, step) # Validation model.eval() for inputs, item_ids, labels in val_batches: inputs = inputs.cuda() with torch.no_grad(): preds = model(inputs) val_auc = compute_auc(preds.cpu(), item_ids, labels) metrics.store({'auc/val': val_auc}) model.train()