示例#1
0
    def fit_one(
        self,
        model,
        tasks: List[Task],
        dataloader: DataLoader,
        optimizer: Optimizer,
        scheduler: _LRScheduler,
        training_logger: ResultLogger,
        train_model=True,
    ) -> float:
        model.train(train_model)

        for task in tasks:
            task.train()

        task_names = [t.name for t in tasks]

        losses = 0.0
        total = 0.0
        for i, (x, y) in enumerate(dataloader):
            outputs = self._step(model, tasks, x, y, optimizer)

            training_logger.log(outputs, task_names, i + 1, len(dataloader))

            for o in outputs:
                losses += o.loss.item()
                total += 1.0
        scheduler.step()
        return losses / total
示例#2
0
    def train(
        self,
        problem: Problem,
        startEpoch: int,
        nEpochs: int,
        batchSize: int,
        scheduler: _LRScheduler,
    ) -> Iterator[FractionalPerformanceSummary]:
        self.cur_epoch = startEpoch

        d_types = [d.data_type for d in problem.datasets]
        assert Split.TRAIN in d_types, "training dataset should be included"

        ################### Start Training #####################################

        logger.info("Model layers")
        logger.info(str(self.model.modules))

        loaders = self._get_loaders(problem, batchSize=batchSize)

        # start epochs
        while self.cur_epoch < nEpochs:
            self.cur_epoch += 1
            logger.info("Starting epoch %d" % self.cur_epoch)
            epoch_stats = self._pass_one_epoch(problem, loaders, Mode.TRAIN)
            logger.info("Finished epoch %d" % self.cur_epoch)

            scheduler.step()

            yield self._get_worker_performance_summary(epoch_stats)
示例#3
0
    def project(self, latents: Latents, images: torch.Tensor, optimizer: Optimizer, num_steps: int, loss_function: Callable, lr_scheduler: _LRScheduler = None) -> Tuple[LatentPaths, Latents]:
        pbar = tqdm(range(num_steps), leave=False)
        latent_path = []
        noise_path = []

        best_latent = best_noise = best_psnr = None

        for i in pbar:
            img_gen, _ = self.generate(latents)

            batch, channel, height, width = img_gen.shape

            if height > 256:
                factor = height // 256

                img_gen = img_gen.reshape(
                    batch, channel, height // factor, factor, width // factor, factor
                )
                img_gen = img_gen.mean([3, 5])

            # # n_loss = noise_regularize(noises)
            loss, loss_dict = loss_function(img_gen, images)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_dict['psnr'] = self.psnr(img_gen, images).item()
            loss_dict['lr'] = optimizer.param_groups[0]["lr"]

            if lr_scheduler is not None:
                lr_scheduler.step()

            self.log.append(loss_dict)

            if best_psnr is None or best_psnr < loss_dict['psnr']:
                best_psnr = loss_dict['psnr']
                best_latent = latents.latent.detach().clone().cpu()
                best_noise = [noise.detach().clone().cpu() for noise in latents.noise]

            if i % self.debug_step == 0:
                latent_path.append(latents.latent.detach().clone().cpu())
                noise_path.append([noise.detach().clone().cpu() for noise in latents.noise])

            loss_description = "; ".join(f"{key}: {value:.6f}" for key, value in loss_dict.items())
            pbar.set_description(loss_description)

            loss_dict['iteration'] = i
            if self.abort_condition is not None and self.abort_condition(loss_dict):
                break

        latent_path.append(latents.latent.detach().clone().cpu())
        noise_path.append([noise.detach().clone().cpu() for noise in latents.noise])

        return LatentPaths(latent_path, noise_path), Latents(best_latent, best_noise)
示例#4
0
def enumerate_scheduler(scheduler: _LRScheduler, steps: int) -> List[float]:
    """
    Reads the current learning rate via get_last_lr, run 1 scheduler step, and repeat. Returns the LR values.
    """
    lrs = []
    for _ in range(steps):
        lr = scheduler.get_last_lr()  # type: ignore
        assert isinstance(lr, list)
        assert len(lr) == 1
        lrs.append(lr[0])
        scheduler.step()
    return lrs
示例#5
0
def train_step(net: nn.Module,
               crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool,
               inputs: torch.Tensor, targets: torch.Tensor, grad_clip:float) -> Tuple[torch.Tensor, float]:
    outputs = net(inputs)

    loss = crit(outputs, targets)
    optim.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(net.parameters(), grad_clip)

    optim.step()
    if sched and not sched_on_epoch:
        sched.step()
    return outputs, loss.item()
def train(model: nn.Module,
          loader: DataLoader,
          class_loss: nn.Module,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          epoch: int,
          callback: VisdomLogger,
          freq: int,
          ex: Experiment = None) -> None:
    model.train()
    device = next(model.parameters()).device
    to_device = lambda x: x.to(device, non_blocking=True)
    loader_length = len(loader)
    train_losses = AverageMeter(device=device, length=loader_length)
    train_accs = AverageMeter(device=device, length=loader_length)

    pbar = tqdm(loader, ncols=80, desc='Training   [{:03d}]'.format(epoch))
    for i, (batch, labels, indices) in enumerate(pbar):
        batch, labels, indices = map(to_device, (batch, labels, indices))
        logits, features = model(batch)
        loss = class_loss(logits, labels).mean()
        acc = (logits.detach().argmax(1) == labels).float().mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        train_losses.append(loss)
        train_accs.append(acc)

        if callback is not None and not (i + 1) % freq:
            step = epoch + i / loader_length
            callback.scalar('xent',
                            step,
                            train_losses.last_avg,
                            title='Train Losses')
            callback.scalar('train_acc',
                            step,
                            train_accs.last_avg,
                            title='Train Acc')

    if ex is not None:
        for i, (loss, acc) in enumerate(
                zip(train_losses.values_list, train_accs.values_list)):
            step = epoch + i / loader_length
            ex.log_scalar('train.loss', loss, step=step)
            ex.log_scalar('train.acc', acc, step=step)
示例#7
0
def train(model: nn.Module,
          data: MolPairDataset,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MolPairDataset (or a list of MolPairDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data.shuffle(
    )  # Very important this is done before conversion to maintain randomness in contrastive dataset.

    loss_sum, iter_count = 0, 0

    if args.loss_func == 'contrastive':
        data = convert2contrast(data)
    num_iters = len(
        data
    ) // args.batch_size * args.batch_size  # don't use the last batch if it's small, for stability

    iter_size = args.batch_size

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break
        mol_batch = MolPairDataset(data[i:i + args.batch_size])
        smiles_batch, features_batch, target_batch = mol_batch.smiles(
        ), mol_batch.features(), mol_batch.targets()
        batch = smiles_batch
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])
        if args.loss_func == 'contrastive':
            mask = targets
        else:
            mask = torch.Tensor([[x is not None for x in tb]
                                 for tb in target_batch])

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        if args.dataset_type == 'regression':
            class_weights = torch.ones(targets.shape)
        else:
            class_weights = targets * (args.class_weights - 1) + 1

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([
                loss_func(preds[:, target_index, :],
                          targets[:, target_index]).unsqueeze(1)
                for target_index in range(preds.size(1))
            ],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += 1

        loss.backward()
        if args.grad_clip:
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += args.batch_size

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
示例#8
0
    def _train(self, epoch: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None,
               validate_interval: int = 10, save: bool = False, amp: bool = False, verbose: bool = True, indent: int = 0,
               loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None,
               get_data_fn: Callable[..., tuple[InputType, torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               validate_func: Callable[..., tuple[float, ...]] = None, epoch_func: Callable[[], None] = None,
               save_fn: Callable = None, file_path: str = None, folder_path: str = None, suffix: str = None, **kwargs):
        loader_train = loader_train if loader_train is not None else self.dataset.loader['train']
        get_data_fn = get_data_fn if get_data_fn is not None else self.get_data
        loss_fn = loss_fn if loss_fn is not None else self.loss
        validate_func = validate_func if validate_func is not None else self._validate
        save_fn = save_fn if save_fn is not None else self.save

        scaler: torch.cuda.amp.GradScaler = None
        if amp and env['num_gpus']:
            scaler = torch.cuda.amp.GradScaler()
        _, best_acc, _ = validate_func(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn,
                                       verbose=verbose, indent=indent, **kwargs)
        losses = AverageMeter('Loss')
        top1 = AverageMeter('Acc@1')
        top5 = AverageMeter('Acc@5')
        params: list[list[nn.Parameter]] = [param_group['params'] for param_group in optimizer.param_groups]
        for _epoch in range(epoch):
            if epoch_func is not None:
                self.activate_params([])
                epoch_func()
                self.activate_params(params)
            losses.reset()
            top1.reset()
            top5.reset()
            epoch_start = time.perf_counter()
            loader = loader_train
            if verbose and env['tqdm']:
                loader = tqdm(loader_train)
            self.train()
            self.activate_params(params)
            optimizer.zero_grad()
            for data in loader:
                # data_time.update(time.perf_counter() - end)
                _input, _label = get_data_fn(data, mode='train')
                if amp and env['num_gpus']:
                    loss = loss_fn(_input, _label, amp=True)
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss = loss_fn(_input, _label)
                    loss.backward()
                    optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    _output = self.get_logits(_input)
                acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                losses.update(loss.item(), batch_size)
                top1.update(acc1, batch_size)
                top5.update(acc5, batch_size)
                empty_cache()
            epoch_time = str(datetime.timedelta(seconds=int(
                time.perf_counter() - epoch_start)))
            self.eval()
            self.activate_params([])
            if verbose:
                pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                    output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35)
                _str = ' '.join([
                    f'Loss: {losses.avg:.4f},'.ljust(20),
                    f'Top1 Acc: {top1.avg:.3f}, '.ljust(20),
                    f'Top5 Acc: {top5.avg:.3f},'.ljust(20),
                    f'Time: {epoch_time},'.ljust(20),
                ])
                prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '',
                       indent=indent)
            if lr_scheduler:
                lr_scheduler.step()

            if validate_interval != 0:
                if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1:
                    _, cur_acc, _ = validate_func(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn,
                                                  verbose=verbose, indent=indent, **kwargs)
                    if cur_acc >= best_acc:
                        prints('best result update!', indent=indent)
                        prints(f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}', indent=indent)
                        best_acc = cur_acc
                        if save:
                            save_fn(file_path=file_path, folder_path=folder_path, suffix=suffix, verbose=verbose)
                    if verbose:
                        print('-' * 50)
        self.zero_grad()
示例#9
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: Number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: Total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data.shuffle()

    loss_sum, iter_count = 0, 0

    # don't use the last batch if it's small, for stability
    num_iters = len(data) // args.batch_size * args.batch_size

    iter_size = args.batch_size

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break

        mol_batch = MoleculeDataset(data[i:i + args.batch_size])

        smiles_batch, features_batch, target_batch = \
            mol_batch.smiles(), mol_batch.features(), mol_batch.targets()

        mask = torch.Tensor([[not np.isnan(x) for x in tb]
                             for tb in target_batch])

        targets = torch.Tensor([[0 if np.isnan(x) else x for x in tb]
                                for tb in target_batch])

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        class_weights = torch.ones(targets.shape)

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(smiles_batch, features_batch)

        # todo: change the loss function for property prediction tasks

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([loss_func(preds[:, target_index, :],
                                        targets[:, target_index]).unsqueeze(1)
                              for target_index in range(preds.size(1))],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask

        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(mol_batch)

        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(
                f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
            debug(
                f'\nLoss = {loss_avg:.4e}, PNorm = {pnorm:.4f},'
                f' GNorm = {gnorm:.4f}, {lrs_str}')

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)

                for idx, learn_rate in enumerate(lrs):
                    writer.add_scalar(
                        f'learning_rate_{idx}', learn_rate, n_iter)

    return n_iter
示例#10
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data.shuffle()

    loss_sum, iter_count = 0, 0

    iter_size = args.batch_size

    if args.class_balance:
        # Reconstruct data so that each batch has equal number of positives and negatives
        # (will leave out a different random sample of negatives each epoch)
        assert len(
            data[0].targets) == 1  # only works for single class classification
        pos = [d for d in data if d.targets[0] == 1]
        neg = [d for d in data if d.targets[0] == 0]

        new_data = []
        pos_size = iter_size // 2
        pos_index = neg_index = 0
        while True:
            new_pos = pos[pos_index:pos_index + pos_size]
            new_neg = neg[neg_index:neg_index + iter_size - len(new_pos)]

            if len(new_pos) == 0 or len(new_neg) == 0:
                break

            if len(new_pos) + len(new_neg) < iter_size:
                new_pos = pos[pos_index:pos_index + iter_size - len(new_neg)]

            new_data += new_pos + new_neg

            pos_index += len(new_pos)
            neg_index += len(new_neg)

        data = new_data

    num_iters = len(
        data
    ) // args.batch_size * args.batch_size  # don't use the last batch if it's small, for stability

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break
        mol_batch = MoleculeDataset(data[i:i + args.batch_size])
        smiles_batch, features_batch, target_batch, weight_batch = mol_batch.smiles(
        ), mol_batch.features(), mol_batch.targets(), mol_batch.weights()
        batch = smiles_batch
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])
        weights = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in weight_batch])
        #        print (weight_batch)
        #        print (weights)

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        if args.enable_weight:
            class_weights = weights
        else:
            class_weights = torch.ones(targets.shape)
#        print(class_weights)

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([
                loss_func(preds[:, target_index, :],
                          targets[:, target_index]).unsqueeze(1)
                for target_index in range(preds.size(1))
            ],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
#        print ("loss")
#        print (loss)
#        print (class_weights)

        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(mol_batch)

        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
示例#11
0
def train(
    model: Union[nn.Module, nn.DataParallel],
    train_loader: DataLoader,
    metrics: Dict[str, Metric],
    optimizer: Optimizer,
    scheduler: _LRScheduler,
    device: torch.device,
    epoch: int,
    log_interval: int,
    hooks: Optional[Sequence[Hook]] = None,
    teacher: Optional[Union[nn.Module, nn.DataParallel]] = None,
) -> Dict[str, float]:
    """
    Train a model on some data using some criterion and with some optimizer.

    Args:
        model: Model to train
        train_loader: Data loader for loading training data
        metrics: A dict mapping evaluation metric names to metrics classes
        optimizer: PyTorch optimizer
        scheduler: PyTorch scheduler
        device: PyTorch device object
        epoch: Current epoch, where the first epoch should start at 1
        log_interval: Number of batches before printing loss
        hooks: A sequence of functions that can implement custom behavior
        teacher: teacher network for knowledge distillation, if any

    Returns:
        A dictionary mapping evaluation metric names to computed values for the training set.
    """
    if hooks is None:
        hooks = []

    model.train()
    for metric in metrics.values():
        metric.reset()

    loss_fn = model.module.loss_fn if isinstance(
        model, nn.DataParallel) else model.loss_fn

    seen_examples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        if teacher is None:
            teacher_output = None
            loss = loss_fn(output, target)  # type: ignore
        else:
            teacher_output = teacher(data)
            loss = loss_fn(output, teacher_output, target)  # type: ignore
        loss.backward()
        optimizer.step()
        project(optimizer)
        scheduler.step()  # type: ignore

        with torch.no_grad():
            for metric in metrics.values():
                metric.update(output, target, teacher_output=teacher_output)

        for hook in hooks:
            hook(
                epoch=epoch,
                global_step=1 + (epoch - 1) * len(train_loader.dataset) +
                batch_idx,
                values_dict={'lr': _get_lr(optimizer)},
                log_interval=log_interval,
            )

        seen_examples += len(data)
        if batch_idx % log_interval == 0:
            logger.info(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tBatch Loss: {:.6f}'.format(
                    epoch,
                    seen_examples,
                    len(train_loader.dataset),
                    100 * batch_idx / len(train_loader),
                    loss.item(),
                ))

    # Computing evaluation metrics for training set
    computed_metrics = {
        name: metric.compute()
        for name, metric in metrics.items()
    }

    logger.info('Training set evaluation metrics:')
    for name, metric in metrics.items():
        logger.info(f'{name}: {metric}')

    return computed_metrics
示例#12
0
def train_model(
    model: Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    loss: _Loss,
    optimizer: Optimizer,
    device: device,
    scheduler: _LRScheduler,
    num_epochs: int,
) -> Module:
    """Trains the segmentation model on the training data.

    Parameters
    ----------
    model : Module
        Model to train.
    train_loader : DataLoader
        Train data.
    val_loader : DataLoader
        Train data.
    loss : _Loss
        Loss function.
    optimizer : Optimizer
        Selected optimizer which updates weights of the model
    device : device
        Device on which is the model.
    scheduler : Union[None, _LRScheduler]
        Selected scheduler of the learning rate.
    val_split : float
        Ratio of the train-validation split.
    num_epochs : int
        Number of training epochs.

    Returns
    -------
    tuple
        Model with best validation loss during the training.
    """
    train_size = len(train_loader.dataset)
    val_size = len(val_loader.dataset)

    best_val_loss = 10**8
    best_model = None
    for epoch in range(num_epochs):
        print(f"Epoch {epoch}/{num_epochs - 1}")
        print("-" * 12)

        train_epoch_loss = 0.0
        model.train()
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            with set_grad_enabled(True):
                outputs = model(inputs)
                l = loss(outputs, labels)
                l.backward()
                optimizer.step()

                train_epoch_loss += l.item() * inputs.size(0)

        scheduler.step(train_epoch_loss)

        val_epoch_loss = 0.0
        model.eval()

        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            with set_grad_enabled(False):
                outputs = model(inputs)
                l = loss(outputs, labels)
                val_epoch_loss += l.item() * inputs.size(0)

        if val_epoch_loss < best_val_loss:
            best_val_loss = val_epoch_loss
            best_model = deepcopy(model.state_dict())

        print(
            f"Train loss: {train_epoch_loss/train_size}, Val. loss: {val_epoch_loss/val_size}"
        )

    model.load_state_dict(best_model)
    return model
示例#13
0
def train(model: MoleculeModel,
          data_loader: MoleculeDataLoader,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: TrainArgs,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: A :class:`~chemprop.models.model.MoleculeModel`.
    :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`.
    :param loss_func: Loss function.
    :param optimizer: An optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for recording output.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()
    loss_sum = iter_count = 0

    for batch in tqdm(data_loader, total=len(data_loader), leave=False):
        # Prepare batch
        batch: MoleculeDataset
        mol_batch, features_batch, target_batch, mask_batch, atom_descriptors_batch, atom_features_batch, bond_features_batch, data_weights_batch = \
            batch.batch_graph(), batch.features(), batch.targets(), batch.mask(), batch.atom_descriptors(), \
            batch.atom_features(), batch.bond_features(), batch.data_weights()

        mask = torch.tensor(mask_batch, dtype=torch.bool) # shape(batch, tasks)
        targets = torch.tensor([[0 if x is None else x for x in tb] for tb in target_batch]) # shape(batch, tasks)

        if args.target_weights is not None:
            target_weights = torch.tensor(args.target_weights).unsqueeze(0) # shape(1,tasks)
        else:
            target_weights = torch.ones(targets.shape[1]).unsqueeze(0)
        data_weights = torch.tensor(data_weights_batch).unsqueeze(1) # shape(batch,1)

        if args.loss_function == 'bounded_mse':
            lt_target_batch = batch.lt_targets() # shape(batch, tasks)
            gt_target_batch = batch.gt_targets() # shape(batch, tasks)
            lt_target_batch = torch.tensor(lt_target_batch)
            gt_target_batch = torch.tensor(gt_target_batch)

        # Run model
        model.zero_grad()
        preds = model(mol_batch, features_batch, atom_descriptors_batch, atom_features_batch, bond_features_batch)

        # Move tensors to correct device
        torch_device = preds.device
        mask = mask.to(torch_device)
        targets = targets.to(torch_device)
        target_weights = target_weights.to(torch_device)
        data_weights = data_weights.to(torch_device)
        if args.loss_function == 'bounded_mse':
            lt_target_batch = lt_target_batch.to(torch_device)
            gt_target_batch = gt_target_batch.to(torch_device)

        # Calculate losses
        if args.loss_function == 'mcc' and args.dataset_type == 'classification':
            loss = loss_func(preds, targets, data_weights, mask) *target_weights.squeeze(0)
        elif args.loss_function == 'mcc': # multiclass dataset type
            targets = targets.long()
            target_losses = []
            for target_index in range(preds.size(1)):
                target_loss = loss_func(preds[:, target_index, :], targets[:, target_index], data_weights, mask[:, target_index]).unsqueeze(0)
                target_losses.append(target_loss)
            loss = torch.cat(target_losses).to(torch_device) * target_weights.squeeze(0)
        elif args.dataset_type == 'multiclass':
            targets = targets.long()
            if args.loss_function == 'dirichlet':
                loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask
            else:
                target_losses = []
                for target_index in range(preds.size(1)):
                    target_loss = loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1)
                    target_losses.append(target_loss)
                loss = torch.cat(target_losses, dim=1).to(torch_device) * target_weights * data_weights * mask
        elif args.dataset_type == 'spectra':
            loss = loss_func(preds, targets, mask) * target_weights * data_weights * mask
        elif args.loss_function == 'bounded_mse':
            loss = loss_func(preds, targets, lt_target_batch, gt_target_batch) * target_weights * data_weights * mask
        elif args.loss_function == 'evidential':
            loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask
        elif args.loss_function == 'dirichlet': # classification
            loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask
        else:
            loss = loss_func(preds, targets) * target_weights * data_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += 1

        loss.backward()
        if args.grad_clip:
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum = iter_count = 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
            debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}')

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
示例#14
0
    def train(
        self,
        epochs: int,
        train_loader: DataLoader,
        optimizer: Optimizer,
        loss_fn: Callable[[torch.Tensor, torch.Tensor], List[Tuple[str, torch.Tensor]]],
        acc_fn: Callable[
            [List[Tuple[str, torch.Tensor]], torch.Tensor],
            List[Tuple[str, torch.Tensor]],
        ],
        scheduler: LRScheduler = None,
        test_loader: DataLoader = None,
        device: str = "cpu",
        use_tqdm: bool = False,
    ) -> History:
        """
        Trains the model.

        Note: 
            
            look at the notes in :meth:`athena.solvers.regression_solver.RegressionSolver.train_step` and in \
            :meth:`athena.solvers.regression_solver.RegressionSolver.test_step`.

        Args:
            epochs (int): The number of epochs to train for.
            train_loader (DataLoader): The ``DataLoader`` for the training data.
            optimizer (Optimizer): The optimizer to use.
            loss_fn (Callable[[torch.Tensor, torch.Tensor], List[Tuple[str, torch.Tensor]]]): The loss function to use. \
                the loss function should take in the predicted output of the model and target output from the dataset as \
                the arguments and return a list of tuples, in which the first element of each tuple is a label for the \
                loss and the second element is the loss value.
            acc_fn (Callable[[List[Tuple[str, torch.Tensor]], torch.Tensor], List[Tuple[str, torch.Tensor]]]): The accuracy \
                function to use. The function should take in two arguments, first, a list of tuples, where the first element \ 
                of each tuple is the label for the loss and the second element is the loss value, and the second argument \
                the target output from the dataset. The function should return a list of tuples, first element of the tuple \
                should be the label of the accuracy and the second element should be the accuracy value.
            scheduler (LRScheduler, optional): The ``LRScheduler`` to use. Defaults to None.
            test_loader (DataLoader, optional): The ``DataLoader`` for the test data. Defaults to None.
            device (str, optional): A valid pytorch device string. Defaults to ``cpu``.
            use_tqdm (bool, optional): If True, uses tqdm instead of a keras style progress bar (``pkbar``). Defaults to False.

        Returns:
            History: An History object containing training information.
        """
        history = History()

        for epoch in range(epochs):
            print("Epoch: %d / %d" % (epoch + 1, epochs), flush=use_tqdm)

            # performing train step
            train_data = self.train_step(
                train_loader, optimizer, scheduler, device, loss_fn, use_tqdm
            )

            # adding metrics to history
            for label, data in train_data:
                history.add_metric(label, data)

            # stepping scheduler
            if scheduler is not None and not isinstance(scheduler, OneCycleLR):
                scheduler.step()

            # performing test step
            if test_loader is not None:
                test_data = self.test_step(
                    test_loader, device, loss_fn, flush_print=use_tqdm
                )

                # adding metrics to history
                for label, data in test_data:
                    history.add_metric(label, data)

        return history
示例#15
0
    def train_step(
        self,
        train_loader: DataLoader,
        optimizer: Optimizer,
        scheduler: LRScheduler,
        device: str,
        loss_fn: Callable[[torch.Tensor, torch.Tensor], List[Tuple[str, torch.Tensor]]],
        acc_fn: Callable[
            [List[Tuple[str, torch.Tensor]], torch.Tensor],
            List[Tuple[str, torch.Tensor]],
        ],
        use_tqdm: bool,
    ) -> List[Tuple[str, torch.Tensor]]:
        """
        Performs a single train step.
            
        Note: 
            
            the losses and accuracies returned by the ``loss_fn`` and ``acc_fn`` are divided by the \
            number of batches in the dataset while recording them for an epoch (averaging). So make \
            sure any reduction in your functions are ``mean``.
            
        Args:
            train_loader (DataLoader): The ``DataLoader`` for the training data.
            optimizer (Optimizer): The optimizer to use.
            scheduler (LRScheduler): The LR scheduler to use.
            device (str): A valid pytorch device string.
            loss_fn (Callable[[torch.Tensor, torch.Tensor], List[Tuple[str, torch.Tensor]]]): The loss function to use. \
                the loss function should take in the predicted output of the model and target output from the dataset as \
                the arguments and return a list of tuples, in which the first element of each tuple is a label for the \
                loss and the second element is the loss value.
            acc_fn (Callable[[List[Tuple[str, torch.Tensor]], torch.Tensor], List[Tuple[str, torch.Tensor]]]): The accuracy \
                function to use. The function should take in two arguments, first, a list of tuples, where the first element \ 
                of each tuple is the label for the loss and the second element is the loss value, and the second argument \
                the target output from the dataset. The function should return a list of tuples, first element of the tuple \
                should be the label of the accuracy and the second element should be the accuracy value.
            use_tqdm (bool): If True, uses tqdm instead of a keras style progress bar (``pkbar``).

        Returns:
            List[Tuple[str, torch.Tensor]]: A list containing tuples in which the first element of the tuple is the label \
                describing the value and the second element is the value itself.
        """

        # setting model in train mode
        self.model.train()

        # creating progress bar
        if use_tqdm:
            pbar = tqdm(train_loader)
            iterator = pbar
        else:
            pbar = Kbar(len(train_loader), stateful_metrics=["loss", "accuracy"])
            iterator = train_loader

        # defining variables
        correct = 0
        processed = 0
        train_losses: np.ndarray = None
        train_accs: np.ndarray = None
        for batch_idx, (data, target) in enumerate(iterator):
            # casting to device
            data, target = data.to(device), target.to(device)

            # zeroing out accumulated gradients
            optimizer.zero_grad()

            # forward prop
            y_pred = self.model(data)

            # calculating loss (look at function documentation for details on what is returned by
            # the loss_fn)
            losses_data: List[Tuple[str, torch.Tensor]] = loss_fn(y_pred, target)
            if train_losses is None:
                train_losses = np.fromiter(
                    [x[-1] for x in losses_data], dtype=np.float32
                )
            else:
                train_losses = train_losses + np.fromiter(
                    [x[-1] for x in losses_data], dtype=np.float32
                )

            # backpropagation
            for _, loss in losses_data:
                loss.backward()
            optimizer.step()

            # calculating the accuracies (look at function documentation for details on what is returned by
            # the acc_fn)
            acc_data: List[Tuple[str, torch.Tensor]] = acc_fn(losses_data, target)
            if train_accs is None:
                train_accs = np.fromiter([x[-1] for x in acc_data], dtype=np.float32)
            else:
                train_accs = train_accs + np.fromiter(
                    [x[-1] for x in acc_data], dtype=np.float32
                )

            # updating progress bar with instantaneous losses and accuracies
            if use_tqdm:
                losses_desc = " - ".join(
                    [f"{name}: {value:0.4f}" for name, value in losses_data]
                )
                accs_desc = " - ".join(
                    [f"{name}: {value:0.4f}" for name, value in acc_data]
                )
                pbar.set_description(
                    desc=f"Batch_id: {batch_idx + 1} - {losses_desc} - {accs_desc}"
                )
            else:
                pbar.update(batch_idx, values=[*losses_data, *acc_data])

            if isinstance(scheduler, OneCycleLR):
                scheduler.step()

        if not use_tqdm:
            # required for pkbar
            pbar.add(1, values=[*losses_data, *acc_data])

        return [
            *list(
                zip(
                    # getting the labels of each loss value
                    [x[0] for x in losses_data],
                    # dividing the value of each of the losses by the number of batches in the dataset
                    [loss / len(train_loader) for loss in train_losses],
                )
            ),
            *list(
                zip(
                    # getting the labels of each accuracy value
                    [x[0] for x in acc_data],
                    # dividing the value of each accuracy by the number of batches in the dataset
                    [acc / len(train_loader) for acc in train_accs],
                )
            ),
        ]
示例#16
0
    def _train(self,
               epoch: int,
               optimizer: Optimizer,
               lr_scheduler: _LRScheduler = None,
               grad_clip: float = None,
               print_prefix: str = 'Epoch',
               start_epoch: int = 0,
               validate_interval: int = 10,
               save: bool = False,
               amp: bool = False,
               loader_train: torch.utils.data.DataLoader = None,
               loader_valid: torch.utils.data.DataLoader = None,
               epoch_fn: Callable[..., None] = None,
               get_data_fn: Callable[..., tuple[torch.Tensor,
                                                torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               after_loss_fn: Callable[..., None] = None,
               validate_fn: Callable[..., tuple[float, float]] = None,
               save_fn: Callable[..., None] = None,
               file_path: str = None,
               folder_path: str = None,
               suffix: str = None,
               writer=None,
               main_tag: str = 'train',
               tag: str = '',
               verbose: bool = True,
               indent: int = 0,
               **kwargs):
        loader_train = loader_train if loader_train is not None else self.dataset.loader[
            'train']
        get_data_fn = get_data_fn if callable(get_data_fn) else self.get_data
        loss_fn = loss_fn if callable(loss_fn) else self.loss
        validate_fn = validate_fn if callable(validate_fn) else self._validate
        save_fn = save_fn if callable(save_fn) else self.save
        # if not callable(iter_fn) and hasattr(self, 'iter_fn'):
        #     iter_fn = getattr(self, 'iter_fn')
        if not callable(epoch_fn) and hasattr(self, 'epoch_fn'):
            epoch_fn = getattr(self, 'epoch_fn')
        if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'):
            after_loss_fn = getattr(self, 'after_loss_fn')

        scaler: torch.cuda.amp.GradScaler = None
        if not env['num_gpus']:
            amp = False
        if amp:
            scaler = torch.cuda.amp.GradScaler()
        _, best_acc = validate_fn(loader=loader_valid,
                                  get_data_fn=get_data_fn,
                                  loss_fn=loss_fn,
                                  writer=None,
                                  tag=tag,
                                  _epoch=start_epoch,
                                  verbose=verbose,
                                  indent=indent,
                                  **kwargs)

        params: list[nn.Parameter] = []
        for param_group in optimizer.param_groups:
            params.extend(param_group['params'])
        total_iter = epoch * len(loader_train)
        for _epoch in range(epoch):
            _epoch += 1
            if callable(epoch_fn):
                self.activate_params([])
                epoch_fn(optimizer=optimizer,
                         lr_scheduler=lr_scheduler,
                         _epoch=_epoch,
                         epoch=epoch,
                         start_epoch=start_epoch)
                self.activate_params(params)
            logger = MetricLogger()
            logger.meters['loss'] = SmoothedValue()
            logger.meters['top1'] = SmoothedValue()
            logger.meters['top5'] = SmoothedValue()
            loader_epoch = loader_train
            if verbose:
                header = '{blue_light}{0}: {1}{reset}'.format(
                    print_prefix, output_iter(_epoch, epoch), **ansi)
                header = header.ljust(30 + get_ansi_len(header))
                if env['tqdm']:
                    header = '{upline}{clear_line}'.format(**ansi) + header
                    loader_epoch = tqdm(loader_epoch)
                loader_epoch = logger.log_every(loader_epoch,
                                                header=header,
                                                indent=indent)
            self.train()
            self.activate_params(params)
            optimizer.zero_grad()
            for i, data in enumerate(loader_epoch):
                _iter = _epoch * len(loader_train) + i
                # data_time.update(time.perf_counter() - end)
                _input, _label = get_data_fn(data, mode='train')
                _output = self(_input, amp=amp)
                loss = loss_fn(_input, _label, _output=_output, amp=amp)
                if amp:
                    scaler.scale(loss).backward()
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params)
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                        # start_epoch=start_epoch, _epoch=_epoch, epoch=epoch)
                    optimizer.step()
                optimizer.zero_grad()
                acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                logger.meters['loss'].update(float(loss), batch_size)
                logger.meters['top1'].update(acc1, batch_size)
                logger.meters['top5'].update(acc5, batch_size)
                empty_cache(
                )  # TODO: should it be outside of the dataloader loop?
            self.eval()
            self.activate_params([])
            loss, acc = logger.meters['loss'].global_avg, logger.meters[
                'top1'].global_avg
            if writer is not None:
                from torch.utils.tensorboard import SummaryWriter
                assert isinstance(writer, SummaryWriter)
                writer.add_scalars(main_tag='Loss/' + main_tag,
                                   tag_scalar_dict={tag: loss},
                                   global_step=_epoch + start_epoch)
                writer.add_scalars(main_tag='Acc/' + main_tag,
                                   tag_scalar_dict={tag: acc},
                                   global_step=_epoch + start_epoch)
            if lr_scheduler:
                lr_scheduler.step()
            if validate_interval != 0:
                if _epoch % validate_interval == 0 or _epoch == epoch:
                    _, cur_acc = validate_fn(loader=loader_valid,
                                             get_data_fn=get_data_fn,
                                             loss_fn=loss_fn,
                                             writer=writer,
                                             tag=tag,
                                             _epoch=_epoch + start_epoch,
                                             verbose=verbose,
                                             indent=indent,
                                             **kwargs)
                    if cur_acc >= best_acc:
                        if verbose:
                            prints('{green}best result update!{reset}'.format(
                                **ansi),
                                   indent=indent)
                            prints(
                                f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}',
                                indent=indent)
                        best_acc = cur_acc
                        if save:
                            save_fn(file_path=file_path,
                                    folder_path=folder_path,
                                    suffix=suffix,
                                    verbose=verbose)
                    if verbose:
                        prints('-' * 50, indent=indent)
        self.zero_grad()
示例#17
0
def train(model: nn.Module, data_loader: MoleculeDataLoader,
          loss_func: Callable, optimizer: Optimizer,
          scheduler: _LRScheduler, args: Namespace,
          n_iter: int = 0, logger: Optional[Logger] = None,
          writer: Optional = None) -> int:
    """Trains a model for an epoch

    Parameters
    ----------
    model : nn.Module
        the model to train
    data_loader : MoleculeDataLoader
        an iterable of MoleculeDatasets
    loss_func : Callable
        the loss function
    optimizer : Optimizer
        the optimizer
    scheduler : _LRScheduler
        the learning rate scheduler
    args : Namespace
        a Namespace object the containing necessary attributes
    n_iter : int
        the current number of training iterations
    logger : Optional[Logger] (Default = None)
        a logger for printing intermediate results. If None, print
        intermediate results to stdout
    writer : Optional[SummaryWriter] (Default = None)
        A tensorboardX SummaryWriter

    Returns
    -------
    n_iter : int
        The total number of iterations (training examples) trained on so far
    """
    model.train()
    # loss_sum = 0
    # iter_count = 0

    for batch in tqdm(data_loader, desc='Step', unit='minibatch',
                      leave=False):
        # Prepare batch
        mol_batch = batch.batch_graph()
        features_batch = batch.features()

        # Run model
        model.zero_grad()
        preds = model(mol_batch, features_batch)        

        targets = batch.targets()   # targets might have None's
        mask = torch.Tensor(
            [list(map(bool, ys)) for ys in targets]).to(preds.device)
        targets = torch.Tensor(
            [[y or 0 for y in ys] for ys in targets]).to(preds.device)
        class_weights = torch.ones(targets.shape).to(preds.device)
        
        # if args.dataset_type == 'multiclass':
        #     targets = targets.long()
        #     loss = (torch.cat([
        #         loss_func(preds[:, target_index, :],
        #                    targets[:, target_index]).unsqueeze(1)
        #         for target_index in range(preds.size(1))
        #         ], dim=1) * class_weights * mask
        #     )

        if model.uncertainty:
            pred_means = preds[:, 0::2]
            pred_vars = preds[:, 1::2]

            loss = loss_func(pred_means, pred_vars, targets)
        else:
            loss = loss_func(preds, targets) * class_weights * mask

        loss = loss.sum() / mask.sum()

        # loss_sum += loss.item()
        # iter_count += len(batch)

        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(batch)

        # Log and/or add to tensorboard
        # if (n_iter // args.batch_size) % args.log_frequency == 0:
        #     lrs = scheduler.get_lr()
        #     pnorm = compute_pnorm(model)
        #     gnorm = compute_gnorm(model)
        #     loss_avg = loss_sum / iter_count
        #     loss_sum, iter_count = 0, 0

        #     lrs_str = ', '.join(
        #         f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
        #     debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, '
        #           + f'GNorm = {gnorm:.4f}, {lrs_str}')

        #     if writer:
        #         writer.add_scalar('train_loss', loss_avg, n_iter)
        #         writer.add_scalar('param_norm', pnorm, n_iter)
        #         writer.add_scalar('gradient_norm', gnorm, n_iter)
        #         for i, lr in enumerate(lrs):
        #             writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
示例#18
0
文件: cv.py 项目: tbwxmu/SAMPN
def train_batch(args,
                fold_i,
                model: nn.Module,
                data: DataLoader,
                loss_func: Callable,
                optimizer: Optimizer,
                scheduler: _LRScheduler,
                logger: logging.Logger = None,
                writer: SummaryWriter = None):
    debug = logger.debug if logger is not None else print
    loss_sum, iter_count, epoch_loss = 0, 0, 0
    for it, result_batch in enumerate(tqdm(data)):

        model.zero_grad()
        batch = result_batch['sm']
        label_batch = result_batch['labels']

        mask = torch.Tensor([[x is not None for x in batch_t]
                             for batch_t in result_batch['labels']])
        targets = torch.Tensor([[0 if x is None else x for x in batch_t]
                                for batch_t in result_batch['labels']])
        args.num_tasks = len(result_batch['labels'][0])
        if args.dataset_type == 'classification':
            if args.class_balance:
                class_weights = []
                for task_num in range(args.n_task):
                    class_weights.append(args.class_weights[fold_i][task_num][
                        targets[:, task_num].long()])
                class_weights = torch.stack(class_weights).t()

            else:
                class_weights = torch.ones(targets.shape)
        if next(model.parameters()).is_cuda and args.gpuUSE:

            mask, targets = mask.cuda(), targets.cuda()

        preds = model(batch)

        if args.dataset_type == 'classification':
            loss = loss_func(preds, targets) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * mask

        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        epoch_loss += loss.item()
        iter_count += targets.size(0)
        loss.backward()
        optimizer.step()
        if isinstance(scheduler, NoamLR):
            scheduler.step()

        if it % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / (iter_count * targets.size(0))
            loss_sum, iter_count = 0, 0
            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))

            if writer is not None:
                writer.add_scalar('train_loss_batch', loss_avg, it)
                writer.add_scalar('param_norm_batch', pnorm, it)
                writer.add_scalar('gradient_norm_batch', gnorm, it)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}_batch', lr, it)

    return it, it * targets.size(0), loss_avg, epoch_loss
示例#19
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None,
          chunk_names: bool = False,
          val_smiles: List[str] = None,
          test_smiles: List[str] = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :param chunk_names: Whether to train on the data in chunks. In this case,
    data must be a list of paths to the data chunks.
    :param val_smiles: Validation smiles strings without targets.
    :param test_smiles: Test smiles strings without targets, used for adversarial setting.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    if args.dataset_type == 'bert_pretraining':
        features_loss = nn.MSELoss()

    if chunk_names:
        for path, memo_path in tqdm(data, total=len(data)):
            featurization.SMILES_TO_FEATURES = dict()
            if os.path.isfile(memo_path):
                found_memo = True
                with open(memo_path, 'rb') as f:
                    featurization.SMILES_TO_FEATURES = pickle.load(f)
            else:
                found_memo = False
            with open(path, 'rb') as f:
                chunk = pickle.load(f)
            if args.moe:
                for source in chunk:
                    source.shuffle()
            else:
                chunk.shuffle()
            n_iter = train(model=model,
                           data=chunk,
                           loss_func=loss_func,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           args=args,
                           n_iter=n_iter,
                           logger=logger,
                           writer=writer,
                           chunk_names=False,
                           val_smiles=val_smiles,
                           test_smiles=test_smiles)
            if not found_memo:
                with open(memo_path, 'wb') as f:
                    pickle.dump(featurization.SMILES_TO_GRAPH,
                                f,
                                protocol=pickle.HIGHEST_PROTOCOL)
        return n_iter

    if not args.moe:
        data.shuffle()

    loss_sum, iter_count = 0, 0
    if args.adversarial:
        if args.moe:
            train_smiles = []
            for d in data:
                train_smiles += d.smiles()
        else:
            train_smiles = data.smiles()
        train_val_smiles = train_smiles + val_smiles
        d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0

    if args.moe:
        test_smiles = list(test_smiles)
        random.shuffle(test_smiles)
        train_smiles = []
        for d in data:
            d.shuffle()
            train_smiles.append(d.smiles())
        num_iters = min(len(test_smiles), min([len(d) for d in data]))
    elif args.maml:
        num_iters = args.maml_batches_per_epoch * args.maml_batch_size
        model.zero_grad()
        maml_sum_loss = 0
    else:
        num_iters = len(data) if args.last_batch else len(
            data) // args.batch_size * args.batch_size

    if args.parallel_featurization:
        batch_queue = Queue(args.batch_queue_max_size)
        exit_queue = Queue(1)
        batch_process = Process(target=async_mol2graph,
                                args=(batch_queue, data, args, num_iters,
                                      args.batch_size, exit_queue,
                                      args.last_batch))
        batch_process.start()
        currently_loaded_batches = []

    iter_size = 1 if args.maml else args.batch_size

    for i in trange(0, num_iters, iter_size):
        if args.moe:
            if not args.batch_domain_encs:
                model.compute_domain_encs(
                    train_smiles)  # want to recompute every batch
            mol_batch = [
                MoleculeDataset(d[i:i + args.batch_size]) for d in data
            ]
            train_batch, train_targets = [], []
            for b in mol_batch:
                tb, tt = b.smiles(), b.targets()
                train_batch.append(tb)
                train_targets.append(tt)
            test_batch = test_smiles[i:i + args.batch_size]
            loss = model.compute_loss(train_batch, train_targets, test_batch)
            model.zero_grad()

            loss_sum += loss.item()
            iter_count += len(mol_batch)
        elif args.maml:
            task_train_data, task_test_data, task_idx = data.sample_maml_task(
                args)
            mol_batch = task_test_data
            smiles_batch, features_batch, target_batch = task_train_data.smiles(
            ), task_train_data.features(), task_train_data.targets(task_idx)
            # no mask since we only picked data points that have the desired target
            targets = torch.Tensor(target_batch).unsqueeze(1)
            if next(model.parameters()).is_cuda:
                targets = targets.cuda()
            preds = model(smiles_batch, features_batch)
            loss = loss_func(preds, targets)
            loss = loss.sum() / len(smiles_batch)
            grad = torch.autograd.grad(
                loss, [p for p in model.parameters() if p.requires_grad])
            theta = [
                p for p in model.named_parameters() if p[1].requires_grad
            ]  # comes in same order as grad
            theta_prime = {
                p[0]: p[1] - args.maml_lr * grad[i]
                for i, p in enumerate(theta)
            }
            for name, nongrad_param in [
                    p for p in model.named_parameters()
                    if not p[1].requires_grad
            ]:
                theta_prime[name] = nongrad_param + torch.zeros(
                    nongrad_param.size()).to(nongrad_param)
        else:
            # Prepare batch
            if args.parallel_featurization:
                if len(currently_loaded_batches) == 0:
                    currently_loaded_batches = batch_queue.get()
                mol_batch, featurized_mol_batch = currently_loaded_batches.pop(
                )
            else:
                if not args.last_batch and i + args.batch_size > len(data):
                    break
                mol_batch = MoleculeDataset(data[i:i + args.batch_size])
            smiles_batch, features_batch, target_batch = mol_batch.smiles(
            ), mol_batch.features(), mol_batch.targets()

            if args.dataset_type == 'bert_pretraining':
                batch = mol2graph(smiles_batch, args)
                mask = mol_batch.mask()
                batch.bert_mask(mask)
                mask = 1 - torch.FloatTensor(mask)  # num_atoms
                features_targets = torch.FloatTensor(
                    target_batch['features']
                ) if target_batch[
                    'features'] is not None else None  # num_molecules x features_size
                targets = torch.FloatTensor(target_batch['vocab'])  # num_atoms
                if args.bert_vocab_func == 'feature_vector':
                    mask = mask.reshape(-1, 1)
                else:
                    targets = targets.long()
            else:
                batch = smiles_batch
                mask = torch.Tensor([[x is not None for x in tb]
                                     for tb in target_batch])
                targets = torch.Tensor([[0 if x is None else x for x in tb]
                                        for tb in target_batch])

            if next(model.parameters()).is_cuda:
                mask, targets = mask.cuda(), targets.cuda()

                if args.dataset_type == 'bert_pretraining' and features_targets is not None:
                    features_targets = features_targets.cuda()

            if args.class_balance:
                class_weights = []
                for task_num in range(data.num_tasks()):
                    class_weights.append(
                        args.class_weights[task_num][targets[:,
                                                             task_num].long()])
                class_weights = torch.stack(
                    class_weights).t()  # num_molecules x num_tasks
            else:
                class_weights = torch.ones(targets.shape)

            if args.cuda:
                class_weights = class_weights.cuda()

            # Run model
            model.zero_grad()
            if args.parallel_featurization:
                previous_graph_input_mode = model.encoder.graph_input
                model.encoder.graph_input = True  # force model to accept already processed input
                preds = model(featurized_mol_batch, features_batch)
                model.encoder.graph_input = previous_graph_input_mode
            else:
                preds = model(batch, features_batch)
            if args.dataset_type == 'regression_with_binning':
                preds = preds.view(targets.size(0), targets.size(1), -1)
                targets = targets.long()
                loss = 0
                for task in range(targets.size(1)):
                    loss += loss_func(
                        preds[:, task, :], targets[:, task]
                    ) * class_weights[:,
                                      task] * mask[:,
                                                   task]  # for some reason cross entropy doesn't support multi target
                loss = loss.sum() / mask.sum()
            else:
                if args.dataset_type == 'unsupervised':
                    targets = targets.long().reshape(-1)

                if args.dataset_type == 'bert_pretraining':
                    features_preds, preds = preds['features'], preds['vocab']

                if args.dataset_type == 'kernel':
                    preds = preds.view(int(preds.size(0) / 2), 2,
                                       preds.size(1))
                    preds = model.kernel_output_layer(preds)

                loss = loss_func(preds, targets) * class_weights * mask
                if args.predict_features_and_task:
                    loss = (loss.sum() + loss[:, :-args.features_size].sum() * (args.task_weight-1)) \
                                / (mask.sum() + mask[:, :-args.features_size].sum() * (args.task_weight-1))
                else:
                    loss = loss.sum() / mask.sum()

                if args.dataset_type == 'bert_pretraining' and features_targets is not None:
                    loss += features_loss(features_preds, features_targets)

            loss_sum += loss.item()
            iter_count += len(mol_batch)

        if args.maml:
            model_prime = build_model(args=args, params=theta_prime)
            smiles_batch, features_batch, target_batch = task_test_data.smiles(
            ), task_test_data.features(), [
                t[task_idx] for t in task_test_data.targets()
            ]
            # no mask since we only picked data points that have the desired target
            targets = torch.Tensor([[t] for t in target_batch])
            if next(model_prime.parameters()).is_cuda:
                targets = targets.cuda()
            model_prime.zero_grad()
            preds = model_prime(smiles_batch, features_batch)
            loss = loss_func(preds, targets)
            loss = loss.sum() / len(smiles_batch)
            loss_sum += loss.item()
            iter_count += len(
                smiles_batch
            )  # TODO check that this makes sense, but it's just for display
            maml_sum_loss += loss
            if i % args.maml_batch_size == args.maml_batch_size - 1:
                maml_sum_loss.backward()
                optimizer.step()
                model.zero_grad()
                maml_sum_loss = 0
        else:
            loss.backward()
            if args.max_grad_norm is not None:
                clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()

        if args.adjust_weight_decay:
            current_pnorm = compute_pnorm(model)
            if current_pnorm < args.pnorm_target:
                for i in range(len(optimizer.param_groups)):
                    optimizer.param_groups[i]['weight_decay'] = max(
                        0, optimizer.param_groups[i]['weight_decay'] -
                        args.adjust_weight_decay_step)
            else:
                for i in range(len(optimizer.param_groups)):
                    optimizer.param_groups[i][
                        'weight_decay'] += args.adjust_weight_decay_step

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        if args.adversarial:
            for _ in range(args.gan_d_per_g):
                train_val_smiles_batch = random.sample(train_val_smiles,
                                                       args.batch_size)
                test_smiles_batch = random.sample(test_smiles, args.batch_size)
                d_loss, gp_norm = model.train_D(train_val_smiles_batch,
                                                test_smiles_batch)
            train_val_smiles_batch = random.sample(train_val_smiles,
                                                   args.batch_size)
            test_smiles_batch = random.sample(test_smiles, args.batch_size)
            g_loss = model.train_G(train_val_smiles_batch, test_smiles_batch)

            # we probably only care about the g_loss honestly
            d_loss_sum += d_loss * args.batch_size
            gp_norm_sum += gp_norm * args.batch_size
            g_loss_sum += g_loss * args.batch_size

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            if args.adversarial:
                d_loss_avg, g_loss_avg, gp_norm_avg = d_loss_sum / iter_count, g_loss_sum / iter_count, gp_norm_sum / iter_count
                d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join('lr_{} = {:.4e}'.format(i, lr)
                                for i, lr in enumerate(lrs))
            debug("Loss = {:.4e}, PNorm = {:.4f}, GNorm = {:.4f}, {}".format(
                loss_avg, pnorm, gnorm, lrs_str))
            if args.adversarial:
                debug(
                    "D Loss = {:.4e}, G Loss = {:.4e}, GP Norm = {:.4}".format(
                        d_loss_avg, g_loss_avg, gp_norm_avg))

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar('learning_rate_{}'.format(i), lr, n_iter)

    if args.parallel_featurization:
        exit_queue.put(
            0)  # dummy var to get the subprocess to know that we're done
        batch_process.join()

    return n_iter
示例#20
0
def train(model: nn.Module,
          data_loader: MoleculeDataLoader,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: TrainArgs,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None,
          gp_switch: bool = False,
          likelihood = None,
          bbp_switch = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data_loader: A MoleculeDataLoader.
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    
    

        
    debug = logger.debug if logger is not None else print
    
    model.train()
    if likelihood is not None:
        likelihood.train()
    
    loss_sum = 0
    if bbp_switch is not None:
        data_loss_sum = 0
        kl_loss_sum = 0
        kl_loss_depth_sum = 0

    #for batch in tqdm(data_loader, total=len(data_loader)):
    for batch in data_loader:
        # Prepare batch
        batch: MoleculeDataset

        # .batch_graph() returns BatchMolGraph
        # .features() returns None if no additional features
        # .targets() returns list of lists of floats containing the targets
        mol_batch, features_batch, target_batch = batch.batch_graph(), batch.features(), batch.targets()
        
        # mask is 1 where targets are not None
        mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch])
        # where targets are None, replace with 0
        targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch])

        # Move tensors to correct device
        mask = mask.to(args.device)
        targets = targets.to(args.device)
        class_weights = torch.ones(targets.shape, device=args.device)
        
        # zero gradients
        model.zero_grad()
        optimizer.zero_grad()
        
        
        ##### FORWARD PASS AND LOSS COMPUTATION #####
        
        
        if bbp_switch == None:
        
            # forward pass
            preds = model(mol_batch, features_batch)
    
            # compute loss
            if gp_switch:
                loss = -loss_func(preds, targets)
            else:
                loss = loss_func(preds, targets, torch.exp(model.log_noise))
        
        
        ### bbp non sample option
        if bbp_switch == 1:    
            preds, kl_loss = model(mol_batch, features_batch, sample = False)
            data_loss = loss_func(preds, targets, torch.exp(model.log_noise))
            kl_loss /= args.train_data_size
            loss = data_loss + kl_loss  
            
        ### bbp sample option
        if bbp_switch == 2:

            if args.samples_bbp == 1:
                preds, kl_loss = model(mol_batch, features_batch, sample=True)
                data_loss = loss_func(preds, targets, torch.exp(model.log_noise))
                kl_loss /= args.train_data_size
        
            elif args.samples_bbp > 1:
                data_loss_cum = 0
                kl_loss_cum = 0
        
                for i in range(args.samples_bbp):
                    preds, kl_loss_i = model(mol_batch, features_batch, sample=True)
                    data_loss_i = loss_func(preds, targets, torch.exp(model.log_noise))                    
                    kl_loss_i /= args.train_data_size                    
                    
                    data_loss_cum += data_loss_i
                    kl_loss_cum += kl_loss_i
        
                data_loss = data_loss_cum / args.samples_bbp
                kl_loss = kl_loss_cum / args.samples_bbp
            
            loss = data_loss + kl_loss

        ### DUN non sample option
        if bbp_switch == 3:
            cat = torch.exp(model.log_cat) / torch.sum(torch.exp(model.log_cat))    
            _, preds_list, kl_loss, kl_loss_depth = model(mol_batch, features_batch, sample=False)
            data_loss = loss_func(preds_list, targets, torch.exp(model.log_noise), cat)
            kl_loss /= args.train_data_size
            kl_loss_depth /= args.train_data_size
            loss = data_loss + kl_loss + kl_loss_depth
            #print('-----')
            #print(data_loss)
            #print(kl_loss)
            #print(cat)

        ### DUN sample option
        if bbp_switch == 4:

            cat = torch.exp(model.log_cat) / torch.sum(torch.exp(model.log_cat))

            if args.samples_dun == 1:
                _, preds_list, kl_loss, kl_loss_depth = model(mol_batch, features_batch, sample=True)
                data_loss = loss_func(preds_list, targets, torch.exp(model.log_noise), cat)
                kl_loss /= args.train_data_size
                kl_loss_depth /= args.train_data_size
        
            elif args.samples_dun > 1:
                data_loss_cum = 0
                kl_loss_cum = 0

                for i in range(args.samples_dun):
                    _, preds_list, kl_loss_i, kl_loss_depth = model(mol_batch, features_batch, sample=True)
                    data_loss_i = loss_func(preds_list, targets, torch.exp(model.log_noise), cat)
                    kl_loss_i /= args.train_data_size                    
                    kl_loss_depth /= args.train_data_size

                    data_loss_cum += data_loss_i
                    kl_loss_cum += kl_loss_i
        
                data_loss = data_loss_cum / args.samples_dun
                kl_loss = kl_loss_cum / args.samples_dun
            
            loss = data_loss + kl_loss + kl_loss_depth

            #print('-----')
            #print(data_loss)
            #print(kl_loss)
            #print(cat)
            
        #############################################
        
        
        # backward pass; update weights
        loss.backward()
        optimizer.step()
        
        
        #for name, parameter in model.named_parameters():
            #print(name)#, parameter.grad)
            #print(np.sum(np.array(parameter.grad)))

        # add to loss_sum and iter_count
        loss_sum += loss.item() * len(batch)
        if bbp_switch is not None:
            data_loss_sum += data_loss.item() * len(batch)
            kl_loss_sum += kl_loss.item() * len(batch)
            if bbp_switch > 2:
                kl_loss_depth_sum += kl_loss_depth * len(batch)

        # update learning rate by taking a step
        if isinstance(scheduler, NoamLR) or isinstance(scheduler, OneCycleLR):
            scheduler.step()

        # increment n_iter (total number of examples across epochs)
        n_iter += len(batch)

        

            
        ########### per epoch REPORTING
        if n_iter % args.train_data_size == 0:
            lrs = scheduler.get_last_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            
            loss_avg = loss_sum / args.train_data_size

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
            debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}')
            
            if bbp_switch is not None:
                data_loss_avg = data_loss_sum / args.train_data_size
                kl_loss_avg = kl_loss_sum / args.train_data_size
                wandb.log({"Total loss": loss_avg}, commit=False)
                wandb.log({"Likelihood cost": data_loss_avg}, commit=False)
                wandb.log({"KL cost": kl_loss_avg}, commit=False)

                if bbp_switch > 2:
                    kl_loss_depth_avg = kl_loss_depth_sum / args.train_data_size
                    wandb.log({"KL cost DEPTH": kl_loss_depth_avg}, commit=False)

                    # log variational categorical distribution
                    wandb.log({"d_1": cat.detach().cpu().numpy()[0]}, commit=False)
                    wandb.log({"d_2": cat.detach().cpu().numpy()[1]}, commit=False)
                    wandb.log({"d_3": cat.detach().cpu().numpy()[2]}, commit=False)
                    wandb.log({"d_4": cat.detach().cpu().numpy()[3]}, commit=False)
                    wandb.log({"d_5": cat.detach().cpu().numpy()[4]}, commit=False)

            else:
                if gp_switch:
                    wandb.log({"Negative ELBO": loss_avg}, commit=False)
                else:
                    wandb.log({"Negative log likelihood (scaled)": loss_avg}, commit=False)
            
            if args.pdts:
                wandb.log({"Learning rate": lrs[0]}, commit=True)
            else:
                wandb.log({"Learning rate": lrs[0]}, commit=False)
            
    if args.pdts and args.swag:
        return loss_avg, n_iter
    else:
        return n_iter
示例#21
0
def train(module: nn.Module,
          num_classes: int,
          epochs: int,
          optimizer: Optimizer,
          lr_scheduler: _LRScheduler = None,
          lr_warmup_epochs: int = 0,
          model_ema: ExponentialMovingAverage = None,
          model_ema_steps: int = 32,
          grad_clip: float = None,
          pre_conditioner: None | KFAC | EKFAC = None,
          print_prefix: str = 'Train',
          start_epoch: int = 0,
          resume: int = 0,
          validate_interval: int = 10,
          save: bool = False,
          amp: bool = False,
          loader_train: torch.utils.data.DataLoader = None,
          loader_valid: torch.utils.data.DataLoader = None,
          epoch_fn: Callable[..., None] = None,
          get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None,
          forward_fn: Callable[..., torch.Tensor] = None,
          loss_fn: Callable[..., torch.Tensor] = None,
          after_loss_fn: Callable[..., None] = None,
          validate_fn: Callable[..., tuple[float, float]] = None,
          save_fn: Callable[..., None] = None,
          file_path: str = None,
          folder_path: str = None,
          suffix: str = None,
          writer=None,
          main_tag: str = 'train',
          tag: str = '',
          accuracy_fn: Callable[..., list[float]] = None,
          verbose: bool = True,
          output_freq: str = 'iter',
          indent: int = 0,
          change_train_eval: bool = True,
          lr_scheduler_freq: str = 'epoch',
          backward_and_step: bool = True,
          **kwargs):
    r"""Train the model"""
    if epochs <= 0:
        return
    get_data_fn = get_data_fn or (lambda x: x)
    forward_fn = forward_fn or module.__call__
    loss_fn = loss_fn or (lambda _input, _label, _output=None: F.cross_entropy(
        _output or forward_fn(_input), _label))
    validate_fn = validate_fn or validate
    accuracy_fn = accuracy_fn or accuracy

    scaler: torch.cuda.amp.GradScaler = None
    if not env['num_gpus']:
        amp = False
    if amp:
        scaler = torch.cuda.amp.GradScaler()
    best_validate_result = (0.0, float('inf'))
    if validate_interval != 0:
        best_validate_result = validate_fn(loader=loader_valid,
                                           get_data_fn=get_data_fn,
                                           forward_fn=forward_fn,
                                           loss_fn=loss_fn,
                                           writer=None,
                                           tag=tag,
                                           _epoch=start_epoch,
                                           verbose=verbose,
                                           indent=indent,
                                           **kwargs)
        best_acc = best_validate_result[0]

    params: list[nn.Parameter] = []
    for param_group in optimizer.param_groups:
        params.extend(param_group['params'])
    len_loader_train = len(loader_train)
    total_iter = (epochs - resume) * len_loader_train

    logger = MetricLogger()
    logger.create_meters(loss=None, top1=None, top5=None)

    if resume and lr_scheduler:
        for _ in range(resume):
            lr_scheduler.step()
    iterator = range(resume, epochs)
    if verbose and output_freq == 'epoch':
        header: str = '{blue_light}{0}: {reset}'.format(print_prefix, **ansi)
        header = header.ljust(max(len(header), 30) + get_ansi_len(header))
        iterator = logger.log_every(range(resume, epochs),
                                    header=print_prefix,
                                    tqdm_header='Epoch',
                                    indent=indent)
    for _epoch in iterator:
        _epoch += 1
        logger.reset()
        if callable(epoch_fn):
            activate_params(module, [])
            epoch_fn(optimizer=optimizer,
                     lr_scheduler=lr_scheduler,
                     _epoch=_epoch,
                     epochs=epochs,
                     start_epoch=start_epoch)
        loader_epoch = loader_train
        if verbose and output_freq == 'iter':
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                'Epoch', output_iter(_epoch, epochs), **ansi)
            header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header))
            loader_epoch = logger.log_every(loader_train,
                                            header=header,
                                            tqdm_header='Batch',
                                            indent=indent)
        if change_train_eval:
            module.train()
        activate_params(module, params)
        for i, data in enumerate(loader_epoch):
            _iter = _epoch * len_loader_train + i
            # data_time.update(time.perf_counter() - end)
            _input, _label = get_data_fn(data, mode='train')
            if pre_conditioner is not None and not amp:
                pre_conditioner.track.enable()
            _output = forward_fn(_input, amp=amp, parallel=True)
            loss = loss_fn(_input, _label, _output=_output, amp=amp)
            if backward_and_step:
                optimizer.zero_grad()
                if amp:
                    scaler.scale(loss).backward()
                    if callable(after_loss_fn) or grad_clip is not None:
                        scaler.unscale_(optimizer)
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params, grad_clip)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                        # start_epoch=start_epoch, _epoch=_epoch, epochs=epochs)
                    if pre_conditioner is not None:
                        pre_conditioner.track.disable()
                        pre_conditioner.step()
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params, grad_clip)
                    optimizer.step()

            if model_ema and i % model_ema_steps == 0:
                model_ema.update_parameters(module)
                if _epoch <= lr_warmup_epochs:
                    # Reset ema buffer to keep copying weights
                    # during warmup period
                    model_ema.n_averaged.fill_(0)

            if lr_scheduler and lr_scheduler_freq == 'iter':
                lr_scheduler.step()
            acc1, acc5 = accuracy_fn(_output,
                                     _label,
                                     num_classes=num_classes,
                                     topk=(1, 5))
            batch_size = int(_label.size(0))
            logger.update(n=batch_size, loss=float(loss), top1=acc1, top5=acc5)
            empty_cache()
        optimizer.zero_grad()
        if lr_scheduler and lr_scheduler_freq == 'epoch':
            lr_scheduler.step()
        if change_train_eval:
            module.eval()
        activate_params(module, [])
        loss, acc = (logger.meters['loss'].global_avg,
                     logger.meters['top1'].global_avg)
        if writer is not None:
            from torch.utils.tensorboard import SummaryWriter
            assert isinstance(writer, SummaryWriter)
            writer.add_scalars(main_tag='Loss/' + main_tag,
                               tag_scalar_dict={tag: loss},
                               global_step=_epoch + start_epoch)
            writer.add_scalars(main_tag='Acc/' + main_tag,
                               tag_scalar_dict={tag: acc},
                               global_step=_epoch + start_epoch)
        if validate_interval != 0 and (_epoch % validate_interval == 0
                                       or _epoch == epochs):
            validate_result = validate_fn(module=module,
                                          num_classes=num_classes,
                                          loader=loader_valid,
                                          get_data_fn=get_data_fn,
                                          forward_fn=forward_fn,
                                          loss_fn=loss_fn,
                                          writer=writer,
                                          tag=tag,
                                          _epoch=_epoch + start_epoch,
                                          verbose=verbose,
                                          indent=indent,
                                          **kwargs)
            cur_acc = validate_result[0]
            if cur_acc >= best_acc:
                best_validate_result = validate_result
                if verbose:
                    prints('{purple}best result update!{reset}'.format(**ansi),
                           indent=indent)
                    prints(
                        f'Current Acc: {cur_acc:.3f}    '
                        f'Previous Best Acc: {best_acc:.3f}',
                        indent=indent)
                best_acc = cur_acc
                if save:
                    save_fn(file_path=file_path,
                            folder_path=folder_path,
                            suffix=suffix,
                            verbose=verbose)
            if verbose:
                prints('-' * 50, indent=indent)
    module.zero_grad()
    return best_validate_result
示例#22
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          metric_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data.shuffle()

    loss_sum, metric_sum, iter_count = [0]*(len(args.atom_targets) + len(args.bond_targets)), \
                                       [0]*(len(args.atom_targets) + len(args.bond_targets)), 0

    loss_weights = args.loss_weights

    num_iters = len(
        data
    ) // args.batch_size * args.batch_size  # don't use the last batch if it's small, for stability

    iter_size = args.batch_size

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break
        mol_batch = MoleculeDataset(data[i:i + args.batch_size])
        smiles_batch, features_batch, target_batch = mol_batch.smiles(
        ), mol_batch.features(), mol_batch.targets()
        batch = smiles_batch
        #mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch])

        # FIXME assign 0 to None in target
        # targets = [[0 if x is None else x for x in tb] for tb in target_batch]

        targets = [torch.Tensor(np.concatenate(x)) for x in zip(*target_batch)]
        if next(model.parameters()).is_cuda:
            #   mask, targets = mask.cuda(), targets.cuda()
            targets = [x.cuda() for x in targets]
        # FIXME
        #class_weights = torch.ones(targets.shape)

        #if args.cuda:
        #   class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)
        targets = [x.reshape([-1, 1]) for x in targets]

        #FIXME mutlticlass
        '''
        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1))], dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        '''
        loss_multi_task = []
        metric_multi_task = []
        for target, pred, lw in zip(targets, preds, loss_weights):
            loss = loss_func(pred, target)
            loss = loss.sum() / target.shape[0]
            loss_multi_task.append(loss * lw)
            if args.cuda:
                metric = metric_func(pred.data.cpu().numpy(),
                                     target.data.cpu().numpy())
            else:
                metric = metric_func(pred.data.numpy(), target.data.numpy())
            metric_multi_task.append(metric)

        loss_sum = [x + y for x, y in zip(loss_sum, loss_multi_task)]
        iter_count += 1

        sum(loss_multi_task).backward()
        optimizer.step()

        metric_sum = [x + y for x, y in zip(metric_sum, metric_multi_task)]

        if isinstance(scheduler, NoamLR) or isinstance(scheduler, SinexpLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = [x / iter_count for x in loss_sum]
            metric_avg = [x / iter_count for x in metric_sum]
            loss_sum, iter_count, metric_sum = [0]*(len(args.atom_targets) + len(args.bond_targets)), \
                                               0, \
                                               [0]*(len(args.atom_targets) + len(args.bond_targets))

            loss_str = ', '.join(f'lss_{i} = {lss:.4e}'
                                 for i, lss in enumerate(loss_avg))
            metric_str = ', '.join(f'mc_{i} = {mc:.4e}'
                                   for i, mc in enumerate(metric_avg))
            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'{loss_str}, {metric_str}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                for i, lss in enumerate(loss_avg):
                    writer.add_scalar(f'train_loss_{i}', lss, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data = deepcopy(data)

    data.shuffle()

    if args.uncertainty == 'bootstrap':
        data.sample(int(4 * len(data) / args.ensemble_size))

    loss_sum, iter_count = 0, 0

    num_iters = len(
        data
    ) // args.batch_size * args.batch_size  # don't use the last batch if it's small, for stability

    iter_size = args.batch_size

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break
        mol_batch = MoleculeDataset(data[i:i + args.batch_size])
        smiles_batch, features_batch, target_batch = mol_batch.smiles(
        ), mol_batch.features(), mol_batch.targets()
        batch = smiles_batch
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        class_weights = torch.ones(targets.shape)

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)

        if model.uncertainty:
            pred_targets = preds[:, [
                j for j in range(len(preds[0])) if j % 2 == 0
            ]]
            pred_var = preds[:,
                             [j for j in range(len(preds[0])) if j % 2 == 1]]
            loss = loss_func(pred_targets, pred_var, targets)
            # sigma = ((pred_targets - targets) ** 2).detach()
            # loss = loss_func(pred_targets, targets) * class_weights * mask
            # loss += nn.MSELoss(reduction='none')(pred_sigma, sigma) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(mol_batch)

        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data.shuffle()

    loss_sum, iter_count = 0, 0

    num_iters = len(
        data
    ) // args.batch_size * args.batch_size  # don't use the last batch if it's small, for stability

    iter_size = args.batch_size

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break
        mol_batch = MoleculeDataset(data[i:i + args.batch_size])
        smiles_batch, features_batch, target_batch = mol_batch.smiles(
        ), mol_batch.features(), mol_batch.targets()
        batch = smiles_batch
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        class_weights = torch.ones(targets.shape)
        #print('class_weight',class_weights.size(),class_weights)
        #print('mask',mask.size(),mask)

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([
                loss_func(preds[:, target_index, :],
                          targets[:, target_index]).unsqueeze(1)
                for target_index in range(preds.size(1))
            ],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        ############ add L1 regularization ############

        ffn_d0_L1_reg_loss = 0
        ffn_d1_L1_reg_loss = 0
        ffn_d2_L1_reg_loss = 0
        ffn_final_L1_reg_loss = 0
        ffn_mol_L1_reg_loss = 0

        lamda_ffn_d0 = 0
        lamda_ffn_d1 = 0
        lamda_ffn_d2 = 0
        lamda_ffn_final = 0
        lamda_ffn_mol = 0

        for param in model.ffn_d0.parameters():
            ffn_d0_L1_reg_loss += torch.sum(torch.abs(param))
        for param in model.ffn_d1.parameters():
            ffn_d1_L1_reg_loss += torch.sum(torch.abs(param))
        for param in model.ffn_d2.parameters():
            ffn_d2_L1_reg_loss += torch.sum(torch.abs(param))
        for param in model.ffn_final.parameters():
            ffn_final_L1_reg_loss += torch.sum(torch.abs(param))
        for param in model.ffn_mol.parameters():
            ffn_mol_L1_reg_loss += torch.sum(torch.abs(param))

        loss += lamda_ffn_d0 * ffn_d0_L1_reg_loss + lamda_ffn_d1 * ffn_d1_L1_reg_loss + lamda_ffn_d2 * ffn_d2_L1_reg_loss + lamda_ffn_final * ffn_final_L1_reg_loss + lamda_ffn_mol * ffn_mol_L1_reg_loss

        ############ add L1 regularization ############

        ############ add L2 regularization ############
        '''
        ffn_d0_L2_reg_loss = 0
        ffn_d1_L2_reg_loss = 0
        ffn_d2_L2_reg_loss = 0
        ffn_final_L2_reg_loss = 0
        ffn_mol_L2_reg_loss = 0

        lamda_ffn_d0 = 1e-6
        lamda_ffn_d1 = 1e-6
        lamda_ffn_d2 = 1e-5
        lamda_ffn_final = 1e-4
        lamda_ffn_mol = 1e-3 

        for param in model.ffn_d0.parameters():
            ffn_d0_L2_reg_loss += torch.sum(torch.square(param))
        for param in model.ffn_d1.parameters():
            ffn_d1_L2_reg_loss += torch.sum(torch.square(param))
        for param in model.ffn_d2.parameters():
            ffn_d2_L2_reg_loss += torch.sum(torch.square(param))
        for param in model.ffn_final.parameters():
            ffn_final_L2_reg_loss += torch.sum(torch.square(param))
        for param in model.ffn_mol.parameters():
            ffn_mol_L2_reg_loss += torch.sum(torch.square(param))

        loss += lamda_ffn_d0 * ffn_d0_L2_reg_loss + lamda_ffn_d1 * ffn_d1_L2_reg_loss + lamda_ffn_d2 * ffn_d2_L2_reg_loss + lamda_ffn_final * ffn_final_L2_reg_loss + lamda_ffn_mol * ffn_mol_L2_reg_loss
        '''
        ############ add L2 regularization ############

        loss_sum += loss.item()
        iter_count += len(mol_batch)

        #loss.backward(retain_graph=True)  # wei, retain_graph=True
        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)
    #print(model)
    return n_iter
def train_epoch(epoch, epochs_count, model: nn.Module, device: torch.device,
                optimizer: optim.Optimizer, scaler: Optional[GradScaler],
                criterion_list: List[nn.Module], train_dataloader: DataLoader,
                cfg, train_metrics: Dict, progress_bar: tqdm, run_folder,
                limit_samples, fold_index, folds_count,
                scheduler: _LRScheduler, schedule_lr_each_step):
    epoch_metrics = {'loss': []}
    epoch_metrics.update({metric: [] for metric in cfg['metrics']})

    grad_acc_iters = max(cfg['train_params'].get('grad_acc_iters', 1), 1)
    grad_clipping = cfg['train_params'].get('grad_clipping', 0)
    pseudo_label = cfg['train_params'].get('pseudo_label', False)
    use_fp = cfg['train_data_loader'].get('use_fp', False)
    aux_weights = cfg['train_params'].get('aux_weights', None)

    batch_size = cfg['train_data_loader']['batch_size']
    epoch_samples = len(train_dataloader) * batch_size
    seen_samples = 0

    model.train()

    # reset CustomMetrics
    for metric in criterion_list:  # type:CustomMetrics
        if isinstance(metric, CustomMetrics):
            metric.reset()

    optimizer.zero_grad()
    is_optimizer_update_finished = True

    for iteration_id, data in enumerate(train_dataloader):
        if limit_samples is not None and seen_samples >= limit_samples:  # DEBUG
            break

        data_img, data_class, data_record_ids = data
        batch_size = len(data_img)

        if progress_bar.n == 0:
            progress_bar.reset(
            )  # reset start time to remove time while DataLoader was populating processes

        if scaler is None:
            pass
            # outputs, loss, metrics = forward_pass(data_img, data_class, model, device, criterion_list, pseudo_label,
            #                                       use_fp, aux_weights=aux_weights)
            # if grad_acc_iters > 1:
            #     loss = loss / grad_acc_iters
            # loss.backward()
        else:
            with autocast():
                outputs, loss, metrics = forward_pass(data_img,
                                                      data_class,
                                                      model,
                                                      device,
                                                      criterion_list,
                                                      pseudo_label,
                                                      use_fp,
                                                      aux_weights=aux_weights)
                if grad_acc_iters > 1:
                    loss = loss / grad_acc_iters
            scaler.scale(loss).backward()

        is_optimizer_update_finished = False

        if grad_acc_iters <= 1 or (iteration_id + 1) % grad_acc_iters == 0:
            if scaler is None:
                pass
                # optimizer.step()
            else:
                if grad_clipping > 0:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(optimizer)
                    # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   grad_clipping)

                scaler.step(optimizer)
                scaler.update()

            optimizer.zero_grad()
            is_optimizer_update_finished = True

        # Metrics and progress
        seen_samples += batch_size

        collect_metrics(cfg, epoch_metrics, loss * grad_acc_iters, metrics,
                        criterion_list, '')
        lr = [group['lr'] for group in optimizer.param_groups][0]
        print_info(progress_bar, 'Train', epoch_metrics, loss * grad_acc_iters,
                   lr, '', batch_size, epoch, epochs_count, seen_samples,
                   epoch_samples, fold_index, folds_count)

        if (iteration_id +
                1) % 5 == 0 and iteration_id + 1 < len(train_dataloader):
            log_dict = dict([(key, value[-1])
                             for key, value in epoch_metrics.items()])
            if schedule_lr_each_step and scheduler is not None:
                log_dict['lr'] = optimizer.param_groups[0]['lr']
            wandb.log(log_dict, (epoch - 1) * epoch_samples + seen_samples,
                      commit=True)

        # Step lr scheduler
        if schedule_lr_each_step and scheduler is not None:
            scheduler.step()

    # Finish optimizer step after the not completed gradient accumulation batch
    if not is_optimizer_update_finished:
        if scaler is None:
            optimizer.step()
        else:
            if grad_clipping > 0:
                # Unscales the gradients of optimizer's assigned params in-place
                scaler.unscale_(optimizer)
                # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               grad_clipping)

            scaler.step(optimizer)
            scaler.update()

    for idx, item in enumerate(epoch_metrics.items()):
        key, value = item
        if isinstance(criterion_list[idx], CustomMetrics):
            value = [criterion_list[idx].compute()]
        train_metrics[key].append(np.mean(value))

    progress_bar.write('')
示例#26
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          n_iter: int = 0,
          logger: logging.Logger = None) -> int:

    model.train()
    data.shuffle()
    loss_sum, iter_count = 0, 0
    num_iters = len(
        data
    ) // config.batch_size * config.batch_size  # don't use the last batch if it's small, for stability
    iter_size = config.batch_size

    if config.verbose:
        generater = trange
    else:
        generater = range

    for i in generater(0, num_iters, iter_size):
        # Prepare batch
        if i + config.batch_size > len(data):
            break
        mol_batch = MoleculeDataset(data[i:i + config.batch_size])
        smiles_batch, target_batch = mol_batch.smiles(), mol_batch.targets()
        batch = smiles_batch
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        class_weights = torch.ones(targets.shape)

        if config.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch)

        if config.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([
                loss_func(preds[:, target_index, :],
                          targets[:, target_index]).unsqueeze(1)
                for target_index in range(preds.size(1))
            ],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(mol_batch)

        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(mol_batch)

    return n_iter
示例#27
0
def train(model: MoleculeModel,
          data_loader: MoleculeDataLoader,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: TrainArgs,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: A :class:`~chemprop.models.model.MoleculeModel`.
    :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`.
    :param loss_func: Loss function.
    :param optimizer: An optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for recording output.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()
    loss_sum, iter_count = 0, 0

    for batch in tqdm(data_loader, total=len(data_loader)):
        # Prepare batch
        batch: MoleculeDataset
        mol_batch, features_batch, target_batch = batch.batch_graph(
        ), batch.features(), batch.targets()
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])

        # Run model
        model.zero_grad()
        preds = model(mol_batch, features_batch)

        # Move tensors to correct device
        mask = mask.to(preds.device)
        targets = targets.to(preds.device)
        class_weights = torch.ones(targets.shape, device=preds.device)

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([
                loss_func(preds[:, target_index, :],
                          targets[:, target_index]).unsqueeze(1)
                for target_index in range(preds.size(1))
            ],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(batch)

        loss.backward()
        if args.grad_clip:
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
示例#28
0
    def fit(self,
            train_data_loader: DataLoader,
            validation_data_loader: DataLoader,
            loss_function: _Loss,
            optimizer: Optimizer,
            lr_scheduler: _LRScheduler,
            num_epochs: int,
            hooks: List[Callable[[], None]] = None):
        """
        Runs experiment with the given parameters and return trained module, train_loss, validation_loss, last_predictions
        :param train_data_loader: Train data.
        :param validation_data_loader: Validation data.
        :param loss_function: Loss functions which will be used along side the optimizer.
        :param optimizer: Optimizer.
        :param lr_scheduler: Learning rate scheduler for the optimizer.
        :param num_epochs: How many epochs to train the model.
        :param hooks: functions to invoke during the training forward pass.
        :return: Trained Forecaster Forecaster
        """
        self.to(device=self.device)
        loss_function = loss_function.to(device=self.device)

        self.train_loss = np.empty(num_epochs)
        self.validation_loss = np.empty(num_epochs)
        # self.validation_forecast = np.empty(len(validation_data_loader.dataset))

        states_dict = None

        avg_validate_loss = float("inf")

        for epoch in range(num_epochs):
            # training
            self.train()
            epoch_train_loss = 0
            for seq, labels in train_data_loader:
                if hooks is not None:
                    for hook in hooks:
                        hook()
                optimizer.zero_grad()

                seq, labels = seq.to(self.device), labels.to(self.device)
                y_pred = self(seq)

                single_loss = loss_function(y_pred.squeeze(), labels.squeeze())
                epoch_train_loss += single_loss.item()
                single_loss.backward()
                optimizer.step()

            # validation
            self.eval()
            epoch_validate_loss = 0
            with torch.no_grad():
                for i, (seq, labels) in enumerate(validation_data_loader):
                    seq, labels = seq.to(self.device), labels.to(self.device)
                    y_pred = self(seq)

                    single_loss = loss_function(y_pred.squeeze(), labels.squeeze())
                    epoch_validate_loss += single_loss.item()

                    # if epoch == num_epochs - 1:
                    #     for j in range(len(seq)):
                    #         self.validation_forecast[i*len(seq) + j] = y_pred[j].item()

            avg_train_loss = epoch_train_loss / len(train_data_loader.dataset)

            new_avg_validate_loss = epoch_validate_loss / len(validation_data_loader.dataset)
            if new_avg_validate_loss < avg_validate_loss:
                avg_validate_loss = new_avg_validate_loss
                states_dict = self.state_dict()

            self.train_loss[epoch] = avg_train_loss
            self.validation_loss[epoch] = new_avg_validate_loss

            self.logger.info(f'Epoch {epoch + 1}/{num_epochs}')
            self.logger.info(f'Avg train loss: {avg_train_loss :10.8f}')
            self.logger.info(f'Avg validation loss: {avg_validate_loss :10.8f}')

            lr_scheduler.step(epoch=epoch)

        self.trained = True
        self.load_state_dict(states_dict)
        return self