def get_class_sizes(data: MoleculeDataset) -> List[List[float]]: """ Determines the proportions of the different classes in the classification dataset. :param data: A classification dataset :return: A list of lists of class proportions. Each inner list contains the class proportions for a task. """ targets = data.targets() # Filter out Nones valid_targets = [[] for _ in range(data.num_tasks())] for i in range(len(targets)): for task_num in range(len(targets[i])): if targets[i][task_num] is not None: valid_targets[task_num].append(targets[i][task_num]) class_sizes = [] for task_targets in valid_targets: # Make sure we're dealing with a binary classification task assert set(np.unique(task_targets)) <= {0, 1} try: ones = np.count_nonzero(task_targets) / len(task_targets) except ZeroDivisionError: ones = float('nan') print('Warning: class has no targets') class_sizes.append([1 - ones, ones]) return class_sizes
def _train(self, epoch: int, data: Union[MoleculeDataset, List[MoleculeDataset]], n_iter: int) -> int: """ Trains a model for an epoch. """ debug = self.logger.debug if self.logger is not None else print debug(f'Running epoch: {epoch}') self.net.train() data.shuffle() loss_sum, iter_count = 0, 0 num_iters = len(data) // self.args.batch_size * self.args.batch_size iter_size = self.args.batch_size for i in trange(0, num_iters, iter_size): # Prepare batch if i + self.args.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + self.args.batch_size]) smiles_batch, features_batch, target_batch = mol_batch.smiles( ), mol_batch.features(), mol_batch.targets() batch = smiles_batch mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) if next(self.net.parameters()).is_cuda: mask, targets = mask.cuda(), targets.cuda() class_weights = torch.ones(targets.shape) if self.use_cuda: class_weights = class_weights.cuda() # Run model self.net.zero_grad() preds, e = self.net(batch, features_batch) loss = self.loss_func(preds, targets) * class_weights * mask loss = loss.sum() / mask.sum() loss_sum += loss.item() iter_count += len(mol_batch) loss.backward() self.optimizer.step() if (n_iter // self.args.batch_size ) % self.args.learning_rate_decay_steps == 0: self.lr_schedule.step() n_iter += len(mol_batch) # Log and/or add to tensorboard if (n_iter // self.args.batch_size) % self.args.log_frequency == 0: lrs = self.lr_schedule.get_lr() loss_avg = loss_sum / iter_count loss_sum, iter_count = 0, 0 lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug(f'Loss = {loss_avg:.4e}, {lrs_str}') if self.writer is not None: self.writer.add_scalar('train_loss', loss_avg, n_iter) # for i, lr in enumerate(lrs): # self.writer.add_scalar(f'learning_rate_{i}', lr, n_iter) return n_iter