Beispiel #1
0
def train(dataloader: torch.utils.data.DataLoader, model: torch.nn.Module,
          loss_func: Any, optimizer: torch.optim.Optimizer, epoch: int,
          args: argparse.Namespace):
    """
    Train the given model for a single epoch using the given dataloader.

    Args:
        dataloader: The dataloader containing the training data.
        model: Instance of the model that is being trained.
        loss_func: A loss function to compute the error between the
            actual and the desired output of the model.
        optimizer: An instance of an optimizer that is used to compute
            and perform the updates to the weights of the network.
        epoch: The current training epoch.
        args: Namespace object containing some global variable (e.g.,
            command line arguments, such as the batch size)
    """

    # -------------------------------------------------------------------------
    # Preliminaries
    # -------------------------------------------------------------------------

    # Activate training mode
    model.train()

    # Keep track the time to process a batch, as well as the batch losses
    batch_times = AverageMeter()
    batch_losses = AverageMeter()

    # -------------------------------------------------------------------------
    # Process the training dataset in mini-batches
    # -------------------------------------------------------------------------

    for batch_idx, (data, target) in enumerate(dataloader):

        # Initialize start time of the batch
        batch_start = time.time()

        # Fetch data and move to device
        data, target = data.to(args.device), target.to(args.device)
        target = target.squeeze()

        # Clear gradients
        optimizer.zero_grad()

        # Compute forward pass through model
        output = model.forward(data).squeeze()

        # Calculate the loss for the batch
        loss = loss_func(output, target)

        # Back-propagate the loss and update the weights
        loss.backward()
        optimizer.step(closure=None)

        # ---------------------------------------------------------------------
        # Log information about current batch to TensorBoard
        # ---------------------------------------------------------------------

        if args.tensorboard:

            # Compute how many examples we have processed already and log the
            # loss value for the current batch
            global_step = ((epoch - 1) * args.n_train_batches + batch_idx) * \
                          args.batch_size
            args.logger.add_scalar(tag='loss/train',
                                   scalar_value=loss.item(),
                                   global_step=global_step)

        # ---------------------------------------------------------------------
        # Additional logging to console
        # ---------------------------------------------------------------------

        # Store the loss and processing time for the current batch
        batch_losses.update(loss.item())
        batch_times.update(time.time() - batch_start)

        # Print information to console, if applicable
        if batch_idx % args.log_interval == 0:

            # Which fraction of batches have we already processed this epoch?
            percent = 100. * batch_idx / args.n_train_batches

            # Print some information about how the training is going
            print(f'Epoch: {epoch:>3}/{args.epochs}', end=' | ', flush=True)
            print(f'Batch: {batch_idx:>3}/{args.n_train_batches}',
                  flush=True,
                  end=' ')
            print(f'({percent:>4.1f}%)', end=' | ', flush=True)
            print(f'Loss: {loss.item():.6f}', end=' | ', flush=True)
            print(f'Time: {batch_times.value:>6.3f}s', flush=True)
Beispiel #2
0
def train_one_epoch(model: torch.nn.Module,
                    criterion: torch.nn.Module,
                    data_loader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device,
                    epoch: int,
                    max_norm: float = 0,
                    neptune=None):
    model.train()
    criterion.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter(
        'class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    for samples, targets in metric_logger.log_every(data_loader, print_freq,
                                                    header):
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(samples)
        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()
                     if k in weight_dict)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {
            f'{k}_unscaled': v
            for k, v in loss_dict_reduced.items()
        }
        loss_dict_reduced_scaled = {
            k: v * weight_dict[k]
            for k, v in loss_dict_reduced.items() if k in weight_dict
        }
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()
        if neptune:
            neptune.log_metric('train/loss', loss_value)

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        if max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

        metric_logger.update(loss=loss_value,
                             **loss_dict_reduced_scaled,
                             **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
Beispiel #3
0
 def optimizer_step(self, optimizer: torch.optim.Optimizer,
                    lambda_closure: Callable, **kwargs):
     optimizer.step(closure=lambda_closure, **kwargs)
Beispiel #4
0
def train_epoch_kontschieder(tree: ProtoTree,
                train_loader: DataLoader,
                optimizer: torch.optim.Optimizer,
                epoch: int,
                disable_derivative_free_leaf_optim: bool,
                device,
                log: Log = None,
                log_prefix: str = 'log_train_epochs',
                progress_prefix: str = 'Train Epoch'
                ) -> dict:

    tree = tree.to(device)

    # Store info about the procedure
    train_info = dict()
    total_loss = 0.
    total_acc = 0.

    # Create a log if required
    log_loss = f'{log_prefix}_losses'
    if log is not None and epoch==1:
        log.create_log(log_loss, 'epoch', 'batch', 'loss', 'batch_train_acc')
    
    # Reset the gradients
    optimizer.zero_grad()

    if disable_derivative_free_leaf_optim:
        print("WARNING: kontschieder arguments will be ignored when training leaves with gradient descent")
    else:
        if tree._kontschieder_normalization:
            # Iterate over the dataset multiple times to learn leaves following Kontschieder's approach
            for _ in range(10):
                # Train leaves with derivative-free algorithm using normalization factor
                train_leaves_epoch(tree, train_loader, epoch, device)
        else:
            # Train leaves with Kontschieder's derivative-free algorithm, but using softmax
            train_leaves_epoch(tree, train_loader, epoch, device)
    # Train prototypes and network. 
    # If disable_derivative_free_leaf_optim, leafs are optimized with gradient descent as well.
    # Show progress on progress bar
    train_iter = tqdm(enumerate(train_loader),
                        total=len(train_loader),
                        desc=progress_prefix+' %s'%epoch,
                        ncols=0)
    # Make sure the model is in train mode
    tree.train()
    for i, (xs, ys) in train_iter:
        xs, ys = xs.to(device), ys.to(device)

        # Reset the gradients
        optimizer.zero_grad()
        # Perform a forward pass through the network
        ys_pred, _ = tree.forward(xs)
        # Compute the loss
        if tree._log_probabilities:
            loss = F.nll_loss(ys_pred, ys)
        else:
            loss = F.nll_loss(torch.log(ys_pred), ys)
        # Compute the gradient
        loss.backward()
        # Update model parameters
        optimizer.step()

        # Count the number of correct classifications
        ys_pred = torch.argmax(ys_pred, dim=1)
        
        correct = torch.sum(torch.eq(ys_pred, ys))
        acc = correct.item() / float(len(xs))

        train_iter.set_postfix_str(
            f'Batch [{i + 1}/{len(train_loader)}], Loss: {loss.item():.3f}, Acc: {acc:.3f}'
        )
        # Compute metrics over this batch
        total_loss+=loss.item()
        total_acc+=acc

        if log is not None:
            log.log_values(log_loss, epoch, i + 1, loss.item(), acc)
        
    train_info['loss'] = total_loss/float(i+1)
    train_info['train_accuracy'] = total_acc/float(i+1)
    return train_info 
def train(model: torch.nn.Module,
          train_dls: List[DataLoader],
          optimizer: torch.optim.Optimizer,
          scheduler: LambdaLR,
          validation_evaluator: MultiDatasetClassificationEvaluator,
          n_epochs: int,
          device: AnyStr,
          log_interval: int = 1,
          patience: int = 10,
          model_dir: str = "wandb_local",
          gradient_accumulation: int = 1,
          domain_name: str = ''):
    #best_loss = float('inf')
    best_f1 = 0.0
    patience_counter = 0

    epoch_counter = 0
    total = sum(len(dl) for dl in train_dls)

    # Main loop
    while epoch_counter < n_epochs:
        dl_iters = [iter(dl) for dl in train_dls]
        dl_idx = list(range(len(dl_iters)))
        finished = [0] * len(dl_iters)
        i = 0
        with tqdm(total=total, desc="Training") as pbar:
            while sum(finished) < len(dl_iters):
                random.shuffle(dl_idx)
                for d in dl_idx:
                    domain_dl = dl_iters[d]
                    batches = []
                    try:
                        for j in range(gradient_accumulation):
                            batches.append(next(domain_dl))
                    except StopIteration:
                        finished[d] = 1
                        if len(batches) == 0:
                            continue
                    optimizer.zero_grad()
                    for batch in batches:
                        model.train()
                        batch = tuple(t.to(device) for t in batch)
                        input_ids = batch[0]
                        masks = batch[1]
                        labels = batch[2]
                        # Null the labels if its the test data
                        if d == len(train_dls) - 1:
                            labels = None
                        # Testing with random domains to see if any effect
                        #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
                        domains = batch[3]

                        loss, logits, alpha = model(input_ids,
                                                    attention_mask=masks,
                                                    domains=domains,
                                                    labels=labels,
                                                    ret_alpha=True)
                        loss = loss.mean() / gradient_accumulation
                        if i % log_interval == 0:
                            # wandb.log({
                            #     "Loss": loss.item(),
                            #     "alpha0": alpha[:,0].cpu(),
                            #     "alpha1": alpha[:, 1].cpu(),
                            #     "alpha2": alpha[:, 2].cpu(),
                            #     "alpha_shared": alpha[:, 3].cpu()
                            # })
                            wandb.log({"Loss": loss.item()})

                        loss.backward()
                        i += 1
                        pbar.update(1)

                    optimizer.step()
                    if scheduler is not None:
                        scheduler.step()

        gc.collect()

        # Inline evaluation
        (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
        print(f"Validation F1: {F1}")

        #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')

        # Saving the best model and early stopping
        #if val_loss < best_loss:
        if F1 > best_f1:
            best_model = model.state_dict()
            #best_loss = val_loss
            best_f1 = F1
            #wandb.run.summary['best_validation_loss'] = best_loss
            torch.save(
                model.state_dict(),
                f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth'
            )
            patience_counter = 0
            # Log to wandb
            wandb.log({
                'Validation accuracy': acc,
                'Validation Precision': P,
                'Validation Recall': R,
                'Validation F1': F1,
                'Validation loss': val_loss
            })
        else:
            patience_counter += 1
            # Stop training once we have lost patience
            if patience_counter == patience:
                break

        gc.collect()
        epoch_counter += 1
Beispiel #6
0
def train_one_epoch(model: torch.nn.Module,
                    criterion: torch.nn.Module,
                    data_loader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device,
                    epoch: int,
                    max_norm: float = 0,
                    writer=None,
                    args=None):
    model.train()
    criterion.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter(
        'class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10
    losses_items = []

    FNs, FPs, TPs, AVGs, TAR = [], [], [], [], []

    # for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
    for i, (samples, targets, info) in enumerate(data_loader):
        samples = samples.to(device)
        targets = [t.to(device) for t in targets]

        outputs = model(samples)

        # import numpy as np
        # couples = []
        # for x in np.arange(1/10, 1, 1/5):
        #     for y in np.arange(1/12, 1, 1/6):
        #         couples.append(torch.tensor([x, y]))
        # outputs['pred_boxes'][0] = torch.cat(couples).view(-1, 2)

        loss_dict, indices = criterion(outputs, targets)

        if epoch % 50 == 0 or epoch == (args.epochs - 1):
            step = (epoch * len(data_loader) + i) * args.batch_size
            plot_images(writer,
                        step,
                        samples,
                        outputs,
                        targets,
                        indices,
                        epoch,
                        i,
                        tag='train',
                        folder=args.comment)

        for d in range(len(samples)):
            FN, FP, TP, in_dist = spine_evaluation(outputs['pred_boxes'][d],
                                                   outputs['pred_logits'][d],
                                                   targets[d][:, 1:3], info[d],
                                                   args)
            FNs.append(FN)
            FPs.append(FP)
            TPs.append(TP)
            TAR.append(len(targets[d]))
            AVGs.append(in_dist)

        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()
                     if k in weight_dict).float()

        not_used_keys = [
            k for k in loss_dict.keys() if k not in weight_dict.keys()
        ]
        if len(not_used_keys) > 0 and i == 0:
            print(
                f'[WARNING] these keys are not used to calculate the loss: {not_used_keys}'
            )

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {
            f'{k}_unscaled': v
            for k, v in loss_dict_reduced.items()
        }
        loss_dict_reduced_scaled = {
            k: v * weight_dict[k]
            for k, v in loss_dict_reduced.items() if k in weight_dict
        }
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()
        losses_items.append(loss_value)
        print(
            f"{epoch:03d}_{i:03d} loss_value: {loss_value:.04f} mean {mean(losses_items):.04f} loss_centers {loss_dict['loss_centers'].item():.04f} loss_bce {loss_dict['loss_bce'].item():.04f} loss_spine_l1 {loss_dict['loss_spine_l1'].item():.04f} id: {info[0]['patient_id']}"
        )
        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        if max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

        metric_logger.update(loss=loss_value,
                             **loss_dict_reduced_scaled,
                             **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    writer.add_scalar('train_metric/FN',
                      sum(FNs) / sum(TAR),
                      global_step=epoch)
    writer.add_scalar('train_metric/FP',
                      sum(FPs) / sum(TAR),
                      global_step=epoch)
    writer.add_scalar('train_metric/TP',
                      sum(TPs) / sum(TAR),
                      global_step=epoch)
    if len(torch.cat(AVGs)) > 0:
        writer.add_scalar('train_metric/avg_dist',
                          torch.cat(AVGs).mean(),
                          global_step=epoch)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
Beispiel #7
0
def train_epoch(model: nn.Module, optimizer: torch.optim.Optimizer,
                loss_func: nn.Module, loader: DataLoader, cfg: Dict,
                epoch: int, use_mse: bool):
    print('train_epoch')
    """Train model for a single epoch.

    Parameters
    ----------
    model : nn.Module
        The PyTorch model to train
    optimizer : torch.optim.Optimizer
        Optimizer used for weight updating
    loss_func : nn.Module
        The loss function, implemented as a PyTorch Module
    loader : DataLoader
        PyTorch DataLoader containing the training data in batches.
    cfg : Dict
        Dictionary containing the run config
    epoch : int
        Current Number of epoch
    use_mse : bool
        If True, loss_func is nn.MSELoss(), else NSELoss() which expects addtional std of discharge
        vector

    """
    model.train()

    # process bar handle
    pbar = tqdm(loader, file=sys.stdout)
    pbar.set_description(f'# Epoch {epoch}')

    # Iterate in batches over training set
    for data in pbar:
        #        print('\n')
        # delete old gradients
        optimizer.zero_grad()

        # forward pass through LSTM
        if len(data) == 3:
            x, y, q_stds = data
            x, y, q_stds = x.to(DEVICE), y.to(DEVICE), q_stds.to(DEVICE)
            predictions = model(x)[0]

        # forward pass through EALSTM
        elif len(data) == 4:
            x_d, x_s, y, q_stds = data
            x_d, x_s, y = x_d.to(DEVICE), x_s.to(DEVICE), y.to(DEVICE)
            predictions = model(x_d, x_s)[0]

        # MSELoss
        mask = ~torch.isnan(predictions)
        if use_mse:
            loss = loss_func(predictions[mask], y[mask])

        # NSELoss needs std of each basin for each sample
        else:
            q_stds = q_stds.to(DEVICE)
            loss = loss_func(predictions[mask], y[mask], q_stds)

        # calculate gradients
        loss.backward()

        if cfg["clip_norm"]:
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           cfg["clip_value"])

        # perform parameter update
        optimizer.step()

        pbar.set_postfix_str(f"Loss: {loss.item():5f}")
Beispiel #8
0
    def attach(self, optimizer: torch.optim.Optimizer):
        r"""
        Attaches the privacy engine to the optimizer.

        Attaches to the ``PrivacyEngine`` an optimizer object,and injects
        itself into the optimizer's step. To do that it,

        1. Validates that the model does not have unsupported layers.

        2. Adds a pointer to this object (the ``PrivacyEngine``) inside the optimizer.

        3. Moves optimizer's original ``step()`` function to ``original_step()``.

        4. Monkeypatches the optimizer's ``step()`` function to call ``step()`` on
        the query engine automatically whenever it would call ``step()`` for itself.

        Args:
            optimizer: The optimizer to which the privacy engine will attach
        """
        if hasattr(optimizer, "privacy_engine"):
            if optimizer.privacy_engine != self:
                raise ValueError(
                    f"Trying to attach to optimizer: {optimizer}, but that optimizer is "
                    f"already attached to a different Privacy Engine: {optimizer.privacy_engine}."
                )
            else:
                warnings.warn(
                    "Trying to attach twice to the same optimizer. Nothing to do."
                )
                return

        self.validator.validate(self.module)
        norm_clipper = (clipping.ConstantFlatClipper(self.max_grad_norm)
                        if not isinstance(self.max_grad_norm, list) else
                        clipping.ConstantPerLayerClipper(self.max_grad_norm))

        if self.misc_settings.get("experimental", False):
            norm_clipper = clipping._Dynamic_Clipper_(
                [self.max_grad_norm],
                self.misc_settings.get("clip_per_layer", False),
                self.misc_settings.get("clipping_method",
                                       clipping.ClippingMethod.STATIC),
                self.misc_settings.get("clipping_ratio", 0.0),
                self.misc_settings.get("clipping_momentum", 0.0),
            )

        self.clipper = PerSampleGradientClipper(
            self.module,
            norm_clipper,
            self.batch_first,
            self.loss_reduction,
        )

        def dp_zero_grad(self):
            self.privacy_engine.zero_grad()
            self.original_zero_grad()

        def dp_step(self, closure=None, is_empty=False):
            self.privacy_engine.step(is_empty)
            if isinstance(self.privacy_engine.module,
                          DifferentiallyPrivateDistributedDataParallel):
                average_gradients(self.privacy_engine.module)
            self.original_step(closure)

        def poisson_dp_step(self, closure=None):
            # Perform one step as usual
            self.dp_step(closure)

            # Taking empty steps to simulate empty batches
            num_empty_batches = self.privacy_engine._sample_poisson_empty_batches(
            )
            for _ in range(num_empty_batches):
                self.zero_grad()
                self.dp_step(closure, is_empty=True)

        optimizer.privacy_engine = self

        optimizer.dp_step = types.MethodType(dp_step, optimizer)
        optimizer.original_step = optimizer.step
        optimizer.step = types.MethodType(
            poisson_dp_step if self.poisson else dp_step, optimizer)

        optimizer.original_zero_grad = optimizer.zero_grad
        optimizer.zero_grad = types.MethodType(dp_zero_grad, optimizer)

        def virtual_step(self):
            self.privacy_engine.virtual_step()

        optimizer.virtual_step = types.MethodType(virtual_step, optimizer)

        # create a cross reference for detaching
        self.optimizer = optimizer

        if self.poisson:
            # Optional initial step on empty batch
            num_empty_batches = self._sample_poisson_empty_batches()
            for _ in range(num_empty_batches):
                self.optimizer.zero_grad()
                for p in self.module.parameters():
                    if p.requires_grad:
                        p.grad = torch.zeros_like(p)
                self.optimizer.dp_step(closure=None, is_empty=True)
Beispiel #9
0
def train(
    dataset: torch.utils.data.Dataset,
    model: torch.nn.Module,
    epochs: int,
    batch_size: int,
    optimizer: torch.optim.Optimizer,
    stopping_delta: Optional[float] = None,
    collate_fn=default_collate,
    cuda: bool = True,
    sampler: Optional[torch.utils.data.sampler.Sampler] = None,
    silent: bool = False,
    update_freq: int = 10,
    evaluate_batch_size: int = 1024,
    update_callback: Optional[Callable[[float, float], None]] = None,
    epoch_callback: Optional[Callable[[int, torch.nn.Module], None]] = None,
) -> None:
    """
    Train the DEC model given a dataset, a model instance and various configuration parameters.

    :param dataset: instance of Dataset to use for training
    :param model: instance of DEC model to train
    :param epochs: number of training epochs
    :param batch_size: size of the batch to train with
    :param optimizer: instance of optimizer to use
    :param stopping_delta: label delta as a proportion to use for stopping, None to disable, default None
    :param collate_fn: function to merge a list of samples into mini-batch
    :param cuda: whether to use CUDA, defaults to True
    :param sampler: optional sampler to use in the DataLoader, defaults to None
    :param silent: set to True to prevent printing out summary statistics, defaults to False
    :param update_freq: frequency of batches with which to update counter, None disables, default 10
    :param evaluate_batch_size: batch size for evaluation stage, default 1024
    :param update_callback: optional function of accuracy and loss to update, default None
    :param epoch_callback: optional function of epoch and model, default None
    :return: None
    """
    static_dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        pin_memory=False,
        sampler=sampler,
        shuffle=False,
    )
    train_dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        sampler=sampler,
        shuffle=True,
    )
    data_iterator = tqdm(
        static_dataloader,
        leave=True,
        unit="batch",
        postfix={
            "epo": -1,
            "acc": "%.4f" % 0.0,
            "lss": "%.8f" % 0.0,
            "dlb": "%.4f" % -1,
        },
        disable=silent,
    )
    kmeans = KMeans(n_clusters=model.cluster_number, n_init=20)
    model.train()
    features = []
    actual = []
    # form initial cluster centres
    for index, batch in enumerate(data_iterator):
        if (isinstance(batch, tuple) or isinstance(batch, list)) and len(batch) == 2:
            batch, value = batch  # if we have a prediction label, separate it to actual
            actual.append(value)
        if cuda:
            batch = batch.cuda(non_blocking=True)
        features.append(model.encoder(batch).detach().cpu())
    actual = torch.cat(actual).long()
    predicted = kmeans.fit_predict(torch.cat(features).numpy())
    predicted_previous = torch.tensor(np.copy(predicted), dtype=torch.long)
    _, accuracy = cluster_accuracy(predicted, actual.cpu().numpy())
    cluster_centers = torch.tensor(
        kmeans.cluster_centers_, dtype=torch.float, requires_grad=True
    )
    if cuda:
        cluster_centers = cluster_centers.cuda(non_blocking=True)
    with torch.no_grad():
        # initialise the cluster centers
        model.state_dict()["assignment.cluster_centers"].copy_(cluster_centers)
    loss_function = nn.KLDivLoss(size_average=False)
    delta_label = None
    for epoch in range(epochs):
        features = []
        data_iterator = tqdm(
            train_dataloader,
            leave=True,
            unit="batch",
            postfix={
                "epo": epoch,
                "acc": "%.4f" % (accuracy or 0.0),
                "lss": "%.8f" % 0.0,
                "dlb": "%.4f" % (delta_label or 0.0),
            },
            disable=silent,
        )
        model.train()
        for index, batch in enumerate(data_iterator):
            if (isinstance(batch, tuple) or isinstance(batch, list)) and len(
                batch
            ) == 2:
                batch, _ = batch  # if we have a prediction label, strip it away
            if cuda:
                batch = batch.cuda(non_blocking=True)
            output = model(batch)
            target = target_distribution(output).detach()
            loss = loss_function(output.log(), target) / output.shape[0]
            data_iterator.set_postfix(
                epo=epoch,
                acc="%.4f" % (accuracy or 0.0),
                lss="%.8f" % float(loss.item()),
                dlb="%.4f" % (delta_label or 0.0),
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step(closure=None)
            features.append(model.encoder(batch).detach().cpu())
            if update_freq is not None and index % update_freq == 0:
                loss_value = float(loss.item())
                data_iterator.set_postfix(
                    epo=epoch,
                    acc="%.4f" % (accuracy or 0.0),
                    lss="%.8f" % loss_value,
                    dlb="%.4f" % (delta_label or 0.0),
                )
                if update_callback is not None:
                    update_callback(accuracy, loss_value, delta_label)
        predicted, actual = predict(
            dataset,
            model,
            batch_size=evaluate_batch_size,
            collate_fn=collate_fn,
            silent=True,
            return_actual=True,
            cuda=cuda,
        )
        delta_label = (
            float((predicted != predicted_previous).float().sum().item())
            / predicted_previous.shape[0]
        )
        if stopping_delta is not None and delta_label < stopping_delta:
            print(
                'Early stopping as label delta "%1.5f" less than "%1.5f".'
                % (delta_label, stopping_delta)
            )
            break
        predicted_previous = predicted
        _, accuracy = cluster_accuracy(predicted.cpu().numpy(), actual.cpu().numpy())
        data_iterator.set_postfix(
            epo=epoch,
            acc="%.4f" % (accuracy or 0.0),
            lss="%.8f" % 0.0,
            dlb="%.4f" % (delta_label or 0.0),
        )
        if epoch_callback is not None:
            epoch_callback(epoch, model)
def fine_tune_train_and_eval(
    input_ids: torch.Tensor,
    token_type_ids: torch.Tensor,
    attention_masks: torch.Tensor,
    start_positions: torch.Tensor,
    end_positions: torch.Tensor,
    batch_size: Tuple[int, int],
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    train_ratio: float = 0.9,
    training_epochs: int = 3,
    lr_scheduler_warmup_steps: int = 0,
    save_model_path: Optional[str] = None,
    save_stats_dict_path: Optional[str] = None,
    device_: Optional[
        str] = None  # if None, it automatically detects if a GPU is available, if not uses a CPU
) -> Tuple[torch.nn.Module, Dict[str, Dict[str, Union[float, str]]]]:
    """
    Performs the fine tuning of the model and returns the trained model as well as a dictionary with evaluation
    statistics at each epochs which can be used to check overfitting and training time.
    :param input_ids: torch.tensor of shape (N, max_len) representing the ids of each token of the N encoded sequence
           pairs, with padding at the end up to max_len. If decoded, the input_ids will consist of a "[CLS]" token,
           followed by the question's tokens, followed by a "[SEP]" token, followed by the context's tokens, followed
           by a "[SEP]" token, followed by "[PAD]" tokens, if relevant, up to max_len.
    :param token_type_ids: torch.tensor of shape (N, max_len) where each Nth dimension is filled with 1 for token
           positions in the context text, 0 elsewhere (i.e. in question and padding)
    :param attention_masks: torch.tensor of shape (N, max_len) where each Nth dimension is filled with 1 for
           non-"[PAD]" tokens, 0 for "[PAD]" tokens.
    :param start_positions: torch.tensor of shape (N) containing the index of the first answer token for each answer
    :param end_positions: torch.tensor of shape (N) containing the index of the last answer token for each answer
    :param batch_size: a tuple of 2 integers, representing the batch size of the train and validation dataloaders
           respectively.
    :param model: the model to use (must be instance of torch.nn.Module). For question answering,
           transformers.BertForQuestionAnswering is recommended.
    :param optimizer: the optimizer to use for the model (must be instance of torch.optim.Optimizer).
    :param train_ratio: the train / (train + validation) split ratio. Default: 0.9 (i.e. 90% of the input data will
           go to the train dataloader and 10% to the validation dataloader). The split is random.
    :param training_epochs: the number of training epochs. Default: 3.
    :param lr_scheduler_warmup_steps: the number of warmup steps of the learning rate scheduler. Default: 0.
           Note: the purpose of this scheduler is to update the learning rate over the course of the training. It is
           preferable for the learning rate to gradually get smaller and smaller so that training makes gradually
           finer adjustments to the weights as the loss gets smaller.
    :param save_model_path: if specified, the path where to save the model (should have '.pt' extension). The model
           will be save at every epoch with the epoch suffix, for easy comparison. Default: None.
    :param save_stats_dict_path: if specified, the path where to save the dictionary of statistics (should have
           '.json' extension). Default: None.
    :param device_: if specified, the device used for the computations. Can be one of cpu, cuda, mkldnn, opengl,
           opencl, ideep, hip, msnpu. If set to None, it will default to GPU (cuda) if one is available, else it will
           use a CPU. Default: None
    :return: model: the fine tuned model.
             training_stats: a dictionary with a number of statistics. For each epoch, the training loss, validation
             loss, validation accuracy, training time and validation time are included.
    """
    assert all(
        [
            isinstance(i, torch.Tensor) for i in [
                input_ids, token_type_ids, attention_masks, start_positions,
                end_positions
            ]
        ]
    ), "Some inputs are not tensors. When training, start_positions and end_positions must be tensors, not lists."
    assert input_ids.shape == token_type_ids.shape == attention_masks.shape, "Some input shapes are incompatible."
    assert input_ids.shape[0] == len(start_positions) == len(
        end_positions), "Some input shapes are incompatible"

    train_dataloader, valid_dataloader = _build_dataloaders(
        input_ids, token_type_ids, attention_masks, start_positions,
        end_positions, batch_size, train_ratio)
    training_steps = training_epochs * len(
        train_dataloader)  # epochs * number of batches
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=lr_scheduler_warmup_steps,
        num_training_steps=training_steps)
    device = set_hardware_acceleration(default=device_)
    model = model.to(device)
    training_stats = {}
    for epoch in (range(training_epochs)):
        logger.info(
            f"Training epoch {epoch + 1} of {training_epochs}. Running training."
        )
        t_i = time()
        model.train()
        cumulative_train_loss_per_epoch = 0.
        for batch_num, batch in tqdm(enumerate(train_dataloader),
                                     total=len(train_dataloader)):
            logger.debug(
                f"Running training batch {batch_num + 1} of {len(train_dataloader)}."
            )
            batch_input_ids, batch_token_type_ids, batch_attention_masks, batch_start_positions, batch_end_positions = \
                batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3].to(device), batch[4].to(device)
            model.zero_grad()
            #  model.zero_grad() and optimizer.zero_grad() are the same IF all model parameters are in that optimizer.
            #  It could be safer to call model.zero_grad() if you have two or more optimizers for one model.
            loss, start_logits, end_logits = model(
                input_ids=batch_input_ids,
                attention_mask=batch_attention_masks,
                token_type_ids=batch_token_type_ids,
                start_positions=batch_start_positions,
                end_positions=batch_end_positions
            )  # BertForQuestionAnswering uses CrossEntropyLoss by default, no need to calculate explicitly

            cumulative_train_loss_per_epoch += loss.item()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
            # clipping the norm of the gradients to 1.0 to help prevent the "exploding gradients" issues.
            optimizer.step()  # update model parameters
            lr_scheduler.step()  # update the learning rate

        average_training_loss_per_batch = cumulative_train_loss_per_epoch / len(
            train_dataloader)
        training_time = format_time(time() - t_i)
        logger.info(f"Epoch {epoch + 1} took {training_time} to train.")
        logger.info(
            f"Average training loss: {average_training_loss_per_batch}. \n Running validation."
        )
        if torch.cuda.is_available():
            logger.info(f"GPU memory usage: \n{gpu_memory_usage()}")

        t_i = time()
        model.eval()

        pred_start = torch.tensor(
            [], dtype=torch.long,
            device=device)  # initialising tensors for storing results
        pred_end = torch.tensor([], dtype=torch.long, device=device)
        true_start = torch.tensor([], dtype=torch.long, device=device)
        true_end = torch.tensor([], dtype=torch.long, device=device)

        cumulative_eval_loss_per_epoch = 0
        cumulative_eval_accuracy_per_epoch = 0  # WE DO THIS DIFFERENTLY. SHALL WE REMOVE THIS?

        for batch_num, batch in tqdm(enumerate(valid_dataloader),
                                     total=len(valid_dataloader)):
            logger.info(
                f"Running validation batch {batch_num + 1} of {len(valid_dataloader)}."
            )
            batch_input_ids, batch_token_type_ids, batch_attention_masks, batch_start_positions, batch_end_positions = \
                batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3].to(device), batch[4].to(device)
            with torch.no_grad():
                loss, start_logits, end_logits = model(
                    input_ids=batch_input_ids,
                    attention_mask=batch_attention_masks,
                    token_type_ids=batch_token_type_ids,
                    start_positions=batch_start_positions,
                    end_positions=batch_end_positions
                )  # if we pass it the true labels, i.e. start_positions and end_positions it will also return the loss
                cumulative_eval_loss_per_epoch += loss.item()
                # SHALL WE MOVE THE BELOW TO CPU AND NUMPY OR KEEP GPU AND PYTORCH?

                pred_start_positions = torch.argmax(start_logits, dim=1)
                pred_end_positions = torch.argmax(end_logits, dim=1)

                pred_start = torch.cat((pred_start, pred_start_positions))
                pred_end = torch.cat((pred_end, pred_end_positions))
                true_start = torch.cat((true_start, batch_start_positions))
                true_end = torch.cat((true_end, batch_end_positions))
            if torch.cuda.is_available():
                logger.debug(f"GPU memory usage: \n{gpu_memory_usage()}")

        total_correct_start = int(sum(pred_start == true_start))
        total_correct_end = int(sum(pred_end == true_end))
        total_correct = total_correct_start + total_correct_end
        total_indices = len(true_start) + len(true_end)

        average_validation_accuracy_per_epoch = total_correct / total_indices
        average_validation_loss_per_batch = cumulative_eval_loss_per_epoch / len(
            valid_dataloader)
        valid_time = format_time(time() - t_i)
        logger.info(f"Epoch {epoch + 1} took {valid_time} to validate.")
        logger.info(
            f"Average validation loss: {average_validation_loss_per_batch}.")
        logger.info(
            f"Average validation accuracy (out of 1): {average_validation_accuracy_per_epoch}."
        )
        if torch.cuda.is_available():
            logger.info(f"GPU memory usage: \n{gpu_memory_usage()}")

        training_stats[f"epoch_{epoch + 1}"] = {
            "training_loss": average_training_loss_per_batch,
            "valid_loss": average_validation_loss_per_batch,
            "valid_accuracy": average_validation_accuracy_per_epoch,
            "training_time": training_time,
            "valid_time": valid_time
        }
        if save_model_path is not None:
            save_model_path = save_model_path.split(".")[
                0]  # removing extension if present
            torch.save(model.state_dict(),
                       f"{save_model_path}_epoch_{epoch + 1}.pt"
                       )  # readd .pt extension
    if save_stats_dict_path is not None:
        with open(save_stats_dict_path, "w") as file:
            json.dump(training_stats, file)
    return model, training_stats
Beispiel #11
0
def train_model(epoch: int, opt: argparse.Namespace, conf: Dict,
                model: BiaffineParser, optimizer: torch.optim.Optimizer,
                train_batch: BatcherBase, valid_batch: Batcher,
                test_batch: Batcher, ix2label: Dict, best_valid: float,
                test_result: float):
    model.reset_timer()
    model.train()

    cnt = 0
    start_time = time.time()

    witnessed_improved_valid_result = False
    total_loss, total_arc_loss, total_tag_loss, total_n_tags = 0., 0., 0., 0.
    for inputs, head_indices, head_tags, _ in train_batch.get():
        cnt += 1
        model.zero_grad()
        forward_output_dict = model.forward(inputs, head_tags, head_indices)

        n_tags = inputs['length'].sum().item()
        loss = forward_output_dict['loss']

        total_loss += loss.item() * n_tags
        total_arc_loss += forward_output_dict['arc_loss'].item() * n_tags
        total_tag_loss += forward_output_dict['tag_loss'].item() * n_tags
        total_n_tags += n_tags

        loss.backward()
        if 'clip_grad' in conf['optimizer']:
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           conf['optimizer']['clip_grad'])

        optimizer.step()

        if cnt % opt.report_steps == 0:
            logger.info(
                "| epoch {:3d} | step {:>6d} | lr {:.3g} | ms/batch {:5.2f} | loss {:.4f} "
                "(arc {:.4f} rel {:.4f}) |".format(
                    epoch, cnt, optimizer.param_groups[0]['lr'],
                    1000 * (time.time() - start_time) / opt.report_steps,
                    total_loss / total_n_tags, total_arc_loss / total_n_tags,
                    total_tag_loss / total_n_tags))
            start_time = time.time()

        if cnt % opt.eval_steps == 0:
            eval_time = time.time()
            valid_result = eval_model(model, valid_batch, ix2label, opt,
                                      opt.gold_valid_path)
            logging_str = "| epoch {:3d} | step {:>6d} | lr {:.3g} | loss {:.4f} (arc {:.4f} rel {:.4f}) " \
                          "| dev {:.4f} |".format(epoch, cnt, optimizer.param_groups[0]['lr'],
                                                  total_loss / total_n_tags,
                                                  total_arc_loss / total_n_tags,
                                                  total_tag_loss / total_n_tags,
                                                  valid_result)
            if valid_result > best_valid:
                logging_str = logging_str + ' NEW |'

            logger.info(logging_str)

            if valid_result > best_valid:
                witnessed_improved_valid_result = True
                torch.save(model.state_dict(),
                           os.path.join(opt.model, 'model.pkl'))
                best_valid = valid_result
                if test_batch is not None:
                    test_result = eval_model(model, test_batch, ix2label, opt,
                                             opt.gold_test_path)
                    logging_str = "| epoch {:3d} | step {:>6d} | lr {:.3g} | " \
                                  "| test {:.4f} |".format(epoch, cnt, optimizer.param_groups[0]['lr'],
                                                           test_result)
                    logger.info(logging_str)
            eval_time = time.time() - eval_time
            start_time += eval_time

    logging_str = "| epoch {:3d} | step {:>6d} | lr {:.3g} | loss {:.4f} " \
                  "(arc {:.4f} rel {:.4f}) |".format(epoch, cnt, optimizer.param_groups[0]['lr'],
                                                     total_loss / total_n_tags,
                                                     total_arc_loss / total_n_tags,
                                                     total_tag_loss / total_n_tags)
    logger.info(logging_str)
    logger.info(
        "| time tracking | input {:.2f}s | context {:.2f}s | classification {:.2f}s"
        .format(model.input_encoding_timer.total_eclipsed_time(),
                model.context_encoding_timer.total_eclipsed_time(),
                model.classification_timer.total_eclipsed_time()))
    return best_valid, test_result, witnessed_improved_valid_result
Beispiel #12
0
def train_one_epoch(model: torch.nn.Module,
                    data_loader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    criterion: torch.nn.Module,
                    device: torch.device,
                    epoch: int,
                    summary: TensorboardSummary,
                    max_norm: float = 0,
                    amp: object = None):
    """
    train model for 1 epoch
    """
    model.train()
    criterion.train()

    # initialize stats
    train_stats = {
        'l1': 0.0,
        'occ_be': 0.0,
        'l1_raw': 0.0,
        'iou': 0.0,
        'rr': 0.0,
        'epe': 0.0,
        'error_px': 0.0,
        'total_px': 0.0
    }

    tbar = tqdm(data_loader)
    for idx, data in enumerate(tbar):
        # forward pass

        aa = data["disp"]
        _, losses, sampled_disp = forward_pass(model, data, device, criterion,
                                               train_stats)

        # terminate training if exploded
        if not math.isfinite(losses['aggregated'].item()):
            print("Loss is {}, stopping training".format(
                losses['aggregated'].item()))
            sys.exit(1)

        # backprop
        optimizer.zero_grad()
        if amp is not None:
            with amp.scale_loss(losses['aggregated'],
                                optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            losses['aggregated'].backward()

        # clip norm
        if max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

        # step optimizer
        optimizer.step()

        print('pixel_error', losses['error_px'] / losses['total_px'])

        # clear cache
        torch.cuda.empty_cache()

    # compute avg
    train_stats[
        'px_error_rate'] = train_stats['error_px'] / train_stats['total_px']

    # log to tensorboard
    write_summary(train_stats, summary, epoch, 'train')

    print('Training loss', train_stats['l1'], 'pixel error rate',
          train_stats['px_error_rate'])
    print('RR loss', train_stats['rr'])

    return
Beispiel #13
0
def train_one_epoch(model: torch.nn.Module,
                    d_vae: torch.nn.Module,
                    data_loader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device,
                    epoch: int,
                    loss_scaler,
                    max_norm: float = 0,
                    log_writer=None,
                    lr_scheduler=None,
                    start_steps=None,
                    lr_schedule_values=None,
                    wd_schedule_values=None):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter(
        'min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    for step, (batch, _) in enumerate(
            metric_logger.log_every(data_loader, print_freq, header)):
        # assign learning rate & weight decay for each step
        it = start_steps + step  # global training iteration
        if lr_schedule_values is not None or wd_schedule_values is not None:
            for i, param_group in enumerate(optimizer.param_groups):
                if lr_schedule_values is not None:
                    param_group["lr"] = lr_schedule_values[it] * param_group[
                        "lr_scale"]
                if wd_schedule_values is not None and param_group[
                        "weight_decay"] > 0:
                    param_group["weight_decay"] = wd_schedule_values[it]

        samples, images, bool_masked_pos = batch
        images = images.to(device, non_blocking=True)
        samples = samples.to(device, non_blocking=True)
        bool_masked_pos = bool_masked_pos.to(device, non_blocking=True)

        with torch.no_grad():
            input_ids = d_vae.get_codebook_indices(images).flatten(1)
            bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool)
            labels = input_ids[bool_masked_pos]

        with torch.cuda.amp.autocast():
            outputs = model(samples,
                            bool_masked_pos=bool_masked_pos,
                            return_all_tokens=False)
            loss = nn.CrossEntropyLoss()(input=outputs, target=labels)

        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        optimizer.zero_grad()
        # this attribute is added by timm on one optimizer (adahessian)
        is_second_order = hasattr(
            optimizer, 'is_second_order') and optimizer.is_second_order
        grad_norm = loss_scaler(loss,
                                optimizer,
                                clip_grad=max_norm,
                                parameters=model.parameters(),
                                create_graph=is_second_order)
        loss_scale_value = loss_scaler.state_dict()["scale"]

        torch.cuda.synchronize()

        mlm_acc = (outputs.max(-1)[1] == labels).float().mean().item()

        metric_logger.update(mlm_acc=mlm_acc)
        if log_writer is not None:
            log_writer.update(mlm_acc=mlm_acc, head="loss")

        metric_logger.update(loss=loss_value)
        metric_logger.update(loss_scale=loss_scale_value)
        min_lr = 10.
        max_lr = 0.
        for group in optimizer.param_groups:
            min_lr = min(min_lr, group["lr"])
            max_lr = max(max_lr, group["lr"])

        metric_logger.update(lr=max_lr)
        metric_logger.update(min_lr=min_lr)
        weight_decay_value = None
        for group in optimizer.param_groups:
            if group["weight_decay"] > 0:
                weight_decay_value = group["weight_decay"]
        metric_logger.update(weight_decay=weight_decay_value)
        metric_logger.update(grad_norm=grad_norm)

        if log_writer is not None:
            log_writer.update(loss=loss_value, head="loss")
            log_writer.update(loss_scale=loss_scale_value, head="opt")
            log_writer.update(lr=max_lr, head="opt")
            log_writer.update(min_lr=min_lr, head="opt")
            log_writer.update(weight_decay=weight_decay_value, head="opt")
            log_writer.update(grad_norm=grad_norm, head="opt")

            log_writer.set_step()

        if lr_scheduler is not None:
            lr_scheduler.step_update(start_steps + step)
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def run_epoch(experiment: comet_ml.Experiment,
              network: ProbabilisticExternalAgentCurvePredictor,
              optimizer: torch.optim.Optimizer,
              dataloader: data_utils.DataLoader,
              debug: bool = False,
              use_tqdm=False):
    num_samples = 0.0
    if use_tqdm:
        t = tqdm(enumerate(dataloader), total=len(dataloader))
    else:
        t = enumerate(dataloader)
    network.train()  # This is important to call before training!

    # we are only doing single-device training for now, so this works fine.
    dev = next(network.parameters()).device
    dtype = next(network.parameters()).dtype
    lossf = 0
    d = network.output_dim
    for (i, datadict) in t:
        valid_mask = datadict["valid_mask"]
        past_positions = datadict["past_positions"]
        past_velocities = datadict["past_velocities"]
        past_quaternions = datadict["past_quaternions"]
        future_positions = datadict["future_positions"]
        tfuture = datadict["tfuture"]

        valid_past_positions = (
            past_positions[valid_mask].type(dtype).to(dev))[:, :, [0, 2]]
        valid_past_velocities = (
            past_velocities[valid_mask].type(dtype).to(dev))[:, :, [0, 2]]
        valid_past_quaternions = past_quaternions[valid_mask].type(dtype).to(
            dev)
        valid_future_positions = (
            future_positions[valid_mask].type(dtype).to(dev))[:, :, [0, 2]]
        valid_tfuture = tfuture[valid_mask].type(dtype).to(dev)
        if network.input_dim == 4:
            networkinput = torch.cat(
                [valid_past_positions, valid_past_velocities], dim=2)
        elif network.input_dim == 8:
            networkinput = torch.cat([
                valid_past_positions, valid_past_velocities,
                valid_past_quaternions
            ],
                                     dim=2)
        else:
            raise ValueError(
                "Currently, only input dimensions of 4 and 8 are supported")
        batch_size = networkinput.shape[0]
        means, varfactors, covarfactors = network(networkinput)
        meancurves = torch.cat(
            [valid_future_positions[:, 0].unsqueeze(1), means], dim=1)
        if debug:
            pass

        dt = valid_tfuture[:, -1] - valid_tfuture[:, 0]
        s_torch_cur = (valid_tfuture - valid_tfuture[:, 0, None]) / dt[:, None]
        Mpos = mu.bezierM(s_torch_cur, network.bezier_order)
        Msquare = torch.square(Mpos)
        pred_points = torch.matmul(Mpos, meancurves)
        deltas = pred_points - valid_future_positions
        squared_norms = torch.sum(torch.square(deltas), dim=2)
        point_estimate_loss = torch.mean(squared_norms)

        scale_trils = torch.diag_embed(varfactors) + torch.diag_embed(
            covarfactors, offset=-1)
        #   print(scale_trils[0])
        covars = torch.matmul(scale_trils, scale_trils.transpose(2, 3))
        covars_expand = covars.unsqueeze(1).expand(batch_size,
                                                   Msquare.shape[1],
                                                   Msquare.shape[2], d, d)
        poscovar = torch.sum(Msquare[:, :, :, None, None] * covars_expand,
                             dim=2)
        # print(pred_points.shape)
        # print(poscovar.shape)
        distpos = torch.distributions.MultivariateNormal(
            pred_points, covariance_matrix=poscovar, validate_args=False)
        log_probs = distpos.log_prob(valid_future_positions)
        NLL = 0.0001 * torch.mean(-log_probs)

        loss = point_estimate_loss + NLL
        if not (loss == loss):
            continue

        optimizer.zero_grad()
        loss.backward()
        # Weight and bias updates.
        optimizer.step()
        # logging information
        if ((i % 15) == 0):
            experiment.log_metric("point_estimate_loss",
                                  point_estimate_loss.item())
            experiment.log_metric("NLL", NLL.item())
        if use_tqdm:
            curr_loss = loss.item()
            lossf += curr_loss
            t.set_postfix({
                "point_estimate_loss": point_estimate_loss.item(),
                "NLL": NLL.item()
            })
Beispiel #15
0
def train_one_epoch(model: torch.nn.Module,
                    criterion: torch.nn.Module,
                    data_loader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device,
                    epoch: int,
                    args,
                    postprocessors=None):
    model.train()
    criterion.train()
    metric_logger = utils.MetricLogger(delimiter='  ')
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    if args.stage != 3:
        metric_logger.add_meter(
            'class_error', utils.SmoothedValue(window_size=1,
                                               fmt='{value:.2f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    max_norm = args.clip_max_norm

    for vid_name_list, locations, samples, targets, num_frames, base, s_e_scores \
        in metric_logger.log_every(data_loader, print_freq, header):

        samples = samples.to(device)
        s_e_scores = s_e_scores.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(locations, samples, s_e_scores)
        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()
                     if k in weight_dict)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {
            f'{k}_unscaled': v
            for k, v in loss_dict_reduced.items()
        }
        loss_dict_reduced_scaled = {
            k: v * weight_dict[k]
            for k, v in loss_dict_reduced.items() if k in weight_dict
        }
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()

        if not math.isfinite(loss_value):
            print('Loss is {}, stopping training'.format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        if max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

        metric_logger.update(loss=loss_value,
                             **loss_dict_reduced_scaled,
                             **loss_dict_reduced_unscaled)
        if args.stage != 3:
            metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]['lr'])

    metric_logger.synchronize_between_processes()
    return {k: meter.global_avg
            for k, meter in metric_logger.meters.items()}, loss_dict
Beispiel #16
0
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    amp: bool = True, teacher_model: torch.nn.Module = None,
                    teach_loss: torch.nn.Module = None, distill_token: bool=False, choices=None, mode='super', retrain_config=None):
    model.train()
    criterion.train()

    # set random seed
    random.seed(epoch)

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10
    if mode == 'retrain':
        config = retrain_config
        model_module = unwrap_model(model)
        print(config)
        model_module.set_sample_config(config=config)
        print(model_module.get_sampled_params_numel(config))

    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        # sample random config
        if mode == 'super':
            config = sample_configs(choices=choices)
            model_module = unwrap_model(model)
            model_module.set_sample_config(config=config)
        elif mode == 'retrain':
            config = retrain_config
            model_module = unwrap_model(model)
            model_module.set_sample_config(config=config)
        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)
        if amp:
            with torch.cuda.amp.autocast():
                if teacher_model:
                    with torch.no_grad():
                        teach_output = teacher_model(samples)
                    _, teacher_label = teach_output.topk(1, 1, True, True)
                    if distill_token:
                        output_cls, output_dis = model(samples)
                        loss = 1/2 * criterion(output_cls, targets) + 1/2 * teach_loss(output_dis, teacher_label.squeeze())
                    else:
                        outputs = model(samples)
                        loss = 1/2 * criterion(outputs, targets) + 1/2 * teach_loss(outputs, teacher_label.squeeze())
                else:
                    outputs = model(samples)
                    loss = criterion(outputs, targets)
        else:
            outputs = model(samples)
            if teacher_model:
                with torch.no_grad():
                    teach_output = teacher_model(samples)
                _, teacher_label = teach_output.topk(1, 1, True, True)
                loss = 1 / 2 * criterion(outputs, targets) + 1 / 2 * teach_loss(outputs, teacher_label.squeeze())
            else:
                loss = criterion(outputs, targets)

        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        optimizer.zero_grad()

        # this attribute is added by timm on one optimizer (adahessian)
        if amp:
            is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
            loss_scaler(loss, optimizer, clip_grad=max_norm,
                    parameters=model.parameters(), create_graph=is_second_order)
        else:
            loss.backward()
            optimizer.step()

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def train(encoder: EncoderBILSTM,
          decoder: DecoderLSTM,
          epoch_count: int,
          train_loader: DataLoader,
          criterion,
          optimizer_enc: torch.optim.Optimizer,
          optimizer_dec: torch.optim.Optimizer,
          is_cuda: bool,
          teacher_forcing: bool = False,
          debug: bool = False,
          lr_schedule=False,
          start_epoch_at: int = 0):
    losses = []
    best_loss = 1000000
    for epoch in range(start_epoch_at, start_epoch_at + epoch_count):
        total_batch_loss = 0
        for ind, batch in enumerate(train_loader):
            loss = 0
            questions, questions_org_len, answers, answers_org_len, pID = batch
            if questions.shape[1] > 1000:
                break
            if is_cuda:
                questions = questions.cuda()
                answers = answers.cuda()

            encoder_input, encoder_len = answers, answers_org_len
            decoder_input, decoder_len = questions, questions_org_len

            if is_cuda:
                encoder_out, encoder_hidden = encoder(
                    encoder_input,
                    torch.LongTensor(encoder_len).cuda(), False)
                encoder_len = torch.FloatTensor(encoder_len)
                if not teacher_forcing:
                    decoder_inp = torch.ones((len(questions), 1),
                                             dtype=torch.long).cuda()
            else:
                encoder_out, encoder_hidden = encoder(
                    encoder_input, torch.LongTensor(encoder_len), False)
                encoder_len = torch.FloatTensor(encoder_len)
                if not teacher_forcing:
                    decoder_inp = torch.ones((len(questions), 1),
                                             dtype=torch.long)

            if teacher_forcing:
                decoder_out, decoder_hidden, attn_scores = decoder(
                    decoder_input[:, :-1], encoder_hidden, encoder_out,
                    encoder_len, False)
                decoder_out = decoder_out.transpose(0, 1).contiguous()
                decoder_out = decoder_out.transpose(1, 2).contiguous()
                loss = criterion(decoder_out, questions[:, :-1])
            else:
                decoder_hidden = (encoder_hidden[0].clone(),
                                  encoder_hidden[1].clone())
                eval_mode = False
                for j in range(questions.shape[1]):
                    decoder_out, decoder_hidden, attn_scores = decoder(
                        decoder_inp,
                        decoder_hidden,
                        encoder_out,
                        encoder_len,
                        eval_mode=eval_mode)
                    # obtaining log_softmax scores we need to minimize log softmax over a span.
                    decoder_out = decoder_out.squeeze(0)
                    prediction = torch.argmax(decoder_out, 1).unsqueeze(1)
                    loss_val = criterion(decoder_out, questions[:, j])
                    loss += loss_val / questions.shape[1]
                    decoder_inp = prediction.clone().detach()
                    eval_mode = True

            optimizer_enc.zero_grad()
            optimizer_dec.zero_grad()

            loss.backward()
            clip_grad_norm_(encoder.parameters(), 5)
            clip_grad_norm_(decoder.parameters(), 5)
            optimizer_enc.step()
            optimizer_dec.step()

            if lr_schedule:
                optimizer_enc = exp_lr_scheduler(optimizer_enc,
                                                 epoch,
                                                 lr_decay_epoch=8)
                optimizer_dec = exp_lr_scheduler(optimizer_dec,
                                                 epoch,
                                                 lr_decay_epoch=8)

            total_batch_loss += loss.item()
            if debug: print("Batch Loss: %f" % loss.item())
            if ind % 1000 == 0:
                print("Batch %d Loss: %f" % (ind, loss.item()))
        losses.append(total_batch_loss)

        print("Epoch[%d] Loss: %f" % (epoch, total_batch_loss))
        if total_batch_loss < best_loss:
            torch.save(encoder.state_dict(),
                       "model_weights/%d-encoder-SGD-small.pth" % epoch)
            torch.save(decoder.state_dict(),
                       "model_weights/%d-decoder-SGD-small.pth" % epoch)
            best_loss = total_batch_loss
    torch.save(encoder.state_dict(),
               "model_weights/final-encoder-SGD-small.pth")
    torch.save(decoder.state_dict(),
               "model_weights/final-decoder-SGD-small.pth")
    return losses
Beispiel #18
0
def _torch_step(optimizer: torch.optim.Optimizer, scaler: Optional[torch.cuda.amp.GradScaler] = None) -> None:
    if scaler is None:
        optimizer.step()
    else:
        scaler.step(optimizer)
    optimizer.zero_grad()
Beispiel #19
0
def train_person_segmentor(
        model: torch.nn.Module,
        train_loader: torch.utils.data.DataLoader,
        valid_loader: torch.utils.data.DataLoader,
        criterion: callable,
        optimiser: torch.optim.Optimizer,
        *,
        save_model_path: Path,
        learning_rate: Number = 6e-2,
        scheduler: torch.optim.lr_scheduler = None,
        n_epochs: int = 100,
        writer: ImageWriterMixin = MockWriter(),
):
    """

    :param model:
    :type model:
    :param train_loader:
    :type train_loader:
    :param valid_loader:
    :type valid_loader:
    :param criterion:
    :type criterion:
    :param optimiser:
    :type optimiser:
    :param scheduler:
    :type scheduler:
    :param save_model_path:
    :type save_model_path:
    :param n_epochs:
    :type n_epochs:
    :return:
    :rtype:"""
    valid_loss_min = numpy.Inf  # track change in validation loss
    assert n_epochs > 0, n_epochs
    E = tqdm(range(1, n_epochs + 1))
    for epoch_i in E:
        train_loss = 0.0
        valid_loss = 0.0

        with TorchTrainSession(model):
            for data, target in tqdm(train_loader):
                output, *_ = model(data.to(global_torch_device()))
                loss = criterion(output,
                                 target.to(global_torch_device()).float())

                optimiser.zero_grad()
                loss.backward()
                optimiser.step()

                train_loss += loss.cpu().item() * data.size(0)

        with TorchEvalSession(model):
            with torch.no_grad():
                for data, target in tqdm(valid_loader):
                    (
                        output,
                        *_,
                    ) = model(  # forward pass: compute predicted outputs by passing inputs to the model
                        data.to(global_torch_device()))
                    validation_loss = criterion(  # calculate the batch loss
                        output,
                        target.to(global_torch_device()).float())
                    writer.scalar(
                        "dice_validation",
                        dice_loss(output,
                                  target.to(global_torch_device()).float()),
                    )

                    valid_loss += validation_loss.detach().cpu().item(
                    ) * data.size(0)  # update average validation loss
                writer.image("prediction", torch.sigmoid(output),
                             epoch_i)  # write the last batch

        # calculate average losses
        train_loss = train_loss / len(train_loader.dataset)
        valid_loss = valid_loss / len(valid_loader.dataset)

        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print(
                f"Validation loss decreased ({valid_loss_min:.6f} --> {valid_loss:.6f}).  Saving model ..."
            )
            torch.save(model.state_dict(), save_model_path)
            valid_loss_min = valid_loss

        if scheduler:
            scheduler.step()
            optimiser, scheduler = reschedule_learning_rate(
                model,
                optimiser,
                epoch_i,
                scheduler,
                starting_learning_rate=learning_rate,
            )

        # print training/validation statistics
        current_lr = next(iter(optimiser.param_groups))["lr"]
        E.set_description(f"Epoch: {epoch_i} "
                          f"Training Loss: {train_loss:.6f} "
                          f"Validation Loss: {valid_loss:.6f} "
                          f"Learning rate: {current_lr:.6f}")
        writer.scalar("training_loss", train_loss)
        writer.scalar("validation_loss", valid_loss)
        writer.scalar("learning_rate", current_lr)

    return model
def train_one_epoch(model: torch.nn.Module,
                    criterion: DistillationLoss,
                    data_loader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device,
                    epoch: int,
                    loss_scaler,
                    max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None,
                    mixup_fn: Optional[Mixup] = None,
                    teacher=None,
                    set_training_mode=True):
    # TODO fix this for finetuning
    # model.train(set_training_mode)
    model.train()

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 100

    for samples, targets in metric_logger.log_every(data_loader, print_freq,
                                                    header):
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)

        samples, targets, mix_rate, aux_targets = two_mix(
            samples, targets, num_patch=samples.shape[-1] // 16)

        with torch.cuda.amp.autocast():
            # outputs, r_loss = model(samples)
            outputs, r_loss, s_loss, proj = model(samples, aux_targets)
            loss = torch.sum(-targets * (1e-8 + outputs.softmax(dim=-1)).log(),
                             dim=-1).mean()

            loss_value = loss.item()
            loss += 1. * (r_loss + 1. * s_loss)

        if not math.isfinite(loss.item()):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        optimizer.zero_grad()

        # this attribute is added by timm on one optimizer (adahessian)
        is_second_order = hasattr(
            optimizer, 'is_second_order') and optimizer.is_second_order
        loss_scaler(loss,
                    optimizer,
                    clip_grad=max_norm,
                    parameters=model.parameters(),
                    create_graph=is_second_order)

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters['r'].update(r_loss.item(), n=targets.shape[0])
        # metric_logger.meters['p'].update(proj.item(), n=targets.shape[0])
        metric_logger.meters['s'].update(s_loss.item(), n=targets.shape[0])
        # metric_logger.meters['cos'].update(cos.item(), n=targets.shape[0])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def train_domain_classifier(
        model: torch.nn.Module,
        train_dl: DataLoader,
        optimizer: torch.optim.Optimizer,
        scheduler: LambdaLR,
        validation_evaluator: MultiDatasetClassificationEvaluator,
        n_epochs: int,
        device: AnyStr,
        log_interval: int = 1,
        patience: int = 10,
        model_dir: str = "wandb_local",
        gradient_accumulation: int = 1,
        domain_name: str = ''):
    #best_loss = float('inf')
    best_acc = 0.0
    patience_counter = 0

    epoch_counter = 0
    total = sum(len(dl) for dl in train_dls)

    # Main loop
    while epoch_counter < n_epochs:
        for i, batch in enumerate(tqdm(train_dl)):
            model.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids = batch[0]
            masks = batch[1]
            labels = batch[2]
            # Testing with random domains to see if any effect
            #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
            domains = batch[3]

            loss, logits = model(input_ids,
                                 attention_mask=masks,
                                 labels=domains)
            loss = loss / gradient_accumulation

            if i % gradient_accumulation == 0:
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
            if scheduler is not None:
                scheduler.step()

        gc.collect()

        # Inline evaluation
        (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
        print(f"Validation acc: {acc}")

        # Saving the best model and early stopping
        #if val_loss < best_loss:
        if acc > best_acc:
            best_model = model.state_dict()
            best_acc = acc
            torch.save(
                model.state_dict(),
                f'{model_dir}/{Path(wandb.run.dir).name}/model_domainclassifier_{domain_name}.pth'
            )
            patience_counter = 0
        else:
            patience_counter += 1
            # Stop training once we have lost patience
            if patience_counter == patience:
                break

        gc.collect()
        epoch_counter += 1
Beispiel #22
0
def train_one_epoch(args,
                    model: torch.nn.Module,
                    criterion: torch.nn.Module,
                    dataloader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device,
                    epoch: int,
                    max_norm: float = 0):
    model.train()
    criterion.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 50

    for samples, targets, support_images, support_class_ids, support_targets in metric_logger.log_every(dataloader, print_freq, header):

        # * Sample Support Categories;
        # * Filters Targets (only keep GTs within support categories);
        # * Samples Support Images and Targets
        targets, support_images, support_class_ids, support_targets = \
            sample_support_categories(args, targets, support_images, support_class_ids, support_targets)

        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        support_images = support_images.to(device)
        support_class_ids = support_class_ids.to(device)
        support_targets = [{k: v.to(device) for k, v in t.items()} for t in support_targets]

        outputs = model(samples, targets=targets, supp_samples=support_images, supp_class_ids=support_class_ids, supp_targets=support_targets)
        loss_dict = criterion(outputs)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()}
        loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict}
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()

        if not math.isfinite(loss_value):
            print("Loss is NaN - {}. \nTraining terminated unexpectedly.\n".format(loss_value))
            print("loss dict:")
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        if max_norm > 0:
            grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        else:
            grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm)
        optimizer.step()

        metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(grad_norm=grad_total_norm)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)

    del support_images
    del support_class_ids
    del support_targets
    del samples
    del targets
    del outputs
    del weight_dict
    del grad_total_norm
    del loss_value
    del losses
    del loss_dict
    del loss_dict_reduced
    del loss_dict_reduced_scaled
    del loss_dict_reduced_unscaled

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
Beispiel #23
0
def train_epoch(tree: ProtoTree,
                train_loader: DataLoader,
                optimizer: torch.optim.Optimizer,
                epoch: int,
                disable_derivative_free_leaf_optim: bool,
                device,
                log: Log = None,
                log_prefix: str = 'log_train_epochs',
                progress_prefix: str = 'Train Epoch'
                ) -> dict:
    
    tree = tree.to(device)
    # Make sure the model is in eval mode
    tree.eval()
    # Store info about the procedure
    train_info = dict()
    total_loss = 0.
    total_acc = 0.
    # Create a log if required
    log_loss = f'{log_prefix}_losses'

    nr_batches = float(len(train_loader))
    with torch.no_grad():
        _old_dist_params = dict()
        for leaf in tree.leaves:
            _old_dist_params[leaf] = leaf._dist_params.detach().clone()
        # Optimize class distributions in leafs
        eye = torch.eye(tree._num_classes).to(device)

    # Show progress on progress bar
    train_iter = tqdm(enumerate(train_loader),
                    total=len(train_loader),
                    desc=progress_prefix+' %s'%epoch,
                    ncols=0)
    # Iterate through the data set to update leaves, prototypes and network
    for i, (xs, ys) in train_iter:
        # Make sure the model is in train mode
        tree.train()
        # Reset the gradients
        optimizer.zero_grad()

        xs, ys = xs.to(device), ys.to(device)

        # Perform a forward pass through the network
        ys_pred, info = tree.forward(xs)

        # Learn prototypes and network with gradient descent. 
        # If disable_derivative_free_leaf_optim, leaves are optimized with gradient descent as well.
        # Compute the loss
        if tree._log_probabilities:
            loss = F.nll_loss(ys_pred, ys)
        else:
            loss = F.nll_loss(torch.log(ys_pred), ys)
        
        # Compute the gradient
        loss.backward()
        # Update model parameters
        optimizer.step()
        
        if not disable_derivative_free_leaf_optim:
            #Update leaves with derivate-free algorithm
            #Make sure the tree is in eval mode
            tree.eval()
            with torch.no_grad():
                target = eye[ys] #shape (batchsize, num_classes) 
                for leaf in tree.leaves:  
                    if tree._log_probabilities:
                        # log version
                        update = torch.exp(torch.logsumexp(info['pa_tensor'][leaf.index] + leaf.distribution() + torch.log(target) - ys_pred, dim=0))
                    else:
                        update = torch.sum((info['pa_tensor'][leaf.index] * leaf.distribution() * target)/ys_pred, dim=0)  
                    leaf._dist_params -= (_old_dist_params[leaf]/nr_batches)
                    F.relu_(leaf._dist_params) #dist_params values can get slightly negative because of floating point issues. therefore, set to zero.
                    leaf._dist_params += update

        # Count the number of correct classifications
        ys_pred_max = torch.argmax(ys_pred, dim=1)
        
        correct = torch.sum(torch.eq(ys_pred_max, ys))
        acc = correct.item() / float(len(xs))

        train_iter.set_postfix_str(
            f'Batch [{i + 1}/{len(train_loader)}], Loss: {loss.item():.3f}, Acc: {acc:.3f}'
        )
        # Compute metrics over this batch
        total_loss+=loss.item()
        total_acc+=acc

        if log is not None:
            log.log_values(log_loss, epoch, i + 1, loss.item(), acc)

    train_info['loss'] = total_loss/float(i+1)
    train_info['train_accuracy'] = total_acc/float(i+1)
    return train_info 
Beispiel #24
0
def resume_training(args: argparse.Namespace, hp: HParams, tier: int, model: Tier,
                    optimizer: torch.optim.Optimizer, logger: logging.Logger) \
        -> Tuple[Tier, torch.optim.Optimizer]:
    """
    Loads the model specified in args.checkpoint_path to resume training from that point.

    Args:
        args (argparse.Namespace): parameters to set up the training. At least, args must contain:
                                   args = {"path_config": ...,
                                           "tier": ...,
                                           "checkpoint_path": ...}
        hp (HParams): hyperparameters for the model and other parameters (training, dataset, ...)
        tier (int): number of the tier to load.
        model (Tier): model where the weights will be loaded.
        optimizer (torch.optim.Optimizer): optimizer where the information will be loaded.
        logger (logging.Logger): to log general information about resuming the training.

    Returns:
        model (Tier) and optimizer (torch.optim.Optimizer)
    """
    if not Path(args.checkpoint_path).exists():
        logger.error(
            f"Path for resuming training {args.checkpoint_path} does not exist."
        )
        raise Exception(
            f"Path for resuming training {args.checkpoint_path} does not exist."
        )

    logger.info(f"Resuming training with weights from: {args.checkpoint_path}")
    checkpoint = torch.load(args.checkpoint_path)
    hp_chkpt = checkpoint["hp"]

    # Check if current hyperparameters and the ones from saved model are the same
    if hp_chkpt.audio != hp.audio:
        logger.warning("New params for audio are different from checkpoint. "
                       "It will use new params.")

    if hp_chkpt.network != hp.network:
        logger.error(
            "New params for network structure are different from checkpoint.")
        raise Exception(
            "New params for network structure are different from checkpoint.")

    if checkpoint["tier_idx"] != tier:
        logger.error(
            f"New tier to train ({tier}) is different from checkpoint ({checkpoint['tier']})."
        )
        raise Exception(
            f"New tier to train ({tier}) is different from checkpoint ({checkpoint['tier']})."
        )

    if hp_chkpt.data != hp.data:
        logger.warning("New params for dataset are different from checkpoint. "
                       "It will use new params.")

    if hp_chkpt.training != hp.training:
        logger.warning(
            "New params for training are different from checkpoint. "
            "It will use new params.")

    # epoch_chkpt = checkpoint["epoch"]
    # iterations_chkpt = checkpoint["iterations"]
    # total_iterations_chkpt = checkpoint["total_iterations"]
    model.load_state_dict(checkpoint["tier"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    return model, optimizer
def run_epoch(
    data_iterator: DataLoader,
    model: nn.Module,
    optimizer: torch.optim.Optimizer = None,
    is_test: bool = False,
    is_metadata: bool = False
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
    """
    Runs an epoch of training (or testing)

    Parameters
    ----------
    data_iterator:
        DataLoader object
    model:
        Pytorch model
    optimizer:
        Pytorch optimizer
    is_test:
        Set true if the function is used in testing the model with test set, so it returns confusion matrix too.
    is_metadata:
        Is there additional metadata besides images and labels (for model using metadata)

    Returns
    -------
    Mean loss and accuracy of the epoch, and optionally confusion matrix

    """
    loss = []
    acc = []
    confusion_m = torch.zeros((3, 3))
    with tqdm(total=len(data_iterator)) as t:
        for idx, data in enumerate(data_iterator):
            t.update(1)
            labels = data[1].to(device)
            image = data[0].to(device)
            if is_metadata:
                metadata = data[2].to(device)
                if not model.training:
                    with torch.no_grad():
                        model_out = model.forward(image, metadata)
                else:
                    model_out = model.forward(image, metadata)
            else:
                if not model.training:
                    with torch.no_grad():
                        model_out = model.forward(image)
                else:
                    model_out = model.forward(image)

            indiv_loss = nn.functional.cross_entropy(
                model_out,
                labels,
                weight=torch.FloatTensor(weights_train).to(device))
            loss.append(indiv_loss.item())
            prediction = torch.argmax(model_out, dim=1)
            asd = np.equal(prediction.cpu().numpy(), labels.cpu().numpy())
            accuracy = np.mean(asd)
            acc.append(accuracy)
            if is_test is True:
                for idx2, label in enumerate(labels):
                    confusion_m[label.item(), prediction[idx2].item()] += 1
            if model.training is True:
                optimizer.zero_grad()
                indiv_loss.backward()
                optimizer.step()
    if is_test is True:
        return np.mean(loss), np.mean(acc), confusion_m.numpy()
    return np.mean(loss), np.mean(acc), None
Beispiel #26
0
def train_one_epoch(model: torch.nn.Module,
                    criterion: torch.nn.Module,
                    data_loader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device,
                    epoch: int,
                    max_norm: float = 0,
                    accumulate_batches=1):
    model.train()
    criterion.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter(
        'class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 100
    num_samples = len(data_loader)

    warmup_scheduler = None
    if epoch == 0:
        warmup_factor = 1. / 1000
        warmup_iters = min(1000, len(data_loader) - 1)
        warmup_scheduler = warmup_lr_scheduler(optimizer, warmup_iters,
                                               warmup_factor)

    optimizer.zero_grad()
    for step, (samples, targets) in enumerate(metric_logger.log_every(
            data_loader, print_freq, header),
                                              start=1):
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(samples)
        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()
                     if k in weight_dict)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {
            f'{k}_unscaled': v
            for k, v in loss_dict_reduced.items()
        }
        loss_dict_reduced_scaled = {
            k: v * weight_dict[k]
            for k, v in loss_dict_reduced.items() if k in weight_dict
        }
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        losses.backward()
        if ((step % accumulate_batches) == 0) or (step == num_samples):
            if max_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            optimizer.step()
            optimizer.zero_grad()

        if warmup_scheduler is not None:
            warmup_scheduler.step()

        metric_logger.update(
            loss=loss_value,
            **loss_dict_reduced_scaled)  #, **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    print("Global scale", model.global_scale)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
Beispiel #27
0
    def attach(self, optimizer: torch.optim.Optimizer):
        r"""
        Attaches the privacy engine to the optimizer.

        Attaches to the ``PrivacyEngine`` an optimizer object,and injects
        itself into the optimizer's step. To do that it,

        1. Validates that the model does not have unsupported layers.

        2. Adds a pointer to this object (the ``PrivacyEngine``) inside the optimizer.

        3. Moves optimizer's original ``step()`` function to ``original_step()``.

        4. Monkeypatches the optimizer's ``step()`` function to call ``step()`` on
        the query engine automatically whenever it would call ``step()`` for itself.

        Args:
            optimizer: The optimizer to which the privacy engine will attach
        """

        self.validator.validate(self.module)
        norm_clipper = (
            # pyre-fixme[6]: Expected `float` for 1st param but got
            #  `Union[List[float], float]`.
            clipping.ConstantFlatClipper(self.max_grad_norm)
            if not isinstance(self.max_grad_norm, list)
            # pyre-fixme[6]: Expected `List[float]` for 1st param but got
            #  `Union[List[float], float]`.
            else clipping.ConstantPerLayerClipper(self.max_grad_norm))

        if self.misc_settings.get("experimental", False):
            norm_clipper = clipping._Dynamic_Clipper_(
                # pyre-fixme[6]: Expected `List[float]` for 1st param but got
                #  `List[Union[List[float], float]]`.
                [self.max_grad_norm],
                self.misc_settings.get("clip_per_layer", False),
                self.misc_settings.get("clipping_method",
                                       clipping.ClippingMethod.STATIC),
                self.misc_settings.get("clipping_ratio", 0.0),
                self.misc_settings.get("clipping_momentum", 0.0),
            )

        self.clipper = PerSampleGradientClipper(self.module, norm_clipper,
                                                self.batch_first)

        def dp_step(self, closure=None):
            self.privacy_engine.step()
            self.original_step(closure)

        # Pyre doesn't like monkeypatching. But we'll do it anyway :)
        optimizer.privacy_engine = self  # pyre-ignore
        optimizer.original_step = optimizer.step  # pyre-ignore
        optimizer.step = types.MethodType(dp_step, optimizer)  # pyre-ignore

        def virtual_step(self):
            self.privacy_engine.virtual_step()

        # pyre-ignore
        optimizer.virtual_step = types.MethodType(virtual_step, optimizer)

        # create a cross reference for detaching
        self.optimizer = optimizer  # pyre-ignore
Beispiel #28
0
    def step_optimizer(
        self,
        optimizer: torch.optim.Optimizer,
        clip_grads: Optional[Callable[[Iterator], None]] = None,
        auto_zero_grads: bool = True,
        scaler: Optional[Any] = None,
        # Should be torch.cuda.amp.GradScaler, but:
        #   * other implementations might be possible
        #   * requiring this type forces upgrades to PyTorch 1.6+
    ) -> None:
        """
        Perform a single optimization step.

        This function must be called once for each optimizer. However, the order of
        different optimizers' steps can be specified by calling this function in different
        orders. Also, gradient accumulation across iterations is performed by the Determined
        training loop by setting the experiment configuration field
        :ref:`optimizations.aggregation_frequency <config-aggregation-frequency>`.

        Here is a code example:

        .. code-block:: python

            def clip_grads(params):
                torch.nn.utils.clip_grad_norm_(params, 0.0001),

            self.context.step_optimizer(self.opt1, clip_grads)

        Arguments:
            optimizer(``torch.optim.Optimizer``): Which optimizer should be stepped.
            clip_grads(a function, optional): This function should have one argument for
                parameters in order to clip the gradients.
            auto_zero_grads(bool, optional): Automatically zero out gradients automatically after
                stepping the optimizer. If false, you need to call ``optimizer.zero_grad()``
                manually. Note that if :ref:`optimizations.aggregation_frequency
                <config-aggregation-frequency>` is greater than 1, ``auto_zero_grads`` must be true.
            scaler(``torch.cuda.amp.GradScaler``, optional): The scaler to use for stepping the
                optimizer. This should be unset if not using AMP, and is necessary if
                ``wrap_scaler()`` was called directly.
        """

        check.true(
            auto_zero_grads or self.hvd_config.aggregation_frequency == 1,
            "if optimizations.aggregation_frequency is larger than 1, "
            "you can only set auto_zero_grads to be true. ",
        )

        if not self._should_communicate_and_update():
            return

        # Communication needs to be synchronized so that is completed
        # before we apply gradient clipping and `step()`. In the case of APEX
        # this is called in backward() instead, so that it's inside the context
        # manager and before unscaling.
        if self.hvd_config.use and not self._use_apex:
            optimizer.synchronize()  # type: ignore

        parameters = ([
            p for group in optimizer.param_groups
            for p in group.get("params", [])
        ] if not self._use_apex else apex.amp.master_params(optimizer))

        if self.hvd_config.average_aggregated_gradients:
            self._average_gradients(
                parameters=parameters,
                divisor=self.hvd_config.aggregation_frequency)

        if clip_grads is not None:
            if self._scaler and self.experimental._auto_amp:
                self._scaler.unscale_(optimizer)
            clip_grads(parameters)

        # For stepping the optimizer we will operate on the scaler passed
        # in, or fall back to the wrapped scaler (if any).
        if scaler is None and self.experimental._auto_amp:
            scaler = self._scaler
        if scaler:

            def step_fn() -> None:
                scaler.step(optimizer)  # type: ignore

        else:
            step_fn = optimizer.step  # type: ignore

        if self.hvd_config.use:
            with optimizer.skip_synchronize():  # type: ignore
                step_fn()
        else:
            step_fn()

        if auto_zero_grads:
            optimizer.zero_grad()
Beispiel #29
0
    def step_optimizer(
        self,
        optimizer: torch.optim.Optimizer,  # type: ignore
        clip_grads: Optional[Callable[[Iterator], None]] = None,
        auto_zero_grads: bool = True,
    ) -> None:
        """
        Perform a single optimization step.

        This function must be called once for each optimizer. However, the order of
        different optimizers' steps can be specified by calling this function in different
        orders. Also, gradient accumulation across iterations is performed by the Determined
        training loop by setting the experiment configuration field
        :ref:`optimizations.aggregation_frequency <config-aggregation-frequency>`.

        Here is a code example:

        .. code-block:: python

            def clip_grads(params):
                torch.nn.utils.clip_grad_norm_(params, 0.0001),

            self.context.step_optimizer(self.opt1, clip_grads)

        Arguments:
            optimizer(``torch.optim.Optimizer``): Which optimizer should be stepped.
            clip_grads(a function, optional): This function should have one argument for
                parameters in order to clip the gradients.
            auto_zero_grads(bool, optional): Automatically zero out gradients automatically after
                stepping the optimizer. If false, you need to call ``optimizer.zero_grad()``
                manually. Note that if :ref:`optimizations.aggregation_frequency
                <config-aggregation-frequency>` is greater than 1, ``auto_zero_grads`` must be true.
        """

        check.true(
            auto_zero_grads or self.hvd_config.aggregation_frequency > 1,
            "if optimizations.aggregation_frequency is larger than 1, "
            "you can only set auto_zero_grads to be true. ",
        )
        if self._should_communicate_and_update():
            # Communication needs to be synchronized so that is completed
            # before we apply gradient clipping and `step()`.
            if self.hvd_config.use and not self._use_amp:
                optimizer.synchronize()

            parameters = (
                [p for group in optimizer.param_groups for p in group.get("params", [])]
                if not self._use_amp
                else apex.amp.master_params(optimizer)
            )

            if self.hvd_config.average_aggregated_gradients:
                self._average_gradients(
                    parameters=parameters, divisor=self.hvd_config.aggregation_frequency
                )

            if clip_grads is not None:
                clip_grads(parameters)

            if self.hvd_config.use:
                with optimizer.skip_synchronize():
                    optimizer.step()
            else:
                optimizer.step()

            if auto_zero_grads:
                optimizer.zero_grad()
Beispiel #30
0
def _train_epoch(
    train_device: torch.device,
    model: torch.jit.ScriptModule,
    ddpmodel: ModelWrapperForDDP,
    model_path: Path,
    optim: torch.optim.Optimizer,
    assembler: tube.ChannelAssembler,
    stat: utils.MultiCounter,
    epoch: int,
    optim_params: OptimParams,
    sync_period: int,
) -> None:
    global _train_epoch_waiting_time
    #global _perfect_player
    pre_num_add = assembler.buffer_num_add()
    pre_num_sample = assembler.buffer_num_sample()
    sync_s = 0.
    num_sync = 0
    t = time.time()
    time.sleep(_train_epoch_waiting_time)
    lossmodel = DDPWrapperForModel(ddpmodel) if ddpmodel is not None else model
    for eid in range(optim_params.epoch_len):
        batch = assembler.sample(optim_params.batchsize)
        batch = utils.to_device(batch, train_device)
        loss, pred_pi, pred_v = model.loss(lossmodel, batch["s"], batch["v"],
                                           batch["pi"], batch["pi_mask"], stat)
        # _perfect_player.loss(batch['m_h'], pred_pi, pred_v, batch["pi"], batch["v"], stat)
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(),
                                             optim_params.grad_clip)
        optim.step()
        optim.zero_grad()

        if (epoch * optim_params.epoch_len + eid + 1) % sync_period == 0:
            sync_t0 = time.time()
            assembler.update_model(model.state_dict())
            sync_s += time.time() - sync_t0
            num_sync += 1

        stat["loss"].feed(loss.detach().item())
        stat["grad_norm"].feed(grad_norm)

    post_num_add = assembler.buffer_num_add()
    post_num_sample = assembler.buffer_num_sample()
    time_elapsed = time.time() - t
    delta_add = post_num_add - pre_num_add
    print("buffer add rate: %.2f / s" % (delta_add / time_elapsed))
    delta_sample = post_num_sample - pre_num_sample
    if delta_sample > 8 * delta_add:  # If the sample rate is not at least 8x the add rate, everything is fine.
        _train_epoch_waiting_time += time_elapsed
    else:
        _train_epoch_waiting_time = 0
    print("buffer sample rate: %.2f / s" % (delta_sample / time_elapsed))
    print(
        f"syncing duration: {sync_s:2f}s for {num_sync} syncs ({int(100 * sync_s / time_elapsed)}% of train time)"
    )

    stat.summary(epoch)
    wandb.log({
        "epoch": epoch,
        "loss": stat["loss"].mean(),
        "grad_norm": stat["grad_norm"].mean()
    })
    stat.reset()