示例#1
0
    def train(
        self,
        problem: Problem,
        startEpoch: int,
        nEpochs: int,
        batchSize: int,
        scheduler: _LRScheduler,
    ) -> Iterator[FractionalPerformanceSummary]:
        self.cur_epoch = startEpoch

        d_types = [d.data_type for d in problem.datasets]
        assert Split.TRAIN in d_types, "training dataset should be included"

        ################### Start Training #####################################

        logger.info("Model layers")
        logger.info(str(self.model.modules))

        loaders = self._get_loaders(problem, batchSize=batchSize)

        # start epochs
        while self.cur_epoch < nEpochs:
            self.cur_epoch += 1
            logger.info("Starting epoch %d" % self.cur_epoch)
            epoch_stats = self._pass_one_epoch(problem, loaders, Mode.TRAIN)
            logger.info("Finished epoch %d" % self.cur_epoch)

            scheduler.step()

            yield self._get_worker_performance_summary(epoch_stats)
示例#2
0
    def fit_one(
        self,
        model,
        tasks: List[Task],
        dataloader: DataLoader,
        optimizer: Optimizer,
        scheduler: _LRScheduler,
        training_logger: ResultLogger,
        train_model=True,
    ) -> float:
        model.train(train_model)

        for task in tasks:
            task.train()

        task_names = [t.name for t in tasks]

        losses = 0.0
        total = 0.0
        for i, (x, y) in enumerate(dataloader):
            outputs = self._step(model, tasks, x, y, optimizer)

            training_logger.log(outputs, task_names, i + 1, len(dataloader))

            for o in outputs:
                losses += o.loss.item()
                total += 1.0
        scheduler.step()
        return losses / total
 def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model':
     checkpoint = torch.load(path_to_checkpoint)
     self.load_state_dict(checkpoint['state_dict'])
     step = checkpoint['step']
     if optimizer is not None:
         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
     if scheduler is not None:
         scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
     return step
示例#4
0
    def project(self, latents: Latents, images: torch.Tensor, optimizer: Optimizer, num_steps: int, loss_function: Callable, lr_scheduler: _LRScheduler = None) -> Tuple[LatentPaths, Latents]:
        pbar = tqdm(range(num_steps), leave=False)
        latent_path = []
        noise_path = []

        best_latent = best_noise = best_psnr = None

        for i in pbar:
            img_gen, _ = self.generate(latents)

            batch, channel, height, width = img_gen.shape

            if height > 256:
                factor = height // 256

                img_gen = img_gen.reshape(
                    batch, channel, height // factor, factor, width // factor, factor
                )
                img_gen = img_gen.mean([3, 5])

            # # n_loss = noise_regularize(noises)
            loss, loss_dict = loss_function(img_gen, images)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_dict['psnr'] = self.psnr(img_gen, images).item()
            loss_dict['lr'] = optimizer.param_groups[0]["lr"]

            if lr_scheduler is not None:
                lr_scheduler.step()

            self.log.append(loss_dict)

            if best_psnr is None or best_psnr < loss_dict['psnr']:
                best_psnr = loss_dict['psnr']
                best_latent = latents.latent.detach().clone().cpu()
                best_noise = [noise.detach().clone().cpu() for noise in latents.noise]

            if i % self.debug_step == 0:
                latent_path.append(latents.latent.detach().clone().cpu())
                noise_path.append([noise.detach().clone().cpu() for noise in latents.noise])

            loss_description = "; ".join(f"{key}: {value:.6f}" for key, value in loss_dict.items())
            pbar.set_description(loss_description)

            loss_dict['iteration'] = i
            if self.abort_condition is not None and self.abort_condition(loss_dict):
                break

        latent_path.append(latents.latent.detach().clone().cpu())
        noise_path.append([noise.detach().clone().cpu() for noise in latents.noise])

        return LatentPaths(latent_path, noise_path), Latents(best_latent, best_noise)
示例#5
0
def enumerate_scheduler(scheduler: _LRScheduler, steps: int) -> List[float]:
    """
    Reads the current learning rate via get_last_lr, run 1 scheduler step, and repeat. Returns the LR values.
    """
    lrs = []
    for _ in range(steps):
        lr = scheduler.get_last_lr()  # type: ignore
        assert isinstance(lr, list)
        assert len(lr) == 1
        lrs.append(lr[0])
        scheduler.step()
    return lrs
示例#6
0
def train_step(net: nn.Module,
               crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool,
               inputs: torch.Tensor, targets: torch.Tensor, grad_clip:float) -> Tuple[torch.Tensor, float]:
    outputs = net(inputs)

    loss = crit(outputs, targets)
    optim.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(net.parameters(), grad_clip)

    optim.step()
    if sched and not sched_on_epoch:
        sched.step()
    return outputs, loss.item()
示例#7
0
    def simulate_values(  # type: ignore[override]
            cls, num_events: int, lr_scheduler: _LRScheduler,
            **kwargs: Any) -> List[List[int]]:
        """Method to simulate scheduled values during num_events events.

        Args:
            num_events (int): number of events during the simulation.
            lr_scheduler (subclass of `torch.optim.lr_scheduler._LRScheduler`): lr_scheduler object to wrap.

        Returns:
            list of pairs: [event_index, value]

        """

        if not isinstance(lr_scheduler, _LRScheduler):
            raise TypeError(
                "Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
                f"but given {type(lr_scheduler)}")

        # This scheduler uses `torch.optim.lr_scheduler._LRScheduler` which
        # should be replicated in order to simulate LR values and
        # not perturb original scheduler.
        with tempfile.TemporaryDirectory() as tmpdirname:
            cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt"
            obj = {
                "lr_scheduler": lr_scheduler.state_dict(),
                "optimizer": lr_scheduler.optimizer.state_dict(
                ),  # type: ignore[attr-defined]
            }
            torch.save(obj, cache_filepath.as_posix())

            values = []
            scheduler = cls(save_history=False,
                            lr_scheduler=lr_scheduler,
                            **kwargs)  # type: ignore[call-arg]
            for i in range(num_events):
                params = [
                    p[scheduler.param_name]
                    for p in scheduler.optimizer_param_groups
                ]
                values.append([i] + params)
                scheduler(engine=None)

            obj = torch.load(cache_filepath.as_posix())
            lr_scheduler.load_state_dict(obj["lr_scheduler"])
            lr_scheduler.optimizer.load_state_dict(
                obj["optimizer"])  # type: ignore[attr-defined]

            return values
def train(model: nn.Module,
          loader: DataLoader,
          class_loss: nn.Module,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          epoch: int,
          callback: VisdomLogger,
          freq: int,
          ex: Experiment = None) -> None:
    model.train()
    device = next(model.parameters()).device
    to_device = lambda x: x.to(device, non_blocking=True)
    loader_length = len(loader)
    train_losses = AverageMeter(device=device, length=loader_length)
    train_accs = AverageMeter(device=device, length=loader_length)

    pbar = tqdm(loader, ncols=80, desc='Training   [{:03d}]'.format(epoch))
    for i, (batch, labels, indices) in enumerate(pbar):
        batch, labels, indices = map(to_device, (batch, labels, indices))
        logits, features = model(batch)
        loss = class_loss(logits, labels).mean()
        acc = (logits.detach().argmax(1) == labels).float().mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        train_losses.append(loss)
        train_accs.append(acc)

        if callback is not None and not (i + 1) % freq:
            step = epoch + i / loader_length
            callback.scalar('xent',
                            step,
                            train_losses.last_avg,
                            title='Train Losses')
            callback.scalar('train_acc',
                            step,
                            train_accs.last_avg,
                            title='Train Acc')

    if ex is not None:
        for i, (loss, acc) in enumerate(
                zip(train_losses.values_list, train_accs.values_list)):
            step = epoch + i / loader_length
            ex.log_scalar('train.loss', loss, step=step)
            ex.log_scalar('train.acc', acc, step=step)
示例#9
0
def save_checkpoint(
    model: nn.Module = None,
    optimizer: optim.Optimizer = None,
    scheduler: sche._LRScheduler = None,
    amp=None,
    exp_name: str = "",
    current_epoch: int = 1,
    full_net_path: str = "",
    state_net_path: str = "",
):
    """
    保存完整参数模型(大)和状态参数模型(小)

    Args:
        model (nn.Module): model object
        optimizer (optim.Optimizer): optimizer object
        scheduler (sche._LRScheduler): scheduler object
        amp (): apex.amp
        exp_name (str): exp_name
        current_epoch (int): in the epoch, model **will** be trained
        full_net_path (str): the path for saving the full model parameters
        state_net_path (str): the path for saving the state dict.
    """

    state_dict = {
        "arch": exp_name,
        "epoch": current_epoch,
        "net_state": model.state_dict(),
        "opti_state": optimizer.state_dict(),
        "sche_state": scheduler.state_dict(),
        "amp_state": amp.state_dict() if amp else None,
    }
    torch.save(state_dict, full_net_path)
    torch.save(model.state_dict(), state_net_path)
示例#10
0
 def snapshot(self,
              net: torch.nn.Module,
              opt: Optimizer,
              sched: _LRScheduler = None,
              epoch: int = None,
              subdir='.'):
     """
     Writes a snapshot of the training, i.e. network weights, optimizer state and scheduler state to a file
     in the log directory.
     :param net: the neural network
     :param opt: the optimizer used
     :param sched: the learning rate scheduler used
     :param epoch: the current epoch
     :param subdir: if given, creates a subdirectory in the log directory. The data is written to a file
         in this subdirectory instead.
     :return:
     """
     outfile = pt.join(self.dir, subdir, 'snapshot.pt')
     if not pt.exists(os.path.dirname(outfile)):
         os.makedirs(os.path.dirname(outfile))
     torch.save(
         {
             'net': net.state_dict(),
             'opt': opt.state_dict(),
             'sched': sched.state_dict(),
             'epoch': epoch
         }, outfile)
     return outfile
示例#11
0
def resume_checkpoint(
    model: nn.Module = None,
    optimizer: optim.Optimizer = None,
    scheduler: sche._LRScheduler = None,
    exp_name: str = "",
    load_path: str = "",
    mode: str = "all",
):
    """
    从保存节点恢复模型

    Args:
        model (nn.Module): model object
        optimizer (optim.Optimizer): optimizer object
        scheduler (sche._LRScheduler): scheduler object
        exp_name (str): exp_name
        load_path (str): 模型存放路径
        mode (str): 选择哪种模型恢复模式:
            - 'all': 回复完整模型,包括训练中的的参数;
            - 'onlynet': 仅恢复模型权重参数
            
    Returns mode: 'all' start_epoch; 'onlynet' None
    """
    if os.path.exists(load_path) and os.path.isfile(load_path):
        construct_print(f"Loading checkpoint '{load_path}'")
        checkpoint = torch.load(load_path)
        if mode == "all":
            if exp_name == checkpoint["arch"]:
                start_epoch = checkpoint["epoch"]
                model.load_state_dict(checkpoint["net_state"])
                optimizer.load_state_dict(checkpoint["opti_state"])
                scheduler.load_state_dict(checkpoint["sche_state"])
                construct_print(f"Loaded '{load_path}' "
                                f"(will train at epoch"
                                f" {checkpoint['epoch']})")
                return start_epoch
            else:
                raise Exception(f"{load_path} does not match.")
        elif mode == "onlynet":
            model.load_state_dict(checkpoint)
            construct_print(f"Loaded checkpoint '{load_path}' "
                            f"(only has the model's weight params)")
        else:
            raise NotImplementedError
    else:
        raise Exception(f"{load_path}路径不正常,请检查")
 def save(self, path_to_checkpoints_dir: str, step: int, optimizer: Optimizer, scheduler: _LRScheduler) -> str:
     path_to_checkpoint = os.path.join(path_to_checkpoints_dir, f'model-{step}.pth')
     checkpoint = {
         'state_dict': self.state_dict(),
         'step': step,
         'optimizer_state_dict': optimizer.state_dict(),
         'scheduler_state_dict': scheduler.state_dict()
     }
     torch.save(checkpoint, path_to_checkpoint)
     return path_to_checkpoint
示例#13
0
 def _better_lr_sched_repr(lr_sched: _LRScheduler) -> str:
     return (
         lr_sched.__class__.__name__
         + "(\n    "
         + "\n    ".join(
             f"{k}: {v}"
             for k, v in lr_sched.state_dict().items()
             if not k.startswith("_")
         )
         + "\n)"
     )
示例#14
0
    def fit_support(
        self,
        model,
        tasks: List[Task],
        dataloader: DataLoader,
        optimizer: Optimizer,
        scheduler: _LRScheduler,
        training_logger: ResultLogger,
    ):
        support_loss = 1.0
        support_epoch = 0

        # Don't change default optimizer and scheduler states
        optimizer_state_dict = deepcopy(optimizer.state_dict())
        scheduler_state_dict = deepcopy(scheduler.state_dict())

        # Reset tasks states
        for task in tasks:
            task.reset()

        model.freeze_weights()

        while (support_loss > self.support_min_loss
               and support_epoch < self.support_max_epochs):
            support_epoch += 1
            support_loss = self.fit_one(
                model,
                tasks,
                dataloader,
                optimizer,
                scheduler,
                training_logger.epoch(support_epoch, self.support_max_epochs),
                train_model=False,
            )

        optimizer.load_state_dict(optimizer_state_dict)
        scheduler.load_state_dict(scheduler_state_dict)
        model.defreeze_weights()
示例#15
0
    def collect_state_dict(
        self,
        iteration: Union[float, int],
        model: EmmentalModel,
        optimizer: Optimizer,
        lr_scheduler: _LRScheduler,
        metric_dict: Dict[str, float],
    ) -> Dict[str, Any]:
        r"""Collect the state dict of the model.

        Args:
          iteration(float or int): The current iteration.
          model(EmmentalModel): The model to checkpoint.
          optimizer(Optimizer): The optimizer used during training process.
          lr_scheduler(_LRScheduler): Learning rate scheduler.
          metric_dict(dict): the metric dict.

        Returns:
          dict: The state dict.
        """

        model_params = {
            "name": model.name,
            "module_pool": model.collect_state_dict(),
            # "task_names": model.task_names,
            # "task_flows": model.task_flows,
            # "loss_funcs": model.loss_funcs,
            # "output_funcs": model.output_funcs,
            # "scorers": model.scorers,
        }

        state_dict = {
            "iteration": iteration,
            "model": model_params,
            "optimizer": optimizer.state_dict(),
            "lr_scheduler":
            lr_scheduler.state_dict() if lr_scheduler else None,
            "metric_dict": metric_dict,
        }

        return state_dict
示例#16
0
def train(model: MoleculeModel,
          data_loader: MoleculeDataLoader,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: TrainArgs,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: A :class:`~chemprop.models.model.MoleculeModel`.
    :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`.
    :param loss_func: Loss function.
    :param optimizer: An optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for recording output.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()
    loss_sum, iter_count = 0, 0

    for batch in tqdm(data_loader, total=len(data_loader)):
        # Prepare batch
        batch: MoleculeDataset
        mol_batch, features_batch, target_batch = batch.batch_graph(
        ), batch.features(), batch.targets()
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])

        # Run model
        model.zero_grad()
        preds = model(mol_batch, features_batch)

        # Move tensors to correct device
        mask = mask.to(preds.device)
        targets = targets.to(preds.device)
        class_weights = torch.ones(targets.shape, device=preds.device)

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([
                loss_func(preds[:, target_index, :],
                          targets[:, target_index]).unsqueeze(1)
                for target_index in range(preds.size(1))
            ],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(batch)

        loss.backward()
        if args.grad_clip:
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
示例#17
0
def resume_checkpoint(
    model: nn.Module = None,
    optimizer: optim.Optimizer = None,
    scheduler: sche._LRScheduler = None,
    amp=None,
    exp_name: str = "",
    load_path: str = "",
    mode: str = "all",
):
    """
    从保存节点恢复模型

    Args:
        model (nn.Module): model object
        optimizer (optim.Optimizer): optimizer object
        scheduler (sche._LRScheduler): scheduler object
        amp (): apex.amp
        exp_name (str): exp_name
        load_path (str): 模型存放路径
        mode (str): 选择哪种模型恢复模式:
            - 'all': 回复完整模型,包括训练中的的参数;
            - 'onlynet': 仅恢复模型权重参数

    Returns mode: 'all' start_epoch; 'onlynet' None
    """
    if os.path.exists(load_path) and os.path.isfile(load_path):
        construct_print(f"Loading checkpoint '{load_path}'")
        checkpoint = torch.load(load_path)
        if mode == "all":
            if exp_name and exp_name != checkpoint["arch"]:
                # 如果给定了exp_name,那么就必须匹配对应的checkpoint["arch"],否则不作要求
                raise Exception(
                    f"We can not match {exp_name} with {load_path}.")

            start_epoch = checkpoint["epoch"]
            if hasattr(model, "module"):
                model.module.load_state_dict(checkpoint["net_state"])
            else:
                model.load_state_dict(checkpoint["net_state"])
            optimizer.load_state_dict(checkpoint["opti_state"])
            scheduler.load_state_dict(checkpoint["sche_state"])
            if checkpoint.get("amp_state", None):
                if amp:
                    amp.load_state_dict(checkpoint["amp_state"])
                else:
                    construct_print("You are not using amp.")
            else:
                construct_print("The state_dict of amp is None.")
            construct_print(f"Loaded '{load_path}' "
                            f"(will train at epoch"
                            f" {checkpoint['epoch']})")
            return start_epoch
        elif mode == "onlynet":
            if hasattr(model, "module"):
                model.module.load_state_dict(checkpoint)
            else:
                model.load_state_dict(checkpoint)
            construct_print(f"Loaded checkpoint '{load_path}' "
                            f"(only has the model's weight params)")
        else:
            raise NotImplementedError
    else:
        raise Exception(f"{load_path}路径不正常,请检查")
示例#18
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          metric_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data.shuffle()

    loss_sum, metric_sum, iter_count = [0]*(len(args.atom_targets) + len(args.bond_targets)), \
                                       [0]*(len(args.atom_targets) + len(args.bond_targets)), 0

    loss_weights = args.loss_weights

    num_iters = len(
        data
    ) // args.batch_size * args.batch_size  # don't use the last batch if it's small, for stability

    iter_size = args.batch_size

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break
        mol_batch = MoleculeDataset(data[i:i + args.batch_size])
        smiles_batch, features_batch, target_batch = mol_batch.smiles(
        ), mol_batch.features(), mol_batch.targets()
        batch = smiles_batch
        #mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch])

        # FIXME assign 0 to None in target
        # targets = [[0 if x is None else x for x in tb] for tb in target_batch]

        targets = [torch.Tensor(np.concatenate(x)) for x in zip(*target_batch)]
        if next(model.parameters()).is_cuda:
            #   mask, targets = mask.cuda(), targets.cuda()
            targets = [x.cuda() for x in targets]
        # FIXME
        #class_weights = torch.ones(targets.shape)

        #if args.cuda:
        #   class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)
        targets = [x.reshape([-1, 1]) for x in targets]

        #FIXME mutlticlass
        '''
        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1))], dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        '''
        loss_multi_task = []
        metric_multi_task = []
        for target, pred, lw in zip(targets, preds, loss_weights):
            loss = loss_func(pred, target)
            loss = loss.sum() / target.shape[0]
            loss_multi_task.append(loss * lw)
            if args.cuda:
                metric = metric_func(pred.data.cpu().numpy(),
                                     target.data.cpu().numpy())
            else:
                metric = metric_func(pred.data.numpy(), target.data.numpy())
            metric_multi_task.append(metric)

        loss_sum = [x + y for x, y in zip(loss_sum, loss_multi_task)]
        iter_count += 1

        sum(loss_multi_task).backward()
        optimizer.step()

        metric_sum = [x + y for x, y in zip(metric_sum, metric_multi_task)]

        if isinstance(scheduler, NoamLR) or isinstance(scheduler, SinexpLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = [x / iter_count for x in loss_sum]
            metric_avg = [x / iter_count for x in metric_sum]
            loss_sum, iter_count, metric_sum = [0]*(len(args.atom_targets) + len(args.bond_targets)), \
                                               0, \
                                               [0]*(len(args.atom_targets) + len(args.bond_targets))

            loss_str = ', '.join(f'lss_{i} = {lss:.4e}'
                                 for i, lss in enumerate(loss_avg))
            metric_str = ', '.join(f'mc_{i} = {mc:.4e}'
                                   for i, mc in enumerate(metric_avg))
            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'{loss_str}, {metric_str}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                for i, lss in enumerate(loss_avg):
                    writer.add_scalar(f'train_loss_{i}', lss, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data = deepcopy(data)

    data.shuffle()

    if args.uncertainty == 'bootstrap':
        data.sample(int(4 * len(data) / args.ensemble_size))

    loss_sum, iter_count = 0, 0

    num_iters = len(
        data
    ) // args.batch_size * args.batch_size  # don't use the last batch if it's small, for stability

    iter_size = args.batch_size

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break
        mol_batch = MoleculeDataset(data[i:i + args.batch_size])
        smiles_batch, features_batch, target_batch = mol_batch.smiles(
        ), mol_batch.features(), mol_batch.targets()
        batch = smiles_batch
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        class_weights = torch.ones(targets.shape)

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)

        if model.uncertainty:
            pred_targets = preds[:, [
                j for j in range(len(preds[0])) if j % 2 == 0
            ]]
            pred_var = preds[:,
                             [j for j in range(len(preds[0])) if j % 2 == 1]]
            loss = loss_func(pred_targets, pred_var, targets)
            # sigma = ((pred_targets - targets) ** 2).detach()
            # loss = loss_func(pred_targets, targets) * class_weights * mask
            # loss += nn.MSELoss(reduction='none')(pred_sigma, sigma) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(mol_batch)

        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data.shuffle()

    loss_sum, iter_count = 0, 0

    num_iters = len(
        data
    ) // args.batch_size * args.batch_size  # don't use the last batch if it's small, for stability

    iter_size = args.batch_size

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break
        mol_batch = MoleculeDataset(data[i:i + args.batch_size])
        smiles_batch, features_batch, target_batch = mol_batch.smiles(
        ), mol_batch.features(), mol_batch.targets()
        batch = smiles_batch
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        class_weights = torch.ones(targets.shape)
        #print('class_weight',class_weights.size(),class_weights)
        #print('mask',mask.size(),mask)

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([
                loss_func(preds[:, target_index, :],
                          targets[:, target_index]).unsqueeze(1)
                for target_index in range(preds.size(1))
            ],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        ############ add L1 regularization ############

        ffn_d0_L1_reg_loss = 0
        ffn_d1_L1_reg_loss = 0
        ffn_d2_L1_reg_loss = 0
        ffn_final_L1_reg_loss = 0
        ffn_mol_L1_reg_loss = 0

        lamda_ffn_d0 = 0
        lamda_ffn_d1 = 0
        lamda_ffn_d2 = 0
        lamda_ffn_final = 0
        lamda_ffn_mol = 0

        for param in model.ffn_d0.parameters():
            ffn_d0_L1_reg_loss += torch.sum(torch.abs(param))
        for param in model.ffn_d1.parameters():
            ffn_d1_L1_reg_loss += torch.sum(torch.abs(param))
        for param in model.ffn_d2.parameters():
            ffn_d2_L1_reg_loss += torch.sum(torch.abs(param))
        for param in model.ffn_final.parameters():
            ffn_final_L1_reg_loss += torch.sum(torch.abs(param))
        for param in model.ffn_mol.parameters():
            ffn_mol_L1_reg_loss += torch.sum(torch.abs(param))

        loss += lamda_ffn_d0 * ffn_d0_L1_reg_loss + lamda_ffn_d1 * ffn_d1_L1_reg_loss + lamda_ffn_d2 * ffn_d2_L1_reg_loss + lamda_ffn_final * ffn_final_L1_reg_loss + lamda_ffn_mol * ffn_mol_L1_reg_loss

        ############ add L1 regularization ############

        ############ add L2 regularization ############
        '''
        ffn_d0_L2_reg_loss = 0
        ffn_d1_L2_reg_loss = 0
        ffn_d2_L2_reg_loss = 0
        ffn_final_L2_reg_loss = 0
        ffn_mol_L2_reg_loss = 0

        lamda_ffn_d0 = 1e-6
        lamda_ffn_d1 = 1e-6
        lamda_ffn_d2 = 1e-5
        lamda_ffn_final = 1e-4
        lamda_ffn_mol = 1e-3 

        for param in model.ffn_d0.parameters():
            ffn_d0_L2_reg_loss += torch.sum(torch.square(param))
        for param in model.ffn_d1.parameters():
            ffn_d1_L2_reg_loss += torch.sum(torch.square(param))
        for param in model.ffn_d2.parameters():
            ffn_d2_L2_reg_loss += torch.sum(torch.square(param))
        for param in model.ffn_final.parameters():
            ffn_final_L2_reg_loss += torch.sum(torch.square(param))
        for param in model.ffn_mol.parameters():
            ffn_mol_L2_reg_loss += torch.sum(torch.square(param))

        loss += lamda_ffn_d0 * ffn_d0_L2_reg_loss + lamda_ffn_d1 * ffn_d1_L2_reg_loss + lamda_ffn_d2 * ffn_d2_L2_reg_loss + lamda_ffn_final * ffn_final_L2_reg_loss + lamda_ffn_mol * ffn_mol_L2_reg_loss
        '''
        ############ add L2 regularization ############

        loss_sum += loss.item()
        iter_count += len(mol_batch)

        #loss.backward(retain_graph=True)  # wei, retain_graph=True
        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)
    #print(model)
    return n_iter
示例#21
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          n_iter: int = 0,
          logger: logging.Logger = None) -> int:

    model.train()
    data.shuffle()
    loss_sum, iter_count = 0, 0
    num_iters = len(
        data
    ) // config.batch_size * config.batch_size  # don't use the last batch if it's small, for stability
    iter_size = config.batch_size

    if config.verbose:
        generater = trange
    else:
        generater = range

    for i in generater(0, num_iters, iter_size):
        # Prepare batch
        if i + config.batch_size > len(data):
            break
        mol_batch = MoleculeDataset(data[i:i + config.batch_size])
        smiles_batch, target_batch = mol_batch.smiles(), mol_batch.targets()
        batch = smiles_batch
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        class_weights = torch.ones(targets.shape)

        if config.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch)

        if config.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([
                loss_func(preds[:, target_index, :],
                          targets[:, target_index]).unsqueeze(1)
                for target_index in range(preds.size(1))
            ],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(mol_batch)

        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(mol_batch)

    return n_iter
def fit(
    step         :FunctionType,
    epochs       :int,
    model        :ModuleType,
    optimizer    :OptimizerType,
    scheduler    :SchedulerType,
    data_loader  :DataLoader,
    model_path   :str,
    logger       :object,
    early_stop   :bool            =False,
    pgd_kwargs   :Optional[dict]  =None,
    verbose      :bool            =False
) -> Tuple[ModuleType, dict]:
    """Standard pytorch boilerplate for training a model.
    Allows for early stopping wrt robust accuracy rather than the usual clean accuracy.
    """
    device = next(model.parameters()).device
    prev_robust_acc = 0.
    start_train_time = time.time()
    logger.info('Epoch \t Seconds \t LR \t \t Train Loss \t Train Acc')

    for epoch in range(epochs):
        start_epoch_time = time.time()
        train_loss = 0
        train_acc = 0
        train_n = 0

        data_generator = enumerate(data_loader)
        if verbose:
            data_generator = tqdm(data_generator, total=len(data_loader), desc=f'Epoch {epoch + 1}')

        for i, (X, y) in data_generator:
            X, y = X.to(device), y.to(device)
            if i == 0:
                first_batch = (X, y)

            loss, logits = step(X, y, 
                model    =model, 
                optimizer=optimizer, 
                scheduler=scheduler)

            train_loss += loss.item() * y.size(0)
            train_acc += (logits.argmax(dim=1) == y).sum().item()
            train_n += y.size(0)

        if early_stop:
            assert pgd_kwargs is not None
            # Check current PGD robustness of model using random minibatch
            X, y = first_batch

            pgd_delta = attack_pgd(
                model=model, 
                X    =X, 
                y    =y, 
                opt  =optimizer,
                **pgd_kwargs)

            model.eval()
            with torch.no_grad():
                output = model(clamp(X + pgd_delta[:X.size(0)], pgd_kwargs['lower_limit'], pgd_kwargs['upper_limit']))
            robust_acc = (output.softmax(dim=1).argmax(dim=1) == y).sum().item() / y.size(0)
            if robust_acc - prev_robust_acc < -0.2:
                logger.info('EARLY STOPPING TRIGGERED')
                break
            prev_robust_acc = robust_acc
            best_state_dict = copy.deepcopy(model.state_dict())
            model.train()
        
        epoch_time = time.time()
        lr = scheduler.get_last_lr()[0]
        logger.info('%d \t %.1f \t \t %.4f \t %.4f \t %.4f',
            epoch, 
            epoch_time - start_epoch_time, 
            lr, 
            train_loss / train_n, 
            train_acc / train_n)

    train_time = time.time()
    if not early_stop:
        best_state_dict = model.state_dict()
    torch.save(best_state_dict, model_path)
    logger.info('Total train time: %.4f minutes', (train_time - start_train_time)/60)

    return model, best_state_dict
def train_epoch(epoch, epochs_count, model: nn.Module, device: torch.device,
                optimizer: optim.Optimizer, scaler: Optional[GradScaler],
                criterion_list: List[nn.Module], train_dataloader: DataLoader,
                cfg, train_metrics: Dict, progress_bar: tqdm, run_folder,
                limit_samples, fold_index, folds_count,
                scheduler: _LRScheduler, schedule_lr_each_step):
    epoch_metrics = {'loss': []}
    epoch_metrics.update({metric: [] for metric in cfg['metrics']})

    grad_acc_iters = max(cfg['train_params'].get('grad_acc_iters', 1), 1)
    grad_clipping = cfg['train_params'].get('grad_clipping', 0)
    pseudo_label = cfg['train_params'].get('pseudo_label', False)
    use_fp = cfg['train_data_loader'].get('use_fp', False)
    aux_weights = cfg['train_params'].get('aux_weights', None)

    batch_size = cfg['train_data_loader']['batch_size']
    epoch_samples = len(train_dataloader) * batch_size
    seen_samples = 0

    model.train()

    # reset CustomMetrics
    for metric in criterion_list:  # type:CustomMetrics
        if isinstance(metric, CustomMetrics):
            metric.reset()

    optimizer.zero_grad()
    is_optimizer_update_finished = True

    for iteration_id, data in enumerate(train_dataloader):
        if limit_samples is not None and seen_samples >= limit_samples:  # DEBUG
            break

        data_img, data_class, data_record_ids = data
        batch_size = len(data_img)

        if progress_bar.n == 0:
            progress_bar.reset(
            )  # reset start time to remove time while DataLoader was populating processes

        if scaler is None:
            pass
            # outputs, loss, metrics = forward_pass(data_img, data_class, model, device, criterion_list, pseudo_label,
            #                                       use_fp, aux_weights=aux_weights)
            # if grad_acc_iters > 1:
            #     loss = loss / grad_acc_iters
            # loss.backward()
        else:
            with autocast():
                outputs, loss, metrics = forward_pass(data_img,
                                                      data_class,
                                                      model,
                                                      device,
                                                      criterion_list,
                                                      pseudo_label,
                                                      use_fp,
                                                      aux_weights=aux_weights)
                if grad_acc_iters > 1:
                    loss = loss / grad_acc_iters
            scaler.scale(loss).backward()

        is_optimizer_update_finished = False

        if grad_acc_iters <= 1 or (iteration_id + 1) % grad_acc_iters == 0:
            if scaler is None:
                pass
                # optimizer.step()
            else:
                if grad_clipping > 0:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(optimizer)
                    # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   grad_clipping)

                scaler.step(optimizer)
                scaler.update()

            optimizer.zero_grad()
            is_optimizer_update_finished = True

        # Metrics and progress
        seen_samples += batch_size

        collect_metrics(cfg, epoch_metrics, loss * grad_acc_iters, metrics,
                        criterion_list, '')
        lr = [group['lr'] for group in optimizer.param_groups][0]
        print_info(progress_bar, 'Train', epoch_metrics, loss * grad_acc_iters,
                   lr, '', batch_size, epoch, epochs_count, seen_samples,
                   epoch_samples, fold_index, folds_count)

        if (iteration_id +
                1) % 5 == 0 and iteration_id + 1 < len(train_dataloader):
            log_dict = dict([(key, value[-1])
                             for key, value in epoch_metrics.items()])
            if schedule_lr_each_step and scheduler is not None:
                log_dict['lr'] = optimizer.param_groups[0]['lr']
            wandb.log(log_dict, (epoch - 1) * epoch_samples + seen_samples,
                      commit=True)

        # Step lr scheduler
        if schedule_lr_each_step and scheduler is not None:
            scheduler.step()

    # Finish optimizer step after the not completed gradient accumulation batch
    if not is_optimizer_update_finished:
        if scaler is None:
            optimizer.step()
        else:
            if grad_clipping > 0:
                # Unscales the gradients of optimizer's assigned params in-place
                scaler.unscale_(optimizer)
                # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               grad_clipping)

            scaler.step(optimizer)
            scaler.update()

    for idx, item in enumerate(epoch_metrics.items()):
        key, value = item
        if isinstance(criterion_list[idx], CustomMetrics):
            value = [criterion_list[idx].compute()]
        train_metrics[key].append(np.mean(value))

    progress_bar.write('')
示例#24
0
    def _train(self, epoch: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None,
               validate_interval: int = 10, save: bool = False, amp: bool = False, verbose: bool = True, indent: int = 0,
               loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None,
               get_data_fn: Callable[..., tuple[InputType, torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               validate_func: Callable[..., tuple[float, ...]] = None, epoch_func: Callable[[], None] = None,
               save_fn: Callable = None, file_path: str = None, folder_path: str = None, suffix: str = None, **kwargs):
        loader_train = loader_train if loader_train is not None else self.dataset.loader['train']
        get_data_fn = get_data_fn if get_data_fn is not None else self.get_data
        loss_fn = loss_fn if loss_fn is not None else self.loss
        validate_func = validate_func if validate_func is not None else self._validate
        save_fn = save_fn if save_fn is not None else self.save

        scaler: torch.cuda.amp.GradScaler = None
        if amp and env['num_gpus']:
            scaler = torch.cuda.amp.GradScaler()
        _, best_acc, _ = validate_func(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn,
                                       verbose=verbose, indent=indent, **kwargs)
        losses = AverageMeter('Loss')
        top1 = AverageMeter('Acc@1')
        top5 = AverageMeter('Acc@5')
        params: list[list[nn.Parameter]] = [param_group['params'] for param_group in optimizer.param_groups]
        for _epoch in range(epoch):
            if epoch_func is not None:
                self.activate_params([])
                epoch_func()
                self.activate_params(params)
            losses.reset()
            top1.reset()
            top5.reset()
            epoch_start = time.perf_counter()
            loader = loader_train
            if verbose and env['tqdm']:
                loader = tqdm(loader_train)
            self.train()
            self.activate_params(params)
            optimizer.zero_grad()
            for data in loader:
                # data_time.update(time.perf_counter() - end)
                _input, _label = get_data_fn(data, mode='train')
                if amp and env['num_gpus']:
                    loss = loss_fn(_input, _label, amp=True)
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss = loss_fn(_input, _label)
                    loss.backward()
                    optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    _output = self.get_logits(_input)
                acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                losses.update(loss.item(), batch_size)
                top1.update(acc1, batch_size)
                top5.update(acc5, batch_size)
                empty_cache()
            epoch_time = str(datetime.timedelta(seconds=int(
                time.perf_counter() - epoch_start)))
            self.eval()
            self.activate_params([])
            if verbose:
                pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                    output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35)
                _str = ' '.join([
                    f'Loss: {losses.avg:.4f},'.ljust(20),
                    f'Top1 Acc: {top1.avg:.3f}, '.ljust(20),
                    f'Top5 Acc: {top5.avg:.3f},'.ljust(20),
                    f'Time: {epoch_time},'.ljust(20),
                ])
                prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '',
                       indent=indent)
            if lr_scheduler:
                lr_scheduler.step()

            if validate_interval != 0:
                if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1:
                    _, cur_acc, _ = validate_func(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn,
                                                  verbose=verbose, indent=indent, **kwargs)
                    if cur_acc >= best_acc:
                        prints('best result update!', indent=indent)
                        prints(f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}', indent=indent)
                        best_acc = cur_acc
                        if save:
                            save_fn(file_path=file_path, folder_path=folder_path, suffix=suffix, verbose=verbose)
                    if verbose:
                        print('-' * 50)
        self.zero_grad()
示例#25
0
    def _train(self,
               epoch: int,
               optimizer: Optimizer,
               lr_scheduler: _LRScheduler = None,
               grad_clip: float = None,
               print_prefix: str = 'Epoch',
               start_epoch: int = 0,
               validate_interval: int = 10,
               save: bool = False,
               amp: bool = False,
               loader_train: torch.utils.data.DataLoader = None,
               loader_valid: torch.utils.data.DataLoader = None,
               epoch_fn: Callable[..., None] = None,
               get_data_fn: Callable[..., tuple[torch.Tensor,
                                                torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               after_loss_fn: Callable[..., None] = None,
               validate_fn: Callable[..., tuple[float, float]] = None,
               save_fn: Callable[..., None] = None,
               file_path: str = None,
               folder_path: str = None,
               suffix: str = None,
               writer=None,
               main_tag: str = 'train',
               tag: str = '',
               verbose: bool = True,
               indent: int = 0,
               **kwargs):
        loader_train = loader_train if loader_train is not None else self.dataset.loader[
            'train']
        get_data_fn = get_data_fn if callable(get_data_fn) else self.get_data
        loss_fn = loss_fn if callable(loss_fn) else self.loss
        validate_fn = validate_fn if callable(validate_fn) else self._validate
        save_fn = save_fn if callable(save_fn) else self.save
        # if not callable(iter_fn) and hasattr(self, 'iter_fn'):
        #     iter_fn = getattr(self, 'iter_fn')
        if not callable(epoch_fn) and hasattr(self, 'epoch_fn'):
            epoch_fn = getattr(self, 'epoch_fn')
        if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'):
            after_loss_fn = getattr(self, 'after_loss_fn')

        scaler: torch.cuda.amp.GradScaler = None
        if not env['num_gpus']:
            amp = False
        if amp:
            scaler = torch.cuda.amp.GradScaler()
        _, best_acc = validate_fn(loader=loader_valid,
                                  get_data_fn=get_data_fn,
                                  loss_fn=loss_fn,
                                  writer=None,
                                  tag=tag,
                                  _epoch=start_epoch,
                                  verbose=verbose,
                                  indent=indent,
                                  **kwargs)

        params: list[nn.Parameter] = []
        for param_group in optimizer.param_groups:
            params.extend(param_group['params'])
        total_iter = epoch * len(loader_train)
        for _epoch in range(epoch):
            _epoch += 1
            if callable(epoch_fn):
                self.activate_params([])
                epoch_fn(optimizer=optimizer,
                         lr_scheduler=lr_scheduler,
                         _epoch=_epoch,
                         epoch=epoch,
                         start_epoch=start_epoch)
                self.activate_params(params)
            logger = MetricLogger()
            logger.meters['loss'] = SmoothedValue()
            logger.meters['top1'] = SmoothedValue()
            logger.meters['top5'] = SmoothedValue()
            loader_epoch = loader_train
            if verbose:
                header = '{blue_light}{0}: {1}{reset}'.format(
                    print_prefix, output_iter(_epoch, epoch), **ansi)
                header = header.ljust(30 + get_ansi_len(header))
                if env['tqdm']:
                    header = '{upline}{clear_line}'.format(**ansi) + header
                    loader_epoch = tqdm(loader_epoch)
                loader_epoch = logger.log_every(loader_epoch,
                                                header=header,
                                                indent=indent)
            self.train()
            self.activate_params(params)
            optimizer.zero_grad()
            for i, data in enumerate(loader_epoch):
                _iter = _epoch * len(loader_train) + i
                # data_time.update(time.perf_counter() - end)
                _input, _label = get_data_fn(data, mode='train')
                _output = self(_input, amp=amp)
                loss = loss_fn(_input, _label, _output=_output, amp=amp)
                if amp:
                    scaler.scale(loss).backward()
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params)
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                        # start_epoch=start_epoch, _epoch=_epoch, epoch=epoch)
                    optimizer.step()
                optimizer.zero_grad()
                acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                logger.meters['loss'].update(float(loss), batch_size)
                logger.meters['top1'].update(acc1, batch_size)
                logger.meters['top5'].update(acc5, batch_size)
                empty_cache(
                )  # TODO: should it be outside of the dataloader loop?
            self.eval()
            self.activate_params([])
            loss, acc = logger.meters['loss'].global_avg, logger.meters[
                'top1'].global_avg
            if writer is not None:
                from torch.utils.tensorboard import SummaryWriter
                assert isinstance(writer, SummaryWriter)
                writer.add_scalars(main_tag='Loss/' + main_tag,
                                   tag_scalar_dict={tag: loss},
                                   global_step=_epoch + start_epoch)
                writer.add_scalars(main_tag='Acc/' + main_tag,
                                   tag_scalar_dict={tag: acc},
                                   global_step=_epoch + start_epoch)
            if lr_scheduler:
                lr_scheduler.step()
            if validate_interval != 0:
                if _epoch % validate_interval == 0 or _epoch == epoch:
                    _, cur_acc = validate_fn(loader=loader_valid,
                                             get_data_fn=get_data_fn,
                                             loss_fn=loss_fn,
                                             writer=writer,
                                             tag=tag,
                                             _epoch=_epoch + start_epoch,
                                             verbose=verbose,
                                             indent=indent,
                                             **kwargs)
                    if cur_acc >= best_acc:
                        if verbose:
                            prints('{green}best result update!{reset}'.format(
                                **ansi),
                                   indent=indent)
                            prints(
                                f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}',
                                indent=indent)
                        best_acc = cur_acc
                        if save:
                            save_fn(file_path=file_path,
                                    folder_path=folder_path,
                                    suffix=suffix,
                                    verbose=verbose)
                    if verbose:
                        prints('-' * 50, indent=indent)
        self.zero_grad()
示例#26
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: Number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: Total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data.shuffle()

    loss_sum, iter_count = 0, 0

    # don't use the last batch if it's small, for stability
    num_iters = len(data) // args.batch_size * args.batch_size

    iter_size = args.batch_size

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break

        mol_batch = MoleculeDataset(data[i:i + args.batch_size])

        smiles_batch, features_batch, target_batch = \
            mol_batch.smiles(), mol_batch.features(), mol_batch.targets()

        mask = torch.Tensor([[not np.isnan(x) for x in tb]
                             for tb in target_batch])

        targets = torch.Tensor([[0 if np.isnan(x) else x for x in tb]
                                for tb in target_batch])

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        class_weights = torch.ones(targets.shape)

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(smiles_batch, features_batch)

        # todo: change the loss function for property prediction tasks

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([loss_func(preds[:, target_index, :],
                                        targets[:, target_index]).unsqueeze(1)
                              for target_index in range(preds.size(1))],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask

        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(mol_batch)

        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(
                f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
            debug(
                f'\nLoss = {loss_avg:.4e}, PNorm = {pnorm:.4f},'
                f' GNorm = {gnorm:.4f}, {lrs_str}')

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)

                for idx, learn_rate in enumerate(lrs):
                    writer.add_scalar(
                        f'learning_rate_{idx}', learn_rate, n_iter)

    return n_iter
示例#27
0
def train(module: nn.Module,
          num_classes: int,
          epochs: int,
          optimizer: Optimizer,
          lr_scheduler: _LRScheduler = None,
          lr_warmup_epochs: int = 0,
          model_ema: ExponentialMovingAverage = None,
          model_ema_steps: int = 32,
          grad_clip: float = None,
          pre_conditioner: None | KFAC | EKFAC = None,
          print_prefix: str = 'Train',
          start_epoch: int = 0,
          resume: int = 0,
          validate_interval: int = 10,
          save: bool = False,
          amp: bool = False,
          loader_train: torch.utils.data.DataLoader = None,
          loader_valid: torch.utils.data.DataLoader = None,
          epoch_fn: Callable[..., None] = None,
          get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None,
          forward_fn: Callable[..., torch.Tensor] = None,
          loss_fn: Callable[..., torch.Tensor] = None,
          after_loss_fn: Callable[..., None] = None,
          validate_fn: Callable[..., tuple[float, float]] = None,
          save_fn: Callable[..., None] = None,
          file_path: str = None,
          folder_path: str = None,
          suffix: str = None,
          writer=None,
          main_tag: str = 'train',
          tag: str = '',
          accuracy_fn: Callable[..., list[float]] = None,
          verbose: bool = True,
          output_freq: str = 'iter',
          indent: int = 0,
          change_train_eval: bool = True,
          lr_scheduler_freq: str = 'epoch',
          backward_and_step: bool = True,
          **kwargs):
    r"""Train the model"""
    if epochs <= 0:
        return
    get_data_fn = get_data_fn or (lambda x: x)
    forward_fn = forward_fn or module.__call__
    loss_fn = loss_fn or (lambda _input, _label, _output=None: F.cross_entropy(
        _output or forward_fn(_input), _label))
    validate_fn = validate_fn or validate
    accuracy_fn = accuracy_fn or accuracy

    scaler: torch.cuda.amp.GradScaler = None
    if not env['num_gpus']:
        amp = False
    if amp:
        scaler = torch.cuda.amp.GradScaler()
    best_validate_result = (0.0, float('inf'))
    if validate_interval != 0:
        best_validate_result = validate_fn(loader=loader_valid,
                                           get_data_fn=get_data_fn,
                                           forward_fn=forward_fn,
                                           loss_fn=loss_fn,
                                           writer=None,
                                           tag=tag,
                                           _epoch=start_epoch,
                                           verbose=verbose,
                                           indent=indent,
                                           **kwargs)
        best_acc = best_validate_result[0]

    params: list[nn.Parameter] = []
    for param_group in optimizer.param_groups:
        params.extend(param_group['params'])
    len_loader_train = len(loader_train)
    total_iter = (epochs - resume) * len_loader_train

    logger = MetricLogger()
    logger.create_meters(loss=None, top1=None, top5=None)

    if resume and lr_scheduler:
        for _ in range(resume):
            lr_scheduler.step()
    iterator = range(resume, epochs)
    if verbose and output_freq == 'epoch':
        header: str = '{blue_light}{0}: {reset}'.format(print_prefix, **ansi)
        header = header.ljust(max(len(header), 30) + get_ansi_len(header))
        iterator = logger.log_every(range(resume, epochs),
                                    header=print_prefix,
                                    tqdm_header='Epoch',
                                    indent=indent)
    for _epoch in iterator:
        _epoch += 1
        logger.reset()
        if callable(epoch_fn):
            activate_params(module, [])
            epoch_fn(optimizer=optimizer,
                     lr_scheduler=lr_scheduler,
                     _epoch=_epoch,
                     epochs=epochs,
                     start_epoch=start_epoch)
        loader_epoch = loader_train
        if verbose and output_freq == 'iter':
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                'Epoch', output_iter(_epoch, epochs), **ansi)
            header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header))
            loader_epoch = logger.log_every(loader_train,
                                            header=header,
                                            tqdm_header='Batch',
                                            indent=indent)
        if change_train_eval:
            module.train()
        activate_params(module, params)
        for i, data in enumerate(loader_epoch):
            _iter = _epoch * len_loader_train + i
            # data_time.update(time.perf_counter() - end)
            _input, _label = get_data_fn(data, mode='train')
            if pre_conditioner is not None and not amp:
                pre_conditioner.track.enable()
            _output = forward_fn(_input, amp=amp, parallel=True)
            loss = loss_fn(_input, _label, _output=_output, amp=amp)
            if backward_and_step:
                optimizer.zero_grad()
                if amp:
                    scaler.scale(loss).backward()
                    if callable(after_loss_fn) or grad_clip is not None:
                        scaler.unscale_(optimizer)
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params, grad_clip)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                        # start_epoch=start_epoch, _epoch=_epoch, epochs=epochs)
                    if pre_conditioner is not None:
                        pre_conditioner.track.disable()
                        pre_conditioner.step()
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params, grad_clip)
                    optimizer.step()

            if model_ema and i % model_ema_steps == 0:
                model_ema.update_parameters(module)
                if _epoch <= lr_warmup_epochs:
                    # Reset ema buffer to keep copying weights
                    # during warmup period
                    model_ema.n_averaged.fill_(0)

            if lr_scheduler and lr_scheduler_freq == 'iter':
                lr_scheduler.step()
            acc1, acc5 = accuracy_fn(_output,
                                     _label,
                                     num_classes=num_classes,
                                     topk=(1, 5))
            batch_size = int(_label.size(0))
            logger.update(n=batch_size, loss=float(loss), top1=acc1, top5=acc5)
            empty_cache()
        optimizer.zero_grad()
        if lr_scheduler and lr_scheduler_freq == 'epoch':
            lr_scheduler.step()
        if change_train_eval:
            module.eval()
        activate_params(module, [])
        loss, acc = (logger.meters['loss'].global_avg,
                     logger.meters['top1'].global_avg)
        if writer is not None:
            from torch.utils.tensorboard import SummaryWriter
            assert isinstance(writer, SummaryWriter)
            writer.add_scalars(main_tag='Loss/' + main_tag,
                               tag_scalar_dict={tag: loss},
                               global_step=_epoch + start_epoch)
            writer.add_scalars(main_tag='Acc/' + main_tag,
                               tag_scalar_dict={tag: acc},
                               global_step=_epoch + start_epoch)
        if validate_interval != 0 and (_epoch % validate_interval == 0
                                       or _epoch == epochs):
            validate_result = validate_fn(module=module,
                                          num_classes=num_classes,
                                          loader=loader_valid,
                                          get_data_fn=get_data_fn,
                                          forward_fn=forward_fn,
                                          loss_fn=loss_fn,
                                          writer=writer,
                                          tag=tag,
                                          _epoch=_epoch + start_epoch,
                                          verbose=verbose,
                                          indent=indent,
                                          **kwargs)
            cur_acc = validate_result[0]
            if cur_acc >= best_acc:
                best_validate_result = validate_result
                if verbose:
                    prints('{purple}best result update!{reset}'.format(**ansi),
                           indent=indent)
                    prints(
                        f'Current Acc: {cur_acc:.3f}    '
                        f'Previous Best Acc: {best_acc:.3f}',
                        indent=indent)
                best_acc = cur_acc
                if save:
                    save_fn(file_path=file_path,
                            folder_path=folder_path,
                            suffix=suffix,
                            verbose=verbose)
            if verbose:
                prints('-' * 50, indent=indent)
    module.zero_grad()
    return best_validate_result
示例#28
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data.shuffle()

    loss_sum, iter_count = 0, 0

    iter_size = args.batch_size

    if args.class_balance:
        # Reconstruct data so that each batch has equal number of positives and negatives
        # (will leave out a different random sample of negatives each epoch)
        assert len(
            data[0].targets) == 1  # only works for single class classification
        pos = [d for d in data if d.targets[0] == 1]
        neg = [d for d in data if d.targets[0] == 0]

        new_data = []
        pos_size = iter_size // 2
        pos_index = neg_index = 0
        while True:
            new_pos = pos[pos_index:pos_index + pos_size]
            new_neg = neg[neg_index:neg_index + iter_size - len(new_pos)]

            if len(new_pos) == 0 or len(new_neg) == 0:
                break

            if len(new_pos) + len(new_neg) < iter_size:
                new_pos = pos[pos_index:pos_index + iter_size - len(new_neg)]

            new_data += new_pos + new_neg

            pos_index += len(new_pos)
            neg_index += len(new_neg)

        data = new_data

    num_iters = len(
        data
    ) // args.batch_size * args.batch_size  # don't use the last batch if it's small, for stability

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break
        mol_batch = MoleculeDataset(data[i:i + args.batch_size])
        smiles_batch, features_batch, target_batch, weight_batch = mol_batch.smiles(
        ), mol_batch.features(), mol_batch.targets(), mol_batch.weights()
        batch = smiles_batch
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])
        weights = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in weight_batch])
        #        print (weight_batch)
        #        print (weights)

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        if args.enable_weight:
            class_weights = weights
        else:
            class_weights = torch.ones(targets.shape)
#        print(class_weights)

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([
                loss_func(preds[:, target_index, :],
                          targets[:, target_index]).unsqueeze(1)
                for target_index in range(preds.size(1))
            ],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
#        print ("loss")
#        print (loss)
#        print (class_weights)

        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(mol_batch)

        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
示例#29
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None,
          chunk_names: bool = False,
          val_smiles: List[str] = None,
          test_smiles: List[str] = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :param chunk_names: Whether to train on the data in chunks. In this case,
    data must be a list of paths to the data chunks.
    :param val_smiles: Validation smiles strings without targets.
    :param test_smiles: Test smiles strings without targets, used for adversarial setting.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    if args.dataset_type == 'bert_pretraining':
        features_loss = nn.MSELoss()

    if chunk_names:
        for path, memo_path in tqdm(data, total=len(data)):
            featurization.SMILES_TO_FEATURES = dict()
            if os.path.isfile(memo_path):
                found_memo = True
                with open(memo_path, 'rb') as f:
                    featurization.SMILES_TO_FEATURES = pickle.load(f)
            else:
                found_memo = False
            with open(path, 'rb') as f:
                chunk = pickle.load(f)
            if args.moe:
                for source in chunk:
                    source.shuffle()
            else:
                chunk.shuffle()
            n_iter = train(model=model,
                           data=chunk,
                           loss_func=loss_func,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           args=args,
                           n_iter=n_iter,
                           logger=logger,
                           writer=writer,
                           chunk_names=False,
                           val_smiles=val_smiles,
                           test_smiles=test_smiles)
            if not found_memo:
                with open(memo_path, 'wb') as f:
                    pickle.dump(featurization.SMILES_TO_GRAPH,
                                f,
                                protocol=pickle.HIGHEST_PROTOCOL)
        return n_iter

    if not args.moe:
        data.shuffle()

    loss_sum, iter_count = 0, 0
    if args.adversarial:
        if args.moe:
            train_smiles = []
            for d in data:
                train_smiles += d.smiles()
        else:
            train_smiles = data.smiles()
        train_val_smiles = train_smiles + val_smiles
        d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0

    if args.moe:
        test_smiles = list(test_smiles)
        random.shuffle(test_smiles)
        train_smiles = []
        for d in data:
            d.shuffle()
            train_smiles.append(d.smiles())
        num_iters = min(len(test_smiles), min([len(d) for d in data]))
    elif args.maml:
        num_iters = args.maml_batches_per_epoch * args.maml_batch_size
        model.zero_grad()
        maml_sum_loss = 0
    else:
        num_iters = len(data) if args.last_batch else len(
            data) // args.batch_size * args.batch_size

    if args.parallel_featurization:
        batch_queue = Queue(args.batch_queue_max_size)
        exit_queue = Queue(1)
        batch_process = Process(target=async_mol2graph,
                                args=(batch_queue, data, args, num_iters,
                                      args.batch_size, exit_queue,
                                      args.last_batch))
        batch_process.start()
        currently_loaded_batches = []

    iter_size = 1 if args.maml else args.batch_size

    for i in trange(0, num_iters, iter_size):
        if args.moe:
            if not args.batch_domain_encs:
                model.compute_domain_encs(
                    train_smiles)  # want to recompute every batch
            mol_batch = [
                MoleculeDataset(d[i:i + args.batch_size]) for d in data
            ]
            train_batch, train_targets = [], []
            for b in mol_batch:
                tb, tt = b.smiles(), b.targets()
                train_batch.append(tb)
                train_targets.append(tt)
            test_batch = test_smiles[i:i + args.batch_size]
            loss = model.compute_loss(train_batch, train_targets, test_batch)
            model.zero_grad()

            loss_sum += loss.item()
            iter_count += len(mol_batch)
        elif args.maml:
            task_train_data, task_test_data, task_idx = data.sample_maml_task(
                args)
            mol_batch = task_test_data
            smiles_batch, features_batch, target_batch = task_train_data.smiles(
            ), task_train_data.features(), task_train_data.targets(task_idx)
            # no mask since we only picked data points that have the desired target
            targets = torch.Tensor(target_batch).unsqueeze(1)
            if next(model.parameters()).is_cuda:
                targets = targets.cuda()
            preds = model(smiles_batch, features_batch)
            loss = loss_func(preds, targets)
            loss = loss.sum() / len(smiles_batch)
            grad = torch.autograd.grad(
                loss, [p for p in model.parameters() if p.requires_grad])
            theta = [
                p for p in model.named_parameters() if p[1].requires_grad
            ]  # comes in same order as grad
            theta_prime = {
                p[0]: p[1] - args.maml_lr * grad[i]
                for i, p in enumerate(theta)
            }
            for name, nongrad_param in [
                    p for p in model.named_parameters()
                    if not p[1].requires_grad
            ]:
                theta_prime[name] = nongrad_param + torch.zeros(
                    nongrad_param.size()).to(nongrad_param)
        else:
            # Prepare batch
            if args.parallel_featurization:
                if len(currently_loaded_batches) == 0:
                    currently_loaded_batches = batch_queue.get()
                mol_batch, featurized_mol_batch = currently_loaded_batches.pop(
                )
            else:
                if not args.last_batch and i + args.batch_size > len(data):
                    break
                mol_batch = MoleculeDataset(data[i:i + args.batch_size])
            smiles_batch, features_batch, target_batch = mol_batch.smiles(
            ), mol_batch.features(), mol_batch.targets()

            if args.dataset_type == 'bert_pretraining':
                batch = mol2graph(smiles_batch, args)
                mask = mol_batch.mask()
                batch.bert_mask(mask)
                mask = 1 - torch.FloatTensor(mask)  # num_atoms
                features_targets = torch.FloatTensor(
                    target_batch['features']
                ) if target_batch[
                    'features'] is not None else None  # num_molecules x features_size
                targets = torch.FloatTensor(target_batch['vocab'])  # num_atoms
                if args.bert_vocab_func == 'feature_vector':
                    mask = mask.reshape(-1, 1)
                else:
                    targets = targets.long()
            else:
                batch = smiles_batch
                mask = torch.Tensor([[x is not None for x in tb]
                                     for tb in target_batch])
                targets = torch.Tensor([[0 if x is None else x for x in tb]
                                        for tb in target_batch])

            if next(model.parameters()).is_cuda:
                mask, targets = mask.cuda(), targets.cuda()

                if args.dataset_type == 'bert_pretraining' and features_targets is not None:
                    features_targets = features_targets.cuda()

            if args.class_balance:
                class_weights = []
                for task_num in range(data.num_tasks()):
                    class_weights.append(
                        args.class_weights[task_num][targets[:,
                                                             task_num].long()])
                class_weights = torch.stack(
                    class_weights).t()  # num_molecules x num_tasks
            else:
                class_weights = torch.ones(targets.shape)

            if args.cuda:
                class_weights = class_weights.cuda()

            # Run model
            model.zero_grad()
            if args.parallel_featurization:
                previous_graph_input_mode = model.encoder.graph_input
                model.encoder.graph_input = True  # force model to accept already processed input
                preds = model(featurized_mol_batch, features_batch)
                model.encoder.graph_input = previous_graph_input_mode
            else:
                preds = model(batch, features_batch)
            if args.dataset_type == 'regression_with_binning':
                preds = preds.view(targets.size(0), targets.size(1), -1)
                targets = targets.long()
                loss = 0
                for task in range(targets.size(1)):
                    loss += loss_func(
                        preds[:, task, :], targets[:, task]
                    ) * class_weights[:,
                                      task] * mask[:,
                                                   task]  # for some reason cross entropy doesn't support multi target
                loss = loss.sum() / mask.sum()
            else:
                if args.dataset_type == 'unsupervised':
                    targets = targets.long().reshape(-1)

                if args.dataset_type == 'bert_pretraining':
                    features_preds, preds = preds['features'], preds['vocab']

                if args.dataset_type == 'kernel':
                    preds = preds.view(int(preds.size(0) / 2), 2,
                                       preds.size(1))
                    preds = model.kernel_output_layer(preds)

                loss = loss_func(preds, targets) * class_weights * mask
                if args.predict_features_and_task:
                    loss = (loss.sum() + loss[:, :-args.features_size].sum() * (args.task_weight-1)) \
                                / (mask.sum() + mask[:, :-args.features_size].sum() * (args.task_weight-1))
                else:
                    loss = loss.sum() / mask.sum()

                if args.dataset_type == 'bert_pretraining' and features_targets is not None:
                    loss += features_loss(features_preds, features_targets)

            loss_sum += loss.item()
            iter_count += len(mol_batch)

        if args.maml:
            model_prime = build_model(args=args, params=theta_prime)
            smiles_batch, features_batch, target_batch = task_test_data.smiles(
            ), task_test_data.features(), [
                t[task_idx] for t in task_test_data.targets()
            ]
            # no mask since we only picked data points that have the desired target
            targets = torch.Tensor([[t] for t in target_batch])
            if next(model_prime.parameters()).is_cuda:
                targets = targets.cuda()
            model_prime.zero_grad()
            preds = model_prime(smiles_batch, features_batch)
            loss = loss_func(preds, targets)
            loss = loss.sum() / len(smiles_batch)
            loss_sum += loss.item()
            iter_count += len(
                smiles_batch
            )  # TODO check that this makes sense, but it's just for display
            maml_sum_loss += loss
            if i % args.maml_batch_size == args.maml_batch_size - 1:
                maml_sum_loss.backward()
                optimizer.step()
                model.zero_grad()
                maml_sum_loss = 0
        else:
            loss.backward()
            if args.max_grad_norm is not None:
                clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()

        if args.adjust_weight_decay:
            current_pnorm = compute_pnorm(model)
            if current_pnorm < args.pnorm_target:
                for i in range(len(optimizer.param_groups)):
                    optimizer.param_groups[i]['weight_decay'] = max(
                        0, optimizer.param_groups[i]['weight_decay'] -
                        args.adjust_weight_decay_step)
            else:
                for i in range(len(optimizer.param_groups)):
                    optimizer.param_groups[i][
                        'weight_decay'] += args.adjust_weight_decay_step

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        if args.adversarial:
            for _ in range(args.gan_d_per_g):
                train_val_smiles_batch = random.sample(train_val_smiles,
                                                       args.batch_size)
                test_smiles_batch = random.sample(test_smiles, args.batch_size)
                d_loss, gp_norm = model.train_D(train_val_smiles_batch,
                                                test_smiles_batch)
            train_val_smiles_batch = random.sample(train_val_smiles,
                                                   args.batch_size)
            test_smiles_batch = random.sample(test_smiles, args.batch_size)
            g_loss = model.train_G(train_val_smiles_batch, test_smiles_batch)

            # we probably only care about the g_loss honestly
            d_loss_sum += d_loss * args.batch_size
            gp_norm_sum += gp_norm * args.batch_size
            g_loss_sum += g_loss * args.batch_size

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            if args.adversarial:
                d_loss_avg, g_loss_avg, gp_norm_avg = d_loss_sum / iter_count, g_loss_sum / iter_count, gp_norm_sum / iter_count
                d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join('lr_{} = {:.4e}'.format(i, lr)
                                for i, lr in enumerate(lrs))
            debug("Loss = {:.4e}, PNorm = {:.4f}, GNorm = {:.4f}, {}".format(
                loss_avg, pnorm, gnorm, lrs_str))
            if args.adversarial:
                debug(
                    "D Loss = {:.4e}, G Loss = {:.4e}, GP Norm = {:.4}".format(
                        d_loss_avg, g_loss_avg, gp_norm_avg))

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar('learning_rate_{}'.format(i), lr, n_iter)

    if args.parallel_featurization:
        exit_queue.put(
            0)  # dummy var to get the subprocess to know that we're done
        batch_process.join()

    return n_iter
示例#30
0
def train(model: nn.Module, data_loader: MoleculeDataLoader,
          loss_func: Callable, optimizer: Optimizer,
          scheduler: _LRScheduler, args: Namespace,
          n_iter: int = 0, logger: Optional[Logger] = None,
          writer: Optional = None) -> int:
    """Trains a model for an epoch

    Parameters
    ----------
    model : nn.Module
        the model to train
    data_loader : MoleculeDataLoader
        an iterable of MoleculeDatasets
    loss_func : Callable
        the loss function
    optimizer : Optimizer
        the optimizer
    scheduler : _LRScheduler
        the learning rate scheduler
    args : Namespace
        a Namespace object the containing necessary attributes
    n_iter : int
        the current number of training iterations
    logger : Optional[Logger] (Default = None)
        a logger for printing intermediate results. If None, print
        intermediate results to stdout
    writer : Optional[SummaryWriter] (Default = None)
        A tensorboardX SummaryWriter

    Returns
    -------
    n_iter : int
        The total number of iterations (training examples) trained on so far
    """
    model.train()
    # loss_sum = 0
    # iter_count = 0

    for batch in tqdm(data_loader, desc='Step', unit='minibatch',
                      leave=False):
        # Prepare batch
        mol_batch = batch.batch_graph()
        features_batch = batch.features()

        # Run model
        model.zero_grad()
        preds = model(mol_batch, features_batch)        

        targets = batch.targets()   # targets might have None's
        mask = torch.Tensor(
            [list(map(bool, ys)) for ys in targets]).to(preds.device)
        targets = torch.Tensor(
            [[y or 0 for y in ys] for ys in targets]).to(preds.device)
        class_weights = torch.ones(targets.shape).to(preds.device)
        
        # if args.dataset_type == 'multiclass':
        #     targets = targets.long()
        #     loss = (torch.cat([
        #         loss_func(preds[:, target_index, :],
        #                    targets[:, target_index]).unsqueeze(1)
        #         for target_index in range(preds.size(1))
        #         ], dim=1) * class_weights * mask
        #     )

        if model.uncertainty:
            pred_means = preds[:, 0::2]
            pred_vars = preds[:, 1::2]

            loss = loss_func(pred_means, pred_vars, targets)
        else:
            loss = loss_func(preds, targets) * class_weights * mask

        loss = loss.sum() / mask.sum()

        # loss_sum += loss.item()
        # iter_count += len(batch)

        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(batch)

        # Log and/or add to tensorboard
        # if (n_iter // args.batch_size) % args.log_frequency == 0:
        #     lrs = scheduler.get_lr()
        #     pnorm = compute_pnorm(model)
        #     gnorm = compute_gnorm(model)
        #     loss_avg = loss_sum / iter_count
        #     loss_sum, iter_count = 0, 0

        #     lrs_str = ', '.join(
        #         f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
        #     debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, '
        #           + f'GNorm = {gnorm:.4f}, {lrs_str}')

        #     if writer:
        #         writer.add_scalar('train_loss', loss_avg, n_iter)
        #         writer.add_scalar('param_norm', pnorm, n_iter)
        #         writer.add_scalar('gradient_norm', gnorm, n_iter)
        #         for i, lr in enumerate(lrs):
        #             writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter